import torch
import torch.nn as nn
import torch.nn.functional as F
# coding=utf-8
import pdb

import matplotlib
import matplotlib.pyplot as plt
import argparse
import numpy as np
from sklearn.preprocessing import LabelEncoder

from PIL import Image
import cv2
import random
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

# from torchinfo import summary


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Soma(nn.Module):
    def __init__(self, k, qs):
        super(Soma, self).__init__()
        self.params = nn.ParameterDict({'k': nn.Parameter(k)})
        self.params.update({'qs': nn.Parameter(qs)})

    def forward(self, x):
        y = 1 / (1 + torch.exp(-self.params['k'] * (x - self.params['qs'])))
        return y


class Membrane(nn.Module):
    def __init__(self):
        super(Membrane, self).__init__()

    def forward(self, x):
        x = torch.sum(x, 1) 
        return x


class Dendritic(nn.Module):
    def __init__(self):
        super(Dendritic, self).__init__()

    def forward(self, x):
        x = torch.prod(x, 2)  # prod or sum 
        return x


class Synapse(nn.Module):

    def __init__(self, w, q, k):
        super(Synapse, self).__init__()
        self.params = nn.ParameterDict({'w': nn.Parameter(w)})
        self.params.update({'q': nn.Parameter(q)})
        self.params.update({'k': nn.Parameter(k)})

    def forward(self, x):
        num, _ = self.params['w'].shape
        x = torch.unsqueeze(x, 1)
        x = x.repeat((1, num, 1))  # copy m
        y = 1 / (1 + torch.exp(
            torch.mul(-self.params['k'], (torch.mul(x, self.params['w']) - self.params['q']))))  # k*(w*x-q)

        return y


class BASE_DNM(nn.Module):
    def __init__(self, dim, M, kv=5, qv=0.3):  # , device=torch.device('cuda:0')):

        w = torch.rand([M, dim])  # .to(device)
        q = torch.rand([M, dim])  # .to(device)
        # k = torch.tensor(kv)
        # qs = torch.tensor(qv)
        k = torch.rand(1)
        qs = torch.rand(1)

        super(BASE_DNM, self).__init__()
        self.model = nn.Sequential(
            Synapse(w, q, k),
            Dendritic(),
            Membrane(),
            Soma(k, qs)
        )

    def forward(self, x):
        x = self.model(x)
        return x


def visualize_feature_map(x):
    # x是包含特征的张量，假设它的形状是 (batch_size, num_features, height, width)
    
    num_features = x.shape[1]
    rows, cols = 8, 8  # 8行8列
    
    fig, axes = plt.subplots(rows, cols, figsize=(16, 16))
    plt.subplots_adjust(wspace=0.01, hspace=0.1)  # 调整子图之间的水平和垂直间距
    
    for i in range(rows):
        for j in range(cols):
            feature_index = i * cols + j
            if feature_index < num_features:
                axes[i, j].imshow(x[4, feature_index].cpu().numpy(), cmap='viridis')  # 假设每个特征是灰度图像
                axes[i, j].axis('off')
    plt.show()

def pillow_feature(x):
    # 假设 x 是包含特征的 PyTorch 张量
    x = x.cpu().numpy()  # 将特征转换为 NumPy 数组

    # 可视化特征
    fig, axes = plt.subplots(1, x.shape[0], figsize=(16, 4))
    for i in range(x.shape[0]):
        axes[i].imshow(x[i][0], cmap='viridis')  # 假设每个特征是灰度图像
        axes[i].axis('off')

    plt.show()

def pillow_combine(x):
    # 将64张图像叠加成一个大图像
    combined_image = np.zeros((384, 384), dtype=np.float32)
    for i in range(64):
        combined_image += x[i].cpu().numpy()
    # 显示叠加后的大图像
    plt.imshow(combined_image, cmap='viridis')
    plt.axis('off')
    plt.show()


def combined_image(x, cnt):
    x = x.cpu().numpy()
    num_groups = x.shape[0]  # 总组数
    num_channels = x.shape[-1]  # 每组的通道数
    # pdb.set_trace()
    for batch in range(num_groups):
        combined_image = np.zeros((384, 384), dtype=np.float32)
        for channel in range(num_channels):
            combined_image += x[batch, :, :, 0, channel]
            if cnt % 2 == 0:
                plt.imshow(x[batch, :, :, 0, channel], cmap='gray')
            else:
                plt.imshow(x[batch, :, :, 0, channel], cmap='gray_r')           
            plt.axis('off')
            plt.title(f'Batch {batch + 6*cnt + 1}_channel_{channel + 1}')
            plt.savefig(f'./pic_feature/batch_{batch + 6*cnt + 1}_channel_{channel + 1}_combined_image.png')
        plt.imshow(combined_image, cmap='gray')
        plt.axis('off')
        plt.title(f'Batch {batch + 6*cnt + 1}')
        plt.savefig(f'./pic_feature/batch_{batch + 6*cnt + 1}_combined_image.png')

def combined_image_dnm(x):
    x = x.cpu().numpy()
    num_groups = x.shape[0]  # 总组数
    num_channels = x.shape[-1]  # 每组的通道数
    # pdb.set_trace()
    for batch in range(num_groups):
        combined_image = np.zeros((384, 384), dtype=np.float32)
        for channel in range(num_channels):
            combined_image += x[batch, :, :, 0, channel]
            plt.imshow(x[batch, :, :, 0, channel], cmap='gray')
            plt.axis('off')
            plt.title(f'Batch {batch + 1}_channel_{channel + 1}')
            plt.savefig(f'./pic_feature_dnm/batch_{batch + 1}_channel_{channel + 1}_combined_image.png')
        plt.imshow(combined_image, cmap='gray')
        plt.axis('off')
        plt.title(f'Batch {batch + 1}')
        plt.savefig(f'./pic_feature_dnm/batch_{batch + 1}_combined_image.png')

def pillow_feature(x, cnt):
    # 假设 x 是包含特征的 PyTorch 张量
    x = x.cpu().numpy()  # 将特征转换为 NumPy 数组

    # 可视化特征
    for i in range(x.shape[0]):
        if cnt % 2 == 0:
            plt.imshow(x[i][0], cmap='viridis')  # 假设每个特征是灰度图像
        else:
            plt.imshow(x[i][0], cmap='viridis')  # 使用 'gray_r' colormap，反转颜色
            # plt.imshow(x[i][0], cmap='gray_r')  # 使用 'gray_r' colormap，反转颜色
        
        plt.axis('off')
        plt.title(f'Batch {i + 6*cnt + 1}_dnm.png')
        plt.savefig(f'./pic_feature/batch_{i + 6*cnt + 1}_dnm.png')
    # plt.show()


def pillow_channel_0(x):
    # 假设 x 是包含特征的 PyTorch 张量
    x = x.cpu().numpy()  # 将特征转换为 NumPy 数组
    # 可视化特征
    for j in range(x.shape[1]):
        plt.imshow(x[3, j, :, :], cmap='viridis')
        # plt.imshow(x[i, j, :, :], cmap='gray')
        plt.axis('off')
        # plt.title(f'Batch {i} Channel{j} dnm_2.png')
        plt.savefig(f'./pic_feature/pic_channel/channel_{j}_dnm_0.png', bbox_inches='tight', pad_inches=0)
    
    # plt.show()

def pillow_DNM_feature(x, cnt):
    # 假设 x 是包含特征的 PyTorch 张量
    
    x = torch.permute(x,(0, 5, 4, 3, 1, 2))  # torch.Size([6, 64, 10, 1, 384, 384])
    x = torch.squeeze(x)                     # torch.Size([6, 64, 10, 384, 384])
    x = x.cpu().numpy()  # 将特征转换为 NumPy 数组
    
    # 可视化特征
    for i in range(x.shape[0]):
        for i in range(x.shape[0]):
            if cnt % 2 == 0:
                plt.imshow(x[i][0], cmap='viridis')  # 假设每个特征是灰度图像
            else:
                plt.imshow(x[i][0], cmap='viridis')  # 使用 'gray_r' colormap，反转颜色
                # plt.imshow(x[i][0], cmap='gray_r')  # 使用 'gray_r' colormap，反转颜色
        
        plt.axis('off')
        plt.title(f'Batch {i + 6*cnt + 1}_dnm.png')
        plt.savefig(f'./pic_feature/batch_{i + 6*cnt + 1}_dnm.png')
    # plt.show()


def pillow_channel_1(x):  # torch.Size([6, 64, 10, 384, 384])
    # 假设 x 是包含特征的 PyTorch 张量
    x = x.cpu().numpy()  # 将特征转换为 NumPy 数组
    # 可视化特征
    for j in range(x.shape[1]):
        # '''
        plt.imshow(x[3, j, 0, :, :], cmap='viridis')
        plt.axis('off')
        plt.savefig(f'./pic_feature/pic_channel_dnm_0thm/channel_{j}_dnm_0.png', bbox_inches='tight', pad_inches=0)
        
        plt.imshow(x[3, j, 9, :, :], cmap='viridis')
        plt.axis('off')
        plt.savefig(f'./pic_feature/pic_channel_dnm_9thm/channel_{j}_dnm_0.png', bbox_inches='tight', pad_inches=0)
        # '''  
        plt.imshow(x[3, j, 1, :, :], cmap='viridis')
        plt.axis('off')
        plt.savefig(f'./pic_feature/pic_channel_dnm_1thm/channel_{j}_dnm_0.png', bbox_inches='tight', pad_inches=0)
    # plt.show()
    
    
def pillow_channel_2(x):
    # 假设 x 是包含特征的 PyTorch 张量
    x = x.cpu().numpy()  # 将特征转换为 NumPy 数组
    # 可视化特征
    for j in range(x.shape[1]):
        plt.imshow(x[3, j, :, :], cmap='viridis')
        # plt.imshow(x[i, j, :, :], cmap='gray')
        plt.axis('off')
        # plt.title(f'Batch {i} Channel{j} dnm_2.png')
        plt.savefig(f'./pic_feature/pic_channel_output/channel_{j}_dnm_output.png', bbox_inches='tight', pad_inches=0)
    
    # plt.show()


# cnt = 1
class DNM_Conv(nn.Module):
    cnt = 0
    
    def __init__(self, input_size, out_size, M, activation=F.relu):
        super(DNM_Conv, self).__init__()
        DNM_W = torch.rand([out_size, M, input_size])  # .cuda() # [size_out, M, size_in]  [num_class, M, 512 * 3 * 3]
        DNM_q = torch.rand([out_size, M, input_size])
        dendritic_W = torch.rand([input_size])  # .cuda() # size_out, M, size_in]
        membrane_W = torch.rand([M])  # .cuda() # size_out, M, size_in]
        qs = torch.rand(1)
        qs = torch.tensor(qs).to(DEVICE)
        self.params = nn.ParameterDict({'DNM_W': nn.Parameter(DNM_W)})
        self.params.update({'q': nn.Parameter(DNM_q)})
        self.params.update({'dendritic_W': nn.Parameter(dendritic_W)})
        self.params.update({'membrane_W': nn.Parameter(membrane_W)})
        # self.k = k
        self.qs = qs
        self.activation = activation
        
        self.norm1 = nn.LayerNorm(input_size)
        self.norm2 = nn.LayerNorm(input_size)

    
    def forward(self, x):
        # Synapse
        out_size, M, _ = self.params['DNM_W'].shape
        
        # pdb.set_trace()
        '''
        DNM_Conv.cnt += 1
        if DNM_Conv.cnt == 1:
            pillow_channel_0(x)
        '''
        # x.shape (batch*64*384*384)    (0,1,2,3) ==> (0,2,3,1) => (0,1,2,3) (permute(0,3,1,2))
        x = torch.permute(x, (0, 2, 3, 1))
        x = torch.unsqueeze(x, 3)
        x = torch.unsqueeze(x, 4)          # [batch, 384, 384, 1, 1, 64]
        
        x = self.norm1(x) # layernorm
        
        x = x.repeat(1, 1, 1, out_size, M, 1) # [batch, 384, 384, 1, 10, 64]
        # x = F.relu(torch.mul(self.k, (torch.mul(x, self.params['DNM_W']) - self.params['q'])))
        x = F.relu(torch.mul(x, self.params['DNM_W']) - self.params['q'])
        # x = torch.mul(x, self.params['DNM_W'])
        # x = F.relu(self.k * (x - self.params['q']))
        
        # Dendritic
        x = self.norm2(x) # norm
        
        ''' 
        x = torch.permute(x,(0, 5, 4, 3, 1, 2))
        x = torch.squeeze(x)
        DNM_Conv.cnt += 1
        
        # pdb.set_trace()
        
        if DNM_Conv.cnt == 1:
            pillow_channel_1(x)
        '''
              
        # x = torch.mul(x, self.params['dendritic_W'])
        # x = x * self.params['dendritic_W']
        x = torch.sum(x, 5)                 # [batch, 384, 384, 1, 10]
        # x = torch.sigmoid(x)
        #x = F.relu(x)
        

        '''
        # pdb.set_trace()
        x = torch.permute(x,(0, 4, 3, 1, 2))
        x = torch.squeeze(x)
         
        DNM_Conv.cnt += 1
        if DNM_Conv.cnt == 1:
            pillow_channel_2(x)
        '''   
        
        # combined_image(x, DNM_Conv.cnt)
        # Membrane
        # x = torch.mul(x, self.params['membrane_W'])
        # x = x * self.params['membrane_W']
        x = torch.sum(x, 4)                 # [batch, 384, 384, 1]
        x = torch.permute(x, (0, 3, 1, 2))  # [batch, 1, 384, 384]
        # pdb.set_trace()
        # pillow_feature(x, DNM_Conv.cnt)
        
        # print(DNM_Conv.cnt)
        # pdb.set_trace()
        # Soma
        if self.activation != None:
            #x = self.activation(self.k * (x - self.qs))
            x = self.activation(x - self.qs)
        return x


class ConvDNMBase(nn.Module):
    def __init__(self, in_channal=64, out_channal=1, m=10):
        super(ConvDNMBase, self).__init__()
        self.inchannel = 64
        
        # self.k = DNM_Conv(in_channal, out_channal, 10, activation=None)
        self.DNM_Conv1 = DNM_Conv(in_channal, out_channal, m, activation=None)
        # self.DNM_Linear2 = DNM_Linear(1024, out_channal, 10, activation=None)

    def forward(self, x):
        out = self.DNM_Conv1(x)
        return out


class conv_dnm(nn.Module):  
    def __init__(self,  input_size, out_size, M):
        super(conv_dnm, self ).__init__()
        # self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), padding=1)
        self.Dnm_conv2d = ConvDNMBase(input_size, out_size, M)
        # self.conv2d = nn.Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))
        
    def forward(self, x):
        # input_size = x.size(0)  # batch_size
        # x = self.conv1(x)
        # pdb.set_trace()
        
        x = self.Dnm_conv2d(x)
        # x = self.conv2d(x)
        # out0 = F.sigmoid(x)

        return x


# model = conv_dnm()
# # # print(model)
# # #
# summary(model, (3, 384, 384))
