import numpy as np
from random import *
from multiprocessing import Process, Queue, cpu_count
import multiprocessing as mp

def popu_reshape(individual, dnm):
    shape = [(dnm.M + dnm.M_SNU), dnm.data_dim]
    length = int(len(individual) / 2)
    temp_w = individual[:length].reshape(shape)
    temp_q = individual[length:].reshape(shape)
    w = temp_w[:dnm.M, :]
    q = temp_q[:dnm.M, :]

    w_snu = temp_w[dnm.M:, :]
    q_snu = temp_q[dnm.M:, :]

    return w, q, w_snu, q_snu


def mutation(population, lamb, mu, alpha, sigma, pMutation):
    mutation_popualtion = population.copy()
    m, n = mutation_popualtion.shape
    for i in range(m):
        for k in range(n):
            if random() < lamb[i]:
                EP = mu
                EP[i] = 0
                EP = EP / np.sum(EP)
                j = RouletteWheelSelection(EP)
                mutation_popualtion[i, k] = population[i, k] + alpha * (population[j, k] - population[i, k])
            if random() <= pMutation:
                mutation_popualtion[i][k] = mutation_popualtion[i][k] + sigma[k] * random()
    return mutation_popualtion


def boundary(population):
    population = np.clip(population, -1, 1)
    return population


def init_population(popsize, dim):
    population = np.random.random([popsize, dim]) * 2 - 1
    return population


def RouletteWheelSelection(EP):
    r = random()
    C = np.cumsum(EP)
    j = np.argwhere(r <= C)[0]
    return j


def selection(population, mutation_population, fitness, fitness_mutation, nKeep, nNew):
    index = np.argsort(fitness_mutation)
    mutation_population = mutation_population[index, :]
    fitness_mutation = fitness_mutation[index]

    temp_population = population[0:nKeep, :]
    temp_mut_population = mutation_population[0: nNew, :]

    temp_fitness = fitness[0:nKeep]
    temp_mut_fitness = fitness_mutation[0: nNew]

    population = np.row_stack((temp_population, temp_mut_population))
    fitness = np.hstack((temp_fitness, temp_mut_fitness))

    return population, fitness


def worker(input, output):
    for func, args in iter((input.get), 'STOP'):
        result = calculate(func, args)
        output.put(result)


def calculate(func, args):
    result, i = func(*args)
    return result, i


def eva_fitenss(popu_CorssOverTmp, index, dnm):
    w, q, w_snu, q_snu = popu_reshape(popu_CorssOverTmp[index, :], dnm)
    dnm.w = w
    dnm.q = q
    dnm.w_SNU = w_snu
    dnm.q_SNU = q_snu
    train_fit = dnm.train()
    fitness_crossover = np.mean((train_fit - dnm.train_label) ** 2)
    return fitness_crossover, index


def run(dnm, popsize, dim, KeepRate, alpha, pMutation, max_iter):
    print('======  Training the model BY BBO  ======')
    nKeep = round(KeepRate * popsize)
    nNew = popsize - nKeep
    mu = np.linspace(1, 0, popsize)
    lamb = 1 - mu

    population = init_population(popsize, dim)
    fitness = np.zeros(popsize)
    sigma = 0.02 * 2 * np.ones_like(population[1, :])
    for i in range(popsize):
        w, q, w_snu, q_snu = popu_reshape(population[i, :], dnm)
        dnm.w = w
        dnm.q = q
        dnm.w_SNU = w_snu
        dnm.q_SNU = q_snu
        fitness[i] = np.mean((dnm.train() - dnm.train_label) ** 2)

    result = np.zeros(max_iter)

    for iter in range(max_iter):
        index = np.argsort(fitness)
        population = population[index, :]
        fitness = fitness[index]
        popu_MutationTmp = mutation(population, lamb, mu, alpha, sigma, pMutation)
        popu_MutationTmp = boundary(popu_MutationTmp)
        fitness_crossover = np.zeros(popsize)

        # calculation fitness
        for i in range(popsize):
            fitness_crossover[i],_ =eva_fitenss(popu_MutationTmp, i, dnm)       

        population, fitness = selection(population, popu_MutationTmp, fitness, fitness_crossover, nKeep, nNew)
        result[iter] = np.min(fitness)
        print('Iteration: %d The best fitness:%f' % (iter, np.min(fitness)))

    best_index = np.argmin(fitness)
    w, q, w_snu, q_snu = popu_reshape(population[best_index, :], dnm)
    dnm.w = w
    dnm.q = q
    dnm.w_SNU = w_snu
    dnm.q_SNU = q_snu

    return result, dnm

def run_multiprocess(dnm, popsize, dim, KeepRate, alpha, pMutation, max_iter):
    print('======  Training the model BY BBO  ======')
    nKeep = round(KeepRate * popsize)
    nNew = popsize - nKeep
    mu = np.linspace(1, 0, popsize)
    lamb = 1 - mu

    population = init_population(popsize, dim)
    fitness = np.zeros(popsize)
    sigma = 0.02 * 2 * np.ones_like(population[1, :])
    for i in range(popsize):
        w, q, w_snu, q_snu = popu_reshape(population[i, :], dnm)
        dnm.w = w
        dnm.q = q
        dnm.w_SNU = w_snu
        dnm.q_SNU = q_snu
        fitness[i] = np.mean((dnm.train() - dnm.train_label) ** 2)

    result = np.zeros(max_iter)

    for iter in range(max_iter):
        index = np.argsort(fitness)
        population = population[index, :]
        fitness = fitness[index]
        popu_MutationTmp = mutation(population, lamb, mu, alpha, sigma, pMutation)
        popu_MutationTmp = boundary(popu_MutationTmp)
        fitness_crossover = np.zeros(popsize)

        # calculation fitness
        task_queue = Queue()
        done_queue = Queue()
        eva_population = [(eva_fitenss, (popu_MutationTmp, i, dnm)) for i in range(popsize)]
        for task in eva_population:
            task_queue.put(task)
        num_cores = 16 #mp.cpu_count()
        for i in range(num_cores):
            Process(target=worker, args=(task_queue, done_queue)).start()
        for i in range(len(eva_population)):
            temp_fitness, index = done_queue.get()
            fitness_crossover[index] = temp_fitness
        for i in range(num_cores):
            task_queue.put('STOP')
        

        population, fitness = selection(population, popu_MutationTmp, fitness, fitness_crossover, nKeep, nNew)
        result[iter] = np.min(fitness)
        print('Iteration: %d The best fitness:%f' % (iter, np.min(fitness)))

    best_index = np.argmin(fitness)
    w, q, w_snu, q_snu = popu_reshape(population[best_index, :], dnm)
    dnm.w = w
    dnm.q = q
    dnm.w_SNU = w_snu
    dnm.q_SNU = q_snu

    return result, dnm




