{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.6.8\n", "IPython 7.2.0\n", "\n", "torch 1.0.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Multilayer Perceptron From Scratch (Sigmoid activation, MSE Loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of a 1-hidden layer multi-layer perceptron from scratch using\n", "- sigmoid activation in the hidden layer\n", "- sigmoid activation in the output layer\n", "- Mean Squared Error loss function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import torch\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "from torch.utils.data import DataLoader\n", "import torch.nn.functional as F\n", "import torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Settings and Dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image batch dimensions: torch.Size([100, 1, 28, 28])\n", "Image label dimensions: torch.Size([100])\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "RANDOM_SEED = 1\n", "BATCH_SIZE = 100\n", "NUM_EPOCHS = 50\n", "\n", "##########################\n", "### MNIST DATASET\n", "##########################\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "train_dataset = datasets.MNIST(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=BATCH_SIZE, \n", " shuffle=True)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=BATCH_SIZE, \n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Implementation" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "class MultilayerPerceptron():\n", "\n", " def __init__(self, num_features, num_hidden, num_classes):\n", " super(MultilayerPerceptron, self).__init__()\n", " \n", " self.num_classes = num_classes\n", " \n", " # hidden 1\n", " self.weight_1 = torch.zeros(num_hidden, num_features, \n", " dtype=torch.float).normal_(0.0, 0.1)\n", " self.bias_1 = torch.zeros(num_hidden, dtype=torch.float)\n", " \n", " # output\n", " self.weight_o = torch.zeros(self.num_classes, num_hidden, \n", " dtype=torch.float).normal_(0.0, 0.1)\n", " self.bias_o = torch.zeros(self.num_classes, dtype=torch.float)\n", " \n", " def forward(self, x):\n", " # hidden 1\n", " \n", " # input dim: [n_hidden, n_features] dot [n_features, n_examples] .T\n", " # output dim: [n_examples, n_hidden]\n", " z_1 = torch.mm(x, self.weight_1.t()) + self.bias_1\n", " a_1 = torch.sigmoid(z_1)\n", "\n", " # hidden 2\n", " # input dim: [n_classes, n_hidden] dot [n_hidden, n_examples] .T\n", " # output dim: [n_examples, n_classes]\n", " z_2 = torch.mm(a_1, self.weight_o.t()) + self.bias_o\n", " a_2 = torch.sigmoid(z_2)\n", " return a_1, a_2\n", "\n", " def backward(self, x, a_1, a_2, y): \n", " \n", " #########################\n", " ### Output layer weights\n", " #########################\n", " \n", " # onehot encoding\n", " y_onehot = torch.FloatTensor(y.size(0), self.num_classes)\n", " y_onehot.zero_()\n", " y_onehot.scatter_(1, y.view(-1, 1).long(), 1)\n", " \n", "\n", " # Part 1: dLoss/dOutWeights\n", " ## = dLoss/dOutAct * dOutAct/dOutNet * dOutNet/dOutWeight\n", " ## where DeltaOut = dLoss/dOutAct * dOutAct/dOutNet\n", " ## for convenient re-use\n", " \n", " # input/output dim: [n_examples, n_classes]\n", " dloss_da2 = 2.*(a_2 - y_onehot) / y.size(0)\n", "\n", " # input/output dim: [n_examples, n_classes]\n", " da2_dz2 = a_2 * (1. - a_2) # sigmoid derivative\n", "\n", " # output dim: [n_examples, n_classes]\n", " delta_out = dloss_da2 * da2_dz2 # \"delta (rule) placeholder\"\n", "\n", " # gradient for output weights\n", " \n", " # [n_examples, n_hidden]\n", " dz2__dw_out = a_1\n", " \n", " # input dim: [n_classlabels, n_examples] dot [n_examples, n_hidden]\n", " # output dim: [n_classlabels, n_hidden]\n", " dloss__dw_out = torch.mm(delta_out.t(), dz2__dw_out)\n", " dloss__db_out = torch.sum(delta_out, dim=0)\n", " \n", "\n", " ################################# \n", " # Part 2: dLoss/dHiddenWeights\n", " ## = DeltaOut * dOutNet/dHiddenAct * dHiddenAct/dHiddenNet * dHiddenNet/dWeight\n", " \n", " # [n_classes, n_hidden]\n", " dz2__a1 = self.weight_o\n", " \n", " # output dim: [n_examples, n_hidden]\n", " dloss_a1 = torch.mm(delta_out, dz2__a1)\n", " \n", " # [n_examples, n_hidden]\n", " da1__dz1 = a_1 * (1. - a_1) # sigmoid derivative\n", " \n", " # [n_examples, n_features]\n", " dz1__dw1 = x\n", " \n", " # output dim: [n_hidden, n_features]\n", " dloss_dw1 = torch.mm((dloss_a1 * da1__dz1).t(), dz1__dw1)\n", " dloss_db1 = torch.sum((dloss_a1 * da1__dz1), dim=0)\n", "\n", " return dloss__dw_out, dloss__db_out, dloss_dw1, dloss_db1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "####################################################\n", "##### Training and evaluation wrappers\n", "###################################################\n", "\n", "def to_onehot(y, num_classes):\n", " y_onehot = torch.FloatTensor(y.size(0), num_classes)\n", " y_onehot.zero_()\n", " y_onehot.scatter_(1, y.view(-1, 1).long(), 1).float()\n", " return y_onehot\n", "\n", "\n", "def loss_func(targets_onehot, probas_onehot):\n", " return torch.mean(torch.mean((targets_onehot - probas_onehot)**2, dim=0))\n", "\n", "\n", "def compute_mse(net, data_loader):\n", " curr_mse, num_examples = torch.zeros(model.num_classes).float(), 0\n", " with torch.no_grad():\n", " for features, targets in data_loader:\n", " features = features.view(-1, 28*28)\n", " logits, probas = net.forward(features)\n", " y_onehot = to_onehot(targets, model.num_classes)\n", " loss = torch.sum((y_onehot - probas)**2, dim=0)\n", " num_examples += targets.size(0)\n", " curr_mse += loss\n", "\n", " curr_mse = torch.mean(curr_mse/num_examples, dim=0)\n", " return curr_mse\n", "\n", "\n", "def train(model, data_loader, num_epochs,\n", " learning_rate=0.1):\n", " \n", " minibatch_cost = []\n", " epoch_cost = []\n", " \n", " for e in range(num_epochs):\n", " \n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.view(-1, 28*28)\n", " \n", " #### Compute outputs ####\n", " a_1, a_2 = model.forward(features)\n", "\n", " #### Compute gradients ####\n", " dloss__dw_out, dloss__db_out, dloss_dw1, dloss_db1 = \\\n", " model.backward(features, a_1, a_2, targets)\n", "\n", " #### Update weights ####\n", " model.weight_1 -= learning_rate * dloss_dw1\n", " model.bias_1 -= learning_rate * dloss_db1\n", " model.weight_o -= learning_rate * dloss__dw_out\n", " model.bias_o -= learning_rate * dloss__db_out\n", " \n", " #### Logging ####\n", " curr_cost = loss_func(to_onehot(targets, model.num_classes), a_2)\n", " minibatch_cost.append(curr_cost)\n", " if not batch_idx % 50:\n", " print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n", " %(e+1, NUM_EPOCHS, batch_idx, \n", " len(train_loader), curr_cost))\n", " \n", " #### Logging #### \n", " curr_cost = compute_mse(model, train_loader)\n", " epoch_cost.append(curr_cost)\n", " print('Epoch: %03d/%03d |' % (e+1, NUM_EPOCHS), end=\"\")\n", " print(' Train MSE: %.5f' % curr_cost)\n", "\n", " return minibatch_cost, epoch_cost" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/050 | Batch 000/600 | Cost: 0.2471\n", "Epoch: 001/050 | Batch 050/600 | Cost: 0.0885\n", "Epoch: 001/050 | Batch 100/600 | Cost: 0.0880\n", "Epoch: 001/050 | Batch 150/600 | Cost: 0.0877\n", "Epoch: 001/050 | Batch 200/600 | Cost: 0.0847\n", "Epoch: 001/050 | Batch 250/600 | Cost: 0.0838\n", "Epoch: 001/050 | Batch 300/600 | Cost: 0.0808\n", "Epoch: 001/050 | Batch 350/600 | Cost: 0.0801\n", "Epoch: 001/050 | Batch 400/600 | Cost: 0.0766\n", "Epoch: 001/050 | Batch 450/600 | Cost: 0.0740\n", "Epoch: 001/050 | Batch 500/600 | Cost: 0.0730\n", "Epoch: 001/050 | Batch 550/600 | Cost: 0.0730\n", "Epoch: 001/050 | Train MSE: 0.06566\n", "Epoch: 002/050 | Batch 000/600 | Cost: 0.0644\n", "Epoch: 002/050 | Batch 050/600 | Cost: 0.0637\n", "Epoch: 002/050 | Batch 100/600 | Cost: 0.0600\n", "Epoch: 002/050 | Batch 150/600 | Cost: 0.0580\n", "Epoch: 002/050 | Batch 200/600 | Cost: 0.0541\n", "Epoch: 002/050 | Batch 250/600 | Cost: 0.0546\n", "Epoch: 002/050 | Batch 300/600 | Cost: 0.0547\n", "Epoch: 002/050 | Batch 350/600 | Cost: 0.0488\n", "Epoch: 002/050 | Batch 400/600 | Cost: 0.0515\n", "Epoch: 002/050 | Batch 450/600 | Cost: 0.0476\n", "Epoch: 002/050 | Batch 500/600 | Cost: 0.0486\n", "Epoch: 002/050 | Batch 550/600 | Cost: 0.0447\n", "Epoch: 002/050 | Train MSE: 0.04302\n", "Epoch: 003/050 | Batch 000/600 | Cost: 0.0413\n", "Epoch: 003/050 | Batch 050/600 | Cost: 0.0404\n", "Epoch: 003/050 | Batch 100/600 | Cost: 0.0374\n", "Epoch: 003/050 | Batch 150/600 | Cost: 0.0351\n", "Epoch: 003/050 | Batch 200/600 | Cost: 0.0374\n", "Epoch: 003/050 | Batch 250/600 | Cost: 0.0371\n", "Epoch: 003/050 | Batch 300/600 | Cost: 0.0347\n", "Epoch: 003/050 | Batch 350/600 | Cost: 0.0359\n", "Epoch: 003/050 | Batch 400/600 | Cost: 0.0373\n", "Epoch: 003/050 | Batch 450/600 | Cost: 0.0305\n", "Epoch: 003/050 | Batch 500/600 | Cost: 0.0333\n", "Epoch: 003/050 | Batch 550/600 | Cost: 0.0318\n", "Epoch: 003/050 | Train MSE: 0.03251\n", "Epoch: 004/050 | Batch 000/600 | Cost: 0.0292\n", "Epoch: 004/050 | Batch 050/600 | Cost: 0.0312\n", "Epoch: 004/050 | Batch 100/600 | Cost: 0.0273\n", "Epoch: 004/050 | Batch 150/600 | Cost: 0.0293\n", "Epoch: 004/050 | Batch 200/600 | Cost: 0.0293\n", "Epoch: 004/050 | Batch 250/600 | Cost: 0.0292\n", "Epoch: 004/050 | Batch 300/600 | Cost: 0.0350\n", "Epoch: 004/050 | Batch 350/600 | Cost: 0.0312\n", "Epoch: 004/050 | Batch 400/600 | Cost: 0.0276\n", "Epoch: 004/050 | Batch 450/600 | Cost: 0.0312\n", "Epoch: 004/050 | Batch 500/600 | Cost: 0.0315\n", "Epoch: 004/050 | Batch 550/600 | Cost: 0.0291\n", "Epoch: 004/050 | Train MSE: 0.02692\n", "Epoch: 005/050 | Batch 000/600 | Cost: 0.0282\n", "Epoch: 005/050 | Batch 050/600 | Cost: 0.0264\n", "Epoch: 005/050 | Batch 100/600 | Cost: 0.0231\n", "Epoch: 005/050 | Batch 150/600 | Cost: 0.0242\n", "Epoch: 005/050 | Batch 200/600 | Cost: 0.0260\n", "Epoch: 005/050 | Batch 250/600 | Cost: 0.0215\n", "Epoch: 005/050 | Batch 300/600 | Cost: 0.0294\n", "Epoch: 005/050 | Batch 350/600 | Cost: 0.0220\n", "Epoch: 005/050 | Batch 400/600 | Cost: 0.0246\n", "Epoch: 005/050 | Batch 450/600 | Cost: 0.0229\n", "Epoch: 005/050 | Batch 500/600 | Cost: 0.0289\n", "Epoch: 005/050 | Batch 550/600 | Cost: 0.0240\n", "Epoch: 005/050 | Train MSE: 0.02365\n", "Epoch: 006/050 | Batch 000/600 | Cost: 0.0201\n", "Epoch: 006/050 | Batch 050/600 | Cost: 0.0223\n", "Epoch: 006/050 | Batch 100/600 | Cost: 0.0253\n", "Epoch: 006/050 | Batch 150/600 | Cost: 0.0258\n", "Epoch: 006/050 | Batch 200/600 | Cost: 0.0216\n", "Epoch: 006/050 | Batch 250/600 | Cost: 0.0282\n", "Epoch: 006/050 | Batch 300/600 | Cost: 0.0203\n", "Epoch: 006/050 | Batch 350/600 | Cost: 0.0218\n", "Epoch: 006/050 | Batch 400/600 | Cost: 0.0249\n", "Epoch: 006/050 | Batch 450/600 | Cost: 0.0211\n", "Epoch: 006/050 | Batch 500/600 | Cost: 0.0230\n", "Epoch: 006/050 | Batch 550/600 | Cost: 0.0174\n", "Epoch: 006/050 | Train MSE: 0.02156\n", "Epoch: 007/050 | Batch 000/600 | Cost: 0.0186\n", "Epoch: 007/050 | Batch 050/600 | Cost: 0.0223\n", "Epoch: 007/050 | Batch 100/600 | Cost: 0.0208\n", "Epoch: 007/050 | Batch 150/600 | Cost: 0.0231\n", "Epoch: 007/050 | Batch 200/600 | Cost: 0.0219\n", "Epoch: 007/050 | Batch 250/600 | Cost: 0.0206\n", "Epoch: 007/050 | Batch 300/600 | Cost: 0.0227\n", "Epoch: 007/050 | Batch 350/600 | Cost: 0.0249\n", "Epoch: 007/050 | Batch 400/600 | Cost: 0.0214\n", "Epoch: 007/050 | Batch 450/600 | Cost: 0.0203\n", "Epoch: 007/050 | Batch 500/600 | Cost: 0.0209\n", "Epoch: 007/050 | Batch 550/600 | Cost: 0.0160\n", "Epoch: 007/050 | Train MSE: 0.02006\n", "Epoch: 008/050 | Batch 000/600 | Cost: 0.0171\n", "Epoch: 008/050 | Batch 050/600 | Cost: 0.0232\n", "Epoch: 008/050 | Batch 100/600 | Cost: 0.0227\n", "Epoch: 008/050 | Batch 150/600 | Cost: 0.0156\n", "Epoch: 008/050 | Batch 200/600 | Cost: 0.0157\n", "Epoch: 008/050 | Batch 250/600 | Cost: 0.0189\n", "Epoch: 008/050 | Batch 300/600 | Cost: 0.0154\n", "Epoch: 008/050 | Batch 350/600 | Cost: 0.0213\n", "Epoch: 008/050 | Batch 400/600 | Cost: 0.0158\n", "Epoch: 008/050 | Batch 450/600 | Cost: 0.0201\n", "Epoch: 008/050 | Batch 500/600 | Cost: 0.0176\n", "Epoch: 008/050 | Batch 550/600 | Cost: 0.0254\n", "Epoch: 008/050 | Train MSE: 0.01892\n", "Epoch: 009/050 | Batch 000/600 | Cost: 0.0195\n", "Epoch: 009/050 | Batch 050/600 | Cost: 0.0214\n", "Epoch: 009/050 | Batch 100/600 | Cost: 0.0255\n", "Epoch: 009/050 | Batch 150/600 | Cost: 0.0153\n", "Epoch: 009/050 | Batch 200/600 | Cost: 0.0184\n", "Epoch: 009/050 | Batch 250/600 | Cost: 0.0247\n", "Epoch: 009/050 | Batch 300/600 | Cost: 0.0151\n", "Epoch: 009/050 | Batch 350/600 | Cost: 0.0165\n", "Epoch: 009/050 | Batch 400/600 | Cost: 0.0171\n", "Epoch: 009/050 | Batch 450/600 | Cost: 0.0136\n", "Epoch: 009/050 | Batch 500/600 | Cost: 0.0206\n", "Epoch: 009/050 | Batch 550/600 | Cost: 0.0142\n", "Epoch: 009/050 | Train MSE: 0.01803\n", "Epoch: 010/050 | Batch 000/600 | Cost: 0.0183\n", "Epoch: 010/050 | Batch 050/600 | Cost: 0.0222\n", "Epoch: 010/050 | Batch 100/600 | Cost: 0.0203\n", "Epoch: 010/050 | Batch 150/600 | Cost: 0.0224\n", "Epoch: 010/050 | Batch 200/600 | Cost: 0.0234\n", "Epoch: 010/050 | Batch 250/600 | Cost: 0.0229\n", "Epoch: 010/050 | Batch 300/600 | Cost: 0.0179\n", "Epoch: 010/050 | Batch 350/600 | Cost: 0.0181\n", "Epoch: 010/050 | Batch 400/600 | Cost: 0.0122\n", "Epoch: 010/050 | Batch 450/600 | Cost: 0.0176\n", "Epoch: 010/050 | Batch 500/600 | Cost: 0.0198\n", "Epoch: 010/050 | Batch 550/600 | Cost: 0.0142\n", "Epoch: 010/050 | Train MSE: 0.01727\n", "Epoch: 011/050 | Batch 000/600 | Cost: 0.0156\n", "Epoch: 011/050 | Batch 050/600 | Cost: 0.0178\n", "Epoch: 011/050 | Batch 100/600 | Cost: 0.0102\n", "Epoch: 011/050 | Batch 150/600 | Cost: 0.0188\n", "Epoch: 011/050 | Batch 200/600 | Cost: 0.0177\n", "Epoch: 011/050 | Batch 250/600 | Cost: 0.0196\n", "Epoch: 011/050 | Batch 300/600 | Cost: 0.0115\n", "Epoch: 011/050 | Batch 350/600 | Cost: 0.0109\n", "Epoch: 011/050 | Batch 400/600 | Cost: 0.0212\n", "Epoch: 011/050 | Batch 450/600 | Cost: 0.0162\n", "Epoch: 011/050 | Batch 500/600 | Cost: 0.0139\n", "Epoch: 011/050 | Batch 550/600 | Cost: 0.0144\n", "Epoch: 011/050 | Train MSE: 0.01665\n", "Epoch: 012/050 | Batch 000/600 | Cost: 0.0185\n", "Epoch: 012/050 | Batch 050/600 | Cost: 0.0137\n", "Epoch: 012/050 | Batch 100/600 | Cost: 0.0160\n", "Epoch: 012/050 | Batch 150/600 | Cost: 0.0142\n", "Epoch: 012/050 | Batch 200/600 | Cost: 0.0138\n", "Epoch: 012/050 | Batch 250/600 | Cost: 0.0169\n", "Epoch: 012/050 | Batch 300/600 | Cost: 0.0141\n", "Epoch: 012/050 | Batch 350/600 | Cost: 0.0137\n", "Epoch: 012/050 | Batch 400/600 | Cost: 0.0134\n", "Epoch: 012/050 | Batch 450/600 | Cost: 0.0141\n", "Epoch: 012/050 | Batch 500/600 | Cost: 0.0139\n", "Epoch: 012/050 | Batch 550/600 | Cost: 0.0175\n", "Epoch: 012/050 | Train MSE: 0.01609\n", "Epoch: 013/050 | Batch 000/600 | Cost: 0.0197\n", "Epoch: 013/050 | Batch 050/600 | Cost: 0.0134\n", "Epoch: 013/050 | Batch 100/600 | Cost: 0.0213\n", "Epoch: 013/050 | Batch 150/600 | Cost: 0.0172\n", "Epoch: 013/050 | Batch 200/600 | Cost: 0.0149\n", "Epoch: 013/050 | Batch 250/600 | Cost: 0.0155\n", "Epoch: 013/050 | Batch 300/600 | Cost: 0.0224\n", "Epoch: 013/050 | Batch 350/600 | Cost: 0.0177\n", "Epoch: 013/050 | Batch 400/600 | Cost: 0.0125\n", "Epoch: 013/050 | Batch 450/600 | Cost: 0.0191\n", "Epoch: 013/050 | Batch 500/600 | Cost: 0.0196\n", "Epoch: 013/050 | Batch 550/600 | Cost: 0.0167\n", "Epoch: 013/050 | Train MSE: 0.01561\n", "Epoch: 014/050 | Batch 000/600 | Cost: 0.0206\n", "Epoch: 014/050 | Batch 050/600 | Cost: 0.0139\n", "Epoch: 014/050 | Batch 100/600 | Cost: 0.0145\n", "Epoch: 014/050 | Batch 150/600 | Cost: 0.0210\n", "Epoch: 014/050 | Batch 200/600 | Cost: 0.0113\n", "Epoch: 014/050 | Batch 250/600 | Cost: 0.0160\n", "Epoch: 014/050 | Batch 300/600 | Cost: 0.0188\n", "Epoch: 014/050 | Batch 350/600 | Cost: 0.0247\n", "Epoch: 014/050 | Batch 400/600 | Cost: 0.0208\n", "Epoch: 014/050 | Batch 450/600 | Cost: 0.0170\n", "Epoch: 014/050 | Batch 500/600 | Cost: 0.0148\n", "Epoch: 014/050 | Batch 550/600 | Cost: 0.0197\n", "Epoch: 014/050 | Train MSE: 0.01518\n", "Epoch: 015/050 | Batch 000/600 | Cost: 0.0138\n", "Epoch: 015/050 | Batch 050/600 | Cost: 0.0183\n", "Epoch: 015/050 | Batch 100/600 | Cost: 0.0117\n", "Epoch: 015/050 | Batch 150/600 | Cost: 0.0123\n", "Epoch: 015/050 | Batch 200/600 | Cost: 0.0114\n", "Epoch: 015/050 | Batch 250/600 | Cost: 0.0116\n", "Epoch: 015/050 | Batch 300/600 | Cost: 0.0199\n", "Epoch: 015/050 | Batch 350/600 | Cost: 0.0165\n", "Epoch: 015/050 | Batch 400/600 | Cost: 0.0199\n", "Epoch: 015/050 | Batch 450/600 | Cost: 0.0143\n", "Epoch: 015/050 | Batch 500/600 | Cost: 0.0148\n", "Epoch: 015/050 | Batch 550/600 | Cost: 0.0130\n", "Epoch: 015/050 | Train MSE: 0.01481\n", "Epoch: 016/050 | Batch 000/600 | Cost: 0.0195\n", "Epoch: 016/050 | Batch 050/600 | Cost: 0.0150\n", "Epoch: 016/050 | Batch 100/600 | Cost: 0.0145\n", "Epoch: 016/050 | Batch 150/600 | Cost: 0.0139\n", "Epoch: 016/050 | Batch 200/600 | Cost: 0.0108\n", "Epoch: 016/050 | Batch 250/600 | Cost: 0.0110\n", "Epoch: 016/050 | Batch 300/600 | Cost: 0.0119\n", "Epoch: 016/050 | Batch 350/600 | Cost: 0.0175\n", "Epoch: 016/050 | Batch 400/600 | Cost: 0.0133\n", "Epoch: 016/050 | Batch 450/600 | Cost: 0.0144\n", "Epoch: 016/050 | Batch 500/600 | Cost: 0.0168\n", "Epoch: 016/050 | Batch 550/600 | Cost: 0.0131\n", "Epoch: 016/050 | Train MSE: 0.01447\n", "Epoch: 017/050 | Batch 000/600 | Cost: 0.0128\n", "Epoch: 017/050 | Batch 050/600 | Cost: 0.0160\n", "Epoch: 017/050 | Batch 100/600 | Cost: 0.0183\n", "Epoch: 017/050 | Batch 150/600 | Cost: 0.0136\n", "Epoch: 017/050 | Batch 200/600 | Cost: 0.0144\n", "Epoch: 017/050 | Batch 250/600 | Cost: 0.0109\n", "Epoch: 017/050 | Batch 300/600 | Cost: 0.0104\n", "Epoch: 017/050 | Batch 350/600 | Cost: 0.0146\n", "Epoch: 017/050 | Batch 400/600 | Cost: 0.0099\n", "Epoch: 017/050 | Batch 450/600 | Cost: 0.0096\n", "Epoch: 017/050 | Batch 500/600 | Cost: 0.0145\n", "Epoch: 017/050 | Batch 550/600 | Cost: 0.0160\n", "Epoch: 017/050 | Train MSE: 0.01415\n", "Epoch: 018/050 | Batch 000/600 | Cost: 0.0140\n", "Epoch: 018/050 | Batch 050/600 | Cost: 0.0145\n", "Epoch: 018/050 | Batch 100/600 | Cost: 0.0167\n", "Epoch: 018/050 | Batch 150/600 | Cost: 0.0136\n", "Epoch: 018/050 | Batch 200/600 | Cost: 0.0102\n", "Epoch: 018/050 | Batch 250/600 | Cost: 0.0164\n", "Epoch: 018/050 | Batch 300/600 | Cost: 0.0094\n", "Epoch: 018/050 | Batch 350/600 | Cost: 0.0169\n", "Epoch: 018/050 | Batch 400/600 | Cost: 0.0108\n", "Epoch: 018/050 | Batch 450/600 | Cost: 0.0155\n", "Epoch: 018/050 | Batch 500/600 | Cost: 0.0106\n", "Epoch: 018/050 | Batch 550/600 | Cost: 0.0143\n", "Epoch: 018/050 | Train MSE: 0.01386\n", "Epoch: 019/050 | Batch 000/600 | Cost: 0.0226\n", "Epoch: 019/050 | Batch 050/600 | Cost: 0.0175\n", "Epoch: 019/050 | Batch 100/600 | Cost: 0.0165\n", "Epoch: 019/050 | Batch 150/600 | Cost: 0.0118\n", "Epoch: 019/050 | Batch 200/600 | Cost: 0.0174\n", "Epoch: 019/050 | Batch 250/600 | Cost: 0.0132\n", "Epoch: 019/050 | Batch 300/600 | Cost: 0.0136\n", "Epoch: 019/050 | Batch 350/600 | Cost: 0.0090\n", "Epoch: 019/050 | Batch 400/600 | Cost: 0.0064\n", "Epoch: 019/050 | Batch 450/600 | Cost: 0.0168\n", "Epoch: 019/050 | Batch 500/600 | Cost: 0.0135\n", "Epoch: 019/050 | Batch 550/600 | Cost: 0.0166\n", "Epoch: 019/050 | Train MSE: 0.01360\n", "Epoch: 020/050 | Batch 000/600 | Cost: 0.0184\n", "Epoch: 020/050 | Batch 050/600 | Cost: 0.0124\n", "Epoch: 020/050 | Batch 100/600 | Cost: 0.0142\n", "Epoch: 020/050 | Batch 150/600 | Cost: 0.0167\n", "Epoch: 020/050 | Batch 200/600 | Cost: 0.0140\n", "Epoch: 020/050 | Batch 250/600 | Cost: 0.0112\n", "Epoch: 020/050 | Batch 300/600 | Cost: 0.0140\n", "Epoch: 020/050 | Batch 350/600 | Cost: 0.0115\n", "Epoch: 020/050 | Batch 400/600 | Cost: 0.0106\n", "Epoch: 020/050 | Batch 450/600 | Cost: 0.0156\n", "Epoch: 020/050 | Batch 500/600 | Cost: 0.0150\n", "Epoch: 020/050 | Batch 550/600 | Cost: 0.0113\n", "Epoch: 020/050 | Train MSE: 0.01335\n", "Epoch: 021/050 | Batch 000/600 | Cost: 0.0127\n", "Epoch: 021/050 | Batch 050/600 | Cost: 0.0100\n", "Epoch: 021/050 | Batch 100/600 | Cost: 0.0183\n", "Epoch: 021/050 | Batch 150/600 | Cost: 0.0138\n", "Epoch: 021/050 | Batch 200/600 | Cost: 0.0120\n", "Epoch: 021/050 | Batch 250/600 | Cost: 0.0115\n", "Epoch: 021/050 | Batch 300/600 | Cost: 0.0125\n", "Epoch: 021/050 | Batch 350/600 | Cost: 0.0085\n", "Epoch: 021/050 | Batch 400/600 | Cost: 0.0121\n", "Epoch: 021/050 | Batch 450/600 | Cost: 0.0140\n", "Epoch: 021/050 | Batch 500/600 | Cost: 0.0098\n", "Epoch: 021/050 | Batch 550/600 | Cost: 0.0145\n", "Epoch: 021/050 | Train MSE: 0.01312\n", "Epoch: 022/050 | Batch 000/600 | Cost: 0.0141\n", "Epoch: 022/050 | Batch 050/600 | Cost: 0.0147\n", "Epoch: 022/050 | Batch 100/600 | Cost: 0.0172\n", "Epoch: 022/050 | Batch 150/600 | Cost: 0.0161\n", "Epoch: 022/050 | Batch 200/600 | Cost: 0.0108\n", "Epoch: 022/050 | Batch 250/600 | Cost: 0.0108\n", "Epoch: 022/050 | Batch 300/600 | Cost: 0.0149\n", "Epoch: 022/050 | Batch 350/600 | Cost: 0.0133\n", "Epoch: 022/050 | Batch 400/600 | Cost: 0.0077\n", "Epoch: 022/050 | Batch 450/600 | Cost: 0.0101\n", "Epoch: 022/050 | Batch 500/600 | Cost: 0.0177\n", "Epoch: 022/050 | Batch 550/600 | Cost: 0.0120\n", "Epoch: 022/050 | Train MSE: 0.01291\n", "Epoch: 023/050 | Batch 000/600 | Cost: 0.0165\n", "Epoch: 023/050 | Batch 050/600 | Cost: 0.0132\n", "Epoch: 023/050 | Batch 100/600 | Cost: 0.0169\n", "Epoch: 023/050 | Batch 150/600 | Cost: 0.0135\n", "Epoch: 023/050 | Batch 200/600 | Cost: 0.0133\n", "Epoch: 023/050 | Batch 250/600 | Cost: 0.0137\n", "Epoch: 023/050 | Batch 300/600 | Cost: 0.0149\n", "Epoch: 023/050 | Batch 350/600 | Cost: 0.0185\n", "Epoch: 023/050 | Batch 400/600 | Cost: 0.0091\n", "Epoch: 023/050 | Batch 450/600 | Cost: 0.0141\n", "Epoch: 023/050 | Batch 500/600 | Cost: 0.0170\n", "Epoch: 023/050 | Batch 550/600 | Cost: 0.0096\n", "Epoch: 023/050 | Train MSE: 0.01270\n", "Epoch: 024/050 | Batch 000/600 | Cost: 0.0122\n", "Epoch: 024/050 | Batch 050/600 | Cost: 0.0095\n", "Epoch: 024/050 | Batch 100/600 | Cost: 0.0099\n", "Epoch: 024/050 | Batch 150/600 | Cost: 0.0063\n", "Epoch: 024/050 | Batch 200/600 | Cost: 0.0133\n", "Epoch: 024/050 | Batch 250/600 | Cost: 0.0108\n", "Epoch: 024/050 | Batch 300/600 | Cost: 0.0149\n", "Epoch: 024/050 | Batch 350/600 | Cost: 0.0143\n", "Epoch: 024/050 | Batch 400/600 | Cost: 0.0124\n", "Epoch: 024/050 | Batch 450/600 | Cost: 0.0116\n", "Epoch: 024/050 | Batch 500/600 | Cost: 0.0083\n", "Epoch: 024/050 | Batch 550/600 | Cost: 0.0079\n", "Epoch: 024/050 | Train MSE: 0.01251\n", "Epoch: 025/050 | Batch 000/600 | Cost: 0.0147\n", "Epoch: 025/050 | Batch 050/600 | Cost: 0.0104\n", "Epoch: 025/050 | Batch 100/600 | Cost: 0.0120\n", "Epoch: 025/050 | Batch 150/600 | Cost: 0.0127\n", "Epoch: 025/050 | Batch 200/600 | Cost: 0.0094\n", "Epoch: 025/050 | Batch 250/600 | Cost: 0.0085\n", "Epoch: 025/050 | Batch 300/600 | Cost: 0.0138\n", "Epoch: 025/050 | Batch 350/600 | Cost: 0.0086\n", "Epoch: 025/050 | Batch 400/600 | Cost: 0.0130\n", "Epoch: 025/050 | Batch 450/600 | Cost: 0.0136\n", "Epoch: 025/050 | Batch 500/600 | Cost: 0.0135\n", "Epoch: 025/050 | Batch 550/600 | Cost: 0.0155\n", "Epoch: 025/050 | Train MSE: 0.01232\n", "Epoch: 026/050 | Batch 000/600 | Cost: 0.0138\n", "Epoch: 026/050 | Batch 050/600 | Cost: 0.0136\n", "Epoch: 026/050 | Batch 100/600 | Cost: 0.0076\n", "Epoch: 026/050 | Batch 150/600 | Cost: 0.0179\n", "Epoch: 026/050 | Batch 200/600 | Cost: 0.0119\n", "Epoch: 026/050 | Batch 250/600 | Cost: 0.0142\n", "Epoch: 026/050 | Batch 300/600 | Cost: 0.0138\n", "Epoch: 026/050 | Batch 350/600 | Cost: 0.0107\n", "Epoch: 026/050 | Batch 400/600 | Cost: 0.0103\n", "Epoch: 026/050 | Batch 450/600 | Cost: 0.0091\n", "Epoch: 026/050 | Batch 500/600 | Cost: 0.0116\n", "Epoch: 026/050 | Batch 550/600 | Cost: 0.0091\n", "Epoch: 026/050 | Train MSE: 0.01215\n", "Epoch: 027/050 | Batch 000/600 | Cost: 0.0085\n", "Epoch: 027/050 | Batch 050/600 | Cost: 0.0065\n", "Epoch: 027/050 | Batch 100/600 | Cost: 0.0102\n", "Epoch: 027/050 | Batch 150/600 | Cost: 0.0152\n", "Epoch: 027/050 | Batch 200/600 | Cost: 0.0162\n", "Epoch: 027/050 | Batch 250/600 | Cost: 0.0079\n", "Epoch: 027/050 | Batch 300/600 | Cost: 0.0118\n", "Epoch: 027/050 | Batch 350/600 | Cost: 0.0111\n", "Epoch: 027/050 | Batch 400/600 | Cost: 0.0081\n", "Epoch: 027/050 | Batch 450/600 | Cost: 0.0100\n", "Epoch: 027/050 | Batch 500/600 | Cost: 0.0103\n", "Epoch: 027/050 | Batch 550/600 | Cost: 0.0117\n", "Epoch: 027/050 | Train MSE: 0.01199\n", "Epoch: 028/050 | Batch 000/600 | Cost: 0.0077\n", "Epoch: 028/050 | Batch 050/600 | Cost: 0.0164\n", "Epoch: 028/050 | Batch 100/600 | Cost: 0.0095\n", "Epoch: 028/050 | Batch 150/600 | Cost: 0.0112\n", "Epoch: 028/050 | Batch 200/600 | Cost: 0.0109\n", "Epoch: 028/050 | Batch 250/600 | Cost: 0.0148\n", "Epoch: 028/050 | Batch 300/600 | Cost: 0.0126\n", "Epoch: 028/050 | Batch 350/600 | Cost: 0.0082\n", "Epoch: 028/050 | Batch 400/600 | Cost: 0.0115\n", "Epoch: 028/050 | Batch 450/600 | Cost: 0.0194\n", "Epoch: 028/050 | Batch 500/600 | Cost: 0.0111\n", "Epoch: 028/050 | Batch 550/600 | Cost: 0.0145\n", "Epoch: 028/050 | Train MSE: 0.01181\n", "Epoch: 029/050 | Batch 000/600 | Cost: 0.0112\n", "Epoch: 029/050 | Batch 050/600 | Cost: 0.0137\n", "Epoch: 029/050 | Batch 100/600 | Cost: 0.0192\n", "Epoch: 029/050 | Batch 150/600 | Cost: 0.0105\n", "Epoch: 029/050 | Batch 200/600 | Cost: 0.0107\n", "Epoch: 029/050 | Batch 250/600 | Cost: 0.0081\n", "Epoch: 029/050 | Batch 300/600 | Cost: 0.0079\n", "Epoch: 029/050 | Batch 350/600 | Cost: 0.0126\n", "Epoch: 029/050 | Batch 400/600 | Cost: 0.0135\n", "Epoch: 029/050 | Batch 450/600 | Cost: 0.0062\n", "Epoch: 029/050 | Batch 500/600 | Cost: 0.0121\n", "Epoch: 029/050 | Batch 550/600 | Cost: 0.0091\n", "Epoch: 029/050 | Train MSE: 0.01167\n", "Epoch: 030/050 | Batch 000/600 | Cost: 0.0068\n", "Epoch: 030/050 | Batch 050/600 | Cost: 0.0115\n", "Epoch: 030/050 | Batch 100/600 | Cost: 0.0145\n", "Epoch: 030/050 | Batch 150/600 | Cost: 0.0128\n", "Epoch: 030/050 | Batch 200/600 | Cost: 0.0129\n", "Epoch: 030/050 | Batch 250/600 | Cost: 0.0128\n", "Epoch: 030/050 | Batch 300/600 | Cost: 0.0085\n", "Epoch: 030/050 | Batch 350/600 | Cost: 0.0149\n", "Epoch: 030/050 | Batch 400/600 | Cost: 0.0080\n", "Epoch: 030/050 | Batch 450/600 | Cost: 0.0168\n", "Epoch: 030/050 | Batch 500/600 | Cost: 0.0106\n", "Epoch: 030/050 | Batch 550/600 | Cost: 0.0125\n", "Epoch: 030/050 | Train MSE: 0.01152\n", "Epoch: 031/050 | Batch 000/600 | Cost: 0.0137\n", "Epoch: 031/050 | Batch 050/600 | Cost: 0.0080\n", "Epoch: 031/050 | Batch 100/600 | Cost: 0.0122\n", "Epoch: 031/050 | Batch 150/600 | Cost: 0.0121\n", "Epoch: 031/050 | Batch 200/600 | Cost: 0.0125\n", "Epoch: 031/050 | Batch 250/600 | Cost: 0.0120\n", "Epoch: 031/050 | Batch 300/600 | Cost: 0.0123\n", "Epoch: 031/050 | Batch 350/600 | Cost: 0.0166\n", "Epoch: 031/050 | Batch 400/600 | Cost: 0.0099\n", "Epoch: 031/050 | Batch 450/600 | Cost: 0.0099\n", "Epoch: 031/050 | Batch 500/600 | Cost: 0.0103\n", "Epoch: 031/050 | Batch 550/600 | Cost: 0.0099\n", "Epoch: 031/050 | Train MSE: 0.01138\n", "Epoch: 032/050 | Batch 000/600 | Cost: 0.0125\n", "Epoch: 032/050 | Batch 050/600 | Cost: 0.0114\n", "Epoch: 032/050 | Batch 100/600 | Cost: 0.0118\n", "Epoch: 032/050 | Batch 150/600 | Cost: 0.0110\n", "Epoch: 032/050 | Batch 200/600 | Cost: 0.0137\n", "Epoch: 032/050 | Batch 250/600 | Cost: 0.0156\n", "Epoch: 032/050 | Batch 300/600 | Cost: 0.0084\n", "Epoch: 032/050 | Batch 350/600 | Cost: 0.0187\n", "Epoch: 032/050 | Batch 400/600 | Cost: 0.0101\n", "Epoch: 032/050 | Batch 450/600 | Cost: 0.0071\n", "Epoch: 032/050 | Batch 500/600 | Cost: 0.0104\n", "Epoch: 032/050 | Batch 550/600 | Cost: 0.0135\n", "Epoch: 032/050 | Train MSE: 0.01126\n", "Epoch: 033/050 | Batch 000/600 | Cost: 0.0159\n", "Epoch: 033/050 | Batch 050/600 | Cost: 0.0126\n", "Epoch: 033/050 | Batch 100/600 | Cost: 0.0077\n", "Epoch: 033/050 | Batch 150/600 | Cost: 0.0093\n", "Epoch: 033/050 | Batch 200/600 | Cost: 0.0092\n", "Epoch: 033/050 | Batch 250/600 | Cost: 0.0128\n", "Epoch: 033/050 | Batch 300/600 | Cost: 0.0095\n", "Epoch: 033/050 | Batch 350/600 | Cost: 0.0108\n", "Epoch: 033/050 | Batch 400/600 | Cost: 0.0116\n", "Epoch: 033/050 | Batch 450/600 | Cost: 0.0082\n", "Epoch: 033/050 | Batch 500/600 | Cost: 0.0151\n", "Epoch: 033/050 | Batch 550/600 | Cost: 0.0097\n", "Epoch: 033/050 | Train MSE: 0.01112\n", "Epoch: 034/050 | Batch 000/600 | Cost: 0.0119\n", "Epoch: 034/050 | Batch 050/600 | Cost: 0.0079\n", "Epoch: 034/050 | Batch 100/600 | Cost: 0.0118\n", "Epoch: 034/050 | Batch 150/600 | Cost: 0.0122\n", "Epoch: 034/050 | Batch 200/600 | Cost: 0.0078\n", "Epoch: 034/050 | Batch 250/600 | Cost: 0.0142\n", "Epoch: 034/050 | Batch 300/600 | Cost: 0.0066\n", "Epoch: 034/050 | Batch 350/600 | Cost: 0.0112\n", "Epoch: 034/050 | Batch 400/600 | Cost: 0.0067\n", "Epoch: 034/050 | Batch 450/600 | Cost: 0.0105\n", "Epoch: 034/050 | Batch 500/600 | Cost: 0.0119\n", "Epoch: 034/050 | Batch 550/600 | Cost: 0.0145\n", "Epoch: 034/050 | Train MSE: 0.01099\n", "Epoch: 035/050 | Batch 000/600 | Cost: 0.0100\n", "Epoch: 035/050 | Batch 050/600 | Cost: 0.0072\n", "Epoch: 035/050 | Batch 100/600 | Cost: 0.0071\n", "Epoch: 035/050 | Batch 150/600 | Cost: 0.0111\n", "Epoch: 035/050 | Batch 200/600 | Cost: 0.0096\n", "Epoch: 035/050 | Batch 250/600 | Cost: 0.0089\n", "Epoch: 035/050 | Batch 300/600 | Cost: 0.0098\n", "Epoch: 035/050 | Batch 350/600 | Cost: 0.0116\n", "Epoch: 035/050 | Batch 400/600 | Cost: 0.0128\n", "Epoch: 035/050 | Batch 450/600 | Cost: 0.0091\n", "Epoch: 035/050 | Batch 500/600 | Cost: 0.0093\n", "Epoch: 035/050 | Batch 550/600 | Cost: 0.0103\n", "Epoch: 035/050 | Train MSE: 0.01088\n", "Epoch: 036/050 | Batch 000/600 | Cost: 0.0065\n", "Epoch: 036/050 | Batch 050/600 | Cost: 0.0164\n", "Epoch: 036/050 | Batch 100/600 | Cost: 0.0118\n", "Epoch: 036/050 | Batch 150/600 | Cost: 0.0075\n", "Epoch: 036/050 | Batch 200/600 | Cost: 0.0193\n", "Epoch: 036/050 | Batch 250/600 | Cost: 0.0208\n", "Epoch: 036/050 | Batch 300/600 | Cost: 0.0096\n", "Epoch: 036/050 | Batch 350/600 | Cost: 0.0084\n", "Epoch: 036/050 | Batch 400/600 | Cost: 0.0096\n", "Epoch: 036/050 | Batch 450/600 | Cost: 0.0109\n", "Epoch: 036/050 | Batch 500/600 | Cost: 0.0104\n", "Epoch: 036/050 | Batch 550/600 | Cost: 0.0063\n", "Epoch: 036/050 | Train MSE: 0.01076\n", "Epoch: 037/050 | Batch 000/600 | Cost: 0.0092\n", "Epoch: 037/050 | Batch 050/600 | Cost: 0.0120\n", "Epoch: 037/050 | Batch 100/600 | Cost: 0.0107\n", "Epoch: 037/050 | Batch 150/600 | Cost: 0.0139\n", "Epoch: 037/050 | Batch 200/600 | Cost: 0.0127\n", "Epoch: 037/050 | Batch 250/600 | Cost: 0.0082\n", "Epoch: 037/050 | Batch 300/600 | Cost: 0.0073\n", "Epoch: 037/050 | Batch 350/600 | Cost: 0.0072\n", "Epoch: 037/050 | Batch 400/600 | Cost: 0.0083\n", "Epoch: 037/050 | Batch 450/600 | Cost: 0.0087\n", "Epoch: 037/050 | Batch 500/600 | Cost: 0.0187\n", "Epoch: 037/050 | Batch 550/600 | Cost: 0.0128\n", "Epoch: 037/050 | Train MSE: 0.01064\n", "Epoch: 038/050 | Batch 000/600 | Cost: 0.0145\n", "Epoch: 038/050 | Batch 050/600 | Cost: 0.0082\n", "Epoch: 038/050 | Batch 100/600 | Cost: 0.0116\n", "Epoch: 038/050 | Batch 150/600 | Cost: 0.0114\n", "Epoch: 038/050 | Batch 200/600 | Cost: 0.0089\n", "Epoch: 038/050 | Batch 250/600 | Cost: 0.0110\n", "Epoch: 038/050 | Batch 300/600 | Cost: 0.0130\n", "Epoch: 038/050 | Batch 350/600 | Cost: 0.0155\n", "Epoch: 038/050 | Batch 400/600 | Cost: 0.0107\n", "Epoch: 038/050 | Batch 450/600 | Cost: 0.0076\n", "Epoch: 038/050 | Batch 500/600 | Cost: 0.0138\n", "Epoch: 038/050 | Batch 550/600 | Cost: 0.0123\n", "Epoch: 038/050 | Train MSE: 0.01054\n", "Epoch: 039/050 | Batch 000/600 | Cost: 0.0106\n", "Epoch: 039/050 | Batch 050/600 | Cost: 0.0153\n", "Epoch: 039/050 | Batch 100/600 | Cost: 0.0108\n", "Epoch: 039/050 | Batch 150/600 | Cost: 0.0097\n", "Epoch: 039/050 | Batch 200/600 | Cost: 0.0116\n", "Epoch: 039/050 | Batch 250/600 | Cost: 0.0123\n", "Epoch: 039/050 | Batch 300/600 | Cost: 0.0082\n", "Epoch: 039/050 | Batch 350/600 | Cost: 0.0114\n", "Epoch: 039/050 | Batch 400/600 | Cost: 0.0083\n", "Epoch: 039/050 | Batch 450/600 | Cost: 0.0162\n", "Epoch: 039/050 | Batch 500/600 | Cost: 0.0108\n", "Epoch: 039/050 | Batch 550/600 | Cost: 0.0110\n", "Epoch: 039/050 | Train MSE: 0.01043\n", "Epoch: 040/050 | Batch 000/600 | Cost: 0.0121\n", "Epoch: 040/050 | Batch 050/600 | Cost: 0.0137\n", "Epoch: 040/050 | Batch 100/600 | Cost: 0.0094\n", "Epoch: 040/050 | Batch 150/600 | Cost: 0.0080\n", "Epoch: 040/050 | Batch 200/600 | Cost: 0.0107\n", "Epoch: 040/050 | Batch 250/600 | Cost: 0.0092\n", "Epoch: 040/050 | Batch 300/600 | Cost: 0.0088\n", "Epoch: 040/050 | Batch 350/600 | Cost: 0.0097\n", "Epoch: 040/050 | Batch 400/600 | Cost: 0.0084\n", "Epoch: 040/050 | Batch 450/600 | Cost: 0.0134\n", "Epoch: 040/050 | Batch 500/600 | Cost: 0.0144\n", "Epoch: 040/050 | Batch 550/600 | Cost: 0.0094\n", "Epoch: 040/050 | Train MSE: 0.01033\n", "Epoch: 041/050 | Batch 000/600 | Cost: 0.0112\n", "Epoch: 041/050 | Batch 050/600 | Cost: 0.0063\n", "Epoch: 041/050 | Batch 100/600 | Cost: 0.0117\n", "Epoch: 041/050 | Batch 150/600 | Cost: 0.0126\n", "Epoch: 041/050 | Batch 200/600 | Cost: 0.0181\n", "Epoch: 041/050 | Batch 250/600 | Cost: 0.0158\n", "Epoch: 041/050 | Batch 300/600 | Cost: 0.0140\n", "Epoch: 041/050 | Batch 350/600 | Cost: 0.0109\n", "Epoch: 041/050 | Batch 400/600 | Cost: 0.0105\n", "Epoch: 041/050 | Batch 450/600 | Cost: 0.0130\n", "Epoch: 041/050 | Batch 500/600 | Cost: 0.0081\n", "Epoch: 041/050 | Batch 550/600 | Cost: 0.0126\n", "Epoch: 041/050 | Train MSE: 0.01023\n", "Epoch: 042/050 | Batch 000/600 | Cost: 0.0100\n", "Epoch: 042/050 | Batch 050/600 | Cost: 0.0114\n", "Epoch: 042/050 | Batch 100/600 | Cost: 0.0109\n", "Epoch: 042/050 | Batch 150/600 | Cost: 0.0066\n", "Epoch: 042/050 | Batch 200/600 | Cost: 0.0080\n", "Epoch: 042/050 | Batch 250/600 | Cost: 0.0101\n", "Epoch: 042/050 | Batch 300/600 | Cost: 0.0122\n", "Epoch: 042/050 | Batch 350/600 | Cost: 0.0108\n", "Epoch: 042/050 | Batch 400/600 | Cost: 0.0088\n", "Epoch: 042/050 | Batch 450/600 | Cost: 0.0132\n", "Epoch: 042/050 | Batch 500/600 | Cost: 0.0103\n", "Epoch: 042/050 | Batch 550/600 | Cost: 0.0083\n", "Epoch: 042/050 | Train MSE: 0.01013\n", "Epoch: 043/050 | Batch 000/600 | Cost: 0.0097\n", "Epoch: 043/050 | Batch 050/600 | Cost: 0.0103\n", "Epoch: 043/050 | Batch 100/600 | Cost: 0.0144\n", "Epoch: 043/050 | Batch 150/600 | Cost: 0.0095\n", "Epoch: 043/050 | Batch 200/600 | Cost: 0.0108\n", "Epoch: 043/050 | Batch 250/600 | Cost: 0.0124\n", "Epoch: 043/050 | Batch 300/600 | Cost: 0.0125\n", "Epoch: 043/050 | Batch 350/600 | Cost: 0.0117\n", "Epoch: 043/050 | Batch 400/600 | Cost: 0.0085\n", "Epoch: 043/050 | Batch 450/600 | Cost: 0.0097\n", "Epoch: 043/050 | Batch 500/600 | Cost: 0.0163\n", "Epoch: 043/050 | Batch 550/600 | Cost: 0.0099\n", "Epoch: 043/050 | Train MSE: 0.01005\n", "Epoch: 044/050 | Batch 000/600 | Cost: 0.0090\n", "Epoch: 044/050 | Batch 050/600 | Cost: 0.0079\n", "Epoch: 044/050 | Batch 100/600 | Cost: 0.0089\n", "Epoch: 044/050 | Batch 150/600 | Cost: 0.0110\n", "Epoch: 044/050 | Batch 200/600 | Cost: 0.0072\n", "Epoch: 044/050 | Batch 250/600 | Cost: 0.0089\n", "Epoch: 044/050 | Batch 300/600 | Cost: 0.0138\n", "Epoch: 044/050 | Batch 350/600 | Cost: 0.0069\n", "Epoch: 044/050 | Batch 400/600 | Cost: 0.0086\n", "Epoch: 044/050 | Batch 450/600 | Cost: 0.0100\n", "Epoch: 044/050 | Batch 500/600 | Cost: 0.0076\n", "Epoch: 044/050 | Batch 550/600 | Cost: 0.0076\n", "Epoch: 044/050 | Train MSE: 0.00995\n", "Epoch: 045/050 | Batch 000/600 | Cost: 0.0098\n", "Epoch: 045/050 | Batch 050/600 | Cost: 0.0064\n", "Epoch: 045/050 | Batch 100/600 | Cost: 0.0097\n", "Epoch: 045/050 | Batch 150/600 | Cost: 0.0077\n", "Epoch: 045/050 | Batch 200/600 | Cost: 0.0136\n", "Epoch: 045/050 | Batch 250/600 | Cost: 0.0181\n", "Epoch: 045/050 | Batch 300/600 | Cost: 0.0085\n", "Epoch: 045/050 | Batch 350/600 | Cost: 0.0102\n", "Epoch: 045/050 | Batch 400/600 | Cost: 0.0058\n", "Epoch: 045/050 | Batch 450/600 | Cost: 0.0099\n", "Epoch: 045/050 | Batch 500/600 | Cost: 0.0061\n", "Epoch: 045/050 | Batch 550/600 | Cost: 0.0077\n", "Epoch: 045/050 | Train MSE: 0.00986\n", "Epoch: 046/050 | Batch 000/600 | Cost: 0.0074\n", "Epoch: 046/050 | Batch 050/600 | Cost: 0.0109\n", "Epoch: 046/050 | Batch 100/600 | Cost: 0.0090\n", "Epoch: 046/050 | Batch 150/600 | Cost: 0.0079\n", "Epoch: 046/050 | Batch 200/600 | Cost: 0.0085\n", "Epoch: 046/050 | Batch 250/600 | Cost: 0.0104\n", "Epoch: 046/050 | Batch 300/600 | Cost: 0.0121\n", "Epoch: 046/050 | Batch 350/600 | Cost: 0.0101\n", "Epoch: 046/050 | Batch 400/600 | Cost: 0.0091\n", "Epoch: 046/050 | Batch 450/600 | Cost: 0.0114\n", "Epoch: 046/050 | Batch 500/600 | Cost: 0.0082\n", "Epoch: 046/050 | Batch 550/600 | Cost: 0.0104\n", "Epoch: 046/050 | Train MSE: 0.00978\n", "Epoch: 047/050 | Batch 000/600 | Cost: 0.0109\n", "Epoch: 047/050 | Batch 050/600 | Cost: 0.0111\n", "Epoch: 047/050 | Batch 100/600 | Cost: 0.0075\n", "Epoch: 047/050 | Batch 150/600 | Cost: 0.0144\n", "Epoch: 047/050 | Batch 200/600 | Cost: 0.0092\n", "Epoch: 047/050 | Batch 250/600 | Cost: 0.0080\n", "Epoch: 047/050 | Batch 300/600 | Cost: 0.0118\n", "Epoch: 047/050 | Batch 350/600 | Cost: 0.0110\n", "Epoch: 047/050 | Batch 400/600 | Cost: 0.0038\n", "Epoch: 047/050 | Batch 450/600 | Cost: 0.0159\n", "Epoch: 047/050 | Batch 500/600 | Cost: 0.0084\n", "Epoch: 047/050 | Batch 550/600 | Cost: 0.0110\n", "Epoch: 047/050 | Train MSE: 0.00969\n", "Epoch: 048/050 | Batch 000/600 | Cost: 0.0071\n", "Epoch: 048/050 | Batch 050/600 | Cost: 0.0095\n", "Epoch: 048/050 | Batch 100/600 | Cost: 0.0093\n", "Epoch: 048/050 | Batch 150/600 | Cost: 0.0144\n", "Epoch: 048/050 | Batch 200/600 | Cost: 0.0123\n", "Epoch: 048/050 | Batch 250/600 | Cost: 0.0070\n", "Epoch: 048/050 | Batch 300/600 | Cost: 0.0107\n", "Epoch: 048/050 | Batch 350/600 | Cost: 0.0123\n", "Epoch: 048/050 | Batch 400/600 | Cost: 0.0064\n", "Epoch: 048/050 | Batch 450/600 | Cost: 0.0129\n", "Epoch: 048/050 | Batch 500/600 | Cost: 0.0065\n", "Epoch: 048/050 | Batch 550/600 | Cost: 0.0121\n", "Epoch: 048/050 | Train MSE: 0.00961\n", "Epoch: 049/050 | Batch 000/600 | Cost: 0.0031\n", "Epoch: 049/050 | Batch 050/600 | Cost: 0.0115\n", "Epoch: 049/050 | Batch 100/600 | Cost: 0.0046\n", "Epoch: 049/050 | Batch 150/600 | Cost: 0.0104\n", "Epoch: 049/050 | Batch 200/600 | Cost: 0.0070\n", "Epoch: 049/050 | Batch 250/600 | Cost: 0.0056\n", "Epoch: 049/050 | Batch 300/600 | Cost: 0.0114\n", "Epoch: 049/050 | Batch 350/600 | Cost: 0.0099\n", "Epoch: 049/050 | Batch 400/600 | Cost: 0.0110\n", "Epoch: 049/050 | Batch 450/600 | Cost: 0.0077\n", "Epoch: 049/050 | Batch 500/600 | Cost: 0.0071\n", "Epoch: 049/050 | Batch 550/600 | Cost: 0.0120\n", "Epoch: 049/050 | Train MSE: 0.00953\n", "Epoch: 050/050 | Batch 000/600 | Cost: 0.0113\n", "Epoch: 050/050 | Batch 050/600 | Cost: 0.0132\n", "Epoch: 050/050 | Batch 100/600 | Cost: 0.0060\n", "Epoch: 050/050 | Batch 150/600 | Cost: 0.0071\n", "Epoch: 050/050 | Batch 200/600 | Cost: 0.0069\n", "Epoch: 050/050 | Batch 250/600 | Cost: 0.0151\n", "Epoch: 050/050 | Batch 300/600 | Cost: 0.0106\n", "Epoch: 050/050 | Batch 350/600 | Cost: 0.0122\n", "Epoch: 050/050 | Batch 400/600 | Cost: 0.0081\n", "Epoch: 050/050 | Batch 450/600 | Cost: 0.0095\n", "Epoch: 050/050 | Batch 500/600 | Cost: 0.0122\n", "Epoch: 050/050 | Batch 550/600 | Cost: 0.0075\n", "Epoch: 050/050 | Train MSE: 0.00945\n" ] } ], "source": [ "####################################################\n", "##### Training \n", "###################################################\n", "\n", "torch.manual_seed(RANDOM_SEED)\n", "model = MultilayerPerceptron(num_features=28*28,\n", " num_hidden=50,\n", " num_classes=10)\n", "\n", "minibatch_cost, epoch_cost = train(model, \n", " train_loader,\n", " num_epochs=NUM_EPOCHS,\n", " learning_rate=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(range(len(minibatch_cost)), minibatch_cost)\n", "plt.ylabel('Mean Squared Error')\n", "plt.xlabel('Minibatch')\n", "plt.show()\n", "\n", "plt.plot(range(len(epoch_cost)), epoch_cost)\n", "plt.ylabel('Mean Squared Error')\n", "plt.xlabel('Epoch')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training Accuracy: 94.72\n", "Test Accuracy: 94.49\n" ] } ], "source": [ "def compute_accuracy(net, data_loader):\n", " correct_pred, num_examples = 0, 0\n", " with torch.no_grad():\n", " for features, targets in data_loader:\n", " features = features.view(-1, 28*28)\n", " _, outputs = net.forward(features)\n", " predicted_labels = torch.argmax(outputs, 1)\n", " num_examples += targets.size(0)\n", " correct_pred += (predicted_labels == targets).sum()\n", " return correct_pred.float()/num_examples * 100\n", " \n", "print('Training Accuracy: %.2f' % compute_accuracy(model, train_loader))\n", "print('Test Accuracy: %.2f' % compute_accuracy(model, test_loader))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visual Inspection" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAABrCAYAAABnlHmpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAENtJREFUeJzt3XmMlNWax/HvI6hERZEB2xZZ7ohRiKKO7XVwEM1VCSKIimuIYCSyOLiEiV7c4xpFUMclCnivMuICKiAY1FxxcEBBaTecCwoojXDZ0XFFRD3zR9f79qnu6rXet97qt3+fpNOnz9tVderh7cOps5pzDhERaf72SLoAIiISDVXoIiIpoQpdRCQlVKGLiKSEKnQRkZRQhS4ikhKq0EVEUiKvCt3M+pvZ52a2xszGR1UoqaT4xkexjY9imxxr6sIiM2sFrALOADYAy4BLnHMroitey6X4xkexjY9im6zWeTz2j8Aa59yXAGb2AjAYqPUfrkOHDq5bt255vGT6ffDBB9udcx1pZHwV2/o1Nbag+NanoqKC7du3G4ptLLx7t075VOidgPXezxuAE6v/kpmNBEYCdOnShfLy8jxeMv3MbF0mWW98FdvGaUxsM7+v+DZQWVlZkFRsY+Ddu3WKfVDUOTfFOVfmnCvr2LHe/2CkERTbeCm+8VFs45FPhf4PoLP386GZPImG4hsfxTY+im2C8qnQlwGHm9kfzGwv4GJgbjTFEhTfOCm28VFsE9TkPnTn3K9mNhZ4A2gF/NU59/fIStbCKb7xUWzjo9gmK59BUZxz84H5EZVFqlF846PYxkexTY5WioqIpEReLXRpeSZOnBimd+7cCcDy5cvDvJdeeqnGY8aMGROme/fuDcCll14aVxFFWiy10EVEUkItdKnXRRddFKZffPHFOn/XzGrkPfHEE2H6zTffBOCUU04J87p06ZJvESVj1apVABxxxBFh3sMPPxymr7rqqoKXqZj9+OOPYfq6664Dsu9Xb8FUeO937dq1QKVrPLXQRURSQhW6iEhKqMtFahV0tdTXzXLkkUeG6f79+wPw5Zdfhnlz51atK1mzZg0A06dPD/NuvPHG/AsrAHz00UcA7LFHVVutU6dOSRWn6G3cuDFMT506FYBWrVqFef4eM/PmzQNg7NixBSpd46mFLiKSEqrQRURSQl0uksX/iDl79uwa14866qgwHXSldOjQIczbb7/9APjll1/CvBNPrNo99ZNPPgFgx44dEZVYfB9//DFQ9e8AcN555yVVnKK1bds2AIYPH55wSaKlFrqISEo0qxZ6sAoxGLwAOOSQQwBo06ZNmDd06NAwffDBBwPQvXv3QhSx2du0aVOYDo4n9Fvlb7zxRpguLS2t9Xn8FaUrV66scX3gwIF5lVOqfPrpp2H6kUceAWDYsGFJFado+fPx58yZA8CyZcsa/PhFixYBVX8XAMccc0yY7tu3b75FzJta6CIiKaEKXUQkJZpVl0uwNLeioqLO3/OX7u6///4A9OzZM/LydO5cdTDL9ddfD2QvFW6OBg0aFKaDOeNt27YN89q3b9+g55kxY0aY9gdIJXqff/55mA6WsvvbNUila6+9Nkz7c80batasWVnfIXvbipkzZwJw/PHHN7WIeVMLXUQkJZpVC/3JJ58Eqqa+QVXLe8WKFWFesFoOYOHChQAsXbo0zAv+V/3qq6/qfL0999wzTPtT84KBQ/85g9Z6c2+h+5qyCdH9998PVG0SVV0whdGfyij5mTBhQpju1q0bkK77MB8DBgwI0/5g5m+//dagx/t/9/vuuy8A69atC/PWrl0bpk844QQAfv/996YVNgL1ttDN7K9mttXM/tfLa29mfzOz1ZnvB8ZbzPS6/PLLOeigg7Jmkii+0VBs46PYFqeGdLk8DfSvljceWOCcOxxYkPlZmuCyyy7j9ddfr56t+EZAsY2PYluc6u1ycc79j5l1q5Y9GDg1k54GLAT+HGG5cjrttNOyvvuCTaGq++abb4Dsbpjg42h9c1D33nvvMO3vLx1sRvX111+HeYcddlidz1Wbvn375hrkTSS+TfXqq6+G6VtvvRWAXbt2hXklJSVh+t577wVgn332ib1caYhtbfz35d/HwX0adA/Epdhj+/bbbwPw2WefhXn+Xv11DYqOHj06TPfr1y9MH3DAAQC89dZbYd7dd99d4/GPP/54mPZP6yqEpg6KljjnghUom4GS2n7RzEaaWbmZlQfLbaVeDYqvYtskunfjo9gmLO9ZLq5ypMHVcX2Kc67MOVfWsWPHfF+uxakrvoptfnTvxkexTUZTZ7lsMbNS59wmMysFtkZZqCgdeGDluMyf/vSnGtdydd3U5uWXXw7TQTdOr169wryLL764qUXMpdnEF7I39PK7WgL+nGj/6LmENKvY1iboUqgu4cox0dj6XUDB3+P27dvrfIw/j/z8888H4LbbbgvzcnUN+rO/Jk+eHKaD1wrWpAD8/PPPQPYe6v7suag1tYU+Fwi2KRsOvBJNcSRD8Y2PYhsfxTZh9bbQzex5Kgc6OpjZBuA24F5gppmNANYBF8ZZyKRs3VrVwLjyyivDdDCfNRgAhIavoKzukksuYeHChWzfvp1DDz0UoAPNJL7nnHMOkL1hV8DflvSuu+4qWJl8zTm29Vm+fHnOfL91GKfqsW3dujUkHNvdu3eH6fpa5sFGWv6KZn/OeV38Frp/2ta4ceOA7IOng3+Ps88+O8xr6gSKhmjILJdLarnU8P4KqdXzzz+f9bOZbXfO7UDxzZtiG5/qsS0rK6OiokKxTZiW/ouIpESzWvpfaI899liY9rtf2rVrB2TPTW8p/P3S3333XSB7IDQYlLv55pvDPP/0HMnPkiVLAHjqqafCvOOOOy5Mn3HGGQUvU3MRLM2Hqvg1tJulNn5XyrPPPgvA+++/n9dz5kMtdBGRlFALPYfFixcDVasaq3vllcrBe38fi5bCP58y18BTcFpUnAM/LdmCBQuAqqmzkL1K2j+5qyXLtfnWe++9F/nr+Bt+BZty5doEzJ8KOX369MjLEVALXUQkJVShi4ikhLpccpg/fz6QfdLO6aefHqZ79+5d8DIlae7cuWHa3+QscOqpp4bpO+64oxBFarH8swACF1xwQQIlKT7+SWVNOZGoKebNmxemg7+NXJuA3X777QUpj1roIiIpoQpdRCQl1OWSsXPnzjAdbNzv74fuf2SKc3OdYrJjxw4A7rnnnjAv14HPxx57bJjWnPPobd68OUwvWrQIqNqTH+Dcc88teJmKkb8vfxyCbX794y79v41cgnnuhaoz1EIXEUkJtdAzgsONoWpw48wzzwzzTjrppIKXKWmTJk0Cal/5FmzOpYHQeD399NNhesuWLUD2vSmFEZxO5K8gzyU4qBtg2rRpQPY2vXFSC11EJCVUoYuIpESL7nLxB1HuvPPOMB0cBnvLLbcUvEzF5IEHHqjzevDRUwOh8Vq3bl2NvOAkLonXgAEDwrR/4HRdevbsGaZPPvnkyMtUF7XQRURSQhW6iEhKNOQIus7AfwElVJ7iPcU5959m1h6YAXQDKoALnXPf1PY8xSSYX3311VeHeb/++muYDj5mxb3Ef/369QwbNowtW7ZgZowcORKA5hLbII6NmWMbdGf5j/GPDvv2229rPCbYWfDBBx+s87n95d733XcfGzZsYNSoUWF8gYOg+cQ34C8vDwwcODCBklSpfu9+9913QLKxzbXLoe+1116rkXfFFVeE6Y0bN9b5nP6S/rrEPR++Lg1pof8K/Idzrifwr8C/m1lPYDywwDl3OLAg87M0QuvWrZk0aRIrVqxg6dKlQZ90GxTbSLRq1SorvsBBunejUf3e3bZtG4pt8hpypugmYFMm/b2ZrQQ6AYOpPDwaYBqwEPhzLKWMgP8/drB/9Nq1a8O87t27h2l/gDROpaWllJaWAtC2bVt69OjB6tWr96KZxLZXr16NfsyFF1aeGxy8b6iaWw3wwgsv5F8woKSkBKg6Oalt27YAO2lG926wKtSPT7Gofu+2adOGXbt2JRrbMWPGhOlch2WfddZZYTrX5l258vx6o74Nv0aPHt2gcsapUX3oZtYNOA54DyjJVPYAm6nsksn1mJFmVm5m5cHSWampoqIiWND0A4pt5CoqKgD2Qfdu5CoqKvjpp59AsU1cgyt0M9sPeBm41jn3nX/NVXY0uVyPc85Ncc6VOefKgvMmJdsPP/zAkCFDeOihhwB+968ptvkL4gus170brSC2nTt3RrFNXoPmoZvZnlRW5s8652ZlsreYWalzbpOZlQJba3+G5H3xxRdhury8vMZ1f851IY9P2717N0OGDGHo0KH+8W5FEdtgcHjOnDmRPefMmTMb9Hv+oOkee9Rsd/iH85aVldW43qdPHyA7vh9++OH/ZS4XRXzrM3v2bCB7wD44EPqUU05JpEw+P7bPPfdckJ1YbP3jESdMmADkPiaxqYKNtnr06BHmTZ06NUz73YhJqbeFbpVDu38BVjrn/JUmc4HhmfRw4JXoi5duzjlGjBhBjx49GDdunH9JsY2A4hsfxbY4NaSF/m/ApcCnZvZxJu9G4F5gppmNANYBF8ZTxPwEq+z69etX49rEiRPDdBLTwN555x2eeeYZjj76aH8L2gMoktjOmlX5YSxo7UDu7XN9wdai9Q1ujhgxIkx37dq1xvVMFwmQ3SJqjMWLF1ePb08zG0CRxDeXTF80kHuaXXA6UaFO5KlN9Xt31apVJB1b/z6aMWMGkP3pMtOl2WQ33XQTAGPHjs3reeLUkFkui4HaJmCeFm1xWpY+ffpkzXMFMLNvnXM7UGzzVj2+ZrbCOTc/86Pim4fqsS0rK6O8vFyxTZhWioqIpETqN+eaPHkykHuDI39gqaGrwFqiXHN66+MNkkkj+IPB7dq1A2Dw4MFh3jXXXFPwMjVHffv2zfoO2d2uU6ZMAbJX4Q4aNAiAUaNGhXn+pxB/061ipRa6iEhKqEIXEUmJVHa5BEumAR599NEESyLSOH6Xy5IlSxIsSfoEW35UT6eJWugiIimRyhb64sWLw/T3339f43qwEZdO2hGRNFELXUQkJVShi4ikRCq7XHLxltazYMECANq3b59UcUREIqcWuohISqSyhX7DDTfkTIuIpJla6CIiKaEKXUQkJaz69q2xvpjZNuBHILpjRJLXgWjfT1fnXKPP5FJsG6RJsQXFtwEU22yJ3LsFrdABzKzcOVfzzLBmqpjeTzGVJQrF9n6KrTz5Kqb3U0xliUJS70ddLiIiKaEKXUQkJZKo0Kck8JpxKqb3U0xliUKxvZ9iK0++iun9FFNZopDI+yl4H7qIiMRDXS4iIilR0ArdzPqb2edmtsbMxhfytaNgZp3N7L/NbIWZ/d3Mrsnktzezv5nZ6sz3AxMom2IbX9kU23jLp/hGxTlXkC+gFfAF8M/AXsAnQM9CvX5E76EU+JdMui2wCugJTADGZ/LHA/cVuFyKrWLb7GKr+Eb/VcgW+h+BNc65L51zvwAvAIPreUxRcc5tcs59mEl/D6wEOlH5PqZlfm0acE6Bi6bYxkexjZfiG6FCVuidgPXezxsyec2SmXUDjgPeA0qcc5sylzYDJQUujmIbH8U2XopvhDQo2gRmth/wMnCtc+47/5qr/HylqUNNpNjGR7GNVzHEt5AV+j+Azt7Ph2bymhUz25PKf7RnnXOzMtlbzKw0c70U2FrgYim28VFs46X4RqiQFfoy4HAz+4OZ7QVcDMwt4OvnzcwM+Auw0jn3gHdpLjA8kx4OvFLgoim28VFs46X4RqnAo8EDqBwB/gK4KcmR6SaWvw+VH5uWAx9nvgYA/wQsAFYDbwLtEyibYqvYNrvYKr7RfmmlqIhISmhQVEQkJVShi4ikhCp0EZGUUIUuIpISqtBFRFJCFbqISEqoQhcRSQlV6CIiKfH/BswZZS+r8qoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for features, targets in test_loader:\n", " break\n", " \n", "fig, ax = plt.subplots(1, 4)\n", "for i in range(4):\n", " ax[i].imshow(features[i].view(28, 28), cmap=matplotlib.cm.binary)\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted labels tensor([7, 2, 1, 0])\n" ] } ], "source": [ "_, predictions = model.forward(features[:4].view(-1, 28*28))\n", "predictions = torch.argmax(predictions, dim=1)\n", "print('Predicted labels', predictions)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }