from astroquery.ukidss import Ukidss
import astropy.units as u
import astropy.coordinates as coord
import os
from astropy.io import fits
from astropy.coordinates import SkyCoord
import subprocess as sp
import numpy as np
from imageio import imread, imsave
import glob
from prepare_frames import prepare

def make_coadd_ukidss(ra, dec, img_size, plate_dir):

   c = SkyCoord(ra*u.degree, dec*u.degree)
   #working_dir = 'images/'
   #out_fname = 'UKIDSS_J{0}{1}'.format(c.ra.to_string(unit=u.hourangle, sep='', precision=2, pad=True), 
   #                                         c.dec.to_string(sep='', precision=2, alwayssign=True, pad=True))
   #plate_dir = os.path.join(working_dir, out_fname)

   download_dir = os.path.join(plate_dir, 'original')
   os.makedirs(download_dir, exist_ok=True)

   uncomp_dir = os.path.join(plate_dir, 'uncomp')
   os.makedirs(uncomp_dir, exist_ok=True)

   prep_dir = os.path.join(plate_dir, 'prepared')
   os.makedirs(prep_dir, exist_ok=True)

   swarp_dir = os.path.join(plate_dir, 'swarp')
   os.makedirs(swarp_dir, exist_ok=True)
 
   resamp_dir = os.path.join(plate_dir, 'swarp_resamp')
   os.makedirs(resamp_dir, exist_ok=True)

   imlist_K = Ukidss.get_image_list(c, radius=(img_size/2.0+6.25)*1.5 * u.arcmin, frame_type='stack', waveband='K')
   imlist_H = Ukidss.get_image_list(c, radius=(img_size/2.0+6.25)*1.5 * u.arcmin, frame_type='stack', waveband='H')
   imlist_Y = Ukidss.get_image_list(c, radius=(img_size/2.0+6.25)*1.5 * u.arcmin, frame_type='stack', waveband='Y')
   imlist = np.concatenate((imlist_K, imlist_H, imlist_Y))

   dow_list = os.path.join(download_dir, 'urls.txt')

   with open(dow_list, 'w+') as f:
      for im in imlist:
         f.write("%s\n" % im)

   call_arg_download = ['wget',
                  '-P', download_dir,
                  '--content-disposition',
                  '-i', dow_list]

   sp.check_call(call_arg_download)
   for im in glob.glob(os.path.join(download_dir, '*.fit')):
      call_arg_uncomp = ['imcopy', im, 
                          os.path.join(uncomp_dir, os.path.basename(im))]
      sp.check_call(call_arg_uncomp)

   for fn in glob.glob(os.path.join(uncomp_dir, '*.fit')):
      prepare(fn, os.path.join(prep_dir, os.path.basename(fn)))

   images_dict = {}
   for fn in glob.glob(os.path.join(prep_dir, '*.fit')):
      if 'weight' not in fn:
         with fits.open(fn) as im:
            if im[0].header['FILTER'] not in images_dict.keys():
               images_dict[im[0].header['FILTER']] = {'fits_files':[]}
            images_dict[im[0].header['FILTER']]['fits_files'].append(fn)

   bands_to_use = ['Y', 'H', 'K']
   for b in bands_to_use:
      fn = '%s_INPFITS.txt' % b
      with open(os.path.join(swarp_dir, fn), 'w+') as t:
         for f in images_dict[b]['fits_files']:
            for ext_n in [1, 2, 3, 4]:
               t.write('%s[%i]\n' % (f,ext_n))
      fn = '%s_INPFITS_WEIGHT.txt' % b
      with open(os.path.join(swarp_dir, fn), 'w+') as t:
         for f in images_dict[b]['fits_files']:
            for ext_n in [1, 2, 3, 4]:
               t.write('%s[%i]\n' % (f.replace('.fits','.weight.fits'), ext_n))


   pix_size = img_size * 60.0 / 0.4

   for b in bands_to_use:
      outcoadd = os.path.join(swarp_dir, '%s_COADD.fits' % b)
      list_fits = os.path.join(swarp_dir,'%s_INPFITS.txt' % b)
      list_fits_weight = os.path.join(swarp_dir,'%s_INPFITS_WEIGHT.txt' % b)
      call_arg_swarp = ['SWarp', 
                      '-SUBTRACT_BACK', 'Y',
                      '-COMBINE', 'Y',
                      '-COMBINE_TYPE', 'MEDIAN',
                      '-WRITE_XML', 'N',
                      '-NTHREADS', '1',
                      '-RESAMPLE_DIR', '%s' % resamp_dir,
                      '-CENTER_TYPE', 'MANUAL',
                      '-CENTER', '%s,%s' % (c.ra.to_string(unit=u.hourangle, sep=':', precision=2, pad=True), c.dec.to_string(sep=':', precision=2, pad=True)),
                      '-PIXELSCALE_TYPE', 'MANUAL',
                      '-WEIGHT_IMAGE', '%s' % list_fits_weight,
                      '-PIXEL_SCALE', '0.4',
                      '-IMAGEOUT_NAME', outcoadd,
                      '-IMAGE_SIZE', '%d,%d' % (pix_size, pix_size),
                      '-VERBOSE_TYPE', 'NORMAL',
                      '@%s' % list_fits]
      sp.check_call(call_arg_swarp)

if __name__ == '__main__':
   ra = 194.9883333
   dec = 28.00 + 3.0/60.0
   image_size_arcmin = 60.0
   make_coadd_ukidss(ra, dec, image_size_arcmin)
