{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://i.typlog.com/tanxy/8331661530_1720295.jpg)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import scipy.io as sio\n", "import numpy as np\n", "from sklearn import preprocessing\n", "import torch\n", "import torch.nn as nn\n", "import math\n", "from torch.utils.data import Dataset,DataLoader\n", "import torch.optim as optim\n", "from torch.autograd import Variable\n", "import torch.nn.functional as F\n", "import logging" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Spectral Residual Learning 模块的输入\n", "class SPCModuleIN(nn.Module):\n", " def __init__(self, in_channels, out_channels, bias=True):\n", " super(SPCModuleIN, self).__init__()\n", " \n", " self.s1 = nn.Conv3d(in_channels, out_channels, kernel_size=(7,1,1), stride=(2,1,1), bias=False)\n", " #self.bn = nn.BatchNorm3d(out_channels)\n", "\n", " def forward(self, input):\n", " \n", " input = input.unsqueeze(1)\n", " \n", " out = self.s1(input)\n", " \n", " return out.squeeze(1) " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Spectral Residual Leaning 部分\n", "class ResSPC(nn.Module):\n", " def __init__(self, in_channels, out_channels, bias=True):\n", " super(ResSPC, self).__init__()\n", " \n", " self.spc1 = nn.Sequential(nn.Conv3d(in_channels, in_channels, kernel_size=(7,1,1), padding=(3,0,0), bias=False),\n", " nn.LeakyReLU(inplace=True),\n", " nn.BatchNorm3d(in_channels),)\n", " \n", " self.spc2 = nn.Sequential(nn.Conv3d(in_channels, in_channels, kernel_size=(7,1,1), padding=(3,0,0), bias=False),\n", " nn.LeakyReLU(inplace=True),)\n", " \n", " self.bn2 = nn.BatchNorm3d(out_channels)\n", "\n", " def forward(self, input):\n", " \n", " out = self.spc1(input)\n", " out = self.bn2(self.spc2(out))\n", " \n", " return F.leaky_relu(out + input)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Spatial Residual Learning 输入部分\n", "class SPAModuleIN(nn.Module):\n", " def __init__(self, in_channels, out_channels, k=97, bias=True):\n", " super(SPAModuleIN, self).__init__()\n", " \n", " # print('k=',k)\n", " self.s1 = nn.Conv3d(in_channels, out_channels, kernel_size=(k,3,3), bias=False)\n", " #self.bn = nn.BatchNorm2d(out_channels)\n", "\n", " def forward(self, input):\n", " \n", " # print(input.size())\n", " out = self.s1(input)\n", " out = out.squeeze(2)\n", " # print(out.size)\n", " \n", " return out" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Spatial Residual Learning 部分\n", "class ResSPA(nn.Module):\n", " def __init__(self, in_channels, out_channels, bias=True):\n", " super(ResSPA, self).__init__()\n", " \n", " self.spa1 = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),\n", " nn.LeakyReLU(inplace=True),\n", " nn.BatchNorm2d(in_channels),)\n", " \n", " self.spa2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),\n", " nn.LeakyReLU(inplace=True),)\n", " self.bn2 = nn.BatchNorm2d(out_channels)\n", "\n", " def forward(self, input):\n", " \n", " out = self.spa1(input)\n", " out = self.bn2(self.spa2(out))\n", " \n", " return F.leaky_relu(out + input)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# 整个 SSRN 的网络\n", "# 由 SPCModuleIN + ResSPC + SPAModuleI + ResSPA 组成\n", "class SSRN(nn.Module):\n", " def __init__(self, num_classes=9, k=97):\n", " super(SSRN, self).__init__()\n", " \n", " # 第一层输入,经过第一个卷积层\n", " self.layer1 = SPCModuleIN(1, 24)\n", " #self.bn1 = nn.BatchNorm3d(28)\n", " \n", " # 第二层,经过第一个光谱卷积块\n", " self.layer2 = ResSPC(24,24)\n", " \n", " # 第三层进入第二个光谱卷积块\n", " self.layer3 = ResSPC(24,24)\n", " \n", " # 第四层,经过空间残差块的输入部分的第一个卷积层\n", " self.layer4 = SPAModuleIN(24, 24, k=k)\n", " \n", " # \n", " self.bn4 = nn.BatchNorm2d(24)\n", " \n", " self.layer5 = ResSPA(24, 24)\n", " self.layer6 = ResSPA(24, 24)\n", "\n", " self.fc = nn.Linear(24, num_classes)\n", "\n", " def forward(self, x):\n", "\n", " x = F.leaky_relu(self.layer1(x)) #self.bn1(F.leaky_relu(self.layer1(x)))\n", "# print(x.shape)\n", " #print(x.size())\n", " x = self.layer2(x)\n", "# print(x.shape)\n", " x = self.layer3(x)\n", "# print(x.shape)\n", " #x = self.layer31(x)\n", "\n", " x = self.bn4(F.leaky_relu(self.layer4(x)))\n", "# print(x.shape)\n", " x = self.layer5(x)\n", "# print(x.shape)\n", " x = self.layer6(x)\n", "# print(x.shape)\n", "\n", " x = F.avg_pool2d(x, x.size()[-1])\n", "# print(x.shape)\n", " x = self.fc(x.squeeze())\n", "# print(x.shape)\n", " \n", " return x" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SSRN(\n", " (layer1): SPCModuleIN(\n", " (s1): Conv3d(1, 24, kernel_size=(7, 1, 1), stride=(2, 1, 1), bias=False)\n", " )\n", " (layer2): ResSPC(\n", " (spc1): Sequential(\n", " (0): Conv3d(24, 24, kernel_size=(7, 1, 1), stride=(1, 1, 1), padding=(3, 0, 0), bias=False)\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " (2): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (spc2): Sequential(\n", " (0): Conv3d(24, 24, kernel_size=(7, 1, 1), stride=(1, 1, 1), padding=(3, 0, 0), bias=False)\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (bn2): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (layer3): ResSPC(\n", " (spc1): Sequential(\n", " (0): Conv3d(24, 24, kernel_size=(7, 1, 1), stride=(1, 1, 1), padding=(3, 0, 0), bias=False)\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " (2): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (spc2): Sequential(\n", " (0): Conv3d(24, 24, kernel_size=(7, 1, 1), stride=(1, 1, 1), padding=(3, 0, 0), bias=False)\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (bn2): BatchNorm3d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (layer4): SPAModuleIN(\n", " (s1): Conv3d(24, 24, kernel_size=(97, 3, 3), stride=(1, 1, 1), bias=False)\n", " )\n", " (bn4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (layer5): ResSPA(\n", " (spa1): Sequential(\n", " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " (2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (spa2): Sequential(\n", " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (layer6): ResSPA(\n", " (spa1): Sequential(\n", " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " (2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (spa2): Sequential(\n", " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " )\n", " (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (fc): Linear(in_features=24, out_features=16, bias=True)\n", ")\n", "torch.Size([16])\n" ] } ], "source": [ "# 测试一下定义的 SSRN 网络通不通 (batch_size, frames, height, width, channels)\n", "x = torch.randn(1, 200, 7, 7)\n", "net = SSRN(num_classes=16, k=97)\n", "print(net)\n", "y = net(x)\n", "print(y.shape)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# 一些预处理函数\n", "def indexToAssignment(index_, pad_length, Row, Col):\n", " new_assign = {}\n", " for counter, value in enumerate(index_):\n", " assign_0 = value // Col + pad_length\n", " assign_1 = value % Col + pad_length\n", " new_assign[counter] = [assign_0, assign_1]\n", " return new_assign\n", "\n", "def assignmentToIndex(assign_0, assign_1, Row, Col):\n", " new_index = assign_0 * Col + assign_1\n", " return new_index\n", "\n", "def selectNeighboringPatch(matrix, ex_len, pos_row, pos_col):\n", " # print(matrix.shape)\n", " selected_rows = matrix[:,range(pos_row-ex_len,pos_row+ex_len+1), :]\n", " selected_patch = selected_rows[:, :, range(pos_col-ex_len, pos_col+ex_len+1)]\n", " return selected_patch\n", "\n", "def sampling(proptionVal, groundTruth): #divide dataset into train and test datasets\n", " labels_loc = {}\n", " train = {}\n", " test = {}\n", " m = max(groundTruth)\n", " for i in range(m):\n", " indices = [j for j, x in enumerate(groundTruth.ravel().tolist()) if x == i + 1]\n", " np.random.shuffle(indices)\n", " labels_loc[i] = indices\n", " nb_val = int(proptionVal * len(indices))\n", " train[i] = indices[:-nb_val]\n", " test[i] = indices[-nb_val:]\n", " whole_indices = []\n", " train_indices = []\n", " test_indices = []\n", " for i in range(m):\n", " whole_indices += labels_loc[i]\n", " train_indices += train[i]\n", " test_indices += test[i]\n", " np.random.shuffle(train_indices)\n", " np.random.shuffle(test_indices)\n", " return whole_indices, train_indices, test_indices\n", "\n", "\n", "\n", "sample_200 = [2, 27, 19, 4, 9, 14, 2, 10, 3, 24, 41, 14, 4, 18, 7, 2]\n", "rsample_200 = [1, 28, 16, 5, 9, 14, 1, 9, 1, 19, 47, 12, 4, 24, 8, 2]\n", "\n", "def rsampling(groundTruth, sample_num = sample_200, rsample_num = rsample_200): #divide dataset into train and test datasets\n", " labels_loc = {}\n", " labeled = {}\n", " train2 = {}\n", " val = {}\n", " test = {}\n", " m = np.max(groundTruth)\n", " for i in range(m):\n", " indices = [j for j, x in enumerate(groundTruth.ravel().tolist()) if x == i + 1]\n", " np.random.shuffle(indices)\n", " labels_loc[i] = indices\n", " labeled[i] = indices[:sample_num[i]]\n", " train2[i] = indices[sample_num[i]:sample_num[i]+rsample_num[i]]\n", " val[i] = indices[-(sample_num[i]+rsample_num[i]):]\n", " test[i] = indices[sample_num[i]+rsample_num[i]:-(sample_num[i]+rsample_num[i])]\n", " whole_indices = []\n", " labeled_indices = []\n", " train2_indices = []\n", " val_indices = []\n", " test_indices = []\n", " for i in range(m):\n", " whole_indices += labels_loc[i]\n", " labeled_indices += labeled[i]\n", " train2_indices += train2[i]\n", " val_indices += val[i]\n", " test_indices += test[i]\n", " np.random.shuffle(labeled_indices)\n", " np.random.shuffle(train2_indices) \n", " np.random.shuffle(val_indices) \n", " np.random.shuffle(test_indices)\n", " return whole_indices, labeled_indices, train2_indices, val_indices, test_indices\n", "\n", "\n", "def zeroPadding_3D(old_matrix, pad_length, pad_depth = 0):\n", " new_matrix = np.lib.pad(old_matrix, ((pad_depth, pad_depth), (pad_length, pad_length), (pad_length, pad_length)), 'constant', constant_values=0)\n", " return new_matrix" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 预处理" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# 读取 Indian Pines 数据集\n", "mat_data = sio.loadmat('../data/Indian_pines_corrected.mat')\n", "data_IN = mat_data['indian_pines_corrected']\n", "mat_gt = sio.loadmat('../data/Indian_pines_gt.mat')\n", "gt_IN = mat_gt['indian_pines_gt']\n", "# print(data_IN.shape, gt_IN.shape) → (145, 145, 200) (145, 145)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TEST_SIZE= 9449\n" ] } ], "source": [ "# Input dataset configuration to generate 103×7×7 HSI samples\n", "new_gt_IN = gt_IN\n", "# the number of land-cover categories in Indian Pines\n", "nb_classes = 16\n", "\n", "INPUT_DIMENSION_CONV = 200\n", "INPUT_DIMENSION = 200\n", "\n", "# 20%:10%:70% data for training, validation and testing\n", "\n", "TOTAL_SIZE = 10249 # Indian Pines 的有效像素是 10249\n", "\n", "TRAIN_SIZE = 200\n", "DEV_SIZE = 200\n", "VAL_SIZE = 400\n", "TEST_SIZE = TOTAL_SIZE - TRAIN_SIZE - DEV_SIZE - VAL_SIZE # 10249-200-200-400\n", "print('TEST_SIZE=', TEST_SIZE) # TEST_SIZE= 9449\n", "\n", "# Indian Pines 数据集的有效光谱通道数\n", "img_channels = 200\n", "PATCH_LENGTH = 4" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# 进行了减去均值除以方差的预处理\n", "MAX = data_IN.max()\n", "data_IN = np.transpose(data_IN, (2,0,1))\n", "\n", "data_IN = data_IN - np.mean(data_IN, axis=(1,2), keepdims=True)\n", "data_IN = data_IN / MAX" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----- Some Shapes For details ----\n", "data.shape = (200, 21025)\n", "gt.shape = (21025,)\n", "whole_data.shape = (200, 145, 145)\n", "padded_data.shape = (200, 153, 153)\n", "train_data.shape = (200, 200, 9, 9)\n", "test_data.shape = (9449, 200, 9, 9)\n", "all_data.shape = (10249, 200, 9, 9)\n" ] } ], "source": [ "# np.prod 返回指定轴上的乘积,不指定轴默认是所有元素的乘积\n", "# 这里其实就是将 data_IN.shape=(145, 145, 200) reshap 为 (200, 145×145),即将前两个维度堆叠\n", "# 还有就是将 data_IN.shape=(145, 145) reshape 为 (21025, ),也是将前两个维度堆叠\n", "data = data_IN.reshape(np.prod(data_IN.shape[:1]),np.prod(data_IN.shape[1:]))\n", "print('----- Some Shapes For details ----')\n", "print('data.shape =', data.shape)\n", "gt = new_gt_IN.reshape(np.prod(new_gt_IN.shape[:2]),)\n", "print('gt.shape =', gt.shape)\n", "\n", "whole_data = data.reshape(data_IN.shape[0], data_IN.shape[1],data_IN.shape[2])\n", "print('whole_data.shape =', whole_data.shape)\n", "#whole_data = whole_data - np.mean(whole_data, axis=(1,2), keepdims=True)\n", "# 对 whole_data 进行 PATCH_LENGTH = 4 的填充,其实就是 145 × 145 的 size 每边都增加 4,最后就是 145 + 4 + 4 = 153\n", "padded_data = zeroPadding_3D(whole_data, PATCH_LENGTH)\n", "print('padded_data.shape =', padded_data.shape)\n", "\n", "#CATEGORY = 9\n", "\n", "train_data = np.zeros((TRAIN_SIZE, INPUT_DIMENSION_CONV, 2*PATCH_LENGTH + 1, 2*PATCH_LENGTH + 1))\n", "print('train_data.shape =', train_data.shape)\n", "test_data = np.zeros((TEST_SIZE, INPUT_DIMENSION_CONV, 2*PATCH_LENGTH + 1, 2*PATCH_LENGTH + 1))\n", "print('test_data.shape =', test_data.shape)\n", "all_data = np.zeros((TOTAL_SIZE, INPUT_DIMENSION_CONV, 2*PATCH_LENGTH + 1, 2*PATCH_LENGTH + 1))\n", "print('all_data.shape =', all_data.shape)\n", "\n", "\n", "all_indices, train_indices, dev_indices, val_indices, test_indices = rsampling(gt)\n", "\n", "y_train = gt[train_indices] - 1\n", "y_test = gt[test_indices] - 1\n", "y_all = gt[all_indices] - 1\n", "\n", "train_assign = indexToAssignment(train_indices, PATCH_LENGTH, whole_data.shape[1], whole_data.shape[2])\n", "for i in range(len(train_assign)):\n", " train_data[i] = selectNeighboringPatch(padded_data, PATCH_LENGTH, train_assign[i][0], train_assign[i][1])\n", " \n", "test_assign = indexToAssignment(test_indices, PATCH_LENGTH, whole_data.shape[1], whole_data.shape[2])\n", "for i in range(len(test_assign)):\n", " test_data[i] = selectNeighboringPatch(padded_data, PATCH_LENGTH, test_assign[i][0], test_assign[i][1])\n", " \n", "all_assign = indexToAssignment(all_indices, PATCH_LENGTH, whole_data.shape[1], whole_data.shape[2])\n", "for i in range(len(all_assign)):\n", " all_data[i] = selectNeighboringPatch(padded_data, PATCH_LENGTH, all_assign[i][0], all_assign[i][1])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([3, 8])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.prod([[1, 2], [3, 4]], axis=0)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils import data\n", "\n", "class HSIDataset(data.Dataset):\n", " def __init__(self, list_IDs, samples, labels):\n", " \n", " self.list_IDs = list_IDs\n", " self.samples = samples\n", " self.labels = labels\n", "\n", " def __len__(self):\n", " return len(self.list_IDs)\n", "\n", " def __getitem__(self, index):\n", " # Select sample\n", " ID = self.list_IDs[index]\n", "\n", " # Load data and get label\n", " X = self.samples[ID]\n", " y = self.labels[ID]\n", "\n", " return X, y" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# CUDA for PyTorch\n", "# 选择 GPU 进行训练\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Parameters\n", "params = {'batch_size': 50,\n", " 'shuffle': True,\n", " 'num_workers': 8}\n", "\n", "# Generators\n", "training_set = HSIDataset(range(len(train_indices)), train_data, y_train)\n", "training_generator = data.DataLoader(training_set, **params)\n", "\n", "validation_set = HSIDataset(range(len(test_indices)), test_data, y_test)\n", "validation_generator = data.DataLoader(validation_set, **params)\n", "\n", "all_set = HSIDataset(range(len(all_indices)), all_data, y_all)\n", "all_generator = data.DataLoader(all_set, **params)\n", "\n", "\n", "trainloader = torch.utils.data.DataLoader(training_set, batch_size=50, shuffle=True, num_workers=8)\n", "\n", "validationloader = torch.utils.data.DataLoader(validation_set, batch_size=50, shuffle=False, num_workers=8)\n", "\n", "allloader = torch.utils.data.DataLoader(all_set, batch_size=50, shuffle=False, num_workers=8)\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1, 4] loss: 2.7162\n", "Accuracy of the network on the validation set: 24.11895 %\n", "[2, 4] loss: 2.0227\n", "Accuracy of the network on the validation set: 24.11895 %\n", "[3, 4] loss: 1.7701\n", "Accuracy of the network on the validation set: 0.12700 %\n", "[4, 4] loss: 1.5836\n", "Accuracy of the network on the validation set: 5.72547 %\n", "[5, 4] loss: 1.4111\n", "Accuracy of the network on the validation set: 9.37665 %\n", "[6, 4] loss: 1.3106\n", "Accuracy of the network on the validation set: 9.37665 %\n", "[7, 4] loss: 1.2268\n", "Accuracy of the network on the validation set: 9.37665 %\n", "[8, 4] loss: 1.0714\n", "Accuracy of the network on the validation set: 8.04318 %\n", "[9, 4] loss: 0.9874\n", "Accuracy of the network on the validation set: 8.04318 %\n", "[10, 4] loss: 0.9272\n", "Accuracy of the network on the validation set: 10.06456 %\n", "[11, 4] loss: 0.8138\n", "Accuracy of the network on the validation set: 8.04318 %\n", "[12, 4] loss: 0.7358\n", "Accuracy of the network on the validation set: 8.04318 %\n", "[13, 4] loss: 0.6966\n", "Accuracy of the network on the validation set: 8.04318 %\n", "[14, 4] loss: 0.6242\n", "Accuracy of the network on the validation set: 8.04318 %\n", "[15, 4] loss: 0.5519\n", "Accuracy of the network on the validation set: 8.04318 %\n", "[16, 4] loss: 0.4973\n", "Accuracy of the network on the validation set: 8.05376 %\n", "[17, 4] loss: 0.4350\n", "Accuracy of the network on the validation set: 9.14382 %\n", "[18, 4] loss: 0.4468\n", "Accuracy of the network on the validation set: 7.83152 %\n", "[19, 4] loss: 0.4053\n", "Accuracy of the network on the validation set: 7.09070 %\n", "[20, 4] loss: 0.3519\n", "Accuracy of the network on the validation set: 9.04858 %\n", "[21, 4] loss: 0.3069\n", "Accuracy of the network on the validation set: 12.86909 %\n", "[22, 4] loss: 0.3460\n", "Accuracy of the network on the validation set: 38.69193 %\n", "[23, 4] loss: 0.3244\n", "Accuracy of the network on the validation set: 44.45973 %\n", "[24, 4] loss: 0.3181\n", "Accuracy of the network on the validation set: 59.91110 %\n", "[25, 4] loss: 0.2651\n", "Accuracy of the network on the validation set: 60.32384 %\n", "[26, 4] loss: 0.2398\n", "Accuracy of the network on the validation set: 61.40332 %\n", "[27, 4] loss: 0.2439\n", "Accuracy of the network on the validation set: 74.31474 %\n", "[28, 4] loss: 0.1928\n", "Accuracy of the network on the validation set: 84.12530 %\n", "[29, 4] loss: 0.1847\n", "Accuracy of the network on the validation set: 87.05683 %\n", "[30, 4] loss: 0.1567\n", "Accuracy of the network on the validation set: 85.49053 %\n", "[31, 4] loss: 0.1418\n", "Accuracy of the network on the validation set: 79.54281 %\n", "[32, 4] loss: 0.1176\n", "Accuracy of the network on the validation set: 87.00392 %\n", "[33, 4] loss: 0.1178\n", "Accuracy of the network on the validation set: 85.43761 %\n", "[34, 4] loss: 0.0870\n", "Accuracy of the network on the validation set: 84.85554 %\n", "[35, 4] loss: 0.0871\n", "Accuracy of the network on the validation set: 85.50111 %\n", "[36, 4] loss: 0.0721\n", "Accuracy of the network on the validation set: 87.64949 %\n", "[37, 4] loss: 0.0771\n", "Accuracy of the network on the validation set: 86.54884 %\n", "[38, 4] loss: 0.0801\n", "Accuracy of the network on the validation set: 84.33697 %\n", "[39, 4] loss: 0.0545\n", "Accuracy of the network on the validation set: 81.76527 %\n", "[40, 4] loss: 0.0690\n", "Accuracy of the network on the validation set: 88.71838 %\n", "[41, 4] loss: 0.0667\n", "Accuracy of the network on the validation set: 86.18901 %\n", "[42, 4] loss: 0.0838\n", "Accuracy of the network on the validation set: 77.74368 %\n", "[43, 4] loss: 0.0986\n", "Accuracy of the network on the validation set: 81.44777 %\n", "[44, 4] loss: 0.0815\n", "Accuracy of the network on the validation set: 80.92920 %\n", "[45, 4] loss: 0.0761\n", "Accuracy of the network on the validation set: 76.42079 %\n", "[46, 4] loss: 0.1073\n", "Accuracy of the network on the validation set: 84.24172 %\n", "[47, 4] loss: 0.0947\n", "Accuracy of the network on the validation set: 75.99746 %\n", "[48, 4] loss: 0.1019\n", "Accuracy of the network on the validation set: 75.96571 %\n", "[49, 4] loss: 0.0650\n", "Accuracy of the network on the validation set: 77.75426 %\n", "[50, 4] loss: 0.0534\n", "Accuracy of the network on the validation set: 81.90285 %\n", "[51, 4] loss: 0.0517\n", "Accuracy of the network on the validation set: 85.18362 %\n", "[52, 4] loss: 0.0547\n", "Accuracy of the network on the validation set: 74.71690 %\n", "[53, 4] loss: 0.0423\n", "Accuracy of the network on the validation set: 83.75489 %\n", "[54, 4] loss: 0.0448\n", "Accuracy of the network on the validation set: 87.30024 %\n", "[55, 4] loss: 0.0419\n", "Accuracy of the network on the validation set: 88.31622 %\n", "[56, 4] loss: 0.0571\n", "Accuracy of the network on the validation set: 70.70589 %\n", "[57, 4] loss: 0.0555\n", "Accuracy of the network on the validation set: 88.04106 %\n", "[58, 4] loss: 0.0805\n", "Accuracy of the network on the validation set: 70.98106 %\n", "[59, 4] loss: 0.1174\n", "Accuracy of the network on the validation set: 78.04000 %\n", "[60, 4] loss: 0.1034\n", "Accuracy of the network on the validation set: 77.39443 %\n", "[61, 4] loss: 0.0609\n", "Accuracy of the network on the validation set: 78.30458 %\n", "[62, 4] loss: 0.0595\n", "Accuracy of the network on the validation set: 80.74929 %\n", "[63, 4] loss: 0.0719\n", "Accuracy of the network on the validation set: 84.88729 %\n", "[64, 4] loss: 0.0696\n", "Accuracy of the network on the validation set: 78.89724 %\n", "[65, 4] loss: 0.0566\n", "Accuracy of the network on the validation set: 84.66504 %\n", "[66, 4] loss: 0.1038\n", "Accuracy of the network on the validation set: 81.42661 %\n", "[67, 4] loss: 0.0988\n", "Accuracy of the network on the validation set: 63.66811 %\n", "[68, 4] loss: 0.1816\n", "Accuracy of the network on the validation set: 65.44608 %\n", "[69, 4] loss: 0.1896\n", "Accuracy of the network on the validation set: 72.94952 %\n", "[70, 4] loss: 0.1280\n", "Accuracy of the network on the validation set: 61.19166 %\n", "[71, 4] loss: 0.2858\n", "Accuracy of the network on the validation set: 69.97566 %\n", "[72, 4] loss: 0.1813\n", "Accuracy of the network on the validation set: 66.67372 %\n", "[73, 4] loss: 0.1547\n", "Accuracy of the network on the validation set: 50.84136 %\n", "[74, 4] loss: 0.1030\n", "Accuracy of the network on the validation set: 74.43116 %\n", "[75, 4] loss: 0.0767\n", "Accuracy of the network on the validation set: 77.20394 %\n", "[76, 4] loss: 0.1155\n", "Accuracy of the network on the validation set: 75.90221 %\n", "[77, 4] loss: 0.0701\n", "Accuracy of the network on the validation set: 73.80675 %\n", "[78, 4] loss: 0.1133\n", "Accuracy of the network on the validation set: 77.22510 %\n", "[79, 4] loss: 0.0995\n", "Accuracy of the network on the validation set: 77.70134 %\n", "[80, 4] loss: 0.0996\n", "Accuracy of the network on the validation set: 78.95015 %\n", "[81, 4] loss: 0.0716\n", "Accuracy of the network on the validation set: 84.00889 %\n", "[82, 4] loss: 0.0912\n", "Accuracy of the network on the validation set: 82.98233 %\n", "[83, 4] loss: 0.1056\n", "Accuracy of the network on the validation set: 75.91280 %\n", "[84, 4] loss: 0.0839\n", "Accuracy of the network on the validation set: 75.38364 %\n", "[85, 4] loss: 0.0483\n", "Accuracy of the network on the validation set: 75.98688 %\n", "[86, 4] loss: 0.0384\n", "Accuracy of the network on the validation set: 77.87067 %\n", "[87, 4] loss: 0.0483\n", "Accuracy of the network on the validation set: 85.15187 %\n", "[88, 4] loss: 0.0378\n", "Accuracy of the network on the validation set: 85.86094 %\n", "[89, 4] loss: 0.0398\n", "Accuracy of the network on the validation set: 87.22616 %\n", "[90, 4] loss: 0.0416\n", "Accuracy of the network on the validation set: 87.17325 %\n", "[91, 4] loss: 0.0288\n", "Accuracy of the network on the validation set: 85.42703 %\n", "[92, 4] loss: 0.0315\n", "Accuracy of the network on the validation set: 86.25251 %\n", "[93, 4] loss: 0.0273\n", "Accuracy of the network on the validation set: 81.34194 %\n", "[94, 4] loss: 0.0281\n", "Accuracy of the network on the validation set: 82.66483 %\n", "[95, 4] loss: 0.0197\n", "Accuracy of the network on the validation set: 86.15727 %\n", "[96, 4] loss: 0.0170\n", "Accuracy of the network on the validation set: 88.18923 %\n", "[97, 4] loss: 0.0221\n", "Accuracy of the network on the validation set: 88.77130 %\n", "[98, 4] loss: 0.0214\n", "Accuracy of the network on the validation set: 89.38512 %\n", "[99, 4] loss: 0.0099\n", "Accuracy of the network on the validation set: 89.88253 %\n", "[100, 4] loss: 0.0117\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Accuracy of the network on the validation set: 90.26352 %\n", "[101, 4] loss: 0.0108\n", "Accuracy of the network on the validation set: 90.55985 %\n", "[102, 4] loss: 0.0090\n", "Accuracy of the network on the validation set: 90.53868 %\n", "[103, 4] loss: 0.0105\n", "Accuracy of the network on the validation set: 90.29527 %\n", "[104, 4] loss: 0.0091\n", "Accuracy of the network on the validation set: 90.07302 %\n", "[105, 4] loss: 0.0109\n", "Accuracy of the network on the validation set: 90.03069 %\n", "[106, 4] loss: 0.0081\n", "Accuracy of the network on the validation set: 90.15769 %\n", "[107, 4] loss: 0.0094\n", "Accuracy of the network on the validation set: 90.30585 %\n", "[108, 4] loss: 0.0056\n", "Accuracy of the network on the validation set: 90.35877 %\n", "[109, 4] loss: 0.0073\n", "Accuracy of the network on the validation set: 90.39052 %\n", "[110, 4] loss: 0.0060\n", "Accuracy of the network on the validation set: 90.46460 %\n", "[111, 4] loss: 0.0071\n", "Accuracy of the network on the validation set: 90.60218 %\n", "[112, 4] loss: 0.0061\n", "Accuracy of the network on the validation set: 90.65510 %\n", "[113, 4] loss: 0.0096\n", "Accuracy of the network on the validation set: 90.59160 %\n", "[114, 4] loss: 0.0092\n", "Accuracy of the network on the validation set: 89.78728 %\n", "[115, 4] loss: 0.0068\n", "Accuracy of the network on the validation set: 89.44862 %\n", "[116, 4] loss: 0.0072\n", "Accuracy of the network on the validation set: 89.58620 %\n", "[117, 4] loss: 0.0152\n", "Accuracy of the network on the validation set: 88.81363 %\n", "[118, 4] loss: 0.0093\n", "Accuracy of the network on the validation set: 88.88771 %\n", "[119, 4] loss: 0.0138\n", "Accuracy of the network on the validation set: 89.22637 %\n", "[120, 4] loss: 0.0070\n", "Accuracy of the network on the validation set: 86.52767 %\n", "[121, 4] loss: 0.0152\n", "Accuracy of the network on the validation set: 85.87152 %\n", "[122, 4] loss: 0.0101\n", "Accuracy of the network on the validation set: 89.01471 %\n", "[123, 4] loss: 0.0089\n", "Accuracy of the network on the validation set: 89.80845 %\n", "[124, 4] loss: 0.0065\n", "Accuracy of the network on the validation set: 89.77670 %\n", "[125, 4] loss: 0.0056\n", "Accuracy of the network on the validation set: 89.89311 %\n", "[126, 4] loss: 0.0089\n", "Accuracy of the network on the validation set: 90.30585 %\n", "[127, 4] loss: 0.0063\n", "Accuracy of the network on the validation set: 90.10477 %\n", "[128, 4] loss: 0.0067\n", "Accuracy of the network on the validation set: 90.09419 %\n", "[129, 4] loss: 0.0057\n", "Accuracy of the network on the validation set: 90.16827 %\n", "[130, 4] loss: 0.0053\n", "Accuracy of the network on the validation set: 89.98836 %\n", "[131, 4] loss: 0.0061\n", "Accuracy of the network on the validation set: 90.33760 %\n", "[132, 4] loss: 0.0071\n", "Accuracy of the network on the validation set: 90.35877 %\n", "[133, 4] loss: 0.0046\n", "Accuracy of the network on the validation set: 90.49635 %\n", "[134, 4] loss: 0.0043\n", "Accuracy of the network on the validation set: 90.59160 %\n", "[135, 4] loss: 0.0034\n", "Accuracy of the network on the validation set: 90.62335 %\n", "[136, 4] loss: 0.0049\n", "Accuracy of the network on the validation set: 90.63393 %\n", "[137, 4] loss: 0.0034\n", "Accuracy of the network on the validation set: 89.73436 %\n", "[138, 4] loss: 0.0042\n", "Accuracy of the network on the validation set: 88.93005 %\n", "[139, 4] loss: 0.0045\n", "Accuracy of the network on the validation set: 89.15229 %\n", "[140, 4] loss: 0.0079\n", "Accuracy of the network on the validation set: 90.46460 %\n", "[141, 4] loss: 0.0046\n", "Accuracy of the network on the validation set: 90.75034 %\n", "[142, 4] loss: 0.0053\n", "Accuracy of the network on the validation set: 90.75034 %\n", "[143, 4] loss: 0.0057\n", "Accuracy of the network on the validation set: 90.77151 %\n", "[144, 4] loss: 0.0034\n", "Accuracy of the network on the validation set: 90.63393 %\n", "[145, 4] loss: 0.0058\n", "Accuracy of the network on the validation set: 90.81384 %\n", "[146, 4] loss: 0.0042\n", "Accuracy of the network on the validation set: 90.69743 %\n", "[147, 4] loss: 0.0057\n", "Accuracy of the network on the validation set: 90.82443 %\n", "[148, 4] loss: 0.0038\n", "Accuracy of the network on the validation set: 90.77151 %\n", "[149, 4] loss: 0.0052\n", "Accuracy of the network on the validation set: 90.63393 %\n", "[150, 4] loss: 0.0046\n", "Accuracy of the network on the validation set: 90.48577 %\n", "[151, 4] loss: 0.0027\n", "Accuracy of the network on the validation set: 90.32702 %\n", "[152, 4] loss: 0.0045\n", "Accuracy of the network on the validation set: 90.43285 %\n", "[153, 4] loss: 0.0034\n", "Accuracy of the network on the validation set: 90.35877 %\n", "[154, 4] loss: 0.0035\n", "Accuracy of the network on the validation set: 90.34818 %\n", "[155, 4] loss: 0.0038\n", "Accuracy of the network on the validation set: 90.36935 %\n", "[156, 4] loss: 0.0028\n", "Accuracy of the network on the validation set: 90.27410 %\n", "[157, 4] loss: 0.0032\n", "Accuracy of the network on the validation set: 90.21060 %\n", "[158, 4] loss: 0.0065\n", "Accuracy of the network on the validation set: 82.49550 %\n", "[159, 4] loss: 0.0110\n", "Accuracy of the network on the validation set: 88.66547 %\n", "[160, 4] loss: 0.0144\n", "Accuracy of the network on the validation set: 83.75489 %\n", "[161, 4] loss: 0.0189\n", "Accuracy of the network on the validation set: 87.48016 %\n", "[162, 4] loss: 0.0049\n", "Accuracy of the network on the validation set: 87.98815 %\n", "[163, 4] loss: 0.0083\n", "Accuracy of the network on the validation set: 87.89290 %\n", "[164, 4] loss: 0.0062\n", "Accuracy of the network on the validation set: 88.96179 %\n", "[165, 4] loss: 0.0037\n", "Accuracy of the network on the validation set: 89.00413 %\n", "[166, 4] loss: 0.0046\n", "Accuracy of the network on the validation set: 89.30046 %\n", "[167, 4] loss: 0.0054\n", "Accuracy of the network on the validation set: 89.63912 %\n", "[168, 4] loss: 0.0050\n", "Accuracy of the network on the validation set: 90.66568 %\n", "[169, 4] loss: 0.0033\n", "Accuracy of the network on the validation set: 91.09959 %\n", "[170, 4] loss: 0.0037\n", "Accuracy of the network on the validation set: 91.13134 %\n", "[171, 4] loss: 0.0037\n", "Accuracy of the network on the validation set: 91.07842 %\n", "[172, 4] loss: 0.0030\n", "Accuracy of the network on the validation set: 90.90909 %\n", "[173, 4] loss: 0.0038\n", "Accuracy of the network on the validation set: 90.57043 %\n", "[174, 4] loss: 0.0030\n", "Accuracy of the network on the validation set: 90.57043 %\n", "[175, 4] loss: 0.0033\n", "Accuracy of the network on the validation set: 90.73976 %\n", "[176, 4] loss: 0.0023\n", "Accuracy of the network on the validation set: 90.87734 %\n", "[177, 4] loss: 0.0023\n", "Accuracy of the network on the validation set: 90.98317 %\n", "[178, 4] loss: 0.0029\n", "Accuracy of the network on the validation set: 90.94084 %\n", "[179, 4] loss: 0.0040\n", "Accuracy of the network on the validation set: 90.83501 %\n", "[180, 4] loss: 0.0049\n", "Accuracy of the network on the validation set: 90.70801 %\n", "[181, 4] loss: 0.0036\n", "Accuracy of the network on the validation set: 90.59160 %\n", "[182, 4] loss: 0.0029\n", "Accuracy of the network on the validation set: 90.52810 %\n", "[183, 4] loss: 0.0025\n", "Accuracy of the network on the validation set: 90.78209 %\n", "[184, 4] loss: 0.0032\n", "Accuracy of the network on the validation set: 91.06784 %\n", "[185, 4] loss: 0.0035\n", "Accuracy of the network on the validation set: 91.39591 %\n", "[186, 4] loss: 0.0029\n", "Accuracy of the network on the validation set: 91.19484 %\n", "[187, 4] loss: 0.0041\n", "Accuracy of the network on the validation set: 91.07842 %\n", "[188, 4] loss: 0.0042\n", "Accuracy of the network on the validation set: 90.72918 %\n", "[189, 4] loss: 0.0017\n", "Accuracy of the network on the validation set: 90.26352 %\n", "[190, 4] loss: 0.0023\n", "Accuracy of the network on the validation set: 90.14711 %\n", "[191, 4] loss: 0.0021\n", "Accuracy of the network on the validation set: 90.34818 %\n", "[192, 4] loss: 0.0019\n", "Accuracy of the network on the validation set: 90.50693 %\n", "[193, 4] loss: 0.0027\n", "Accuracy of the network on the validation set: 90.59160 %\n", "[194, 4] loss: 0.0028\n", "Accuracy of the network on the validation set: 91.13134 %\n", "[195, 4] loss: 0.0027\n", "Accuracy of the network on the validation set: 91.29008 %\n", "[196, 4] loss: 0.0024\n", "Accuracy of the network on the validation set: 91.35358 %\n", "[197, 4] loss: 0.0026\n", "Accuracy of the network on the validation set: 91.17367 %\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[198, 4] loss: 0.0022\n", "Accuracy of the network on the validation set: 90.49635 %\n", "[199, 4] loss: 0.0026\n", "Accuracy of the network on the validation set: 90.25294 %\n", "[200, 4] loss: 0.0024\n", "Accuracy of the network on the validation set: 90.20002 %\n", "Finished Training\n" ] } ], "source": [ "# 实例化网络,准备放到 GPU 上训练\n", "net = SSRN(num_classes=16, k=97)\n", "net.to(device)\n", "\n", "import torch\n", "import torch.optim as optim\n", "# 引入时间模块格式化时间生成唯一训练文件\n", "import time\n", "unique_time = time.strftime(\"%Y%m%d-%H%M%S\")\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "#optimizer = optim.RMSprop(net.parameters())\n", "# 论文中使用的是 RMSprop 优化器,但是最后作者还是使用了 Adam,估计优化效果更好吧\n", "optimizer = optim.Adam(net.parameters(), lr=0.002)\n", "\n", "best_pred = 0\n", "#SAVE_PATH3 = './saved_models/ssnet_best3_up_seed' + str(args.seed) + '.pth' \n", "SSRN_TRAIN_SAVE_PATH = 'SSRN-Train-' + unique_time + '.pth' \n", "#torch.save(net.state_dict(), SAVE_PATH)\n", "\n", "for epoch in range(200): # loop over the dataset multiple times\n", " \n", " running_loss = 0.0\n", " #iters = len(trainloader)\n", " net = net.train()\n", " for i, data in enumerate(trainloader, 0):\n", " # get the inputs\n", " inputs, labels = data\n", " inputs, labels = inputs.to(device), labels.to(device)\n", "\n", " # zero the parameter gradients\n", " optimizer.zero_grad()\n", "\n", " # forward + backward + optimize\n", " outputs = net(inputs.float())\n", " loss = criterion(outputs, labels.long())\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # print statistics\n", " running_loss += loss.item()\n", " if i % 4 == 3: # print every 2000 mini-batches\n", " print('[%d, %5d] loss: %.4f' %\n", " (epoch + 1, i + 1, running_loss / 4))\n", " running_loss = 0.0\n", " #schedular.step()\n", " \n", " correct = 0\n", " total = 0\n", " net = net.eval()\n", " counter = 0 \n", " with torch.no_grad():\n", " for data in validationloader:\n", "# if counter <= 10:\n", "# counter += 1\n", " images, labels = data\n", " images, labels = images.to(device), labels.to(device)\n", " outputs = net(images.float())\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += labels.size(0)\n", " correct += (predicted == labels.long()).sum().item()\n", "\n", " new_pred = correct / total\n", " print('Accuracy of the network on the validation set: %.5f %%' % (\n", " 100 * new_pred))\n", " \n", " if new_pred > best_pred:\n", " logging.info('new_pred %f', new_pred)\n", " logging.info('best_pred %f', best_pred)\n", " torch.save(net.state_dict(), SSRN_TRAIN_SAVE_PATH)\n", " best_pred=new_pred\n", " \n", "print('Finished Training')" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "## Validation Stage Functions\n", "def cal_results(matrix):\n", " shape = np.shape(matrix)\n", " number = 0\n", " sum = 0\n", " AA = np.zeros([shape[0]], dtype=np.float)\n", " for i in range(shape[0]):\n", " number += matrix[i, i]\n", " AA[i] = matrix[i, i] / np.sum(matrix[i, :])\n", " sum += np.sum(matrix[i, :]) * np.sum(matrix[:, i])\n", " OA = number / np.sum(matrix)\n", " AA_mean = np.mean(AA)\n", " pe = sum / (np.sum(matrix) ** 2)\n", " Kappa = (OA - pe) / (1 - pe)\n", " return OA, AA_mean, Kappa, AA\n", "\n", "\n", "def predVisIN(indices, pred, size1, size2):\n", " \n", " if pred.ndim > 1:\n", " pred = np.ravel(pred)\n", " \n", " x = np.zeros(size1*size2)\n", " x[indices] = pred\n", " \n", " y = np.ones((x.shape[0], 3))\n", "\n", " for index, item in enumerate(x):\n", " if item == 0:\n", " y[index] = np.array([230, 230, 230]) / 255. # np.array([255, 255, 255]) / 255.\n", " if item == 1:\n", " y[index] = np.array([255, 0, 0]) / 255.\n", " if item == 2:\n", " y[index] = np.array([0, 255, 0]) / 255.\n", " if item == 3:\n", " y[index] = np.array([0, 0, 255]) / 255.\n", " if item == 4:\n", " y[index] = np.array([255, 255, 0]) / 255.\n", " if item == 5:\n", " y[index] = np.array([0, 255, 255]) / 255.\n", " if item == 6:\n", " y[index] = np.array([255, 0, 255]) / 255.\n", " if item == 7:\n", " y[index] = np.array([192, 192, 192]) / 255.\n", " if item == 8:\n", " y[index] = np.array([128, 128, 128]) / 255.\n", " if item == 9:\n", " y[index] = np.array([128, 0, 0]) / 255.\n", " if item == 10: \n", " y[index] = np.array([128, 128, 0]) / 255.\n", " if item == 11:\n", " y[index] = np.array([0, 128, 0]) / 255.\n", " if item == 12:\n", " y[index] = np.array([128, 0, 128]) / 255.\n", " if item == 13:\n", " y[index] = np.array([0, 128, 128]) / 255.\n", " if item == 14:\n", " y[index] = np.array([0, 0, 128]) / 255.\n", " if item == 15:\n", " y[index] = np.array([255, 165, 0]) / 255.\n", " if item == 16:\n", " y[index] = np.array([255, 215, 0]) / 255.\n", " \n", " y_rgb = np.reshape(y, (size1, size2, 3))\n", " \n", " return y_rgb" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OA, AA_Mean, Kappa: %f, %f, %f, 0.9139591491163086 0.932226390035374 0.9019106981267497\n", "('AA for each class: ', array([0.975 , 0.90743551, 0.88289474, 0.88127854, 0.9082774 ,\n", " 0.99703264, 1. , 0.98863636, 1. , 0.88148984,\n", " 0.89863975, 0.84473198, 1. , 0.9703641 , 0.80337079,\n", " 0.97647059]))\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_5205/1516455691.py:6: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " AA = np.zeros([shape[0]], dtype=np.float)\n" ] } ], "source": [ "# Validation Stage \n", "\n", "from sklearn import metrics, preprocessing\n", "\n", "trained_net = SSRN(num_classes=16, k=97)\n", "\n", "trained_net.load_state_dict(torch.load(SSRN_TRAIN_SAVE_PATH))\n", "trained_net.eval()\n", "trained_net = trained_net.cuda()\n", "\n", "label_val = []\n", "pred_val = []\n", "\n", "with torch.no_grad():\n", " for data in validationloader:\n", " images, labels = data\n", " #label_val = torch.stack([label_val.type_as(labels), labels])\n", " label_val.append(labels)\n", " \n", " images, labels = images.to(device), labels.to(device)\n", " outputs = trained_net(images.float())\n", " _, predicted = torch.max(outputs.data, 1)\n", " #pred_val = torch.stack([pred_val.type_as(predicted), predicted])\n", " pred_val.append(predicted)\n", " \n", "label_val_cpu = [x.cpu() for x in label_val]\n", "pred_val_cpu = [x.cpu() for x in pred_val]\n", "\n", "label_cat = np.concatenate(label_val_cpu)\n", "pred_cat = np.concatenate(pred_val_cpu)\n", "\n", "matrix = metrics.confusion_matrix(label_cat, pred_cat)\n", "\n", "OA, AA_mean, Kappa, AA = cal_results(matrix)\n", "\n", "print('OA, AA_Mean, Kappa: %f, %f, %f, ', OA, AA_mean, Kappa)\n", "print(str((\"AA for each class: \", AA)))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# generate classification maps\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "all_pred = []\n", "\n", "with torch.no_grad():\n", " for data in allloader:\n", " images, _ = data\n", " images, _ = images.to(device), labels.to(device)\n", " outputs = trained_net(images.float())\n", " _, predicted = torch.max(outputs.data, 1)\n", " all_pred.append(predicted)\n", "\n", "all_pred = torch.cat(all_pred)\n", "all_pred = all_pred.cpu().numpy() + 1\n", "\n", "y_pred = predVisIN(all_indices, all_pred, 145, 145)\n", "\n", "\n", "#plt.plot(x, y)\n", "plt.imshow(y_pred)\n", "plt.axis('off')\n", "fig_path = 'SSRN-Train-' + unique_time + '.png'\n", "plt.savefig(fig_path, bbox_inches=0)\n", "#plt.savefig(fig_path, bbox_inches='tight')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "BBP", "language": "python", "name": "bbp" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 2 }