# Example script to produce publication quality plot with CZKM catalog 
# Usage:
#	 python BPTdiagram.py /path/to/catalog_table.fits
#			or
#                 %run BPTdiagram.py /Volumes/ALM/CZKM/Data/ResuMain04012016.fits
#                 %run LINERdiagnostic.py /Volumes/ALM_ext/CZKM/Data/res_emis_MILES_x_20151123_gaus.fits
#                 %run LINERdiagnostic.py /Volumes/ALM_ext/CZKM/Data/res_emis_MILES_x_20151124_nonpar.fits
#
import sys
import os
import pylab as plt
import numpy as np
import math   # This will import math module
from astropy.table import Table
from astropy.cosmology import WMAP9 as cosmo
from matplotlib.colors import LogNorm
from matplotlib.ticker import LogFormatter
import scipy.ndimage
#import seaborn as sns

#from __future__ import division
from scipy.signal import convolve2d

def unred_cardelli(lam):
  """
  Cardelli et al. (1989) extinction law
  
  @param freq: frequency for which one wants to calculate extinction, in Hz
  @return: A_lambda / A_V value (A_freq / A_V, to be more precise, as the input is frequency)
  """
  x = 1.0 / lam * 1e4
  R_V = 3.1
  # Value x in Cardelli89 is measured in inverse microns, mu^-1
  # Originally, Cardelli defines extinction law only up to x = 0.3, but we extend it to ~200 microns (i.e. x ~= 0.005)
  # to include Spitzer/MIPS

  if x >= 0.005 and x <= 1.1:  
    ax = 0.574 * x ** 1.61
    bx = -0.527 * x ** 1.61
  elif x > 1.1 and x <= 3.3001:
    y = x - 1.82
    ax = 1 + 0.17699 * y - 0.50447 * y ** 2 - 0.02427 * y ** 3 + 0.72085 * y **4 + 0.01979 * y ** 5 - \
      0.77530 * y ** 6 + 0.32999 * y ** 7
    bx = 1.41338 * y + 2.28305 * y ** 2 + 1.07233 * y ** 3 - 5.38434 * y ** 4 - 0.62251 * y ** 5 + \
      5.30260 * y ** 6 - 2.09002 * y ** 7
  else:
    raise ValueError('Cardelli law is not defined at given wavelength')
  return ax + bx / R_V


def moving_average_2d(data, window):
    """Moving average on two-dimensional data.
    """
    # Makes sure that the window function is normalized.
    window /= window.sum()
    # Makes sure data array is a numpy array or masked array.
    if type(data).__name__ not in ['ndarray', 'MaskedArray']:
        data = np.asarray(data)

    # The output array has the same dimensions as the input data 
    # (mode='same') and symmetrical boundary conditions are assumed
    # (boundary='symm').
    return convolve2d(data, window, mode='same', boundary='symm')
  

def make_LINERdiagram(t):
  """
  Takes astropy table object as input and does publication-quality diagram. Saves it to diagram.png file in a run directory
  
  @param table: astropy table object returned by Table.read() function
  """
  
  # filter table rows by a condition
  
  # filter =  (t['F3727_OII_FLX'] > 1.e-7) & \
  #           (t['F6565_H_ALPHA_FLX'] > 1.e-5) & \
  #           (t['F6366_OI_FLX'] > 1.e-7) & \
  #           (t['F5008_OIII_FLX'] > 1.e-5) & \
  #           (t['F5008_OIII_FLX_ERR'] > 1.e-5) & \
  #           (t['F3727_OII_FLX_ERR'] > 1.e-7) & \
  #           (t['F6366_OI_FLX_ERR'] > 1.e-7) & \
  #           (t['F6565_H_ALPHA_FLX_ERR'] > 1.e-5) & \
  #           (t['F3727_OII_FLX_ERR']/t['F3727_OII_FLX'] < 0.3)  & \
  #           (t['F6565_H_ALPHA_FLX_ERR']/t['F6565_H_ALPHA_FLX'] < 0.3)  & \
  #           (t['F5008_OIII_FLX_ERR']/t['F5008_OIII_FLX'] < 0.3)  & \
  #           (t['F6366_OI_FLX_ERR']/t['F6366_OI_FLX'] < 0.3) 
  
  # IK tets filter
  filter =  (t['F6366_OI_FLX']/t['F6366_OI_FLX_ERR'] > 3) & \
            (t['F6565_H_ALPHA_FLX']/t['F6565_H_ALPHA_FLX_ERR'] > 3)

  # correction for background extinction
  balmer_decrement_theoretical = 2.85
  balmer_decrement = t['F6565_H_ALPHA_FLX']/t['F4863_H_BETA_FLX']
  ebv = 1.97 * np.log10( balmer_decrement / balmer_decrement_theoretical )
  ebv[ (balmer_decrement < balmer_decrement_theoretical) | np.isnan(ebv) | np.isinf(ebv) ] = 0.01 
  
  # IMPORTANT!!!
  # BE SURE THAT ALL LINES USED BELOW IN THIS LIST
  col_names=['F3727_OII_FLX','F3730_OII_FLX','F4863_H_BETA_FLX','F4960_OIII_FLX','F5008_OIII_FLX','F6302_OI_FLX','F6366_OI_FLX','F6550_NII_FLX','F6565_H_ALPHA_FLX','F6585_NII_FLX','F6718_SII_FLX','F6733_SII_FLX']
  col_waves=[3727,3730,4863,4960,5008,6302,6366,6550,6565,6585,6718]

  for wave, name in zip(col_waves,col_names):
    t[name] = t[name] * 10 ** ( 0.4 * ebv * 3.1 * unred_cardelli(wave) )

  # prepare X and Y and Z data
  Y = np.log10(t[filter]['F5008_OIII_FLX'] / ( t[filter]['F3727_OII_FLX'] + t[filter]['F3730_OII_FLX'] ))
  X = np.log10(( t[filter]['F6302_OI_FLX'] ) / t[filter]['F6565_H_ALPHA_FLX'])
  Z =  (t[filter]['F6565_H_ALPHA_EW'])

  xra=np.log10([0.002,1.0])
  yra=np.log10([0.03,30])
  good = (np.isnan(X) == False) & (np.isnan(Y) == False)
  myrange=np.array([xra,yra])
  counts, xbins, ybins, image = plt.hist2d(X[good], Y[good],
                                           bins=100, range=myrange)
  plt.clf()
  print 'N=',len(X)
  #X = X[filter2]
  #Y = Y[filter2]
  #Z = Z[filter2]
  
  # setting larger tickmarks and fonts for publication style plot
  plt.rcdefaults()
  plt.rcParams.update({
    'xtick.major.size': 7.0, 
    'xtick.minor.size': 4.0, 
    'ytick.major.size': 7.0, 
    'ytick.minor.size': 4.0, 
    'font.size': 16
  })
  
  marker_size = 8 # size of points
  
  plt.scatter(X[:],Y[:],c=Z[:],s=marker_size, norm=LogNorm(), vmin=3, vmax=300, cmap="gist_rainbow",lw=0,alpha=0.5)
  l_f = LogFormatter(base=10, labelOnlyBase=False)
  lvls = (5, 10., 20., 50., 100., 200. )
#  Cb=plt.colorbar(format=l_f)
  CB=plt.colorbar(ticks=lvls,format=l_f)

#  CB=plt.colorbar(ticks=lvls)
  CB.set_label(r"$H_\alpha$ equivalent width ($\AA$)", labelpad=10)
#  CS    = ax.contourf(X,Y,Z,levels=[1e0,1e-1,1e-2,1e-3],cmap=plt.cm.jet,norm = LogNorm())
#  cbar = plt.colorbar(CS)
#  cbar.ax.set_ylabel("Halpha equivalent width")

  # plt.xlabel(r'$[OI]/H_\alpha \, (6300+6364\AA/6563\AA)$', labelpad=10)
  plt.xlabel(r'$[OI]/H_\alpha \, (6300\AA/6563\AA)$', labelpad=10)
  plt.ylabel(r'$[OIII]/[OII] \, (5007\AA/3727\AA)$', labelpad=2)

# Kewley et al. (2006) MNRAS 372:961
  tt = np.arange(0.01,1.0, 0.01)
  lines = plt.plot(np.log10(tt),-1.701*(np.log10(tt))-2.163,'b--')
  plt.setp(lines, color='b', linewidth=2.0)
  tt = np.arange(0.09,1.0, 0.01)
  lines = plt.plot(np.log10(tt),1.0*(np.log10(tt))+0.7,'r-')
  plt.setp(lines, color='r', linewidth=2.0)
  plt.text(-0.9, 1.1, '%d galaxies\n' % (len(X)))
  # in case you need to change axis limits / scale
  ax = plt.gca()
  #  ax.set_xscale('log')
  #  ax.set_yscale('log')

  ax.set_xlim(xra)
  ax.set_ylim(yra)
  plt.xticks(np.log10([0.002,0.005,0.0075,0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]),['0.002','','','0.01','','','','0.05','','','','','0.1','0.2','','','0.5','','','','','1.0'])
  plt.yticks(np.log10([0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1, 0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1,2,3,4,5,6,7,8,9,10,30]),['','','','','','','','0.1','0.2','0.3','','0.5','','','','','1','2','3','','5','','','','','10','30'])


  levels = np.logspace(0.0, 2.8, 10)

  #Smoothing
  m, n = 4, 4 # shape of the window array
  win = np.ones((m, n))
  counts1 = moving_average_2d(counts, win)

  cs = plt.contour(counts1.transpose(), levels, extent=[xbins.min(),xbins.max(), ybins.min(), ybins.max()], linewidths=1.5, normed='False', colors='black')

  plt.savefig('LINERdiagram.png', bbox_inches='tight')
  plt.close()  

# this function is called when script is launched as `python example_plot.py` or `%run example_plot.py` (if called from ipython).
# it first checks input arguments and if a table file exists, calls plotting function defined above
if __name__ == '__main__':
  
  if len(sys.argv) == 2 and os.path.isfile(sys.argv[1]):

#    filename = 'czkm_catalog_final_150114.fits'
    filename = sys.argv[1]
    t = Table.read(filename)
    make_LINERdiagram(t)    
    #    print t
  else:
    print 'Wrong input arguments or table filename!'
                                                                                                                                                                                          
