{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic Regression\n", "\n", "Logistic Regression is a simple classification technique that analyzes the relationship between a quantitative variable x and a dichotomous categorical variable y. Similar to linear regression, this relationship is inferred by applying a linear transformation to the data\n", "\n", "$$\n", "Y = Xw\n", "$$\n", "\n", "# Binary-Cross-Entropy\n", "\n", "However, Logistic Regression is unique in the way that it **learns relationships.** Given that we have to classify y given x, a natural choice of performance/criterion is the Binary Cross Entropy Loss.\n", "\n", "In short, Cross Entropy is a popular choice of option to judge classification task models whose output is a probability mass function. It is computed by independently applying below function to each prediction\n", "\n", "$$\n", "L(\\hat{y},y)=y\\cdot -log(\\hat{y})^T\n", "$$\n", "\n", "And its binary re-formulation\n", "\n", "$$\n", "L(\\hat{y},y)=y\\cdot-log(\\hat{y})+(1-y)\\cdot(-log(1-\\hat{y}))\n", "$$\n", "\n", "\n", "In general, Cross-Entropy loss exponentially increases as our predictions diverge from the truth; conversely, as our predictions become infinitely close to our target, the loss equates to zero. \n", "\n", "Below graph models Binary-Cross-Entropy when target is either 1 or 0 at different levels of prediction\n", "\n", "\"image-20200508081156737\"\n", "\n", "\n", "\n", "We can derive the gradient of our Loss function w.r.t. our prediction for the backward pass by using some simple Calculus\n", "\n", "$$\n", "\\begin{split}\\frac{\\partial L(\\hat{y},y)}{\\partial \\hat{y}} & =\\frac{\\partial}{\\hat{y}}(y\\cdot-log(\\hat{y})) + \\frac{\\partial}{\\hat{y}}(1-y)\\cdot(-log(1-\\hat{y})) \\\\& =\\frac{-y}{\\hat{y}} - (\\frac{1-y}{1-\\hat{y}}* -1) \\\\& = \\frac{-y}{\\hat{y}} + \\frac{1-y}{1-\\hat{y}}\\end{split}\n", "$$\n", "\n", "Given that the Loss function is usually the last forward operation, it also becomes the first gradient we need to compute for the backward pass. For this reason, there is no incoming gradient that we need to worry about integrating.\n", "\n", "However, instead of calculating a backward pass of *each prediction* w.r.t. our weight parameters, we usually take the mean prediction confidence as a way to better gauge our model's performance. \n", "\n", "In this case, we will be calculating below gradients\n", "\n", "$$\n", "\\frac{\\partial }{\\partial \\hat{w}}avg(L(\\hat{y},y))\n", "$$\n", "\n", "\n", "Once our loss has been computed, we follow the general DL procedure of taking a \"step\" towards steepest descent by computing the gradient of our Loss/criterion function w.r.t. weight parameters\n", "\n", "$$\n", "w_j=w_j-\\alpha\\frac{\\partial }{\\partial w_j}L(w_j)\n", "$$\n", "\n", "# Sigmoid\n", "\n", "In order for us to use the Binary-Cross-Entropy as our criterion, we must first ensure that our model's output ranges between ```[0,1]```. Of coarse, once rounded, this ensures our binary prediction:\n", "\n", "* 1 = Yes\n", "* 0 = No\n", "\n", "We will do this by applying a ```Sigmoid``` layer before feeding our inputs to the Loss function.\n", "\n", "A sigmoid layers is an activation function that \"squeezes\" all our inputs to a range between ```[0,1]``` by applying below function\n", "\n", "$$\n", "\\sigma(y)=\\frac{1}{1+e^{-y}}\n", "$$\n", "\n", "\"Sigmoid\n", "\n", "One distinct property of the Sigmoid function is that its derivative can be calculated by a simple re-formulation of its forward operation\n", "\n", "$$\n", "\\frac{\\partial \\sigma}{\\partial y} = \\sigma(y)(1-\\sigma(y))\n", "$$\n", "\n", "\n", "Given that activation function is applied independently to each element, its derivative function will be equivalent for all inputs.\n", "\n", "$$\n", "\\sigma(y) = \\sigma\\begin{pmatrix}y_1 & y_2 &y_3\\end{pmatrix} = \\begin{pmatrix}\\sigma(y_1) & \\sigma(y_2) & \\sigma(y_3)\\end{pmatrix} \\\\\\frac{\\partial \\sigma}{\\partial y} = \\begin{pmatrix}\\sigma(y_1)(1-\\sigma(y_1) & \\sigma(y_2)(1-\\sigma(y_2) &\\sigma(y_3)(1-\\sigma(y_3)\\end{pmatrix}\n", "$$\n", "\n", "Further, given that Sigmoid introduces no new parameters, its backward pass classifies as an intermediate operation. As a result, we can integrate the latest incoming gradient of the chain rule ($\\frac{\\partial L}{\\partial \\sigma}$) with the partial of our sigmoid function ($\\frac{\\partial \\sigma}{\\partial y}$) by a simple Hadamard product\n", "\n", "$$\n", "\\frac{\\partial L}{\\partial y}=\\frac{\\partial L}{\\partial \\sigma}\\odot \\frac{\\partial \\sigma}{\\partial y}\n", "$$\n", "\n", "**NOTE:** if the above statements do not make much sense, make sure to review the [Linear Layer]() and/or [ReLU]() tutorial where I expand on such concepts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Build Logistic Regression\n", "\n", "Now that we have defined all our needed methods, we will now manually implement the forward/backward pass of each operation using PyTorch's capabilities.\n", "\n", "**NOTE:** We will not go in-depth on our Linear Layer implementation as this was done on this [tutorial](Linear Layer.ipynb)\n" ] }, { "cell_type": "code", "execution_count": 220, "metadata": {}, "outputs": [], "source": [ "import torch\n", "torch.randn((2,2)).cuda()\n", "import torch.nn as nn" ] }, { "cell_type": "code", "execution_count": 221, "metadata": {}, "outputs": [], "source": [ "####################### Linear Layer ###################\n", "\n", "\n", "class Linear_Layer(torch.autograd.Function):\n", " \"\"\"\n", " Define a Linear Layer operation\n", " \"\"\"\n", " @staticmethod\n", " def forward(ctx, input,weights, bias = None):\n", " \"\"\"\n", " In the forward pass, we feed this class all necessary objects to \n", " compute a linear layer (input, weights, and bias)\n", " \"\"\"\n", " # input.dim = (B, in_dim)\n", " # weights.dim = (in_dim, out_dim)\n", " \n", " # given that the grad(output) wrt weight parameters equals the input,\n", " # we will save it to use for backpropagation\n", " ctx.save_for_backward(input, weights, bias)\n", " \n", " \n", " # linear transformation\n", " # (B, out_dim) = (B, in_dim) * (in_dim, out_dim)\n", " output = torch.mm(input, weights)\n", " \n", " if bias is not None:\n", " # bias.shape = (out_dim)\n", " \n", " # expanded_bias.shape = (B, out_dim), repeats bias B times\n", " expanded_bias = bias.unsqueeze(0).expand_as(output)\n", " \n", " # element-wise addition\n", " output += expanded_bias\n", " \n", " return output\n", "\n", " \n", " @staticmethod\n", " def backward(ctx, incoming_grad):\n", " \"\"\"\n", " In the backward pass we receive a Tensor (output_grad) containing the \n", " gradient of the loss with respect to our f(x) output, \n", " and we now need to compute the gradient of the loss\n", " with respect to our defined function.\n", " \"\"\"\n", " # incoming_grad.shape = (B, out_dim)\n", " \n", " # extract inputs from forward pass\n", " input, weights, bias = ctx.saved_tensors \n", " \n", " # assume none of the inputs need gradients\n", " grad_input = grad_weight = grad_bias = None\n", " \n", " \n", " # if input requires grad\n", " if ctx.needs_input_grad[0]:\n", " # (B, in_dim) = (B, out_dim) * (out_dim, in_dim)\n", " grad_input = incoming_grad.mm(weights.t())\n", " \n", " # if weights require grad\n", " if ctx.needs_input_grad[1]:\n", " # (out_dim, in_dim) = (out_dim, B) * (B, in_dim) \n", " grad_weight = incoming_grad.t().mm(input)\n", " \n", " # if bias requires grad\n", " if bias is not None and ctx.needs_input_grad[2]:\n", " # torch.ones((1,B)).mm(incoming_grad) \n", " # (out) = (1,B)*(B,out_dim)\n", " grad_bias = incoming_grad.sum(0)\n", " \n", " \n", " # below, if any of the grads = None, they will simply be ignored\n", " \n", " # add grad_output.t() to match original layout of weight parameter\n", " return grad_input, grad_weight.t(), grad_bias\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 222, "metadata": {}, "outputs": [], "source": [ "class Linear(nn.Module):\n", " def __init__(self, in_dim, out_dim, bias = True):\n", " super().__init__()\n", " self.in_dim = in_dim\n", " self.out_dim = out_dim\n", " \n", " # define parameters\n", " \n", " # weight parameter\n", " self.weight = nn.Parameter(torch.randn((in_dim, out_dim)))\n", " \n", " # bias parameter\n", " if bias:\n", " self.bias = nn.Parameter(torch.randn((out_dim)))\n", " else:\n", " # register parameter as None if not initialized\n", " self.register_parameter('bias',None)\n", " \n", " def forward(self, input):\n", " output = Linear_Layer.apply(input, self.weight, self.bias)\n", " return output" ] }, { "cell_type": "code", "execution_count": 223, "metadata": {}, "outputs": [], "source": [ "################## Sigmoid Layer #######################\n", "\n", "# Remember that our incoming gradient will be of equal dims as our output\n", "# b/c of this, output now becomes an intermediate variable\n", "# input.shape == out.shape == incoming_gradient.shape\n", "\n", "import torch.nn as nn\n", "import torch\n", "\n", "class sigmoid_layer(torch.autograd.Function):\n", " \n", " def __init__(self):\n", " ''\n", " \n", " def sigmoid(self,x):\n", " sig = 1 / (1 + (-1*x).exp())\n", " return sig\n", " \n", " # forward pass\n", " def forward(self, input):\n", " # save input for backward() pass \n", " self.save_for_backward(input) \n", " activated_input = self.sigmoid(input)\n", " return activated_input\n", "\n", " # integrate backward pass with incoming_grad\n", " def backward(self, incoming_grad):\n", " \"\"\"\n", " In the backward pass we receive a Tensor containing the \n", " gradient of the loss with respect to our f(x) output, \n", " and we need to compute the gradient of the loss\n", " with respect to the input.\n", " \"\"\"\n", " input, = self.saved_tensors\n", " chained_grad = (self.sigmoid(input) * (1- self.sigmoid(input))) * incoming_grad\n", " return chained_grad" ] }, { "cell_type": "code", "execution_count": 224, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.7311], grad_fn=)" ] }, "execution_count": 224, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# test forward pass\n", "\n", "weight = torch.tensor([1.], requires_grad = True)\n", "input = torch.tensor([1.])\n", "x = input * weight\n", "sig = sigmoid_layer()(x)\n", "sig" ] }, { "cell_type": "code", "execution_count": 225, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.1966])" ] }, "execution_count": 225, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# test backward pass\n", "\n", "sig.backward(torch.tensor([1.]))\n", "weight.grad" ] }, { "cell_type": "code", "execution_count": 226, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.7311], grad_fn=)" ] }, "execution_count": 226, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# compare output with PyTorch's inherent Method\n", "\n", "weight = torch.tensor([1.], requires_grad = True)\n", "input = torch.tensor([1.])\n", "x = input * weight\n", "\n", "sig = nn.Sigmoid()(x)\n", "sig" ] }, { "cell_type": "code", "execution_count": 227, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.1966])" ] }, "execution_count": 227, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sig.backward()\n", "weight.grad" ] }, { "cell_type": "code", "execution_count": 228, "metadata": {}, "outputs": [], "source": [ "# Wrap ReLU_layer function in nn.module\n", "\n", "class Sigmoid(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", "\n", " \n", " def forward(self, input):\n", " output = sigmoid_layer()(input)\n", " return output\n", " " ] }, { "cell_type": "code", "execution_count": 229, "metadata": {}, "outputs": [], "source": [ "#################### Binary Cross Entropy ################\n", "\n", "# inputs must all be of type .float()\n", "class BCE_loss(torch.autograd.Function):\n", " \n", "\n", " @staticmethod\n", " def forward(self, yhat, y):\n", " # save input for backward() pass \n", " self.save_for_backward(y,yhat) \n", " loss = - (y * yhat.log() + (1-y)* (1-yhat).log())\n", " return loss\n", "\n", " @staticmethod\n", " def backward(self, output_grad):\n", " y,yhat = self.saved_tensors\n", " chained_grad = ((yhat-y) / (yhat * (1- yhat)))\n", " \n", " # y does not need gradient and thus we pass None to signify this\n", " return chained_grad, None" ] }, { "cell_type": "code", "execution_count": 230, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.6931], grad_fn=)" ] }, "execution_count": 230, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# test above method\n", "output = torch.tensor([.50], requires_grad = True)\n", "y = torch.tensor([1.])\n", "loss = BCE_loss.apply(output,y)\n", "loss" ] }, { "cell_type": "code", "execution_count": 231, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-2.])" ] }, "execution_count": 231, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# test backward() method\n", "loss.backward()\n", "output.grad" ] }, { "cell_type": "code", "execution_count": 232, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.6931, grad_fn=)" ] }, "execution_count": 232, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# test with PyTorch\n", "output = torch.tensor([.50], requires_grad = True)\n", "y = torch.tensor([1.])\n", "bce = nn.BCELoss()\n", "loss = bce(output,y)\n", "loss" ] }, { "cell_type": "code", "execution_count": 233, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-2.])" ] }, "execution_count": 233, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# test backward() method\n", "loss.backward()\n", "output.grad" ] }, { "cell_type": "code", "execution_count": 234, "metadata": {}, "outputs": [], "source": [ "# Wrap BCELoss function in nn.module\n", "\n", "class BCELoss(nn.Module):\n", " def __init__(self, reduction = 'mean'):\n", " super().__init__()\n", "\n", " \n", " def forward(self, pred, target):\n", " output = BCE_loss.apply(pred,target)\n", " # reduce output by average\n", " output = output.mean()\n", " return output\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have all of our \"ingredients\", we can now create our Logistic model" ] }, { "cell_type": "code", "execution_count": 235, "metadata": {}, "outputs": [], "source": [ "# Create Logistic Regression function\n", "class LogisticRegression(nn.Module):\n", " def __init__(self, input_dim = 30):\n", " super().__init__()\n", " self.linear = Linear(input_dim, 1) \n", " self.sigmoid = Sigmoid()\n", " \n", " def forward(self,x):\n", " # output.shape = (B, 1)\n", " output = self.sigmoid(self.linear(x))\n", " return output.view(-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Wisconcin Breast Cancer Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To showcase our Linear Regression, we will train our model to differentiate between malignant and benign cancer cells, given the characteristics of the cell nuclei. \n", "\n", "Refer to [link](https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)) to learn more about the data" ] }, { "cell_type": "code", "execution_count": 236, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
radius_meantexture_meanperimeter_meanarea_meansmoothness_meancompactness_meanconcavity_meanpoints_meansymmetry_meandimension_mean...radius_worsttexture_worstperimeter_worstarea_worstsmoothness_worstcompactness_worstconcavity_worstpoints_worstsymmetry_worstdimension_worst
diagnosis
B12.3212.3978.85464.10.102800.069810.039870.037000.19590.05955...13.5015.6486.97549.10.13850.12660.124200.093910.28270.06771
B10.6018.9569.28346.40.096880.114700.063870.026420.19220.06491...11.8822.9478.28424.80.12130.25150.191600.079260.29400.07587
B11.0416.8370.92373.20.107700.078040.030460.024800.17140.06340...12.4126.4479.93471.40.13690.14820.106700.074310.29980.07881
B11.2813.3973.00384.80.116400.113600.046350.047960.17710.06072...11.9215.7776.53434.00.13670.18220.086690.086110.21020.06784
B15.1913.2197.65711.80.079630.069340.033930.026570.17210.05544...16.2015.73104.50819.10.11260.17370.136200.081780.24870.06766
\n", "

5 rows × 30 columns

\n", "
" ], "text/plain": [ " radius_mean texture_mean perimeter_mean area_mean \\\n", "diagnosis \n", "B 12.32 12.39 78.85 464.1 \n", "B 10.60 18.95 69.28 346.4 \n", "B 11.04 16.83 70.92 373.2 \n", "B 11.28 13.39 73.00 384.8 \n", "B 15.19 13.21 97.65 711.8 \n", "\n", " smoothness_mean compactness_mean concavity_mean points_mean \\\n", "diagnosis \n", "B 0.10280 0.06981 0.03987 0.03700 \n", "B 0.09688 0.11470 0.06387 0.02642 \n", "B 0.10770 0.07804 0.03046 0.02480 \n", "B 0.11640 0.11360 0.04635 0.04796 \n", "B 0.07963 0.06934 0.03393 0.02657 \n", "\n", " symmetry_mean dimension_mean ... radius_worst texture_worst \\\n", "diagnosis ... \n", "B 0.1959 0.05955 ... 13.50 15.64 \n", "B 0.1922 0.06491 ... 11.88 22.94 \n", "B 0.1714 0.06340 ... 12.41 26.44 \n", "B 0.1771 0.06072 ... 11.92 15.77 \n", "B 0.1721 0.05544 ... 16.20 15.73 \n", "\n", " perimeter_worst area_worst smoothness_worst compactness_worst \\\n", "diagnosis \n", "B 86.97 549.1 0.1385 0.1266 \n", "B 78.28 424.8 0.1213 0.2515 \n", "B 79.93 471.4 0.1369 0.1482 \n", "B 76.53 434.0 0.1367 0.1822 \n", "B 104.50 819.1 0.1126 0.1737 \n", "\n", " concavity_worst points_worst symmetry_worst dimension_worst \n", "diagnosis \n", "B 0.12420 0.09391 0.2827 0.06771 \n", "B 0.19160 0.07926 0.2940 0.07587 \n", "B 0.10670 0.07431 0.2998 0.07881 \n", "B 0.08669 0.08611 0.2102 0.06784 \n", "B 0.13620 0.08178 0.2487 0.06766 \n", "\n", "[5 rows x 30 columns]" ] }, "execution_count": 236, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# import data\n", "import pandas as pd\n", "url = 'https://raw.githubusercontent.com/PacktPublishing/Machine-Learning-with-R-Second-Edition/master/Chapter%2003/wisc_bc_data.csv'\n", "df = pd.read_csv(url)\n", "df.index = df.diagnosis\n", "df.drop(columns = ['diagnosis','id'],inplace = True)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 237, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Index: 569 entries, B to M\n", "Data columns (total 30 columns):\n", "radius_mean 569 non-null float64\n", "texture_mean 569 non-null float64\n", "perimeter_mean 569 non-null float64\n", "area_mean 569 non-null float64\n", "smoothness_mean 569 non-null float64\n", "compactness_mean 569 non-null float64\n", "concavity_mean 569 non-null float64\n", "points_mean 569 non-null float64\n", "symmetry_mean 569 non-null float64\n", "dimension_mean 569 non-null float64\n", "radius_se 569 non-null float64\n", "texture_se 569 non-null float64\n", "perimeter_se 569 non-null float64\n", "area_se 569 non-null float64\n", "smoothness_se 569 non-null float64\n", "compactness_se 569 non-null float64\n", "concavity_se 569 non-null float64\n", "points_se 569 non-null float64\n", "symmetry_se 569 non-null float64\n", "dimension_se 569 non-null float64\n", "radius_worst 569 non-null float64\n", "texture_worst 569 non-null float64\n", "perimeter_worst 569 non-null float64\n", "area_worst 569 non-null float64\n", "smoothness_worst 569 non-null float64\n", "compactness_worst 569 non-null float64\n", "concavity_worst 569 non-null float64\n", "points_worst 569 non-null float64\n", "symmetry_worst 569 non-null float64\n", "dimension_worst 569 non-null float64\n", "dtypes: float64(30)\n", "memory usage: 137.8+ KB\n" ] } ], "source": [ "df.info(verbose = True)" ] }, { "cell_type": "code", "execution_count": 238, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# visualize the distribution of our binary classes\n", "\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "plt.style.use('ggplot')\n", "\n", "sns.countplot(df.index);plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given that there is about twice more data on Benign cells than Malignant, a model will become bias towards classifying Benign cells" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data Preprocessing\n", "\n" ] }, { "cell_type": "code", "execution_count": 239, "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Separate features (X) from target (y)\n", "import numpy as np\n", "\n", "X = df.values\n", "y = (df.index == 'M')\n", "y = y.astype(np.double)" ] }, { "cell_type": "code", "execution_count": 240, "metadata": {}, "outputs": [], "source": [ "# normalize features\n", "from sklearn.preprocessing import normalize\n", "X = normalize(X, axis = 0)" ] }, { "cell_type": "code", "execution_count": 241, "metadata": {}, "outputs": [], "source": [ "# parse data to training and testing set for evaluation\n", "\n", "from sklearn.model_selection import train_test_split\n", "X_train, X_test, y_train, y_test = train_test_split(X,y, test_size = .20, random_state = 42, shuffle = True)" ] }, { "cell_type": "code", "execution_count": 242, "metadata": {}, "outputs": [], "source": [ "# Transform data to PyTorch tensors and separate data into batches\n", "from skorch.dataset import Dataset\n", "from torch.utils.data import DataLoader\n", "\n", "# Wrap each observation with its corresponding target\n", "train = Dataset(X_train,y_train) \n", "test = Dataset(X_test,y_test) \n", "\n", "# separate data into batches of 16\n", "train_dl = DataLoader(train, batch_size = 16, pin_memory = True)\n", "test_dl = DataLoader(test, batch_size = 16, pin_memory = True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have all the data formatted, let's instatiate our model, criterion, and optimizer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Instantiate Logistic Regression" ] }, { "cell_type": "code", "execution_count": 243, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LogisticRegression(\n", " (linear): Linear()\n", " (sigmoid): Sigmoid()\n", ")" ] }, "execution_count": 243, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# instantiate model and place it on GPU\n", "\n", "device = torch.device('cuda')\n", "model = LogisticRegression(30).to(device)\n", "model" ] }, { "cell_type": "code", "execution_count": 244, "metadata": {}, "outputs": [], "source": [ "# initiate loss function\n", "criterion = BCELoss()" ] }, { "cell_type": "code", "execution_count": 245, "metadata": {}, "outputs": [], "source": [ "# initiate optimizer\n", "from torch import optim\n", "optimizer = optim.SGD(model.parameters(), lr = .01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make one forward pass to make sure everything works as it should" ] }, { "cell_type": "code", "execution_count": 246, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "batch_X.shape: torch.Size([16, 30])\n", "-----------------------------------\n", "batch_X.shape: torch.Size([16])\n" ] } ], "source": [ "# test train_dl\n", "batch_X,batch_y = next(iter(train_dl))\n", "print(f\"batch_X.shape: {batch_X.shape}\")\n", "print('-'*35)\n", "print(f\"batch_X.shape: {batch_y.shape}\")" ] }, { "cell_type": "code", "execution_count": 247, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "torch.Size([16])" ] }, "execution_count": 247, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Assert our model makes as many predictions according to the batch\n", "# all inputs must be of type .float()\n", "\n", "output = model(batch_X.cuda().float())\n", "output.shape" ] }, { "cell_type": "code", "execution_count": 248, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.6896, device='cuda:0', grad_fn=)" ] }, "execution_count": 248, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# average loss\n", "loss = criterion(output,batch_y.cuda().float())\n", "loss" ] }, { "cell_type": "code", "execution_count": 249, "metadata": {}, "outputs": [], "source": [ "# compute gradients by calling .backward()\n", "loss.backward()" ] }, { "cell_type": "code", "execution_count": 250, "metadata": {}, "outputs": [], "source": [ "# take a step\n", "optimizer.step()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train Logistic Regression\n", "\n", "Now that we have asserted our model works as should, it's time to train it" ] }, { "cell_type": "code", "execution_count": 251, "metadata": {}, "outputs": [], "source": [ "\n", "def train(model, iterator, optimizer, criterion):\n", " \n", " # hold avg loss and acc sum of all batches\n", " epoch_loss = 0\n", " epoch_acc = 0\n", " \n", " \n", " for batch in iterator:\n", " \n", " # zero-out all gradients (if any) from our model parameters\n", " model.zero_grad()\n", " \n", " \n", " \n", " # extract input and label\n", " \n", " # input.shape = (B, fetures)\n", " input = batch[0].cuda().float()\n", " # label.shape = (B)\n", " label = batch[1].cuda().float()\n", " \n", " \n", " # Start PyTorch's Dynamic Graph\n", " \n", " # predictions.shape = (B)\n", " predictions = model(input)\n", " \n", " # average batch loss \n", " loss = criterion(predictions, label)\n", " \n", " # calculate grad(loss) / grad(parameters)\n", " # \"clears\" PyTorch's dynamic graph\n", " loss.backward()\n", " \n", " \n", " # perform SGD \"step\" operation\n", " optimizer.step()\n", " \n", " \n", " # Given that PyTorch variables are \"contagious\" (they record all operations)\n", " # we need to \".detach()\" to stop them from recording any performance\n", " # statistics\n", " \n", " \n", " # average batch accuracy\n", " acc = binary_accuracy(predictions.detach(), label)\n", " \n", "\n", " \n", " # record our stats\n", " epoch_loss += loss.detach()\n", " epoch_acc += acc\n", " \n", " # NOTE: tense.item() unpacks Tensor item to a regular python object \n", " # tense.tensor([1]).item() == 1\n", " \n", " # return average loss and acc of epoch\n", " return epoch_loss.item() / len(iterator), epoch_acc / len(iterator)\n" ] }, { "cell_type": "code", "execution_count": 252, "metadata": {}, "outputs": [], "source": [ "# compute average accuracy per batch\n", "\n", "def binary_accuracy(preds, y):\n", " # preds.shape = (B)\n", " # y.shape = (B)\n", "\n", " #round predictions to the closest integer\n", " rounded_preds = torch.round(preds)\n", " correct = (rounded_preds == y).sum()\n", " acc = correct.item() / len(y)\n", " return acc" ] }, { "cell_type": "code", "execution_count": 253, "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "def epoch_time(start_time, end_time):\n", " elapsed_time = end_time - start_time\n", " elapsed_mins = int(elapsed_time / 60)\n", " elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n", " return elapsed_mins, elapsed_secs\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 254, "metadata": {}, "outputs": [], "source": [ "def evaluate(model, iterator, criterion):\n", " \n", " epoch_loss = 0\n", " epoch_acc = 0\n", " \n", " # turn off grad tracking as we are only evaluation performance\n", " with torch.no_grad():\n", " \n", " for batch in iterator:\n", "\n", " # extract input and label \n", " input = batch[0].cuda().float()\n", " label = batch[1].cuda().float()\n", "\n", "\n", " # predictions.shape = (B,1)\n", " predictions = model(input)\n", "\n", " # average batch loss \n", " loss = criterion(predictions, label)\n", "\n", " # average batch accuracy\n", " acc = binary_accuracy(predictions, label)\n", "\n", " epoch_loss += loss\n", " epoch_acc += acc\n", " \n", " return epoch_loss.item() / len(iterator), epoch_acc / len(iterator)" ] }, { "cell_type": "code", "execution_count": 255, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "---------------------------------------------------------------------------\n", "Epoch: 01 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.662 | Train Acc: 62.65%\n", "\t Val. Loss: 0.653 | Val. Acc: 63.28%\n", "---------------------------------------------------------------------------\n", "Epoch: 02 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.658 | Train Acc: 62.65%\n", "\t Val. Loss: 0.649 | Val. Acc: 63.28%\n", "---------------------------------------------------------------------------\n", "Epoch: 03 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.655 | Train Acc: 62.65%\n", "\t Val. Loss: 0.646 | Val. Acc: 63.28%\n", "---------------------------------------------------------------------------\n", "Epoch: 04 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.652 | Train Acc: 62.65%\n", "\t Val. Loss: 0.643 | Val. Acc: 63.28%\n", "---------------------------------------------------------------------------\n", "Epoch: 05 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.649 | Train Acc: 62.65%\n", "\t Val. Loss: 0.639 | Val. Acc: 63.28%\n", "---------------------------------------------------------------------------\n", "Epoch: 06 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.645 | Train Acc: 62.65%\n", "\t Val. Loss: 0.636 | Val. Acc: 63.28%\n", "---------------------------------------------------------------------------\n", "Epoch: 07 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.642 | Train Acc: 62.87%\n", "\t Val. Loss: 0.633 | Val. Acc: 63.28%\n", "---------------------------------------------------------------------------\n", "Epoch: 08 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.639 | Train Acc: 62.87%\n", "\t Val. Loss: 0.630 | Val. Acc: 63.28%\n", "---------------------------------------------------------------------------\n", "Epoch: 09 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.636 | Train Acc: 62.87%\n", "\t Val. Loss: 0.627 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 10 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.633 | Train Acc: 62.87%\n", "\t Val. Loss: 0.624 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 11 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.630 | Train Acc: 62.87%\n", "\t Val. Loss: 0.621 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 12 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.627 | Train Acc: 62.65%\n", "\t Val. Loss: 0.618 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 13 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.624 | Train Acc: 62.65%\n", "\t Val. Loss: 0.615 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 14 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.621 | Train Acc: 62.44%\n", "\t Val. Loss: 0.612 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 15 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.618 | Train Acc: 62.44%\n", "\t Val. Loss: 0.609 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 16 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.615 | Train Acc: 62.44%\n", "\t Val. Loss: 0.606 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 17 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.612 | Train Acc: 62.87%\n", "\t Val. Loss: 0.603 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 18 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.609 | Train Acc: 63.52%\n", "\t Val. Loss: 0.600 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 19 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.607 | Train Acc: 63.95%\n", "\t Val. Loss: 0.597 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 20 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.604 | Train Acc: 63.95%\n", "\t Val. Loss: 0.594 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 21 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.601 | Train Acc: 64.59%\n", "\t Val. Loss: 0.592 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 22 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.599 | Train Acc: 64.81%\n", "\t Val. Loss: 0.589 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 23 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.596 | Train Acc: 65.02%\n", "\t Val. Loss: 0.586 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 24 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.593 | Train Acc: 65.02%\n", "\t Val. Loss: 0.584 | Val. Acc: 64.06%\n", "---------------------------------------------------------------------------\n", "Epoch: 25 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.591 | Train Acc: 65.24%\n", "\t Val. Loss: 0.581 | Val. Acc: 64.84%\n", "---------------------------------------------------------------------------\n", "Epoch: 26 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.588 | Train Acc: 65.46%\n", "\t Val. Loss: 0.578 | Val. Acc: 64.84%\n", "---------------------------------------------------------------------------\n", "Epoch: 27 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.586 | Train Acc: 65.89%\n", "\t Val. Loss: 0.576 | Val. Acc: 64.84%\n", "---------------------------------------------------------------------------\n", "Epoch: 28 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.583 | Train Acc: 66.32%\n", "\t Val. Loss: 0.573 | Val. Acc: 65.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 29 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.581 | Train Acc: 66.75%\n", "\t Val. Loss: 0.571 | Val. Acc: 65.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 30 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.578 | Train Acc: 66.96%\n", "\t Val. Loss: 0.568 | Val. Acc: 65.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 31 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.576 | Train Acc: 66.96%\n", "\t Val. Loss: 0.566 | Val. Acc: 66.41%\n", "---------------------------------------------------------------------------\n", "Epoch: 32 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.573 | Train Acc: 67.40%\n", "\t Val. Loss: 0.564 | Val. Acc: 68.75%\n", "---------------------------------------------------------------------------\n", "Epoch: 33 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.571 | Train Acc: 67.61%\n", "\t Val. Loss: 0.561 | Val. Acc: 68.75%\n", "---------------------------------------------------------------------------\n", "Epoch: 34 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.569 | Train Acc: 68.04%\n", "\t Val. Loss: 0.559 | Val. Acc: 69.53%\n", "---------------------------------------------------------------------------\n", "Epoch: 35 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.566 | Train Acc: 68.26%\n", "\t Val. Loss: 0.556 | Val. Acc: 69.53%\n", "---------------------------------------------------------------------------\n", "Epoch: 36 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.564 | Train Acc: 68.26%\n", "\t Val. Loss: 0.554 | Val. Acc: 69.53%\n", "---------------------------------------------------------------------------\n", "Epoch: 37 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.562 | Train Acc: 68.69%\n", "\t Val. Loss: 0.552 | Val. Acc: 69.53%\n", "---------------------------------------------------------------------------\n", "Epoch: 38 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.560 | Train Acc: 68.69%\n", "\t Val. Loss: 0.550 | Val. Acc: 69.53%\n", "---------------------------------------------------------------------------\n", "Epoch: 39 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.558 | Train Acc: 68.90%\n", "\t Val. Loss: 0.547 | Val. Acc: 69.53%\n", "---------------------------------------------------------------------------\n", "Epoch: 40 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.555 | Train Acc: 69.77%\n", "\t Val. Loss: 0.545 | Val. Acc: 69.53%\n", "---------------------------------------------------------------------------\n", "Epoch: 41 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.553 | Train Acc: 70.84%\n", "\t Val. Loss: 0.543 | Val. Acc: 69.53%\n", "---------------------------------------------------------------------------\n", "Epoch: 42 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.551 | Train Acc: 71.55%\n", "\t Val. Loss: 0.541 | Val. Acc: 69.53%\n", "---------------------------------------------------------------------------\n", "Epoch: 43 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.549 | Train Acc: 71.77%\n", "\t Val. Loss: 0.539 | Val. Acc: 70.31%\n", "---------------------------------------------------------------------------\n", "Epoch: 44 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.547 | Train Acc: 71.77%\n", "\t Val. Loss: 0.536 | Val. Acc: 70.31%\n", "---------------------------------------------------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 45 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.545 | Train Acc: 71.98%\n", "\t Val. Loss: 0.534 | Val. Acc: 70.31%\n", "---------------------------------------------------------------------------\n", "Epoch: 46 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.543 | Train Acc: 72.20%\n", "\t Val. Loss: 0.532 | Val. Acc: 70.31%\n", "---------------------------------------------------------------------------\n", "Epoch: 47 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.541 | Train Acc: 73.06%\n", "\t Val. Loss: 0.530 | Val. Acc: 71.09%\n", "---------------------------------------------------------------------------\n", "Epoch: 48 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.539 | Train Acc: 73.06%\n", "\t Val. Loss: 0.528 | Val. Acc: 71.09%\n", "---------------------------------------------------------------------------\n", "Epoch: 49 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.537 | Train Acc: 73.06%\n", "\t Val. Loss: 0.526 | Val. Acc: 71.09%\n", "---------------------------------------------------------------------------\n", "Epoch: 50 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.535 | Train Acc: 73.49%\n", "\t Val. Loss: 0.524 | Val. Acc: 71.09%\n", "---------------------------------------------------------------------------\n", "Epoch: 51 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.533 | Train Acc: 73.71%\n", "\t Val. Loss: 0.522 | Val. Acc: 71.88%\n", "---------------------------------------------------------------------------\n", "Epoch: 52 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.531 | Train Acc: 73.92%\n", "\t Val. Loss: 0.520 | Val. Acc: 71.88%\n", "---------------------------------------------------------------------------\n", "Epoch: 53 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.529 | Train Acc: 74.14%\n", "\t Val. Loss: 0.518 | Val. Acc: 71.88%\n", "---------------------------------------------------------------------------\n", "Epoch: 54 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.527 | Train Acc: 74.14%\n", "\t Val. Loss: 0.516 | Val. Acc: 72.66%\n", "---------------------------------------------------------------------------\n", "Epoch: 55 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.525 | Train Acc: 74.57%\n", "\t Val. Loss: 0.514 | Val. Acc: 73.44%\n", "---------------------------------------------------------------------------\n", "Epoch: 56 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.524 | Train Acc: 74.57%\n", "\t Val. Loss: 0.512 | Val. Acc: 74.22%\n", "---------------------------------------------------------------------------\n", "Epoch: 57 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.522 | Train Acc: 74.78%\n", "\t Val. Loss: 0.511 | Val. Acc: 74.22%\n", "---------------------------------------------------------------------------\n", "Epoch: 58 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.520 | Train Acc: 75.00%\n", "\t Val. Loss: 0.509 | Val. Acc: 74.22%\n", "---------------------------------------------------------------------------\n", "Epoch: 59 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.518 | Train Acc: 75.43%\n", "\t Val. Loss: 0.507 | Val. Acc: 74.22%\n", "---------------------------------------------------------------------------\n", "Epoch: 60 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.516 | Train Acc: 76.08%\n", "\t Val. Loss: 0.505 | Val. Acc: 74.22%\n", "---------------------------------------------------------------------------\n", "Epoch: 61 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.515 | Train Acc: 76.51%\n", "\t Val. Loss: 0.503 | Val. Acc: 74.22%\n", "---------------------------------------------------------------------------\n", "Epoch: 62 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.513 | Train Acc: 76.72%\n", "\t Val. Loss: 0.502 | Val. Acc: 75.78%\n", "---------------------------------------------------------------------------\n", "Epoch: 63 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.511 | Train Acc: 77.16%\n", "\t Val. Loss: 0.500 | Val. Acc: 75.78%\n", "---------------------------------------------------------------------------\n", "Epoch: 64 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.510 | Train Acc: 77.37%\n", "\t Val. Loss: 0.498 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 65 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.508 | Train Acc: 77.37%\n", "\t Val. Loss: 0.496 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 66 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.506 | Train Acc: 77.37%\n", "\t Val. Loss: 0.495 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 67 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.505 | Train Acc: 77.59%\n", "\t Val. Loss: 0.493 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 68 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.503 | Train Acc: 77.59%\n", "\t Val. Loss: 0.491 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 69 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.501 | Train Acc: 78.02%\n", "\t Val. Loss: 0.490 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 70 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.500 | Train Acc: 78.02%\n", "\t Val. Loss: 0.488 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 71 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.498 | Train Acc: 78.45%\n", "\t Val. Loss: 0.486 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 72 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.497 | Train Acc: 78.45%\n", "\t Val. Loss: 0.485 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 73 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.495 | Train Acc: 78.66%\n", "\t Val. Loss: 0.483 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 74 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.494 | Train Acc: 79.09%\n", "\t Val. Loss: 0.482 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 75 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.492 | Train Acc: 79.31%\n", "\t Val. Loss: 0.480 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 76 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.491 | Train Acc: 79.31%\n", "\t Val. Loss: 0.478 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 77 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.489 | Train Acc: 79.53%\n", "\t Val. Loss: 0.477 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 78 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.488 | Train Acc: 79.74%\n", "\t Val. Loss: 0.475 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 79 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.486 | Train Acc: 80.17%\n", "\t Val. Loss: 0.474 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 80 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.485 | Train Acc: 80.17%\n", "\t Val. Loss: 0.472 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 81 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.483 | Train Acc: 80.60%\n", "\t Val. Loss: 0.471 | Val. Acc: 76.56%\n", "---------------------------------------------------------------------------\n", "Epoch: 82 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.482 | Train Acc: 81.03%\n", "\t Val. Loss: 0.469 | Val. Acc: 77.34%\n", "---------------------------------------------------------------------------\n", "Epoch: 83 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.480 | Train Acc: 81.47%\n", "\t Val. Loss: 0.468 | Val. Acc: 77.34%\n", "---------------------------------------------------------------------------\n", "Epoch: 84 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.479 | Train Acc: 81.68%\n", "\t Val. Loss: 0.466 | Val. Acc: 78.12%\n", "---------------------------------------------------------------------------\n", "Epoch: 85 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.478 | Train Acc: 81.68%\n", "\t Val. Loss: 0.465 | Val. Acc: 78.12%\n", "---------------------------------------------------------------------------\n", "Epoch: 86 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.476 | Train Acc: 81.68%\n", "\t Val. Loss: 0.464 | Val. Acc: 78.12%\n", "---------------------------------------------------------------------------\n", "Epoch: 87 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.475 | Train Acc: 81.68%\n", "\t Val. Loss: 0.462 | Val. Acc: 78.12%\n", "---------------------------------------------------------------------------\n", "Epoch: 88 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.474 | Train Acc: 82.11%\n", "\t Val. Loss: 0.461 | Val. Acc: 78.91%\n", "---------------------------------------------------------------------------\n", "Epoch: 89 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.472 | Train Acc: 82.11%\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\t Val. Loss: 0.459 | Val. Acc: 78.91%\n", "---------------------------------------------------------------------------\n", "Epoch: 90 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.471 | Train Acc: 82.11%\n", "\t Val. Loss: 0.458 | Val. Acc: 78.91%\n", "---------------------------------------------------------------------------\n", "Epoch: 91 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.470 | Train Acc: 82.33%\n", "\t Val. Loss: 0.457 | Val. Acc: 78.91%\n", "---------------------------------------------------------------------------\n", "Epoch: 92 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.468 | Train Acc: 82.76%\n", "\t Val. Loss: 0.455 | Val. Acc: 78.91%\n", "---------------------------------------------------------------------------\n", "Epoch: 93 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.467 | Train Acc: 83.41%\n", "\t Val. Loss: 0.454 | Val. Acc: 79.69%\n", "---------------------------------------------------------------------------\n", "Epoch: 94 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.466 | Train Acc: 83.62%\n", "\t Val. Loss: 0.453 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 95 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.465 | Train Acc: 83.84%\n", "\t Val. Loss: 0.451 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 96 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.463 | Train Acc: 84.27%\n", "\t Val. Loss: 0.450 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 97 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.462 | Train Acc: 84.48%\n", "\t Val. Loss: 0.449 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 98 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.461 | Train Acc: 84.48%\n", "\t Val. Loss: 0.447 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 99 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.460 | Train Acc: 84.48%\n", "\t Val. Loss: 0.446 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 100 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.458 | Train Acc: 84.48%\n", "\t Val. Loss: 0.445 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 101 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.457 | Train Acc: 84.48%\n", "\t Val. Loss: 0.444 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 102 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.456 | Train Acc: 84.48%\n", "\t Val. Loss: 0.442 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 103 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.455 | Train Acc: 84.91%\n", "\t Val. Loss: 0.441 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 104 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.454 | Train Acc: 84.91%\n", "\t Val. Loss: 0.440 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 105 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.452 | Train Acc: 84.91%\n", "\t Val. Loss: 0.439 | Val. Acc: 80.47%\n", "---------------------------------------------------------------------------\n", "Epoch: 106 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.451 | Train Acc: 84.91%\n", "\t Val. Loss: 0.438 | Val. Acc: 81.25%\n", "---------------------------------------------------------------------------\n", "Epoch: 107 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.450 | Train Acc: 84.91%\n", "\t Val. Loss: 0.436 | Val. Acc: 81.25%\n", "---------------------------------------------------------------------------\n", "Epoch: 108 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.449 | Train Acc: 84.91%\n", "\t Val. Loss: 0.435 | Val. Acc: 82.03%\n", "---------------------------------------------------------------------------\n", "Epoch: 109 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.448 | Train Acc: 85.13%\n", "\t Val. Loss: 0.434 | Val. Acc: 82.03%\n", "---------------------------------------------------------------------------\n", "Epoch: 110 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.447 | Train Acc: 85.13%\n", "\t Val. Loss: 0.433 | Val. Acc: 82.03%\n", "---------------------------------------------------------------------------\n", "Epoch: 111 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.446 | Train Acc: 85.13%\n", "\t Val. Loss: 0.432 | Val. Acc: 82.03%\n", "---------------------------------------------------------------------------\n", "Epoch: 112 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.445 | Train Acc: 85.34%\n", "\t Val. Loss: 0.430 | Val. Acc: 82.03%\n", "---------------------------------------------------------------------------\n", "Epoch: 113 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.444 | Train Acc: 85.34%\n", "\t Val. Loss: 0.429 | Val. Acc: 82.03%\n", "---------------------------------------------------------------------------\n", "Epoch: 114 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.443 | Train Acc: 85.99%\n", "\t Val. Loss: 0.428 | Val. Acc: 82.03%\n", "---------------------------------------------------------------------------\n", "Epoch: 115 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.441 | Train Acc: 85.99%\n", "\t Val. Loss: 0.427 | Val. Acc: 82.03%\n", "---------------------------------------------------------------------------\n", "Epoch: 116 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.440 | Train Acc: 85.99%\n", "\t Val. Loss: 0.426 | Val. Acc: 82.03%\n", "---------------------------------------------------------------------------\n", "Epoch: 117 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.439 | Train Acc: 85.99%\n", "\t Val. Loss: 0.425 | Val. Acc: 82.03%\n", "---------------------------------------------------------------------------\n", "Epoch: 118 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.438 | Train Acc: 86.21%\n", "\t Val. Loss: 0.424 | Val. Acc: 82.03%\n", "---------------------------------------------------------------------------\n", "Epoch: 119 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.437 | Train Acc: 86.42%\n", "\t Val. Loss: 0.423 | Val. Acc: 89.84%\n", "---------------------------------------------------------------------------\n", "Epoch: 120 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.436 | Train Acc: 86.42%\n", "\t Val. Loss: 0.422 | Val. Acc: 89.84%\n", "---------------------------------------------------------------------------\n", "Epoch: 121 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.435 | Train Acc: 86.42%\n", "\t Val. Loss: 0.421 | Val. Acc: 89.84%\n", "---------------------------------------------------------------------------\n", "Epoch: 122 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.434 | Train Acc: 86.64%\n", "\t Val. Loss: 0.419 | Val. Acc: 89.84%\n", "---------------------------------------------------------------------------\n", "Epoch: 123 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.433 | Train Acc: 86.85%\n", "\t Val. Loss: 0.418 | Val. Acc: 89.84%\n", "---------------------------------------------------------------------------\n", "Epoch: 124 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.432 | Train Acc: 86.85%\n", "\t Val. Loss: 0.417 | Val. Acc: 89.84%\n", "---------------------------------------------------------------------------\n", "Epoch: 125 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.431 | Train Acc: 86.85%\n", "\t Val. Loss: 0.416 | Val. Acc: 89.84%\n", "---------------------------------------------------------------------------\n", "Epoch: 126 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.430 | Train Acc: 86.85%\n", "\t Val. Loss: 0.415 | Val. Acc: 89.84%\n", "---------------------------------------------------------------------------\n", "Epoch: 127 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.429 | Train Acc: 86.85%\n", "\t Val. Loss: 0.414 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 128 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.428 | Train Acc: 86.85%\n", "\t Val. Loss: 0.413 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 129 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.427 | Train Acc: 86.85%\n", "\t Val. Loss: 0.412 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 130 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.426 | Train Acc: 86.85%\n", "\t Val. Loss: 0.411 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 131 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.425 | Train Acc: 86.85%\n", "\t Val. Loss: 0.410 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 132 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.425 | Train Acc: 86.85%\n", "\t Val. Loss: 0.409 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 133 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.424 | Train Acc: 86.85%\n", "\t Val. Loss: 0.408 | Val. Acc: 90.62%\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "---------------------------------------------------------------------------\n", "Epoch: 134 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.423 | Train Acc: 86.85%\n", "\t Val. Loss: 0.407 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 135 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.422 | Train Acc: 87.07%\n", "\t Val. Loss: 0.406 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 136 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.421 | Train Acc: 87.28%\n", "\t Val. Loss: 0.405 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 137 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.420 | Train Acc: 87.28%\n", "\t Val. Loss: 0.404 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 138 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.419 | Train Acc: 87.28%\n", "\t Val. Loss: 0.403 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 139 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.418 | Train Acc: 87.28%\n", "\t Val. Loss: 0.402 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 140 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.417 | Train Acc: 87.28%\n", "\t Val. Loss: 0.402 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 141 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.416 | Train Acc: 87.50%\n", "\t Val. Loss: 0.401 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 142 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.416 | Train Acc: 87.50%\n", "\t Val. Loss: 0.400 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 143 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.415 | Train Acc: 87.50%\n", "\t Val. Loss: 0.399 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 144 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.414 | Train Acc: 87.50%\n", "\t Val. Loss: 0.398 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 145 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.413 | Train Acc: 87.50%\n", "\t Val. Loss: 0.397 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 146 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.412 | Train Acc: 87.50%\n", "\t Val. Loss: 0.396 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 147 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.411 | Train Acc: 87.50%\n", "\t Val. Loss: 0.395 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 148 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.411 | Train Acc: 87.50%\n", "\t Val. Loss: 0.394 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 149 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.410 | Train Acc: 87.50%\n", "\t Val. Loss: 0.393 | Val. Acc: 90.62%\n", "---------------------------------------------------------------------------\n", "Epoch: 150 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.409 | Train Acc: 87.50%\n", "\t Val. Loss: 0.393 | Val. Acc: 90.62%\n" ] } ], "source": [ "N_EPOCHS = 150\n", "\n", "# track statistics\n", "track_stats = {'epoch': [],\n", " 'train_loss': [],\n", " 'train_acc': [],\n", " 'valid_loss':[],\n", " 'valid_acc':[]}\n", "\n", "\n", "best_valid_loss = float('inf')\n", "\n", "for epoch in range(N_EPOCHS):\n", "\n", " start_time = time.time()\n", " \n", " train_loss, train_acc = train(model, train_dl, optimizer, criterion)\n", " valid_loss, valid_acc = evaluate(model, test_dl, criterion)\n", " \n", " end_time = time.time()\n", " \n", " # record operations\n", " track_stats['epoch'].append(epoch + 1)\n", " track_stats['train_loss'].append(train_loss)\n", " track_stats['train_acc'].append(train_acc)\n", " track_stats['valid_loss'].append(valid_loss)\n", " track_stats['valid_acc'].append(valid_acc)\n", " \n", " \n", "\n", " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", " \n", " # if this was our best performance, record model parameters\n", " if valid_loss < best_valid_loss:\n", " best_valid_loss = valid_loss\n", " torch.save(model.state_dict(), 'best_log_regression.pt')\n", " \n", " # print out stats\n", " print('-'*75)\n", " print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n", " print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n", " print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')\n" ] }, { "cell_type": "markdown", "metadata": { "scrolled": true }, "source": [ "# Visualization\n", "\n", "Our model performed very well! With a top validation accuracy of 90.62%\n", "\n", "Now, let us graph our results" ] }, { "cell_type": "code", "execution_count": 256, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_losstrain_accvalid_lossvalid_acc
010.6618840.6265390.6528510.632812
120.6584710.6265390.6494580.632812
230.6551210.6265390.6461000.632812
340.6518150.6265390.6427800.632812
450.6485500.6265390.6394960.632812
560.6453230.6265390.6362490.632812
670.6421340.6286950.6330390.632812
780.6389830.6286950.6298650.632812
890.6358680.6286950.6267280.640625
9100.6327900.6286950.6236260.640625
10110.6297480.6286950.6205590.640625
11120.6267420.6265390.6175270.640625
12130.6237710.6265390.6145290.640625
13140.6208350.6243840.6115650.640625
14150.6179330.6243840.6086340.640625
15160.6150650.6243840.6057370.640625
16170.6122310.6286950.6028730.640625
17180.6094300.6351600.6000400.640625
18190.6066610.6394700.5972400.640625
19200.6039250.6394700.5944710.640625
20210.6012200.6459360.5917330.640625
21220.5985470.6480910.5890250.640625
22230.5959050.6502460.5863480.640625
23240.5932930.6502460.5837010.640625
24250.5907120.6524010.5810830.648438
25260.5881600.6545570.5784940.648438
26270.5856380.6588670.5759340.648438
27280.5831440.6631770.5734020.656250
28290.5806790.6674880.5708970.656250
29300.5782420.6696430.5684210.656250
..................
1201210.4352260.8642240.4205190.898438
1211220.4342180.8663790.4194550.898438
1221230.4332190.8685340.4183980.898438
1231240.4322260.8685340.4173490.898438
1241250.4312420.8685340.4163080.898438
1251260.4302650.8685340.4152750.898438
1261270.4292950.8685340.4142490.906250
1271280.4283330.8685340.4132310.906250
1281290.4273780.8685340.4122200.906250
1291300.4264300.8685340.4112160.906250
1301310.4254900.8685340.4102200.906250
1311320.4245560.8685340.4092310.906250
1321330.4236300.8685340.4082490.906250
1331340.4227100.8685340.4072730.906250
1341350.4217970.8706900.4063050.906250
1351360.4208910.8728450.4053440.906250
1361370.4199920.8728450.4043890.906250
1371380.4190990.8728450.4034410.906250
1381390.4182120.8728450.4024990.906250
1391400.4173320.8728450.4015640.906250
1401410.4164580.8750000.4006360.906250
1411420.4155910.8750000.3997140.906250
1421430.4147300.8750000.3987980.906250
1431440.4138740.8750000.3978880.906250
1441450.4130250.8750000.3969850.906250
1451460.4121820.8750000.3960880.906250
1461470.4113450.8750000.3951960.906250
1471480.4105140.8750000.3943110.906250
1481490.4096890.8750000.3934320.906250
1491500.4088690.8750000.3925580.906250
\n", "

150 rows × 5 columns

\n", "
" ], "text/plain": [ " epoch train_loss train_acc valid_loss valid_acc\n", "0 1 0.661884 0.626539 0.652851 0.632812\n", "1 2 0.658471 0.626539 0.649458 0.632812\n", "2 3 0.655121 0.626539 0.646100 0.632812\n", "3 4 0.651815 0.626539 0.642780 0.632812\n", "4 5 0.648550 0.626539 0.639496 0.632812\n", ".. ... ... ... ... ...\n", "145 146 0.412182 0.875000 0.396088 0.906250\n", "146 147 0.411345 0.875000 0.395196 0.906250\n", "147 148 0.410514 0.875000 0.394311 0.906250\n", "148 149 0.409689 0.875000 0.393432 0.906250\n", "149 150 0.408869 0.875000 0.392558 0.906250\n", "\n", "[150 rows x 5 columns]" ] }, "execution_count": 256, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# format data \n", "import pandas as pd\n", "\n", "stats = pd.DataFrame(track_stats)\n", "stats" ] }, { "cell_type": "code", "execution_count": 257, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'epoch': 1.0,\n", " 'train_loss': 0.6618835843842605,\n", " 'train_acc': 0.6265394088669951,\n", " 'valid_loss': 0.6528506278991699,\n", " 'valid_acc': 0.6328125},\n", " {'epoch': 2.0,\n", " 'train_loss': 0.6584710745975889,\n", " 'train_acc': 0.6265394088669951,\n", " 'valid_loss': 0.6494578719139099,\n", " 'valid_acc': 0.6328125},\n", " {'epoch': 3.0,\n", " 'train_loss': 0.6551205207561624,\n", " 'train_acc': 0.6265394088669951,\n", " 'valid_loss': 0.6461004018783569,\n", " 'valid_acc': 0.6328125},\n", " {'epoch': 4.0,\n", " 'train_loss': 0.6518148882635708,\n", " 'train_acc': 0.6265394088669951,\n", " 'valid_loss': 0.642779529094696,\n", " 'valid_acc': 0.6328125},\n", " {'epoch': 5.0,\n", " 'train_loss': 0.6485495074041958,\n", " 'train_acc': 0.6265394088669951,\n", " 'valid_loss': 0.6394957304000854,\n", " 'valid_acc': 0.6328125},\n", " {'epoch': 6.0,\n", " 'train_loss': 0.6453227339119747,\n", " 'train_acc': 0.6265394088669951,\n", " 'valid_loss': 0.6362490057945251,\n", " 'valid_acc': 0.6328125},\n", " {'epoch': 7.0,\n", " 'train_loss': 0.6421338443098397,\n", " 'train_acc': 0.6286945812807881,\n", " 'valid_loss': 0.6330390572547913,\n", " 'valid_acc': 0.6328125},\n", " {'epoch': 8.0,\n", " 'train_loss': 0.6389825755152209,\n", " 'train_acc': 0.6286945812807881,\n", " 'valid_loss': 0.6298654079437256,\n", " 'valid_acc': 0.6328125},\n", " {'epoch': 9.0,\n", " 'train_loss': 0.6358680067391231,\n", " 'train_acc': 0.6286945812807881,\n", " 'valid_loss': 0.6267277598381042,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 10.0,\n", " 'train_loss': 0.6327900722109038,\n", " 'train_acc': 0.6286945812807881,\n", " 'valid_loss': 0.623625636100769,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 11.0,\n", " 'train_loss': 0.629748245765423,\n", " 'train_acc': 0.6286945812807881,\n", " 'valid_loss': 0.6205587387084961,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 12.0,\n", " 'train_loss': 0.6267420670081829,\n", " 'train_acc': 0.6265394088669951,\n", " 'valid_loss': 0.6175265908241272,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 13.0,\n", " 'train_loss': 0.6237710097740436,\n", " 'train_acc': 0.6265394088669951,\n", " 'valid_loss': 0.6145287752151489,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 14.0,\n", " 'train_loss': 0.6208349425217201,\n", " 'train_acc': 0.624384236453202,\n", " 'valid_loss': 0.6115648746490479,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 15.0,\n", " 'train_loss': 0.6179331417741447,\n", " 'train_acc': 0.624384236453202,\n", " 'valid_loss': 0.6086344718933105,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 16.0,\n", " 'train_loss': 0.6150652129074623,\n", " 'train_acc': 0.624384236453202,\n", " 'valid_loss': 0.6057370901107788,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 17.0,\n", " 'train_loss': 0.6122307612978178,\n", " 'train_acc': 0.6286945812807881,\n", " 'valid_loss': 0.6028725504875183,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 18.0,\n", " 'train_loss': 0.609429721174569,\n", " 'train_acc': 0.6351600985221675,\n", " 'valid_loss': 0.6000401973724365,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 19.0,\n", " 'train_loss': 0.6066611059780779,\n", " 'train_acc': 0.6394704433497537,\n", " 'valid_loss': 0.5972397327423096,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 20.0,\n", " 'train_loss': 0.603924718396417,\n", " 'train_acc': 0.6394704433497537,\n", " 'valid_loss': 0.5944706201553345,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 21.0,\n", " 'train_loss': 0.6012202295763739,\n", " 'train_acc': 0.645935960591133,\n", " 'valid_loss': 0.5917326211929321,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 22.0,\n", " 'train_loss': 0.598547179123451,\n", " 'train_acc': 0.6480911330049262,\n", " 'valid_loss': 0.5890252590179443,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 23.0,\n", " 'train_loss': 0.5959050408725081,\n", " 'train_acc': 0.6502463054187192,\n", " 'valid_loss': 0.5863481163978577,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 24.0,\n", " 'train_loss': 0.5932934201996902,\n", " 'train_acc': 0.6502463054187192,\n", " 'valid_loss': 0.5837007761001587,\n", " 'valid_acc': 0.640625},\n", " {'epoch': 25.0,\n", " 'train_loss': 0.5907119882517847,\n", " 'train_acc': 0.6524014778325123,\n", " 'valid_loss': 0.581082820892334,\n", " 'valid_acc': 0.6484375},\n", " {'epoch': 26.0,\n", " 'train_loss': 0.5881601530930092,\n", " 'train_acc': 0.6545566502463054,\n", " 'valid_loss': 0.5784939527511597,\n", " 'valid_acc': 0.6484375},\n", " {'epoch': 27.0,\n", " 'train_loss': 0.5856377831820784,\n", " 'train_acc': 0.6588669950738917,\n", " 'valid_loss': 0.5759336352348328,\n", " 'valid_acc': 0.6484375},\n", " {'epoch': 28.0,\n", " 'train_loss': 0.5831442865832098,\n", " 'train_acc': 0.6631773399014779,\n", " 'valid_loss': 0.5734015703201294,\n", " 'valid_acc': 0.65625},\n", " {'epoch': 29.0,\n", " 'train_loss': 0.580679202901906,\n", " 'train_acc': 0.6674876847290641,\n", " 'valid_loss': 0.5708974003791809,\n", " 'valid_acc': 0.65625},\n", " {'epoch': 30.0,\n", " 'train_loss': 0.5782424005968817,\n", " 'train_acc': 0.6696428571428572,\n", " 'valid_loss': 0.5684206485748291,\n", " 'valid_acc': 0.65625},\n", " {'epoch': 31.0,\n", " 'train_loss': 0.5758331561910695,\n", " 'train_acc': 0.6696428571428572,\n", " 'valid_loss': 0.5659710168838501,\n", " 'valid_acc': 0.6640625},\n", " {'epoch': 32.0,\n", " 'train_loss': 0.5734513381431843,\n", " 'train_acc': 0.6739532019704434,\n", " 'valid_loss': 0.5635480284690857,\n", " 'valid_acc': 0.6875},\n", " {'epoch': 33.0,\n", " 'train_loss': 0.5710964202880859,\n", " 'train_acc': 0.6761083743842364,\n", " 'valid_loss': 0.561151385307312,\n", " 'valid_acc': 0.6875},\n", " {'epoch': 34.0,\n", " 'train_loss': 0.5687679422312769,\n", " 'train_acc': 0.6804187192118227,\n", " 'valid_loss': 0.5587806701660156,\n", " 'valid_acc': 0.6953125},\n", " {'epoch': 35.0,\n", " 'train_loss': 0.5664656408901872,\n", " 'train_acc': 0.6825738916256158,\n", " 'valid_loss': 0.5564355850219727,\n", " 'valid_acc': 0.6953125},\n", " {'epoch': 36.0,\n", " 'train_loss': 0.5641893847235318,\n", " 'train_acc': 0.6825738916256158,\n", " 'valid_loss': 0.5541156530380249,\n", " 'valid_acc': 0.6953125},\n", " {'epoch': 37.0,\n", " 'train_loss': 0.5619383844836005,\n", " 'train_acc': 0.686884236453202,\n", " 'valid_loss': 0.5518207550048828,\n", " 'valid_acc': 0.6953125},\n", " {'epoch': 38.0,\n", " 'train_loss': 0.5597124428584658,\n", " 'train_acc': 0.686884236453202,\n", " 'valid_loss': 0.5495502352714539,\n", " 'valid_acc': 0.6953125},\n", " {'epoch': 39.0,\n", " 'train_loss': 0.5575110994536301,\n", " 'train_acc': 0.6890394088669951,\n", " 'valid_loss': 0.5473039746284485,\n", " 'valid_acc': 0.6953125},\n", " {'epoch': 40.0,\n", " 'train_loss': 0.5553342227278084,\n", " 'train_acc': 0.6976600985221675,\n", " 'valid_loss': 0.5450814962387085,\n", " 'valid_acc': 0.6953125},\n", " {'epoch': 41.0,\n", " 'train_loss': 0.5531812865158607,\n", " 'train_acc': 0.708435960591133,\n", " 'valid_loss': 0.5428824424743652,\n", " 'valid_acc': 0.6953125},\n", " {'epoch': 42.0,\n", " 'train_loss': 0.551051929079253,\n", " 'train_acc': 0.7155172413793104,\n", " 'valid_loss': 0.5407066345214844,\n", " 'valid_acc': 0.6953125},\n", " {'epoch': 43.0,\n", " 'train_loss': 0.5489459531060581,\n", " 'train_acc': 0.7176724137931034,\n", " 'valid_loss': 0.5385535955429077,\n", " 'valid_acc': 0.703125},\n", " {'epoch': 44.0,\n", " 'train_loss': 0.5468627995458143,\n", " 'train_acc': 0.7176724137931034,\n", " 'valid_loss': 0.5364230871200562,\n", " 'valid_acc': 0.703125},\n", " {'epoch': 45.0,\n", " 'train_loss': 0.5448023697425579,\n", " 'train_acc': 0.7198275862068966,\n", " 'valid_loss': 0.5343146920204163,\n", " 'valid_acc': 0.703125},\n", " {'epoch': 46.0,\n", " 'train_loss': 0.5427641704164702,\n", " 'train_acc': 0.7219827586206896,\n", " 'valid_loss': 0.5322281718254089,\n", " 'valid_acc': 0.703125},\n", " {'epoch': 47.0,\n", " 'train_loss': 0.5407478398290174,\n", " 'train_acc': 0.7306034482758621,\n", " 'valid_loss': 0.5301631689071655,\n", " 'valid_acc': 0.7109375},\n", " {'epoch': 48.0,\n", " 'train_loss': 0.538753345094878,\n", " 'train_acc': 0.7306034482758621,\n", " 'valid_loss': 0.5281195044517517,\n", " 'valid_acc': 0.7109375},\n", " {'epoch': 49.0,\n", " 'train_loss': 0.5367800942782698,\n", " 'train_acc': 0.7306034482758621,\n", " 'valid_loss': 0.5260966420173645,\n", " 'valid_acc': 0.7109375},\n", " {'epoch': 50.0,\n", " 'train_loss': 0.5348278900672649,\n", " 'train_acc': 0.7349137931034483,\n", " 'valid_loss': 0.5240945219993591,\n", " 'valid_acc': 0.7109375},\n", " {'epoch': 51.0,\n", " 'train_loss': 0.5328963049526872,\n", " 'train_acc': 0.7370689655172413,\n", " 'valid_loss': 0.5221127271652222,\n", " 'valid_acc': 0.71875},\n", " {'epoch': 52.0,\n", " 'train_loss': 0.5309851745079304,\n", " 'train_acc': 0.7392241379310345,\n", " 'valid_loss': 0.520150899887085,\n", " 'valid_acc': 0.71875},\n", " {'epoch': 53.0,\n", " 'train_loss': 0.5290941369944605,\n", " 'train_acc': 0.7413793103448276,\n", " 'valid_loss': 0.5182088613510132,\n", " 'valid_acc': 0.71875},\n", " {'epoch': 54.0,\n", " 'train_loss': 0.5272229951003502,\n", " 'train_acc': 0.7413793103448276,\n", " 'valid_loss': 0.516286313533783,\n", " 'valid_acc': 0.7265625},\n", " {'epoch': 55.0,\n", " 'train_loss': 0.5253712555457806,\n", " 'train_acc': 0.7456896551724138,\n", " 'valid_loss': 0.5143830180168152,\n", " 'valid_acc': 0.734375},\n", " {'epoch': 56.0,\n", " 'train_loss': 0.5235388525601091,\n", " 'train_acc': 0.7456896551724138,\n", " 'valid_loss': 0.5124986171722412,\n", " 'valid_acc': 0.7421875},\n", " {'epoch': 57.0,\n", " 'train_loss': 0.5217253586341595,\n", " 'train_acc': 0.7478448275862069,\n", " 'valid_loss': 0.5106328129768372,\n", " 'valid_acc': 0.7421875},\n", " {'epoch': 58.0,\n", " 'train_loss': 0.5199305764560042,\n", " 'train_acc': 0.75,\n", " 'valid_loss': 0.5087854862213135,\n", " 'valid_acc': 0.7421875},\n", " {'epoch': 59.0,\n", " 'train_loss': 0.5181542429430731,\n", " 'train_acc': 0.7543103448275862,\n", " 'valid_loss': 0.5069563388824463,\n", " 'valid_acc': 0.7421875},\n", " {'epoch': 60.0,\n", " 'train_loss': 0.5163960950127964,\n", " 'train_acc': 0.7607758620689655,\n", " 'valid_loss': 0.5051449537277222,\n", " 'valid_acc': 0.7421875},\n", " {'epoch': 61.0,\n", " 'train_loss': 0.5146557380413187,\n", " 'train_acc': 0.7650862068965517,\n", " 'valid_loss': 0.5033513307571411,\n", " 'valid_acc': 0.7421875},\n", " {'epoch': 62.0,\n", " 'train_loss': 0.512933139143319,\n", " 'train_acc': 0.7672413793103449,\n", " 'valid_loss': 0.5015749931335449,\n", " 'valid_acc': 0.7578125},\n", " {'epoch': 63.0,\n", " 'train_loss': 0.5112278379242996,\n", " 'train_acc': 0.771551724137931,\n", " 'valid_loss': 0.49981582164764404,\n", " 'valid_acc': 0.7578125},\n", " {'epoch': 64.0,\n", " 'train_loss': 0.509539768613618,\n", " 'train_acc': 0.7737068965517241,\n", " 'valid_loss': 0.4980735778808594,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 65.0,\n", " 'train_loss': 0.5078684708167767,\n", " 'train_acc': 0.7737068965517241,\n", " 'valid_loss': 0.4963480234146118,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 66.0,\n", " 'train_loss': 0.5062139116484543,\n", " 'train_acc': 0.7737068965517241,\n", " 'valid_loss': 0.4946388006210327,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 67.0,\n", " 'train_loss': 0.5045756964847959,\n", " 'train_acc': 0.7758620689655172,\n", " 'valid_loss': 0.4929458796977997,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 68.0,\n", " 'train_loss': 0.5029537924404802,\n", " 'train_acc': 0.7758620689655172,\n", " 'valid_loss': 0.49126893281936646,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 69.0,\n", " 'train_loss': 0.501347673350367,\n", " 'train_acc': 0.7801724137931034,\n", " 'valid_loss': 0.48960769176483154,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 70.0,\n", " 'train_loss': 0.4997573720997778,\n", " 'train_acc': 0.7801724137931034,\n", " 'valid_loss': 0.487962007522583,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 71.0,\n", " 'train_loss': 0.49818255983549975,\n", " 'train_acc': 0.7844827586206896,\n", " 'valid_loss': 0.4863317012786865,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 72.0,\n", " 'train_loss': 0.4966230063602842,\n", " 'train_acc': 0.7844827586206896,\n", " 'valid_loss': 0.4847164750099182,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 73.0,\n", " 'train_loss': 0.4950785472475249,\n", " 'train_acc': 0.7866379310344828,\n", " 'valid_loss': 0.48311617970466614,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 74.0,\n", " 'train_loss': 0.4935490509559368,\n", " 'train_acc': 0.790948275862069,\n", " 'valid_loss': 0.4815305769443512,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 75.0,\n", " 'train_loss': 0.49203392554973735,\n", " 'train_acc': 0.7931034482758621,\n", " 'valid_loss': 0.47995948791503906,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 76.0,\n", " 'train_loss': 0.49053353276746026,\n", " 'train_acc': 0.7931034482758621,\n", " 'valid_loss': 0.4784027636051178,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 77.0,\n", " 'train_loss': 0.489047280673323,\n", " 'train_acc': 0.7952586206896551,\n", " 'valid_loss': 0.47686007618904114,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 78.0,\n", " 'train_loss': 0.4875750706113618,\n", " 'train_acc': 0.7974137931034483,\n", " 'valid_loss': 0.4753313660621643,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 79.0,\n", " 'train_loss': 0.4861166723843279,\n", " 'train_acc': 0.8017241379310345,\n", " 'valid_loss': 0.4738163650035858,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 80.0,\n", " 'train_loss': 0.4846720531068999,\n", " 'train_acc': 0.8017241379310345,\n", " 'valid_loss': 0.4723149538040161,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 81.0,\n", " 'train_loss': 0.4832408181552229,\n", " 'train_acc': 0.8060344827586207,\n", " 'valid_loss': 0.4708269238471985,\n", " 'valid_acc': 0.765625},\n", " {'epoch': 82.0,\n", " 'train_loss': 0.48182293464397563,\n", " 'train_acc': 0.8103448275862069,\n", " 'valid_loss': 0.4693520665168762,\n", " 'valid_acc': 0.7734375},\n", " {'epoch': 83.0,\n", " 'train_loss': 0.48041820526123047,\n", " 'train_acc': 0.8146551724137931,\n", " 'valid_loss': 0.4678902328014374,\n", " 'valid_acc': 0.7734375},\n", " {'epoch': 84.0,\n", " 'train_loss': 0.47902633403909617,\n", " 'train_acc': 0.8168103448275862,\n", " 'valid_loss': 0.46644127368927,\n", " 'valid_acc': 0.78125},\n", " {'epoch': 85.0,\n", " 'train_loss': 0.4776471894362877,\n", " 'train_acc': 0.8168103448275862,\n", " 'valid_loss': 0.4650050103664398,\n", " 'valid_acc': 0.78125},\n", " {'epoch': 86.0,\n", " 'train_loss': 0.4762807056821626,\n", " 'train_acc': 0.8168103448275862,\n", " 'valid_loss': 0.46358129382133484,\n", " 'valid_acc': 0.78125},\n", " {'epoch': 87.0,\n", " 'train_loss': 0.47492652103818694,\n", " 'train_acc': 0.8168103448275862,\n", " 'valid_loss': 0.4621698558330536,\n", " 'valid_acc': 0.78125},\n", " {'epoch': 88.0,\n", " 'train_loss': 0.47358470127500335,\n", " 'train_acc': 0.8211206896551724,\n", " 'valid_loss': 0.4607706367969513,\n", " 'valid_acc': 0.7890625},\n", " {'epoch': 89.0,\n", " 'train_loss': 0.4722549504247205,\n", " 'train_acc': 0.8211206896551724,\n", " 'valid_loss': 0.45938345789909363,\n", " 'valid_acc': 0.7890625},\n", " {'epoch': 90.0,\n", " 'train_loss': 0.4709371040607321,\n", " 'train_acc': 0.8211206896551724,\n", " 'valid_loss': 0.45800817012786865,\n", " 'valid_acc': 0.7890625},\n", " {'epoch': 91.0,\n", " 'train_loss': 0.46963099775643186,\n", " 'train_acc': 0.8232758620689655,\n", " 'valid_loss': 0.45664459466934204,\n", " 'valid_acc': 0.7890625},\n", " {'epoch': 92.0,\n", " 'train_loss': 0.468336532855856,\n", " 'train_acc': 0.8275862068965517,\n", " 'valid_loss': 0.45529258251190186,\n", " 'valid_acc': 0.7890625},\n", " {'epoch': 93.0,\n", " 'train_loss': 0.46705347916175577,\n", " 'train_acc': 0.834051724137931,\n", " 'valid_loss': 0.45395195484161377,\n", " 'valid_acc': 0.796875},\n", " {'epoch': 94.0,\n", " 'train_loss': 0.4657817709034887,\n", " 'train_acc': 0.8362068965517241,\n", " 'valid_loss': 0.452622652053833,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 95.0,\n", " 'train_loss': 0.4645211449984846,\n", " 'train_acc': 0.8383620689655172,\n", " 'valid_loss': 0.45130446553230286,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 96.0,\n", " 'train_loss': 0.4632716014467437,\n", " 'train_acc': 0.8426724137931034,\n", " 'valid_loss': 0.44999733567237854,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 97.0,\n", " 'train_loss': 0.46203294293633823,\n", " 'train_acc': 0.8448275862068966,\n", " 'valid_loss': 0.4487009644508362,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 98.0,\n", " 'train_loss': 0.4608049721553408,\n", " 'train_acc': 0.8448275862068966,\n", " 'valid_loss': 0.447415292263031,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 99.0,\n", " 'train_loss': 0.45958755756246633,\n", " 'train_acc': 0.8448275862068966,\n", " 'valid_loss': 0.4461402893066406,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 100.0,\n", " 'train_loss': 0.45838060050175106,\n", " 'train_acc': 0.8448275862068966,\n", " 'valid_loss': 0.44487571716308594,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 101.0,\n", " 'train_loss': 0.4571839036612675,\n", " 'train_acc': 0.8448275862068966,\n", " 'valid_loss': 0.443621426820755,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 102.0,\n", " 'train_loss': 0.4559974341556944,\n", " 'train_acc': 0.8448275862068966,\n", " 'valid_loss': 0.44237738847732544,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 103.0,\n", " 'train_loss': 0.45482102755842535,\n", " 'train_acc': 0.8491379310344828,\n", " 'valid_loss': 0.44114333391189575,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 104.0,\n", " 'train_loss': 0.45365445367221174,\n", " 'train_acc': 0.8491379310344828,\n", " 'valid_loss': 0.43991929292678833,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 105.0,\n", " 'train_loss': 0.45249781115301724,\n", " 'train_acc': 0.8491379310344828,\n", " 'valid_loss': 0.4387049674987793,\n", " 'valid_acc': 0.8046875},\n", " {'epoch': 106.0,\n", " 'train_loss': 0.45135073826230804,\n", " 'train_acc': 0.8491379310344828,\n", " 'valid_loss': 0.4375004470348358,\n", " 'valid_acc': 0.8125},\n", " {'epoch': 107.0,\n", " 'train_loss': 0.45021320211476296,\n", " 'train_acc': 0.8491379310344828,\n", " 'valid_loss': 0.4363054931163788,\n", " 'valid_acc': 0.8125},\n", " {'epoch': 108.0,\n", " 'train_loss': 0.4490851040544181,\n", " 'train_acc': 0.8491379310344828,\n", " 'valid_loss': 0.4351199269294739,\n", " 'valid_acc': 0.8203125},\n", " {'epoch': 109.0,\n", " 'train_loss': 0.4479663783106311,\n", " 'train_acc': 0.8512931034482759,\n", " 'valid_loss': 0.4339437484741211,\n", " 'valid_acc': 0.8203125},\n", " {'epoch': 110.0,\n", " 'train_loss': 0.4468568604567955,\n", " 'train_acc': 0.8512931034482759,\n", " 'valid_loss': 0.4327767789363861,\n", " 'valid_acc': 0.8203125},\n", " {'epoch': 111.0,\n", " 'train_loss': 0.4457563202956627,\n", " 'train_acc': 0.8512931034482759,\n", " 'valid_loss': 0.43161892890930176,\n", " 'valid_acc': 0.8203125},\n", " {'epoch': 112.0,\n", " 'train_loss': 0.4446647249419114,\n", " 'train_acc': 0.853448275862069,\n", " 'valid_loss': 0.4304701089859009,\n", " 'valid_acc': 0.8203125},\n", " {'epoch': 113.0,\n", " 'train_loss': 0.4435821072808627,\n", " 'train_acc': 0.853448275862069,\n", " 'valid_loss': 0.4293302297592163,\n", " 'valid_acc': 0.8203125},\n", " {'epoch': 114.0,\n", " 'train_loss': 0.442508105573983,\n", " 'train_acc': 0.8599137931034483,\n", " 'valid_loss': 0.42819908261299133,\n", " 'valid_acc': 0.8203125},\n", " {'epoch': 115.0,\n", " 'train_loss': 0.44144278559191474,\n", " 'train_acc': 0.8599137931034483,\n", " 'valid_loss': 0.42707669734954834,\n", " 'valid_acc': 0.8203125},\n", " {'epoch': 116.0,\n", " 'train_loss': 0.44038595002273034,\n", " 'train_acc': 0.8599137931034483,\n", " 'valid_loss': 0.4259628653526306,\n", " 'valid_acc': 0.8203125},\n", " {'epoch': 117.0,\n", " 'train_loss': 0.43933756598110857,\n", " 'train_acc': 0.8599137931034483,\n", " 'valid_loss': 0.424857497215271,\n", " 'valid_acc': 0.8203125},\n", " {'epoch': 118.0,\n", " 'train_loss': 0.43829746904044314,\n", " 'train_acc': 0.8620689655172413,\n", " 'valid_loss': 0.4237605333328247,\n", " 'valid_acc': 0.8203125},\n", " {'epoch': 119.0,\n", " 'train_loss': 0.4372657249713766,\n", " 'train_acc': 0.8642241379310345,\n", " 'valid_loss': 0.4226718544960022,\n", " 'valid_acc': 0.8984375},\n", " {'epoch': 120.0,\n", " 'train_loss': 0.43624200492069637,\n", " 'train_acc': 0.8642241379310345,\n", " 'valid_loss': 0.4215913414955139,\n", " 'valid_acc': 0.8984375},\n", " {'epoch': 121.0,\n", " 'train_loss': 0.43522624311776,\n", " 'train_acc': 0.8642241379310345,\n", " 'valid_loss': 0.4205189347267151,\n", " 'valid_acc': 0.8984375},\n", " {'epoch': 122.0,\n", " 'train_loss': 0.4342184724478886,\n", " 'train_acc': 0.8663793103448276,\n", " 'valid_loss': 0.41945451498031616,\n", " 'valid_acc': 0.8984375},\n", " {'epoch': 123.0,\n", " 'train_loss': 0.43321859425511855,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.41839802265167236,\n", " 'valid_acc': 0.8984375},\n", " {'epoch': 124.0,\n", " 'train_loss': 0.4322263454568797,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.41734933853149414,\n", " 'valid_acc': 0.8984375},\n", " {'epoch': 125.0,\n", " 'train_loss': 0.43124179182381467,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.4163084030151367,\n", " 'valid_acc': 0.8984375},\n", " {'epoch': 126.0,\n", " 'train_loss': 0.43026470315867454,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.41527506709098816,\n", " 'valid_acc': 0.8984375},\n", " {'epoch': 127.0,\n", " 'train_loss': 0.42929527677338697,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.41424936056137085,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 128.0,\n", " 'train_loss': 0.42833305227345436,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.4132310152053833,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 129.0,\n", " 'train_loss': 0.4273781940854829,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.4122201204299927,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 130.0,\n", " 'train_loss': 0.4264304391269026,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.41121649742126465,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 131.0,\n", " 'train_loss': 0.4254899189389985,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.4102201461791992,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 132.0,\n", " 'train_loss': 0.42455633755387934,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.4092308580875397,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 133.0,\n", " 'train_loss': 0.42362982651283,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.4082486629486084,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 134.0,\n", " 'train_loss': 0.4227101556186018,\n", " 'train_acc': 0.8685344827586207,\n", " 'valid_loss': 0.40727338194847107,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 135.0,\n", " 'train_loss': 0.4217972919858735,\n", " 'train_acc': 0.8706896551724138,\n", " 'valid_loss': 0.4063050448894501,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 136.0,\n", " 'train_loss': 0.4208910711880388,\n", " 'train_acc': 0.8728448275862069,\n", " 'valid_loss': 0.40534353256225586,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 137.0,\n", " 'train_loss': 0.41999155899574014,\n", " 'train_acc': 0.8728448275862069,\n", " 'valid_loss': 0.40438878536224365,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 138.0,\n", " 'train_loss': 0.41909852521172886,\n", " 'train_acc': 0.8728448275862069,\n", " 'valid_loss': 0.4034406840801239,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 139.0,\n", " 'train_loss': 0.41821210137728987,\n", " 'train_acc': 0.8728448275862069,\n", " 'valid_loss': 0.4024991989135742,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 140.0,\n", " 'train_loss': 0.4173319586392107,\n", " 'train_acc': 0.8728448275862069,\n", " 'valid_loss': 0.40156421065330505,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 141.0,\n", " 'train_loss': 0.41645826142409753,\n", " 'train_acc': 0.875,\n", " 'valid_loss': 0.4006357192993164,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 142.0,\n", " 'train_loss': 0.4155908781906654,\n", " 'train_acc': 0.875,\n", " 'valid_loss': 0.39971357583999634,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 143.0,\n", " 'train_loss': 0.4147295458563443,\n", " 'train_acc': 0.875,\n", " 'valid_loss': 0.39879775047302246,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 144.0,\n", " 'train_loss': 0.4138744617330617,\n", " 'train_acc': 0.875,\n", " 'valid_loss': 0.39788818359375,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 145.0,\n", " 'train_loss': 0.41302542850889007,\n", " 'train_acc': 0.875,\n", " 'valid_loss': 0.39698484539985657,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 146.0,\n", " 'train_loss': 0.4121824461838295,\n", " 'train_acc': 0.875,\n", " 'valid_loss': 0.3960876166820526,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 147.0,\n", " 'train_loss': 0.4113452845606311,\n", " 'train_acc': 0.875,\n", " 'valid_loss': 0.3951963782310486,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 148.0,\n", " 'train_loss': 0.4105140422952586,\n", " 'train_acc': 0.875,\n", " 'valid_loss': 0.3943111300468445,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 149.0,\n", " 'train_loss': 0.40968858784642714,\n", " 'train_acc': 0.875,\n", " 'valid_loss': 0.3934318423271179,\n", " 'valid_acc': 0.90625},\n", " {'epoch': 150.0,\n", " 'train_loss': 0.40886885544349405,\n", " 'train_acc': 0.875,\n", " 'valid_loss': 0.39255842566490173,\n", " 'valid_acc': 0.90625}]" ] }, "execution_count": 257, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = []\n", "for row in stats.iterrows():\n", " data.append(row[1].to_dict())\n", "data" ] }, { "cell_type": "code", "execution_count": 258, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "\n", " \n", " \n", " HiPlot\n", " \n", " \n", " \n", "
Loading HiPlot...
\n", " \n", "
\n", "\n", " \n", " \n", " \n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 258, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import hiplot as hip\n", "hip.Experiment.from_iterable(data).display(force_full_width = True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above graph gives us a very nice way to visualize our expected general patterns: \n", "as the number of epoch increases, train and validation loss decreases while train and validation accuracy increase" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Further, we can investigate the magnitude of each weight parameter to shed insight on the variables that had a higher level of influence on our prediction (assuming that higher magnitudes correlate with higher importance)\n", "\n", "**NOTE**: this is only possible as our model is just a one layer linear operation. If this was a \"deep\" model, interpretation by weight magnitude would not be possible" ] }, { "cell_type": "code", "execution_count": 259, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
radius_meantexture_meanperimeter_meanarea_meansmoothness_meancompactness_meanconcavity_meanpoints_meansymmetry_meandimension_mean...radius_worsttexture_worstperimeter_worstarea_worstsmoothness_worstcompactness_worstconcavity_worstpoints_worstsymmetry_worstdimension_worst
01.1617650.6772372.2550953.453925-0.3815822.0461125.1075393.2686450.2493270.61387...2.5254640.678428-0.5702255.1252370.9066723.3808831.8960724.8355142.8841022.432787
\n", "

1 rows × 30 columns

\n", "
" ], "text/plain": [ " radius_mean texture_mean perimeter_mean area_mean smoothness_mean \\\n", "0 1.161765 0.677237 2.255095 3.453925 -0.381582 \n", "\n", " compactness_mean concavity_mean points_mean symmetry_mean \\\n", "0 2.046112 5.107539 3.268645 0.249327 \n", "\n", " dimension_mean ... radius_worst texture_worst perimeter_worst \\\n", "0 0.61387 ... 2.525464 0.678428 -0.570225 \n", "\n", " area_worst smoothness_worst compactness_worst concavity_worst \\\n", "0 5.125237 0.906672 3.380883 1.896072 \n", "\n", " points_worst symmetry_worst dimension_worst \n", "0 4.835514 2.884102 2.432787 \n", "\n", "[1 rows x 30 columns]" ] }, "execution_count": 259, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = list(model.parameters())[0].detach().cpu().view(-1).numpy()\n", "param_df = pd.DataFrame(params).T\n", "param_df.columns = df.columns\n", "param_df" ] }, { "cell_type": "code", "execution_count": 260, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 260, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize = (4,8))\n", "sns.heatmap(param_df.T.sort_values(0,ascending = False))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Assuming that weights with higher magnitudes equate to a higher level of importance, we can see that \n", "\n", "1. area_worst\n", "2. area_se and\n", "3. points_worst\n", "\n", "were the top 3 variables that helped us differentiate between Malign and Benign cells." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Conclusion\n", "\n", "Logistic Regression is a poweful method to analyze the relationship between quantitative and binary qualitative variables that uses the techniques of DL to present a meaninful relationship.\n", "\n", "Given its \"shallow\" architecture in comparison with alternative DL architectures, it does not necessitate much data to learn relationships.\n", "\n", "As such, Logistic Regression is an important concept to have on anyone's Data Science arsenal" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }