import config
import measfcts
import os
import numpy as np
import matplotlib.ticker as ticker

import momentsml.plot
from momentsml.tools.feature import Feature
import matplotlib.pyplot as plt

import logging
logger = logging.getLogger(__name__)



from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
## for Palatino and other serif fonts use:
#rc('font',**{'family':'serif','serif':['Palatino']})
rc('text', usetex=True)

###############################

weights = True # Run on a validation catalog with weights, or without weights ?

###############################

select = True # Select accordign to <SNR> (only if without weights, is set to False otherwise)
regressmethod = 2 # Only matter when runnign with weights. 2 should be used, as we care about each PSF equally.

###############################


if weights is True:
	select = False
	valname = config.wvalname
else:
	regressmethod = 1
	valname = config.valname

valcat = os.path.join(config.valdir, valname + ".pkl")
cat = momentsml.tools.io.readpickle(valcat)


widescale=False
showlegend=False
if "sum55" in config.valname:
	text = "Ignoring the variability of the PSF"
	showlegend=True
	widescale=True
if "sum77" in config.valname or "pos" in config.valname:
	text = "Using field coordinates as features"
if "sum88" in config.valname or "mom" in config.valname:
	text = "Using PSF moments as features"


if widescale:
	lim = 1e0
else:
	lim = 1e-1


if select:
	momentsml.tools.table.addstats(cat, "snr")
	s = momentsml.tools.table.Selector("snr_mean > 10", [
		("min", "snr_mean", 10.0),
	])
	cat = s.select(cat)
#exit()

"""
for comp in ["1","2"]:

	# If no weights are in the catalog (or not yet), we add ones
	if not "pre_s{}w".format(comp) in cat.colnames:
		
		# First putting all weights to 1.0:
		cat["pre_s{}w".format(comp)] = np.ones(cat["adamom_g1"].shape)
		
	cat["pre_s{}w_norm".format(comp)] = cat["pre_s{}w".format(comp)] / np.max(cat["pre_s{}w".format(comp)])

	momentsml.tools.table.addrmsd(cat, "pre_s{}".format(comp), "tru_s{}".format(comp))
	momentsml.tools.table.addstats(cat, "pre_s{}".format(comp), "pre_s{}w".format(comp))
	cat["pre_s{}_wbias".format(comp)] = cat["pre_s{}_wmean".format(comp)] - cat["tru_s{}".format(comp)]
"""



fwhmpersigma = 2.3548
cat["tru_psf_fwhm"] = cat["tru_psf_sigma"] * fwhmpersigma

tru_psf_g1 = Feature("tru_psf_g1", -0.27, 0.27, nicename=r"PSF $\varepsilon_1$")
tru_psf_g2 = Feature("tru_psf_g2", -0.02, 0.27, nicename=r"PSF $\varepsilon_2$")
tru_psf_fwhm = Feature("tru_psf_fwhm", 4.1, 5.35, nicename=r"PSF FWHM [pix]")


def make_plot(ax, featbin, showlegend=False):
	ax.axhline(0.0, color='gray', lw=0.5)	
	if weights is True:
		momentsml.plot.mcbin.mcbin(ax, cat, Feature("tru_s1"), Feature("pre_s1", rea="all"), featbin, featprew=Feature("pre_s1w", rea="all"), comp=1, regressmethod=regressmethod)
		momentsml.plot.mcbin.mcbin(ax, cat, Feature("tru_s2"), Feature("pre_s2", rea="all"), featbin, featprew=Feature("pre_s2w", rea="all"), comp=2, showbins=False, showlegend=showlegend, regressmethod=regressmethod)
	else:
		momentsml.plot.mcbin.mcbin(ax, cat, Feature("tru_s1"), Feature("pre_s1", rea="all"), featbin, comp=1, regressmethod=regressmethod)
		momentsml.plot.mcbin.mcbin(ax, cat, Feature("tru_s2"), Feature("pre_s2", rea="all"), featbin, comp=2, showbins=False, showlegend=showlegend, regressmethod=regressmethod)	
	momentsml.plot.mcbin.make_symlog(ax, featbin, lim=lim)
	ax.set_xlabel(featbin.nicename)


fig = plt.figure(figsize=(10, 3.0))
plt.subplots_adjust(
	left  = 0.08,  # the left side of the subplots of the figure
	right = 0.99,    # the right side of the subplots of the figure
	bottom = 0.15,   # the bottom of the subplots of the figure
	top = 0.86,      # the top of the subplots of the figure
	wspace = 0.08,   # the amount of width reserved for blank space between subplots,
	                # expressed as a fraction of the average axis width
	hspace = 0.2,   # the amount of height reserved for white space between subplots,
					# expressed as a fraction of the average axis heightbottom=0.1, right=0.8, top=0.9)
	)

ax = plt.subplot(1, 3, 1)
make_plot(ax, tru_psf_g1)
ax.set_ylabel("Metric value")

ax = plt.subplot(1, 3, 2)
make_plot(ax, tru_psf_g2)
ax.set_yticklabels([])

ax = plt.subplot(1, 3, 3)
make_plot(ax, tru_psf_fwhm, showlegend=showlegend)
ax.set_yticklabels([])



fig.text(.005, .94, text, fontdict={"fontsize":12})

momentsml.plot.figures.savefig(os.path.join(config.valdir, valname + "_psf_biases"), fig, fancy=True)

plt.show()

