import numpy as np
import time
from multiprocessing import Process, Queue
import multiprocessing as mp


def init_population(popsize, dim):
    population = np.random.random([popsize, dim]) * 2 - 1
    return population


def boundary(population):
    population = np.clip(population, -1, 1)
    return population


def mutation(population, F):
    m, n = population.shape
    r1 = np.zeros(m, dtype=np.int64)
    r2 = np.zeros(m, dtype=np.int64)
    r3 = np.zeros(m, dtype=np.int64)

    sequence = np.arange(m, dtype=np.int64)
    for i in range(m):
        temp_r = np.random.choice(sequence, 3, replace=False)
        r1[i] = temp_r[0]
        r2[i] = temp_r[1]
        r3[i] = temp_r[2]

    popu_MutationTmp = population[r1, :] + F * (population[r2, :] - population[r3, :])
    return popu_MutationTmp


def crossover(population, popu_MutationTmp, CR):
    m, n = popu_MutationTmp.shape
    popu_CorssOverTmp = np.zeros((m, n))
    CR_ = np.ones_like(popu_MutationTmp) * CR
    r = np.random.random([m, n])
    flag = r < CR_
    popu_CorssOverTmp[flag] = popu_MutationTmp[flag].copy()
    popu_CorssOverTmp[~flag] = population[~flag].copy()
    return popu_CorssOverTmp


def selection(population, popu_CorssOverTmp, fitness, fitnessCrossOverVal):
    index = fitnessCrossOverVal < fitness
    population[index, :] = popu_CorssOverTmp[index, :].copy()
    fitness[index] = fitnessCrossOverVal[index].copy()

    return population, fitness
"""
 temp_data = np.matlib.repmat(self.train_data[h, :], self.M, 1)
            Y = 1 / (1 + np.exp(-self.k * (self.w * temp_data - self.q)))
            Z = np.prod(Y, 1)
            V = np.sum(Z)
            O = 1 / (1 + np.exp(-self.k * (V - self.qs)))
            Q[h] = O
"""
def popu_reshape(individual, dnm):
    shape = [(dnm.M + dnm.M_SNU), dnm.data_dim]
    length = int(len(individual) / 2)
    temp_w = individual[:length].reshape(shape)
    temp_q = individual[length:].reshape(shape)
    w = temp_w[:dnm.M, :]
    q = temp_q[:dnm.M, :]

    w_snu = temp_w[dnm.M:, :]
    q_snu = temp_q[dnm.M:, :]

    return w, q, w_snu, q_snu


def worker(input, output):
    for func, args in iter((input.get), 'STOP'):
        result = calculate(func, args)
        # print(args[1])
        output.put(result)


def calculate(func, args):
    result, i = func(*args)
    return result, i


def eva_fitenss(popu_CorssOverTmp, index, dnm):
    w, q, w_snu, q_snu = popu_reshape(popu_CorssOverTmp[index, :], dnm)
    dnm.w = w
    dnm.q = q
    dnm.w_SNU = w_snu
    dnm.q_SNU = q_snu
    train_fit = dnm.train()
    fitness_crossover = np.mean((train_fit - dnm.train_label) ** 2)
    return fitness_crossover, index


def run_multiprocess(dnm, popsize, dim, max_iter, F, CR, print_flag):
    # print('====== Training the model BY DE SNU======')
    # initialization population
    population = init_population(popsize, dim)
    # evaluating the popualtion
    fitness = np.zeros(popsize)
    # acc = np.zeros(popsize)
    for i in range(popsize):
        w, q, w_snu, q_snu = popu_reshape(population[i, :], dnm)
        dnm.w = w
        dnm.q = q
        dnm.w_SNU = w_snu
        dnm.q_SNU = q_snu
        train_fit = dnm.train()
        fitness[i] = np.mean((train_fit - dnm.train_label) ** 2)

    # iteration
    result = np.zeros(max_iter)
    for iter in range(max_iter):
        popu_MutationTmp = mutation(population, F)
        popu_CorssOverTmp = crossover(population, popu_MutationTmp, CR)
        popu_CorssOverTmp = boundary(popu_CorssOverTmp)
        fitness_crossover = np.zeros(popsize)

        # multiprocessing evaluate the poulaiton
        task_queue = Queue()
        done_queue = Queue()
        eva_population = [(eva_fitenss, (popu_CorssOverTmp, i, dnm)) for i in range(popsize)]
        for task in eva_population:
            task_queue.put(task)
        num_cores = 16# mp.cpu_count()
        for i in range(num_cores):
            Process(target=worker, args=(task_queue, done_queue)).start()
        for i in range(len(eva_population)):
            temp_fitness, index = done_queue.get()
            fitness_crossover[index] = temp_fitness
        for i in range(num_cores):
            task_queue.put('STOP')

        population, fitness = selection(population, popu_CorssOverTmp, fitness, fitness_crossover)
        result[iter] = np.min(fitness)
        print('Iteration: %d The best fitness:%f' % (iter, np.min(fitness)))


    best_index = np.argmin(fitness)
    w, q, w_snu, q_snu = popu_reshape(population[best_index, :], dnm)
    dnm.w = w
    dnm.q = q
    dnm.w_SNU = w_snu
    dnm.q_SNU = q_snu
    return result, dnm


def run(dnm, popsize, dim, max_iter, F, CR, flag):
    print('======  Training the model BY DE   ======')
    # initialization population
    population = init_population(popsize, dim)

    # evaluating the popualtion
    fitness = np.zeros(popsize)
    for i in range(popsize):
        w, q, w_snu, q_snu = popu_reshape(population[i, :], dnm)
        dnm.w = w
        dnm.q = q
        dnm.w_SNU = w_snu
        dnm.q_SNU = q_snu
        train_fit = dnm.train()
        fitness[i] = np.mean((train_fit - dnm.train_label) ** 2)
    print(fitness)
    # iteration
    result = np.zeros(max_iter)
    for iter in range(max_iter):

        popu_MutationTmp = mutation(population, F)
        popu_CorssOverTmp = crossover(population, popu_MutationTmp, CR)
        popu_CorssOverTmp = boundary(popu_CorssOverTmp)
        fitness_crossover = np.zeros(popsize)
        for i in range(popsize):
            w, q, w_snu, q_snu = popu_reshape(popu_CorssOverTmp[i, :], dnm)
            dnm.w = w
            dnm.q = q
            dnm.w_SNU = w_snu
            dnm.q_SNU = q_snu
            train_fit = dnm.train()
            fitness_crossover[i] = np.mean((train_fit - dnm.train_label) ** 2)

        population, fitness = selection(population, popu_CorssOverTmp, fitness, fitness_crossover)
        result[iter] = np.min(fitness)
        print('Iteration: %d The best fitness: %f' % (iter, np.min(fitness)))

    best_index = np.argmin(fitness)
    w, q, w_snu, q_snu = popu_reshape(population[best_index, :], dnm)
    dnm.w = w
    dnm.q = q
    dnm.w_SNU = w_snu
    dnm.q_SNU = q_snu

    return result, dnm
