import sys
import numpy as np
from astropy.io import fits
from scipy.special import erf as erf
from astropy.wcs import WCS
import matplotlib.pyplot as plt


def calc_refr_indx(T,p,l):
    """Calculation of refraction index given temperature - T (in C),
    pressure - p (in mm Hg), wavelength - l (in microns)"""
    n_1 = (64.328 + 29498.1 / (146 - (1. / l)**2) + 255.4 / (41 - (1. / l)**2)) * 1e-6 * (p * (1 + (1.049-0.0157*T)*1e-6*p))/(720.883 * (1. + 0.003661*T)) - 8 * 1e-6 * (0.0624 - 0.00068 / (l**2) ) / (1 + 0.003661 * T)
    n = n_1 + 1
    return n


def calc_R(z,T,p,l):
    """Calculation of refraction value at zenith angle z -- in radians"""
    dR = 206265.0*(calc_refr_indx(T, p, l) - calc_refr_indx(T, p, 0.65)) * np.tan(z)
    return dR


def calc_throughput(lambdas, rotangle, parangle, airmass, temperature, pressure,
                    slit_width, seeing_fwhm, plot=False):
    """
    Fraction of light passed through the slit at given wavelengths

    Parameters
    ----------
    lambdas : np.array
        Wavelengths in microns.
    rotangle : float
        Rotational angle.
    parangle : float
        Paralactic angle.
    airmass : float
        Airmass of the object.
    temperature : float
        Temperature in Celsius.
    pressure : float
        Pressure in mm Hg.
    slit_width : float
        Slit width of the spectrograph.
    seeing_fwhm : float
        FWHM around Halpha at 0.7 micron.
    plot : bool
        Description of parameter `plot`.

    Returns
    -------
    type
        Description of returned object.
    """
    seeing_sigma = seeing_fwhm / 2.355
    seeing_sigma = seeing_sigma * (lambdas / 0.7)**(-1. / 5.)

    z = np.arctan(np.sqrt(airmass**2 - 1))

    # offset perpendicular the slit
    dl = calc_R(z, temperature, pressure, lambdas) * np.sin(-(parangle-rotangle) / 57.3)
    # offset along the slit
    dh = calc_R(z, temperature, pressure, lambdas) * np.cos(-(parangle-rotangle) / 57.3)

    # fraction of light passed through the slit
    throughput = (
        erf((dl + slit_width / 2.) / (np.sqrt(2) * seeing_sigma)) -
        erf((dl - slit_width / 2.) / (np.sqrt(2) * seeing_sigma))
        )

    throughput /= 2 # IK Why this is?

    if plot:
        plt.plot(lambdas, throughput, label='Slit Throughput')
        plt.plot(lambdas, dl, label="Shift cross slit (arcsec)")
        plt.plot(lambdas, dh, label="Shift along slit (arcsec)")
        plt.xlabel('Wavelength, $\mu$m')
        plt.legend()
        plt.show()

    return throughput


if __name__ == "__main__":
    """
    Run Example:
    %run ../aper_corr/loss_script.py 1605+1748_021.fits 0
    """

    # input filename -- redused MagE spectrum
    if len(sys.argv) == 1:
        filename = '2101-0031_001_merged.fits'
        ext = 1
    else:
        filename = sys.argv[1]
        ext = int(sys.argv[2])


    with fits.open(filename) as hdu:
        header = hdu[ext].header

    rotangle = header['ROTANGLE'] + 44.5 # PA correction for MagE data
    parangle = header['PARANGLE']
    airmass = header['AIRMASS']

    lambdas = np.linspace(0.4, 0.9, 100) # in microns

    # Physical conditions and parameters
    # temperature (in C)
    temperature = 7
    # pressure (mm Hg)
    pressure = 600
    # slit width in arcseconds
    slit_width = 0.7
    # seeing in arcseconds in r
    seeing_fwhm = 0.5

    slit_fraction_pass = calc_throughput(lambdas, rotangle, parangle, airmass,
                                         temperature, pressure, slit_width,
                                         seeing_fwhm, plot=True)
