import sys
import os
import pylab as plt
import numpy as np
import matplotlib as mpl
from astropy import cosmology
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, FormatStrFormatter


from astropy.table import Table
from astropy.io import fits
from astropy.cosmology import WMAP9 as cosmo
import astropy.units as u
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable 


#load tables with matched MPA and our gaussian lines
t = Table.read('combine_RCSED_DR7_veldisp.fits')


fig = plt.figure(figsize=(5, 5))
plt.subplots_adjust(hspace=0.05, wspace=0)

ax1 = fig.add_subplot(211)
ax2 = fig.add_subplot(212)

d_sdss = t['velDisp']
d_ssp = t['ssp_dispvel']
d_exp = t['exp_dispvel']

ylims=[-50,80]
xlims=[30,370]

# ax1.scatter(d_sdss,d_ssp-d_sdss,s=2,marker='.',linewidth=0,rasterized=True,alpha=0.2)
# ax2.scatter(d_sdss,d_exp-d_sdss,s=2,marker='.',linewidth=0,rasterized=True,alpha=0.2)
ax1.hist2d(d_sdss,d_ssp-d_sdss,bins=[200,100],range=[xlims,ylims],cmap='Blues',norm=LogNorm())
ax2.hist2d(d_sdss,d_exp-d_sdss,bins=[200,100],range=[xlims,ylims],cmap='Blues',norm=LogNorm())

for ax in [ax1,ax2]:
	ax.set_xlim(xlims)
	ax.set_ylim(ylims)
	ax.yaxis.set_minor_locator(MultipleLocator(5))
	ax.yaxis.set_major_locator(MultipleLocator(20))
	ax.xaxis.set_minor_locator(MultipleLocator(10))
	ax.xaxis.set_major_locator(MultipleLocator(50))
	ax.plot([0,1e3],[0,0],'k--',lw=2)
	# ax.set_xscale('log')


ax1.set_xticklabels('',visible=False)

fig.text(-0.015, 0.5, r'$\sigma_{RCSED}$ - $\sigma_{SDSS}$, km/s', va='center', rotation='vertical')
ax2.set_xlabel(r'$\sigma_{SDSS}$, km/s')
ax1.text(50,60,'SSP')
ax2.text(50,60,'EXP-SFH')

# # calc median in bins
nbins = 15
y1 = d_ssp - d_sdss
y2 = d_exp - d_sdss
bins = np.linspace(xlims[0],xlims[1],nbins)
xbins=bins-(bins[1]-bins[0])/2
idx = np.digitize(d_sdss,bins)
r1_med = [np.median( y1[(np.isnan(y1)==False) & (idx==k) & (y1>ylims[0]) & (y1<ylims[1])] ) for k in range(nbins)]
r1_std = [   np.std( y1[(np.isnan(y1)==False) & (idx==k) & (y1>ylims[0]) & (y1<ylims[1])] ) for k in range(nbins)]
r2_med = [np.median( y2[(np.isnan(y1)==False) & (idx==k) & (y2>ylims[0]) & (y2<ylims[1])] ) for k in range(nbins)]
r2_std = [   np.std( y2[(np.isnan(y1)==False) & (idx==k) & (y2>ylims[0]) & (y2<ylims[1])] ) for k in range(nbins)]
ax1.plot(xbins, r1_med,'--',c='darkred',lw=2)#,alpha=0.7)
ax1.errorbar(xbins, r1_med,r1_std,ecolor='darkred',fmt='none')

ax2.plot(xbins, r2_med,'--',c='darkred',lw=2)#,alpha=0.7)
ax2.errorbar(xbins, r2_med,r2_std,ecolor='darkred',fmt='none')


plt.savefig('fig_veldisp.pdf', bbox_inches='tight')
plt.close()




