# Example script to produce publication quality plot with CZKM catalog 
# Usage:
#	 python example_plot.py /path/to/catalog_table.fits
#			or
#	 ipython> %run example_plot.py /path/to/catalog_table.fits
#
import sys
import os
import pylab as plt
import numpy as np
from astropy.table import Table
from astropy.cosmology import WMAP9 as cosmo
from matplotlib.colors import LogNorm


def make_diagram(t,fileout):
  """
  Takes astropy table object as input and does publication-quality diagram. Saves it to diagram.png file in a run directory
  
  @param table: astropy table object returned by Table.read() function
  """
  # filter table rows by a condition
  filter=((np.isnan(t['exp_age']) == False) & (np.isnan(t['exp_met']) == False))
  xx=np.log10(t[filter]['exp_age']/1000.0)
  yy=t[filter]['exp_met']
  xra=np.log10([0.01,20])
  yra=[-1.5,0.8]

  counts,ybins,xbins,image = plt.hist2d(xx,yy,bins=[200,180],norm=LogNorm())
  
#  filter = (t['corrmag_g'] > 10) & (t['corrmag_g'] < 25) & (t['corrmag_r'] > 10) & (t['corrmag_r'] < 25) & (t['corrmag_k'] > 10) & (t['corrmag_k'] < 25)
#  
#  # prepare X and Y data
#  X = t[filter]['corrmag_g'] - t[filter]['kcorr_g'] - t[filter]['corrmag_r'] + t[filter]['kcorr_r']
#  Y = t[filter]['corrmag_k'] - t[filter]['kcorr_k'] - 25. - 5 * np.log10(cosmo.luminosity_distance(t[filter]['z']).value)
#  
#  # setting larger tickmarks and fonts for publication style plot
#  plt.rcdefaults()
#  plt.rcParams.update({
#    'xtick.major.size': 7.0, 
#    'xtick.minor.size': 4.0, 
#    'ytick.major.size': 7.0, 
#    'ytick.minor.size': 4.0, 
#    'font.size': 16
#  })
#                      
#  plt.plot(X, Y, marker='.', linestyle='none', color='black')

  plt.xlabel(r"$\tau$, Gyr")
  plt.ylabel("[Z/H], dex")
  plt.title('exp-SFH') # I usually don't use plot titles...
  
  # in case you need to change axis limits / scale
  ax = plt.gca()
  #ax.set_xscale('log')
  #ax.set_yscale('log')
  ax.set_xlim(xra)
  ax.set_ylim(yra)
  plt.xticks(np.log10([0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]),\
          ['0.01','','','','','','','','','0.1','','','','','','','','','1','','','','5','','','','','10','','','','','15','','','','','20'])
  #levs=np.linspace(10,100000,num=100)
  #ax.contour(counts,levels=levs ) #,levels=range(1,300),linewidths=2)
  #ax.contour(counts,levels=levs ) #,levels=range(1,300),linewidths=2)
  
  plt.savefig(fileout, bbox_inches='tight')
  plt.close()
  

# this function is called when script is launched as `python example_plot.py` or `%run example_plot.py` (if called from ipython).
# it first checks input arguments and if a table file exists, calls plotting function defined above
if __name__ == '__main__':
  #filename = '/db1/Data/pro/SDSS/fit_PEGASEHR_exp/tbl_res_PEGASE_exp_entire_sample.fits'
  filename = 'czkm_catalog_final_150114.fits'
  fileout = 'fig10_tau_met_expSFH.png'
  t = Table.read(filename)
  make_diagram(t,fileout)    
  print t
