import astropy
import math
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import matplotlib.pyplot as plt
import numpy as np
import random
import matplotlib.transforms as transforms
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LogNorm
from astropy.io import fits
from matplotlib import gridspec

font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 5}


def lupt(x):
  return math.asinh(1*x)
  
def delete_er(s):
  res = s.replace('*','')
  res = res.replace('[','')
  res = res.replace(']','')
  return res

 
def print_plots(name):
  hdulist = fits.open(name)
  hdulistpsf = fits.open(hdulist[2].header['PSF'])
  hdulist.info()
  scidata = hdulist[3].data/ hdulist[2].data
  scidata2 = hdulist[2].data
  scidata1 = hdulist[1].data
  scidata3 = hdulist[3].data
  scidata4 = hdulistpsf[0].data

  width=hdulist[2].header['NAXIS1']#height
  height=hdulist[2].header['NAXIS2'] #width
  section=hdulist[2].header['FITSECT']
  part_1=section.split(',')[0]
  part_2=section.split(',')[1]
  part_1=part_1.replace("[", "")
  part_2=part_2.replace("]", "")
  x_min=float(part_1.split(':')[0])
  y_min=float(part_2.split(':')[0])

  print x_min, y_min
  ra_ratio=3600*math.sqrt(hdulist[1].header['CD1_1']**2+hdulist[1].header['CD1_2']**2)
  dec_ratio=ra_ratio

  x_c=int(float(delete_er(hdulist[2].header['1_XC'].split('+/-')[0]))-x_min)
  y_c=int(float(delete_er(hdulist[2].header['1_YC'].split('+/-')[0]))-y_min)
  x=[]
  y=[]
  z=[]
  pix_min=scidata2[0,0]
  pix_max=scidata2[0,0]
  for i in range(width):
    for j in range(height):
      if scidata2[j,i]>pix_max :
	pix_max=scidata2[j,i]
      if scidata2[j,i]<pix_min :
	pix_min=scidata2[j,i]   

      
      
  fig = plt.figure()
  gs = gridspec.GridSpec(3, 2, width_ratios=[1, 1.3])

  ax2=plt.subplot(gs[0])
  ax= plt.subplot(gs[1])
  ax3= plt.subplot(gs[2])
  ax4= plt.subplot(gs[3])
  ax5= plt.subplot(gs[4])

  '''
  ax2=fig.add_subplot(221)
  ax=fig.add_subplot(222)
  ax3=fig.add_subplot(223)
  ax4=fig.add_subplot(224)
  '''
  psf_width = scidata4.shape[0]
  psf_height = scidata4.shape[1]
  x_c_psf=int(psf_width/2)
  y_c_psf=int(psf_height/2)
  x_psf = np.arange(0, scidata4.shape[0], 1)
  y_psf = np.arange(0, scidata4.shape[1], 1)

  x_ = np.arange(0, width, 1)
  y_ = np.arange(0, height, 1)
  x=x_-x_c
  y=y_-y_c

  print min(x)

  X, Y = np.meshgrid(y, x)
  print np.ravel(X)+x_c

  y, x = np.mgrid[slice(1, height , 1),
		  slice(1, width , 1)]
  y_psf, x_psf = np.mgrid[slice(1, psf_height , 1),
		  slice(1, psf_width , 1)]
  x=x-x_c
  y=y-y_c
  
  x_psf=x_psf-x_c_psf
  y_psf=y_psf-y_c_psf
  
  z=scidata[y+y_c,x+x_c]
  z2=scidata2[y+y_c,x+x_c]
  z3=scidata1[y+y_c,x+x_c]
  z4=scidata3[y+y_c,x+x_c]
  z5=scidata4[y_psf+y_c_psf,x_psf+x_c_psf]
  x_min=x.min() * ra_ratio
  x_max=x.max() * ra_ratio
  y_min=y.min() * dec_ratio
  y_max=y.max() * dec_ratio
  
  x_min_psf=x_psf.min() * ra_ratio
  x_max_psf=x_psf.max() * ra_ratio
  y_min_psf=y_psf.min() * dec_ratio
  y_max_psf=y_psf.max() * dec_ratio
  #scidata=scidata[:+1, :+1]
  '''
  zs = np.array([scidata[y,x] for x,y in zip(np.ravel(X+x_c), np.ravel(Y+y_c))])
  Z = zs.reshape(X.shape)

  zs2 = np.array([scidata2[y,x] for x,y in zip(np.ravel(X+x_c), np.ravel(Y+y_c))])
  Z2 = zs2.reshape(X.shape)

  zs3 = np.array([scidata1[y,x] for x,y in zip(np.ravel(X+x_c), np.ravel(Y+y_c))])
  Z3 = zs3.reshape(X.shape)

  zs4 = np.array([scidata3[y,x] for x,y in zip(np.ravel(X+x_c), np.ravel(Y+y_c))])
  Z4 = zs4.reshape(X.shape)
  '''
  surf=ax.pcolor(x*ra_ratio, y*dec_ratio, z, cmap=cm.jet, linewidth=0, vmin=-0.3, vmax=0.3)
  surfpsf=ax5.pcolor(x_psf*ra_ratio, y_psf*dec_ratio, np.arcsinh(z5*1000),  cmap='gray', linewidth=0)
  ax5.set_aspect(1)

  pix_max = np.percentile(z2, 99)
  pix_min = np.percentile(z2, 0.5)

  levels=[]
  if pix_min<0:
    pix_min = 0
  range_pix=lupt(pix_max)-lupt(pix_min)
  delt=range_pix/10

  for i in range(10):
    if i>0:
      levels.append(math.sinh(lupt(pix_min)+i*delt))



  #surf2 = ax.contour(x*ra_ratio, y*dec_ratio,  np.arcsinh(z2/2),levels,cmap=cm.hsv, linewidths=1)
  div = make_axes_locatable(ax)
  cax = div.append_axes("right", size="15%", pad=0.05)
  cbar = plt.colorbar(surf, cax=cax)
  z2_=z2
  delt=min(z3.min(),z2.min())
  z3=z3-delt+1
  z2=z2-delt+1

  ax.set_aspect(1)
  ax.set_title("Relative residuals")
  ax.transData.transform((555, 0))
  flux_max = np.percentile(np.arcsinh(z2/5), 99.99)
  flux_min = np.percentile(np.arcsinh(z2/5), 0.25)
  surf2=ax2.pcolor(x*ra_ratio, y*dec_ratio, np.arcsinh(z2/5), vmin=flux_min, vmax=flux_max, cmap='gray', linewidth=0)
  ax2.set_aspect(1)
  ax2.set_title("Model")

  surf3=ax3.pcolor(x*ra_ratio, y*dec_ratio, np.arcsinh(z3/5), vmin=flux_min, vmax=flux_max, cmap='gray', linewidth=0)
  ax3.set_aspect(1)

  ax3.set_title("Raw image")

  surf3_1=ax4.pcolor(x*ra_ratio, y*dec_ratio, z4, cmap=cm.jet, linewidth=0, vmin=-0.3, vmax=0.3)
  #surf3_2= ax4.contour(x*ra_ratio, y*dec_ratio, z2_,levels,cmap=cm.hsv, linewidths=1)
  ax4.set_aspect(1)
  ax4.set_title("Residuals")

  div2 = make_axes_locatable(ax4)
  cax2 = div2.append_axes("right", size="15%", pad=0.05)
  cbar2 = plt.colorbar(surf3_1, cax=cax2)
  ax5.set_title("PSF")

  ax.set_xlim(x_min, x_max)
  ax.set_ylim(y_min, y_max)
  ax2.set_xlim(x_min, x_max)
  ax2.set_ylim(y_min, y_max)
  ax3.set_xlim(x_min, x_max)
  ax3.set_ylim(y_min, y_max)
  ax4.set_xlim(x_min, x_max)
  ax4.set_ylim(y_min, y_max)
  ax5.set_xlim(x_min_psf, x_max_psf)
  ax5.set_ylim(y_min_psf, y_max_psf)
  ax.set_ylabel(u"$\Delta Y$, arcsec")
  ax.set_xlabel(u"$\Delta X$, arcsec")
  ax2.set_ylabel(u"$\Delta Y$, arcsec")
  ax2.set_xlabel(u"$\Delta X$, arcsec")
  ax3.set_ylabel(u"$\Delta Y$, arcsec")
  ax3.set_xlabel(u"$\Delta X$, arcsec")
  ax4.set_ylabel(u"$\Delta Y$, arcsec")
  ax4.set_xlabel(u"$\Delta X$, arcsec")
  ax5.set_ylabel(u"$\Delta Y$, arcsec")
  ax5.set_xlabel(u"$\Delta X$, arcsec")
  #ax.colorbar(surf, shrink=0.5, aspect=20)
  plt.tight_layout()
  plt.savefig('foo.png', dpi=700)
  
  
print_plots('output.fits')