# -*- coding: utf-8 -*-
"""
Created on Wed Mar 25 10:51:20 2015

@author: adrienkuntz
"""

import numpy as np
from scipy import integrate
from scipy.special import sici
from scipy.optimize import brentq
#import matplotlib.pyplot as plt
import Trispectrum_PT as pt
import kernels as knl
from Params import *
import math




ktab = np.logspace(kmin_log, kmax_log, knum)



#function to obtain multidimensional arrays

def mafonction(ufct, *vectors):
    vs = np.ix_(*vectors)
    r = ufct.identity
    for v in vs:
        r = ufct(r,v)
    return r
    




###Powerspectrum
    
    
def WTH(x): return(3 / x**3 * (np.sin(x) - x * np.cos(x)))
def Wprime(x): return 3 / x**4 * (np.sin(x) * (x**2-3) + 3 * x * np.cos(x))    #Fourier transform of a top hat and its derivative

def Pasympt(k) : return math.pow(k, nspectral-4) * np.log(k)**2
    
def omegam(z) : return omegam0 * (1+z)**3 / knl.E(z)**2
def rho(z) : return rhobar * omegam(z) / omegam0                #comoving density as a function of z
def delta(z) : return (18*np.pi**2 + 82*(omegam(z)-1) - 39*(omegam(z)-1)**2) / omegam(z)        #delta of spherical collapse
    
for z in ztab :
    
    powerspectrum = pt.powerspectrum.copy()
    
    print 'z =', z
    
    powerspectrum[:, 1] = np.array(map(lambda k: pt.P(k, z), pt.powerspectrum[:,0]))

    
    def Rf(m) : return np.exp(1/3. * np.log(3. * m / (4. * np.pi * rho(z))))        #Radius within which the mass is m
        
    ###Definition of sigma2. To have a good value, the integration must be up to k*R(m)~10 because k**2*P(k) is very slowly convergent
        
    kmax_sigma = xmax / Rf(mmin)
    if kmax_sigma > powerspectrum[-1, 0] :             ###There is not enough points, complete P(k) with an asymptotic value in log space
    
        step = np.log10(powerspectrum[1, 0]) - np.log10(powerspectrum[0, 0])
        k0 = powerspectrum[-1, 0]
        Pk0 = powerspectrum[-1, 1]
        k = k0 * math.pow(10, step)
        
        while k <= kmax_sigma :
            ligne = np.array([[k, Pasympt(k) * Pk0 / Pasympt(k0)]])
            powerspectrum = np.concatenate((powerspectrum, ligne))
            k = k * math.pow(10, step)
         
         
    pow_matrix = mafonction(np.multiply, np.ones(mnum), powerspectrum[:,1])
    k_matrix = mafonction(np.multiply, np.ones(mnum), powerspectrum[:,0])
    
    
    def sigma2f(m):         #variance of the initial density field
        f = powerspectrum[:,0]**2 / (2 * np.pi**2) * powerspectrum[:,1] * WTH(powerspectrum[:,0] * Rf(m))**2
        return integrate.simps(f, x=powerspectrum[:,0])
        
    
    def nuf(m): return (delta_sc)**2 / sigma2f(m)
    
    
    mstar = brentq(lambda m:nuf(m)-1., mmin, 1e14)       #critical mass : nu(mstar)=1
       
    m_min_log = np.log10(mmin)
    m_max_log = np.log10(mstar) + 6
    mtab = np.logspace(m_min_log, m_max_log, mnum)
   
    
    
    ###Coefficients b (link the overdensity of halos to the overdensity of dark matter)
        

    
    R = np.exp(1/3. * np.log(3. * mtab / (4. * np.pi * rho(z))))
    kR = mafonction(np.multiply, R, powerspectrum[:,0])                               #k*R
    dRdm = 1./3 * np.exp(1/3. * np.log(3. / (4. * np.pi * rho(z) * mtab**2)))            #dR/dm
    
    
    
    
    
    
    sigma2 = integrate.simps(k_matrix**2 / (2 * np.pi**2) * pow_matrix * WTH(kR)**2, x=powerspectrum[:,0], axis=1)
        
    nu = (delta_sc)**2/sigma2   
    
        
        
    epsilon1 = (q * nu - 1) / (delta_sc)
    epsilon2 = q * nu * (q * nu - 3) / (delta_sc)**2
    epsilon3 = q * nu * (q * nu - 3)**2 / (delta_sc)**3
    
    E1 = 2 * p / (delta_sc) / (1 + np.exp(p * np.log(q * nu)))
    E2 = E1 * (2 * epsilon1 + (1 + 2 * p) / (delta_sc))
    E3 = E1 * ((4*(p**2-1) + 6*p*q*nu) / delta_sc**2 + 3*epsilon1**2)
        
    b1 = 1. + epsilon1 + E1
    b2 = 2 * (1+a2) * (epsilon1 + E1) + epsilon2 + E2
    b3 = 6 * (a2+a3) * (epsilon1 + E1) + 3 * (1+2*a2) * (epsilon2 + E2) + epsilon3 + E3
       
       
    
    
    
    
    ###Sheth-Tormen function with a good normalization factor (adapted to the range of integration considered here)
    
    Sheth = 0.3222 * (1 + np.exp(-p * np.log(q * nu))) * np.sqrt(q / (2 * np.pi * nu)) * np.exp(-q * nu / 2.)
    
    normfactor1 = integrate.simps(Sheth, x=nu)
    Sheth = Sheth / normfactor1
    
    normfactor2 = integrate.simps(Sheth * b1, x=nu)
    b1 = b1 / normfactor2
    
    normfactor3 = integrate.simps(Sheth * b2, x = nu)
    b2 = b2 - normfactor3
    
    normfactor4 = integrate.simps(Sheth * b3, x = nu)
    b3 = b3 - normfactor4
    
    ###density of halos
    
    
    dsigma2dm = dRdm * integrate.simps(k_matrix**3 / (np.pi**2) * pow_matrix * WTH(kR) * Wprime(kR), x=powerspectrum[:,0], axis=1) 
        
    
    numberdens = - rho(z) / mtab * Sheth * dsigma2dm * (delta_sc)**2 / sigma2**2   #convert to Mpc
        
        
        
        
    ###profile of a halo : NFW profile
        
        
    cbar = 9. / (1+z) * np.exp(-0.13 * np.log(mtab/mstar))    #concentration parameter
    rs = R / math.pow(delta(z), 1/3.) / cbar                            #radius parameter
    
    
    def profile(k):
        
        (si1, ci1) = sici((1+cbar) * k * rs)
        (si2, ci2) = sici(k * rs)
        u0 = 1/(np.log(1 + cbar) - cbar/(1+cbar))
        return u0 * (np.sin(k * rs) * (si1 - si2) - np.sin(k * rs * cbar)/((1 + cbar) * k * rs) + np.cos(k * rs) * (ci1 - ci2))
            
    
    
    
    ###Functions Mijn
    

    
    MF04 = lambda k1, k2, k3, k4 : integrate.simps(numberdens * (mtab / rho(z))**4 * profile(k1) * profile(k2) * profile(k3) * profile(k4), x = mtab)
    
    MF02 = lambda k1, k2 : integrate.simps(numberdens * (mtab / rho(z))**2 * profile(k1) * profile(k2), x = mtab)    
    
       
    
    MF11 = lambda k1 : integrate.simps(numberdens * (mtab / rho(z)) * b1 * profile(k1), x = mtab)
        
    MF12 = lambda k1, k2 : integrate.simps(numberdens * (mtab / rho(z))**2 * b1 * profile(k1) * profile(k2), x = mtab)
    
    MF13 = lambda k1, k2, k3 : integrate.simps(numberdens * (mtab / rho(z))**3 * b1 * profile(k1) * profile(k2) * profile(k3), x = mtab)
        
    MF14 = lambda k1, k2, k3, k4 : integrate.simps(numberdens * (mtab / rho(z))**4 * b1 * profile(k1) * profile(k2) * profile(k3) * profile(k4), x = mtab)
    
    
        
    MF21 = lambda k1 : integrate.simps(numberdens * (mtab / rho(z)) * b2 * profile(k1), x = mtab)
        
    MF22 = lambda k1, k2 : integrate.simps(numberdens * (mtab / rho(z))**2 * b2 * profile(k1) * profile(k2), x = mtab)
    
    MF23 = lambda k1, k2, k3 : integrate.simps(numberdens * (mtab / rho(z))**3 * b2 * profile(k1) * profile(k2) * profile(k3), x = mtab)
    
    MF24 = lambda k1, k2, k3, k4 : integrate.simps(numberdens * (mtab / rho(z))**4 * b2 * profile(k1) * profile(k2) * profile(k3) * profile(k4), x = mtab)



    MF31 = lambda k1 : integrate.simps(numberdens * (mtab / rho(z)) * b3 * profile(k1), x = mtab)
        


    
    ###Storage of the values of M in arrays. Only the particular points needed for the trispectrum are stored.
    
    
 
    
    MT11_k1 = np.array(map(MF11, ktab))
        
    MT21_k1 = np.array(map(MF21, ktab))
    
    MT31_k1 = np.array(map(MF31, ktab))

    
    
    MT12_k1k2 = np.zeros((knum, knum))
    for i in range(knum):
        for j in range(knum):
            MT12_k1k2[i,j] = MF12(ktab[i], ktab[j])

            
    MT22_k1k2 = np.zeros((knum, knum))
    for i in range(knum):
        for j in range(knum):
            MT22_k1k2[i,j] = MF22(ktab[i], ktab[j])
    
    
    MT04_k1k1k2k2 = np.zeros((knum, knum))
    for i in range(knum):
        for j in range(knum):
            MT04_k1k1k2k2[i,j] = MF04(ktab[i], ktab[j], ktab[i], ktab[j])
    
    
    MT13_k1k1k2 = np.zeros((knum, knum))
    for i in range(knum):
        for j in range(knum):
            MT13_k1k1k2[i,j] = MF13(ktab[i], ktab[i], ktab[j])
    
    
    MT02_k1k1 = np.zeros(knum)
    for i in range(knum):
        MT02_k1k1[i] = MF02(ktab[i], ktab[i])

    
    
    
    ###Write the result in files
    
    
    M11_k1 = open('M11_k1_{}.dat'.format(z), 'w')
    for i in range(knum) :
        M11_k1.write('{} {}\n'.format(ktab[i], MT11_k1[i]))
    M11_k1.close()
        
    M21_k1 = open('M21_k1_{}.dat'.format(z), 'w')
    for i in range(knum) :
        M21_k1.write('{} {}\n'.format(ktab[i], MT21_k1[i]))
    M21_k1.close()
    
    M31_k1 = open('M31_k1_{}.dat'.format(z), 'w')
    for i in range(knum) :
        M31_k1.write('{} {}\n'.format(ktab[i], MT31_k1[i]))
    M31_k1.close()
            
    M02_k1k1 = open('M02_k1k1_{}.dat'.format(z), 'w')
    for i in range(knum) :
        M02_k1k1.write('{} {}\n'.format(ktab[i], MT02_k1k1[i]))
    M02_k1k1.close()    
    
    
    
    M12_k1k2 = open('M12_k1k2_{}.dat'.format(z), 'w')
    for i in range(knum) :
        for j in range(knum) :
            M12_k1k2.write('{} {} {}\n'.format(ktab[i], ktab[j], MT12_k1k2[i,j]))
    M12_k1k2.close()
        
    M22_k1k2 = open('M22_k1k2_{}.dat'.format(z), 'w')
    for i in range(knum) :
        for j in range(knum) :
            M22_k1k2.write('{} {} {}\n'.format(ktab[i], ktab[j], MT22_k1k2[i,j]))
    M22_k1k2.close()
        
    M04_k1k1k2k2 = open('M04_k1k1k2k2_{}.dat'.format(z), 'w')
    for i in range(knum) :
        for j in range(knum) :
            M04_k1k1k2k2.write('{} {} {}\n'.format(ktab[i], ktab[j], MT04_k1k1k2k2[i,j]))
    M04_k1k1k2k2.close()
        
    M13_k1k1k2 = open('M13_k1k1k2_{}.dat'.format(z), 'w')
    for i in range(knum) :
        for j in range(knum) :
            M13_k1k1k2.write('{} {} {}\n'.format(ktab[i], ktab[j], MT13_k1k1k2[i,j]))
    M13_k1k1k2.close()    