from astropy.io import fits
import matplotlib.pyplot as plt
from matplotlib.ticker import NullLocator,MultipleLocator,FormatStrFormatter,FixedLocator, FuncFormatter,AutoMinorLocator,AutoLocator
import matplotlib.patches as patches
from astropy.visualization import make_lupton_rgb
from astropy import wcs
from astropy.coordinates import SkyCoord
from astropy import units as u
import fitsio
import pandas as pd

import numpy as np

# plt.rc('text', usetex=True)
plt.rc('font', family='serif')

galns = ['0249', '0228', 'chandra1']
galnames = ['SDSS J0249-0815', 'SDSS J0228-0901', 'SDSS J1107+1347']
redshifts = [0.029, 0.072, 0.045]
dirnames=['J0249-08/', 'J022849.51-090153.8/', '../4star/chandra1/']

galcors = [
    np.array([[42.30362, -8.25714]]),  np.array([[37.20632, -9.03163]]),
    np.array([[166.88013, 13.78691]])   ]


exptimes = [2028.0, 1230., 897.]

#xraycors = SkyCoord(['15:23:05.0 +11:45:53.18'], unit=(u.hour, u.deg))
#xraycors = [np.array([[xr.ra.deg]]) for xr in xraycors]


for galn, coords, gn, redshift, exptime, dirname in zip(galns, galcors, galnames, redshifts, exptimes, dirnames):
    print("Working on {:}".format(galn))
    files = ["panstrs_{:}_{:}.fits".format(galn, band) for band in ['g','r','i']]
    im_g, im_r, im_i = [fits.open(dirname+file)[1].data for file in files]

    hdr = fits.open(dirname+"panstrs_{:}_g.fits".format(galn))[1].header

    w = wcs.WCS(hdr)
    pixcrd = w.wcs_world2pix(coords, 1)[0]
    pixscl = np.sqrt(np.sum(w.pixel_scale_matrix[0,:]**2))*3600.0

    winpx = 150
    offset = [30,70]
    pixcrd_int = pixcrd.astype(int)
    regx = [pixcrd_int[0]-winpx, pixcrd_int[0]+winpx]
    regy = [pixcrd_int[1]-winpx, pixcrd_int[1]+winpx]

    image_b, image_g, image_r = [im[regy[0]:regy[1], regx[0]:regx[1]] for im in [im_g, im_r, im_i]]

    image = make_lupton_rgb(image_r, image_g, image_b, minimum=0.01, Q=0, stretch=7)


    fig, ax = plt.subplots(1, figsize=(5,5))
    # fig.subplots_adjust(left=0,right=1,bottom=0,top=1)
    ax.set_position([0,0,1,1])
    ax.imshow(image)
    ax.set_xlim(winpx-offset[0],winpx+offset[1])
    ax.set_ylim(winpx-offset[0],winpx+offset[1])

    # oplot NIRSpec FoV
    half_width = 3.0/2.0/pixscl
    fov = pixcrd + np.array([
        [-half_width,-half_width],
        [-half_width, half_width],
        [ half_width, half_width],
        [ half_width,-half_width],
        [-half_width,-half_width]]) - (pixcrd + np.array([-winpx,-winpx]))
    ax.plot(fov[:,0], fov[:,1], ls=':', color='yellow', lw=1.3)

    ax.text(0.97, 0.08, gn, color='yellow', fontsize=20, transform=ax.transAxes, 
        ha='right')

    yyran = ax.get_ylim()[1]-ax.get_ylim()[0]
    xxran = ax.get_xlim()[1]-ax.get_xlim()[0]
    yypos = ax.get_ylim()[0] + 0.1*yyran
    xxpos = ax.get_xlim()[0] + 0.06*xxran
    sclline = np.array([[xxpos, xxpos + 3.0/pixscl],[yypos, yypos]])
    ax.plot(sclline[0], sclline[1], color='yellow', marker='|')

    arcsec3_to_kpc = 3.0 * redshift * 299792.45 / 79.0 * 1e3 / 206265.0
    ax.text(xxpos + 3.0/pixscl/2.0, yypos + 0.03*yyran, r"3 arcsec", 
        color='yellow', fontsize=18, ha='center',va='bottom')
    ax.text(xxpos + 3.0/pixscl/2.0, yypos - 0.03*yyran, r"{:.1f} kpc".format(arcsec3_to_kpc), 
        color='yellow', fontsize=18, ha='center',va='top')

    # inset plot with surface brightness
    #tab = fits.open(file_ellipse)[1].data
    tab = pd.read_csv(dirname+'isolist.csv')

    scale = 0.05
    zeropnt = 21.665
    x = tab['x'] * scale
    intens = tab['data']
    #eint = 0.5*tab['data_err'] * np.sqrt(tab.field('NDATA'))
    eint = 0.5*tab['data_err']
    mag = (zeropnt-2.5*np.log10(intens/scale**2)) + 2.5*np.log10(exptime)
    mag_ep = (zeropnt-2.5*np.log10((intens+eint)/scale**2))+ 2.5*np.log10(exptime)
    mag_en = (zeropnt-2.5*np.log10((intens-eint)/scale**2))+ 2.5*np.log10(exptime)

    axmu = plt.axes((0.5,0.5,0.48,0.47), facecolor='white', alpha=0.1)
    y = mag
    axmu.plot(x, mag_ep, '-', color='gray')
    axmu.plot(x, mag_en, '-', color='gray')
    axmu.plot(x, mag, 'b.')
    axmu.set_ylim(np.max(mag)+0.5, np.min(mag)-0.5)
    #axmu.set_xlim(.0, 6.0)
    axmu.set_xlabel("R, arcsec", fontsize=14, color='yellow')
    axmu.set_ylabel(r"$\mu_i$, mag/arcsec$^2$", fontsize=14, color='yellow')
    axmu.xaxis.set_tick_params(which='both',color='white',labelcolor='yellow',labelsize=14)
    axmu.yaxis.set_tick_params(which='both',color='white',labelcolor='yellow',labelsize=14)
    plt.setp(axmu.spines.values(), color='yellow')
    # axmu.yaxis.set_minor_locator(AutoMinorLocator())
    axmu.yaxis.set_minor_locator(MultipleLocator(0.2))
    axmu.xaxis.set_minor_locator(MultipleLocator(0.2))
    axmu.patch.set_alpha(0.7)

    # ax.axis('tight')
    # ax.axis('off')

    plt.savefig("panstrs_{:}_rgb.png".format(galn), dpi=300)
    plt.close()
