from astropy.io import fits
from scipy.ndimage import rotate
import matplotlib.pyplot as plt
import numpy as np
from reproject import reproject_interp
from astropy.wcs import WCS
from astropy.table import Table

cmap = plt.cm.get_cmap('hsv')
obj_name = '587741602572796110'
ra_gal = 194.94282
dec_gal = 27.74625
fnm = '/data1/Data/SUBARU/cutouts_psg/587741602572796110_W-J-B_SUPM139B835C0E40E400_extended.fits'
hdul = fits.open(fnm)
imgdata = hdul[0].data
inp_w = WCS(hdul[0].header)
ra_c, dec_c = inp_w.all_pix2world(imgdata.shape[0]/2.0, imgdata.shape[1]/2.0, 0)
print ra_c, dec_c
hdr = fits.Header()
fac_sc = 1.0
sc = 5.608e-5/fac_sc
slitw = 1.0/3600.0
theta = 160.0/57.3
velscale = 60.0
hdr['CD1_1'] = -np.cos(theta) * sc
hdr['CD1_2'] = np.sin(theta) * sc
hdr['CD2_1'] = np.sin(theta) * sc
hdr['CD2_2'] = np.cos(theta) * sc
hdr['CRVAL1'] = float(ra_c)
hdr['CRVAL2'] = float(dec_c)
hdr['CTYPE1'] = 'RA---TAN'
hdr['CTYPE2'] = 'DEC--TAN'
hdr['CRPIX1'] = imgdata.shape[0]/2.0
hdr['CRPIX2'] = imgdata.shape[1]/2.0
w = WCS(header = hdr)
data_new, fprnt = reproject_interp((imgdata, WCS(hdul[0].header)), w, shape_out=(int(2000*fac_sc), int(2000*fac_sc)))
x_gal, y_gal = w.all_world2pix(ra_gal, dec_gal, 0)
slitwp = slitw/sc
print x_gal, y_gal, slitwp
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
xvals, yvals = np.meshgrid(np.arange(int(2000*fac_sc)), np.arange(int(2000*fac_sc)))
dx_arr = np.array([0]) * fac_sc
y_min = int(y_gal - 20.0/(sc*3600.0))
y_max = int(y_gal + 20.0/(sc*3600.0))
sigma = 1.0
x_ = np.arange(int(10*sigma))
lsf = np.exp(-(x_ - np.mean(x_))**2/(2.0 * sigma**2))
lsf = lsf/np.sum(lsf)

for dx in dx_arr:
	x_gal_ = x_gal + dx
	xvals_slt = xvals[y_min:y_max, int(x_gal_ - slitwp/2.0):int(x_gal_ + slitwp/2.0)] - x_gal_
	slit_img = data_new[y_min:y_max, int(x_gal_ - slitwp/2.0):int(x_gal_ + slitwp/2.0)]
	slit_img_ext = data_new.copy()
	slit_img_ext[np.abs(xvals - x_gal_) > slitwp/2.0] = 0
	slit_img_ext = slit_img_ext[y_min:y_max, int(x_gal_ - 5*slitwp/2.0):int(x_gal_ + 5*slitwp/2.0)]
	xvals_slt_ext = xvals[y_min:y_max, int(x_gal_ - 5*slitwp/2.0):int(x_gal_ + 5*slitwp/2.0)] - x_gal_
	l_prof = np.sum(slit_img,axis=1)
	for y_slit in range(y_max - y_min):
		slit_img_ext[y_slit] = np.convolve(slit_img_ext[y_slit], lsf, mode='same')
	centroids = 3600.0*sc*np.nansum(xvals_slt*slit_img,axis=1) / l_prof
        centroids_ext = 3600.0*sc*np.nansum(xvals_slt_ext*slit_img_ext,axis=1) / l_prof
	tbl_cent = Table(((yvals[y_min:y_max, 0]-y_gal)*sc*3600, centroids*velscale, centroids_ext*velscale), names=('xpos', 'centroids', 'centroids_lsf'))
	tbl_cent.write('%s_dx%.3f_centroids.csv' % (obj_name, dx))
	ax1.plot((yvals[y_min:y_max, 0]-y_gal)*sc*3600,centroids*velscale, c = cmap((dx - min(dx_arr))/float(max(dx_arr) - min(dx_arr))))
        ax1.plot((yvals[y_min:y_max, 0]-y_gal)*sc*3600,centroids_ext*velscale, c = cmap((dx - min(dx_arr))/float(max(dx_arr) - min(dx_arr))))
	ax1.set_ylim(-10.0, 0)
ax2.plot((yvals[y_min:y_max,0]-y_gal)*sc*3600, l_prof)
#plt.show()
