from astropy.table import Table
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import griddata

masses = ['0.6','0.8','1.0','1.2','1.5','2.0','2.5','3.0','4.0','5.0','6.0','7.0','8.0']
z_arr = ['0.02','0.001']
markers = ['k-', 'g-']
path = '/home/kirill/stellar_models/M8.0Z0.02/LOGS/history.data'
Teff = 5200
Teff_err = 300
lgL = 2.1
lgL_err = 0.1
teff_arr = []
lgl_arr = []
prob_arr = []
for i,zval in enumerate(z_arr):
	for mval in masses:
		path = 'stellar_models/M%sZ%s/LOGS/history.data' % (mval, zval)
		tbl = Table.read(path, format='ascii', header_start=4, data_start=5)
		init_ind = np.arange(len(tbl))
		init_ind_sftd = np.where(((tbl['log_L']-tbl['log_LH']) < 0.05))
		tbl_prehburn = tbl[0:init_ind_sftd[0][0]]
                if i == 0:
			d_dist2 = (Teff - 10**tbl_prehburn['log_Teff'])**2/(2*Teff_err**2) + (lgL - tbl_prehburn['log_L'])**2/(2*lgL_err**2)
			prob_arr.extend(np.exp(-d_dist2))
			teff_arr.extend(tbl_prehburn['star_age'])
			s_age = [float(mval)] * len(tbl_prehburn['star_age'])
			lgl_arr.extend(s_age)
		plt.plot(10**tbl_prehburn['log_Teff'],tbl_prehburn['log_L'], markers[i], zorder=-32)
		plt.text(10**tbl_prehburn['log_Teff'][-1],tbl_prehburn['log_L'][-1], mval, horizontalalignment='right', verticalalignment='top')

plt.errorbar([5200],[2.1], xerr=[300], yerr=[0.1], barsabove=True, fmt='r*')
plt.xscale('log')
plt.xlabel(r'$T_{eff}$, K')
plt.ylabel(r'$lg(L/L_{\odot})$')
plt.gca().invert_xaxis()
plt.savefig('task1_diagram.pdf', dpi=150)
plt.xlim(4800,5600)
plt.ylim(1.7,2.5)
plt.savefig('task1_diagram_zoom.pdf', dpi=150)
plt.clf()
xi = np.linspace(6e4,1.2e5,20) #np.linspace(np.nanpercentile(teff_arr,95),max(teff_arr),100)
yi = np.linspace(3.0,7.0,20) #np.linspace(min(lgl_arr),max(lgl_arr),100)
zi = griddata((teff_arr, lgl_arr), prob_arr, (xi[None,:], yi[:,None]), method='linear')
CS = plt.contourf(xi,yi,zi,15,cmap=plt.cm.jet)
ind_max = np.nanargmax(prob_arr)
teff_arr = np.array(teff_arr)
lgl_arr = np.array(lgl_arr)
prob_arr = np.array(prob_arr)
a_c,m_c = (teff_arr[ind_max], lgl_arr[ind_max])
teff_arr_near = teff_arr[((teff_arr>a_c*0.7) & (teff_arr<a_c*1.3) & (lgl_arr>(m_c-2.0)) & (lgl_arr<(m_c+2.0)))]
lgl_arr_near = lgl_arr[((lgl_arr>(m_c-2.0)) & (lgl_arr<(m_c+2.0)) & (teff_arr>a_c*0.7) & (teff_arr<a_c*1.3))]
prob_arr_near = prob_arr[((teff_arr>a_c*0.7) & (teff_arr<a_c*1.3) & (lgl_arr>(m_c-2.0)) & (lgl_arr<(m_c+2.0)))]
age_mean = np.sum(teff_arr_near * prob_arr_near) / np.sum(prob_arr_near)
mass_mean = np.sum(lgl_arr_near * prob_arr_near) / np.sum(prob_arr_near)
age2_mean = np.sum(teff_arr_near**2 * prob_arr_near) / np.sum(prob_arr_near)
mass2_mean = np.sum(lgl_arr_near**2 * prob_arr_near) / np.sum(prob_arr_near)
age_sigma = np.sqrt(age2_mean - age_mean**2)
mass_sigma = np.sqrt(mass2_mean - mass_mean**2)
print age_mean, age_sigma,  mass_mean, mass_sigma
plt.savefig('task1_diagram_prob.pdf', dpi=150)
