import argh
import numpy as np
import pandas as pd
import matplotlib as mpl
mpl.use('Agg')
from matplotlib.ticker import FixedLocator, FormatStrFormatter
import pylab as plt
from astropy.table import Table
from collections import OrderedDict


def selection_criteria(imbh_fits, synthetic_imbh_fits, criteria="BLR_SIG / NLR_STDDEV > 1.0"):
    """
    Plot most important Mbh recostruction metrics filtering results with some criteria
    
    @param imbh_fits: filename of the table with observed data results (so called imbh.fits, better to have a candidates subset)
    @param synthetic_imbh_fits: MC run results (synthetic restable)
    @param criteria: TOPCAT subset syntax to filter both datasets
    """
    f1 = Table.read(imbh_fits)
    f2 = Table.read(synthetic_imbh_fits)
    
    # translate TOPCAT-compatible string to pandas query
    pandas_criteria = criteria.replace('&&', ' and ').replace('||', ' or ')
    
    # redo table so that it's not this stupid multi-dimensional array
    f2_correct = Table()
    for c in f2.columns:
        f2_correct[c] = f2[c][0]
        
    df1 = f1.to_pandas()    
    df2 = f2_correct.to_pandas()
    df1_filtered = df1.query(pandas_criteria)
    df1_filtered_out = df1.query("not ({})".format(pandas_criteria))
    df2_filtered = df2.query(pandas_criteria)
    print "Candidates: {} of {} filtered".format(len(df1_filtered), len(df1))
    print "Synthetic: {} of {} filtered".format(len(df2_filtered), len(df2))

    # printing details of candidates filtered out
    cols = ['url', 'Mbh', 'xray', 'literature']
    for c in criteria.split():
        if np.char.isalpha(c[0]): # if first symbol is a letter, this is a column name
            cols.append(c)
    #pd.set_option('display.height', 1000)
    pd.set_option('display.max_rows', 500)
    #pd.set_option('display.max_columns', 500)
    #pd.set_option('display.width', 1000)
    pd.set_option('max_colwidth', 200)
    #print df1_filtered_out[cols]
    print df1_filtered[cols]

    # plotting of synthetic selection    
    # axes is tuple (ax1, ax2, ax3)
    f, axes = plt.subplots(1, 3, figsize=(22, 6))
    
    plot_cols = OrderedDict([
                        ('BLR_SIG', {'xlim': (100., 1100.)}),
                        ('BLR_FLUX_HALPHA', {'xlim': (100., 2500.)}), 
                        ('MBH', {'xlim': (5e3, 1e7)}),
                    ])
    
    for i, c in enumerate(plot_cols):    
        axes[i].plot(df2.get(c + '_INPUT'), df2.get(c) / df2.get(c + '_INPUT'), color='grey', marker='.', linestyle='None', markersize=2, alpha=0.6)
        axes[i].plot(df2_filtered.get(c + '_INPUT'), df2_filtered.get(c) / df2_filtered.get(c + '_INPUT'), color='red', marker='.', linestyle='None', markersize=2)
        axes[i].hlines(1, plot_cols[c]['xlim'][0], plot_cols[c]['xlim'][1], color='black', linewidth=3, linestyle='dashed')
        axes[i].set_xscale('log')
        axes[i].set_yscale('log')
        axes[i].set_xlim(plot_cols[c]['xlim'])
        axes[i].set_ylim((0.3, 5.))
        axes[i].set_xlabel(c + '_INPUT')
        axes[i].set_ylabel('{} / INPUT'.format(c))
        locs_major = np.concatenate((np.arange(0.1, 1, 0.1), np.arange(1, 10, 1)))
        locs_minor = np.concatenate((np.arange(0.1, 1, 0.01), np.arange(1, 10, 0.1)))
        axes[i].yaxis.set_major_locator(FixedLocator(locs_major))
        axes[i].yaxis.set_major_formatter(FormatStrFormatter('%g'))
        axes[i].yaxis.set_minor_locator(FixedLocator(locs_minor))

    f.suptitle(criteria, fontsize=10)
    
    plt.savefig('heh.png', bbox='tight')
    
    
if __name__ == '__main__':
    argh.dispatch_command(selection_criteria)
    