from astropy.io import fits
from astropy.table import Table
import numpy as np
import glob
import ntpath
import os

fields = {
'WAVE': {'ucd': 'em.wl',
 'unit': 'nm'},
'FLUX': {'ucd': 'phot.flux.density;em.wl',
 'unit': '1e-17 erg cm**(-2) s**(-1) angstrom**(-1)'},
'FIT': {'ucd': 'phot.flux.density;em.wl',
 'unit': '1e-17 erg cm**(-2) s**(-1) angstrom**(-1)'},
'GOODPIXELS': {'ucd': 'meta.code.qual;phot.flux.density, em...'},
'ERROR': {'ucd': 'phot.flux.density;em...;stat.error',
 'unit': '1e-17 erg cm**(-2) s**(-1) angstrom**(-1)'},
'LSF_SIG': {'ucd': 'spect.resolution; em.*',
'unit': 'km/s'},
'V': {'ucd': 'phys.veloc',
 'unit': 'km/s'},
'E_V': {'ucd': 'stat.error;phys.veloc',
 'unit': 'km/s'},
'SIG':{'ucd': 'phys.veloc.dispersion',
 'unit': 'km/s'},
'E_SIG':{'ucd': 'stat.error;phys.veloc.dispersion',
 'unit': 'km/s'},
'AGE':{'ucd': 'time.age',
 'unit': 'Myr'},
'E_AGE':{'ucd': 'stat.error;time.age',
 'unit': 'Myr'},
'CHI2':{'ucd': 'stat.fit.chi2'},
'PHOT_WAVE':{'ucd': 'phot.flux.density;em.wl',
 'unit': 'nm'},
'PHOT_GOODBANDS':{'ucd': 'meta.code.qual;phot.mag, em...',
 'unit': 'nm'},
'PHOT_FLUX':{'ucd': 'phot.mag',
 'unit': 'mag'},
'PHOT_ERROR':{'ucd': 'phot.mag;em...;stat.error',
 'unit': 'mag'},
'PHOT_FIT':{'ucd': 'phot.mag;em...;stat.error',
 'unit': 'mag'},
'PHOT_FIT_COMP':{'ucd': 'phot.mag;em...;stat.error',
 'unit': 'mag'},
'SED_CHI2':{'ucd': 'stat.fit.chi2'},


}

out_dir = 'main_sample_for_zenodo'
for f in glob.glob('main_sample/*ppxf_results*.fits'):
    h = fits.open(f)
    t = Table(h[1].data)
    #t = t[0]
    print(f, t['AGE'], t['E_AGE'], t['MET'], t['E_MET'])
    for he in h[1].header:
        if 'TTYPE' in he:
            if h[1].header[he] in fields.keys():
                field_name = h[1].header[he]
                field_rec = fields[field_name]
                if 'ucd' in field_rec.keys():
                    ucd_header_key = he.replace('TTYPE', 'TUCD')
                    h[1].header[ucd_header_key] = field_rec['ucd']
                    #print(h[1].header)
                if 'unit' in field_rec.keys():
                    ucd_header_key = he.replace('TTYPE', 'TUNIT')
                    h[1].header[ucd_header_key] = field_rec['unit']
                    #print(h[1].header)
                #print(he, fields[field_name])
    if len(h[1].data) > 1:
        h[1].data = h[1].data[:0]
    else:
        h[1].data = h[1].data[:1]
    outfile = os.path.join(out_dir, ntpath.basename(f))
    h.writeto(outfile, overwrite=True)
    print(outfile)



