import numpy as np
import time
from threading import Thread, Semaphore


class MyThread(Thread):

    def __init__(self, func, args=()):
        super(MyThread, self).__init__()
        self.func = func
        self.args = args

    def run(self):
        self.result = self.func(*self.args)

    def get_result(self):
        try:
            return self.result  #
        except Exception:
            return None


def init_population(popsize, dim):
    population = np.random.random([popsize, dim]) * 2 - 1
    return population


def boundary(population):
    population = np.clip(population, -1, 1)
    return population


def mutation(population, F):
    m, n = population.shape
    r1 = np.zeros(m, dtype=np.int64)
    r2 = np.zeros(m, dtype=np.int64)
    r3 = np.zeros(m, dtype=np.int64)

    sequence = np.arange(m, dtype=np.int64)
    for i in range(m):
        temp_r = np.random.choice(sequence, 3, replace=False)
        r1[i] = temp_r[0]
        r2[i] = temp_r[1]
        r3[i] = temp_r[2]

    popu_MutationTmp = population[r1, :] + F * (population[r2, :] - population[r3, :])
    return popu_MutationTmp


def crossover(population, popu_MutationTmp, CR):
    m, n = popu_MutationTmp.shape
    popu_CorssOverTmp = np.zeros((m, n))
    CR_ = np.ones_like(popu_MutationTmp) * CR
    r = np.random.random([m, n])
    flag = r < CR_
    popu_CorssOverTmp[flag] = popu_MutationTmp[flag].copy()
    popu_CorssOverTmp[~flag] = population[~flag].copy()
    # for i in range(m):
    #     for j in range(n):
    #         r = np.random.random()
    #         if (r <= CR):
    #             popu_CorssOverTmp[i, j] = popu_MutationTmp[i, j]
    #         else:
    #             popu_CorssOverTmp[i, j] = population[i, j]
    return popu_CorssOverTmp


def selection(population, popu_CorssOverTmp, fitness, fitnessCrossOverVal):
    index = fitnessCrossOverVal < fitness
    population[index, :] = popu_CorssOverTmp[index, :].copy()
    fitness[index] = fitnessCrossOverVal[index].copy()
    # for i in range(len(fitness)):
    #     if fitness[i]>fitnessCrossOverVal[i]:
    #         population[i,:] = popu_CorssOverTmp[i,:]
    #         fitness[i] = fitnessCrossOverVal[i]
    return population, fitness


def popu_reshape(individual, dnm):
    shape = [(dnm.M + dnm.M_SNU), dnm.data_dim]
    length = int(len(individual) / 2)
    temp_w = individual[:length].reshape(shape)
    temp_q = individual[length:].reshape(shape)
    w = temp_w[:dnm.M, :]
    q = temp_q[:dnm.M, :]

    w_snu = temp_w[dnm.M:, :]
    q_snu = temp_q[dnm.M:, :]

    return w, q, w_snu, q_snu


def eva_fitness(population, dnm, i, max_thread):
    with max_thread:
        w, q, w_snu, q_snu = popu_reshape(population, dnm)
        dnm.w = w
        dnm.q = q
        dnm.w_SNU = w_snu
        dnm.q_SNU = q_snu
        train_fit = dnm.train()
        # fitness[i] = np.mean((train_fit - dnm.train_label) ** 2)


def eva_fitness1(population, dnm, i, max_thread):
    with max_thread:
        w, q, w_snu, q_snu = popu_reshape(population, dnm)
        dnm.w = w
        dnm.q = q
        dnm.w_SNU = w_snu
        dnm.q_SNU = q_snu
        train_fit = dnm.train()
        fitness_crossover = np.mean((train_fit - dnm.train_label) ** 2)
        return fitness_crossover


def run(dnm, popsize, dim, max_iter, F, CR, print_flag):
    print('====== Training the model BY DE SNU======')
    # initialization population
    population = init_population(popsize, dim)
    # evaluating the popualtion
    # global  fitness_crossover

    # max_thread = Semaphore(8)
    #
    # thread_id = []
    # for i in range(popsize):
    #     thread_energy = Thread(target=eva_fitness, args=(population[i, :], dnm, i, max_thread))
    #     thread_energy.start()
    #     thread_id.append(thread_energy)
    # for x in thread_id:
    #     x.join()
    fitness = np.zeros(popsize)

    for i in range(popsize):

        w, q, w_snu, q_snu = popu_reshape(population[i, :], dnm)
        dnm.w = w
        dnm.q = q
        dnm.w_SNU = w_snu
        dnm.q_SNU = q_snu
        train_fit = dnm.train()
        fitness[i] = np.mean((train_fit - dnm.train_label) ** 2)

    # iteration
    result = np.zeros(max_iter)
    for iter in range(max_iter):

        popu_MutationTmp = mutation(population, F)
        popu_CorssOverTmp = crossover(population, popu_MutationTmp, CR)
        popu_CorssOverTmp = boundary(popu_CorssOverTmp)
        fitness_crossover = np.zeros(popsize)


        for i in range(popsize):
            w, q, w_snu, q_snu = popu_reshape(popu_CorssOverTmp[i, :], dnm)
            dnm.w = w
            dnm.q = q
            dnm.w_SNU = w_snu
            dnm.q_SNU = q_snu
            train_fit = dnm.train()
            fitness_crossover[i] = np.mean((train_fit - dnm.train_label) ** 2)

        population, fitness = selection(population, popu_CorssOverTmp, fitness, fitness_crossover)
        result[iter] = np.min(fitness)
        if print_flag:
            print('Iteration: %d The best fitness: %f' % (iter, np.min(fitness)))

    return result, dnm


