from astropy.io import fits
from astropy.wcs import WCS
import astropy.units as u

import matplotlib.pyplot as plt
import matplotlib as mpl

mpl.rcParams['xtick.direction'] = 'in'
mpl.rcParams['ytick.direction'] = 'in'
mpl.rcParams['xtick.color'] = 'black'

def cm2inch(value):
    return value/2.54

def load_maps(infile, field):
    primary_hdu, image_hdu = fits.open(infile)

    data = image_hdu.data
    hdu = primary_hdu
    wcs = WCS(image_hdu)

    return data,hdu,wcs


def skymap_subplotsv3(file1,ext1,file2,ext2,file3,ext3):
    plt.rc('font', size=6)

    #Configure here
    cmap = "viridis"

    vmin = -5.
    vmax = 20.
    grid = False

    size = 50
    x1 = 290-size
    x2 = 290+size
    xlim = (250-size,250+size)
    ylim = (250-size,250+size)

    image_data1, hdu1, wcs1  = load_maps(file1,ext1)
    image_data2, hdu2, wcs2  = load_maps(file2,ext2)
    image_data3, hdu3, wcs3  = load_maps(file3,ext3)

    fig = plt.figure(figsize=(cm2inch(12.0),cm2inch(8.5)))

    ax1 = plt.subplot(1,3,1, projection = wcs1)
    ax2 = plt.subplot(1,3,2, projection = wcs2)
    ax3 = plt.subplot(1,3,3, projection = wcs3)

    axes = [ax1,ax2,ax3]
    wcs = [wcs1,wcs2,wcs3]

    image_data = [image_data1,image_data2,image_data3]

    VMIN = [-5,-5,-3]
    VMAX = [20,5,3]

    cbaraxes = [[0.314, 0.5, 0.017, 0.15],
                [0.575, 0.5, 0.017, 0.15],
                [0.834, 0.5, 0.017, 0.15]]

    cbarticks = [[-5, 0.,20.],
                 [-5,0,5],
                 [-3,0,3]]

    ylabel = ["Dec (J2000)","",""]

    ylabel_vis = [True,False,False]
    letter = ["A","B","C"]
    for i, iWCS in enumerate(wcs):

        ax_i = plt.subplot(1,3,i+1, projection = iWCS)
        sky_i = plt.imshow(image_data[i], cmap = cmap, origin = 'lower', vmin = VMIN[i], vmax = VMAX[i])

        ax_i.coords[0].set_ticks(spacing=0.2 *u.degree, color = 'white')
        ax_i.coords[1].set_ticks(spacing=0.2 *u.degree, color = 'white')
        ax_i.coords[0].set_major_formatter('hh:mm')
        ax_i.coords[1].set_major_formatter('dd:mm')
        ax_i.coords[0].display_minor_ticks(True)
        ax_i.coords[1].display_minor_ticks(True)
        ax_i.coords[0].set_minor_frequency(5)

        ax_i.set_xlabel("RA (J2000)", fontsize = 7, labelpad= 1.)
        ax_i.set_ylabel(ylabel[i], fontsize = 7,  labelpad=1.0)
        ax_i.tick_params(axis = "both", which = "major",  labelsize = 7, pad = 5)
        ax_i.coords[1].set_ticklabel_visible(ylabel_vis[i])
        ax_i.set_ylim(ylim)
        ax_i.set_xlim(xlim)

        color_HESS = (219./255., 41./255., 25./255.)
        ax_i.scatter(44.546, -8.96, s=5, color=color_HESS, transform=ax_i.get_transform('icrs'), marker="+") #resulting fitted position

        cbaxes_i = fig.add_axes(cbaraxes[i])
        cbar_i = plt.colorbar(sky_i, ticks = cbarticks[i], orientation="vertical",
                             fraction=0.046,pad=0.03,cax=cbaxes_i)

        cbytick_obj = plt.getp(cbar_i.ax.axes, 'yticklabels')
        plt.setp(ax_i.get_xticklabels(minor=True), visible=True)

        plt.setp(cbytick_obj, color='white', weight = 'bold')

        cbar_i.ax.set_ylabel('', rotation=90)
        cbar_i.ax.tick_params(labelsize=7, color = 'white')
        cbar_i.ax.yaxis.set_tick_params(color='white')
        cbar_i.ax.xaxis.set_tick_params(color='white')

        cbar_i.set_label(r"$\sigma$" ,  color='white',fontsize = 7, labelpad = -30 , weight = "bold", rotation = 90, y = 0.5)
        cbar_i.ax.set_title('',color = "white")

        ax_i.text(44.95,-8.6,letter[i], color = 'white',transform = ax_i.get_transform('icrs'), fontsize = 7, fontweight = 'bold')


    fig.subplots_adjust(wspace=-0.03)

    plt.show()


skymap_subplotsv3('grb190829A_skymap1.fits',1,'grb190829A_skymap2.fits',1,'grb190829A_skymap3.fits',1)
