from astropy.table import Table
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import griddata
from scipy import polyfit, poly1d
import os

# plt.style.use('classic')
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True

# Recommended for any plot for a science publication.
# Some paper submision systems (e.g. MNRAS) requires next lines to be done for
# your pictures.
# Read more about font types: https://en.wikipedia.org/wiki/PostScript_fonts
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['pdf.fonttype'] = 42

plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.rcParams['font.size'] = 16

def human_format(num):
    magnitude = 0
    while abs(num) >= 1000:
        magnitude += 1
        num /= 1000.0
    # add more suffixes if you need them
    return '%.2f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude])

cmap = plt.get_cmap('hsv')
masses = ['0.6','0.8','1.0','1.2','1.5','2.0','2.5','3.0','4.0',
	'5.0','6.0','7.0','8.0','10','12','15','20','25','30','40','50']

z_arr = ['0.02']
markers = ['k-', 'g-']
hr = plt.figure()
ax_hr = hr.add_subplot(111)

rm = plt.figure()
ax_rm = rm.add_subplot(111)

lm = plt.figure()
ax_lm = lm.add_subplot(111)

rt = plt.figure()
ax_rt = rt.add_subplot(111)

lt = plt.figure()
ax_lt = lt.add_subplot(111)

ctt = plt.figure()
ax_ctt = ctt.add_subplot(111)

crt = plt.figure()
ax_crt = crt.add_subplot(111)

ppcno = plt.figure()
ax_ppcno = ppcno.add_subplot(111)

dmm = plt.figure()
ax_dmm = dmm.add_subplot(111)

dmtot = plt.figure() 
ax_dmtot = dmtot.add_subplot(111)

mtp = plt.figure()
#ax_mtp = mtp.add_subplot(111)

mrp = plt.figure()
#ax_mrp = mrp.add_subplot(111)

rpp = plt.figure()
#ax_rpp = rpp.add_subplot(111)

mhp = plt.figure()
#ax_mhp = mhp.add_subplot(111)

mcp = plt.figure()

h2burned_frac = []
mtotloss_frac = []

slopes = open('slopes.txt', 'w+')
polyind = open('polyind.txt', 'w+')

HR_fit = []
ML_fit = []
MR_fit = []

for i,zval in enumerate(z_arr):
        for j,mval in enumerate(masses):
                path = '/data1/kirg/stellar_models/M%sZ%s/LOGS/history.data' % (mval, zval)
		path_ind = '/data1/kirg/stellar_models/M%sZ%s/LOGS/profiles.index' % (mval, zval)
                tbl = Table.read(path, format='ascii', header_start=4, data_start=5)
                init_ind = np.arange(len(tbl))
                init_ind_sftd = np.where(((tbl['log_L']-tbl['log_LH']) < 0.0))
                tbl_ms = tbl[init_ind_sftd[0][0]:]
		tms_ind_end = np.where(tbl_ms['center_h1'] < 1e-3*tbl_ms['center_h1'][0])
		tbl_ms = tbl_ms[:tms_ind_end[0][0]]
		c_coeff = (j + 0.)/(len(masses) + 0.)
                #plt.plot(10**tbl_prehburn['log_Teff'],tbl_prehburn['log_L'], c=cmap(c_coeff), label=mval)
                ax_hr.plot(10**tbl_ms['log_Teff'],10**tbl_ms['log_L'], c=cmap(c_coeff), label=mval)
                ax_rm.plot(tbl_ms['star_mass'],10**tbl_ms['log_R'], c=cmap(c_coeff), label=mval)
                ax_lm.plot(tbl_ms['star_mass'],10**tbl_ms['log_L'], c=cmap(c_coeff), label=mval)
                ax_rt.plot(tbl_ms['star_age'],10**tbl_ms['log_R'], c=cmap(c_coeff), label=mval)
                ax_lt.plot(tbl_ms['star_age'],10**tbl_ms['log_L'], c=cmap(c_coeff), label=mval)
                ax_ctt.plot(tbl_ms['star_age'],10**tbl_ms['log_center_T'], c=cmap(c_coeff), label=mval)
                ax_crt.plot(tbl_ms['star_age'],10**tbl_ms['log_center_Rho'], c=cmap(c_coeff), label=mval)
		ax_ppcno.plot(tbl_ms['center_h1'],10**(tbl_ms['pp']-tbl_ms['cno']), c=cmap(c_coeff), label=mval)
		HR_fit.append([tbl_ms['log_Teff'][0],tbl_ms['log_L'][0]])
                ML_fit.append([np.log10(tbl_ms['star_mass'][0]),tbl_ms['log_L'][0]])
                MR_fit.append([np.log10(tbl_ms['star_mass'][0]),tbl_ms['log_R'][0]])
		h2_burned = 0.5 * 3.15e7 * 4e33 * 10**tbl_ms['log_L'] * 10**tbl_ms['log_dt']
		dMH = (np.sum(h2_burned)/((3e10)**2))/(0.008*2e33)
		h2burned_frac.append(dMH / float(mval))
		dMtot = (float(mval) - tbl_ms['star_mass'][-1]) / float(mval)
		mtotloss_frac.append(dMtot)
		index_table = Table.read(path_ind, format='ascii.no_header', data_start=1)
                mtp.clf()
		mrp.clf()
		rpp.clf()
		mhp.clf()
		mcp.clf()
		ax_mtp = mtp.add_subplot(111)
		ax_mrp = mrp.add_subplot(111)
                ax_rpp = rpp.add_subplot(111)
                ax_mhp = mhp.add_subplot(111)
                ax_mcp = mcp.add_subplot(111)
		N_prof = 0.0
		for indval in index_table:
                        path_profile = '/data1/kirg/stellar_models/M%sZ%s/LOGS/profile%i.data' % (mval, zval, int(indval['col3']))
			if os.path.isfile(path_profile):
                                prof_tbl = Table.read(path_profile, format='ascii', header_start=4, data_start=5)
                                prof_hdr = Table.read(path_profile, format='ascii', header_start=1, data_start=2, data_end=3)
				if ((prof_hdr[0]['star_age'] > tbl_ms['star_age'][0]) & (prof_hdr[0]['star_age'] < tbl_ms['star_age'][-1])):
					print mval, prof_hdr[0]['star_age']
					age_str = human_format(prof_hdr[0]['star_age'])
                			ax_mtp.plot(prof_tbl['mass'],10**prof_tbl['logT'], c=cmap(N_prof/10.0), label=age_str)
                                        ax_mrp.plot(prof_tbl['mass'],10**prof_tbl['logRho'], c=cmap(N_prof/10.0), label=age_str)
                                        ax_rpp.plot(10**prof_tbl['logRho'],10**prof_tbl['logP'], c=cmap(N_prof/10.0), label=age_str)
                                        ax_mhp.plot(prof_tbl['logRho'],prof_tbl['h1'], c=cmap(N_prof/10.0), label=age_str)
					conv = np.zeros(prof_tbl['logT'].shape)
					conv[prof_tbl['gradr'] > prof_tbl['grada']] = 1
                                        ax_mcp.plot(prof_tbl['mass'],conv, c=cmap(N_prof/10.0), label=age_str)
					N_prof += 1.0
	                                p, residuals, rank, sv, rcond = polyfit(prof_tbl['logRho'],prof_tbl['logP'],1,full=True)
	                                polyind.write("M=%s\t T=%s\t poly_slope=%.2f\n" % (mval,age_str,1.0/(p[0] - 1.0)))
		ax_mtp.set_yscale('log')
		ax_mtp.set_xlabel(r'$m/M_{\odot}$')
		ax_mtp.set_ylabel(r'$T, K$')
		ax_mtp.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=3, mode="expand", borderaxespad=0.)
		mtp.savefig('task3_MTp_%s.pdf' % mval.replace('.','d'), bbox_inches='tight', dpi=150)

                ax_mcp.set_xlabel(r'$m/M_{\odot}$')
                ax_mcp.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=3, mode="expand", borderaxespad=0.)
                mcp.savefig('task3_MCp_%s.pdf' % mval.replace('.','d'), bbox_inches='tight', dpi=150)

		ax_mrp.set_yscale('log')
		ax_mrp.set_xlabel(r'$m/M_{\odot}$')
		ax_mrp.set_ylabel(r'$\rho, g/cm^3$')
		ax_mrp.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=3, mode="expand", borderaxespad=0.)
		mrp.savefig('task3_MRhop_%s.pdf' % mval.replace('.','d'), bbox_inches='tight', dpi=150)

		ax_rpp.set_yscale('log')
		ax_rpp.set_xscale('log')
		ax_rpp.set_ylabel(r'$P$')
		ax_rpp.set_xlabel(r'$\rho, g/cm^3$')
		ax_rpp.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=3, mode="expand", borderaxespad=0.)
		rpp.savefig('task3_PRhop_%s.pdf' % mval.replace('.','d'), bbox_inches='tight', dpi=150)

		ax_mhp.set_ylabel(r'$H_2$')
		ax_mhp.set_xlabel(r'$m/M_{\odot}$')
		ax_mhp.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=3, mode="expand", borderaxespad=0.)
		mhp.savefig('task3_mHp_%s.pdf' % mval.replace('.','d'), bbox_inches='tight', dpi=150)

		#plt.text(10**tbl_prehburn['log_Teff'][-1],tbl_prehburn['log_L'][-1], mval, horizontalalignment='right', verticalalignment='top')
HR_fit = np.array(HR_fit)
p, residuals, rank, sv, rcond = polyfit(HR_fit[:,0],HR_fit[:,1],1,full=True)
slopes.write("L_T\t poly_slope=%.2f\n" % p[0])

MR_fit = np.array(MR_fit)
p, residuals, rank, sv, rcond = polyfit(MR_fit[MR_fit[:,0]<np.log10(1.5)][:,0],MR_fit[MR_fit[:,0]<np.log10(1.5)][:,1],1,full=True)
slopes.write("M_Rl1.5\t poly_slope=%.2f\n" % p[0])

p, residuals, rank, sv, rcond = polyfit(MR_fit[MR_fit[:,0]>np.log10(1.5)][:,0],MR_fit[MR_fit[:,0]>np.log10(1.5)][:,1],1,full=True)
slopes.write("M_Rh1.5\t poly_slope=%.2f\n" % p[0])


ML_fit = np.array(ML_fit)
p, residuals, rank, sv, rcond = polyfit(ML_fit[ML_fit[:,0]<np.log10(10)][:,0],ML_fit[ML_fit[:,0]<np.log10(10)][:,1],1,full=True)
slopes.write("M_Ll10\t poly_slope=%.2f\n" % p[0])

p, residuals, rank, sv, rcond = polyfit(ML_fit[ML_fit[:,0]>np.log10(10)][:,0],ML_fit[ML_fit[:,0]>np.log10(10)][:,1],1,full=True)
slopes.write("M_Lh10\t poly_slope=%.2f\n" % p[0])

ax_hr.set_xscale('log')
ax_hr.set_yscale('log')
ax_hr.set_xlabel(r'$T_{eff}$, K')
ax_hr.set_ylabel(r'$L/L_{\odot}$')
ax_hr.invert_xaxis()
ax_hr.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=5, mode="expand", borderaxespad=0.)
hr.savefig('task1_HR.pdf', bbox_inches='tight', dpi=150)

ax_rm.set_xscale('log')
ax_rm.set_yscale('log') 
ax_rm.set_xlabel(r'$M/M_{\odot}$')
ax_rm.set_ylabel(r'$R/R_{\odot}$')
ax_rm.invert_xaxis()
ax_rm.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=5, mode="expand", borderaxespad=0.)
rm.savefig('task1_MR.pdf', bbox_inches='tight', dpi=150)

ax_lm.set_xscale('log')
ax_lm.set_yscale('log')
ax_lm.set_xlabel(r'$M/M_{\odot}$')
ax_lm.set_ylabel(r'$L/L_{\odot}$')
ax_lm.invert_xaxis()
ax_lm.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=5, mode="expand", borderaxespad=0.)
lm.savefig('task1_ML.pdf', bbox_inches='tight', dpi=150)


ax_rt.set_xscale('log')
ax_rt.set_yscale('log')
ax_rt.set_xlabel(r'$t, y$')
ax_rt.set_ylabel(r'$R/R_{\odot}$')
ax_rt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=5, mode="expand", borderaxespad=0.)
rt.savefig('task2_Rt.pdf', bbox_inches='tight', dpi=150)

ax_lt.set_xscale('log')
ax_lt.set_yscale('log')
ax_lt.set_xlabel(r'$t, y$')
ax_lt.set_ylabel(r'$L/L_{\odot}$')
ax_lt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=5, mode="expand", borderaxespad=0.)
lt.savefig('task2_Lt.pdf', bbox_inches='tight', dpi=150)

ax_ctt.set_xscale('log')
ax_ctt.set_yscale('log')
ax_ctt.set_xlabel(r'$t, y$')
ax_ctt.set_ylabel(r'$T, K$')
ax_ctt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=5, mode="expand", borderaxespad=0.)
ctt.savefig('task2_cTempt.pdf', bbox_inches='tight', dpi=150)

ax_crt.set_xscale('log')
ax_crt.set_yscale('log')
ax_crt.set_xlabel(r'$t, y$')
ax_crt.set_ylabel(r'$\rho, g/cm^3$')
ax_crt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=5, mode="expand", borderaxespad=0.)
crt.savefig('task2_cRho.pdf', bbox_inches='tight', dpi=150)

ax_ppcno.set_xlabel(r'$H_2$')
ax_ppcno.set_ylabel(r'$pp/cno$')
ax_ppcno.set_yscale('log')
ax_ppcno.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc=3, ncol=5, mode="expand", borderaxespad=0.)
ppcno.savefig('task5_ppcno.pdf', bbox_inches='tight', dpi=150)

ax_dmm.plot(np.array(masses, dtype=np.float), h2burned_frac, marker='*')
ax_dmm.set_xlabel(r'$M/M_{\odot}$')
ax_dmm.set_ylabel(r'$\Delta M_H/M_*$')
ax_dmm.set_yscale('log')
ax_dmm.set_xscale('log')
dmm.savefig('task4_massloss.pdf', bbox_inches='tight', dpi=150)

ax_dmtot.plot(np.array(masses, dtype=np.float), mtotloss_frac, marker='*')
ax_dmtot.set_xlabel(r'$M/M_{\odot}$')
ax_dmtot.set_ylabel(r'$\Delta M_*/M_*$')
ax_dmtot.set_yscale('log')
ax_dmtot.set_xscale('log')
dmtot.savefig('task6_totmassloss.pdf', bbox_inches='tight', dpi=150)

