import numpy as np
from scipy.spatial.distance import cdist
import time

from search.MOEA import MOEA
from search.SelectMethod import tournament


class SDR(MOEA):
    def __init__(self, proteins, config, energy, coder, logging, current_gen=1):
        super(SDR, self).__init__(config, energy, coder, logging, current_gen)
        self.pop_size = self.config['pop_size']
        self.proteins = proteins
        self.obj_num = self.proteins[0].obj_num
        self.max_angles, self.min_angles = self.proteins[0].get_angles_field()
        self.init_param()
        self.evolution_parameters_init()

    def init_param(self):

        self.mesh_div = self.config['mesh_div']

    def generate_offspring(self, obj_view, z_min):
        # calculate ASF
        weight = obj_view / np.tile(np.sum(obj_view, axis=1).reshape(-1, 1), [1, self.obj_num])
        weight = np.where(weight < 1e-6, 1e-6, weight)
        obj_view = obj_view - np.tile(z_min, [self.pop_size, 1])
        asf = np.max(obj_view / weight, axis=1)

        # obtain the rank value of each solution's ASF value
        rank = np.argsort(asf)
        asf_rank = np.argsort(rank)

        # calculate the min angle of each solution to others
        angle = np.arccos(1 - cdist(obj_view, obj_view, metric='cosine'))
        np.fill_diagonal(angle, np.inf)
        angle_min = np.min(angle, axis=1)

        # binary tournament selection
        mating_pool = np.zeros(self.pop_size, dtype=int)
        for x in range(self.pop_size):
            select_index = np.random.permutation(self.pop_size)[:2]
            if asf[select_index[0]] < asf[select_index[1]] and angle_min[select_index[0]] > angle_min[select_index[1]]:
                select_index = select_index[0]
            else:
                select_index = select_index[1]

            if np.random.random() < 1.0002 + asf_rank[select_index] / self.pop_size:
                mating_pool[x] = select_index
            else:
                mating_pool[x] = np.random.randint(0, high=self.pop_size)

        return mating_pool

    def NDSort(self, PopObj, nSort):
        # get all proteins' energy values
        PopObj, a, Loc = np.unique(PopObj, return_index=True, return_inverse=True, axis=0)

        rank = np.lexsort(PopObj[:, ::-1].T)
        PopObj = PopObj[rank].copy()
        table, _ = np.histogram(Loc, max(Loc) + 1)

        N, M = PopObj.shape
        # FrontNo = np.full((1, N), np.inf)
        FrontNo = np.ones(N) * np.inf
        FrontNo = np.array(FrontNo)
        MaxFNo = 0

        index_table = np.where(FrontNo < np.inf)
        # print(sum(table[index_table]))

        while sum(table[index_table]) < min(nSort, len(Loc)):
            # print(sum(table[index_table]))
            MaxFNo = MaxFNo + 1
            for i in range(N):
                if FrontNo[i] == np.inf:
                    Dominated = False
                    for j in range(i - 1, 0, -1):
                        if FrontNo[j] == MaxFNo:
                            m = 1
                            while m <= M - 1:
                                if PopObj[i, m] >= PopObj[j, m]:
                                    m = m + 1
                                else:
                                    break
                            Dominated = m >= M
                            if Dominated | M == 1:
                                break
                    if not Dominated:
                        FrontNo[i] = MaxFNo
            index_table = np.where(FrontNo < np.inf)

        FrontNo[rank] = FrontNo
        FrontNo = FrontNo[Loc]

        # pareto_dict = {}
        # for x in range(1, MaxFNo + 1):
        #     pareto_dict[x] = np.where(FrontNo == x)[0]

        return FrontNo, MaxFNo  # , pareto_dict

    def NDsort_SDR(self, obj_view, nSort):
        protein_num = obj_view.shape[0]
        norm_p = np.sum(obj_view, axis=1)
        cosine = 1 - cdist(obj_view, obj_view, metric='cosine')
        np.fill_diagonal(cosine, 0)
        angle = np.arccos(cosine)

        temp = np.sort(np.unique(np.min(angle, axis=1)))
        min_a = temp[int(min(np.ceil(protein_num / 2) - 1, len(temp) - 1))]
        theta = np.where(angle / min_a < 1, 1.0, angle / min_a)

        dominate = np.zeros((protein_num, protein_num), dtype=bool)
        for i in range(protein_num - 1):
            for j in range(i + 1, protein_num):
                if norm_p[i] * theta[i, j] < norm_p[j]:
                    dominate[i, j] = True
                elif norm_p[j] * theta[j, i] < norm_p[i]:
                    dominate[j, i] = True

        front_no = np.ones(protein_num) * np.inf
        max_f_no = 0
        while np.sum(front_no != np.inf) < min(nSort, protein_num):
            max_f_no = max_f_no + 1
            current = (~np.any(dominate, axis=0)) & (front_no == np.inf)
            front_no[current] = max_f_no
            dominate[current] = False

        return front_no, max_f_no

    def get_pareto_rank_array(self, pareto_front_dict):
        pareto_size = np.sum([len(x) for x in pareto_front_dict.values()])
        pareto_rank_array = np.zeros(pareto_size, dtype=int)

        for rank in pareto_front_dict:
            for x in pareto_front_dict[rank]:
                pareto_rank_array[x] = rank
        max_rank = np.max(pareto_rank_array)
        return pareto_rank_array, max_rank

    def select_protein(self, combine_proteins, combine_obj_view, z_min, z_max):
        # normalization
        combine_size = combine_obj_view.shape[0]
        combine_obj_view = combine_obj_view - np.tile(z_min, [combine_size, 1])

        obj_range = z_max - z_min
        if 0.05 * np.max(obj_range) < np.min(obj_range):
            combine_obj_view = combine_obj_view / np.tile(obj_range, [combine_size, 1])

        _, select_index = np.unique(np.round(combine_obj_view * 1e6) / 1e6, return_index=True, axis=0)
        select_index = np.array(select_index, dtype=int)

        combine_proteins = [combine_proteins[x].copy() for x in select_index]
        combine_obj_view = np.array([x.obj.copy() for x in combine_proteins])

        select_num = min(self.pop_size, len(combine_proteins))

        # non-dominated sort
        front_no, max_f_no = self.NDsort_SDR(combine_obj_view, select_num)
        next_index = front_no < max_f_no

        # calculate the crowding distance of each solutions
        crow_dist = self.crowding_distance(combine_obj_view, front_no)

        # select the solutions in the last front based on their crowding distance
        last = np.where(front_no == max_f_no)[0]
        rank = np.argsort(crow_dist[last])[::-1]
        next_index[last[rank[0: select_num - np.sum(next_index)]]] = True

        next_index = np.where(next_index == True)[0]
        # next generation
        proteins = [combine_proteins[x].copy() for x in next_index]
        front_no = np.array(front_no[next_index])
        crow_dist = np.array(crow_dist[next_index])

        return proteins, front_no, crow_dist

    def update_archive(self):
        temp_archive = self.proteins + self.archive
        temp_archive_angle_view = np.array([x.angle_view().copy() for x in temp_archive])
        _, save_index = np.unique(temp_archive_angle_view, axis=0, return_index=True)
        temp_archive = [temp_archive[x].copy() for x in save_index]

        temp_archive_obj = [x.obj.copy() for x in temp_archive]
        pareto_rank, max_rank = self.NDSort(temp_archive_obj, len(temp_archive_obj))
        select_pareto_front = np.where(pareto_rank == 1)[0]

        # pareto_front_dict = self.nonDominatedSort(temp_archive) 
        # select_pareto_front = pareto_front_dict[1]  # + pareto_front_dict[2]

        temp_archive = [temp_archive[x] for x in select_pareto_front]
        temp_archive_obj_view = np.array([x.obj for x in temp_archive])

        if len(temp_archive) > self.archive_thresh:
            self.delete_archive(temp_archive, temp_archive_obj_view)
        else:
            self.archive = temp_archive

        self.archive_obj_view = np.array([x.obj for x in self.archive])

    def delete_archive(self, temp_archive, temp_archive_obj_view):
        num_archive = len(temp_archive)

        # calculate the grid location of each solution
        obj_max = np.max(temp_archive_obj_view, axis=0)
        obj_min = np.min(temp_archive_obj_view, axis=0)
        div = (obj_max - obj_min) / self.mesh_div
        div = np.tile(div, (num_archive, 1))
        obj_min = np.tile(obj_min, (num_archive, 1))

        grid_location = np.floor((temp_archive_obj_view - obj_min) / div)
        grid_location[grid_location >= self.mesh_div] = self.mesh_div - 1
        grid_location[np.isnan(grid_location)] = 0

        # detect the grid of each solution belongs to
        _, _, site = np.unique(grid_location, return_index=True, return_inverse=True, axis=0)

        # calculate the crowd degree of each grid
        crowd_degree = np.histogram(site, np.max(site) + 1)[0]

        del_index = np.zeros(num_archive, dtype=bool)

        while np.sum(del_index) < num_archive - self.archive_thresh:
            max_grid = np.where(crowd_degree == max(crowd_degree))[0]
            temp = np.random.randint(0, len(max_grid))
            grid = max_grid[temp]

            in_grid = np.where(site == grid)[0]

            temp = np.random.randint(0, len(in_grid))
            p = in_grid[temp]
            del_index[p] = True
            site[p] = -100
            crowd_degree[grid] = crowd_degree[grid] - 1

        del_index = np.where(del_index == 1)[0]

        # for x in range(len(temp_archive)):
        #    if x not in del_index:
        #        self.archive = temp_archive[x]

        self.archive = [temp_archive[x].copy() for x in range(len(temp_archive)) if x not in del_index]

    def run(self):

        proteins_obj_view = np.array([x.obj.copy() for x in self.proteins])
        z_min = np.min(proteins_obj_view, axis=0)
        z_max = np.max(proteins_obj_view, axis=0)

        _, front_no, crow_dist = self.select_protein(self.proteins, proteins_obj_view, z_min, z_max)
        print_path = self.logging.protein_name + '_print_result.txt'
        f = open(print_path, 'a')

        for x in range(self.current_gen, self.max_gen):
            start_time = time.time()
            mating_pool = tournament(2, self.pop_size, np.c_[front_no, -crow_dist])

            offspring = [self.proteins[x].copy() for x in mating_pool]
            offspring_angle_view = self.crossover_binary(offspring)
            offspring_angle_view = self.mutation_polynomial(offspring_angle_view)

            for protein, angle in zip(offspring, offspring_angle_view):
                protein.update_angle_from_view(angle)
            self.energy.calculate_energy(offspring)

            offspring_obj_view = [x.obj.copy() for x in offspring]
            z_min = np.min(np.vstack([z_min, offspring_obj_view]), axis=0)
            proteins_obj_view = np.array([x.obj.copy() for x in self.proteins])
            z_max = np.max(proteins_obj_view[np.where(front_no == 1)[0]], axis=0)

            combine_proteins = [x.copy() for x in self.proteins + offspring]
            combine_proteins_obj_view = np.array([x.obj.copy() for x in combine_proteins])

            self.proteins, front_no, crow_dist = self.select_protein(combine_proteins, combine_proteins_obj_view, z_min,
                                                                     z_max)

            self.current_gen = x
            print(self.logging.protein_name,x, time.time() - start_time, z_min)
            print(self.logging.protein_name,x, time.time() - start_time, z_min, file=f)
            self.logging.write(self.proteins, self.coder, self.config['save_all'], self.current_gen)

        self.logging.write_archive(self.proteins, self.coder)
        self.energy.stop()
