import numpy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchmetrics import JaccardIndex
# from torchmetrics.functional import dice_score
# from torchmetrics import Dice
# from tensorboardX import SummaryWriter
import dataset
import logging
import model_net
from model_net import *
from dataset import *
from PIL import Image
import pdb
from medpy import metric
from torchvision.datasets import ImageFolder
import os
# from utils.dice_score import multiclass_dice_coeff, dice_coeff
import torchvision.transforms as TF
# from torchmetrics.functional import precision_recall
from torchmetrics import Specificity, JaccardIndex
import argparse
from PIL import Image, ImageDraw

parse = argparse.ArgumentParser()
parse = argparse.ArgumentParser()
# parse.add_argument("action", type=str, help="train or test")
parse.add_argument("--log_name", type=str, default="./log/test.log")
parse.add_argument("--batch_size", type=int, default=1)
parse.add_argument("--EPOCH", type=int, default=100)
parse.add_argument("--LR", type=float, default=0.0001)
parse.add_argument("--DEVICE", type=int, default=0)
parse.add_argument("--M", type=int, default=10)
parse.add_argument("--DNM1", type=int, default=1)
parse.add_argument("--DNM2", type=int, default=1)
parse.add_argument("--ckpt", type=str, help="the path of model weight file")
args = parse.parse_args()
if args.DEVICE == 0:
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 使用GPU或者CPU训练
else:
    DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")  # 使用GPU或者CPU训练

mean_nums = [0.485, 0.456, 0.406]
std_nums = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((384, 384), Image.BILINEAR),
    # transforms.RandomResizedCrop(224),#Resizes all images into same dimension
    # transforms.RandomRoation(10),# Rotates the images upto Max of 10 Degrees
    # transforms.RandomHorizontalFlip(p=0.4),#Performs Horizantal Flip over images
    # transforms.RandomVerticalFlip(p=0.4),
    # transforms.RandomRotation(15),
    # transforms.RandomRotation(90, expand=True),
    # transforms.RandomHorizontalFlip(p=1.0),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  # Coverts into Tensors
    # transforms.Normalize(mean=mean_nums, std=std_nums)  # Normalizes
    # transforms.Normalize((.5,.5,.5), (.5,.5,.5))
    # normalize
])

transform_test = transforms.Compose([
    transforms.Resize((384, 384), Image.BILINEAR),
    # transforms.Grayscale(num_output_channels=1),
    # transforms.RandomResizedCrop(224),
    # transforms.CenterCrop(224), #Performs Crop at Center and resizes it to 224
    transforms.ToTensor(),
    # transforms.Normalize(mean = mean_nums, std=std_nums) # Normalizes
    # transforms.Normalize((.5,.5,.5), (.5,.5,.5))
])

filepath_busi = './data/Dataset_BUSI/Dataset_BUSI_with_GT/'
filepath_bus = './data/BUS/BUS/'
filepath_busi_m = './data/Dataset_BUSI_malignant/Dataset_BUSI_with_GT/'
filepath_cloth = './data/archive/'
filepath_Polyp = './data/Kvasir-SEG/'

filepath = filepath_bus

imagefilepath = filepath + 'data/train/'
imagefilepath_label = filepath + 'data/trainannot/'

valfilepath = filepath + 'data/val/'
valfilepath_label = filepath + 'data/valannot/'

train_dataset = dataset.Busi(imagefilepath, imagefilepath_label, transform, transform_test)
test_dataset = dataset.Busi(valfilepath, valfilepath_label, transform, transform_test)
# print(train_dataset.class_to_idx)


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)


# train_size = len(train_loader.dataset)
# test_size = len(test_loader.dataset) #incorrect
# train_num_batches = len(train_loader)
# test_num_batches = len(test_loader)


class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        # comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)

        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)

        return 1 - dice


# 损失函数和模型调用
# criterion = nn.CrossEntropyLoss()
criterion = nn.BCELoss()
criterion_mse = nn.MSELoss()
criterion_dice = DiceLoss()

model = model_net.DKNet(DEVICE, args.M, args.DNM1, args.DNM2)
model.to(DEVICE)

# 预训练模型和优化器的选用：
# pretrained_model = "./log/bus_0.5loss_norm.log.pth"
pretrained_model = "./model_pth/bus_0.5loss.log.pth"
# pretrained_model = "./model_pth/bus_0.5loss.log.pth"
# pretrained_model = "./model_pth/ori.pth"

# optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.09)
optimizer = optim.Adam(model.parameters(), lr=args.LR)

# 预训练模型加载
pretrained = 1
if pretrained:
    pretrain_model = model_net.DKNet(DEVICE, args.M, args.DNM1, args.DNM2)
    pre_dic = torch.load(pretrained_model)
    pretrain_model.load_state_dict(pre_dic["model_static_dict"])
    model_dict = model.state_dict()
    pretrained_dict = pretrain_model.state_dict()
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}  # 选择相同的部分
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    model.eval()


def calculate_metric_percase(pred, gt):
    # pdb.set_trace()
    if torch.is_tensor(pred):
        predict = pred.data.cpu().numpy()
    if torch.is_tensor(gt):
        target = gt.data.cpu().numpy()

    pred = numpy.atleast_1d(predict.astype(numpy.bool))
    gt = numpy.atleast_1d(target.astype(numpy.bool))

    dice = metric.binary.dc(pred, gt)
    jc = metric.binary.jc(pred, gt)
    pre = metric.binary.precision(pred, gt)
    rec = metric.binary.recall(pred, gt)
    spe = metric.binary.specificity(pred, gt)
    return dice, jc, pre, rec, spe


def test(epoch):
    model.eval()
    dice_score = 0
    jaccard_score = 0
    pre_score = 0
    recall_score = 0
    spe_score = 0
    sum_total_loss = 0
    loss_sum = [0 for i in range(8)]
    for batch_idx, (img, mask_true) in tqdm(enumerate(test_loader)):
        img, label = img.to(DEVICE), mask_true.to(DEVICE)
        with torch.no_grad():
            output = model(img)
            # mask_pred = (Out1>0.5).float()
            mask_pred = torch.where(output[2] > 0.9, 1., 0.) # 0 DCN 1: DUnet 2: Segnet
            print(mask_pred.shape)
            # dice, jc, pre, rec, spe = calculate_metric_percase(mask_pred.float().cpu(), mask_true.to(device=DEVICE, dtype=torch.long).float().cpu())
            # pdb.set_trace()
            for i in range(args.batch_size):
                # 标记mask_true的边界
                contours_true, hierarchy_true = cv2.findContours(
                    mask_true[i].cpu().squeeze().numpy().astype('uint8'),
                    cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

                img_true = transforms.ToTensor()(mask_true[i].cpu().squeeze().numpy().astype('uint8'))
                img_true = img_true * 255
                img_true = img_true.permute(1, 2, 0).numpy()
                # img_true = cv2.cvtColor(img_true, cv2.COLOR_RGB2BGR)
                img_true = cv2.cvtColor(img_true, cv2.COLOR_GRAY2BGR)
                cv2.drawContours(img_true, contours_true, -1, (0, 0, 255), 1)
                # pdb.set_trace()

                # 标记mask_pred的边界
                contours_pred, hierarchy_pred = cv2.findContours(
                    mask_pred[i].cpu().squeeze().numpy().astype('uint8'),
                    cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

                img_pred = transforms.ToTensor()(mask_pred[i].cpu().squeeze().numpy().astype('uint8'))
                img_pred = img_pred * 255
                img_pred = img_pred.permute(1, 2, 0).numpy()
                # img_pred = cv2.cvtColor(img_pred, cv2.COLOR_RGB2BGR)
                img_pred = cv2.cvtColor(img_pred, cv2.COLOR_GRAY2BGR)
                
                cv2.drawContours(img_pred, contours_pred, -1, (0, 255, 0), 1)  # 用绿色填充轮廓
                # 把img_true和img_pred叠加
                
                mask = img_pred + img_true
                mask[mask > 255] = 255
                # 计算交集
                intersection = cv2.bitwise_and(img_true, img_pred)
                
                # 将交集部分设为白色
                intersection[intersection > 0] = 255
                
                # pdb.set_trace()
                
                mask_pic = mask + intersection

                
                
                # img_blend = cv2.addWeighted(img_true, 0.5, img_pred, 0.5, 0)
                # pdb.set_trace()
                # 显示标记mask_true和mask_pred轮廓后的图像
                # cv2.imshow('Blend', img_blend)
                # cv2.waitKey(0)
                # img_blend = mask
                # pdb.set_trace()


                cv2.imwrite('./pic_RRCnet/2/{}_{}_Conv.png'.format(batch_idx, i), mask_pic)

               
       
       
def predicted_ori2(epoch):
    model.eval()
    dice_score = 0
    jaccard_score = 0
    pre_score = 0
    recall_score = 0
    spe_score = 0
    sum_total_loss = 0
    loss_sum = [0 for i in range(8)]
    for batch_idx, (img, mask_true) in tqdm(enumerate(test_loader)):
        img, label = img.to(DEVICE), mask_true.to(DEVICE)
        with torch.no_grad():
            output = model(img)
            # mask_pred = (Out1>0.5).float()
            
            mask_pred = torch.where(output[1] > 0.5, 1., 0.)
            mask_true = torch.where(mask_true > 0.5, 1., 0.)
            
            for i in range(args.batch_size):
                contours_true, hierarchy_true = cv2.findContours(
                    mask_true[i].cpu().squeeze().numpy().astype('uint8'),
                    cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                    
                img_true = transforms.ToTensor()(mask_true[i].cpu().squeeze().numpy().astype('uint8'))
                img_true = img_true * 255
                img_true = img_true.permute(1, 2, 0).numpy()
                img_true = cv2.cvtColor(img_true, cv2.COLOR_GRAY2BGR)
                cv2.drawContours(img_true, contours_true, -1, (0, 0, 255), 1)
                # img_pil = Image.fromarray(img_true.transpose((1, 2, 0)))
                
                
                # 将原图从tensor类型转为ndarray类型，并将通道顺序从(C, H, W)转为(H, W, C)
                img_np = img[i].cpu().numpy().transpose((1, 2, 0))
                
                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

                # 在原图上绘制轮廓线
                img_np = (img_np * 255).astype(np.uint8)
                cv2.drawContours(img_np, contours_true, -1, (0, 0, 255), thickness=2)     
                # pdb.set_trace()
                # img_np += img_true
                # img_np.
                # img_np = (img_np * 255).astype(np.uint8)
                # img_np[img_np>255] = 255
                # img_pil = Image.fromarray(img_np)
                
                # img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
                # 保存图片
                # img_pil.save('./pic/{}_{}_mask.png'.format(batch_idx, i))
                print('****')
                cv2.imwrite('./pic/{}_{}_mask.png'.format(batch_idx, i), img_np)
                

def test_report(epoch):
    model.eval()
    dice_score = 0
    jaccard_score = 0
    pre_score = 0
    recall_score = 0
    spe_score = 0
    sum_total_loss = 0
    loss_sum = [0 for i in range(8)]
    for batch_idx, (img, mask_true) in tqdm(enumerate(test_loader)):
        img, label = img.to(DEVICE), mask_true.to(DEVICE)
        with torch.no_grad():
            output = model(img)
            mask_pred = torch.where(output[1] > 0.5, 1., 0.)
            mask_pred = mask_pred.float().cpu()
            mask_true = mask_true.to(device=DEVICE, dtype=torch.long).float().cpu()
            loss = criterion_mse(mask_true, mask_pred)
            sum_total_loss += loss.data.item()

            dice, jc, pre, rec, spe = calculate_metric_percase(mask_pred, mask_true)
            pre_score += pre
            recall_score += rec
            dice_score += dice
            jaccard_score += jc
            spe_score += spe
            # pdb.set_trace()
            # dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
    print("Test Epoch: {}".format(epoch))
    print("pre_score: \t{:.4f}".format(pre_score / len(test_loader)))
    print("recall_score: \t{:.4f}".format(recall_score / len(test_loader)))
    print("dice_score: \t{:.4f}".format(dice_score / len(test_loader)))
    print("jaccard_score: \t{:.4f}".format(jaccard_score / len(test_loader)))
    print("spe_score: \t{:.4f}".format(spe_score / len(test_loader)))
    print("test_loss: \t{:.4f}".format(sum_total_loss / len(test_loader)))
    # logger.info(
    #     "Epoch: {}".format(epoch))
    # logger.info("pre_score: \t{:.4f}".format(pre_score / len(test_loader)))
    # logger.info("recall_score: \t{:.4f}".format(recall_score / len(test_loader)))
    # logger.info("dice_score: \t{:.4f}".format(dice_score / len(test_loader)))
    # logger.info("jaccard_score: {:.4f}".format(jaccard_score / len(test_loader)))
    # logger.info("spe_score: \t{:.4f}".format(spe_score / len(test_loader)))
    return pre_score / len(test_loader), recall_score / len(test_loader), dice_score / len(
        test_loader), jaccard_score / len(test_loader), spe_score / len(test_loader), sum_total_loss / len(test_loader)

         
def adjust_learning_rate(optimizer, epoch):  # 学习率自动调整
    if epoch % 80 == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1


# best = 0.9
if __name__ == '__main__':
    for epoch in range(1, 2):
        test(epoch)
        # test_report(epoch)
        # predicted_ori(epoch)
        # predicted_ori2(epoch)
        # pre, recall, dice, jaccard, spe, test_loss = test(epoch)
        # logger.info(f'Epoch {epoch}: train_loss={train_loss:.4f}, '
        #             f'pre={pre:.4f}, recall={recall:.4f}, dice={dice:.4f}, jaccard={jaccard:.4f},spe={spe:.4f},test_loss={test_loss:.4f},')
        # if min > test_loss:
        #     min = test_loss
        #     checkpoint = {
        #         "model_static_dict": model.state_dict(),
        #         "epoch": epoch,
        #         "optimizer_state_dic": optimizer.state_dict()
        #     }
        #     torch.save(checkpoint, args.log_name + '.pth')
    # sen, spe = test()
    # if epoch > 40:
    #     # sen, spe = test()  # 调用测试函数
    #     # if (sen + spe) / 2 > best:  # 断点存储模型方便下次训练
    #     #     best = (sen + spe) / 2

    # adjust_learning_rate(optimizer, epoch)  # 调用学习率自动调整函数
