import torch
from torch import nn
import torch.nn.functional as F
import math
from typical_models import KANLinear

#残差+LSTM
class Res_LSTM_model(nn.Module):
    def __init__(self, feature_num, res_channel=32, conv_out_channel=64, lstm_hidden=128):
        super(Res_LSTM_model, self).__init__()

        self.feature_num = feature_num
        self.res_channel = res_channel
        self.lstm_hidden = lstm_hidden
        self.conv_out_channel = conv_out_channel

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, res_channel, kernel_size=(1, 1), stride=(1, 1)),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(res_channel, conv_out_channel, kernel_size=(1, feature_num), stride=(1, 1)),
            nn.ReLU()
        )

        self.lstm1 = nn.LSTM(input_size=conv_out_channel, hidden_size=lstm_hidden, batch_first=True, bidirectional=False)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(lstm_hidden, 64)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, input):
        tem = self.conv1(input)
        tem = tem + input.repeat(1, self.res_channel, 1, 1)
        tem = self.conv2(tem)
        tem = tem.squeeze()
        tem = tem.permute([0, 2, 1])
        tem, _ = self.lstm1(tem)
        tem = tem[:, -1, :]  # 取LSTM的最后一个时间步作为输出
        # tem = self.dropout(tem)
        tem = F.relu(self.fc1(tem))
        out = self.fc2(tem)

        return out


class Res_GRU_model(nn.Module):
    def __init__(self, feature_num, res_channel=32, conv_out_channel=64, gru_hidden=128):
        super(Res_GRU_model, self).__init__()

        self.feature_num = feature_num
        self.res_channel = res_channel
        self.gru_hidden = gru_hidden
        self.conv_out_channel = conv_out_channel

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, res_channel, kernel_size=(1, 1), stride=(1, 1)),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(res_channel, conv_out_channel, kernel_size=(1, feature_num), stride=(1, 1)),
            nn.ReLU()
        )

        self.gru = nn.GRU(input_size=conv_out_channel, hidden_size=gru_hidden, batch_first=True, bidirectional=False)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(gru_hidden, 64)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, input):
        tem = self.conv1(input)
        tem = tem + input.repeat(1, self.res_channel, 1, 1)
        tem = self.conv2(tem)
        tem = tem.squeeze()
        tem = tem.permute([0, 2, 1])
        tem, _ = self.gru(tem)
        tem = tem[:, -1, :]  # 取LSTM的最后一个时间步作为输出
        # tem = self.dropout(tem)
        tem = F.relu(self.fc1(tem))
        out = self.fc2(tem)

        return out


class GRU_KANmodel(nn.Module):
    def __init__(self, feature_num, gru_hidden=128):
        super(GRU_KANmodel, self).__init__()

        self.feature_num = feature_num
        self.gru_hidden = gru_hidden

        self.gru = nn.GRU(input_size=feature_num, hidden_size=gru_hidden, batch_first=True, bidirectional=False)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = KANLinear(gru_hidden, 64)
        self.fc2 = KANLinear(64, 1)

    def forward(self, input):
        tem, _ = self.gru(input)
        tem = tem[:, -1, :]  # 取LSTM的最后一个时间步作为输出
        # tem = self.dropout(tem)
        tem = F.relu(self.fc1(tem))
        out = self.fc2(tem)

        return out

class Res_KANmodel(nn.Module):
    def __init__(self, feature_num, n_hours, res_channel=32, conv_out_channel=64):
        super(Res_KANmodel, self).__init__()

        self.feature_num = feature_num
        self.n_hours = n_hours
        self.res_channel = res_channel
        self.conv_out_channel = conv_out_channel

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, res_channel, kernel_size=(1, 1), stride=(1, 1)),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(res_channel, conv_out_channel, kernel_size=(1, feature_num), stride=(1, 1)),
            nn.ReLU()
        )

        self.dropout = nn.Dropout(0.25)
        self.flat = nn.Flatten()
        self.fc1 = KANLinear(n_hours*conv_out_channel, 64)
        self.fc2 = KANLinear(64, 1)

    def forward(self, input):
        tem = self.conv1(input)
        tem = tem + input.repeat(1, self.res_channel, 1, 1)
        tem = self.conv2(tem)
        tem = tem.squeeze(axis=-1)
        tem = tem.permute([0, 2, 1])
        tem = self.flat(tem)
        # tem = self.dropout(tem)
        tem = F.relu(self.fc1(tem))
        out = self.fc2(tem)

        return out

class Res_GRU_KANmodel(nn.Module):
    def __init__(self, feature_num, res_channel=32, conv_out_channel=64, gru_hidden=128):
        super(Res_GRU_KANmodel, self).__init__()

        self.feature_num = feature_num
        self.res_channel = res_channel
        self.gru_hidden = gru_hidden
        self.conv_out_channel = conv_out_channel

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, res_channel, kernel_size=(1, 1), stride=(1, 1)),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(res_channel, conv_out_channel, kernel_size=(1, feature_num), stride=(1, 1)),
            nn.ReLU()
        )

        self.gru = nn.GRU(input_size=conv_out_channel, hidden_size=gru_hidden, batch_first=True, bidirectional=False)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = KANLinear(gru_hidden, 64)
        self.fc2 = KANLinear(64, 1)

    def forward(self, input):
        tem = self.conv1(input)
        tem = tem + input.repeat(1, self.res_channel, 1, 1)
        tem = self.conv2(tem)
        tem = tem.squeeze(axis=-1)
        tem = tem.permute([0, 2, 1])
        tem, _ = self.gru(tem)
        tem = tem[:, -1, :]  # 取LSTM的最后一个时间步作为输出
        # tem = self.dropout(tem)
        tem = F.relu(self.fc1(tem))
        out = self.fc2(tem)

        return out


def main():
    x = torch.randn(32, 1, 5, 4)
    # model = Res_LSTM_model(4, res_channel=32, conv_out_channel=64, lstm_hidden=128)
    # model = Res_GRU_model(4, res_channel=32, conv_out_channel=64, gru_hidden=128)
    # model = GRU_KANmodel(4, gru_hidden=128)
    model = Res_KANmodel(4, n_hours=5, res_channel=32, conv_out_channel=64)
    # model = Res_GRU_KANmodel(4, res_channel=32, conv_out_channel=64, gru_hidden=128)
    out = model(x)

    #打印出model的总参数量
    # print(sum(p.numel() for p in model.parameters()))

    print(out.shape)

    # for name, param in model.named_parameters():
    #     print(name, param.shape)


if __name__ == "__main__":
    main()
