from scipy.spatial.distance import cdist

from search.MOPSO import *
import math
import itertools
from scipy.spatial.distance import cdist
from search.SelectMethod import *


class IGD(MOEA):
    def __init__(self, proteins, config, energy, coder, logging, current_gen=1):
        super(IGD, self).__init__(config, energy, coder, logging, current_gen)

        self.proteins = proteins
        self.obj_num = self.proteins[0].obj_num
        self.max_angles, self.min_angles = proteins[0].get_angles_field()
        self.init_param()

    def init_param(self):
        self.k = self.config['k']  # The number of neighbors for estimating density
        self.l = self.config['l']  # The number of candidates for convergence-based selection
        self.mesh_div = self.config['mesh_div']
        self.archive_thresh = self.config['archive_thresh']

    def arr_nchoosek(self, array, m):
        all_combos = list(itertools.combinations(array, m))
        return np.array(all_combos)

    def nchoosek(self, n, k):
        return math.factorial(n) // math.factorial(k) // math.factorial(n - k)

    def UniformPoint(self, N, M):
        # NBI
        H1 = 1
        while self.nchoosek(H1 + M, M - 1) <= N:
            H1 += 1
        np.arange(1, (H1 + M - 1))
        W = self.arr_nchoosek(np.arange(1, (H1 + M)), M - 1) - np.tile(np.arange(M - 2 + 1),
                                                                       [self.nchoosek(H1 + M - 1, M - 1),
                                                                        1]) - 1
        W = (np.hstack([W, np.zeros([W.shape[0], 1]) + H1]) - np.hstack([np.zeros([W.shape[0], 1]), W])) / H1
        if H1 < M:
            H2 = 0
            while self.nchoosek(H1 + M - 1, M - 1) + self.nchoosek(H2 + M, M - 1) <= N:
                H2 += 1
            if H2 > 0:
                W2 = self.arr_nchoosek(np.arange(1, (H2 + M)), M - 1) - np.tile(np.arange(M - 2 + 1),
                                                                                [self.nchoosek(H2 + M - 1, M - 1),
                                                                                 1]) - 1
                W2 = (np.hstack([W2, np.zeros([W2.shape[0], 1]) + H2]) - np.hstack(
                    [np.zeros([W2.shape[0], 1]), W2])) / H2
                W2_temp = W2 // 2 + 1 // (2 * M)
                W = np.vstack([W, W2_temp])
        W = np.where(W < 1e-6, 1e-6, W)
        N = W.shape[0]
        return W, N

    def fitness(self, obj):
        fit = np.zeros_like(obj)
        for i in range(obj.shape[1]):
            index = np.hstack([np.arange(i), np.arange(i + 1, obj.shape[1])])
            fit[:, i] = np.abs(obj[:, i]) + 100 * np.sum(obj[:, index] ** 2, axis=1)
        return fit

    def EnvironmentalSelection(self, proteins, W, N):
        protines_obj_view = np.array([x.obj.copy() for x in proteins])
        _, x = np.unique(np.round(protines_obj_view * 1e4) / 1e4, axis=0, return_index=True)
        protines_obj_view = protines_obj_view[x, :]
        N = min(N, protines_obj_view.shape[0])
        rank = np.zeros(protines_obj_view.shape[0])
        dis = np.zeros([protines_obj_view.shape[0], W.shape[0]])
        for i in range(protines_obj_view.shape[0]):
            temp = np.tile(protines_obj_view[i, :], [W.shape[0], 1])
            domi = np.int64(np.any(temp < W, axis=1)) - np.int64(np.any(temp > W, axis=1))
            if np.any(domi == 1):
                rank[i] = 1
                dis[i, :] = -np.sqrt(np.sum((temp - W) ** 2, axis=1)).T
            elif np.any(domi == -1):
                rank[i] = 3
                dis[i, :] = np.sqrt(np.sum((temp - W) ** 2, axis=1)).T
            else:
                rank[i] = 2
                dis[i, :] = -np.sqrt(np.sum(np.where(temp - W < 0, 0, temp - W) ** 2, axis=1)).T

        his_index = np.histogram(rank, np.arange(4) + 1)[0]
        cumsum = np.cumsum(his_index)
        max_ind = np.where(cumsum >= N)[0]
        MaxFNo = max_ind[0] + 1
        Next = rank < MaxFNo
        Last = np.where(rank == MaxFNo)[0]
        Choose = self.LastSelection(dis[Last, :], N - np.sum(Next), W)


        Next[Last[Choose]] = True

        offspring = []
        for x, y in zip(proteins, Next):
            if y:
                offspring.append(x.copy())

        rank = rank[Next]
        Dis = dis[Next, :]

        return offspring, rank, Dis

        # Population = Population[Next]
        # Rank = Rank(Next);
        # Dis = Dis(Next,:);

    def LastSelection(self, Dis, K, W):
        Distance = cdist(W, W)
        Distance[np.eye(len(Distance), dtype=bool)] = np.inf
        Del = np.zeros([W.shape[0]])

        while np.sum(Del == False) > K:
            Remain = np.where(Del == False)[0]
            ttt = np.array([Distance[i, Remain] for i in Remain])
            Temp = np.sort(ttt, axis=1)
            rank = Temp[:, 0].argsort()
            Del[Remain[rank[0]]] = True
        Dis = Dis[:, np.where(Del == False)[0]]
        Choose = np.zeros([Dis.shape[0]], dtype=bool)
        for i in range(Dis.shape[1]):
            remain = np.where(Choose == False)[0]
            best = np.argmin(Dis[remain, i])
            Choose[remain[best]] = True
        return Choose

    def run(self):
        time1 = time.time()
        print_path = self.logging.protein_name + '_print_result.txt'
        f = open(print_path, 'a')
        self.evolution_parameters_init()
        DNPE = self.pop_size * 100
        # W, N = self.generate_uniform_point(self.pop_size, self.obj_num)
        W, N = self.UniformPoint(self.pop_size, self.obj_num)
        self.pop_size = N
        self.proteins_point = self.proteins
        self.offspring_point = self.proteins
        for x in range(self.current_gen, 2):
            st_nn = time.time()
            Offspring_angle = self.crossover_binary(
                [self.proteins_point[x] for x in np.random.randint(len(self.proteins_point), size=self.pop_size)])
            Offspring_angle = self.mutation_polynomial(Offspring_angle)
            for protein, angle in zip(self.offspring_point, Offspring_angle):
                protein.update_angle_from_view(angle)
            self.energy.calculate_energy(self.offspring_point)
            self.proteins_point = (self.proteins_point + self.offspring_point).copy()
            protines_obj_view = np.array([x.obj.copy() for x in self.proteins_point])
            fit = self.fitness(protines_obj_view)
            rank = np.argsort(fit, axis=1)
            index = np.unique(rank[:math.ceil(self.pop_size / self.obj_num), :])

            self.proteins_point = [self.proteins_point[x].copy() for x in index]
            self.current_gen = x
            print(x, time.time() - st_nn)

        protines_obj_view = np.array([x.obj.copy() for x in self.proteins_point])
        ext = np.argmin(self.fitness(protines_obj_view), axis=0)
        zmax = np.diag(protines_obj_view[ext]).T
        zmin = np.min(protines_obj_view, axis=0)
        zmax = np.where(zmax < 1e-6, 1, zmax)

        W = W * np.tile(zmax - zmin, [self.pop_size, 1]) + np.tile(zmin, [self.pop_size, 1])

        self.proteins, Rank, Dis = self.EnvironmentalSelection(self.proteins, W, self.pop_size)

        for x in range(self.current_gen, self.max_gen):
            Rank_r = Rank.reshape(Rank.shape[0], 1)
            a = np.min(Dis, axis=1).reshape(Dis.shape[0], 1)
            fitness = np.hstack([Rank_r, a])
            MatingPool = tournament(2, self.pop_size, fitness)
            self.offspring = [self.proteins[x].copy() for x in MatingPool]
            start_time = time.time()

            # crossover
            self.offspring_anlge_view = self.crossover_binary(self.offspring)

            # mutation

            self.offspring_anlge_view = self.mutation_polynomial(self.offspring_anlge_view)

            # evolution
            for protein, angle in zip(self.offspring, self.offspring_anlge_view):
                protein.update_angle_from_view(angle)
            self.energy.calculate_energy(self.offspring)

            # selection
            combine_proteins = self.proteins + self.offspring
            offspring_obj_view = np.array([x.obj.copy() for x in combine_proteins])
            z_min = np.min(offspring_obj_view, axis=0)
            # z_min = np.min(np.vstack([z_min, offspring_obj_view]), axis=0)
            self.proteins, Rank, Dis = self.EnvironmentalSelection(combine_proteins, W, self.pop_size)
            # self.proteins = self.select_population(combine_proteins, z_min)

            # self.update_archive_obj_matrix()

            self.current_gen = x
            print(self.logging.protein_name, x, time.time() - start_time, z_min)
            print(self.logging.protein_name, x, time.time() - start_time, z_min, file=f)
            self.logging.write(self.proteins, self.coder, self.config['save_all'], self.current_gen)
        self.logging.write_archive(self.proteins, self.coder, self.current_gen)
        print('Total time:', time.time() - time1)
        print('Total time:', time.time() - time1, file=f)
        f.close()
        self.energy.stop()
