from astropy.table import Table
import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.colors import LogNorm
from matplotlib.ticker import NullLocator,MultipleLocator,FormatStrFormatter,FixedLocator, FuncFormatter,AutoMinorLocator,AutoLocator, ScalarFormatter
from matplotlib.patches import Rectangle
from astropy import stats
from scipy.integrate import simps
import fitsio
import numpy as np
import sys, os
from flux_sig_to_mbh import flux_sig_to_mbh


def print_info(dec, psfcor):
    mbh,mbh_err = flux_sig_to_mbh(dec['PARS_BLR_FLUX_HA']/psfcor,dec['PARS_BLR_SIG'],
        dec['PARS_VSYS']/299792.45, ref='reines13',
        flux_err=dec['PARS_BLR_FLUX_HALPHA_ERR']/psfcor, sigma_err=dec['PARS_BLR_SIG_ERR'])
    print("MBH from fits ext     |MBH Python formulae   | BLR_SIG        | BLR_FLUX_Ha   | c2dof vs noblr|")
    print("{:.2e} +/- {:.2e} |{:.2e} +/- {:.2e} | {:.1f} +/- {:.1f} | {:.1f} +/- {:.1f} | {:.2f} vs. {:.2f} | ".format(
        dec['PARS_MBH'],
        dec['PARS_MBH_ERR'],
        mbh.value,
        mbh_err.value,
        dec['PARS_BLR_SIG'],
        dec['PARS_BLR_SIG_ERR'],
        dec['PARS_BLR_FLUX_HA'],
        dec['PARS_BLR_FLUX_HALPHA_ERR'],
        dec['PARS_CHI2DOF'],
        dec['PARS_NOBLR_CHI2DOF']))
    print("NB: flux values corrected for aperture and PSF correction ({})".format(psfcor))


def calc_chi2(spec_full, spec, fit, fit_noblr, err, dec):
    npars_all = ( len(dec['LOSVD_VBIN']) + len(dec['PARS_COMP_FLUX_NLR']) +
        len(dec['PARS_COMP_FLUX_BLR']) + len([dec['PARS_BLR_POS'], dec['PARS_BLR_SIG']]) )
    npars_all_noblr = len(spec_full) - ( len(dec['LOSVD_VBIN']) + len(dec['PARS_COMP_FLUX_NLR']))
    corr_coef = float(len(spec_full)) / len(spec)
    dof = float(len(spec)) #- npars_all * corr_coef
    dof_noblr = float(len(spec)) #- npars_all_noblr * corr_coef

    chi2 = np.nansum((spec - fit)**2/err**2) / dof
    chi2noblr = np.nansum((spec - fit_noblr)**2/err**2) / dof_noblr

    return chi2, chi2noblr


def make_plot(p, verbose=True, psfcor=1.0, mage=False):

    print("Input: {}".format(p['file_in']))
    print("Output: {}".format(p['file_out']))

    fits = fitsio.FITS(p['file_in'])
    sp = fits['SPECTRUM'].read()[0]
    dec = fits['EMIS_SPEC_DECOMP'].read()[0]
    mbh,mbh_err = flux_sig_to_mbh(dec['PARS_BLR_FLUX_HA']/psfcor,dec['PARS_BLR_SIG'],
                dec['PARS_VSYS']/299792.45, ref='reines13',
                flux_err=dec['PARS_BLR_FLUX_HALPHA_ERR']/psfcor, sigma_err=dec['PARS_BLR_SIG_ERR'])

    if p.has_key('file_in_mage'):
        fits = fitsio.FITS(p['file_in_mage'])
        mage_sp = fits['SPECTRUM'].read()[0]
        mage_dec = fits['EMIS_SPEC_DECOMP'].read()[0]
        mage_mbh, mage_mbh_err = flux_sig_to_mbh(mage_dec['PARS_BLR_FLUX_HA']/psfcor,mage_dec['PARS_BLR_SIG'],
                mage_dec['PARS_VSYS']/299792.45, ref='reines13',
                flux_err=mage_dec['PARS_BLR_FLUX_HALPHA_ERR']/psfcor, sigma_err=mage_dec['PARS_BLR_SIG_ERR'])

    if verbose:
        print_info(dec, psfcor)
        if p.has_key('file_in_mage'):
            print("---------------- MagE data ----------------")
            print_info(mage_dec, psfcor)

    c = 299792.45
    z = (dec['PARS_VSYS']) / c
    residuals = sp['FLUX']-sp['FIT']
    residuals_clipped = stats.sigma_clip(residuals, sigma=4)
    residuals_std = np.nanstd(residuals_clipped)
    print("RMS of residuals: {}".format(residuals_std))
    waves = sp['WAVE'] / (1+z)

    sp_wave = sp['WAVE'] / (1 + z)
    sp_flux = sp['FLUX']
    sp_fit = sp['FIT']
    sp_err = sp['ERROR']
    sp_emis = sp_flux - sp_fit



    spec_wave = dec['SPEC_WAVE'] / (1 + z)
    spec_emis = dec['SPEC_EMIS']
    spec_error = dec['SPEC_ERROR']
    spec_fit = dec['SPEC_FIT']
    spec_fit_noblr = dec['SPEC_FIT_NOBLR']
    spec_comps_nlr = dec['SPEC_COMPS_NLR']
    spec_comps_blr = dec['SPEC_COMPS_BLR']

    # calculate reduced chi2 within small region around Halpha+NII
    waves_chi2 = [6535.0, 6595.0]
    idx_chi2 = (spec_wave > waves_chi2[0]) & (spec_wave < waves_chi2[1])
    chi2, chi2noblr = calc_chi2(spec_emis, spec_emis[idx_chi2], spec_fit[idx_chi2],
        spec_fit_noblr[idx_chi2], spec_error[idx_chi2], dec)

    print("Chi2",chi2,chi2noblr)
    # make plot ================================================================
    if p.has_key('file_in_mage'):
        fig = plt.figure(figsize=(12.5, 7))
        gs = gridspec.GridSpec(3, 6,
            width_ratios=[1, 1, 1, 1, 1, 1.5],
            wspace = 0.25,
            hspace = 0.25,
            height_ratios=[1, 1.5, 1.5]
            )
        gsm = gridspec.GridSpec(3, 6,
            width_ratios=[1, 1, 1, 1, 1,1.5],
            wspace = 0.25,
            hspace = 0.25,
            height_ratios=[1, 1.5, 1.5]
            )
        ax0 = plt.subplot(gs[0,:])
        axs = [plt.subplot(gs[1,0:2])] + [plt.subplot(gs[1,i+2]) for i in range(3)]
        axs_bnlr = plt.subplot(gs[1,-1])
        axm = [plt.subplot(gsm[2,0:4]), plt.subplot(gsm[2,4])]
        axm_bnlr = plt.subplot(gsm[2,-1])
        # ax0 = plt.subplot2grid((8,4),(0,0),colspan=5,rowspan=2)

        # axs = [ plt.subplot2grid((8,5),(2,i), rowspan=3) for i in range(4)]
        # axs_bnlr = plt.subplot2grid((3,5),(2,4), rowspan=3)
        # axm = [ plt.subplot2grid((8,3), (5,0)), plt.subplot2grid((8,3),(5,1))]
        # axm_bnlr = plt.subplot2grid((8,3),(5,2), rowspan=3)
    else:
        fig = plt.figure(figsize=(12.5, 4.5))
        ax0 = plt.subplot2grid((2,4),(0,0),colspan=5)
        gs = gridspec.GridSpec(2, 6,
            width_ratios=[1, 1, 1, 1, 1, 1.5],
            wspace = 0.25,
            hspace = 0.25,
            height_ratios=[1, 1.5]
            )
        ax0 = plt.subplot(gs[0,:])
        axs = [plt.subplot(gs[1,0:2])] + [plt.subplot(gs[1,i+2]) for i in range(3)]
        axs_bnlr = plt.subplot(gs[1,-1])

    # plt.subplots_adjust(hspace=0.2, wspace=0.2)

    # ax0.plot(waves, residuals, color='gray', alpha=0.5, lw=0.5)
    ax0.plot(waves, sp['FLUX'], color='navy', alpha=0.8, lw=0.75, label='SDSS spectrum')
    ax0.plot(waves, sp['FIT'], color='coral', lw=1.00, label='Starlight model')
    ax0.legend()
    ax0.set_xlim(3800, 6800)
    index = (waves > 3800) & (waves < 6800)
    ax0.set_ylim(min(sp['FIT'][index])-residuals_std, 1.3*max(sp['FIT'][index])+1.5*residuals_std)

    ax0.xaxis.set_minor_locator(AutoMinorLocator())
    ax0.xaxis.set_major_locator(AutoLocator())
    ax0.yaxis.set_minor_locator(AutoMinorLocator())
    ax0.yaxis.set_major_locator(AutoLocator())

    ax0.set_title(p['sname'][:5]+p['sname'][7:12])

    empty_patch = Rectangle((0, 0), 0, 0, alpha=0.0)
    # plot subplots
    winwl = [[40,30], [15,15], [15,15], [22,18]]
    # wlines = np.array([4861.36, 4958.91, 5006.84, 6562.79, 6724])
    wlines = np.array([6562.79+3, 4861.36, 5006.84, 6724])
    lnames = [r'H$\alpha$+[NII]', r'H$\beta$','[OIII]','[SII]']

    for id, ax in enumerate(axs):
        # id = axs.index(ax)

        idx = (spec_wave > wlines[id]-winwl[i][0]) & (spec_wave < wlines[id]+winwl[i][1])
        ax.step(sp_wave, sp_emis, color='k', where='mid')
        ax.set_xlim(wlines[id]-winwl[id][0], wlines[id]+winwl[id][1])

        if(id == 1): # Hbeta
            ax.plot(spec_wave, spec_comps_nlr[0,:], color='blue', ls='-')
        if(id == 0): # Halpha
            ax.plot(spec_wave, spec_comps_nlr[2,:], color='blue', ls='-')
            ax.plot(spec_wave, spec_comps_nlr[3,:], color='blue', ls='-')

        ax.step(spec_wave, spec_fit, color='magenta', where='mid', lw=2, alpha=0.7)

        if(id == 1): # Hbeta
            ax.step(spec_wave, spec_comps_blr[0,:], color='red', where='mid', lw=1.5)
        if(id == 0): # Halpha
            ax.step(spec_wave, spec_comps_blr[1,:], color='red', where='mid', lw=1.5)
            ax.legend([empty_patch]*3,
                    [r"H$\alpha$+[NII]",
                    r"$\chi^2$={:.1f} ({:.1f})".format(chi2, chi2noblr),
                    r"$\rm {:.0f} \pm {:.0f} \, kM_\odot$".format(mbh.value/1e3, mbh_err.value/1e3)],
                    handlelength=0, handletextpad=0, numpoints=None, labelspacing=0.5,
                    borderpad=0, frameon=False, loc='upper left')

        ax.xaxis.set_minor_locator(MultipleLocator(2))
        ax.xaxis.set_major_locator(MultipleLocator(20))
        # ax.xaxis.set_major_formatter(FormatStrFormatter('%.0f'))
        # ax.yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
        ax.yaxis.set_minor_locator(AutoMinorLocator())

        mm = max(np.concatenate( (spec_emis[idx], spec_fit[idx]) ))
        if(id == 0):
            maxblr = max(spec_comps_blr[1])
            if maxblr < mm/5.0:
                mm = maxblr*5.0

        resid = spec_emis-spec_fit
        rms = np.nanstd(resid)

        resid_level = -0.3*mm
        ax.set_ylim(resid_level + 0.4*resid_level, mm*1.05)
        ax.step(spec_wave, resid + resid_level, color='k', where='mid', alpha=0.5)
        ax.plot(spec_wave, spec_error + resid_level, color='gray', ls='-', alpha=0.5)
        ax.plot(spec_wave, -spec_error + resid_level, color='gray', ls='-', alpha=0.5)
        ax.plot([wlines[id] - winwl[id][0], wlines[id] + winwl[id][1]],[resid_level, resid_level], color='k', alpha=0.5)
        # ax.text(0.04,0.9,lnames[id],transform=ax.transAxes)
        if id != 0:
            ax.legend([empty_patch], [lnames[id]], handlelength=0, handletextpad=0,
                numpoints=None, labelspacing=0.1, borderpad=0.2, loc='upper left',
                frameon=False)

    # plot LOSVDs
    if p.has_key('file_in_mage'):
        zips = zip([dec, mage_dec], [axs_bnlr, axm_bnlr], ['sdss','mage'])
    else:
        zips = zip([dec], [axs_bnlr], ['sdss'])

    for dd, axlosvd, dataset in zips:
        vbin = dd['LOSVD_VBIN']
        losvd_nlr = dd['LOSVD_NLR']
        vbinexp = np.linspace(-5e3,5e3,1000)
        norm = -simps(losvd_nlr, vbin)
        losvd_nlr /= norm
        losvd_blr = 1.0/np.sqrt(2.0*np.pi)/dd['PARS_BLR_SIG']*np.exp( -0.5*((vbinexp)/dd['PARS_BLR_SIG'])**2 )
        # losvd_blr /= np.sum(losvd_blr)
        # offset = 1e-2
        # losvd_nlr /= offset
        # losvd_blr /= offset
        axlosvd.step(vbin, losvd_nlr, color='blue',lw=1.5)
        axlosvd.plot(vbinexp, losvd_blr, color='red', lw=2)
        axlosvd.set_xlim(-700,700)
        axlosvd.set_ylim(0,max(losvd_nlr)*1.2)
        if dataset == 'mage':
            axlosvd.set_xlabel('Velocity, km/s')
        # axlosvd.text(0.06,0.9,'Normalized LOSVDs',transform=axlosvd.transAxes)
        axlosvd.legend([empty_patch]*1, ["Normalized LOSVDs"],
                    handlelength=0, handletextpad=0, numpoints=None, labelspacing=0.3, borderpad=0.2,
                    framealpha=0.7, facecolor='white', loc='upper left', frameon=False)
        axlosvd.xaxis.set_minor_locator(MultipleLocator(50))
        axlosvd.xaxis.set_major_locator(MultipleLocator(500))
        formatter = ScalarFormatter()
        formatter.set_scientific(True)
        formatter.set_powerlimits((-1,1))
        axlosvd.yaxis.set_major_formatter(formatter)
        # import ipdb; ipdb.set_trace()

    if p.has_key('file_in_mage'):

        sp_wave = mage_sp['WAVE'] / (1 + z)
        sp_flux = mage_sp['FLUX']
        sp_fit = mage_sp['FIT']
        sp_err = mage_sp['ERROR']
        sp_emis = sp_flux - sp_fit

        spec_wave = mage_dec['SPEC_WAVE'] / (1 + z)
        spec_emis = mage_dec['SPEC_EMIS']
        spec_error = mage_dec['SPEC_ERROR']
        spec_fit = mage_dec['SPEC_FIT']
        spec_fit_noblr = mage_dec['SPEC_FIT_NOBLR']
        spec_comps_nlr = mage_dec['SPEC_COMPS_NLR']
        spec_comps_blr = mage_dec['SPEC_COMPS_BLR']

        # calculate reduced chi2 within small region around Halpha+NII
        waves_chi2 = [6535.0, 6595.0]
        idx_chi2 = (spec_wave > waves_chi2[0]) & (spec_wave < waves_chi2[1])
        chi2, chi2noblr = calc_chi2(spec_emis, spec_emis[idx_chi2], spec_fit[idx_chi2],
            spec_fit_noblr[idx_chi2], spec_error[idx_chi2], dec)
        print("Mage chi2:", chi2, chi2noblr)


        winwl = [[40.0,108], [22,18]]
        for i, id, ax in zip([0,1], [0,3], axm):

            idx = (spec_wave > wlines[id]-winwl[i][0]) & (spec_wave < wlines[id]+winwl[i][1])
            ax.step(sp_wave, sp_emis, color='k', where='mid')
            ax.set_xlim(wlines[id] - winwl[i][0], wlines[id] + winwl[i][1])

            if(id == 0): # Halpha
                ax.plot(spec_wave, spec_comps_nlr[2,:], color='blue', ls='-')
                ax.plot(spec_wave, spec_comps_nlr[3,:], color='blue', ls='-')

            ax.step(spec_wave, spec_fit, color='magenta', where='mid', lw=2, alpha=0.6)

            if(id == 0): # Halpha
                ax.step(spec_wave, spec_comps_blr[1,:], color='red', where='mid', lw=1.5)
                ax.legend([empty_patch]*1, [r"H$\alpha$+[NII]"],
                    handlelength=0, handletextpad=0, numpoints=None, labelspacing=0.5,
                    borderpad=0, frameon=False, loc='upper left')
                ax.legend([empty_patch]*2,
                    [r"$\chi^2$={:.1f} ({:.1f})".format(chi2, chi2noblr),
                    r"$\rm {:.0f}\pm{:.0f} \,kM_\odot$".format(mage_mbh.value/1e3,mage_mbh_err.value/1e3)],
                    handlelength=0, handletextpad=0, numpoints=None, labelspacing=0.5,
                    borderpad=0.5, frameon=False, loc='best')

            ax.xaxis.set_minor_locator(MultipleLocator(2))
            ax.xaxis.set_major_locator(MultipleLocator(20))
            # ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))
            ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
            ax.yaxis.set_minor_locator(AutoMinorLocator())

            mm = max(np.concatenate( (spec_emis[idx], spec_fit[idx]) ))
            maxblr = max(spec_comps_blr[1])
            if maxblr < mm/5.0:
                mm = maxblr*5.0

            resid = sp_emis - np.interp(sp_wave, spec_wave, spec_fit)
            rms = np.nanstd(resid)

            resid_level = -0.3*mm
            ax.set_ylim(resid_level + 0.4*resid_level, mm*1.05)
            ax.step(sp_wave, resid + resid_level, color='k', where='mid', alpha=0.5)
            ax.plot(spec_wave, spec_error + resid_level, color='gray', ls='-', alpha=0.5)
            ax.plot(spec_wave, -spec_error + resid_level, color='gray', ls='-', alpha=0.5)
            ax.plot([wlines[id] - winwl[i], wlines[id] + winwl[i]],[resid_level, resid_level], color='k', alpha=0.5)
            # ax.text(0.04,0.9,lnames[id+2],transform=ax.transAxes)
            # ax.text(0.04,0.9,lnames[id],transform=ax.transAxes)
            if id == 3:
                # ax.text(0.04,0.8,'MagE spectrum', fontsize=18, transform=ax.transAxes)
                first_legend = ax.legend([empty_patch], [lnames[id]], handlelength=0, handletextpad=0,
                    numpoints=None, labelspacing=0.1, borderpad=0.2, frameon=False, loc='upper right')

    # for axx in [ax0, ax_bnlr] + axs:
        # axx.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))

    # set labels
    if p.has_key('file_in_mage'):
        fig.text(0.47,0.03, r'Restframe wavelengths, \AA', ha='center', fontsize=16)
        fig.text(0.06,0.5, r'$F_\lambda$, 10$^{-17}$ erg cm$^{-2}$ s$^{-1}$ ${\rm \AA}^{-1}$',
            rotation='vertical', va='center', fontsize=16)
        fig.text(0.085,0.25, r'MagE', rotation='vertical', va='center', fontsize=15)
        fig.text(0.085,0.54, r'SDSS', rotation='vertical', va='center', fontsize=15)
    else:
        fig.text(0.47,0.01, r'Restframe wavelengths, \AA', ha='center', fontsize=16)
        fig.text(0.07,0.5, r'$F_\lambda$, 10$^{-17}$ erg cm$^{-2}$ s$^{-1}$ ${\rm \AA}^{-1}$',
            rotation='vertical', va='center', fontsize=16)

    fig.savefig(p['file_out'], bbox_inches='tight')
    plt.close()


plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True

plt.rcParams['ps.fonttype'] = 42
plt.rcParams['pdf.fonttype'] = 42

plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rcParams['font.size'] = 12


wds = '../../../data/spectra/'
wdp = '../../../paper/diagrams/'
parameters = [
# Literature objects
    {
    'sname': 'J152304+114553',
    'file_out': wdp+'spec_decomp_J152304+114553.pdf',
    'file_in': wds+'spSpec54234-2753-127_results.fits',

    },
    {
    'sname': 'J153425+040806',
    'file_out': wdp+'spec_decomp_J153425+040806.pdf',
    'file_in': wds+'spSpec52026-0593-547_results.fits',
    },
    {
    'sname': 'J160531+174826',
    'file_out': wdp+'spec_decomp_J160531+174826.pdf',
    'file_in': wds+'spSpec53534-2196-561_results.fits',
    },
    {
    'sname': 'J112333+671109',
    'file_out': wdp+'spec_decomp_J112333+671109.pdf',
    'file_in': wds+'spSpec51942-0491-483_results.fits',
    },
    {
    'sname': 'J022849-090153',
    'file_out': wdp+'spec_decomp_J022849-090153.pdf',
    'file_in': wds+'spSpec51908-0454-109_results.fits',
    },
# Our objects SDSS
    {
    'sname': 'J122732+075747',
    'file_out': wdp+'spec_decomp_J122732+075747.pdf',
    'file_in': wds+'spSpec53472-1626-592_results.fits',
    },
    {
    'sname': 'J171409+584906',
    'file_out': wdp+'spec_decomp_J171409+584906.pdf',
    'file_in': wds+'spSpec51703-0353-103_results.fits',
    },
    {
    'sname': 'J111552-000436',
    'file_out': wdp+'spec_decomp_J111552-000436.pdf',
    'file_in': wds+'spSpec51608-0279-176_results.fits',
    },
    {
    'sname': 'J110731+134712',
    'file_out': wdp+'spec_decomp_J110731+134712.pdf',
    'file_in': wds+'spSpec53377-1751-147_results.fits',
    },
    {
    'sname': 'J134244+053056',
    'file_out': wdp+'spec_decomp_J134244+053056.pdf',
    'file_in': wds+'spSpec51908-0454-109_results.fits',
    },
# Mage decomposition
    {
    'sname': 'J110731+134712',
    'file_out': wdp+'spec_decomp_mage_J110731+134712.pdf',
    'file_in': wds+'spSpec53377-1751-147_results.fits',
    'file_in_mage': '../../../data/mage_processed/analysis_1107+1347_170705/1107+1347_sdss_05psf.fits',
    },
    {
    'sname': 'J122732+075747',
    'file_out': wdp+'spec_decomp_mage_J122732+075747.pdf',
    'file_in': wds+'spSpec53472-1626-592_results.fits',
    'file_in_mage': '../../../data/mage_processed/analysis_1227+0758_170710/1227+0758_sdss_05psf.fits',
    },
    {
    'sname': 'J134244+053056',
    'file_out': wdp+'spec_decomp_mage_J134244+053056.pdf',
    'file_in': wds+'spSpec52373-0854-369_results.fits',
    'file_in_mage': '../../../data/mage_processed/analysis_1342+0530_170531/1342+0530_sdss_05psf.fits',
    },
    {
    'sname': 'J022849-090153',
    'file_out': wdp+'spec_decomp_mage_J022849-090153.pdf',
    'file_in': wds+'spSpec51908-0454-109_results.fits',
    'file_in_mage': '../../../data/mage_processed/analysis_0228-0901_180101/0228-0901_sdss_05psf.fits',
    },
    ]

for p in parameters:#[-3:-2]
    # print("Working on {}".format(p['file_out']))

    make_plot(p)
