{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Create the data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import random\n", "import torch\n", "from torch import nn, optim\n", "import math\n", "from IPython import display" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from res.plot_lib import plot_data, plot_model, set_default" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "set_default()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "seed = 12345\n", "random.seed(seed)\n", "torch.manual_seed(seed)\n", "N = 1000 # num_samples_per_class\n", "D = 2 # dimensions\n", "C = 3 # num_classes\n", "H = 100 # num_hidden_units" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X = torch.zeros(N * C, D).to(device)\n", "y = torch.zeros(N * C, dtype=torch.long).to(device)\n", "for c in range(C):\n", " index = 0\n", " t = torch.linspace(0, 1, N)\n", " # When c = 0 and t = 0: start of linspace\n", " # When c = 0 and t = 1: end of linpace\n", " # This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)\n", " inner_var = torch.linspace(\n", " # When t = 0\n", " (2 * math.pi / C) * (c),\n", " # When t = 1\n", " (2 * math.pi / C) * (2 + c),\n", " N\n", " ) + torch.randn(N) * 0.2\n", " \n", " for ix in range(N * c, N * (c + 1)):\n", " X[ix] = t[index] * torch.FloatTensor((\n", " math.sin(inner_var[index]), math.cos(inner_var[index])\n", " ))\n", " y[ix] = c\n", " index += 1\n", "\n", "print(\"Shapes:\")\n", "print(\"X:\", tuple(X.size()))\n", "print(\"y:\", tuple(y.size()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# visualise the data\n", "plot_data(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Linear model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learning_rate = 1e-3\n", "lambda_l2 = 1e-5" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# nn package to create our linear model\n", "# each Linear module has a weight and bias\n", "model = nn.Sequential(\n", " nn.Linear(D, H),\n", " nn.Linear(H, C)\n", ")\n", "model.to(device) #Convert to CUDA\n", "\n", "# nn package also has different loss functions.\n", "# we use cross entropy loss for our classification task\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", "# we use the optim package to apply\n", "# stochastic gradient descent for our parameter updates\n", "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=lambda_l2) # built-in L2\n", "\n", "# Training\n", "for t in range(1000):\n", " \n", " # Feed forward to get the logits\n", " y_pred = model(X)\n", " \n", " # Compute the loss and accuracy\n", " loss = criterion(y_pred, y)\n", " score, predicted = torch.max(y_pred, 1)\n", " acc = (y == predicted).sum().float() / len(y)\n", " print(\"[EPOCH]: %i, [LOSS]: %.6f, [ACCURACY]: %.3f\" % (t, loss.item(), acc))\n", " display.clear_output(wait=True)\n", " \n", " # zero the gradients before running\n", " # the backward pass.\n", " optimizer.zero_grad()\n", " \n", " # Backward pass to compute the gradient\n", " # of loss w.r.t our learnable params. \n", " loss.backward()\n", " \n", " # Update params\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Plot trained model\n", "print(model)\n", "plot_model(X, y, model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Two-layered network" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learning_rate = 1e-3\n", "lambda_l2 = 1e-5" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# nn package to create our linear model\n", "# each Linear module has a weight and bias\n", "\n", "model = nn.Sequential(\n", " nn.Linear(D, H),\n", " nn.ReLU(),\n", " nn.Linear(H, C)\n", ")\n", "model.to(device)\n", "\n", "# nn package also has different loss functions.\n", "# we use cross entropy loss for our classification task\n", "criterion = torch.nn.CrossEntropyLoss()\n", "\n", "# we use the optim package to apply\n", "# ADAM for our parameter updates\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=lambda_l2) # built-in L2\n", "\n", "# e = 1. # plotting purpose\n", "\n", "# Training\n", "for t in range(1000):\n", " \n", " # Feed forward to get the logits\n", " y_pred = model(X)\n", " \n", " # Compute the loss and accuracy\n", " loss = criterion(y_pred, y)\n", " score, predicted = torch.max(y_pred, 1)\n", " acc = (y == predicted).sum().float() / len(y)\n", " print(\"[EPOCH]: %i, [LOSS]: %.6f, [ACCURACY]: %.3f\" % (t, loss.item(), acc))\n", " display.clear_output(wait=True)\n", " \n", " # zero the gradients before running\n", " # the backward pass.\n", " optimizer.zero_grad()\n", " \n", " # Backward pass to compute the gradient\n", " # of loss w.r.t our learnable params. \n", " loss.backward()\n", " \n", " # Update params\n", " optimizer.step()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Plot trained model\n", "print(model)\n", "plot_model(X, y, model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:dl-minicourse] *", "language": "python", "name": "conda-env-dl-minicourse-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }