import numpy as np
from scipy.special import expn
from virtual_force import dBm2w
from ray import get_gain_based_dist

K = 1 # 1-5
OMEGA = 1   #
NOISE = dBm2w(-100)
SIGMA_E = dBm2w(0)   # 0.2 0.5
power_upper = 3

def get_C1(p_c,pl):
    return p_c*pl/NOISE
def get_C2(p_c,pl):
    return (power_upper-p_c)/NOISE*pl
def get_C3(SIGMA_E,pl):
    return SIGMA_E*pl*power_upper/NOISE+1
def get_C4(C1,C2,C3):
    return (C1/2+C2)/C3
def get_C5(K,OMEGA):
    return (K+1)/OMEGA

def get_C6(p_i,pl):
    return p_i * pl / NOISE
def get_C7(p_i,p_c,pl):
    return pl*(power_upper-p_c-p_i)/NOISE
def get_C8(C6,C7,C3):
    return (C6/2+C7)/C3

def get_common_rate(C_1,C_3,C_4,C_5):

    # 计算e^(C_5 /C_4)/e^K
    exp_term = (np.exp(C_5 / C_4) / np.exp(K))*C_5/C_4

    # 无穷级数的求和部分，用一个足够大的m值来近似求和
    m_max = 10  # 假设的足够大的m值
    sum_term = 0
    for m in range(m_max):
        # 计算泰勒级数展开的项 E_(m+1)
        expn_term = expn(m + 1, C_5 / (C_4))
        # 计算级数的第m项
        term = (K ** m) / (np.math.factorial(m) * expn_term)
        # 累加项
        sum_term += term if not isinstance(term, np.ndarray) else term[0]
    # 计算最终的表达式
    result = (C_1 / (C_3 * C_4 * np.math.log(2))) * (1 - exp_term * sum_term)

    return result

def get_private_rate(C_3,C_5,C_6,C_8):
    # 计算e^(C_5 /C_4)/e^K
    exp_term = (np.exp(C_5 / C_8) / np.exp(K))*C_5/C_8

    # 无穷级数的求和部分，用一个足够大的m值来近似求和
    m_max = 100  # 假设的足够大的m值
    sum_term = 0
    for m in range(m_max):
        # 计算泰勒级数展开的项 E_(m+1)
        expn_term = expn(m + 1, C_5 / (C_8))
        # 计算级数的第m项
        tmp = np.math.factorial(m) * expn_term
        if tmp:
            term = (K ** m) / (np.math.factorial(m) * expn_term)
            # 累加项
            sum_term += term

    # 计算最终的表达式
    result = (C_6 / (C_3 * C_8 * np.math.log(2))) * (1 - exp_term * sum_term)

    return result
def get_total_rate_icsi(BS_list, UE_list, dist_matrix,K=1,SIGMA_E=0):
    SIGMA_E = dBm2w(SIGMA_E)
    result = 0
    common_rate_list = []

    for i in range(len(BS_list)):
        for j in range(len(UE_list)):

            uav = BS_list[i]
            p_c = uav.common_power
            p_i = uav.power_allocation[j]

            dist2 = dist_matrix[j][i]
            pl = get_gain_based_dist(dist2)

            C_1 = get_C1(p_c,pl)

            C_2 = get_C2(p_c,pl)
            C_3 = get_C3(SIGMA_E,pl)
            C_4 = get_C4(C_1,C_2,C_3)
            C_5 = get_C5(K,OMEGA)
            common_rate = get_common_rate(C_1,C_3,C_4,C_5)
            common_rate_list.append(common_rate)

            if p_i != 0:
                C_6 = get_C6(p_i,pl)
                C_7 = get_C7(p_i,p_c,pl)
                C_8 = get_C8(C_6,C_7,C_3)

                private_rate = get_private_rate(C_3,C_5,C_6,C_8)
                result += private_rate
    result += min(common_rate_list)*len(UE_list)
    return result
# if __name__ == '__main__':
#     get_rate()