import sys
import os
import time
import pidly
import numpy as np
import fitsio
import glob
import re
import shutil
import scoop
from scoop import utils
import matplotlib.pyplot as plt
from matplotlib.ticker import NullLocator, MultipleLocator, FormatStrFormatter, FixedLocator, FuncFormatter

# local project imports
# we can't easily make emission-line-fitting a package, so no relative imports, have to hack sys.path to import local code
#sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import settings
from common_pidly import setup_idl_environment
from common_utils import make_tmp_file, write_to_tmp, log_head, get_spectrum_abspath, extract_synth_table


def fit_abs_mdeg(work_dir, filelist, mpf):
    """
    Function for fitting absorption spectra.
    """
    mdegrees = range(5,36)

    idl = pidly.IDL()
    setup_idl_environment(idl)
    # path_ssp = '../template/pegase.hr/MILES/bspl/' # IC default path
    path_ssp = '~/sci/IDL_libs/libCHIL/pegase.hr/MILES/bspl/'# IK laptop path

    for mdeg in mdegrees:
        idl_command = """process_sdss,inptab='{}',\
            mom=2,lammi=3600.0,lamma=6790.0,mdeg={:d},\
            imin=imin, imax=imax, datapath='', \
            path_ssp='{}',/plot, \
            suffix='_KU.fits', inpageidx=range(0,67), inpmetidx=range(1,7), \
            outpath='tmp_fit/',/iter""".format(filelist, mdeg, path_ssp)
        idl(idl_command)

        name_old = work_dir + "spSpec{:05d}-{:04d}-{:03d}_results.fits".format(mpf[0], mpf[1], mpf[2])
        name_new = work_dir + "spSpec{:05d}-{:04d}-{:03d}_results_{:02d}.fits".format(mpf[0], mpf[1], mpf[2], mdeg)
        os.rename(name_old, name_new)

    idl.close()

    return


def fit_lines(filelist):
    """
    Gaussian fitting.
    """
    idl = pidly.IDL()
    setup_idl_environment(idl)
    idl_command = "emis_line_fitting_gaus, '{}'".format(filelist)
    idl_command = "emis_line_fitting_nonpar,'{}',verbose=1".format(filelist)
    idl_command = "emis_line_fitting_decomp,'{}',/sm,verbose=1".format(filelist)
    idl(idl_command)
    idl.close()

    return


def test_mdeg():
    """
    Test to realize how mdeg influences BH mass estimation.
    """
    # mjd plate fiberid 
    mpf = [53818, 2013, 9]

    # path to original SDSS files
    # /sdss7/spectro/1d_26/0342/1d/spSpec-51691-0342-430.fit 0342

    work_dir = "tmp_fit/"
    sdss_file = work_dir + "spSpec-{:05d}-{:04d}-{:03d}.fit".format(mpf[0], mpf[1], mpf[2])

    # Absorption fitting for different mdegree.
    tmpname = write_to_tmp(sdss_file)
    # fit_abs_mdeg(work_dir, tmpname, mpf)
    os.unlink(tmpname)

    # Emission line fitting.
    list = glob.glob(work_dir + 'spSpec?????-????-???_results_??.fits')
    liststr = ("\n".join(list))
    tmpname = write_to_tmp(liststr)
    # fit_lines(tmpname)
    os.unlink(tmpname)

    # Collect data and plot figures
    nfiles = len(list)
    mdeg = np.zeros(nfiles)
    chi2_abs = np.zeros(nfiles)
    chi2_nonp = np.zeros(nfiles)
    mbh = np.zeros(nfiles)
    sig = np.zeros(nfiles)
    flx = np.zeros(nfiles)

    for i,f in enumerate(list):
        fits = fitsio.FITS(f)
        abs = fits['SPECTRUM']
        gaus = fits['EMISSPEC']
        nonp = fits['EMISSPEC_NONP']
        dec = fits['EMIS_PROF_DECOMP']
        mdeg[i] = np.float(re.split('\.|\_',f)[3])
        chi2_abs[i] = abs['chi2'][0][0]
        chi2_nonp[i] = nonp['SPEC_CHI2DOF_SM'][0][0]
        sig[i] = dec['PARS_BLR_SIG'][0][0]
        flx[i] = dec['PARS_BLR_FLUX_HALPHA'][0][0] 
        mbh[i] = flux_sig_to_mbh(dec['PARS_BLR_FLUX_HALPHA'][0][0], dec['PARS_BLR_SIG'][0][0], gaus[0]['LOSVD_ALLOWED'][0][0]/299792.45).value

    # plot figures
    fig = plt.figure(figsize=(4, 8))
    ax1 = plt.subplot(4,1,1)
    ax2 = plt.subplot(4,1,2)
    ax3 = plt.subplot(4,1,3)
    ax4 = plt.subplot(4,1,4)
    
    ax1.plot(mdeg, chi2_abs, marker='o', mfc='magenta')
    ax1.plot(mdeg, chi2_nonp,marker='o', mfc='green')
    ax2.plot(mdeg, sig,marker='o', mfc='magenta')
    ax3.plot(mdeg, flx,marker='o', mfc='magenta')
    ax4.plot(mdeg, mbh,marker='o', mfc='magenta')
    ax1.set_ylabel(r'$\chi^2$ abs, nonp')
    ax2.set_ylabel(r'$\sigma_{BLR}$')
    ax3.set_ylabel(r'$F(H\alpha)$')
    ax4.set_ylabel(r'$M_{BH}$')
    ax4.set_xlabel('mdegree')
    ax1.set_title(os.path.basename(sdss_file))
    # ax.set_xlabel("Number of workers on gal-04")
    # ax.set_ylabel("Time, seconds")
    # ax.set_title("Collecting data of 39725 objects")
    # # ax.set_yscale('symlog')
    # ax.xaxis.set_major_locator(MultipleLocator(2))
    # ax.xaxis.set_minor_locator(MultipleLocator(1))
    # ax.yaxis.set_minor_locator(MultipleLocator(10))
    fig.savefig(sdss_file.replace('.fit','_mdeg.png'), bbox_inches='tight')

    return


def fit_nonp_synth(args):
    """
    MC simulation to estimate the BH detection limit for given object.
    This function makes nonparametric fitting of synthetic mock spectra.
    @return: filename of fits file where output table is stored
    """
    log_head(args)

    pid = args[0] + 1
    proc_n = args[1]
    task_name = args[2]
    infile = args[3]

    outfile = "{}/tmp_{}_{:02g}{:02g}.fits".format(os.path.dirname(infile), task_name, pid, proc_n)
    scoop.logger.info("Output file {}".format(outfile))

    idl = pidly.IDL()
    setup_idl_environment(idl)

    # run nonp fit for synth spectra
    file = infile.replace('.fits','_synth.fits')
    idl.pro('emis_line_fitting_nonpar_stack', file, outfile, cpu=[pid, proc_n], verbose=0)
    idl.close()

    return outfile


def fit_decomp_synth(args):
    """
    MC simulation to estimate the BH detection limit for given object.
    This function makes decomposition of allowed profiles recovered from the
    synthetic mock spectra.
    
    @return: filename of fits file where output table is stored
    """
    log_head(args)

    pid = args[0] + 1
    proc_n = args[1]
    task_name = args[2]
    infile = args[3]

    idl = pidly.IDL()
    setup_idl_environment(idl)

    # run nonp fit for synth spectra
    file = infile.replace('.fits','_synth.fits')
    outfile = (idl.pro('mc_decomp_spec', file, cpu=[pid, proc_n], 
        verbose=0, regularization=True, lsf_sig_corr=True))

    print(outfile)
    idl.close()

    # The name of file where resulted table is stored. This name should be
    # equivalent one in the IDL routine emis_line_fitting_decomp_stack
    outfile = file.replace('.fits','_tmp_{:02g}{:02g}.fits'.format(pid, proc_n))

    return outfile


def make_synth_sample(infile, nsim=1000):
    """
    Function to make synthetic mock data sample by executing IDL routine.
    @param: nsim is number of synthetic spectrum
    """
    idl = pidly.IDL()
    setup_idl_environment(idl)
    # make file with synthetic spectra
    idl.pro('mc_make_synth_sample', infile, nsim=nsim)
    idl.close
    return


def test_emisnonp():
    """
    Temporal function for testing changes in the emis_nonpar script
    """
    test_reference_file = "tests/test_fits/spSpec52826-1037-499_test_reference.fits"

    # test_file = 
    # copyfile(src, dst)

    idl = pidly.IDL()
    setup_idl_environment(idl)
    idl_command = "emis_line_fitting_nonpar,'{}',verbose=1".format("tests/test_fits/tlist")
    idl(idl_command)
    idl.close()
    nonp = fitsio.read(test_reference_file,'EMISSPEC_NONP')
    print("============= Reference =============")
    for vel in ['LINE_FLUX', 'SPEC_CHI2DOF', 'SPEC_CHI2DOF_SM']:
        print nonp[0][vel]
    print("============= Current file =============")
    test_current_file = "tests/test_fits/spSpec52826-1037-499_test_orig.fits"
    nonp = fitsio.read(test_current_file,'EMISSPEC_NONP')
    for vel in ['LINE_FLUX', 'SPEC_CHI2DOF', 'SPEC_CHI2DOF_SM']:
        print nonp[0][vel]

    return


def test_decomp():
    """
    Temporal function for testing changes in the decomp script
    """
    file = "tmp_fit/spSpec53818-2013-009_results_25_for_test_decomp.fits"
    lst = "tmp_fit/tlist"

    idl = pidly.IDL()
    setup_idl_environment(idl)
    # outfile = "{}/tmp_{}_{}{}.fits".format(os.path.dirname(infile), task_name, pid, proc_n)
    idl.pro('emis_line_fitting_decomp_stack',lst, cpu=[3,4], verbose=1)
    idl.close()

    return


def concatenate_synth_fits(file_list, type):
    """
    Function to concatenate fits files and write it to the new extension of the main file.
    @param: file_list: is a list of file names needed to join
    @param: type: nonpar or decomp result tables should be processed
    @return: numpy structured array
    """

    config = {
        'nonpar': {
            'ext': 'EMIS_SYNTH_SPECTRA_FITTED',
            'ref': 'SPEC_CHI2DOF',
            'names': ['LINE_FLUX', 'LINE_FLUX_SM', 'LINE_FLUX_ERR',
                'LOSVD_FORBIDDEN', 'LOSVD_ALLOWED', 'LOSVD_FORBIDDEN_SM', 
                'LOSVD_ALLOWED_SM', 'SPEC_FIT', 'SPEC_FIT_SM', 'SPEC_CHI2DOF', 
                'SPEC_CHI2DOF_SM']
            },
        'decomp': {
            'ext': 'STACK_PROF_DECOMP',
            'ref': 'DECOMP_CHI2',
            'names': ['SPEC_BLR_HALPHA', 'LOSVD_ALLOWED', 'LOSVD_NLR', 'LOSVD_BLR',
                'LOSVD_NLR_NOBLR', 'LOSVD_FIT', 'LOSVD_FIT_NOBLR', 'ID_SIM', 'MAX_ALLOWED',
                'ERROR_LEVEL', 'NLR_INT', 'NLR_NOBLR_INT', 'NLR_FRAC', 'NLR_STDDEV',
                'NLR_MAX', 'BLR_INT', 'BLR_POS', 'BLR_SIG', 'BLR_MAX', 'BLR_FRAC',
                'BLR_FLUX_HALPHA', 'FLUX_HALPHA', 'RMS', 'RMS_BS', 'DECOMP_CHI2',
                'DECOMP_CHI2DOF', 'DECOMP_CHI2_NOBLR', 'DECOMP_CHI2DOF_NOBLR', 
                'NBIN', 'F_VALUE', 'PROB_F', 'NLR_INT_ERR', 'NLR_NOBLR_INT_ERR', 'BLR_INT_ERR', 
                'BLR_POS_ERR', 'BLR_SIG_ERR', 'BLR_FLUX_HALPHA_ERR', 'MBH', 
                'MBH_ERR', 'MBH_INPUT', 'BLR_SIG_INPUT', 'BLR_FLUX_HALPHA_INPUT', 'FULL_FLUX_HALPHA_INPUT', 
                'BLR_POS_INPUT', 'BLR_FRAC_INPUT']
            },
        'decomp_spec': {
            'ext': 'EMIS_SYNTH_SPECTRA_DECOMP',
            'ref': 'PARS_CHI2',
            'names': ['ID_SIM', 'SPEC_EMIS', 
                'SPEC_COMPS_NLR', 'SPEC_COMPS_BLR', 'SPEC_FIT', 'SPEC_FIT_NOBLR', 
                'LOSVD_NLR', 'LOSVD_NOBLR_NLR', 'PARS_COMP_FLUX_NLR', 
                'PARS_COMP_FLUX_BLR', 'PARS_COMP_FLUX_NLR_ERR', 'PARS_COMP_FLUX_BLR_ERR', 
                'PARS_NLR_FLUX_HBETA', 'PARS_NLR_FLUX_HALPHA', 'PARS_BLR_FLUX_HBETA', 
                'PARS_BLR_FLUX_HALPHA', 'PARS_NLR_FLUX_HBETA_ERR', 'PARS_NLR_FLUX_HALPHA_ERR',
                'PARS_BLR_FLUX_HBETA_ERR', 'PARS_BLR_FLUX_HALPHA_ERR',
                'PARS_MAX_HA', 'PARS_ERROR_LEVEL', 'PARS_NLR_MAX', 'PARS_NLR_STDDEV', 
                'PARS_BLR_POS', 'PARS_BLR_SIG', 'PARS_BLR_MAX', 
                'PARS_BLR_FRAC_HA', 'PARS_BLR_FLUX_HA', 'PARS_FLUX_HA', 'PARS_MBH', 
                'PARS_MBH_ERR', 'PARS_RMS', 'PARS_CHI2', 'PARS_CHI2DOF', 'PARS_F_VALUE', 
                'PARS_PROB_F', 'PARS_NOBLR_CHI2', 'PARS_NOBLR_CHI2DOF', 'PARS_BLR_POS_ERR', 
                'PARS_BLR_SIG_ERR', 'MBH_INPUT', 
                'BLR_SIG_INPUT', 'BLR_FLUX_HALPHA_INPUT', 'FULL_FLUX_HALPHA_INPUT', 
                'BLR_POS_INPUT', 'BLR_DECR_INPUT']
            }
        }

    try:
        types = config[type]
    except KeyError:
        raise

    file_list = sorted(file_list)

    # read files
    tbls = [fitsio.read(f, ext=types['ext']) for f in file_list]

    # calculate total number of simulations
    nsim = 0
    for t in tbls:
        nsim += len(t[types['ref']].flatten())

    # create new extended table
    dtypes_orig = tbls[0].dtype
    nsim_in_one_table = len(tbls[0][types['ref']].flatten())

    # temporarly code lines
    # for desc in dtypes_orig.descr:
    #     if(desc[0] in types['names']): # for fields which should be extend
    #         print desc

    dtypes = []
    for desc in dtypes_orig.descr:
        if(desc[0] in types['names']): # for fields which should be extend
            if(len(desc[2]) == 3): # for 3d column like SPEC_COMPS_NLR
                dtypes.append( (desc[0], desc[1], (nsim, desc[2][1], desc[2][2]) ) )
            if(len(desc[2]) == 2): # for 2d column like LOSVD_BLR
                dtypes.append( (desc[0], desc[1], (nsim, desc[2][1]) ) )
            if(len(desc[2]) == 1): # for 1d column like MBH
                dtypes.append( (desc[0], desc[1], (nsim,) ))
        else:
            dtypes.append( desc ) # for other unchanged columns

    data = np.zeros(1, dtype=dtypes)

    names = data.dtype.names
    for name in names:
        if name in types['names']: # for fields which should be extend
            data_tbls = []
            for t in tbls: # iterate over all tables where columns can be in differens shape
                nrows = np.size(t[types['ref']])
                if nrows == 1: # for tables with only 1 row
                    if(np.size(t[0][name]) == 1): # 1 element field
                        data_tbls.append(t[0][name].reshape(1))
                    else:
                        data_tbls.append(t[0][name].reshape((1,np.size(t[0][name])))) # fields of elements array
                else:
                    data_tbls.append(t[0][name])
            # import ipdb; ipdb.set_trace()
            data[name] = np.concatenate((data_tbls))
        else:
            data[name] = tbls[0][name]

    return data


if __name__ == '__main__':
    """
    Be sure that all environment settings as well as paths and code are the same on all host servers.

    Usage examples:
    (for Gal-04 server)
        python -m scoop -n 24 object_qc.py make_synth 24  /data1/pro/SDSS/fit_agn/2013/spSpec53818-2013-009_results.fits 3000
    (for laptop)
        python -m scoop -n 4 object_qc.py make_synth 4 ~/sci/work/catalog/uv-to-nir-catalog/emission-line-fitting/tmp_fit/spSpec53818-2013-009_results.fits 100

    First line in scoop's hostfile must point to localhost. The localhost must have at least one worker.
    Debug principle: 1) find the buggy host by editing hostfile; 2) copy-paste worker function.
    """
    scoop_time = time.time()

    try:
        task = sys.argv[1]
        proc_n = int(sys.argv[2])
        infile = sys.argv[3]
        nsim = sys.argv[4]
    except IndexError:
        print utils.getEnv()
        scoop.logger.error('Launch requires 4 arguments, got {:d}'.format(len(sys.argv) - 1))
        sys.exit(1)

    # create MC dir if it does not exist
    if not os.path.exists(settings.MC_TMP_DIR):
        os.makedirs(settings.MC_TMP_DIR)
        scoop.logger.info('Created temporary directory: {}'.format(settings.MC_TMP_DIR))

    # resolve input spectrum file to abspath
    infile_abs = get_spectrum_abspath(None, None, None, filename=infile)
    if not infile_abs:
        scoop.logger.error("Input file {} can't be resolved to abspath".format(infile))
        sys.exit(1)
    # copy file to MC dir, don't overwrite if already exists
    infile_copy = os.path.join(settings.MC_TMP_DIR, os.path.basename(infile_abs))
    if not os.path.exists(infile_copy):
        shutil.copy(infile_abs, settings.MC_TMP_DIR)
        scoop.logger.info('Copied {} to {}'.format(infile_abs, infile_copy))
    synthfile = infile_copy.replace('.fits','_synth.fits')

    proc_ids = range(proc_n)

    # args is like [(0, 4, 'function_name'), (1, 4, 'function_name'), (2, 4, ...), ...]
    args = zip( proc_ids, [proc_n] * proc_n, [task] * proc_n, [infile_copy] * proc_n )

    if task == 'make_synth':
        # generate synthetic dataset
        make_synth_sample(infile_copy, nsim=nsim)

        # decompose  generated dataset
        chunk_fits_files_decomp = list(scoop.futures.map_as_completed(fit_decomp_synth, args))

        # concatenate temporary tables with worker results into the single one
        data = concatenate_synth_fits(chunk_fits_files_decomp, 'decomp_spec')
        fits = fitsio.FITS(synthfile, 'rw')
        fits.write_table(data, extname='EMIS_SYNTH_SPECTRA_DECOMP')
        fits.close()

        # extract table for analysing
        extract_synth_table(synthfile, synthfile.replace('_synth.fits', '_synth_restable.fits'))
        
    else:
        scoop.logger.error("Wrong task {}".format(task))
        sys.exit(1)

    # common cleanup and messaging code
    # clean tmp tables
    for f in chunk_fits_files_decomp:
        os.unlink(f)
        scoop.logger.info('Cleaning up file {}'.format(f))

    scoop.logger.info('Created or updated {}'.format(synthfile))
    scoop.logger.info('All done in {:.3f} seconds'.format(time.time() - scoop_time))
