# -*- coding: utf-8 -*-
"""
Created on Tue Apr  7 11:41:13 2015

@author: adrienkuntz
"""

import numpy as np
cimport numpy as np
cimport cython_gsl as gsl
cimport Trispectrum_PTc as ptc
from Params import *
from libc.stdlib cimport malloc, free


cdef double [::1] ktab = np.logspace(kmin_log, kmax_log, knum)
cdef double [:] ztabc = ztab.copy()
cdef double zmax = zmax


#function which performs a search in a sorted array in c

cdef int search(double [:] tab, int length, double value) :
    cdef int i = 0, result = -1
    while (i < length) and (result == -1) :
        if tab[i] >= value :
            result = i
        i += 1
    return result



###Define the arrays containing the values of M






cdef double [:,::1] MT11_k1 = np.zeros((znum, knum))
cdef double [:,::1] MT21_k1 = np.zeros((znum, knum))
cdef double [:,::1] MT31_k1 = np.zeros((znum, knum))
cdef double [:,:,::1] MT12_k1k2 = np.zeros((znum, knum, knum), order = 'C')
cdef double [:,:,::1] MT22_k1k2 = np.zeros((znum, knum, knum), order = 'C')
cdef double [:,:,::1] MT04_k1k1k2k2 = np.zeros((znum, knum, knum), order = 'C')
cdef double [:,:,::1] MT13_k1k1k2 = np.zeros((znum, knum, knum), order = 'C')
cdef double [:,::1] MT02_k1k1 = np.zeros((znum, knum))



##definitions for the interpolation. The values of z are indexed in an array (very simple interpolation since there is no function interp3d available)

cdef gsl.gsl_interp_accel **accM11_k1 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.gsl_interp **interpM11_k1 = <gsl.gsl_interp **> malloc(znum * sizeof(gsl.gsl_interp_alloc(gsl.gsl_interp_cspline, knum)))

cdef gsl.gsl_interp_accel **accM21_k1 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.gsl_interp **interpM21_k1 = <gsl.gsl_interp **> malloc(znum * sizeof(gsl.gsl_interp_alloc(gsl.gsl_interp_cspline, knum)))

cdef gsl.gsl_interp_accel **accM31_k1 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.gsl_interp **interpM31_k1 = <gsl.gsl_interp **> malloc(znum * sizeof(gsl.gsl_interp_alloc(gsl.gsl_interp_cspline, knum)))

cdef gsl.gsl_interp_accel **accM02_k1k1 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.gsl_interp **interpM02_k1k1 = <gsl.gsl_interp **> malloc(znum * sizeof(gsl.gsl_interp_alloc(gsl.gsl_interp_cspline, knum)))



cdef gsl.gsl_interp_accel **xaccM12_k1k2 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.gsl_interp_accel **yaccM12_k1k2 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.interp2d **interpM12_k1k2 = <gsl.interp2d **> malloc(znum * sizeof(gsl.interp2d_alloc(gsl.interp2d_bicubic, knum, knum)))

cdef gsl.gsl_interp_accel **xaccM22_k1k2 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.gsl_interp_accel **yaccM22_k1k2 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.interp2d **interpM22_k1k2 = <gsl.interp2d **> malloc(znum * sizeof(gsl.interp2d_alloc(gsl.interp2d_bicubic, knum, knum)))

cdef gsl.gsl_interp_accel **xaccM04_k1k1k2k2 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.gsl_interp_accel **yaccM04_k1k1k2k2 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.interp2d **interpM04_k1k1k2k2 = <gsl.interp2d **> malloc(znum * sizeof(gsl.interp2d_alloc(gsl.interp2d_bicubic, knum, knum)))

cdef gsl.gsl_interp_accel **xaccM13_k1k1k2 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.gsl_interp_accel **yaccM13_k1k1k2 = <gsl.gsl_interp_accel **> malloc(znum * sizeof(gsl.gsl_interp_accel*))
cdef gsl.interp2d **interpM13_k1k1k2 = <gsl.interp2d **> malloc(znum * sizeof(gsl.interp2d_alloc(gsl.interp2d_bicubic, knum, knum)))










cdef int l
cdef int i = 0
cdef int j = 0

for l in range(znum) :  
    
    ###Get M from the files
    
    monfichier = open('M11_k1_{}.dat'.format(ztabc[l]), 'r')
    iterable = monfichier.readlines()
    monfichier.close()
    for ligne in iterable :
        part = ligne.split(' ')
        MT11_k1[l, i] = float(part[1])
        i = i + 1
    i = 0
    
    
        
    monfichier = open('M21_k1_{}.dat'.format(ztabc[l]), 'r')
    iterable = monfichier.readlines()
    monfichier.close()
    for ligne in iterable :
        part = ligne.split(' ')
        MT21_k1[l, i] = float(part[1])
        i = i + 1
    i = 0
    
    monfichier = open('M31_k1_{}.dat'.format(ztabc[l]), 'r')
    iterable = monfichier.readlines()
    monfichier.close()
    for ligne in iterable :
        part = ligne.split(' ')
        MT31_k1[l, i] = float(part[1])
        i = i + 1
    i = 0
    
       
    
    
    
    monfichier = open('M12_k1k2_{}.dat'.format(ztabc[l]), 'r')
    iterable = monfichier.readlines()
    monfichier.close()
    for ligne in iterable :
        part = ligne.split(' ')
        MT12_k1k2[l, i, j] = float(part[2])
        j = j + 1
        if j == knum :
            i = i + 1
            j = 0
    i = 0
    j = 0
         
         
    
    monfichier = open('M22_k1k2_{}.dat'.format(ztabc[l]), 'r')
    iterable = monfichier.readlines()
    monfichier.close()
    for ligne in iterable :
        part = ligne.split(' ')
        MT22_k1k2[l, i, j] = float(part[2])
        j = j + 1
        if j == knum :
            i = i + 1
            j = 0
    i = 0
    j = 0
            
            
    
    monfichier = open('M04_k1k1k2k2_{}.dat'.format(ztabc[l]), 'r')
    iterable = monfichier.readlines()
    monfichier.close()
    for ligne in iterable :
        part = ligne.split(' ')
        MT04_k1k1k2k2[l, i, j] = float(part[2])
        j = j + 1
        if j == knum :
            i = i + 1
            j = 0
    i = 0
    j = 0
    
    
    
    monfichier = open('M13_k1k1k2_{}.dat'.format(ztabc[l]), 'r')
    iterable = monfichier.readlines()
    monfichier.close()
    for ligne in iterable :
        part = ligne.split(' ')
        MT13_k1k1k2[l, i, j] = float(part[2])
        j = j + 1
        if j == knum :
            i = i + 1
            j = 0
    i = 0
    j = 0
    
    
    
    
    monfichier = open('M02_k1k1_{}.dat'.format(ztabc[l]), 'r')
    iterable = monfichier.readlines()
    monfichier.close()
    for ligne in iterable :
        part = ligne.split(' ')
        MT02_k1k1[l, i] = float(part[1])
        i = i + 1
    i = 0
    
    
    
    
    ###initialize the interpolation
    
    accM11_k1[l] = gsl.gsl_interp_accel_alloc()
    interpM11_k1[l] = gsl.gsl_interp_alloc(gsl.gsl_interp_cspline, knum)
    gsl.gsl_interp_init(interpM11_k1[l], &ktab[0], &MT11_k1[l, 0], knum)
    
    accM21_k1[l] = gsl.gsl_interp_accel_alloc()
    interpM21_k1[l] = gsl.gsl_interp_alloc(gsl.gsl_interp_cspline, knum)
    gsl.gsl_interp_init(interpM21_k1[l], &ktab[0], &MT21_k1[l, 0], knum)
    
    accM31_k1[l] = gsl.gsl_interp_accel_alloc()
    interpM31_k1[l] = gsl.gsl_interp_alloc(gsl.gsl_interp_cspline, knum)
    gsl.gsl_interp_init(interpM31_k1[l], &ktab[0], &MT31_k1[l, 0], knum)
    
    accM02_k1k1[l] = gsl.gsl_interp_accel_alloc()
    interpM02_k1k1[l] = gsl.gsl_interp_alloc(gsl.gsl_interp_cspline, knum)
    gsl.gsl_interp_init(interpM02_k1k1[l], &ktab[0], &MT02_k1k1[l, 0], knum)
   
  
    
    
    xaccM12_k1k2[l] = gsl.gsl_interp_accel_alloc()
    yaccM12_k1k2[l] = gsl.gsl_interp_accel_alloc()
    interpM12_k1k2[l] = gsl.interp2d_alloc(gsl.interp2d_bicubic, knum, knum)
    gsl.interp2d_init(interpM12_k1k2[l], &ktab[0], &ktab[0], &MT12_k1k2[l, 0, 0], knum, knum)
    
    xaccM22_k1k2[l] = gsl.gsl_interp_accel_alloc()
    yaccM22_k1k2[l] = gsl.gsl_interp_accel_alloc()
    interpM22_k1k2[l] = gsl.interp2d_alloc(gsl.interp2d_bicubic, knum, knum)
    gsl.interp2d_init(interpM22_k1k2[l], &ktab[0], &ktab[0], &MT22_k1k2[l, 0, 0], knum, knum)
    
    xaccM04_k1k1k2k2[l] = gsl.gsl_interp_accel_alloc()
    yaccM04_k1k1k2k2[l] = gsl.gsl_interp_accel_alloc()
    interpM04_k1k1k2k2[l] = gsl.interp2d_alloc(gsl.interp2d_bicubic, knum, knum)
    gsl.interp2d_init(interpM04_k1k1k2k2[l], &ktab[0], &ktab[0], &MT04_k1k1k2k2[l, 0, 0], knum, knum)
    
    xaccM13_k1k1k2[l] = gsl.gsl_interp_accel_alloc()
    yaccM13_k1k1k2[l] = gsl.gsl_interp_accel_alloc()
    interpM13_k1k1k2[l] = gsl.interp2d_alloc(gsl.interp2d_bicubic, knum, knum)
    gsl.interp2d_init(interpM13_k1k1k2[l], &ktab[0], &ktab[0], &MT13_k1k1k2[l, 0, 0], knum, knum)






##interpolation functions
##To avoid too much verbose it does not display a message when the asked point is outside the range of ktab. Only P does this.

## 1D ###########################################################################################

cpdef double M11_k1(double k1, double z):
    cdef int i = 0
    cdef double left, right, t
    
    if z >= zmax :
        if k1 < ktab[0] :
            return MT11_k1[znum-1, 0]
        elif k1 > ktab[knum-1] :
            return MT11_k1[znum-1, knum-1]
        else :
            return gsl.gsl_interp_eval(interpM11_k1[znum-1], &ktab[0], &MT11_k1[znum-1, 0], k1, accM11_k1[znum-1])
            
    else :
        i = search(ztabc, znum, z)
        t = (z - ztabc[i])/(ztabc[i-1] - ztabc[i])
        
        if k1 < ktab[0] :
            left = MT11_k1[i-1, 0]
            right = MT11_k1[i, 0]
        elif k1 > ktab[knum-1] :
            left = MT11_k1[i-1, knum-1]
            right = MT11_k1[i, knum-1]
        else :
            left = gsl.gsl_interp_eval(interpM11_k1[i-1], &ktab[0], &MT11_k1[i-1, 0], k1, accM11_k1[i-1])
            right = gsl.gsl_interp_eval(interpM11_k1[i], &ktab[0], &MT11_k1[i, 0], k1, accM11_k1[i])
            
        return t*left + (1-t)*right
        


cpdef double M21_k1(double k1, double z):
    cdef int i = 0
    cdef double left, right, t
    
    if z >= zmax :
        if k1 < ktab[0] :
            return MT21_k1[znum-1, 0]
        elif k1 > ktab[knum-1] :
            return MT21_k1[znum-1, knum-1]
        else :
            return gsl.gsl_interp_eval(interpM21_k1[znum-1], &ktab[0], &MT21_k1[znum-1, 0], k1, accM21_k1[znum-1])
            
    else :
        i = search(ztabc, znum, z)
        t = (z - ztabc[i])/(ztabc[i-1] - ztabc[i])
        
        if k1 < ktab[0] :
            left = MT21_k1[i-1, 0]
            right = MT21_k1[i, 0]
        elif k1 > ktab[knum-1] :
            left = MT21_k1[i-1, knum-1]
            right = MT21_k1[i, knum-1]
        else :
            left = gsl.gsl_interp_eval(interpM21_k1[i-1], &ktab[0], &MT21_k1[i-1, 0], k1, accM21_k1[i-1])
            right = gsl.gsl_interp_eval(interpM21_k1[i], &ktab[0], &MT21_k1[i, 0], k1, accM21_k1[i])
            
        return t*left + (1-t)*right
        

cpdef double M31_k1(double k1, double z):
    cdef int i = 0
    cdef double left, right, t
    
    if z >= zmax :
        if k1 < ktab[0] :
            return MT31_k1[znum-1, 0]
        elif k1 > ktab[knum-1] :
            return MT31_k1[znum-1, knum-1]
        else :
            return gsl.gsl_interp_eval(interpM31_k1[znum-1], &ktab[0], &MT31_k1[znum-1, 0], k1, accM31_k1[znum-1])
            
    else :
        i = search(ztabc, znum, z)
        t = (z - ztabc[i])/(ztabc[i-1] - ztabc[i])
        
        if k1 < ktab[0] :
            left = MT31_k1[i-1, 0]
            right = MT31_k1[i, 0]
        elif k1 > ktab[knum-1] :
            left = MT31_k1[i-1, knum-1]
            right = MT31_k1[i, knum-1]
        else :
            left = gsl.gsl_interp_eval(interpM31_k1[i-1], &ktab[0], &MT31_k1[i-1, 0], k1, accM31_k1[i-1])
            right = gsl.gsl_interp_eval(interpM31_k1[i], &ktab[0], &MT31_k1[i, 0], k1, accM31_k1[i])
            
        return t*left + (1-t)*right




cpdef double M02_k1k1(double k1, double z):
    cdef int i = 0
    cdef double left, right, t
    
    if z >= zmax :
        if k1 < ktab[0] :
            return MT02_k1k1[znum-1, 0]
        elif k1 > ktab[knum-1] :
            return MT02_k1k1[znum-1, knum-1]
        else :
            return gsl.gsl_interp_eval(interpM02_k1k1[znum-1], &ktab[0], &MT02_k1k1[znum-1, 0], k1, accM02_k1k1[znum-1])
            
    else :
        i = search(ztabc, znum, z)
        t = (z - ztabc[i])/(ztabc[i-1] - ztabc[i])
        
        if k1 < ktab[0] :
            left = MT02_k1k1[i-1, 0]
            right = MT02_k1k1[i, 0]
        elif k1 > ktab[knum-1] :
            left = MT02_k1k1[i-1, knum-1]
            right = MT02_k1k1[i, knum-1]
        else :
            left = gsl.gsl_interp_eval(interpM02_k1k1[i-1], &ktab[0], &MT02_k1k1[i-1, 0], k1, accM02_k1k1[i-1])
            right = gsl.gsl_interp_eval(interpM02_k1k1[i], &ktab[0], &MT02_k1k1[i, 0], k1, accM02_k1k1[i])
            
        return t*left + (1-t)*right




   
## 2D ################################################################################################### 
  
  



cpdef double M12_k1k2(double k1, double k2, double z):
    
    if z >= zmax :
        return gsl.interp2d_eval_no_boundary_check(interpM12_k1k2[znum-1], &ktab[0], &ktab[0], &MT12_k1k2[znum-1,0,0], k1, k2, xaccM12_k1k2[znum-1], yaccM12_k1k2[znum-1])

    else :
        i = search(ztabc, znum, z)
        t = (z - ztabc[i])/(ztabc[i-1] - ztabc[i])
    
        left = gsl.interp2d_eval_no_boundary_check(interpM12_k1k2[i-1], &ktab[0], &ktab[0], &MT12_k1k2[i-1,0,0], k1, k2, xaccM12_k1k2[i-1], yaccM12_k1k2[i-1])
        right = gsl.interp2d_eval_no_boundary_check(interpM12_k1k2[i], &ktab[0], &ktab[0], &MT12_k1k2[i,0,0], k1, k2, xaccM12_k1k2[i], yaccM12_k1k2[i])
            
        return t*left + (1-t)*right


cpdef double M22_k1k2(double k1, double k2, double z):
    
    if z >= zmax :
        return gsl.interp2d_eval_no_boundary_check(interpM22_k1k2[znum-1], &ktab[0], &ktab[0], &MT22_k1k2[znum-1,0,0], k1, k2, xaccM22_k1k2[znum-1], yaccM22_k1k2[znum-1])

    else :
        i = search(ztabc, znum, z)
        t = (z - ztabc[i])/(ztabc[i-1] - ztabc[i])
    
        left = gsl.interp2d_eval_no_boundary_check(interpM22_k1k2[i-1], &ktab[0], &ktab[0], &MT22_k1k2[i-1,0,0], k1, k2, xaccM22_k1k2[i-1], yaccM22_k1k2[i-1])
        right = gsl.interp2d_eval_no_boundary_check(interpM22_k1k2[i], &ktab[0], &ktab[0], &MT22_k1k2[i,0,0], k1, k2, xaccM22_k1k2[i], yaccM22_k1k2[i])
            
        return t*left + (1-t)*right


cpdef double M04_k1k1k2k2(double k1, double k2, double z):
    
    if z >= zmax :
        return gsl.interp2d_eval_no_boundary_check(interpM04_k1k1k2k2[znum-1], &ktab[0], &ktab[0], &MT04_k1k1k2k2[znum-1,0,0], k1, k2, xaccM04_k1k1k2k2[znum-1], yaccM04_k1k1k2k2[znum-1])

    else :
        i = search(ztabc, znum, z)
        t = (z - ztabc[i])/(ztabc[i-1] - ztabc[i])
    
        left = gsl.interp2d_eval_no_boundary_check(interpM04_k1k1k2k2[i-1], &ktab[0], &ktab[0], &MT04_k1k1k2k2[i-1,0,0], k1, k2, xaccM04_k1k1k2k2[i-1], yaccM04_k1k1k2k2[i-1])
        right = gsl.interp2d_eval_no_boundary_check(interpM04_k1k1k2k2[i], &ktab[0], &ktab[0], &MT04_k1k1k2k2[i,0,0], k1, k2, xaccM04_k1k1k2k2[i], yaccM04_k1k1k2k2[i])
            
        return t*left + (1-t)*right


cpdef double M13_k1k1k2(double k1, double k2, double z):
    
    if z >= zmax :
        return gsl.interp2d_eval_no_boundary_check(interpM13_k1k1k2[znum-1], &ktab[0], &ktab[0], &MT13_k1k1k2[znum-1,0,0], k1, k2, xaccM13_k1k1k2[znum-1], yaccM13_k1k1k2[znum-1])

    else :
        i = search(ztabc, znum, z)
        t = (z - ztabc[i])/(ztabc[i-1] - ztabc[i])
    
        left = gsl.interp2d_eval_no_boundary_check(interpM13_k1k1k2[i-1], &ktab[0], &ktab[0], &MT13_k1k1k2[i-1,0,0], k1, k2, xaccM13_k1k1k2[i-1], yaccM13_k1k1k2[i-1])
        right = gsl.interp2d_eval_no_boundary_check(interpM13_k1k1k2[i], &ktab[0], &ktab[0], &MT13_k1k1k2[i,0,0], k1, k2, xaccM13_k1k1k2[i], yaccM13_k1k1k2[i])
            
        return t*left + (1-t)*right

    






### Trispectrum matter/galaxy in the halo model evaluated at points (k1, -k1, k2, -k2)


cpdef double trispectrum(double k1, double k2, double u12, double z) :
    
    #fonctions qui apparaîtront souvent
    
    cdef double M11k1 = M11_k1(k1, z)
    cdef double M11k2 = M11_k1(k2, z)
    
    cdef double M21k1 = M21_k1(k1, z)
    cdef double M21k2 = M21_k1(k2, z)
    
    cdef double M31k1 = M31_k1(k1, z)
    cdef double M31k2 = M31_k1(k2, z)
    
    cdef double M12k1k2 = M12_k1k2(k1, k2, z)
    
    cdef double M22k1k2 = M22_k1k2(k1, k2, z)
    
    
    cdef double n12 = ptc.norme2(k1, k2, u12)
    cdef double m12 = ptc.norme2(k1, k2, -u12)
    
    cdef double Pk1 = ptc.P(k1, z)
    cdef double Pk2 = ptc.P(k2, z)
    cdef double Pk1plusk2 = ptc.P(n12, z)
    cdef double Pk1moinsk2 = ptc.P(m12, z)
    
    cdef double u1_n12, u2_n12, u1_m12, u2_m12
    
    if n12 == 0 :
        u1_n12 = 0.
        u2_n12 = 0.
    else :
        u1_n12 = (k1*k1 + k1*k2*u12)/(k1 * n12)      #angle between k1 and k1+k2, etc
        u2_n12 = (k2*k2 + k1*k2*u12)/(k2 * n12)
        
    if m12 == 0 :
        u1_m12 = 0.
        u2_m12 = 0.
    else :
        u1_m12 = (k1*k1 - k1*k2*u12)/(k1 * m12)     
        u2_m12 = (-k2*k2 + k1*k2*u12)/(k2 * m12)

    
    
    cdef double T1h = M04_k1k1k2k2(k1, k2, z)
    
    
    
    cdef double T2h = (
    
          2 * M11k1 * M13_k1k1k2(k2, k1, z) * Pk1
        + 2 * M11k2 * M13_k1k1k2(k1, k2, z) * Pk2
        
        + M12k1k2 * M12k1k2 * (Pk1plusk2 + Pk1moinsk2)
    )
    
    
    
    cdef double T3h = (
    
          ( 2 * M11k1 * M12k1k2 * M11k2 ) * ptc.BPT(k1, k2, n12, u12, -u1_n12, -u2_n12, z)
        + ( 2 * M12k1k2 * M11k1 * M11k2 ) * ptc.BPT(k1, k2, m12, -u12, -u1_m12, u2_m12, z)
    
    
    
        + M11k1 * M11k1 * M22_k1k2(k2, k2, z) * Pk1*Pk1
        
        + 4 * M11k1 * M22k1k2 * M11k2 * Pk1 * Pk2
        
        + 2 * M21k1 * M12k1k2 * M11k2 * Pk2 * (Pk1plusk2 + Pk1moinsk2)
        + 2 * M11k1 * M12k1k2 * M21k2 * Pk1 * (Pk1plusk2 + Pk1moinsk2)
                        
        + M22_k1k2(k1, k1, z) * M11k2 * M11k2 * Pk2 * Pk2
    
    )
    
    cdef double T4h = (
        
          M11k1 * M11k1 * M11k2 * M11k2 * ptc.TPT(k1, k2, u12, z)
          
        + 2 * M31k1 * M11k1 * M11k2 * M11k2 * Pk1 * Pk2 * Pk2
        + 2 * M31k2 * M11k2 * M11k1 * M11k1 * Pk2 * Pk1 * Pk1
        
        + 2 * M21k1 * M11k1 * M11k2 * M11k2 * Pk2 * (ptc.BPT(k1, k2, n12, u12, -u1_n12, -u2_n12, z) + ptc.BPT(k1, k2, m12, -u12, -u1_m12, u2_m12, z))
        + 2 * M21k2 * M11k2 * M11k1 * M11k1 * Pk1 * (ptc.BPT(k1, k2, n12, u12, -u1_n12, -u2_n12, z) + ptc.BPT(k1, k2, m12, -u12, -u1_m12, u2_m12, z))

        + 2 * M21k1 * M11k2 * Pk2 * (M21k1 * M11k2 * Pk2 + M11k1 * M21k2 * Pk1) * (Pk1plusk2 + Pk1moinsk2)
        + 2 * M21k2 * M11k1 * Pk1 * (M21k2 * M11k1 * Pk1 + M11k2 * M21k1 * Pk2) * (Pk1plusk2 + Pk1moinsk2)
    
    )
 

    return T1h + T2h + T3h + T4h
    
   
    

###Powerspectrum in the halo model


cpdef double Pmm(double k, double z) :
    
    cdef double P1h = M02_k1k1(k, z)
    cdef double P2h = M11_k1(k, z)**2 * ptc.P(k, z)
    
    return P1h + P2h