import numpy as np
from astropy.io import fits
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
from photutils import CircularAperture
from photutils import aperture_photometry
from astropy import units as u
import os
import matplotlib.pyplot as plt
from pylab import * # This program requires the matplotlib python library for the plots.
from scipy import ndimage
from matplotlib import rc
from scipy.interpolate import interp1d
from scipy import interpolate
from scipy.interpolate import griddata
import pylab
from scipy.ndimage.interpolation import rotate


calib = fits.open('cutout_rings.v3.skycell.1736.048.stk.i.unconv.fits')
PA_angle = -43
raw_data = rotate(calib[0].data, PA_angle,reshape=True)
x_c = raw_data.shape[1]/2
y_c = raw_data.shape[0]/2
hdu = fits.PrimaryHDU(data=raw_data)
hdulist = fits.HDUList([hdu])
hdulist.writeto('rotated.fits', clobber=True)

calib_dat = raw_data[(y_c-7):(y_c+7),(x_c-2):(x_c+2)]

x = np.linspace(0,  calib_dat.shape[1]-1, calib_dat.shape[1])
y = np.linspace(0,  calib_dat.shape[0]-1, calib_dat.shape[0])
X, Y = np.meshgrid(x, y, copy=False)
#Z = X**2 + Y**2 + np.random.rand(*X.shape)*0.01


Z = calib_dat
X = X.flatten()
Y = Y.flatten()
ones = X*0+1
degrees = np.array([[0,0],[1,0],[0,1],[1,1],[2,0],[0,2],[1,2],[2,1],[3,0],[0,3],[3,1],[1,3],[2,2],[4,0],[0,4]])
A = np.array([ones, X, Y, X*Y, X**2, Y**2, X*Y**2,X**2*Y,X**3,Y**3,X**3*Y,Y**3*X,X**2*Y**2,X**4,Y**4]).T
B = Z.flatten()

coeff, r, rank, s = np.linalg.lstsq(A, B)

model = X * 0
#print A
for i in range(A.shape[1]):
    model = model + coeff[i] * A[i,0]

x_min = 1.1
x_max = 3.9
y_min = 1.
y_max = 13.
integral = 0.
for i in range(len(coeff)):
    #print degrees[i,0]
    integral = integral + coeff[i]*(x_max**(degrees[i,0]+1)-x_min**(degrees[i,0]+1))/(degrees[i,0]+1) * (y_max**(degrees[i,1]+1)-y_min**(degrees[i,1]+1))/(degrees[i,1]+1)
    
apertures = CircularAperture([(y_c,x_c)], r=6.)
phot_table = aperture_photometry(raw_data, apertures)

aper_corr = integral/phot_table[0]['aperture_sum']
print aper_corr
model = np.dot(A,coeff)
#print (model - calib_dat.flatten())/model

#plot(range(len(model)),model)
#plot(range(len(B)),B)
#plt.show()