{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch 実装の基本フロー\n", "## MNIST手書き数字の分類" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 畳み込み層+全結合層で構築したネットワークモデルの実装を行い、MNIST分類を学習し推論します。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 事前準備" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## モジュールインポート" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torchvision import datasets, transforms\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from tqdm import tqdm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Datasetの準備" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MNISTデータセット" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#データ前処理 transform を設定\n", "transform = transforms.Compose(\n", " [transforms.ToTensor(), # Tensor変換とshape変換 [H, W, C] -> [C, H, W]\n", " transforms.Normalize((0.5, ), (0.5, ))]) # 標準化 平均:0.5 標準偏差:0.5\n", "\n", "#訓練用Datasetを作成\n", "train_dataset = datasets.MNIST(root='./data', \n", " train=True,\n", " download=True,\n", " transform=transform)\n", "\n", "#検証用Datasetを作成\n", "val_dataset = datasets.MNIST(root='./data', \n", " train=False, \n", " download=True, \n", " transform=transform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Detaloaderの作成" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#訓練用 Dataloder\n", "train_dataloader = torch.utils.data.DataLoader(train_dataset,\n", " batch_size=64,\n", " shuffle=True)\n", "\n", "#検証用 Dataloder\n", "val_dataloader = torch.utils.data.DataLoader(val_dataset, \n", " batch_size=64,\n", " shuffle=False)\n", "\n", "# 辞書型変数にまとめる\n", "dataloaders_dict = {\"train\": train_dataloader, \"val\": val_dataloader}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 動作の確認" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "imges size = torch.Size([64, 1, 28, 28])\n", "labels size = torch.Size([64])\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAPd0lEQVR4nO3df+wUdX7H8edLDi8V5Sq1UMpZpVdIVUw9xF949iz08FcapQ2Xs02lQf3aniRnUxt/NHpq02ptz+aSphAMRu5yh2f8SbR3ao1WavUUqAgenHAUFeUgxKv80PgD3v1j51u+wndnl52Znf1+P69Hstndee/MvNnw+s7szsx+FBGY2fB3WN0NmFl3OOxmiXDYzRLhsJslwmE3S4TDbpYIh30IkrRZ0u+3+dqQ9Fsdrqfjea33OOzWNZI+K2mxpDck7ZL035IuqLuvVDjs1k2fAd4Cvgx8DrgJuF/S8TX2lAyHfYiTdLqkFyT9r6Stkv5F0uEHvOxCSZsk7ZD0j5IOGzD/PEnrJP1C0hOSjquq14jYExG3RMTmiNgXEY8B/wOcWtU6bT+HfejbC/wlcAxwFjAT+PoBr5kNTAOmAhcD8wAkXQLcCPwh8KvAcmBpOyuV9K/ZH5jBbq+2uYxxwGTgtXZeb8XI58YPPZI2A1dExL8PUrsG+HJEzM6eB3BBRPwoe/514I8iYqakHwIPRMTirHYYsBs4ISLeyOadFBEbK/g3jAR+CPwsIq4qe/l2MG/ZhzhJkyU9JunnknYCf09jKz/QWwMevwH8evb4OODb/Vtk4F1AwISKez4M+C7wETC/ynXZfg770LcAWE9jCzyaxm65DnjNsQMe/wbwTvb4LeCqiPjlAbdfioj/arVSSQsl7W5ya7pbLknAYmAcjT2Mj9v/p1oRDvvQdxSwE9gt6beBvxjkNX8t6WhJxwLfAH6QTV8I3CDpJABJn5M0p52VRsSfR8SRTW4n5cy6ADgB+IOI+KDNf6OVwGEf+q4F/hjYBdzN/iAP9CiwEngFeJzGlpWIeBj4B+C+7CPAWqCy497ZN/1XAacAPx+wJ/AnVa3T9vMXdGaJ8JbdLBEOu1kiHHazRDjsZon4TDdXlp2RZWYViogDz7MACm7ZJZ0v6aeSNkq6vsiyzKxaHR96kzQCeB34CrAFeBm4NCJ+kjOPt+xmFatiy346sDEiNkXER8B9NK6oMrMeVCTsE/j0BRZbGOQCCkl9klZIWlFgXWZWUJEv6AbbVThoNz0iFgGLwLvxZnUqsmXfwqevpvo8+6+mMrMeUyTsLwOTJE3Mfgbpa8Cyctoys7J1vBsfEZ9Img88AYwA7okI/7yQWY/q6lVv/sxuVr1KTqoxs6HDYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIro6ZHMvmz17dm79jDPO6FInB5s6dWpufcaMGU1rCxYsyJ13/vz5ufUHHnggt37rrbfm1tetW9e0tnfv3tx5rVzespslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmifAorpl9+/bl1rv5Pg0nJ510UtPa+vXru9hJOpqN4lropBpJm4FdwF7gk4iYVmR5ZladMs6g+72I2FHCcsysQv7MbpaIomEP4ElJKyX1DfYCSX2SVkhaUXBdZlZA0d34syPiHUljgackrY+I5wa+ICIWAYugt7+gMxvuCm3ZI+Kd7H478DBwehlNmVn5Og67pFGSjup/DMwC1pbVmJmVq8hu/DjgYUn9y/l+RPyolK5qMGfOnNz6rFmzmtamT5+eO++ECRNy66+//npuvZW868IXLlyYO++1116bW58yZUpHPfXL+52A22+/vdCy7dB0HPaI2AT8Tom9mFmFfOjNLBEOu1kiHHazRDjsZolw2M0S4UtcSzB69Ojc+sSJE3Prq1evLrOdQ3LiiSfm1tesWVNo+UuWLGlamzdvXqFl2+CaXeLqLbtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulggP2VyCnTt35tbrPI5et5NPPrlpbdSoUbnz7tmzp+x2kuYtu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCB9nt0rlXQ/v4+jd5S27WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIH2e3Sr3wwgt1t2CZllt2SfdI2i5p7YBpYyQ9JWlDdn90tW2aWVHt7MbfC5x/wLTrgacjYhLwdPbczHpYy7BHxHPAuwdMvhjoH9dnCXBJyX2ZWck6/cw+LiK2AkTEVkljm71QUh/Q1+F6zKwklX9BFxGLgEUwfAd2NBsKOj30tk3SeIDsfnt5LZlZFToN+zJgbvZ4LvBoOe2YWVVa7sZLWgqcCxwjaQvwTeAO4H5JlwNvAnOqbNKqM3PmzEqXv3z58kqXb+1rGfaIuLRJqdr/JWZWKp8ua5YIh90sEQ67WSIcdrNEOOxmifAlrok79dRT627BusRbdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sET7OPswdfvjhufWxY5v+olhbXnrppdz622+/XWj5Vh5v2c0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRPg4+zA3ZsyY3Pp5551XaPnPP/98bn3Xrl2Flm/l8ZbdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEj7MPc1dccUWh+VetWpVbv/nmmwst37qn5ZZd0j2StktaO2DaLZLelvRKdruw2jbNrKh2duPvBc4fZPo/R8Qp2e3fym3LzMrWMuwR8Rzwbhd6MbMKFfmCbr6kV7Pd/KObvUhSn6QVklYUWJeZFdRp2BcAXwBOAbYC32r2wohYFBHTImJah+sysxJ0FPaI2BYReyNiH3A3cHq5bZlZ2ToKu6TxA57OBtY2e62Z9YaWx9klLQXOBY6RtAX4JnCupFOAADYDV1XYo7WQ99vw1113XaFlf/DBB7n1999/v9DyrXtahj0iLh1k8uIKejGzCvl0WbNEOOxmiXDYzRLhsJslwmE3S4QvcR0GzjnnnKa1I444otCyly9fXmj+IiTl1keNGpVbHzduXNPa3Llzc+e97LLLcuvLli3Lrd9000259ffeey+3XgVv2c0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRPg4+zBQ5OecP/7449z6M8880/GyAcaPH9+0Nn369Nx5Z82alVsv+jPZRVx99dW59T179uTWb7jhhjLbaYu37GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZInycfQhodbz5rLPO6njZd911V279xRdfzK3feeedufUrr7yyaW306NG581q5vGU3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLRzpDNxwLfAX4N2AcsiohvSxoD/AA4nsawzV+NiF9U12q6Jk2alFsfMWJEx8u+6KKLcuszZszIrZ922mkdr7uoDz/8MLf+yCOPNK3de++9hdbd19eXW9+0aVOh5VehnS37J8BfRcQJwJnA1ZJOBK4Hno6IScDT2XMz61Etwx4RWyNiVfZ4F7AOmABcDCzJXrYEuKSqJs2suEP6zC7peOCLwI+BcRGxFRp/EICxZTdnZuVp+9x4SUcCDwLXRMTOVuNwDZivD8j/gGNmlWtryy5pJI2gfy8iHsomb5M0PquPB7YPNm9ELIqIaRExrYyGzawzLcOuxiZ8MbAuIgZeIrUM6B8Kcy7waPntmVlZ2tmNPxv4U2CNpFeyaTcCdwD3S7oceBOYU02LNnny5MqWPWXKlMqW3cqGDRty6ytXrsytt7q8dvXq1YfcU7uefPLJypZdlZZhj4j/BJp9QJ9ZbjtmVhWfQWeWCIfdLBEOu1kiHHazRDjsZolw2M0S4Z+SHgImTpxYdwtNvfnmm7n1vJ+qXrp0ae68O3bs6KgnG5y37GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhQR3VuZ1L2VDSNnnnlmbv3ZZ59tWhs5cmTuvI8//nhu/bbbbsutr1+/Pre+e/fu3LqVLyIGvSTdW3azRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBE+zm42zPg4u1niHHazRDjsZolw2M0S4bCbJcJhN0uEw26WiJZhl3SspGckrZP0mqRvZNNvkfS2pFey24XVt2tmnWp5Uo2k8cD4iFgl6ShgJXAJ8FVgd0T8U9sr80k1ZpVrdlJNyxFhImIrsDV7vEvSOmBCue2ZWdUO6TO7pOOBLwI/zibNl/SqpHskHd1knj5JKyStKNSpmRXS9rnxko4E/gP4u4h4SNI4YAcQwN/S2NWf12IZ3o03q1iz3fi2wi5pJPAY8EREHDRSX7bFfywiprRYjsNuVrGOL4SRJGAxsG5g0LMv7vrNBtYWbdLMqtPOt/FfApYDa4B92eQbgUuBU2jsxm8Grsq+zMtblrfsZhUrtBtfFofdrHq+nt0scQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslouUPTpZsB/DGgOfHZNN6Ua/21qt9gXvrVJm9Hdes0NXr2Q9aubQiIqbV1kCOXu2tV/sC99apbvXm3XizRDjsZomoO+yLal5/nl7trVf7AvfWqa70VutndjPrnrq37GbWJQ67WSJqCbuk8yX9VNJGSdfX0UMzkjZLWpMNQ13r+HTZGHrbJa0dMG2MpKckbcjuBx1jr6beemIY75xhxmt97+oe/rzrn9kljQBeB74CbAFeBi6NiJ90tZEmJG0GpkVE7SdgSPpdYDfwnf6htSTdCbwbEXdkfyiPjojreqS3WzjEYbwr6q3ZMON/Ro3vXZnDn3eiji376cDGiNgUER8B9wEX19BHz4uI54B3D5h8MbAke7yExn+WrmvSW0+IiK0RsSp7vAvoH2a81vcup6+uqCPsE4C3BjzfQm+N9x7Ak5JWSuqru5lBjOsfZiu7H1tzPwdqOYx3Nx0wzHjPvHedDH9eVB1hH2xoml46/nd2REwFLgCuznZXrT0LgC/QGANwK/CtOpvJhhl/ELgmInbW2ctAg/TVlfetjrBvAY4d8PzzwDs19DGoiHgnu98OPEzjY0cv2dY/gm52v73mfv5fRGyLiL0RsQ+4mxrfu2yY8QeB70XEQ9nk2t+7wfrq1vtWR9hfBiZJmijpcOBrwLIa+jiIpFHZFydIGgXMoveGol4GzM0ezwUerbGXT+mVYbybDTNOze9d7cOfR0TXb8CFNL6R/xnwN3X00KSv3wRWZ7fX6u4NWEpjt+5jGntElwO/AjwNbMjux/RQb9+lMbT3qzSCNb6m3r5E46Phq8Ar2e3Cut+7nL668r75dFmzRPgMOrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEf8H/mrOsdln0l4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "batch_iterator = iter(dataloaders_dict[\"train\"]) # イテレータに変換\n", "imges, labels = next(batch_iterator) # 1番目の要素を取り出す\n", "print(\"imges size = \", imges.size())\n", "print(\"labels size = \", labels.size())\n", "\n", "#試しに1枚 plot してみる\n", "plt.imshow(imges[0].numpy().reshape(28,28), cmap='gray')\n", "plt.title(\"label = {}\".format(labels[0].numpy()))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# ネットワークモデルの作成\n", "- 畳み込み層CNNのモデル" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# 畳み込み層+全結合層のネットワークモデル\n", "class Net(nn.Module):\n", "\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 32, 3, 1) #畳み込み層\n", " self.conv2 = nn.Conv2d(32, 64, 3, 1) #畳み込み層\n", " self.fc1 = nn.Linear(9216, 128) #全結合層\n", " self.fc2 = nn.Linear(128, 10) #全結合層\n", "\n", " def forward(self, x):\n", " x = self.conv1(x) # (Batch, 1, 28, 28) -> (Batch, 32, 26, 26)\n", " x = F.relu(x)\n", " x = self.conv2(x) # (Batch, 32, 26, 26) -> (Batch, 64, 24, 24)\n", " x = F.relu(x)\n", " x = F.max_pool2d(x, 2) # (Batch, 64, 24, 24) -> (Batch, 64, 12, 12)\n", " x = torch.flatten(x, 1) # (Batch, 64, 12, 12) -> (Batch, 9216)\n", " x = self.fc1(x) # (Batch, 9216) -> (Batch, 128)\n", " x = self.fc2(x) # (Batch, 128) -> (Batch, 10)\n", "\n", " return x" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "#モデル作成\n", "net = Net()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Net(\n", " (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n", " (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", " (fc1): Linear(in_features=9216, out_features=128, bias=True)\n", " (fc2): Linear(in_features=128, out_features=10, bias=True)\n", ")\n" ] } ], "source": [ "#ネットワークのレイヤー確認\n", "print(net)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 損失関数の定義" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# nn.CrossEntropyLoss() はソフトマックス関数+クロスエントロピー誤差\n", "criterion = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 最適化手法の設定" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "optimizer = optim.Adam(net.parameters(), lr=0.001)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 学習・検証の実施" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# モデルを学習させる関数を作成\n", "def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):\n", " \n", " # epochのループ\n", " for epoch in range(num_epochs):\n", " print('Epoch {}/{}'.format(epoch+1, num_epochs))\n", " print('-------------')\n", "\n", " # epochごとの学習と検証のループ\n", " for phase in ['train', 'val']:\n", " if phase == 'train':\n", " net.train() # モデルを訓練モードに\n", " else:\n", " net.eval() # モデルを検証モードに\n", "\n", " epoch_loss = 0.0 # epochの損失和\n", " epoch_corrects = 0 # epochの正解数\n", "\n", " # 未学習時の検証性能を確かめるため、epoch=0の訓練は省略\n", " if (epoch == 0) and (phase == 'train'):\n", " continue\n", "\n", " # データローダーからミニバッチを取り出すループ\n", " for i , (inputs, labels) in tqdm(enumerate(dataloaders_dict[phase])):\n", "\n", " # optimizerを初期化\n", " optimizer.zero_grad()\n", "\n", " # 順伝搬(forward)計算\n", " with torch.set_grad_enabled(phase == 'train'): # 訓練モードのみ勾配を算出\n", " outputs = net(inputs) # 順伝播\n", " loss = criterion(outputs, labels) # 損失を計算\n", " _, preds = torch.max(outputs, 1) # ラベルを予測\n", " \n", " \n", " # 訓練時はバックプロパゲーション\n", " if phase == 'train':\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # イタレーション結果の計算\n", " # lossの合計を更新\n", " epoch_loss += loss.item() * inputs.size(0) \n", " # 正解数の合計を更新\n", " epoch_corrects += torch.sum(preds == labels.data)\n", "\n", " # epochごとのlossと正解率を表示\n", " epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)\n", " epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)\n", " \n", " print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "3it [00:00, 23.63it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "-------------\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "157it [00:05, 30.87it/s]\n", "1it [00:00, 9.87it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "val Loss: 2.3047 Acc: 0.1128\n", "Epoch 2/3\n", "-------------\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "938it [01:11, 13.09it/s]\n", "3it [00:00, 28.23it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "train Loss: 0.1247 Acc: 0.9617\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "157it [00:04, 33.20it/s]\n", "2it [00:00, 13.95it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "val Loss: 0.0486 Acc: 0.9844\n", "Epoch 3/3\n", "-------------\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "938it [01:26, 10.81it/s]\n", "3it [00:00, 26.54it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "train Loss: 0.0447 Acc: 0.9862\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "157it [00:05, 26.97it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "val Loss: 0.0338 Acc: 0.9894\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# 学習・検証を実行する\n", "num_epochs = 3\n", "train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# テストデータに対する予測" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAASjUlEQVR4nO3dfbBU9X3H8fcnSKIVo4CiiDwYo6nWcVDRmBEdjIkh2AwaNZFMUtpGr51RU2cso2PjoK3tMObByHTGDD5UNEakEh9rLIhG6qiJaFQg+ACWAIKABRWsCsi3f5yDWa+7Z+/dZ+7v85rZubvnu+fsd8+9n3ue9uxRRGBmfd+n2t2AmbWGw26WCIfdLBEOu1kiHHazRDjsZolw2LuR9BtJ57V6XGs8SW9IGpvfv1rSv7W7p3bqs2GXtELSV9rdRzmSfi5pS37bKmlbyeNft6GfWZJ+WMN4nynpe+dth6Qf9eJ1P8jH2yjpYUmH9v4dVBcRUyPioh721PJ50Qp9NuydLCL+LiIGRMQA4F+Bu3Y+joiv93Z6kvo1vsvqIuKDkr4HAAcBHwD/0YvJ/HM+7ghgM3BjuSdJ2q3uhpuoQfOiqZILu6SBkh6UtEHSpvz+Qd2edoik30l6W9J9kgaVjH+CpCclvSXpBUnjmtDjbpLmSFqXv85jkr5QUp8labqkuZLeBb4kaYikX0t6R9LTkqZJeqRknCMlPZq/56WSzsiH/wA4C7gyXxrV88f5bWBFRPyutyNGxBZgFnBk3tc0Sb+UdJekzcC5kvpJulLSa5LelHSHpH1K3uP3Ja3Mf7dTSqefT++mksfj8vn0dj7OdzplXjRLcmEne8//DowkW5q8B3Tflvsr4G+BA4HtwHQAScOA/wSuAQYB/wDMkbRf9xeRNCIP6oga+7wfOAQ4AHgJmNmt/l3gSmAv4BlgBrAB2B/oAiaX9PJZYB5wM7Bv/v5ukfT5iJgOzCFfwkbEOfk48/L+y93urtDz5DJ99kje4yTg9yWDz8qnt3fe4xTgNGAs2ZJzG3BdPv5o4GdkITsIGJW/13Kv9XngQeBHwGDgWGBJp8yLpomIPnkDVgBf6cHzRgObSh7/BphW8vgIYCvQD7gMuL3b+P8FTC4Z97xe9nkV8IsqzzkA2AHsnj+eBcwoqe+e10eWDPsx8Eh+fzIwr9s0ZwKXlUzvh3XO70PJ/jEO68U4s8j+2b4FrAXu2fkegGnA3G7P/x/gxJLHBwP/B4hsc+jWktre+TwZWzK9m/L7VwN3FvTU8nnRiltHbwc1g6Q/I1sajAcG5oP3ktQvIj7MH68qGeWPQH+ypcRI4BxJ3yip9wcea3CPu5H9cX4zf90dZH/Qg4HXy/R4QF5fXTJsFdk/MvK+T5b0Vkl9N2BTA9ueDMyPiNerPvPj/iUirqlQ++g9ShIwHHhIUunZW58imy8Hlj4/It6W9HaF6Q4Hlveyz96odV40VYqr8ZcCXwC+GBGfBU7Oh6vkOcNL7o8gW118k+yP6faI2KfktmdETGtwj39Dtrp6CtkS6s/L9Fj6B/9G/nhYhfewimwpWdr3gIi4pMy0shfKtu+7713eebun23MFfI/Gr7Z+1Fdki8zXgS93ex+7R8SbZGsGH71nSXuTzbtyVpFtIhW+Zsm0OmFe1K2vh72/pN1LbruRbeO+B7yV73ibWma870o6Il8L+Cfg7nyp/wvgG5K+lu8s2j3f0dN9B1+99gLeB/4X2JNsH0FFEfE+8ABwdd7TkcB3Sp5yL3C0pG9L6i/p0/mOxsPy+jrgc92m+eUo2bvc7XZmtxZOAfYhWw3/SN5LSDqhd2+/op8D0yQNz6c/pGQtazbwTUlflPQZsnm2o8J0bgP+UtKZ+e9xP0lH5bWmzItO0NfD/hBZsHferiLbibMH2ZL6aeDhMuPdDtxKtsTcHfgBQESsAiYCV5DtDFtFttPoE/Mx30G3pcYddDfn038DWAQ80YNxLiBbld0A3ATcSXboh4jYBHyNbI1hLbCGLAz983FnAMflO5xm1dDvZGB2RLzXbfhwsu3xP9QwzXKuBR4BHs330D8JHAMQEb8nW2u7m2xzZiXZ7/gTImI5f/o9bgIWAn+Rl5s1L9pO+Q4F62MkXU+2Q++CNvZwHtlOqqvb1YP9icPeR+Sr7kG2FP0S2SHCSRFRbs3FEpTc3vg+bG+yzY8DyFb/r3HQrZSX7GaJ6Os76Mws19LV+G4fhjCzJogIlRte15Jd0nhJL0taJunyeqZlZs1V8za7stMqXwG+SnZc8xmyvb8Vj6l6yW7WfM1Ysh8PLIuI1yJiK9kJBBPrmJ6ZNVE9YR/Gx0/GWM3HP5sNgKQuSQslLazjtcysTvXsoCu3qvCJ1fSImEH2EUSvxpu1UT1L9tV8/Myqg8g+c21mHaiesD8DHCrpYEmfBs4l+3YVM+tANa/GR8R2SReRfVNLP+CWiFjSsM7MrKFa+nFZb7ObNV9TPlRjZrsOh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRJR8/XZASStADYDHwLbI2JMI5oys8arK+y5UyLizQZMx8yayKvxZomoN+wBzJX0rKSuck+Q1CVpoaSFdb6WmdVBEVH7yNKBEbFG0hBgHnBxRCwoeH7tL2ZmPRIRKje8riV7RKzJf64H7gGOr2d6ZtY8NYdd0p6S9tp5HzgNWNyoxsysserZG78/cI+kndP5ZUQ83JCuzKzh6tpm7/WLeZvdrOmass1uZrsOh90sEQ67WSIcdrNEOOxmiWjEiTBJOPvssyvWzj///MJx16xZU1h///33C+t33HFHYf2NN96oWFu2bFnhuJYOL9nNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0T4rLceeu211yrWRo0a1bpGyti8eXPF2pIlS1rYSWdZvXp1xdq1115bOO7Chbvut6j5rDezxDnsZolw2M0S4bCbJcJhN0uEw26WCIfdLBE+n72His5ZP+qoowrHXbp0aWH98MMPL6wfc8wxhfVx48ZVrJ1wwgmF465ataqwPnz48MJ6PbZv315Y37BhQ2F96NChNb/2ypUrC+u78nH2SrxkN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S4fPZ+4CBAwdWrI0ePbpw3Geffbawftxxx9XUU09U+778V155pbBe7fMLgwYNqli78MILC8e94YYbCuudrObz2SXdImm9pMUlwwZJmifp1fxn5b82M+sIPVmNvxUY323Y5cD8iDgUmJ8/NrMOVjXsEbEA2Nht8ERgZn5/JnBGg/syswar9bPx+0fEWoCIWCtpSKUnSuoCump8HTNrkKafCBMRM4AZ4B10Zu1U66G3dZKGAuQ/1zeuJTNrhlrDfj8wOb8/GbivMe2YWbNUPc4u6U5gHLAvsA6YCtwLzAZGACuBcyKi+068ctPyarz12FlnnVVYnz17dmF98eLFFWunnHJK4bgbN1b9c+5YlY6zV91mj4hJFUqn1tWRmbWUPy5rlgiH3SwRDrtZIhx2s0Q47GaJ8Cmu1jZDhlT8lDUAixYtqmv8s88+u2Jtzpw5hePuynzJZrPEOexmiXDYzRLhsJslwmE3S4TDbpYIh90sEb5ks7VNta9z3m+//QrrmzZtKqy//PLLve6pL/OS3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhM9nt6Y68cQTK9YeffTRwnH79+9fWB83blxhfcGCBYX1vsrns5slzmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmifD57NZUEyZMqFirdhx9/vz5hfWnnnqqpp5SVXXJLukWSeslLS4ZdpWk1yU9n98q/0bNrCP0ZDX+VmB8meHXRcTo/PZQY9sys0arGvaIWABsbEEvZtZE9eygu0jSi/lq/sBKT5LUJWmhpIV1vJaZ1anWsN8AHAKMBtYCP6n0xIiYERFjImJMja9lZg1QU9gjYl1EfBgRO4AbgeMb25aZNVpNYZc0tOThmcDiSs81s85Q9Ti7pDuBccC+klYDU4FxkkYDAawALmhij9bB9thjj8L6+PHlDuRktm7dWjju1KlTC+vbtm0rrNvHVQ17REwqM/jmJvRiZk3kj8uaJcJhN0uEw26WCIfdLBEOu1kifIqr1WXKlCmF9aOPPrpi7eGHHy4c98knn6ypJyvPS3azRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBG+ZLMVOv300wvr9957b2H93XffrVgrOv0V4Omnny6sW3m+ZLNZ4hx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgifz564wYMHF9anT59eWO/Xr19h/aGHKl/z08fRW8tLdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEVXPZ5c0HLgNOADYAcyIiOslDQLuAkaRXbb5WxGxqcq0fD57i1U7Dl7tWPexxx5bWF++fHlhveic9WrjWm3qOZ99O3BpRBwOnABcKOkI4HJgfkQcCszPH5tZh6oa9ohYGxHP5fc3A0uBYcBEYGb+tJnAGc1q0szq16ttdkmjgKOB3wL7R8RayP4hAEMa3ZyZNU6PPxsvaQAwB7gkIt6Rym4WlBuvC+iqrT0za5QeLdkl9ScL+h0R8at88DpJQ/P6UGB9uXEjYkZEjImIMY1o2MxqUzXsyhbhNwNLI+KnJaX7gcn5/cnAfY1vz8wapSeH3sYC/w0sIjv0BnAF2Xb7bGAEsBI4JyI2VpmWD7212GGHHVZYf+mll+qa/sSJEwvrDzzwQF3Tt96rdOit6jZ7RDwBVNpAP7WepsysdfwJOrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIf5V0HzBy5MiKtblz59Y17SlTphTWH3zwwbqmb63jJbtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulggfZ+8Duroqf+vXiBEj6pr2448/Xliv9n0I1jm8ZDdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuHj7LuAsWPHFtYvvvjiFnViuzIv2c0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRFQ9zi5pOHAbcADZ9dlnRMT1kq4Czgc25E+9IiIealajKTvppJMK6wMGDKh52suXLy+sb9mypeZpW2fpyYdqtgOXRsRzkvYCnpU0L69dFxE/bl57ZtYoVcMeEWuBtfn9zZKWAsOa3ZiZNVavttkljQKOBn6bD7pI0ouSbpE0sMI4XZIWSlpYV6dmVpceh13SAGAOcElEvAPcABwCjCZb8v+k3HgRMSMixkTEmAb0a2Y16lHYJfUnC/odEfErgIhYFxEfRsQO4Ebg+Oa1aWb1qhp2SQJuBpZGxE9Lhg8tedqZwOLGt2dmjdKTvfEnAt8DFkl6Ph92BTBJ0mgggBXABU3p0OrywgsvFNZPPfXUwvrGjRsb2Y61UU/2xj8BqEzJx9TNdiH+BJ1ZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhFp5yV1Jvr6vWZNFRLlD5V6ym6XCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJaPUlm98E/ljyeN98WCfq1N46tS9wb7VqZG8jKxVa+qGaT7y4tLBTv5uuU3vr1L7AvdWqVb15Nd4sEQ67WSLaHfYZbX79Ip3aW6f2Be6tVi3pra3b7GbWOu1esptZizjsZoloS9gljZf0sqRlki5vRw+VSFohaZGk59t9fbr8GnrrJS0uGTZI0jxJr+Y/y15jr029XSXp9XzePS9pQpt6Gy7pMUlLJS2R9Pf58LbOu4K+WjLfWr7NLqkf8ArwVWA18AwwKSL+0NJGKpC0AhgTEW3/AIakk4EtwG0RcWQ+7FpgY0RMy/9RDoyIyzqkt6uALe2+jHd+taKhpZcZB84A/po2zruCvr5FC+ZbO5bsxwPLIuK1iNgKzAImtqGPjhcRC4Dul2SZCMzM788k+2NpuQq9dYSIWBsRz+X3NwM7LzPe1nlX0FdLtCPsw4BVJY9X01nXew9grqRnJXW1u5ky9o+ItZD98QBD2txPd1Uv491K3S4z3jHzrpbLn9erHWEv9/1YnXT878SIOAb4OnBhvrpqPdOjy3i3SpnLjHeEWi9/Xq92hH01MLzk8UHAmjb0UVZErMl/rgfuofMuRb1u5xV085/r29zPRzrpMt7lLjNOB8y7dl7+vB1hfwY4VNLBkj4NnAvc34Y+PkHSnvmOEyTtCZxG512K+n5gcn5/MnBfG3v5mE65jHely4zT5nnX9sufR0TLb8AEsj3yy4F/bEcPFfr6HPBCflvS7t6AO8lW67aRrRF9HxgMzAdezX8O6qDebgcWAS+SBWtom3obS7Zp+CLwfH6b0O55V9BXS+abPy5rlgh/gs4sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S8T/AzzvhOkEgmGIAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "batch_iterator = iter(dataloaders_dict[\"val\"]) # イテレータに変換\n", "imges, labels = next(batch_iterator) # 1番目の要素を取り出す\n", "\n", "net.eval() #推論モード\n", "with torch.set_grad_enabled(False): # 推論モードでは勾配を算出しない\n", " outputs = net(imges) # 順伝播\n", " _, preds = torch.max(outputs, 1) # ラベルを予測\n", " \n", "#テストデータの予測結果を描画\n", "plt.imshow(imges[0].numpy().reshape(28,28), cmap='gray')\n", "plt.title(\"Label: Target={}, Predict={}\".format(labels[0], preds[0].numpy()))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以上" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }