import numpy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchmetrics import JaccardIndex
from torchmetrics.functional import dice_score
# from torchmetrics import Dice
# from tensorboardX import SummaryWriter
import dataset
import logging
import model_net
from model_net import *
from dataset import *
from PIL import Image
import pdb
from medpy import metric
from torchvision.datasets import ImageFolder
import os
# from utils.dice_score import multiclass_dice_coeff, dice_coeff
import torchvision.transforms as TF

from model.segnet import SegNet
from model.unet_model import R2U_Net, AttU_Net, R2AttU_Net, U_Net
# from model.unext import UNext
from model.transunet_model import TransUNet
from model.sknet import SKNet26
from model.nestedUNet import NestedUNet

from torchmetrics.functional import precision_recall
from torchmetrics import Specificity, JaccardIndex
import argparse



image_files = [os.path.join(".//pic", filename)
               for filename in os.listdir(".//pic")
               if filename.endswith(".png")]
            
# 按照文件名排序（假设文件名以“数字_数字_模型名称.png”的格式命名）
# pdb.set_trace()
# image_files = sorted(image_files, key=lambda x: (tuple(map(int, os.path.basename(x).split('_')[:2])), os.path.basename(x).split('_')[-1]))

#model_order = {"DRCNET": 0, "UNet": 1, "SegNet": 2, "R2U_Net": 3, "AttU_Net": 4, "R2AttU_Net": 5, "NestedUNet": 6, "AAUnet": 7}

#image_files = sorted(image_files, key=lambda x: (tuple(map(int, os.path.basename(x).split('_')[:2])), model_order[os.path.basename(x).split('_')[-1]]))


# model_order = {"DRCNET": 0, "UNet": 1, "SegNet": 2, "R2UNet": 3, "AttUNet": 4, "R2AttUNet": 5, "NestedUNet": 6, "AAUnet": 7}

# image_files = sorted(image_files, key=lambda x: (tuple(map(int, os.path.basename(x).split('_')[:-2])), model_order[os.path.basename(x).split('_')[-1].split('.')[0]]))

model_order1 = {"Mask": 0,  "Ours": 1, "U-Net": 2, "AttUNet": 3, "MBSNet": 4, "UNet++": 5, "AAU-net": 6, "SegNet": 7,  "TransUNet": 8,  "RRCnet": 9,  "ICNet": 10}
model_order = {"mask": 0,  "DRCNET": 1, "UNet": 2, "AttUNet": 3, "mbsnet": 4, "NestedUNet": 5, "AAUnet": 6, "SegNet": 7,  "transunet": 8,  "RRCnet": 9,  "ICDnet": 10}


def sort_key(file_name):
    parts = os.path.basename(file_name).split('_')
    num_parts = tuple(map(int, parts[:-1]))
    model_name = parts[-1].split('.')[0]
    model_order_value = model_order[model_name]
    return (num_parts, model_order_value)

image_files = sorted(image_files, key=sort_key)

# pdb.set_trace()

# 0, 66, 77, 88, 121, 132, 165, 242, 253
# pdb.set_trace()
# 将图像文件划分为每n个图像的列表
image_lists = [image_files[i:i+11] for i in range(0, len(image_files), 11)]
# pdb.set_trace()
# 遍历每个子列表，并将n个图像放在同一行中
for image_list in image_lists:
    fig, axs = plt.subplots(1, 11, figsize=(20, 5))
    fig.subplots_adjust(wspace=0.05)  # 调整子图之间的宽度
    for i, image_file in enumerate(image_list):
        image = plt.imread(image_file)
        axs[i].imshow(image, extent=[0, image.shape[1], 0, image.shape[0]])
        axs[i].axis('off')
        model_name = os.path.basename(image_file).split('_')[2].split('.')[0]  # 从文件名中获取模型名称
        if model_name == "mask":
            model_name = "Mask"
        elif model_name == "DRCNET":
            model_name = "Ours"
        elif model_name == "UNet":
            model_name = "U-Net"
        elif model_name == "mbsnet":
            model_name = "MBSnet"
        elif model_name == "NestedUNet":
            model_name = "U-Net++"
        elif model_name == "transunet":
            model_name = "TransUNet"
        elif model_name == "ICDnet":
            model_name = "IC-Net"   
        elif model_name == "AttUNet":
            model_name = "AttU-Net"
        axs[i].text(0.5, -0.15, model_name, size=14, ha="center", transform=axs[i].transAxes)  # 在图像下方添加模型名称
    plt.savefig('./pic/row_{}.png'.format(image_files.index(image_list[0])))
    plt.close()

