"""
Script to extract information from the HEGS maps
 author Francois Brun  francois.brun@cea.fr
"""

import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import simple_norm
import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.table import Table
from gammapy.maps import Map
from regions import CircleSkyRegion
import astropy
import math
from math import *
import sys
import argparse
import os
import glob

def plot_mask(ax, mask, **kwargs):
            if mask is not None:
                mask.plot_mask(ax=ax, **kwargs)

def get_norm_param(skymap):
    vmin= (skymap).data[np.where(skymap != 0)].min()
    vmax = (skymap).data[np.where(skymap != 0)].max()
    return vmin, vmax

def integflux(Phi_0,E_0,Index,E_thresh,E_max):
    Index  = -fabs(Index);
    if E_max > E_thresh:
        int_value_ethresh = -1. * (Phi_0 * E_0 / (1+Index)) * ((E_thresh/E_0)**(1+Index))
        int_valuemax_emax = -1. * (Phi_0 * E_0 / (1+Index)) * ((E_max/E_0)**(1+Index))
        int_value=int_value_ethresh-int_valuemax_emax;
    else:
        int_value = -1. * (Phi_0 * E_0 / (1+Index)) * ((E_thresh/E_0)**(1+Index))

    return int_value

if __name__ == "__main__":
    parser=argparse.ArgumentParser(description="Search for a position in the HEGS dataset")
    parser.add_argument("datapath",type=str,help="Path to the HEGS data files")
    parser.add_argument("ra",type=float,help="Right Ascension (J2000) of the searched position")
    parser.add_argument("dec",type=float,help="Declination (J2000) of the searched position")
    parser.add_argument("--nomap", action="store_true")
    args=parser.parse_args()
    datapath= args.datapath
    user_ra = args.ra
    user_dec= args.dec
    
    user_coord=SkyCoord(user_ra*u.deg,user_dec*u.deg)
    tFoV=Table.read(datapath+'/'+'HEGS_FOVs_v1.fits',hdu=1)

    print("Searching for coordinates RA:{} deg, DEC:{} in the HEGS dataset. \n".format(user_ra, user_dec))

    filenames=[]
    selectedname=[]

    for entry in range(len(tFoV)):
        name,ra,dec,extx,exty = (tFoV[entry][0],tFoV[entry][1],tFoV[entry][2],tFoV[entry][3],tFoV[entry][4])
        #print(user_ra,user_dec,rc,name,ra,dec,extx,exty,user_ra-ra,user_dec-dec)
        if (abs(user_ra-ra) < extx and abs(user_dec-dec) < exty) or (abs(user_ra-360-ra) < extx and abs(user_dec-dec)):
            print("Found corresponding FOV : {}".format(name))
            print("> name : {}".format(name))
            print("> ra,dec : {}, {}".format(ra,dec))
            filenames.append(name)
            
    for fname in filenames:
        mapname = datapath+'/'+fname+"_Maps_Loose_Index3.fits.gz"
        #print("-> ", mapname)
        skymap_time = Map.read(mapname,hdu=1)
        time=skymap_time.get_by_coord((user_ra,user_dec))
        if np.isnan(time[0]):
            print("nan")
        else:
            selectedname.append(fname)

    if len(selectedname) == 0:
        sys.exit("No observation time available at the requested position.")
    if len(selectedname) == 1:
        print("Found observation time in FOV {}".format(selectedname[0]))
        name=selectedname[0]
    if len(selectedname) > 1:
        print("Problem -> this should not happen!")

    mapname = datapath+'/'+name+"_Maps_Loose_Index3.fits.gz"
    print("File : {}".format(mapname))

    skymap_time = Map.read(mapname,hdu=1)
    skymap_sig = Map.read(mapname,hdu=2)
    skymap_meanth = Map.read(mapname,hdu=3)
    skymap_flux = Map.read(mapname,hdu=4)
    skymap_fluxerr = Map.read(mapname,hdu=5)
    skymap_fluxUL = Map.read(mapname,hdu=6)
    skymap_fluxSens_5p0 = Map.read(mapname,hdu=7)
    skymap_fluxSens_5p7 = Map.read(mapname,hdu=8)

    skymap_sig_mask = (skymap_sig != 0)

    #Printing values
    time=skymap_time.get_by_coord((user_ra,user_dec))
    if np.isnan(time[0]):
        print("nan")
    else:
        print("\n ----- STATS -----")
        print("Observation time =  {} hours".format(time[0]))
        print("Significance = {}".format(skymap_sig.get_by_coord((user_ra,user_dec))[0]))

    if time[0] == 0 or np.isnan(time[0]):
        print("No observation time available at this position")
    else:
        meanth = skymap_meanth.get_by_coord((user_ra,user_dec))[0]
        flux = skymap_flux.get_by_coord((user_ra,user_dec))[0] 
        dflux = skymap_fluxerr.get_by_coord((user_ra,user_dec))[0]
        fluxul = skymap_fluxUL.get_by_coord((user_ra,user_dec))[0]
        fluxsens5p0 = skymap_fluxSens_5p0.get_by_coord((user_ra,user_dec))[0] 
        fluxsens5p7 = skymap_fluxSens_5p7.get_by_coord((user_ra,user_dec))[0] 

        print("E Threshold (TeV) - mean = {}".format(meanth))
        #print("Significance = {}".format(skymap_sig.get_by_coord((user_ra,user_dec))[0]))
        print("\n ----- FLUX VALUES -----")
        print("Flux (1 TeV) = {} +/- {} /cm2/s/TeV".format(flux,dflux))
        print("Flux UL (1 TeV, 95% CL) = {} /cm2/s/TeV".format(fluxul))
        print("Flux Sensitivity (1 TeV, 5.0 sigma) = {} /cm2/s/TeV".format(fluxsens5p0))
        print("Flux Sensitivity (1 TeV, 5.7 sigma) = {} /cm2/s/TeV".format(fluxsens5p7))
        print("---")
        print('Flux > 0.3 TeV = {}  +/- {} /cm2/s'.format(integflux(flux,1,-3,0.3,-1),integflux(dflux,1,-3,0.3,-1)))
        print('Flux UL (> 0.3 TeV,  95% CL) = {} /cm2/s'.format(integflux(fluxul,1,-3,0.3,-1)))
        print('Flux UL (> Eth, 95% CL) = {} /cm2/s'.format(integflux(fluxul,1,-3,meanth,-1)))

    if not args.nomap:  
        #Plot the significance map
        fig, ax = plt.subplots(figsize=(9,9),subplot_kw={"projection": skymap_sig.geom.wcs}, ncols=1, nrows=1)
        plt.title("Test position : RA={}, DEC={}".format(user_ra,user_dec))

        min1,max1=get_norm_param(skymap_sig)
        norm1 = simple_norm(skymap_sig.data,vmin=min1, vmax=max1)
        skymap_sig.plot(ax=ax, add_cbar=True,cmap="turbo")
        skymap_sig_mask.plot_mask(ax=ax, colors='lightgrey', alpha=1)
        reg = CircleSkyRegion(user_coord, 0.118*u.deg)
        reg_pix=reg.to_pixel(ax.wcs)
        reg_pix.plot(ax=ax, edgecolor="blue",linestyle="--")

        plt.show()
