import sys
import os
import datetime
import time
import pidly
import numpy as np
import fitsio
import re
import glob
import scoop
from scoop import utils
from scipy.stats import f as fprob
import operator

# local project imports
import settings
from common_pidly import setup_idl_environment
from common_utils import make_file_list, log_head
from flux_sig_to_mbh import flux_sig_to_mbh


c = 299792.45


def fit_lines_gaus(args):
    """
    Worker function for gaussian line fitting.
    """
    scoop.logger.info( "Gaussian line fitting" )
    scoop.logger.info( args )
    idl = pidly.IDL()
    setup_idl_environment( idl )    
    idl_command = "emis_line_fitting_gaus, '{}', cpu=[{},{}], /logging".format(args[2],args[0]+1,args[1])    
    scoop.logger.info( "Computing command:" )
    scoop.logger.info( idl_command )
    idl( idl_command )    
    idl.close()    
    scoop.logger.info("Finished")
    return


def fit_lines_nonpar(args):
    """
    Worker function for non-parametric line fitting.
    """
    scoop.logger.info( "Non-parametric line fitting" )
    scoop.logger.info( args )
    idl = pidly.IDL()
    setup_idl_environment( idl )    
    idl_command = "emis_line_fitting_nonpar, '{}', cpu=[{},{}], /logging".format(args[2],args[0]+1,args[1])    
    scoop.logger.info( "Computing command:" )
    scoop.logger.info( idl_command )
    idl( idl_command )    
    idl.close()    
    scoop.logger.info("Finished")
    return


def mc_nonp_profiles_fluxes(args):
    """
    Worker function for MC simulation of non-parametric line fitting results.
    """
    flog = makelog(args)
    scoop.logger.info( "MC simulations for non-parametric line fitting" )
    scoop.logger.info( args )
    scoop.logger.info( "Log file: " + flog )
    idl = pidly.IDL()
    setup_idl_environment( idl )    
    idl_command = "calc_errors_nonp_profile, '{}', cpu=[{},{}], /logging/".format(args[2],args[0]+1,args[1])    
    scoop.logger.info( "Computing command:" )
    scoop.logger.info( idl_command )
    idl( idl_command )    
    idl.close()    
    scoop.logger.info( "Finished: ID {} from {} nIDs".format(args[0]+1, args[1]) )
    return


def profile_decompose(args):
    """
    Worker function for profile decomposition.
    """
    flog = makelog(args)
    scoop.logging.basicConfig(filename=flog, format="[%(asctime)-15s] %(module)-9s %(levelname)-7s %(message)s")
    scoop.logger.info( "Profile decomposition" )
    scoop.logger.info( args )
    scoop.logger.info( "Log file: " + flog )
    idl = pidly.IDL()
    setup_idl_environment( idl )
    # idl_command = "emis_line_fitting_decomp, '{}', cpu=[{},{}], /sm, /logging, logfile='{}'".format(args[2],args[0]+1,args[1],flog)
    idl_command = "emis_line_fitting_decomp_spec, '{}', cpu=[{},{}], /regularization, /lsf, /logging, logfile='{}'".format(args[2],args[0]+1,args[1],flog)
    scoop.logger.info( "Computing command:" )
    scoop.logger.info( idl_command )
    # start IDL
    idl_feedback = idl( idl_command, ret=True)
    idl.close()
    # write to log IDL feedback
    scoop.logger.info("================== IDL FEEDBACK ==================")
    for line in idl_feedback.split("\n"):
        if not re.search('% Compiled module:', line):
            scoop.logger.info(line)
    scoop.logger.info("====================== END =======================")
    
    scoop.logger.info( "Finished: ID {} from {} nIDs".format(args[0]+1, args[1]) )
    return

def find_bad_files(args):
    """
    Routine to find files which has in the fits extension Image instead of Bintable
    """
    list = args[2]
    ID = args[0] + 1
    nIDs = args[1]

    with open(list) as f:
        files = f.read().splitlines()

    nfiles = len(files)

    # select files for given chunk
    list_limits = utils_chunk(nfiles, ID, nIDs)
    files = files[list_limits[0]:list_limits[1]+1]
    nfiles = len(files)

    out = [["", ""]] * nfiles

    for count, file in enumerate(files):
    # Read header of 1st extension
        try:
            h = fitsio.read_header(file, ext=1)
            out[count] = [file, h['EXTNAME']]
            if count % 500 == 0:
                scoop.logger.info("{} {:d} from {:d}".format(os.path.basename(file), count, nfiles))
        except IOError as e:
            scoop.logger.error(file)
            scoop.logger.error(e)
            out[count] = [file, "BAD_FILE"]
        except ValueError as e:
            scoop.logger.error(file)
            scoop.logger.error(e)
            out[count] = [file, "BAD_FILE"]


    log_name = 'tmp_findbadfile_{:02d}{:02d}.log'.format(args[0] + 1, args[1])
    scoop.logger.info( "File with 1st extname: {}".format(log_name) )

    thefile = open(log_name, 'w')
    for line in out:
        thefile.write( "{} {}\n".format(line[0], line[1]) )
    thefile.close()

    scoop.logger.info( "Finished: ID {} from {} nIDs".format(args[0]+1, args[1]) )
    return log_name

def collect_data(args):
    """
    Worker function for assemblying data from files.
    """
    # flog = makelog(args)
    log_head(args)

    # scoop.logger.info("Log file: " + flog)

    data = collect_tables(args[2], args[0] + 1, args[1], nprint=100, write_fits=False)
    # idl = pidly.IDL()
    # setup_idl_environment( idl )    
    # idl_command = "emis_line_fitting_decomp, '{}', cpu=[{},{}], /sm, /logging, logfile='{}'".format(args[2],args[0]+1,args[1],flog)    
    # scoop.logger.info( "Computing command:" )
    # scoop.logger.info( idl_command )
    # idl( idl_command )    
    # idl.close()    
    scoop.logger.info( "Finished: ID {} from {} nIDs".format(args[0] + 1, args[1]) )
    return data

def collect_data_decomp(args):
    """
    Worker function for assemblying decomposition data from files.
    """
    # log_head(args)
    data = collect_decomp_table(args[2], args[0] + 1, args[1], nprint=100, write_fits=False)
    scoop.logger.info( "Finished: ID {} from {} nIDs".format(args[0] + 1, args[1]) )

    return data

list_files = make_file_list(mask='fit_agn/0???', return_array=True)
def collect_data_fake(file):
    """
    Fake worker function to process single file
    """
    try:
        fits = fitsio.FITS(file)
        next = len(fits)
                        
        # filling gaussian emission line table
        emis = fits['EMISSPEC'].read()
    except:
        scoop.logger.error('heh')
    finally:
        fits.close()
                                                                    
    return np.ones(2)
    

def utils_chunk(full_length, ID, nIDs):    
    if ID <= 0:
        ID = 1
    if ID > nIDs:
        ID = nIDs

    chunk_size = np.int( np.ceil( np.float(full_length) / nIDs ) )
    chunk_limits = [chunk_size*(ID-1),chunk_size*ID-1]

    return np.clip( chunk_limits, 0, full_length-1 ).tolist()


def collect_nonp_mc(list, outfile, nprint=None):
    """
    Function for collection emission line fitting parameters from mock of files. _MC fields are taken into account.
    """

    with open(list) as f:
        files = f.read().splitlines() 
    
    nfiles = len(files)
    # make prefixes
    nonp = fitsio.read( files[0], ext='EMISSPEC_NONP')
    prefixes = []
    for ln in nonp['LINE_NAME'][0]:                
        wn = re.findall('\d+\.\d+',ln)[0]
        pref = 'F{}_{}'.format( int(round((float(wn)))), ln[0:re.search(wn,ln).start()] )
        prefixes.append( pref )
        
    # make output structured array 
    dtypes = ( [('mjd',np.int32),
                ('plate',np.int32),
                ('fiberid',np.int16),
                ('forbid_v',np.float32),
                ('forbid_sig',np.float32),
                ('allowed_v',np.float32),
                ('allowed_sig',np.float32),
                ('losvd_vsys',np.float32),
                ('chi2dof_emis',np.float32)]
            )

    for p in prefixes:
        for suf in ['_FLX','_FLX_ERR','_FLX_ERR_MC','_CNT','_CNT_ERR','_EW','_EW_ERR']:
            dtypes.append( ( p + suf, np.float32) )

    data = np.empty( nfiles, dtype = dtypes )
    data[:] = np.nan

    suffixes = ([['_FLX','LINE_FLUX'],
                ['_FLX_ERR','LINE_FLUX_ERR'],
                ['_FLX_ERR_MC','mc_err_flux'],
                ['_CNT','LINE_CONT'],
                ['_CNT_ERR','LINE_CONT_ERR'],
                ['_EW','LINE_EW'],
                ['_EW_ERR','LINE_EW_ERR']] )
    
    for suf in suffixes:
        index_items = [ s for s in data.dtype.names if s.endswith(suf[0]) ]
        suf.append(index_items)
    
    for count,file in enumerate(files):
        fname = os.path.basename(file)

        mjd = fname[6:11]
        plate = fname[12:16]
        fiberid = fname[17:20]

        data['mjd'][count] = mjd
        data['plate'][count] = plate
        data['fiberid'][count] = fiberid

        if nprint:
            if count % nprint == 0:
                print count, fname    

        try:
            nonp = fitsio.read( file, ext='EMISSPEC_NONP')                        
            data['forbid_v'][count] = np.sum(nonp['LOSVD_FORBIDDEN']*nonp['LOSVD_VBIN_FORBIDDEN'])
            data['forbid_sig'][count] = np.sqrt( np.sum(nonp['LOSVD_FORBIDDEN']*nonp['LOSVD_VBIN_FORBIDDEN']**2) - np.sum(nonp['LOSVD_FORBIDDEN']*nonp['LOSVD_VBIN_FORBIDDEN'])**2 )
            data['allowed_v'][count] = np.sum(nonp['LOSVD_ALLOWED']*nonp['LOSVD_VBIN_ALLOWED'])
            data['allowed_sig'][count] = np.sqrt( np.sum(nonp['LOSVD_ALLOWED']*nonp['LOSVD_VBIN_ALLOWED']**2) - np.sum(nonp['LOSVD_ALLOWED']*nonp['LOSVD_VBIN_ALLOWED'])**2 )
            data['losvd_vsys'][count] = nonp['LOSVD_VSYS']
            data['chi2dof_emis'][count] = nonp['SPEC_CHI2DOF']

            for suf in suffixes:                            
                for icol, col in enumerate(suf[2]):
                    data[col][count] = nonp[suf[1]][0][icol]
                    
        except IOError as e:
            print file
            print e
        except ValueError as e:
            print file
            print e

    fitsio.write(outfile, data, extname='restable_nonpar',clobber=True)
    return data, suffixes


def collect_tables(list, ID, nIDs, nprint=None, gaus=False, nonpar=False, decomp=False, all=True, write_fits=True):
    """
    Function for collection emission line fitting parameters from mock of files. _MC fields are taken into account.
    Example of usage:
    data, suf = collect_gaus('../agn_bh/fit_agn/flist_checks.txt','/Users/katkov/sci/work/catalog/res_files/short2k_restable_gaus.fits',nprint=True)

    @param write_fits: if True, write out resulting FITS

    @return:           Returns np.array() filled with collected information for the subsequent assembly of global array 
                       from invidual chunks returned by workers
    """

    with open(list) as f:
        files = f.read().splitlines()

    nfiles = len(files)

    # select files for given chunk
    list_limits = utils_chunk(nfiles, ID, nIDs)
    files = files[list_limits[0]:list_limits[1]+1]
    nfiles = len(files)

    # make output structured array for emission line fitting results
    if gaus or nonpar or all:
        line_names = [
                'F3727_OII', 'F3730_OII', 'F3751_H_kappa', 'F3772_H_iota',
                'F3799_H_theta', 'F3836_H_eta', 'F3870_NeIII', 'F3889_HeI',
                'F3890_H_dzita', 'F3971_H_epsilon', 'F4070_SII', 'F4078_SII',
                'F4103_H_delta', 'F4342_H_gamma', 'F4364_OIII', 'F4687_HeII',
                'F4713_ArIV', 'F4742_ArIV', 'F4863_H_beta', 'F4960_OIII', 'F5008_OIII',
                'F5199_NI', 'F5202_NI', 'F5756_NII', 'F5877_HeI', 'F6302_OI',
                'F6366_OI', 'F6550_NII', 'F6565_H_alpha', 'F6585_NII', 'F6680_HeI',
                'F6718_SII', 'F6733_SII'
                ]

        line_dtypes = ([('mjd', np.int32),
                   ('plate', np.int32),
                   ('fiberid', np.int16),
                   ('forbid_v', np.float32),
                   ('forbid_v_err', np.float32),
                   ('forbid_sig', np.float32),
                   ('forbid_sig_err', np.float32),
                   ('allowed_v', np.float32),
                   ('allowed_v_err', np.float32),
                   ('allowed_sig', np.float32),
                   ('allowed_sig_err', np.float32),
                   ('chi2dof_emis', np.float32)]
                  )

        for p in line_names:
            for suf in ['_FLX', '_FLX_ERR', '_CNT', '_CNT_ERR', '_EW', '_EW_ERR']:
                line_dtypes.append((p + suf, np.float32))

        # create a table filled by NANs
        line_data = np.empty(nfiles, dtype=line_dtypes)
        line_data[:] = np.nan

        suffixes = ([['_FLX', 'LINE_FLUX'],
                     ['_FLX_ERR', 'LINE_FLUX_ERR'],
                     ['_CNT', 'LINE_CONT'],
                     ['_CNT_ERR', 'LINE_CONT_ERR'],
                     ['_EW', 'LINE_EW'],
                     ['_EW_ERR', 'LINE_EW_ERR']])

        for suf in suffixes:
            index_items = [s for s in line_data.dtype.names if s.endswith(suf[0])]
            suf.append(index_items)

        # after this manipulations suffix list looks like:
        # [['_FLX', 'LINE_FLUX', ['F3727_OII_FLX', 'F3730_OII_FLX', ..., 'F6718_SII_FLX', 'F6733_SII_FLX']], 
        #  ['_FLX_ERR', 'LINE_FLUX_ERR', ['F3727_OII_FLX_ERR', ...]],
        #  ['_CNT', 'LINE_CONT', ['F3727_OII_CNT', 'F3730_OII_CNT', ...]],
        #  ['_CNT_ERR', 'LINE_CONT_ERR', ['F3727_OII_CNT_ERR', ...]],
        #  ['_EW', 'LINE_EW', ['F3727_OII_EW', 'F3730_OII_EW', ...]],
        #  ['_EW_ERR', 'LINE_EW_ERR', ['F3727_OII_EW_ERR', ...]]]

    if gaus or all:
        data_gaus = np.copy(line_data)
    if nonpar or all:
        # this construction in order to avoid forbid_v_err, forbid_sig_err, 
        # allowed_v_err, allowed_sig_err fields which is inapplicable for nonpar table
        idx_no_err_vectors = [x for x in line_data.dtype.names if not '_err' in x]
        data_nonpar = np.copy(line_data[idx_no_err_vectors])
        data_nonpar_sm = np.copy(line_data[idx_no_err_vectors])
    del line_data

    # make output structured array for decomposition results
    if decomp or all:
        decomp_dtypes = ([
                ('mjd', np.int32),
                ('plate', np.int32),
                ('fiberid', np.int16),
                ('MAX_ALLOWED', np.float32),
                ('ERROR_LEVEL', np.float32),
                ('NLR_INT', np.float32),
                ('NLR_NOBLR_INT', np.float32),
                ('NLR_FRAC', np.float32),
                ('NLR_STDDEV', np.float32),
                ('NLR_MAX', np.float32),
                ('BLR_INT', np.float32),
                ('BLR_POS', np.float32),
                ('BLR_SIG', np.float32),
                ('BLR_MAX', np.float32),
                ('BLR_FRAC', np.float32),
                ('BLR_FLUX_Halpha', np.float32),
                ('RMS', np.float32),
                ('RMS_BS', np.float32),
                ('CHI2', np.float32),
                ('CHI2DOF', np.float32),
                ('CHI2_NOBLR', np.float32),
                ('CHI2DOF_NOBLR', np.float32),
                ('NBIN', np.float32),
                ('F_VALUE', np.float32),
                ('PROB_F', np.float32),
                ('NLR_INT_ERR', np.float32),
                ('NLR_NOBLR_INT_ERR', np.float32),
                ('BLR_INT_ERR', np.float32),
                ('BLR_POS_ERR', np.float32),
                ('BLR_SIG_ERR', np.float32),
                ('blr_flux_halpha_err', np.float32)
                ])
        # create a table filled by NANs
        data_decomp = np.empty(nfiles, dtype=decomp_dtypes)
        data_decomp[:] = np.nan

    tables = []
    if gaus or all:
        tables.append([data_gaus,'gaus'])
    if nonpar or all:
        tables.append([data_nonpar,'nonpar'])
        tables.append([data_nonpar_sm,'nonpar_sm'])
    if decomp or all:
        tables.append([data_decomp,'decomp'])

    for count, file in enumerate(files):
        fname = os.path.basename(file)

        mjd = fname[6:11]
        plate = fname[12:16]
        fiberid = fname[17:20]

        for t in tables:
            t[0]['mjd'][count] = mjd
            t[0]['plate'][count] = plate
            t[0]['fiberid'][count] = fiberid

        if nprint:
            if count % nprint == 0:
                scoop.logger.info("{} {}".format(count, fname))

        try:
            fits = fitsio.FITS(file)
            next = len(fits)

            # filling gaussian emission line table
            if gaus or all:
                emis = fits['EMISSPEC'].read()
                data_gaus['forbid_v'][count] = emis['LOSVD_FORBIDDEN'][0][0]
                data_gaus['forbid_v_err'][count] = emis['LOSVD_FORBIDDEN_ERR'][0][0]
                data_gaus['forbid_sig'][count] = emis['LOSVD_FORBIDDEN'][0][1]
                data_gaus['forbid_sig_err'][count] = emis[
                    'LOSVD_FORBIDDEN_ERR'][0][1]
                data_gaus['allowed_v'][count] = emis['LOSVD_ALLOWED'][0][0]
                data_gaus['allowed_v_err'][count] = emis['LOSVD_ALLOWED_ERR'][0][0]
                data_gaus['allowed_sig'][count] = emis['LOSVD_ALLOWED'][0][1]
                data_gaus['allowed_sig_err'][count] = emis[
                    'LOSVD_ALLOWED_ERR'][0][1]
                data_gaus['chi2dof_emis'][count] = emis['SPEC_CHI2DOF'][0]

                for suf in suffixes:
                    for icol, col in enumerate(suf[2]):
                        data_gaus[col][count] = emis[suf[1]][0][icol]

            # filling nonparametric table of results
            if nonpar or all:
                nonp = fits['EMISSPEC_NONP'].read()

                # filling nonparametric table and smoothed nonparametric
                for t,s in zip([data_nonpar, data_nonpar_sm], ['','_SM']):
                    t['forbid_v'][count] = \
                        np.nansum(nonp['LOSVD_FORBIDDEN' + s] * nonp['LOSVD_VBIN_FORBIDDEN'])
                    t['forbid_sig'][count] = \
                        np.sqrt(np.nansum(nonp['LOSVD_FORBIDDEN' + s] * nonp['LOSVD_VBIN_FORBIDDEN']**2) - \
                        np.nansum(nonp['LOSVD_FORBIDDEN' + s] * nonp['LOSVD_VBIN_FORBIDDEN'])**2)
                    t['allowed_v'][count] = \
                        np.nansum(nonp['LOSVD_ALLOWED' + s] * nonp['LOSVD_VBIN_ALLOWED'])
                    t['allowed_sig'][count] = \
                        np.sqrt(np.nansum(nonp['LOSVD_ALLOWED' + s] * nonp['LOSVD_VBIN_ALLOWED']**2) - \
                        np.nansum(nonp['LOSVD_ALLOWED' + s] * nonp['LOSVD_VBIN_ALLOWED'])**2)
                    t['chi2dof_emis'][count] = nonp['SPEC_CHI2DOF' + s][0]

                    # iteration over all emission lines and variouse types of parameters like _FLUX, _EW
                    for suf in suffixes:
                        if not '_ERR' in suf[1]: # ordinary and regularised tables have the same error vectors
                            for icol, col in enumerate(suf[2]):
                                t[col][count] = nonp[suf[1] + s][0][icol]
                        else:
                            for icol, col in enumerate(suf[2]):
                                t[col][count] = nonp[suf[1]][0][icol]

            # filling results of balmer line decomposition
            if decomp or all:
                dec = fits['EMIS_PROF_DECOMP'].read()

                nbin = len(dec['LOSVD_ALLOWED'][0])
                f_value = dec['PARS_NOBLR_CHI2DOF'][0] * (nbin - 1) - dec['PARS_CHI2'][0] / 3 / dec['PARS_CHI2DOF'][0]
                prob_f = fprob.cdf(f_value,nbin-1,nbin-4)
                rms = np.nanstd(dec['LOSVD_ALLOWED'][0] - dec['LOSVD_FIT'][0])
                vsys = np.sum(dec['LOSVD_VBIN'][0] * dec['LOSVD_ALLOWED'][0])
                mask_line_center = np.abs(dec['LOSVD_VBIN'][0] - vsys) < dec['PARS_BLR_SIG'][0]
                rms_bs = np.nanstd(dec['LOSVD_ALLOWED'][0][mask_line_center] - dec['LOSVD_FIT'][0][mask_line_center])

                data_decomp['MAX_ALLOWED'][count] = dec['PARS_MAX_ALLOWED'][0]
                data_decomp['ERROR_LEVEL'][count] = dec['PARS_ERROR_LEVEL'][0]
                data_decomp['NLR_INT'][count] = dec['PARS_NLR_INT'][0]
                data_decomp['NLR_NOBLR_INT'][count] = dec['PARS_NLR_INT'][0]
                data_decomp['NLR_FRAC'][count] = dec['PARS_NLR_FRAC'][0]
                data_decomp['NLR_STDDEV'][count] = dec['PARS_NLR_STDDEV'][0]
                data_decomp['NLR_MAX'][count] = dec['PARS_NLR_MAX'][0]
                data_decomp['BLR_INT'][count] = dec['PARS_BLR_INT'][0]
                data_decomp['BLR_POS'][count] = dec['PARS_BLR_POS'][0]
                data_decomp['BLR_SIG'][count] = dec['PARS_BLR_SIG'][0]
                data_decomp['BLR_MAX'][count] = dec['PARS_BLR_MAX'][0]
                data_decomp['BLR_FRAC'][count] = dec['PARS_BLR_FRAC'][0]
                data_decomp['BLR_FLUX_Halpha'][count] = dec['PARS_BLR_FLUX_HALPHA'][0]
                data_decomp['RMS'][count] = rms
                data_decomp['RMS_BS'][count] = rms_bs
                data_decomp['CHI2'][count] = dec['PARS_CHI2'][0]
                data_decomp['CHI2DOF'][count] = dec['PARS_CHI2DOF'][0]
                data_decomp['CHI2_NOBLR'][count] = dec['PARS_NOBLR_CHI2'][0]
                data_decomp['CHI2DOF_NOBLR'][count] = dec['PARS_NOBLR_CHI2DOF'][0]
                data_decomp['NBIN'][count] = nbin
                data_decomp['F_VALUE'][count] = f_value
                data_decomp['PROB_F'][count] = prob_f
                data_decomp['NLR_INT_ERR'][count] = dec['PARS_NLR_INT_ERR'][0]
                data_decomp['NLR_NOBLR_INT_ERR'][count] = dec['PARS_NOBLR_NLR_INT_ERR'][0]
                data_decomp['BLR_INT_ERR'][count] = dec['PARS_BLR_INT_ERR'][0]
                data_decomp['BLR_POS_ERR'][count] = dec['PARS_BLR_POS_ERR'][0]
                data_decomp['BLR_SIG_ERR'][count] = dec['PARS_BLR_SIG_ERR'][0]
                data_decomp['blr_flux_halpha_err'][count] = dec['PARS_BLR_FLUX_HALPHA_ERR'][0]

        except IOError as e:
            scoop.logger.error(file)
            scoop.logger.error(e)
        except ValueError as e:
            scoop.logger.error(file)
            scoop.logger.error(e)

        finally:
            fits.close()

    if write_fits:
        if gaus or all:
            outfile = '{}/restable_emis_gaus_test_{}.fits'.format(os.path.dirname(list), ID)
            fitsio.write(outfile, data_gaus, extname='restable_emis', clobber=True)
        if nonpar or all:
            outfile = '{}/restable_emis_nonpar_test_{}.fits'.format(os.path.dirname(list), ID)
            fitsio.write(outfile, data_nonpar, extname='restable_nonpar', clobber=True)
            outfile = '{}/restable_emis_nonpar_sm_test_{}.fits'.format(os.path.dirname(list), ID)
            fitsio.write(outfile, data_nonpar_sm, extname='restable_nonpar_sm', clobber=True)
        if decomp or all:
            print decomp

    return tables 

def collect_decomp_table(list, ID, nIDs, nprint=None, write_fits=True):
    # make output structured array for decomposition results
    with open(list) as f:
        files = f.read().splitlines()

    nfiles = len(files)

    # select files for given chunk
    list_limits = utils_chunk(nfiles, ID, nIDs)
    files = files[list_limits[0]:list_limits[1]+1]
    nfiles = len(files)

    decomp_dtypes = ([
            ('mjd', np.int32),
            ('plate', np.int32),
            ('fiberid', np.int16),
            ('MBH_IDL', np.float32),
            ('MBH_ERR_IDL', np.float32),
            ('MBH', np.float32),
            ('VSYS', np.float32),
            ('MAX_HALPHA', np.float32),
            ('ERROR_LEVEL', np.float32),
            ('FLUX_HALPHA', np.float32),
            ('NLR_MAX', np.float32),
            ('NLR_STDDEV', np.float32),
            ('BLR_POS', np.float32),
            ('BLR_SIG', np.float32),
            ('BLR_MAX', np.float32),
            ('BLR_FRAC', np.float32),
            ('BLR_FLUX_HALPHA', np.float32),
            ('BLR_FLUX_HBETA', np.float32),
            ('NLR_FLUX_HBETA', np.float32),
            ('NLR_FLUX_OIII', np.float32),
            ('NLR_FLUX_HALPHA', np.float32),
            ('NLR_FLUX_NII', np.float32),
            ('NLR_FLUX_SII_1', np.float32),
            ('NLR_FLUX_SII_2', np.float32),
            ('RMS', np.float32),
            ('CHI2', np.float32),
            ('CHI2DOF', np.float32),
            ('CHI2_NOBLR', np.float32),
            ('CHI2DOF_NOBLR', np.float32),
            ('F_VALUE', np.float32),
            ('PROB_F', np.float32),
            ('CHI2_40', np.float32),
            ('CHI2DOF_40', np.float32),
            ('CHI2_NOBLR_40', np.float32),
            ('CHI2DOF_NOBLR_40', np.float32),
            ('F_VALUE_40', np.float32),
            ('PROB_F_40', np.float32),
            ('CHI2_ADP', np.float32),
            ('CHI2DOF_ADP', np.float32),
            ('CHI2_NOBLR_ADP', np.float32),
            ('CHI2DOF_NOBLR_ADP', np.float32),
            ('F_VALUE_ADP', np.float32),
            ('PROB_F_ADP', np.float32),
            ('BLR_POS_ERR', np.float32),
            ('BLR_SIG_ERR', np.float32),
            ('BLR_FLUX_HALPHA_ERR', np.float32),
            ('BLR_FLUX_HBETA_ERR', np.float32),
            ('NLR_FLUX_HBETA_ERR', np.float32),
            ('NLR_FLUX_OIII_ERR', np.float32),
            ('NLR_FLUX_HALPHA_ERR', np.float32),
            ('NLR_FLUX_NII_ERR', np.float32),
            ('NLR_FLUX_SII_1_ERR', np.float32),
            ('NLR_FLUX_SII_2_ERR', np.float32),
            ('SPECTRUM_SNR', np.float32),
            ('FIT_V', np.float32),
            ('FIT_SIG', np.float32),
            ('FIT_AGE', np.float32),
            ('FIT_MET', np.float32),
            ('FIT_CHI2', np.float32)
            ])

    # create a table filled by NANs
    data_decomp = np.empty(nfiles, dtype=decomp_dtypes)
    data_decomp[:] = np.nan

    for count, file in enumerate(files):
        fname = os.path.basename(file)

        data_decomp['mjd'][count] = fname[6:11]
        data_decomp['plate'][count] = fname[12:16]
        data_decomp['fiberid'][count] = fname[17:20]

        if nprint:
            if count % nprint == 0:
                scoop.logger.info("{} {}".format(count, fname))

        try:
            fits = fitsio.FITS(file)
            next = len(fits)

            spec = fits['SPECTRUM'].read()
            dec = fits['EMIS_SPEC_DECOMP'].read()

            mbh = (flux_sig_to_mbh(dec['PARS_BLR_FLUX_HA'][0], dec['PARS_BLR_SIG'][0], 
                dec['PARS_VSYS'][0] / 299792.45, ref='reines13')[0].value)

            data_decomp['VSYS'][count] = dec['PARS_VSYS'][0]
            data_decomp['MAX_HALPHA'][count] = dec['PARS_MAX_HA'][0]
            data_decomp['ERROR_LEVEL'][count] = dec['PARS_ERROR_LEVEL'][0]
            data_decomp['NLR_MAX'][count] = dec['PARS_NLR_MAX'][0]
            data_decomp['NLR_STDDEV'][count] = dec['PARS_NLR_STDDEV'][0]
            data_decomp['BLR_POS'][count] = dec['PARS_BLR_POS'][0]
            data_decomp['BLR_SIG'][count] = dec['PARS_BLR_SIG'][0]
            data_decomp['BLR_MAX'][count] = dec['PARS_BLR_MAX'][0]
            data_decomp['BLR_FRAC'][count] = dec['PARS_BLR_FRAC_HA'][0]
            data_decomp['BLR_FLUX_HALPHA'][count] = dec['PARS_BLR_FLUX_HA'][0]
            data_decomp['BLR_FLUX_HBETA'][count] = dec['PARS_COMP_FLUX_BLR'][0][0]
            data_decomp['MBH_IDL'][count] = dec['PARS_MBH'][0]
            data_decomp['MBH_ERR_IDL'][count] = dec['PARS_MBH_ERR'][0]
            data_decomp['MBH'][count] = mbh
            data_decomp['FLUX_HALPHA'][count] = dec['PARS_FLUX_HA'][0]
            data_decomp['NLR_FLUX_HBETA'][count] = dec['PARS_COMP_FLUX_NLR'][0][0]
            data_decomp['NLR_FLUX_OIII'][count] = dec['PARS_COMP_FLUX_NLR'][0][1]
            data_decomp['NLR_FLUX_HALPHA'][count] = dec['PARS_COMP_FLUX_NLR'][0][2]
            data_decomp['NLR_FLUX_NII'][count] = dec['PARS_COMP_FLUX_NLR'][0][3]
            data_decomp['NLR_FLUX_SII_1'][count] = dec['PARS_COMP_FLUX_NLR'][0][4]
            data_decomp['NLR_FLUX_SII_2'][count] = dec['PARS_COMP_FLUX_NLR'][0][5]
            data_decomp['RMS'][count] = dec['PARS_RMS'][0]
            data_decomp['CHI2'][count] = dec['PARS_CHI2'][0]
            data_decomp['CHI2DOF'][count] = dec['PARS_CHI2DOF'][0]
            data_decomp['CHI2_NOBLR'][count] = dec['PARS_NOBLR_CHI2'][0]
            data_decomp['CHI2DOF_NOBLR'][count] = dec['PARS_NOBLR_CHI2DOF'][0]
            data_decomp['F_VALUE'][count] = dec['PARS_F_VALUE'][0]
            data_decomp['PROB_F'][count] = dec['PARS_PROB_F'][0]
            data_decomp['BLR_POS_ERR'][count] = dec['PARS_BLR_POS_ERR'][0]
            data_decomp['BLR_SIG_ERR'][count] = dec['PARS_BLR_SIG_ERR'][0]
            data_decomp['BLR_FLUX_HALPHA_ERR'][count] = dec['PARS_BLR_FLUX_HALPHA_ERR'][0]
            data_decomp['BLR_FLUX_HBETA_ERR'][count] = dec['PARS_COMP_FLUX_BLR_ERR'][0][1]
            data_decomp['NLR_FLUX_HBETA_ERR'][count] = dec['PARS_COMP_FLUX_NLR_ERR'][0][0]
            data_decomp['NLR_FLUX_OIII_ERR'][count] = dec['PARS_COMP_FLUX_NLR_ERR'][0][1]
            data_decomp['NLR_FLUX_HALPHA_ERR'][count] = dec['PARS_COMP_FLUX_NLR_ERR'][0][2]
            data_decomp['NLR_FLUX_NII_ERR'][count] = dec['PARS_COMP_FLUX_NLR_ERR'][0][3]
            data_decomp['NLR_FLUX_SII_1_ERR'][count] = dec['PARS_COMP_FLUX_NLR_ERR'][0][4]
            data_decomp['NLR_FLUX_SII_2_ERR'][count] = dec['PARS_COMP_FLUX_NLR_ERR'][0][5]
            data_decomp['SPECTRUM_SNR'][count] = np.nanmedian(spec['FLUX'][0])/np.nanmedian(spec['ERROR'][0])
            data_decomp['FIT_V'][count] = spec['V'][0]
            data_decomp['FIT_SIG'][count] = spec['SIG'][0]
            data_decomp['FIT_AGE'][count] = spec['AGE'][0]
            data_decomp['FIT_MET'][count] = spec['MET'][0]
            data_decomp['FIT_CHI2'][count] = spec['CHI2'][0]

            # calculate chi2 values around Halpha in fixed and adaptive window
            spec_wave = dec['SPEC_WAVE'][0]
            spec_emis = dec['SPEC_EMIS'][0]
            spec_error = dec['SPEC_ERROR'][0]
            spec_fit = dec['SPEC_FIT'][0]
            spec_fit_noblr = dec['SPEC_FIT_NOBLR'][0]

            winwl_40 = 40.0 # A - fixed width window (+/-winwl) around Halpha
            # adaptive window
            sig_adp = 3 * max(dec['PARS_BLR_SIG'][0], dec['PARS_NLR_STDDEV'][0]) 
            winwl_adp = sig_adp / c * 6562.79 * ( 1 + dec['PARS_VSYS'][0] / c )
            winwl_adp = min([winwl_adp,100])

            for winwl, suffix in zip([winwl_40, winwl_adp], ['_40', '_ADP']):
                idx  = np.abs( spec_wave - 6562.79 * ( 1 + dec['PARS_VSYS'][0] / c ) ) < winwl
                nidx = len(idx)

                chi2 = np.sum(((spec_emis[idx] - spec_fit[idx]) / spec_error[idx])**2)
                chi2_noblr = np.sum(((spec_emis[idx] - spec_fit_noblr[idx]) / spec_error[idx])**2)
                p1 = 2 #+ len(dec['LOSVD_VBIN'][0])# noblr
                p2 = 5 #+ len(dec['LOSVD_VBIN'][0])# +blr fit
                chi2dof = chi2 / ( nidx - p2 ) # Halpha nlr+blr, NII fluxes (3), BLR_POS, BLR_SIG (2)
                chi2dof_noblr = chi2_noblr / ( nidx - p1 ) # Halpha nlr, NII fluxes (2)
                # Using this description for F-test
                # https://en.wikipedia.org/wiki/F-test#Regression_problems
                f_value = ( (chi2_noblr - chi2) / (p2 - p1) ) / (chi2 / (nidx - p2))
                # prob_f = fprob.cdf(f_value, nidx - p1, nidx - p2)
                prob_f = fprob.cdf(f_value, p2 - p1, nidx - p2)

                data_decomp['CHI2' + suffix][count] = chi2
                data_decomp['CHI2DOF' + suffix][count] = chi2dof
                data_decomp['CHI2_NOBLR' + suffix][count] = chi2_noblr
                data_decomp['CHI2DOF_NOBLR' + suffix][count] = chi2dof_noblr
                data_decomp['F_VALUE' + suffix][count] = f_value
                data_decomp['PROB_F' + suffix][count] = prob_f

        except IOError as e:
            scoop.logger.error(file)
            scoop.logger.error(e)
        except ValueError as e:
            scoop.logger.error(file)
            scoop.logger.error(e)

        finally:
            fits.close()

    # if write_fits:
    #     if gaus or all:
    #         outfile = '{}/restable_emis_gaus_test_{}.fits'.format(os.path.dirname(list), ID)
    #         fitsio.write(outfile, data_gaus, extname='restable_emis', clobber=True)
    #     if nonpar or all:
    #         outfile = '{}/restable_emis_nonpar_test_{}.fits'.format(os.path.dirname(list), ID)
    #         fitsio.write(outfile, data_nonpar, extname='restable_nonpar', clobber=True)
    #         outfile = '{}/restable_emis_nonpar_sm_test_{}.fits'.format(os.path.dirname(list), ID)
    #         fitsio.write(outfile, data_nonpar_sm, extname='restable_nonpar_sm', clobber=True)
    #     if decomp or all:
    #         print decomp
    return data_decomp#, dec

def concatenate_arrays(arr1, arr2):
    """
    Simple wrapper over np.concatenate() so that it can be used as a reduce() function.
    Note that one can't use lambda function inside reduce with scoop, as lambda is not pickable.
    """
    return np.concatenate((arr1, arr2))

def reduce_map_decomp_output(map_output):
    """
    Reduce function to join tables with decomposition results
    """
    length = len(map_output)
    if length == 1:
        result = map_output[0]
    else:
        result = map_output[0]
        for m in map_output[1:]:
            result = concatenate_arrays(result, m)
    return result


def reduce_output(list1, list2):
    """
    Reduce function to join tables and their names.
    For all=True list1 looks like
    [[data_gaus,'gaus'], [data_nonpar,'nonpar'], [data_nonpar_sm,'nonpar_sm'], [data_decomp,'decomp']]
    where data_* are ndarrays.
    """
    list = [[concatenate_arrays(element1[0], element2[0]),element1[1]] for element1, element2 in zip(list1, list2)]
    return list


def reduce_map_output(map_output):
    """
    Reduce function to join tables from map function outout.
    """
    length = len(map_output)
    if length == 1:
        result = map_output[0]
    else:
        result = map_output[0]
        for m in map_output[1:]:
            result = [[concatenate_arrays(element1[0], element2[0]),element1[1]] for element1, element2 in zip(result, m)]
    return result


def test(args):
    """
    Simple test, creates a log file on every worker
    """
    logname = makelog(args)    
    # flog = os.path.dirname( args[2] ) + '/' + logname
    scoop.logger.info(logname)
    return


def makelog(args):
    """
    @TODO: document arguments
    """
    datetime_stamp = datetime.datetime.now().isoformat()[:-7]
    log_name = 'log_{}_{:02d}{:02d}_{}.log'.format(args[3], args[0] + 1, args[1], datetime_stamp).replace('-','').replace(':','')
    log_name = os.path.join(os.path.dirname(args[2]), log_name)
    with open(log_name, 'w') as log:
        log.write( 'Current dir: {} \n'.format(os.getcwd()) )
        log.write( 'File list: {} \n'.format(args[2]) )
        log.write( 'Hostname: {} \n'.format(utils.socket.gethostname()) )
        log.write( 'IP: {} \n'.format(utils.ip) )
        log.write( 'ID: {} \n'.format(args[0]+1) )
        log.write( 'nIDs: {} \n'.format(args[1]) )
        
    return log_name


if __name__ == '__main__':
    """
    Be sure that all environment settings as well as paths and code are the same on all host servers.

    Usage examples:
    (for server)
        python -m scoop --hostfile hosts -n 38 multiproc.py nonp 38 ../agn_bh/speclist_imbh_joined.txt
        python -m scoop -n 16 multiproc.py collect_data 16 /tmp/tmp_filelist.txt > collect_data.log 2>&1
    (for laptop)
        python -m scoop -n 4 multiproc.py decomp 4 ../agn_bh/fit_agn/flist_checks.txt 
    (test)
        python -m scoop -n 4 multiproc.py test 4 blabla

    First line in hostfile must point to localhost. The localhost must have at least one worker.
    Debug principle: 1) find the buggy host by editing hostfile; 2) copy-paste worker function.
    """
    scoop_time = time.time()

    try:
        task = sys.argv[1]
        proc_n = int(sys.argv[2])
        file_list = sys.argv[3]
    except IndexError:
        print utils.getEnv()
        scoop.logger.error('Launch requires 3 arguments, got {:d}'.format(len(sys.argv) - 1))
        sys.exit(1)

    proc_ids = range(proc_n)
    # args is like [(0, 4, 'filename', 'function_name'), (1, 4, 'filename', 'function_name'), (2, 4, ...), ...]
    args = zip( proc_ids, [proc_n] * proc_n, [file_list] * proc_n, [task] * proc_n )
    
    if task == 'gaus':
        returnValues = list(
            scoop.futures.map_as_completed(fit_lines_gaus, args))

    elif task == 'nonp':
        returnValues = list(
            scoop.futures.map_as_completed(fit_lines_nonpar, args))

    elif task == 'mc_nonp':
        returnValues = list(
            scoop.futures.map_as_completed(mc_nonp_profiles_fluxes, args))

    elif task == 'decomp':
        returnValues = list(
            scoop.futures.map_as_completed(profile_decompose, args))
        scoop.logger.info('All done in {:.3f} seconds'.format(time.time() - scoop_time))

    elif task == 'collect_data':
        #scoop.logger.setLevel(scoop.logging.INFO)
        returnValues = list(scoop.futures.map_as_completed(collect_data, args))
        merged_array = reduce_map_output(returnValues)
        for tbl in merged_array:
            file_list = os.path.abspath(file_list)
            datamark = datetime.datetime.now().isoformat()[2:-16].replace('-','')
            outputfile = os.path.join(os.path.dirname(file_list), 'restable_{}_{}_{}.fits'.format(os.path.splitext(os.path.basename(file_list))[0],tbl[1],datamark))
            scoop.logger.info('Writing {:d} collected rows in a FITS file: {}'.format(len(tbl[0]), outputfile))
            fitsio.write(outputfile, tbl[0], extname='restable_' + tbl[1], clobber=True)
        scoop.logger.info('All done in {:.3f} seconds'.format(time.time() - scoop_time))

    elif task == 'collect_decomp':
        returnValues = list(scoop.futures.map_as_completed(collect_data_decomp, args))
        merged_tbl = reduce_map_decomp_output(returnValues)
        file_list = os.path.abspath(file_list)
        datamark = datetime.datetime.now().isoformat()[2:-16].replace('-','')
        outputfile = os.path.join(os.path.dirname(file_list), 'restable_{}_{}_{}.fits'.format(os.path.splitext(os.path.basename(file_list))[0],'decomp_spec',datamark))
        scoop.logger.info('Writing {:d} collected rows in a FITS file: {}'.format(len(merged_tbl), outputfile))
        fitsio.write(outputfile, merged_tbl, extname='restable_decomp_spec', clobber=True)
        scoop.logger.info('All done in {:.3f} seconds'.format(time.time() - scoop_time))

    elif task == 'collect_data_fake':
        merged_array = scoop.futures.mapReduce(collect_data_fake, concatenate_arrays, list_files)
        scoop.logger.info('All done in {:.3f} seconds'.format(time.time() - scoop_time))
    
    
    elif task == 'test':
        returnValues = list(scoop.futures.map_as_completed(test, args))

    elif task == 'find_bad_files':
        filenames = list(scoop.futures.map_as_completed(find_bad_files, args))
        # join files
        with open(os.path.dirname(file_list)+"/check_extname.list", 'w') as outfile:
            for fname in filenames:
                with open(fname) as infile:
                    for line in infile:
                        outfile.write(line)
        # delete temp files
        for fname in filenames:
            os.unlink(fname)
        # returnValues = scoop.futures.mapReduce(find_bad_files, operator.add, args)
        # print(returnValues)
        scoop.logger.info('All done in {:.3f} seconds'.format(time.time() - scoop_time))

    else:
        scoop.logger.error("Wrong task {}".format(task))
        sys.exit(1)
        
