import numpy as np
from scipy.spatial.distance import cdist
import time

from search.MOEA import MOEA
from search.SelectMethod import tournament


class IT(MOEA):
    def __init__(self, proteins, config, energy, coder, logging, current_gen=1):
        super(IT, self).__init__(config, energy, coder, logging, current_gen)
        self.pop_size = self.config['pop_size']
        self.proteins = proteins
        self.obj_num = self.proteins[0].obj_num
        self.max_angles, self.min_angles = self.proteins[0].get_angles_field()
        self.init_param()
        self.evolution_parameters_init()

    def init_param(self):
        self.evaluation_0_step = self.config['evaluation_0_step']
        self.evaluation_1_step = self.config['evaluation_1_step']
        self.epsilon = self.config['epsilon']

        self.archive_thresh = self.config['archive_threshold']
        self.mesh_div = self.config['mesh_div']

    def NDSort(self, PopObj, nSort):
        # get all proteins' energy values
        PopObj, a, Loc = np.unique(PopObj, return_index=True, return_inverse=True, axis=0)

        rank = np.lexsort(PopObj[:, ::-1].T)
        PopObj = PopObj[rank].copy()
        table, _ = np.histogram(Loc, max(Loc) + 1)

        N, M = PopObj.shape
        # FrontNo = np.full((1, N), np.inf)
        FrontNo = np.ones(N) * np.inf
        FrontNo = np.array(FrontNo)
        MaxFNo = 0

        index_table = np.where(FrontNo < np.inf)
        # print(sum(table[index_table]))

        while sum(table[index_table]) < min(nSort, len(Loc)):
            # print(sum(table[index_table]))
            MaxFNo = MaxFNo + 1
            for i in range(N):
                if FrontNo[i] == np.inf:
                    Dominated = False
                    for j in range(i - 1, 0, -1):
                        if FrontNo[j] == MaxFNo:
                            m = 1
                            while m <= M - 1:
                                if PopObj[i, m] >= PopObj[j, m]:
                                    m = m + 1
                                else:
                                    break
                            Dominated = m >= M
                            if Dominated | M == 1:
                                break
                    if not Dominated:
                        FrontNo[i] = MaxFNo
            index_table = np.where(FrontNo < np.inf)

        FrontNo[rank] = FrontNo
        FrontNo = FrontNo[Loc]

        # pareto_dict = {}
        # for x in range(1, MaxFNo + 1):
        #     pareto_dict[x] = np.where(FrontNo == x)[0]

        return FrontNo, MaxFNo  # , pareto_dict

    def get_pareto_rank_array(self, pareto_front_dict):
        pareto_size = np.sum([len(x) for x in pareto_front_dict.values()])
        pareto_rank_array = np.zeros(pareto_size, dtype=int)

        for rank in pareto_front_dict:
            for x in pareto_front_dict[rank]:
                pareto_rank_array[x] = rank
        max_rank = np.max(pareto_rank_array)
        return pareto_rank_array, max_rank

    def select_protein_ndwa(self, combine_proteins, sf, protein_num):
        combine_proteins_obj_view = np.array([x.obj.copy() for x in combine_proteins])
        front_no, max_f_no = self.NDSort(combine_proteins_obj_view, protein_num)

        next_index = front_no < max_f_no
        last_index = np.where(front_no == max_f_no)[0]
        next_index[last_index[:protein_num - np.sum(next_index)]] = True
        next_index = np.where(next_index == True)[0]

        new_proteins = [combine_proteins[x].copy() for x in next_index]
        sf[:protein_num] = sf[next_index].copy()

        return new_proteins, sf

    def find_subspace(self, angle_view):
        x = angle_view.T.copy()
        avg = np.mean(x, axis=0)
        x = x - np.tile(avg, (x.shape[0], 1))
        sigma = np.dot(x, x.T) / x.shape[0]
        _, s, _ = np.linalg.svd(sigma)
        v_list = np.cumsum(s) / np.sum(s)
        select_i = 0
        while v_list[select_i] < self.epsilon:
            select_i = select_i + 1

        x1 = angle_view.copy()
        avg_1 = np.round(np.mean(x1, axis=0), 2)
        x1[:, select_i:] = np.tile(avg_1[select_i:], (x1.shape[0], 1))

        return x1

    def cos_v_func(self, obj_view, v):
        # calculate the objective value of each solution on each single-objective
        r1 = np.sum(obj_view * v, axis=1)
        r2 = np.sqrt(np.sum(obj_view ** 2, axis=1))
        sf = -r1 / r2
        return sf

    def update_archive(self):
        temp_archive = self.proteins + self.archive
        temp_archive_angle_view = np.array([x.angle_view().copy() for x in temp_archive])
        _, save_index = np.unique(temp_archive_angle_view, axis=0, return_index=True)
        temp_archive = [temp_archive[x].copy() for x in save_index]

        temp_archive_obj = [x.obj.copy() for x in temp_archive]
        pareto_rank, max_rank = self.NDSort(temp_archive_obj, len(temp_archive_obj))
        select_pareto_front = np.where(pareto_rank == 1)[0]

        # pareto_front_dict = self.nonDominatedSort(temp_archive) 
        # select_pareto_front = pareto_front_dict[1]  # + pareto_front_dict[2]

        self.archive = [temp_archive[x] for x in select_pareto_front]
        self.archive_obj_view = np.array([x.obj.copy() for x in temp_archive])

    def run(self):
        self.archive = []

        # generate the weight vectors
        weight, weight_num = MOEA.generate_uniform_point(self.pop_size, self.obj_num)
        u_weight = np.r_[weight, weight[::-1]]
        num_W = u_weight.shape[0]

        # the ovjective value of each solution on single-objective in different stage
        sf = np.zeros(self.pop_size *2)

        # the objective value of each solution on single-objective in phase 1
        proteins_obj_view = np.array([x.obj.copy() for x in self.proteins])
        sf[:self.pop_size] = np.sum(proteins_obj_view * u_weight[0, :], axis=1)

        # non-dominated dynamic weight aggregation (NDWA)
        while self.current_gen < self.evaluation_0_step:
            start_time = time.time()
            w_index = int(np.ceil(self.current_gen / self.pop_size) + 1) % num_W
            if w_index == 0:
                w_index = num_W
            w_index = w_index-1
            mating_pool = tournament(2, self.pop_size, sf[:self.pop_size].reshape(-1, 1))
            offspring = [self.proteins[x].copy() for x in mating_pool]

            # 0.9 20 1 20
            offspring_angle_view = self.crossover_binary(offspring)
            offspring_angle_view = self.mutation_polynomial(offspring_angle_view)

            for protein, angle in zip(offspring, offspring_angle_view):
                protein.update_angle_from_view(angle)
            self.energy.calculate_energy(offspring)
            self.current_gen = self.current_gen + len(offspring)

            offspring_obj_view = np.array([x.obj.copy() for x in offspring])
            a = np.sum(offspring_obj_view * u_weight[w_index, :], axis=1)
            sf[self.pop_size:] = a#np.sum(offspring_obj_view * u_weight[w_index, :], axis=1)

            combine_proteins = [x.copy() for x in self.proteins + offspring]
            self.proteins, sf = self.select_protein_ndwa(combine_proteins, sf, self.pop_size)
            self.update_archive()

            print(int(self.current_gen / self.pop_size), time.time() - start_time)
            self.logging.write(self.proteins, self.coder, self.config['save_all'],
                               int(self.current_gen / self.pop_size))

        # pareto-optimal subspace learning
        archive_angle_view = np.array([x.angle_view().copy() for x in self.archive])
        learned_data = self.find_subspace(archive_angle_view)

        mean_val = np.mean(learned_data, axis=0)
        u_limit = np.ones_like(self.max_angles)
        l_limit = np.zeros_like(self.min_angles)

        for i in range(np.size(mean_val)):
            if np.abs(learned_data[0, i] - mean_val[i]) < 0.1:
                l_limit[i] = np.round(mean_val[i], 1)
                u_limit[i] = np.round(mean_val[i], 1)
            else:
                l_limit[i] = self.min_angles[i]
                u_limit[i] = self.max_angles[i]

        self.min_angles = l_limit.copy()
        self.max_angles = u_limit.copy()

        # reference line mapping
        # optimize problem as single objective problems
        rf_list = np.eye(self.obj_num)
        so_maxeval = int(np.floor(self.evaluation_1_step / self.obj_num))
        extreme_point = [self.proteins[0].copy() for _ in range(self.obj_num)]

        angle_num = len(self.min_angles)

        for i in range(self.obj_num):
            start_time = time.time()
            current_eval = self.current_gen
            angle_view = np.random.random((self.pop_size, angle_num)) * np.tile(u_limit - l_limit,
                                                                                (self.pop_size, 1)) + np.tile(l_limit, (
            self.pop_size, 1))

            for protein, angle in zip(self.proteins, angle_view):
                protein.update_angle_from_view(angle)
            self.energy.calculate_energy(self.proteins)
            self.current_gen = self.current_gen + len(self.proteins)

            print(int(self.current_gen / self.pop_size), time.time() - start_time)
            self.logging.write(self.proteins, self.coder, self.config['save_all'],
                               int(self.current_gen / self.pop_size))

            proteins_obj_view = np.array([x.obj.copy() for x in self.proteins])
            sf[:self.pop_size] = self.cos_v_func(proteins_obj_view, rf_list[i, :])
            while self.current_gen < current_eval + so_maxeval:
                start_time = time.time()
                mating_pool = tournament(2, self.pop_size, sf[:self.pop_size].reshape(-1, 1))
                offspring = [self.proteins[x].copy() for x in mating_pool]
                offspring_angle_view = self.crossover_binary(offspring)
                offspring_angle_view = self.mutation_polynomial(offspring_angle_view)

                for protein, angle in zip(offspring, offspring_angle_view):
                    protein.update_angle_from_view(angle)
                self.energy.calculate_energy(offspring)
                self.current_gen = self.current_gen + len(offspring)

                offspring_obj_view = np.array([x.obj.copy() for x in offspring])

                sf[self.pop_size:] = self.cos_v_func(offspring_obj_view, rf_list[i, :])
                rank = np.argsort(sf, axis=0)

                self.proteins = [(self.proteins + offspring)[x].copy() for x in rank[:self.pop_size]]
                sf[:self.pop_size] = sf[rank[:self.pop_size]]

                print(int(self.current_gen / self.pop_size), time.time() - start_time)
                self.logging.write(self.proteins, self.coder, self.config['save_all'],
                                   int(self.current_gen / self.pop_size))

            extreme_point[i] = self.proteins[0].copy()

        extreme_point_obj_view = np.array([x.obj.copy() for x in extreme_point])
        ideal_point = np.min(extreme_point_obj_view, axis=0)
        nadir_point = np.max(extreme_point_obj_view, axis=0)
        ref_point = weight * np.tile(nadir_point - ideal_point, (weight.shape[0], 1)) + np.tile(ideal_point,
                                                                                                (weight.shape[0], 1))

        # diversity maintaining
        # optimize weight_num single-objective optimization problems
        so_maxeval = int(np.floor(self.max_gen - self.evaluation_0_step - self.evaluation_1_step) / weight_num)
        result = [self.proteins[0].copy() for _ in range(ref_point.shape[0])]
        for i in range(weight_num):
            start_time = time.time()
            current_eval = self.current_gen
            angle_view = np.random.random((self.pop_size, angle_num)) * np.tile(u_limit - l_limit,
                                                                                (self.pop_size, 1)) + np.tile(l_limit, (
            self.pop_size, 1))

            for protein, angle in zip(self.proteins, angle_view):
                protein.update_angle_from_view(angle)
            self.energy.calculate_energy(self.proteins)
            self.current_gen = self.current_gen + self.pop_size

            print(int(self.current_gen / self.pop_size), time.time() - start_time)
            self.logging.write(self.proteins, self.coder, self.config['save_all'],
                               int(self.current_gen / self.pop_size))

            proteins_obj_view = np.array([x.obj.copy() for x in self.proteins])
            sf[:self.pop_size] = self.cos_v_func(proteins_obj_view, ref_point[i, :])

            while self.current_gen < current_eval + so_maxeval:
                start_time = time.time()
                mating_pool = tournament(2, self.pop_size, sf[:self.pop_size].reshape(-1, 1))
                offspring = [self.proteins[x].copy() for x in mating_pool]
                offspring_angle_view = self.crossover_binary(offspring)
                offspring_angle_view = self.mutation_polynomial(offspring_angle_view)

                for protein, angle in zip(offspring, offspring_angle_view):
                    protein.update_angle_from_view(angle)
                self.energy.calculate_energy(offspring)
                self.current_gen = self.current_gen + len(offspring)

                print(int(self.current_gen / self.pop_size), time.time() - start_time)

                offspring_obj_view = np.array([x.obj.copy() for x in offspring])
                sf[self.pop_size:] = self.cos_v_func(offspring_obj_view, ref_point[i, :])
                rank = np.argsort(sf, axis=0)
                self.proteins = [(self.proteins + offspring)[x].copy() for x in rank[:self.pop_size]]
                sf[:self.pop_size] = sf[rank[:self.pop_size]]

                self.logging.write(self.proteins, self.coder, self.config['save_all'],
                                   int(self.current_gen / self.pop_size))

                if i == weight_num and self.current_gen == current_eval + so_maxeval:
                    result[weight_num] = self.proteins[0].copy()
                    self.proteins = [x.copy() for x in result]
                    self.pop_size =weight_num
            result[i] = self.proteins[0].copy()

        self.logging.write_archive(result, self.coder)
        self.energy.stop()
