# Example script to produce publication quality plot with CZKM catalog 
# Usage:
#	 python BPTdiagram.py /path/to/catalog_table.fits
#			or
#                 %run extinction2.py /Volumes/ALM/CZKM/Data/ResuMain04012016.fits
#
import sys
import os
import pylab as plt
import numpy as np
import math   # This will import math module
from astropy.table import Table
from astropy.cosmology import WMAP9 as cosmo
from matplotlib.colors import LogNorm
from matplotlib.ticker import LogFormatter
import scipy.ndimage
#import seaborn as sns

#from __future__ import division
from scipy.signal import convolve2d

def unred_cardelli(lam):
  """
  Cardelli et al. (1989) extinction law
  
  @param freq: frequency for which one wants to calculate extinction, in Hz
  @return: A_lambda / A_V value (A_freq / A_V, to be more precise, as the input is frequency)
  """
  x = 1.0 / lam * 1e4
  R_V = 3.1
  # Value x in Cardelli89 is measured in inverse microns, mu^-1
  # Originally, Cardelli defines extinction law only up to x = 0.3, but we extend it to ~200 microns (i.e. x ~= 0.005)
  # to include Spitzer/MIPS

  if x >= 0.005 and x <= 1.1:  
    ax = 0.574 * x ** 1.61
    bx = -0.527 * x ** 1.61
  elif x > 1.1 and x <= 3.3001:
    y = x - 1.82
    ax = 1 + 0.17699 * y - 0.50447 * y ** 2 - 0.02427 * y ** 3 + 0.72085 * y **4 + 0.01979 * y ** 5 - \
      0.77530 * y ** 6 + 0.32999 * y ** 7
    bx = 1.41338 * y + 2.28305 * y ** 2 + 1.07233 * y ** 3 - 5.38434 * y ** 4 - 0.62251 * y ** 5 + \
      5.30260 * y ** 6 - 2.09002 * y ** 7
  else:
    raise ValueError('Cardelli law is not defined at given wavelength')
  return ax + bx / R_V


def moving_average_2d(data, window):
    """Moving average on two-dimensional data.
    """
    # Makes sure that the window function is normalized.
    window /= window.sum()
    # Makes sure data array is a numpy array or masked array.
    if type(data).__name__ not in ['ndarray', 'MaskedArray']:
        data = np.asarray(data)

    # The output array has the same dimensions as the input data 
    # (mode='same') and symmetrical boundary conditions are assumed
    # (boundary='symm').
    return convolve2d(data, window, mode='same', boundary='symm')
  

def make_Extdiagram(t):
  """
  Takes astropy table object as input and does publication-quality diagram. Saves it to diagram.png file in a run directory
  
  @param table: astropy table object returned by Table.read() function
  """
  
  # filter table rows by a condition
  
  # filter =  (t['F3727_OII_FLX'] > 1.e-7) & \
  #           (t['F6565_H_ALPHA_FLX'] > 1.e-5) & \
  #           (t['F6366_OI_FLX'] > 1.e-7) & \
  #           (t['F5008_OIII_FLX'] > 1.e-5) & \
  #           (t['F5008_OIII_FLX_ERR'] > 1.e-5) & \
  #           (t['F3727_OII_FLX_ERR'] > 1.e-7) & \
  #           (t['F6366_OI_FLX_ERR'] > 1.e-7) & \
  #           (t['F6565_H_ALPHA_FLX_ERR'] > 1.e-5) & \
  #           (t['F3727_OII_FLX_ERR']/t['F3727_OII_FLX'] < 0.3)  & \
  #           (t['F6565_H_ALPHA_FLX_ERR']/t['F6565_H_ALPHA_FLX'] < 0.3)  & \
  #           (t['F5008_OIII_FLX_ERR']/t['F5008_OIII_FLX'] < 0.3)  & \
  #           (t['F6366_OI_FLX_ERR']/t['F6366_OI_FLX'] < 0.3) 
  
  # IK tets filter
  filter =  (t['corrmag_r']/t['corrmag_r_err'] > 3) & \
            (t['corrmag_g']/t['corrmag_g_err'] > 3) & \
            (t['F6565_H_ALPHA_FLX']/t['F6565_H_ALPHA_FLX_ERR'] > 3) & \
            (t['F4863_H_BETA_FLX']/t['F4863_H_BETA_FLX_ERR'] > 3)

  # correction for background extinction
  balmer_decrement_theoretical = 2.85
  balmer_decrement = t['F6565_H_ALPHA_FLX']/t['F4863_H_BETA_FLX']
  ebv = 1.97 * np.log10( balmer_decrement / balmer_decrement_theoretical )
  ebv[ (balmer_decrement < balmer_decrement_theoretical) | np.isnan(ebv) | np.isinf(ebv) ] = 0.01 

  # IMPORTANT!!!
  # BE SURE THAT ALL LINES USED BELOW IN THIS LIST
  col_names=['F3727_OII_FLX','F3730_OII_FLX','F4863_H_BETA_FLX','F4960_OIII_FLX','F5008_OIII_FLX','F6302_OI_FLX','F6366_OI_FLX','F6550_NII_FLX','F6565_H_ALPHA_FLX','F6585_NII_FLX','F6718_SII_FLX','F6733_SII_FLX']
  col_waves=[3727,3730,4863,4960,5008,6302,6366,6550,6565,6585,6718]

  col2_names=['corrmag_g','corrmag_r']
  col2_waves=[4770,6231]

  for wave, name in zip(col2_waves,col2_names):
    t[name] = t[name] - ebv * 3.1 * unred_cardelli(wave)
  
# Av -- dirty change!
    t['e_bv']= ebv * 3.1

  for wave, name in zip(col_waves,col_names):
    t[name] = t[name] * 10 ** ( 0.4 * ebv * 3.1 * unred_cardelli(wave) )

  # prepare X and Y and Z data
  Y = t[filter]['e_bv']
  X = (( t[filter]['corrmag_g']-t[filter]['kcorr_g'] ) - (t[filter]['corrmag_r']-t[filter]['kcorr_r']))
  Z =  (t[filter]['F6565_H_ALPHA_EW'])

  xra=([-1.,1.8])
  yra=([-0.1,6])
  good = (np.isnan(X) == False) & (np.isnan(Y) == False)
  myrange=np.array([xra,yra])
  counts, xbins, ybins, image = plt.hist2d(X[good], Y[good],
                                           bins=100, range=myrange)
  plt.clf()
  print 'N=',len(X)
  #X = X[filter2]
  #Y = Y[filter2]
  #Z = Z[filter2]
  
  # setting larger tickmarks and fonts for publication style plot
  plt.rcdefaults()
  plt.rcParams.update({
    'xtick.major.size': 7.0, 
    'xtick.minor.size': 4.0, 
    'ytick.major.size': 7.0, 
    'ytick.minor.size': 4.0, 
    'font.size': 16
  })
  
  marker_size = 8 # size of points
  
  plt.scatter(X[:],Y[:],c=Z[:],s=marker_size, norm=LogNorm(), vmin=3, vmax=300, cmap="gist_rainbow",lw=0,alpha=0.5)
  l_f = LogFormatter(base=10, labelOnlyBase=False)
  lvls = (5, 10., 20., 50., 100., 200. )
#  Cb=plt.colorbar(format=l_f)
  CB=plt.colorbar(ticks=lvls,format=l_f)

#  CB=plt.colorbar(ticks=lvls)
  CB.set_label(r"$H_\alpha$ equivalent width ($\AA$)", labelpad=10)
#  CS    = ax.contourf(X,Y,Z,levels=[1e0,1e-1,1e-2,1e-3],cmap=plt.cm.jet,norm = LogNorm())
#  cbar = plt.colorbar(CS)
#  cbar.ax.set_ylabel("Halpha equivalent width")

  # plt.xlabel(r'$[OI]/H_\alpha \, (6300+6364\AA/6563\AA)$', labelpad=10)
  plt.ylabel(r'$A_V$', labelpad=10)
  plt.xlabel(r'$g-r$', labelpad=2)

  plt.text(0.2, 5., '%d galaxies\n' % (len(X)))

  # in case you need to change axis limits / scale
  ax = plt.gca()
  #  ax.set_xscale('log')
  #  ax.set_yscale('log')

  ax.set_xlim(xra)
  ax.set_ylim(yra)
  plt.xticks(([-1,-0.75,-0.5,-0.25,0.,0.25,0.5,0.75,1.0,1.25,1.5,1.75]),['-1','','-0.5','','0','','0.5','','1.0','','1.5',''])
  plt.yticks(([0.,0.25,0.5,0.75,1.0,1.5,2,2.5,3,3.5,4,4.5,5,5.5,6]),['0','0.25','0.5','0.75','1','','2','','3','','4','','5','','6'])


  levels = np.logspace(0.0, 2.8, 10)

  #Smoothing
  m, n = 4, 4 # shape of the window array
  win = np.ones((m, n))
  counts1 = moving_average_2d(counts, win)

  cs = plt.contour(counts1.transpose(), levels, extent=[xbins.min(),xbins.max(), ybins.min(), ybins.max()], linewidths=1.5, normed='False', colors='black')

  plt.savefig('extinction2.png', bbox_inches='tight')
  plt.close()  

# this function is called when script is launched as `python example_plot.py` or `%run example_plot.py` (if called from ipython).
# it first checks input arguments and if a table file exists, calls plotting function defined above
if __name__ == '__main__':
  
  if len(sys.argv) == 2 and os.path.isfile(sys.argv[1]):

#    filename = 'czkm_catalog_final_150114.fits'
    filename = sys.argv[1]
    t = Table.read(filename)
    make_Extdiagram(t)    
    #    print t
  else:
    print 'Wrong input arguments or table filename!'
                                                                                                                                                                                          
