#Importing required modules
import math
import random
import matplotlib.pyplot as plt
import numpy as np

import Users
import tool
import GDOP
import time
import NSGA2 as ns
import TS as ts


#Function to find index of list
def index_of(a,list):
    for i in range(0,len(list)):
        if list[i] == a:
            return i
    return -1

def main(users,pop_size,max_gen,k=3):
    pos_solution,p_solution=ns.init(pop_size,k,users)
    # print(p_solution)
    # max_gen = 100
    gen_no = 0
    pos_range=[0,1000]
    best_pos=[]
    best_p=[]
    f1_best=[]
    f2_best=[]
    function1_best=[]
    function2_best=[]
    solution_num=[]
    while(gen_no<max_gen):
        function1_values,function2_values,punish_values=ns.get_function_values(pop_size,pos_solution,p_solution,users)
        non_dominated_sorted_solution,rank = ns.fast_non_dominated_sort(function1_values[:], function2_values[:])
        print(non_dominated_sorted_solution)
        crowding_distance_values=[]
        for i in range(0,len(non_dominated_sorted_solution)):
            crowding_distance_values.append(ns.crowding_distance(function1_values[:],function2_values[:],non_dominated_sorted_solution[i][:]))
        #todo 采用TS对父代每个个体进行搜索
        # for index in range(len(pos_solution)):
        #     pos_solution[index]=ts.tabu_search(pos_solution[index],users,10)
        #     p_solution[index]=ns.init_p(pos_solution[index],users)
        pos_solution2 = pos_solution[:]
        p_solution2=p_solution[:]
        #排序要考虑惩罚项
        #交叉变异生成子代
        while (len(pos_solution2) != 2 * pop_size):
            #锦标赛选择
            parent1 = ns.tournament(pos_solution,non_dominated_sorted_solution,rank,crowding_distance_values)
            parent2 = parent1
            while parent1== parent2:
                parent2 = ns.tournament(pos_solution,non_dominated_sorted_solution,rank,crowding_distance_values)
            #交叉变异
            # child = ns.cross_over_old(parent1, parent2)
            # child_mute = ns.mutate_old(child, pos_range)
            child = ns.cross_over(parent1, parent2, gen_no, max_gen)
            child_mute = ns.mutate(child, pos_range, gen_no, max_gen)
            child_p=ns.init_p(child_mute,users)
            #todo 对每个子代进行TS
            # child_mute=ts.tabu_search(child_mute,users,10)
            # child_p=ns.init_p(child_mute,users)
            pos_solution2.append(child_mute)
            p_solution2.append(child_p)

        function1_values2, function2_values2, punish_values2 = ns.get_function_values(2*pop_size, pos_solution2, p_solution2,users)
        non_dominated_sorted_solution2,rank2 = ns.fast_non_dominated_sort(function1_values2[:], function2_values2[:])
        crowding_distance_values2 = []
        for i in range(0, len(non_dominated_sorted_solution2)):
            crowding_distance_values2.append(
                ns.crowding_distance(function1_values2[:], function2_values2[:], non_dominated_sorted_solution2[i][:]))
        new_solution = []
        for i in range(0, len(non_dominated_sorted_solution2)):
            non_dominated_sorted_solution2_1 = [
                index_of(non_dominated_sorted_solution2[i][j], non_dominated_sorted_solution2[i]) for j in
                range(0, len(non_dominated_sorted_solution2[i]))]
            front22 = ns.sort_by_values(non_dominated_sorted_solution2_1[:], crowding_distance_values2[i][:])
            front = [non_dominated_sorted_solution2[i][front22[j]] for j in
                     range(0, len(non_dominated_sorted_solution2[i]))]
            front.reverse()
            for value in front:
                new_solution.append(value)
                if (len(new_solution) == pop_size):
                    break
            if (len(new_solution) == pop_size):
                break
        pos_solution = [pos_solution2[i] for i in new_solution]
        p_solution=[p_solution2[i] for i in new_solution]

        function1_best,function2_best,_=ns.get_function_values(len(pos_solution),pos_solution,
                                                                p_solution,users)

        best_pos,rank=ns.fast_non_dominated_sort(function1_best[:], function2_best[:])
        solution_num.append(len(best_pos[0]))

        if len(best_pos[0]) >= pop_size:
            break

        # if len(best_pos[0])> 30:
        # for index in range(len(pos_solution)):
        #     pos_solution[index] = ts.tabu_search(pos_solution[index], users, 5, rank[index], len(best_pos),gen_no,max_gen)
        #     p_solution[index] = ns.init_p(pos_solution[index], users)
            # for r in range(1,len(best_pos)):#取所有支配解
            #     for index in best_pos[r]:
            #         pos_solution[index] = ts.tabu_search(pos_solution[index], users, 5, rank[index], len(best_pos))
            #         p_solution[index] = ns.init_p(pos_solution[index], users)
        if len(best_pos[0])< pop_size/2:
            # for r in range(0, len(best_pos)):
            #     for index in best_pos[r]:
            #         pos_solution[index] = ts.tabu_search(pos_solution[index], users, 1, len(best_pos[0]), 60,r,len(best_pos))
            #         p_solution[index] = ns.init_p(pos_solution[index], users)
            for index in range(len(pos_solution)):
                pos_solution[index] = ts.tabu_search(pos_solution[index], users, 1, len(best_pos[0]), pop_size,1,len(best_pos))
                p_solution[index] = ns.init_p(pos_solution[index], users)

        gen_no+=1
    for i in range(len(best_pos[0])):
        f1_best.append(function1_best[best_pos[0][i]])
        f2_best.append(function2_best[best_pos[0][i]])

    return f1_best,f2_best,solution_num

def compare_pop():
    pop=[i for i in range(10,110,10)]
    filename = "./data.csv"
    print(pop)
    times=[]
    for i in pop:
        run_time=run_pop(filename,i,150,3)
        times.append(run_time)
    print(times)


def run_pop(filename,pop_size,iter_max,uav_num):
    users = Users.getUsers(filename)
    start_time = time.time()
    f1, f2, solution_num = main(users, pop_size, iter_max, uav_num)
    end_time = time.time()
    run_time = end_time - start_time
    return run_time

def run(filename,pop_size,iter_max,uav_num):
    users = Users.getUsers(filename)
    start_time = time.time()
    f1, f2, solution_num = main(users, pop_size, iter_max, uav_num)
    end_time = time.time()
    run_time = end_time - start_time
    print(f1, f2)
    print("程序运行时间为：", run_time)
    print(solution_num)
    plt.xlabel('Average snr', fontsize=12)
    plt.ylabel('Average -gdop', fontsize=12)

    plt.scatter(f1, f2, )

    plt.show()
def run_ts():
    filename = "./data.csv"
    times=[]
    for i in range(10):
        run_time = run_pop(filename, 30, 150, 3)
        print("time-",run_time)
        times.append(run_time)
    print(times)


if __name__ == '__main__':
    # compare_pop()
    filename="./data.csv"
    # run_ts()
    # print(run_pop(filename,60,200,3))
    run(filename,60,200,3)
    # users = Users.getUsers("./data.csv")
    # start_time = time.time()
    # f1,f2,solution_num=main(users,60,120,3)
    # end_time = time.time()
    # run_time = end_time - start_time
    # print(f1, f2)
    # print("程序运行时间为：", run_time)
    # print(solution_num)
    # plt.xlabel('Average snr', fontsize=12)
    # plt.ylabel('Average -gdop', fontsize=12)
    #
    # plt.scatter(f1, f2, )
    #
    # plt.show()

