from astropy.io import fits
import matplotlib.pyplot as plt
import numpy as np

url='/Users/katkov/sci/work/deep2/results/spSpec_31030070_results.fits'


    #url = 'http://gal-02.sai.msu.ru:8080/getproduct/specphot/ssap/%04d_fit/spSpec%s-%04d-%03d_results.fits' % (int(request.GET['plate']), request.GET['mjd'], int(request.GET['plate']), int(request.GET['fiberid']))
hdu = fits.open(url)
sp = hdu[1].data

req={'smooth':10}
#req={}

# after reading fields we need to transpose it and extract 1D array (otherwise it's technically 2D for numpy)
wave = (sp.field('WAVE')).T[:, 0]
flux = (sp.field('FLUX')).T[:, 0]
fit = (sp.field('FIT')).T[:, 0]
err = (sp.field('ERROR')).T[:, 0]
#emis= (em.field('SPEC_FIT')).T[:, 0]
#emis_wave=(em.field('SPEC_WAVE')).T[:, 0]
z = sp.field('V') / 299792.458

# running average with window of `smooth` number of pixels
if 'smooth' in req:
    smooth = int(req['smooth'])
    flux = np.convolve(np.ones((smooth,)) / smooth, flux)[(smooth - 1):]
    fit = np.convolve(np.ones((smooth,)) / smooth, fit)[(smooth - 1):]
    err = np.convolve(np.ones((smooth,)) / smooth, err)[(smooth - 1):] / smooth**0.5
#    emis= np.convolve(np.ones((smooth,)) / smooth, emis)[(smooth - 1):]

# customization of Wavelength range
# in observed Wavelength
if 'wave_min' in req:
    min_wave = int(req['wave_min'])
else:
    min_wave = min(wave)

if 'wave_max' in req:
    max_wave = int(req['wave_max'])
else:
    max_wave = max(wave)

# in restframe Wavelength
if 'rwave_min' in req:
    min_wave = int(req['rwave_min'])*(1+z)
else:
    min_wave = min(wave)

if 'rwave_max' in req:
    max_wave = int(req['rwave_max'])*(1+z)
else:
    max_wave = max(wave)

res = flux - fit
idx=np.where(np.logical_and(wave>min_wave,wave<max_wave))
min_fit = min(fit[idx])
max_fit = max(fit[idx])    
yr = max_fit - min_fit
mederr = np.median(err)
zeroline = min_fit - 2.5 * mederr

f = plt.figure(figsize=[12,8], dpi=60,facecolor='w')
ax1 = f.add_subplot(111)
ax2 = ax1.twiny()
ax1.plot(wave, flux, 'k')
ax1.plot(wave, fit, 'r')    
ax1.plot(wave, res + zeroline, 'k')
#ax1.plot(emis_wave, emis+zeroline, 'r')
ax1.plot(wave, -err + zeroline, 'b')
ax1.plot(wave, err + zeroline, 'b')
ax1.plot([min_wave, max_wave], zeroline * np.array([1,1]), '--b')
ax1.set_xlim([min_wave - 30, max_wave + 30])
ax1.set_ylim([zeroline - 2 * mederr, max_fit + max([0.2 * yr, mederr])])
ax1.set_xlabel('Wavelength, A')
ax1.set_ylabel('Relative Flux')
ax1.grid(True, which='minor',color='grey',linestyle='--',linewidth=0.5)
ax1.grid(True, which='major',color='grey',linestyle='-',linewidth=1.0)
ax1.minorticks_on()
ax2.minorticks_on()

# just setting limits of the second axis in the new scaling does the trick
ax2.set_xlim([(min_wave - 30) / (1 + z), (max_wave + 30) / (1 + z)])

plt.show()
