from scipy.spatial.distance import cdist

from search.MOPSO import *


class DDFC(MOEA):
    def __init__(self, proteins, config, energy, coder, logging, current_gen=1):
        super(DDFC, self).__init__(config, energy, coder, logging, current_gen)

        self.proteins = proteins
        self.obj_num = self.proteins[0].obj_num
        self.max_angles, self.min_angles = proteins[0].get_angles_field()
        self.init_param()

    def init_param(self):
        self.k = self.config['k']  # The number of neighbors for estimating density
        self.l = self.config['l']  # The number of candidates for convergence-based selection
        self.mesh_div = self.config['mesh_div']
        self.archive_thresh = self.config['archive_thresh']

    def cal_fc(self, proteins_obj_view, z_min):

        cal_pop_size = proteins_obj_view.shape[0]
        diff = proteins_obj_view - np.tile(z_min, (cal_pop_size, 1))
        favor_weight = np.zeros((cal_pop_size, self.obj_num))

        favor_weight = np.where(proteins_obj_view == z_min, np.ones_like(favor_weight), favor_weight)

        one_index = np.argwhere(favor_weight == 1)
        other = one_index[:, 0]
        update_index = np.ones(cal_pop_size, dtype=bool)
        update_index[other] = 0
        favor_weight[update_index] = 1 / diff[update_index] / \
                                     np.tile(np.sum(1 / diff[update_index], axis=1), (self.obj_num, 1)).T

        fc = np.max(favor_weight * diff, axis=1)
        fc[np.where(fc < 1e-6)] = 1e-6

        return fc

    def mating_select(self, proteins_obj_view, z_min):
        fc = self.cal_fc(proteins_obj_view, z_min)
        pop_size = len(fc)
        parents_1 = np.arange(pop_size)[:self.pop_size]
        parents_2 = np.arange(pop_size)[:self.pop_size]
        np.random.shuffle(parents_1)
        np.random.shuffle(parents_2)
        # parents_1 = parents_1
        # parents_2 = parents_2

        dominate_temp = proteins_obj_view[parents_1, :] - proteins_obj_view[parents_2, :]
        dominate = np.int64(np.any(dominate_temp < 0, axis=1)) - \
                   np.int64(np.any(dominate_temp > 0, axis=1))

        mating_pool = np.array([], dtype=int)
        mating_pool = np.append(mating_pool, parents_1[dominate == 1])
        mating_pool = np.append(mating_pool, parents_2[dominate == -1])
        mating_pool = np.append(mating_pool, parents_1[
            np.argwhere((dominate == 0) & (fc[parents_1] <= fc[parents_2])).reshape(-1)])
        mating_pool = np.append(mating_pool, parents_2[
            np.argwhere((dominate == 0) & (fc[parents_1] > fc[parents_2])).reshape(-1)])

        return mating_pool

    def NDSort(self, PopObj, nSort):
        # get all proteins' energy values
        PopObj, a, Loc = np.unique(PopObj, return_index=True, return_inverse=True, axis=0)

        rank = np.lexsort(PopObj[:, ::-1].T)
        PopObj = PopObj[rank].copy()
        table, _ = np.histogram(Loc, max(Loc) + 1)

        N, M = PopObj.shape
        # FrontNo = np.full((1, N), np.inf)
        FrontNo = np.ones(N) * np.inf
        FrontNo = np.array(FrontNo)
        MaxFNo = 0

        index_table = np.where(FrontNo < np.inf)
        # print(sum(table[index_table]))

        while sum(table[index_table]) < min(nSort, len(Loc)):
            # print(sum(table[index_table]))
            MaxFNo = MaxFNo + 1
            for i in range(N):
                if FrontNo[i] == np.inf:
                    Dominated = False
                    for j in range(i - 1, 0, -1):
                        if FrontNo[j] == MaxFNo:
                            m = 1
                            while m <= M - 1:
                                if PopObj[i, m] >= PopObj[j, m]:
                                    m = m + 1
                                else:
                                    break
                            Dominated = m >= M
                            if Dominated | M == 1:
                                break
                    if not Dominated:
                        FrontNo[i] = MaxFNo
            index_table = np.where(FrontNo < np.inf)

        FrontNo[rank] = FrontNo
        FrontNo = FrontNo[Loc]
        return FrontNo, MaxFNo

    def select_population(self, combine_proteins, z_min):
        proteins_and_offspring_obj_view = np.array([x.obj.copy() for x in combine_proteins])
        # proteins_and_offspring_obj_view
        FrontNo, MaxFNo = self.NDSort(proteins_and_offspring_obj_view, self.pop_size)

        next = np.zeros(len(FrontNo), dtype=bool)
        index_over = np.where(FrontNo < MaxFNo)
        next[index_over] = True
        Last = np.argwhere(FrontNo == MaxFNo).flatten()

        P = proteins_and_offspring_obj_view[next]
        F = proteins_and_offspring_obj_view[Last]
        combine_obj_view = np.vstack([P, F])
        combine_size = combine_obj_view.shape[0]
        # calculate the FC
        fc = self.cal_fc(combine_obj_view, z_min)

        # Identify the ideal point
        z_min = np.min(proteins_and_offspring_obj_view, axis=0)

        # identify the extreme point
        w = np.zeros((self.obj_num, self.obj_num)) + 1e-6
        w[np.diag_indices(self.obj_num)] = 1
        asf = np.zeros((combine_size, self.obj_num))
        for x in range(self.obj_num):
            asf[:, x] = np.max((combine_obj_view - np.tile(z_min, (combine_size, 1))) /
                               np.tile(w[x, :], (combine_size, 1)), axis=1)
        extreme = np.argmin(asf, axis=0)

        hyperplane = np.dot(combine_obj_view[extreme, :], np.ones(self.obj_num).reshape((self.obj_num, 1)))

        # calculate the intercept
        a = (1 / hyperplane).reshape(-1)
        if np.any(np.isnan(a)):
            a = np.max(combine_obj_view, axis=0)
        # Normalization
        combine_obj_view = (combine_obj_view - np.tile(z_min, (combine_size, 1))) / np.tile(a - z_min,
                                                                                            (combine_size, 1))
        # Projection
        combine_obj_view = combine_obj_view / np.tile(np.sum(combine_obj_view, axis=1), (self.obj_num, 1)).T

        choose = np.zeros(combine_size, dtype=bool)
        P_len = P.shape[0]
        choose[0:P_len] = True
        distance = cdist(combine_obj_view, combine_obj_view, metric='euclidean')
        distance = np.where(distance == 0, np.inf, distance)
        while np.sum(choose) < self.pop_size:
            if not np.any(choose):
                dis = np.sort(distance[:, ~choose][~choose], axis=1)
            else:
                dis = np.sort(distance[:, choose][~choose], axis=1)

            dd = np.sum(1 / dis[:, 0: min(self.k, dis.shape[1])], axis=1)
            rank = np.argsort(dd)
            remain = np.where(~choose)
            remain = np.array(remain)
            r = remain[0][rank[0: min(self.l, len(rank))]]

            fitness = np.cumsum(1 / fc[r])
            fitness = fitness / np.max(fitness)

            # r = np.argmax(fitness)

            r = r[np.where(np.random.rand() <= fitness)]
            if len(r) > 1:
                r = r[0]

            choose[r] = True

        choose = choose[P_len:]

        next[Last[choose]] = True

        offspring = []
        for x, y in zip(self.proteins + self.offspring, next):
            if y:
                offspring.append(x.copy())

        return offspring

    def update_archive_obj_matrix(self):
        temp_archive = [x.copy() for x in self.proteins + self.archive]
        temp_archive_obj = np.array([x.obj.copy() for x in self.proteins + self.archive])
        sort_obj, protein_index, _ = np.unique(temp_archive_obj, return_index=True, return_inverse=True, axis=0)
        rank, max_rank = self.NDSort(sort_obj, len(sort_obj))
        save_archive = []
        for rank, index in zip(rank, protein_index):
            if rank == 1:
                save_archive.append(temp_archive[index])

        if len(save_archive) > self.archive_thresh:
            self.delete_archive_obj_matrix(save_archive)
        else:
            self.archive = save_archive

    def update_archive(self):
        # temp_archive = [x.copy() for x in self.proteins + self.archive]
        temp_archive = [x.obj.copy() for x in self.proteins + self.archive]
        sort_obj_angle, protein_index, = np.unique(temp_archive, return_index=True, return_inverse=True)
        pareto_front_dict = self.nonDominatedSort(temp_archive)
        temp_archive = [temp_archive[x] for x in pareto_front_dict[1]]
        temp_archive_obj_view = np.array([x.obj.copy() for x in temp_archive])

        if len(temp_archive) > self.archive_thresh:
            self.delete_archive(temp_archive, temp_archive_obj_view)
        else:
            self.archive = temp_archive

        self.archive_obj_view = np.array([x.obj for x in self.archive])

    def delete_archive_obj_matrix(self, temp_archive):
        temp_archive_obj_view = np.array([x.obj.copy() for x in temp_archive])
        num_archive = len(temp_archive)

        # calculate the grid location of each solution
        obj_max = np.max(temp_archive_obj_view, axis=0)
        obj_min = np.min(temp_archive_obj_view, axis=0)
        div = (obj_max - obj_min) / self.mesh_div
        div = np.tile(div, (num_archive, 1))
        obj_min = np.tile(obj_min, (num_archive, 1))

        grid_location = np.floor((temp_archive_obj_view - obj_min) / div)
        grid_location[grid_location >= self.mesh_div] = self.mesh_div - 1
        grid_location[np.isnan(grid_location)] = 0

        # detect the grid of each solution belongs to
        _, _, site = np.unique(grid_location, return_index=True, return_inverse=True, axis=0)

        # calculate the crowd degree of each grid
        crowd_degree = np.histogram(site, np.max(site) + 1)[0]

        del_index = np.zeros(num_archive, dtype=bool)

        while np.sum(del_index) < num_archive - self.archive_thresh:
            max_grid = np.where(crowd_degree == max(crowd_degree))[0]
            temp = np.random.randint(0, len(max_grid))
            grid = max_grid[temp]

            in_grid = np.where(site == grid)[0]

            temp = np.random.randint(0, len(in_grid))
            p = in_grid[temp]
            del_index[p] = True
            site[p] = -100
            crowd_degree[grid] = crowd_degree[grid] - 1

        del_index = np.where(del_index == 1)[0]

        self.archive = [temp_archive[x].copy() for x in range(len(temp_archive)) if x not in del_index]

    def delete_archive(self, temp_archive, temp_archive_obj_view):
        num_archive = len(temp_archive)

        # calculate the grid location of each solution
        obj_max = np.max(temp_archive_obj_view, axis=0)
        obj_min = np.min(temp_archive_obj_view, axis=0)
        div = (obj_max - obj_min) / self.mesh_div
        div = np.tile(div, (num_archive, 1))
        obj_min = np.tile(obj_min, (num_archive, 1))

        grid_location = np.floor((temp_archive_obj_view - obj_min) / div)
        grid_location[grid_location >= self.mesh_div] = self.mesh_div - 1
        grid_location[np.isnan(grid_location)] = 0

        # detect the grid of each solution belongs to
        _, _, site = np.unique(grid_location, return_index=True, return_inverse=True, axis=0)

        # calculate the crowd degree of each grid
        crowd_degree = np.histogram(site, np.max(site) + 1)[0]

        del_index = np.zeros(num_archive, dtype=bool)

        while np.sum(del_index) < num_archive - self.archive_thresh:
            max_grid = np.where(crowd_degree == max(crowd_degree))[0]
            temp = np.random.randint(0, len(max_grid))
            grid = max_grid[temp]

            in_grid = np.where(site == grid)[0]

            temp = np.random.randint(0, len(in_grid))
            p = in_grid[temp]
            del_index[p] = True
            site[p] = -100
            crowd_degree[grid] = crowd_degree[grid] - 1

        del_index = np.where(del_index == 1)[0]

        # for x in range(len(temp_archive)):
        #    if x not in del_index:
        #        self.archive = temp_archive[x]

        self.archive = [temp_archive[x].copy() for x in range(len(temp_archive)) if x not in del_index]

    def run(self):
        time1 = time.time()
        print_path = self.logging.protein_name + '_print_result.txt'
        f = open(print_path, 'a')
        self.evolution_parameters_init()
        pareto_front_dict = self.nonDominatedSort(self.proteins)
        self.archive = [self.proteins[x].copy() for x in pareto_front_dict[1]]
        protines_obj_view = np.array([x.obj.copy() for x in self.proteins])
        z_min = np.min(protines_obj_view, axis=0)
        for x in range(self.current_gen, self.max_gen):
            start_time = time.time()

            protines_obj_view = np.array([x.obj.copy() for x in self.proteins])
            mating_pool = self.mating_select(protines_obj_view, z_min)
            self.offspring = [(self.proteins)[x].copy() for x in mating_pool]
            self.offspring_size = len(self.offspring)

            # no mating selection
            # self.offspring = [x.copy() for x in self.proteins]
            # self.offspring_size = len(self.offspring)

            # crossover
            self.offspring_anlge_view = self.crossover_binary(self.offspring)

            # mutation
            self.pro_m = np.exp(-self.current_gen / (4 * self.max_gen))
            self.offspring_anlge_view = self.mutation_polynomial(self.offspring_anlge_view)

            # evolution
            for protein, angle in zip(self.offspring, self.offspring_anlge_view):
                protein.update_angle_from_view(angle)
            self.energy.calculate_energy(self.offspring)

            # selection
            combine_proteins = self.proteins + self.offspring
            offspring_obj_view = np.array([x.obj.copy() for x in self.offspring])
            z_min = np.min(np.vstack([z_min, offspring_obj_view]), axis=0)

            self.proteins = self.select_population(combine_proteins, z_min)

            self.update_archive_obj_matrix()

            self.current_gen = x
            print(self.logging.protein_name, x, time.time() - start_time, z_min, len(self.archive))
            print(self.logging.protein_name, x, time.time() - start_time, z_min, len(self.archive), file=f)
            self.logging.write(self.proteins, self.coder, self.config['save_all'], self.current_gen)
        self.logging.write_archive(self.archive, self.coder)
        print('Total time:', time.time() - time1)
        print('Total time:', time.time() - time1, file=f)
        f.close()
        self.energy.stop()
