# -*- coding: utf-8 -*-
"""
Created on Mon Mar 30 17:06:01 2015

@author: adrienkuntz
"""

##### Calculation of the bispectrum and the trispectrum in perturbation theory
##### The trispectrum is evaluated in points (k1, -k1, k2, -k2) for the purpose of the study
##### In this program uij means cos(theta(i,j)) with theta(i,j) the angle between ki and kj

import numpy as np
cimport numpy as np
cimport cython_gsl as gsl
from Params import *
import kernels as knl
cdef extern from "math.h" :
    double pow(double, double)

#@cython.boundscheck(False)
#@cython.wraparound(False)




### Get the power spectrum from an external file



    
monfichier = open('matterpower.dat','r')
matterpower = monfichier.readlines()
monfichier.close()
result = np.array([[0,0]])
for ligne in matterpower:
    part = ligne.split('    ')
    nb = np.array([[float(part[1]),float(part[2])]])
    result = np.concatenate((result,nb))


powerspectrum = result[1:,:]
cpdef size_t N = powerspectrum.size/2

cdef double [::1] ktab = powerspectrum[:,0].copy()
cdef double [::1] Ptab = powerspectrum[:,1].copy()


cdef double *kptr
kptr = &ktab[0]
cdef double *Pptr
Pptr = &Ptab[0]


###Growth factor

cpdef double E2(double z) : return omegam0*(1+z)**3 + omegalambda0

cpdef double D(double z) :
    cdef double omegam = omegam0 / E2(z) * (1+z)**3 
    cdef double omegalambda = omegalambda0 / E2(z)
    return 5/2. * omegam / ((1+z) * (pow(omegam, 4/7.) - omegalambda + (1 + omegam/2)*(1 + omegalambda/70)))




### function which interpolates the power spectrum

cdef gsl.gsl_interp_accel *acc = gsl.gsl_interp_accel_alloc ()
cdef gsl.gsl_interp *interp = gsl.gsl_interp_alloc(gsl.gsl_interp_cspline, N) 
gsl.gsl_interp_init(interp, &ktab[0], &Ptab[0], N)


cpdef double P(double k, double z) :
    
    if k < ktab[0] :
        if k == 0 : return 0
        else :
           # print 'Points missing on the left for P : k_asked = {}, k_min = {}'.format(k, ktab[0])
            return Ptab[0] * (D(z)/D(0))**2
    elif k > ktab[N-1] :
        #print 'Points missing on the right for P : k_asked = {}, k_max = {}'.format(k, ktab[N-1])
        return Ptab[N-1] * (D(z)/D(0))**2
    else :
        return gsl.gsl_interp_eval(interp, &ktab[0], &Ptab[0], k, acc) * (D(z)/D(0))**2




#### Bispectrum

cdef double norme2(double k1, double k2, double u12) :
    
    cdef double tmp = k1**2 + k2**2 + 2*k1*k2*u12
    #if tmp < 0. : print 'oops ! Negative norm2 : {} ! k1 = {}, k2 = {}'.format(tmp, k1, k2)    #sometimes a tiny negative norm can happen, due to residual errors
    return gsl.sqrt(gsl.fabs(tmp))
    
cdef double norme3(double k1, double k2, double k3, double u12, double u13, double u23) :
    
    cdef double tmp = k1**2 + k2**2 + k3**2 + 2*k1*k2*u12 + 2*k1*k3*u13 + 2*k2*k3*u23
    #if tmp < 0. : print 'oops ! Negative norm3 : {} !'.format(tmp)
    return gsl.sqrt(gsl.fabs(tmp))


cdef double F2(double k1, double k2, double u12) :
    if (k1 == 0.) or (k2 == 0.) :
        return 0.
    else :
        return 5./7 + 1/2.*(k1/k2 + k2/k1)*u12 + 2./7*u12**2
    

cdef double BPT(double k1, double k2, double k3, double u12, double u13, double u23, double z) :
    
    return (2*F2(k1, k2, u12)*P(k1, z)*P(k2, z) + 2*F2(k1, k3, u13)*P(k1, z)*P(k3, z) + 2*F2(k2, k3, u23)*P(k2, z)*P(k3, z))



### Trispectrum
##return T(k1, -k1, k2, -k2)

cdef double G2(double k1, double k2, double u12) : return 3./7 + 1/2.*(k1/k2 + k2/k1)*u12 + 4./7*u12**2
    


cdef double F3(double k1, double k2, double k3, double u12, double u13, double u23) :
    cdef double n123 = norme3(k1, k2, k3, u12, u13, u23)**2
    cdef double n12 = norme2(k1, k2, u12)**2
    cdef double n23 = norme2(k2, k3, u23)**2
    cdef double terme1, terme2, terme3
    
    if (n12 == 0.) and (n23 == 0.) :
        return 0.
    
    elif n12 == 0. :                  #F3(k1, -k1, k2) is frequently used
        terme1 = 7.*(1 + k2/k1*u12 + k3/k1*u13)
        terme2 = n123*(k1*k2*u12 + k1*k3*u13)/(k1**2 * n23)
        return 1/18.*(terme1 * F2(k2, k3, u23) + terme2 * G2(k2, k3, u23))
        
    elif n23 == 0. :
        terme3 = 7.*(k1**2 + k2**2 + 2*k1*k2*u12 + k1*k3*u13 + k2*k3*u23)/(n12) + n123*(k1*k3*u13 + k2*k3*u23)/(n12*k3**2)
        return 1/18.*(terme3 * G2(k1, k2, u12))

    else :
        terme1 = 7.*(1 + k2/k1*u12 + k3/k1*u13)
        terme2 = n123*(k1*k2*u12 + k1*k3*u13)/(k1**2 * n23)
        terme3 = 7.*(k1**2 + k2**2 + 2*k1*k2*u12 + k1*k3*u13 + k2*k3*u23)/(n12) + n123*(k1*k3*u13 + k2*k3*u23)/(n12*k3**2)
        return 1/18.*(terme1 * F2(k2, k3, u23) + terme2 * G2(k2, k3, u23) + terme3 * G2(k1, k2, u12))
    
    
    
    
cdef double F3s(double k1, double k2, double k3, double u12, double u13, double u23):
    return 1/6.*(F3(k1, k2, k3, u12, u13, u23) + F3(k1, k3, k2, u13, u12, u23) + F3(k2, k1, k3, u12, u23, u13) + F3(k3, k1, k2, u13, u23, u12) + F3(k3, k2, k1, u23, u13, u12) + F3(k2, k3, k1, u23, u12, u13))



cdef double TPT(double k1, double k2, double u12, double z):
    cdef double n12 = norme2(k1, k2, u12)
    cdef double m12 = norme2(k1, k2, -u12)
    cdef double u1_n12, u2_n12, u1_m12, u2_m12
    
    if n12 == 0 :
        u1_n12 = 0.
        u2_n12 = 0.
    else :
        u1_n12 = (k1**2 + k1*k2*u12)/(k1 * n12)      #angle between k1 and k1+k2, etc
        u2_n12 = (k2**2 + k1*k2*u12)/(k2 * n12)
        
    if m12 == 0 :
        u1_m12 = 0.
        u2_m12 = 0.
    else :
        u1_m12 = (k1**2 - k1*k2*u12)/(k1 * m12)     
        u2_m12 = (-k2**2 + k1*k2*u12)/(k2 * m12)

    
    return ( 4*P(n12, z)*(F2(k1, n12, -u1_n12) * P(k1, z) + F2(k2, n12, -u2_n12) * P(k2, z))**2
           + 4*P(m12, z)*(F2(k1, m12, -u1_m12) * P(k1, z) + F2(k2, m12, u2_m12) * P(k2, z))**2
           + 12*(F3s(k1, k1, k2, -1., u12, -u12) * P(k1, z)**2 * P(k2, z) + F3s(k1, k2, k2, u12, -u12, -1.) * P(k1, z) * P(k2, z)**2) )