import astropy
import math
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import matplotlib.pyplot as plt
import numpy as np
import random
import matplotlib.transforms as transforms
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LogNorm
from astropy.io import fits
from matplotlib import gridspec
from astropy.wcs import WCS

list_fits_files = {}

def make_ticklabels_invisible(fig):
    for i, ax in enumerate(fig.axes):
        #ax.text(0.5, 0.5, "ax%d" % (i+1), va="center", ha="center")
        for tl in ax.get_xticklabels() + ax.get_yticklabels():
            tl.set_visible(False)
hdulist = fits.open('garik6.fits')
data = hdulist[1].data

with open('list_fits_files.csv') as f:
    f.readline()
    for line in f:
        arr_splitted = line.split(';')
        list_fits_files[str(arr_splitted[0])] = arr_splitted[1]
frames = [None] * data.shape[0]
j = 0
for i in range(data.shape[0]):
    if   list_fits_files.has_key(str(data[i]['bestObjID'])):
        frames[j] = [None] * 2
        frames[j][0] = list_fits_files[str(data[i]['bestObjID'])]
        frames[j][1] = [data[i]['ra'],data[i]['dec'],str(data[i]['bestObjID'])]
        j = j + 1
frames = frames[0:j]
print frames
#frames = [['../frame-g-005322-2-0142.fits',[230.77068,11.76503, 'number']],['../frame-g-005322-2-0142.fits',[230.77068,11.76503, 'number']],['../frame-g-005322-2-0142.fits',[230.77068,11.76503, 'number']]]


f = plt.figure(figsize=(18,18))
nx = 25
ny = 25
gs1 = gridspec.GridSpec(nx, ny)
gs1.update(left=0.01, right=0.99, wspace=0.00, hspace=0.00)
ax1 = plt.subplot(gs1[0, 1])
ax2 = plt.subplot(gs1[1, 1])
ax3 = plt.subplot(gs1[-1, -1])


for i in range(len(frames)):
    x_pos = i // nx
    y_pos = i % ny
    name = frames[i][0]
    hdulist = fits.open(name)
    data = hdulist[0].data
    w = WCS(hdulist[0].header)
    x, y = w.wcs_world2pix(frames[i][1][0], frames[i][1][1], 0)
    print x , y
    if (x > 0) and (y > 0) and (x < data.shape[1]) and (y < data.shape[0]):
        width = 100
        if (x<width):
            width = x - 1
        if (y<width):
            width = y - 1
        if ((x+width)>data.shape[1]):
            width = abs(data.shape[1] - x -2)
        if ((y+width)>data.shape[0]):
            width = abs(data.shape[0] - y -2)
        print x , y
        print width, data.shape
        pic = data[(int(y)-width):(int(y)+width),(int(x)-width):(int(x)+width)]

        y_arr, x_arr = np.mgrid[slice(1, pic.shape[0]+1 , 1), slice(1, pic.shape[1]+1 , 1)]
        ax = plt.subplot(gs1[x_pos, y_pos])
        surfpsf=ax.pcolor(x_arr, y_arr, np.arcsinh(pic*5),  cmap='gray', linewidth=0)
        ax.text(70, 20, frames[i][1][2],fontsize=3, ha='center', va='bottom', color='green')
    
    
make_ticklabels_invisible(plt.gcf())
plt.savefig('rowimages.png', dpi=600)