import os
import time
import numpy as np
import torch
from torch import nn
import torch.optim as optim
from sklearn.model_selection import train_test_split



class Soma(nn.Module):
    def __init__(self, k, qs):
        super(Soma, self).__init__()
        self.k = k
        self.qs = qs

    def forward(self, x):
        y = 1 / (1 + torch.exp(-self.k * (x - self.qs)))
        return y


class Membrane(nn.Module):
    def __init__(self):
        super(Membrane, self).__init__()

    def forward(self, x):
        x = torch.sum(x, 1)
        return x


class Dendritic(nn.Module):
    def __init__(self):
        super(Dendritic, self).__init__()

    def forward(self, x):
        x = torch.prod(x, 2)
        return x


class Synapse(nn.Module):

    def __init__(self, w, q, k):
        super(Synapse, self).__init__()
        self.params = nn.ParameterDict({'w': nn.Parameter(w)})
        self.params.update({'q': nn.Parameter(q)})
        self.k = k

    def forward(self, x):
        num, _ = self.params['w'].shape
        x = torch.unsqueeze(x, 1)
        x = x.repeat((1, num, 1))
        y = 1 / (1 + torch.exp(
            torch.multiply(-self.k, (torch.multiply(x, self.params['w']) - self.params['q']))))

        return y


class DNM(nn.Module):
    def __init__(self, w, q, k, qs):
        super(DNM, self).__init__()
        self.model = nn.Sequential(
            Synapse(w, q, k),
            Dendritic(),
            Membrane(),
            Soma(k, qs)
        )

    def forward(self, x):
        x = self.model(x)
        return x


def load_data(abs_path, name):
    print('======         Read %s Dataset        ======' % name)
    data = np.loadtxt(os.path.join(abs_path, 'Dataset', name), delimiter=',')
    train_set, test_set = train_test_split(data, test_size=0.3, random_state=42)
    train_data = train_set[:, :-1]

    train_label = np.int64(train_set[:, -1])
    train_label = train_label
    test_data = test_set[:, :-1]
    test_label = np.int64(test_set[:, -1])
    test_label = test_label
    return train_data, train_label, test_data, test_label


st = time.time()
M = 20
abs_path = os.getcwd()
name = 'cleveland'
epochs = 200
train_data, train_label, test_data, test_label = load_data(abs_path, name)
num, dim = train_data.shape

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cuda:0')

train_data = torch.from_numpy(train_data)
train_label = torch.from_numpy(train_label)
test_data = torch.from_numpy(test_data)
test_label = torch.from_numpy(test_label)
data, target = train_data.to(device), train_label.to(device)
train_acc = []
test_acc = []

for run in range(52):

    w = torch.rand([M, dim]).to(device)
    q = torch.rand([M, dim]).to(device)
    k = torch.tensor(5).to(device)
    qs = torch.tensor(0.3).to(device)

    learning_rate = 0.001

    net = DNM(w, q, k, qs).to(device)
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    for epoch in range(epochs):
        logits = net(data)
        a = logits - target
        # print(a)
        loss = torch.mean((logits - target) ** 2)
        optimizer.zero_grad()  
        loss.backward()  
        optimizer.step()  
        print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, loss.item()))
        if epoch == (epochs - 1):
            classifi = torch.where(logits > 0.5, 1, 0)
            train_acc.append((sum(classifi == target) / len(classifi)).item())
            # print(train_acc.item())

    test_data, test_target = test_data.to(device), test_label.to(device)

    test_fit = net(test_data)
    test_fit = torch.where(test_fit > 0.5, 1, 0)
    test_acc.append((sum(test_fit == test_target) / len(test_fit)).item())
    print(run, ' train ACC:', train_acc[run], ' test ACC:', test_acc[run])
print('mean train:', np.mean(train_acc), ' test:', np.mean(test_acc))
print(time.time() - st)
