from search.MOEA import MOEA
import numpy as np
import time
from scipy.spatial.distance import cdist
import collections

class MOPSO(MOEA):
    def __init__(self, proteins, config, energy, coder, logging, current_gen=1):
        super(MOPSO, self).__init__(config, energy, coder, logging, current_gen)

        self.proteins = proteins
        self.obj_num = self.proteins[0].obj_num
        self.max_angles, self.min_angles = proteins[0].get_angles_field()
        self.init_param()

    def init_param(self):
        self.w = self.config['w']
        self.c1 = self.config['c1']
        self.c2 = self.config['c2']
        self.mesh_div = self.config['mesh_div']
        self.archive_thresh = self.config['archive_thresh']

    def cosine_similarity(self, proteins_x, proteins_y=None, mode='side'):
        # parameters must be in ['main', 'side', 'all'],
        # which is meaning that calculate similarity in main-chain or side-chain angle of residue.
        # this function calculate each similarity of vector in proteins_y to proteins_X
        # and return shape is (len(proteins_y), len(proteins_x))
        # NOTE: if proteins_y is None, return value is proteins_x's similarity

        proteins_x_angles = []
        proteins_y_angles = []

        if mode == 'main':
            for protein in proteins_x:
                protein_angles = []
                for x in protein.res:
                    protein_angles = protein_angles + [x.get_angle('phi'), x.get_angle('psi')]
                proteins_x_angles.append(protein_angles)

            if proteins_y is not None:
                for protein in proteins_y:
                    protein_angles = []
                    for y in protein.res:
                        protein_angles = protein_angles + [y.get_angle('phi'), y.get_angle('psi')]
                    proteins_y_angles.append(protein_angles)

        elif mode == 'side':
            for protein in proteins_x:
                protein_angles = []
                for x in protein.res:
                    protein_angles = protein_angles + x.get_angle('sidechain').tolist()
                proteins_x_angles.append(protein_angles)

            if proteins_y is not None:
                for protein in proteins_y:
                    protein_angles = []
                    for y in protein.res:
                        protein_angles = protein_angles + y.get_angle('sidechain').tolist()
                    proteins_y_angles.append(protein_angles)

        elif mode == 'all':
            proteins_x_angles = [x.angle_view() for x in proteins_x]
            if proteins_y is not None:
                proteins_y_angles = [y.angle_view() for y in proteins_y]

        proteins_x_angles = np.array(proteins_x_angles)

        if proteins_y is not None:
            proteins_y_angles = np.array(proteins_y_angles)
            return cdist(proteins_y_angles, proteins_x_angles, 'cosine')
        else:
            return cdist(proteins_x_angles, proteins_x_angles, 'cosine')


    def linear_weight(self):
        pass

    def adaptive_wight(self):
        pass


    def update_gbest(self):
        # According to cosine similarity between each protein and archive,
        # choose the most similar in archive to be gbest for each protein.

        cosine_sim = self.cosine_similarity(self.archive, self.proteins, 'side')

        selected_index = np.argmax(cosine_sim, axis=1)
        selected_gbest = [self.archive[x].copy() for x in selected_index]

        return selected_gbest

    def roulette_wheel_selection(self, crowd_degree):
        crowd_degree = np.reshape(crowd_degree, (-1,))
        crowd_degree = crowd_degree + np.minimum(np.min(crowd_degree), 0)
        crowd_degree = np.cumsum(1 / crowd_degree)
        crowd_degree = crowd_degree / np.max(crowd_degree)
        index = np.sum(np.int64(~(np.random.rand(self.pop_size, 1) < crowd_degree)), axis=1)

        return index

    def update_v(self):
        r1 = np.tile(np.random.rand(self.pop_size, 1), (1, self.angles_num))
        r2 = np.tile(np.random.rand(self.pop_size, 1), (1, self.angles_num))

        self.v = self.w * self.v + \
                 r1 * (self.pbest_angles_view - self.proteins_angles_view) + \
                 r2 * (self.gbest_angles_view - self.proteins_angles_view)

    def update_angle(self):
        self.proteins_angles_view = self.proteins_angles_view + self.v
        self.proteins_angles_view = self.proteins_angles_view.clip(self.min_angles, self.max_angles)

    def update_pbest(self):
        temp = self.pbest_obj_view - self.proteins_obj_view
        dominate = np.int64(np.any(temp < 0, axis=1)) - np.int64(np.any(temp > 0, axis=1))

        remained_1 = dominate == -1
        self.pbest_obj_view[remained_1] = self.proteins_obj_view[remained_1]
        for x in remained_1:
            self.pbest[x] = self.proteins[x].copy()

        remained_2 = dominate == 0
        remained_temp_rand = np.random.rand(len(dominate), ) < 0.5
        remained_final = remained_2 & remained_temp_rand
        self.pbest_obj_view[remained_final] = self.proteins_obj_view[remained_final]
        for x in remained_final:
            self.pbest[x] = self.proteins[x].copy()

    def update_archive(self):
        temp_archive = [x.copy() for x in self.proteins + self.archive]
        pareto_front_dict = self.nonDominatedSort(temp_archive)
        temp_archive = [temp_archive[x] for x in pareto_front_dict[1]]
        temp_archive_obj_view = np.array([x.obj for x in temp_archive])

        if len(temp_archive) > self.archive_thresh:
            self.delete_archive(temp_archive, temp_archive_obj_view)
        else:
            self.archive = temp_archive

        self.archive_obj_view = np.array([x.obj for x in self.archive])



    def delete_archive(self, temp_archive, temp_archive_obj_view):
        # Delete the most similar protein from archive until the length of archive satisfied
        # The new archive has less similarity then the older
        num_archive = len(temp_archive)
        cosine_sim = self.cosine_similarity(temp_archive, mode='side')

        sim_count = collections.Counter(np.argmin(cosine_sim + np.eye(len(cosine_sim)), axis=1))
        sim_count = np.array([[x, sim_count[x]] for x in sim_count])
        # where sim_count is [[index0, count0], [index1, count1], [index2, count2] ... ]
        del_index = sim_count[np.argsort(sim_count[:, 1])][self.archive_thresh - num_archive:][:, 0]

        self.archive = [temp_archive[x].copy() for x in range(len(temp_archive)) if x not in del_index]

    def reshape_angles(self):
        # proteins_angles_view = []
        # for x in self.proteins:
        #     proteins_angles_view.append(x.angle_view())
        # self.proteins_angles_view = np.array(proteins_angles_view)
        #
        # pbest_angles_view = []
        # for x in self.pbest:
        #     pbest_angles_view.append(x.angle_view())
        # self.pbest_angles_view = np.array(pbest_angles_view)
        #
        # gbest_angles_view = []
        # for x in self.gbest:
        #     gbest_angles_view.append(x.angle_view())
        # self.gbest_angles_view = np.array(gbest_angles_view)

        self.proteins_angles_view = np.array([x.angle_view() for x in self.proteins])
        self.pbest_angles_view = np.array([x.angle_view() for x in self.pbest])
        self.gbest_angles_view = np.array([x.angle_view() for x in self.gbest])


    def reshape_obj(self):
        # proteins_obj_view = []
        # for x in self.proteins:
        #     proteins_obj_view.append(x.obj)
        # self.proteins_obj_view = np.array(proteins_obj_view)
        #
        # pbest_obj_view = []
        # for x in self.pbest:
        #     pbest_obj_view.append(x.obj)
        # self.pbest_obj_view = np.array(pbest_obj_view)
        #
        # gbest_obj_view = []
        # for x in self.pbest:
        #     gbest_obj_view.append(x.obj)
        # self.gbest_obj_view = np.array(gbest_obj_view)

        self.proteins_obj_view = np.array([x.obj for x in self.proteins])
        self.pbest_obj_view = np.array([x.obj for x in self.pbest])
        self.gbest_obj_view = np.array([x.obj for x in self.gbest])
        self.archive_obj_view = np.array([x.obj for x in self.archive])


    def run(self):
        self.max_v = (self.max_angles - self.min_angles) * 0.05
        self.min_v = (self.max_angles - self.min_angles) * 0.05 * -1

        self.v = np.random.random(self.max_v.shape) * (self.max_v - self.min_v) + self.min_v
        self.angles_num = len(self.max_angles)

        self.pbest = [x.copy() for x in self.proteins]

        pareto_front_dict = self.nonDominatedSort(self.proteins)
        self.archive = [self.proteins[x].copy() for x in pareto_front_dict[1]]
        self.archive_obj_view = np.array([x.obj for x in self.archive])

        self.gbest = self.update_gbest()


        for x in range(self.current_gen, self.max_gen):

            start_time = time.time()
            self.reshape_angles()
            self.update_v()
            self.update_angle()

            for protein, angle in zip(self.proteins, self.proteins_angles_view):
                protein.update_angle_from_view(angle)

            self.energy.calculate_energy(self.proteins)

            self.reshape_obj()
            self.update_pbest()
            self.update_archive()
            self.gbest = self.update_gbest()
            self.logging.write(self.proteins, self.coder)
            print(x, time.time() - start_time, len(self.archive))
            self.current_gen = self.current_gen + 1


