from astropy.io import fits
from astropy.utils.data import download_file
import matplotlib.pyplot as plt
from photutils.isophote import Ellipse, EllipseGeometry, build_ellipse_model
import numpy as np
from astropy.modeling.models import Sersic1D
from scipy.optimize import curve_fit
from astropy.nddata import Cutout2D
from astropy import units as u
from lmfit import minimize, Minimizer, Parameters, Parameter, report_fit
from scipy.interpolate import interp1d
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import gridspec
import pandas as pd
from matplotlib.colors import LogNorm
from photutils.isophote import Ellipse
import sys

_, keyword = sys.argv

if keyword=='plotmodel':
    '''image_data = fits.getdata('587725491599114507/hst_11142_0m_acs_wfc_f814w_drz.fits')
    position = (4595.7384, 3135.4924)
    size = (1000, 1000)   # pixels'''
    image_data = fits.getdata('J022849.51-090153.8/hst_11130_11_wfpc2_f814w_pc_drz.fits')
    position = (429.3, 461.8)
    size = (300, 300)   # pixels
    cutout = Cutout2D(image_data, position, size)
    data = cutout.data
    '''url = 'https://github.com/astropy/photutils-datasets/raw/master/data/isophote/M51.fits'
    path = download_file(url)
    hdu = fits.open(path)
    data = hdu[0].data
    hdu.close()'''
    #g = EllipseGeometry(530., 511, 10., 0.1, 10./180.*np.pi)
    #g.find_center(data)
    #g = EllipseGeometry(x0=500., y0=500., sma=50., eps=0.4,
     #                         pa=80.*np.pi/180.)
    #g.find_center(data)
    #ellipse = Ellipse(data)
    x0 = 500. 
    y0 = 500.
    sma = 50.    # start from here, and use this as the minimum sma to fit as well.
    eps = 0.17
    pa = 70. / 180. * np.pi
    g = EllipseGeometry(x0, y0, sma, eps, pa)
    g.find_center(data)
    ellipse = Ellipse(data, geometry=g)
    isolist = ellipse.fit_image(sclip=2., nclip=3, maxsma=143, step=0.05)
    print(isolist.to_table())
    print(isolist.intens)
    print(isolist.int_err)
    fill = np.mean(data[0:10, 0:10])
    model_image = build_ellipse_model(data.shape, isolist, fill=fill)
    residual = data - model_image

    #fig, axs = plt.subplots(ncols=3, nrows=3, figsize=(15, 10))
    #gs = axs[1, 0].get_gridspec()
    fig = plt.figure(figsize=(15, 10))
    gs = gridspec.GridSpec(3, 3, height_ratios=[3, 3, 1]) 
    ax0 = plt.subplot(gs[0, 0])
    ax1 = plt.subplot(gs[0, 1])
    ax2 = plt.subplot(gs[0, 2])
    axs = [[ax0, ax1, ax2]]
    axbig = fig.add_subplot(gs[1, :])

    axbig2 = fig.add_subplot(gs[2, :], sharex=axbig)
    im0 = axs[0][0].imshow(data, vmin=0, vmax=1)
    axs[0][0].set_title("Data")
    divider0 = make_axes_locatable(axs[0][0])
    cax0 = divider0.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im0, cax=cax0, orientation='vertical')
    isos = []
    isos = []
    smas = [20., 50., 90., 120., 180., 260., 370., 480.]
    '''for sma in smas:
        iso = isolist.get_closest(sma)
        isos.append(iso)
        x, y, = iso.sampled_coordinates()
        axs[0, 0].plot(x, y, color='white')'''

    im1 = axs[0][1].imshow(model_image, vmin=0, vmax=1)
    axs[0][1].set_title("Model")
    divider1 = make_axes_locatable(axs[0][1])
    cax1 = divider1.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im1, cax=cax1, orientation='vertical')

    im2 = axs[0][2].imshow(residual, vmin=-0.5, vmax=0.5)
    axs[0][2].set_title("Residual")
    divider2 = make_axes_locatable(axs[0][2])
    cax2 = divider2.append_axes('right', size='5%', pad=0.05)
    fig.colorbar(im2, cax=cax2, orientation='vertical')

    zeropoint = 21.665
    axbig.plot(isolist.sma, zeropoint-2.5*np.log10(isolist.intens), 'ko', markersize=2)
    axbig.set_title("Brightness profile")
    axbig2.set_xlabel('Radius')
    axbig.set_ylabel('Magnitude')
    axbig.invert_yaxis()
    axbig2.set_xticks([10, 100, 200, 500])

    d = {'x': isolist.sma, 'data': isolist.intens, 'data_err': isolist.int_err}
    df = pd.DataFrame(data=d)
    df.to_csv('J022849.51-090153.8/isolist.csv')
    x = df['x']
    print(np.min(np.diff(x)))
    data = df['data']
    data0=df['data']
    data_err = df['data_err']
    data_err = data_err / data
    data_err[0] = data_err[1]
    data = zeropoint - 2.5*np.log10(data) 

else:
    fig = plt.figure(figsize=(15, 10))
    gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1]) 
    axbig2 = fig.add_subplot(gs[1, :])
    axbig = fig.add_subplot(gs[0, :], sharex=axbig2)
    df = pd.read_csv('J022849.51-090153.8/isolist.csv')
    x = df['x']
    print(np.min(np.diff(x)))
    data = df['data']
    data0=df['data']
    data_err = df['data_err']
    data_err = data_err / data
    zeropoint = 21.665
    axbig.plot(x, zeropoint-2.5*np.log10(data/1230./(0.05**2)/(0.05**2)), 'ko', markersize=2)
    axbig.set_title("Brightness profile")
    axbig2.set_xlabel('Radius')
    axbig.set_ylabel('Magnitude')
    axbig.invert_yaxis()
    data_err[0] = data_err[1]

    data = zeropoint - 2.5*np.log10(data) #zeropoint from https://www.stsci.edu/hst/wfpc2/Wfpc2_dhb/wfpc2_ch52.html


def b(n):
    return 2*n - 1.0/3.0 + 4.0 / (405.0 * n) + 46.0 / (25515.0 * n**2) 

def fcn2min(params, components, x, data):
    I_sum = np.zeros(x_prof.shape)
    for i in range(len(components)):
        if components[i]['type']=='sersic':
            mu_0 = params['mu_0'+str(i)].value
            I_0 = 10.**(-0.4*mu_0)
            r0 = params['r_0'+str(i)].value
            n = params['n'+str(i)].value
            I_sum = I_sum + I_0*np.exp(-b(n)*((x_prof/r0)**(1.0/n) - 1.0))
        if components[i]['type']=='psf':
            mu_0 = params['mu_0'+str(i)].value
            I_0 = 10.**(-0.4*mu_0)
            I_sum = I_sum + I_0*np.exp(-0.5*(x_prof/sigma)**2)
        '''mu_0 = params['mu_0'+str(i)].value
        I_0 = 10.**(-0.4*mu_0)
        r0 = params['r_0'+str(i)].value
        n = params['n'+str(i)].value
        I_sum = I_sum + I_0*np.exp(-b(n)*((x_prof/r0)**(1.0/n) - 1.0))    '''
    I_sum = np.array(I_sum)
    I_sum_copied = I_sum[:]
    I_sum = I_sum[::-1]
    I_sum = np.concatenate((I_sum,I_sum_copied))
    I_sum = np.convolve(I_sum,psf, 'same')
    I_sum = I_sum[int(len(I_sum)/2):]
    model = -2.5*np.log10(I_sum) 
    f = interp1d(x_prof, model)
    model_pix = f(x)
    return (model_pix - data)/data_err

def make_final_res(params, components, x, data):
    I_sum = np.zeros(x_prof.shape)
    for i in range(len(components)):
        if components[i]['type']=='sersic':
            mu_0 = result.params['mu_0'+str(i)].value
            r0 = result.params['r_0'+str(i)].value
            n = result.params['n'+str(i)].value
            I_0 = 10.**(-0.4*mu_0)
            prof = I_0*np.exp(-b(n)*((x_prof/r0)**(1.0/n) - 1.0))
            I_sum = I_sum + prof
        if components[i]['type']=='psf':
            mu_0 = result.params['mu_0'+str(i)].value
            I_0 = 10.**(-0.4*mu_0)
            prof = I_0*np.exp(-0.5*(x_prof/sigma)**2)
            I_sum = I_sum + prof    
        '''mu_0 = params['mu_0'+str(i)].value
        I_0 = 10.**(-0.4*mu_0)
        r0 = params['r_0'+str(i)].value
        n = params['n'+str(i)].value
        prof = I_0*np.exp(-b(n)*((x_prof/r0)**(1.0/n) - 1.0))
        print(-2.5*np.log10(I_sum))
        I_sum = I_sum + prof  
        print(-2.5*np.log10(I_sum))  '''
    I_sum = np.array(I_sum)
    I_sum_copied = I_sum[:]
    I_sum = I_sum[::-1]
    I_sum = np.concatenate((I_sum,I_sum_copied))
    I_sum = np.convolve(I_sum,psf, 'same')
    I_sum = I_sum[int(len(I_sum)/2):]
    model = -2.5*np.log10(I_sum)
    #print('Chi2: %.2f' % np.sum(((model - data)/data_err)**2))
    return (model, x_prof) 

#hdulist_psf = fits.open('587729160057127189/res_catalog.psf')
#hdulist_psf = fits.open('psf_gauss_output.fits')
#print(hdulist_psf.info())
#dat_psf = np.array(hdulist_psf[1].data)
dat_psf = fits.getdata('J022849.51-090153.8/res_catalog.psf')[0][0][0]
print(dat_psf)
print(dat_psf.shape)
psfi = dat_psf[12,:]
f = interp1d(np.linspace(0, len(psfi) - 1 , len(psfi)),psfi)
step_prof = np.min(np.diff(x)) / 2.0
psf = f(np.linspace(0, len(psfi) - 1 , int(len(psfi)/step_prof)))
fwhm = 2.69
sigma = fwhm / 2.355
print('SIGMA', sigma)
x_psf = np.linspace(0, 1000, 1001)
psf = np.exp(- 0.5 * ((x_psf - 1001/2.0)/sigma)**2)
psf[psf<0] = 0
psf = psf/sum(psf)
N_points = int(np.max(x * 1.1) / step_prof)
print('N points = ', N_points)
x_prof = np.linspace(0, np.max(x * 1.1), N_points)

components = {0:{'type':'sersic', 'mu_0':25, 'r_0':500, 'n':1},
              1:{'type':'sersic', 'mu_0':25, 'r_0':10, 'n':4}}
'''components = {0:{'type':'psf','mu_0':19},
              1:{'type':'sersic','mu_0':22.5,'r_0':20,'n':4.},
              2:{'type':'sersic','mu_0':23.9,'r_0':400,'n':2.}}'''
params = Parameters()
params.add('mu_0'+str(0), value=components[0]['mu_0'], min=0)
for i in range(0, len(components)):
    params.add('mu_0'+str(i),   value=components[i]['mu_0'], min=0)
    params.add('r_0'+str(i), value=components[i]['r_0'], min=0)
    params.add('n'+str(i), value=components[i]['n'], min=0)
#params.add('n'+str(3), value=components[3]['n'], m+98in=0, vary=False)  
#params.add('n'+str(2), value=components[2]['n'], min=0, vary=False)  
print('before: ', params)
print(components)
minner = Minimizer(fcn2min, params, fcn_args=(components, x, data))
result = minner.minimize(method='leastsq')
finalr = make_final_res(result.params, components, x, data)  
print('after: ', result.params)

report_fit(result)
print(finalr[1], finalr[0])
axbig.plot(finalr[1], finalr[0], 'r-')   
for i in range(len(components)):
    mu_0 = result.params['mu_0'+str(i)].value
    I_0 = 10.**(-0.4*mu_0)
    if components[i]['type']=='sersic':
        print(components[i]['type'])
        r0 = result.params['r_0'+str(i)].value
        n = result.params['n'+str(i)].value
        I_sum = I_0*np.exp(-b(n)*((x_prof/r0)**(1.0/n) - 1.0))
    if components[i]['type']=='psf':
        I_sum = I_0*np.exp(-0.5*(x_prof/sigma)**2)
    '''r0 = result.params['r_0'+str(i)].value
    n = result.params['n'+str(i)].value
    I_sum = I_0*np.exp(-b(n)*((x_prof/r0)**(1.0/n) - 1.0))  '''   
    I_sum = np.array(I_sum)
    I_sum_copied = I_sum[:]
    I_sum = I_sum[::-1]
    I_sum = np.concatenate((I_sum,I_sum_copied))
    I_sum = np.convolve(I_sum,psf, 'same')
    I_sum = I_sum[int(len(I_sum)/2):]
    model = -2.5*np.log10(I_sum)
    axbig.plot(x_prof, model, 'g--', label='comp'+str(i))
axbig.set_ylim([25, 18.5])
#axbig.set_xlim([0, 160])
ind = []
for i in range(len(x)):
    ind.append(np.abs(finalr[1]-x[i]).argmin())

axbig2.errorbar(x, zeropoint-2.5*np.log10(data0)-finalr[0][ind], yerr=data_err, ls='', marker='o', ms=2, mfc='k')
#axbig2.set_xlim([0, 160])
axbig2.set_ylim([-0.5, 0.5])
axbig2.set_ylabel('Residuals')
print('SIGMA: ', sigma)


plt.show()
