import numpy as np
import skfuzzy as fuzz
from skfuzzy import control as ctrl
import matplotlib.pyplot as plt

# 会话类
#  带宽范围[0,300]
#  qos范围为[0，25]

class mamdani:
    def __init__(self,bandwidth_range,td_range,plr_range,bandwidth_membership,td_membership,plr_membership):
        self.bandwidth_range = bandwidth_range
        self.td_range = td_range
        self.plr_range = plr_range

        self.bandwidth_membership = bandwidth_membership
        self.td_membership = td_membership
        self.plr_membership = plr_membership

        self.bandwidth = []
        self.td = []
        self.plr = []
        self.tip = []
    def generate_membership_func(self):

        x_bandwidth = np.arange(self.bandwidth_range[0], self.bandwidth_range[1], 1)
        x_td = np.arange(self.td_range[0], self.td_range[1], 1)
        x_plr = np.arange(self.plr_range[0], self.plr_range[1], 1)

        x_qos = np.arange(0, 26, 1)
        # 定义模糊控制变量
        self.bandwidth = ctrl.Antecedent(x_bandwidth, 'bandwidth')
        self.td = ctrl.Antecedent(x_td, 'td')
        self.plr = ctrl.Antecedent(x_plr, 'plr')

        self.tip = ctrl.Consequent(x_qos, 'tip')
        # 生成模糊隶属函数
        self.bandwidth['L'] = fuzz.trapmf(x_bandwidth, self.bandwidth_membership[0])  #定义质量差时的三角隶属度函数横坐标
        self.bandwidth['M'] = fuzz.trimf(x_bandwidth, self.bandwidth_membership[1])
        self.bandwidth['H'] = fuzz.trimf(x_bandwidth, self.bandwidth_membership[2])

        self.td['L'] = fuzz.trapmf(x_td, self.td_membership[0])  #定义服务差时的三角隶属度函数横坐标
        self.td['M'] = fuzz.trimf(x_td, self.td_membership[1])
        self.td['H'] = fuzz.trimf(x_td, self.td_membership[2])

        self.plr['L'] = fuzz.trapmf(x_plr, self.plr_membership[0])  #定义服务差时的三角隶属度函数横坐标
        self.plr['M'] = fuzz.trimf(x_plr, self.plr_membership[1])
        self.plr['H'] = fuzz.trimf(x_plr, self.plr_membership[2])

        self.tip['L'] = fuzz.trimf(x_qos, [0, 0, 13])      #定义小费的三角隶属度函数横坐标
        self.tip['M'] = fuzz.trimf(x_qos, [0, 13, 25])
        self.tip['H'] = fuzz.trimf(x_qos, [13, 25, 25])

        self.tip.defuzzify_method='lom'
    def show_membership(self):
        self.bandwidth.view()
        self.td.view()
        self.plr.view()

    def set_rules(self):

        # 共27条待补充规则
        rule1 = ctrl.Rule(
            antecedent=((self.bandwidth['L']&self.td['L'])|(self.bandwidth['L'] & self.td['M'])|(self.bandwidth['M'] & self.td['L'])),
            consequent=self.tip['L'],label='Low')
        rule2 = ctrl.Rule(
            antecedent=((self.bandwidth['M']&self.td['M'])|(self.bandwidth['L']&self.td['H'])|(self.bandwidth['H']&self.td['L'])),
            consequent=self.tip['M'],label='Medium')
        rule3 = ctrl.Rule(
            antecedent=((self.bandwidth['M']&self.td['H'])|(self.bandwidth['H']&self.td['M'])|(self.bandwidth['H']&self.td['H'])),
            consequent=self.tip['H'],label='High')

        rule2.view()
        self.tipping_ctrl = ctrl.ControlSystem([rule1, rule2, rule3])
    def test(self):
        self.tipping = ctrl.ControlSystemSimulation(self.tipping_ctrl)
        # 测试输出
        self.tipping.input['bandwidth'] = 6.5
        self.tipping.input['td'] = 9.8
        self.tipping.compute()
        print (self.tipping.output['tip'])
        self.tip.view(sim=self.tipping)
    def show_3dresult(self):
        # 仿真结果3D图输出，使用下列代码时请注释掉上面的测试输出
        upsampled = np.linspace(0, 11, 21) #这里的范围不能错
        x, y= np.meshgrid(upsampled, upsampled)
        z = np.zeros_like(x)
        pp=[]
        for i in range(0,21):
            for j in range(0,21):
                self.tipping.input['bandwidth']=x[i, j]
                self.tipping.input['td'] = y[i, j]
                self.tipping.compute()
                z[i, j] = self.tipping.output['tip']
                pp.append(z[i,j])
        print('max:',max(pp))
        print('min:',min(pp))

        fig = plt.figure(figsize=(8, 8))   #定义画布大小
        ax = fig.add_subplot(111, projection='3d')
        surf = ax.plot_surface(x, y, z, rstride=1, cstride=1, cmap='viridis',linewidth=0.4, antialiased=True)
        ax.view_init(30, 200)#设置观察角度
        plt.show()
    def process(self):
        self.generate_membership_func()
        self.show_membership()
        self.set_rules()
        self.test()
        self.show_3dresult()
        plt.show
if __name__ == '__main__':
    # 会话类
    bandwidth_range = [0, 301]
    td_range = [0, 101]
    plr_range = [-4, -1]
    bandwidth_membership = [[0, 5, 5, 64], [5, 64, 300], [64, 300, 300]]
    td_membership = [[0, 0, 30, 75], [30, 75, 100], [75, 100, 100]]
    plr_membership = [[-4, -4, -4, -3], [-4, -3, -2], [-3, -2, -2]]
    mam1 = mamdani(bandwidth_range,td_range,plr_range,bandwidth_membership,td_membership,plr_membership)
    mam1.process()