from search.MOEA import MOEA
from scipy.spatial.distance import cdist
import numpy as np
import time


class K_RVEA(MOEA):
    def __init__(self, proteins, uniform_point, config, energy, coder, logging, current_gen=1):
        super(K_RVEA, self).__init__(config, energy, coder, logging, current_gen)

        self.proteins = proteins
        self.uniform_point = uniform_point
        self.ref_points = uniform_point
        self.num_ref_points = len(self.ref_points)
        self.pop_size = self.num_ref_points
        self.obj_num = self.proteins[0].obj_num
        self.max_angles, self.min_angles = proteins[0].get_angles_field()

        self.init_param()

    def init_param(self):
        self.pro_c = self.config['prob_crossover']
        self.dis_c = self.config['dis_crossover']
        self.dis_m = self.config['dis_mutation']
        self.pro_m = 1 / self.pop_size

        self.a = self.config['alpha']
        self.fr = self.config['fr']

        self.gamma = None


    def reinsert(self, offsprings, ref_point):
        combine_proteins = [x.copy() for x in self.proteins + offsprings]
        pareto_front_dict = self.nonDominatedSort(combine_proteins)
        combine_proteins = [combine_proteins[x] for x in pareto_front_dict[1]]

        selection, self.gamma = self.apd_select(combine_proteins, ref_point, self.a, self.gamma)

        return [combine_proteins[x] for x in selection]

    def update_ref_point(self, ref_points):
        self.proteins_obj_view = self.proteins_obj_view - np.min(self.proteins_obj_view, axis=0)
        link_index = np.argmax(1 - cdist(self.proteins_obj_view, ref_points, 'cosine'), 1)
        no_link_idx = list(set(range(ref_points.shape[0])) - set(link_index))
        ref_points[no_link_idx, :] = np.random.rand(len(no_link_idx), ref_points.shape[1]) * np.max(self.proteins_obj_view, 0)
        return ref_points


    def ref_select(self):
        self.proteins_obj_view = np.array([x.obj for x in self.proteins])

        uniform_point, num_points = self.generate_uniform_point(self.pop_size, self.obj_num)
        ref_point = np.vstack([uniform_point, np.random.rand(num_points, self.obj_num)])

    def run(self):
        #self.ref_points = np.vstack([self.ref_points, np.random.rand(self.pop_size, self.obj_num)])
        self.num_ref_points = len(self.ref_points)
        for x in range(self.current_gen, self.max_gen):
            start_time = time.time()

            mating_pool = self.select_method(len(self.proteins), self.pop_size)
            offsprings = [self.proteins[x].copy() for x in mating_pool]

            offsprings_view = self.crossover_binary(offsprings)
            offsprings_view = self.mutation_polynomial(offsprings_view)

            for protein, angle in zip(offsprings, offsprings_view):
                protein.update_angle_from_view(angle)


            self.energy.calculate_energy(offsprings)

            self.proteins = self.reinsert(offsprings, self.ref_points)
            self.proteins_obj_view = np.array([x.obj for x in self.proteins])

            #self.ref_points[self.pop_size:, :] = self.update_ref_point(self.ref_points[self.pop_size:, :])

            if self.current_gen % np.ceil(self.fr * self.max_gen) == 0:
                # self.ref_points[:self.pop_size, :] = self.uniform_point * \
                #                                      (np.max(self.proteins_obj_view, 0) - np.min(self.proteins_obj_view, 0))
                self.ref_points = self.uniform_point * (np.max(self.proteins_obj_view, 0) - np.min(self.proteins_obj_view, 0))
                self.gamma = None

            print(x, time.time() - start_time, len(self.proteins))

            self.current_gen = self.current_gen + 1









