from search.MOEA import MOEA
import numpy as np
import time

class MOPSO(MOEA):
    def __init__(self, proteins, config, energy, coder, logging, current_gen=1):
        super(MOPSO, self).__init__(config, energy, coder, logging, current_gen)

        self.proteins = proteins
        self.obj_num = self.proteins[0].obj_num
        self.max_angles, self.min_angles = proteins[0].get_angles_field()
        self.init_param()

    def init_param(self):
        self.w = self.config['w']
        self.c1 = self.config['c1']
        self.c2 = self.config['c2']
        self.mesh_div = self.config['mesh_div']
        self.archive_thresh = self.config['archive_thresh']

    def update_gbest(self):
        num_archive = len(self.archive)

        # calculate the grid location of each solution
        obj_max = np.max(self.archive_obj_view, axis=0)
        obj_min = np.min(self.archive_obj_view, axis=0)
        div = (obj_max - obj_min) / self.mesh_div
        div = np.tile(div, (num_archive, 1))
        obj_min = np.tile(obj_min, (num_archive, 1))

        grid_location = np.floor((self.archive_obj_view - obj_min) / div)
        grid_location[grid_location >= self.mesh_div] = self.mesh_div - 1
        grid_location[np.isnan(grid_location)] = 0

        # detect the grid of each solution belongs to
        _, _, site = np.unique(grid_location, return_index=True, return_inverse=True, axis=0)

        # calculate the crowd degree of each grid
        crowd_degree = np.histogram(site, np.max(site) + 1)[0]

        # using roulette-wheel for objective selection
        selected_grid = self.roulette_wheel_selection(crowd_degree)

        re_p = np.zeros(self.pop_size, )
        for i in range(self.pop_size):
            in_grid = np.where(site == selected_grid[i])[0]
            temp = np.random.randint(0, len(in_grid))
            re_p[i] = in_grid[temp]
        re_p = np.int64(re_p)
        return [self.archive[x].copy() for x in re_p]

    def roulette_wheel_selection(self, crowd_degree_p):
        crowd_degree = crowd_degree_p.copy()
        crowd_degree = np.reshape(crowd_degree, (-1,))
        crowd_degree = crowd_degree + np.minimum(np.min(crowd_degree), 0)
        crowd_degree = np.cumsum(1 / crowd_degree)
        crowd_degree = crowd_degree / np.max(crowd_degree)
        index = np.sum(np.int64(~(np.random.rand(self.pop_size, 1) < crowd_degree)), axis=1)

        return index

    def update_v(self):
        r1 = np.tile(np.random.rand(self.pop_size, 1), (1, self.angles_num))
        r2 = np.tile(np.random.rand(self.pop_size, 1), (1, self.angles_num))

        self.v = self.w * self.v + \
                 r1 * (self.pbest_angles_view - self.proteins_angles_view) + \
                 r2 * (self.gbest_angles_view - self.proteins_angles_view)

    def update_angle(self):
        self.proteins_angles_view = self.proteins_angles_view + self.v
        self.proteins_angles_view = self.proteins_angles_view.clip(self.min_angles, self.max_angles)

    def update_pbest(self):
        temp = self.pbest_obj_view - self.proteins_obj_view
        dominate = np.int64(np.any(temp < 0, axis=1)) - np.int64(np.any(temp > 0, axis=1))

        remained_1 = dominate == -1
        self.pbest_obj_view[remained_1] = self.proteins_obj_view[remained_1]
        for x in remained_1:
            self.pbest[x] = self.proteins[x].copy()

        remained_2 = dominate == 0
        remained_temp_rand = np.random.rand(len(dominate), ) < 0.5
        remained_final = remained_2 & remained_temp_rand
        self.pbest_obj_view[remained_final] = self.proteins_obj_view[remained_final]
        for x in remained_final:
            self.pbest[x] = self.proteins[x].copy()

    def update_archive(self):
        temp_archive = [x.copy() for x in self.proteins + self.archive]
        pareto_front_dict = self.nonDominatedSort(temp_archive)
        temp_archive = [temp_archive[x] for x in pareto_front_dict[1]]
        temp_archive_obj_view = np.array([x.obj for x in temp_archive])

        if len(temp_archive) > self.archive_thresh:
            self.delete_archive(temp_archive, temp_archive_obj_view)
        else:
            self.archive = temp_archive

        self.archive_obj_view = np.array([x.obj for x in self.archive])


    def delete_archive(self, temp_archive, temp_archive_obj_view):
        num_archive = len(temp_archive)

        # calculate the grid location of each solution
        obj_max = np.max(temp_archive_obj_view, axis=0)
        obj_min = np.min(temp_archive_obj_view, axis=0)
        div = (obj_max - obj_min) / self.mesh_div
        div = np.tile(div, (num_archive, 1))
        obj_min = np.tile(obj_min, (num_archive, 1))

        grid_location = np.floor((temp_archive_obj_view - obj_min) / div)
        grid_location[grid_location >= self.mesh_div] = self.mesh_div - 1
        grid_location[np.isnan(grid_location)] = 0

        # detect the grid of each solution belongs to
        _, _, site = np.unique(grid_location, return_index=True, return_inverse=True, axis=0)

        # calculate the crowd degree of each grid
        crowd_degree = np.histogram(site, np.max(site) + 1)[0]

        del_index = np.zeros(num_archive, dtype=bool)

        while np.sum(del_index) < num_archive - self.archive_thresh:
            max_grid = np.where(crowd_degree == max(crowd_degree))[0]
            temp = np.random.randint(0, len(max_grid))
            grid = max_grid[temp]

            in_grid = np.where(site == grid)[0]

            temp = np.random.randint(0, len(in_grid))
            p = in_grid[temp]
            del_index[p] = True
            site[p] = -100
            crowd_degree[grid] = crowd_degree[grid] - 1

        del_index = np.where(del_index == 1)[0]

        #for x in range(len(temp_archive)):
        #    if x not in del_index:
        #        self.archive = temp_archive[x]

        self.archive = [temp_archive[x].copy() for x in range(len(temp_archive)) if x not in del_index]

    def reshape_angles(self):
        # proteins_angles_view = []
        # for x in self.proteins:
        #     proteins_angles_view.append(x.angle_view())
        # self.proteins_angles_view = np.array(proteins_angles_view)
        #
        # pbest_angles_view = []
        # for x in self.pbest:
        #     pbest_angles_view.append(x.angle_view())
        # self.pbest_angles_view = np.array(pbest_angles_view)
        #
        # gbest_angles_view = []
        # for x in self.gbest:
        #     gbest_angles_view.append(x.angle_view())
        # self.gbest_angles_view = np.array(gbest_angles_view)

        self.proteins_angles_view = np.array([x.angle_view() for x in self.proteins])
        self.pbest_angles_view = np.array([x.angle_view() for x in self.pbest])
        self.gbest_angles_view = np.array([x.angle_view() for x in self.gbest])


    def reshape_obj(self):
        # proteins_obj_view = []
        # for x in self.proteins:
        #     proteins_obj_view.append(x.obj)
        # self.proteins_obj_view = np.array(proteins_obj_view)
        #
        # pbest_obj_view = []
        # for x in self.pbest:
        #     pbest_obj_view.append(x.obj)
        # self.pbest_obj_view = np.array(pbest_obj_view)
        #
        # gbest_obj_view = []
        # for x in self.pbest:
        #     gbest_obj_view.append(x.obj)
        # self.gbest_obj_view = np.array(gbest_obj_view)

        self.proteins_obj_view = np.array([x.obj for x in self.proteins])
        self.pbest_obj_view = np.array([x.obj for x in self.pbest])
        self.gbest_obj_view = np.array([x.obj for x in self.gbest])
        self.archive_obj_view = np.array([x.obj for x in self.archive])


    def run(self):
        self.max_v = (self.max_angles - self.min_angles) * 0.05
        self.min_v = (self.max_angles - self.min_angles) * 0.05 * -1

        self.v = np.random.random(self.max_v.shape) * (self.max_v - self.min_v) + self.min_v
        self.angles_num = len(self.max_angles)

        self.pbest = [x.copy() for x in self.proteins]

        pareto_front_dict = self.nonDominatedSort(self.proteins)
        self.archive = [self.proteins[x].copy() for x in pareto_front_dict[1]]
        self.archive_obj_view = np.array([x.obj.copy() for x in self.archive])

        self.gbest = self.update_gbest()


        for x in range(self.current_gen, self.max_gen):

            start_time = time.time()
            self.reshape_angles()
            self.update_v()
            self.update_angle()

            for protein, angle in zip(self.proteins, self.proteins_angles_view):
                protein.update_angle_from_view(angle)

            self.energy.calculate_energy(self.proteins)

            self.reshape_obj()
            self.update_pbest()
            self.update_archive()
            self.gbest = self.update_gbest()
            self.logging.write(self.proteins, self.coder)
            print(x, time.time() - start_time, len(self.archive))
            self.current_gen = self.current_gen + 1


