import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.feature_selection import mutual_info_regression
from sklearn.metrics import mutual_info_score

#计算文件时间序列中24小时的分别均值与方差
def calculate_mean_and_var(file_path):
    # 读取CSV文件
    df = pd.read_csv(file_path)
    #取df第二列
    df = df.iloc[:, 1]
    # 将数据转换为NumPy数组
    power = df.values
    power = power.astype('float32')
    power = power.reshape(-1, 1)
    # print(power[0::24][-2:])

    # normalize features
    scaler = MinMaxScaler(feature_range=(0, 1)).fit(power)
    scaled = scaler.fit_transform(power)
    # 分别计算24个时间点的均值与方差
    means_list = []
    var_list = []
    for i in range(24):
        means = np.mean(scaled[i::24])
        var = np.var(scaled[i::24])
        means_list.append(means)
        var_list.append(var)
    #把列表中的np数字转换成浮点类型
    means_list = [float(x) for x in means_list]
    var_list = [float(x) for x in var_list]

    return means_list, var_list, scaled.reshape(-1)


def calculate_mutual_information(df, method='regression'):
    """
    计算时间序列之间的互信息矩阵
    :param df: DataFrame，包含多列时间序列
    :param method: 'regression' 或 'classification'，分别用于连续变量和离散变量
    :return: 互信息矩阵
    """
    n_series = df.shape[1]
    mi_matrix = np.zeros((n_series, n_series))

    for i in range(n_series):
        for j in range(n_series):
            if i == j:
                mi_matrix[i, j] = 1.0  # 自身互信息为1
            else:
                if method == 'regression':
                    mi_matrix[i, j] = mutual_info_regression(df.iloc[:, i].values.reshape(-1, 1),
                                                             df.iloc[:, j].values)[0]
                elif method == 'classification':
                    # 将连续变量离散化（分箱）
                    bins = 10  # 分箱数量
                    x = np.histogram(df.iloc[:, i], bins=bins)[0]
                    y = np.histogram(df.iloc[:, j], bins=bins)[0]
                    mi_matrix[i, j] = mutual_info_score(x, y)
                else:
                    raise ValueError("Method must be 'regression' or 'classification'")

    mi_matrix = pd.DataFrame(mi_matrix, index=df.columns, columns=df.columns)
    return mi_matrix



if __name__ == '__main__':
    #设置全局字体为Times New Roman
    plt.rcParams['font.family'] = 'Times New Roman'


    file_path1 = 'D:\\小论文\\小论文2\\experiment_content\\dealt_data\\panama_power.csv'
    file_path2 = 'D:\\小论文\\小论文2\\experiment_content\\dealt_data\\tetouan_region1_power.csv'
    file_path3 = 'D:\\小论文\\小论文2\\experiment_content\\dealt_data\\tetouan_region2_power.csv'
    file_path4 = 'D:\\小论文\\小论文2\\experiment_content\\dealt_data\\tetouan_region3_power.csv'

    means_list1, var_list1, scaled_panama = calculate_mean_and_var(file_path1)  #巴拿马数据从1.00开始，所以需要把最后一个元素放到第一个位置
    means_list1 = [means_list1[-1]] + means_list1[:-1]
    var_list1 = [var_list1[-1]] + var_list1[:-1]

    means_list2, var_list2, scaled_region1 = calculate_mean_and_var(file_path2)
    means_list3, var_list3, scaled_region2 = calculate_mean_and_var(file_path3)
    means_list4, var_list4, scaled_region3 = calculate_mean_and_var(file_path4)

    #分析四个平均值的相关性
    # 将四个平均值列表组合成一个DataFrame
    df = pd.DataFrame({
        'panama': means_list1,
        'region1': means_list2,
        'region2': means_list3,
        'region3': means_list4
    })

    # correlation_matrix = df.corr(method='pearson') # 计算皮尔逊相关系数矩阵
    # correlation_matrix = df.corr(method='spearman') # 计算斯皮尔曼相关系数矩阵
    correlation_matrix = df.corr(method='kendall') # 计算肯德尔相关系数矩阵

    # 计算互信息矩阵（默认使用回归方法）
    # correlation_matrix = calculate_mutual_information(df, method='regression')


    # 绘制热力图
    plt.figure(figsize=(8, 6))
    # 调整 annot_kws 中的 fontsize 控制数字字体大小
    # 调整 xticklabels 和 yticklabels 的 fontsize 控制标签字体大小
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', annot_kws={'size': 14})
    # 设置 x 轴标签字体大小
    plt.xticks(fontsize=14)
    # 设置 y 轴标签字体大小
    plt.yticks(fontsize=14)
    # plt.title('Correlation Heatmap of Energy Values')
    plt.show()