{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "01-mnist-hello-world.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyOtAKVa5POQ6Xg3UcTQqXDJ",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/01-mnist-hello-world.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i7XbLCXGkll9",
        "colab_type": "text"
      },
      "source": [
        "# Introduction to Pytorch Lightning ⚡\n",
        "\n",
        "In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n",
        "\n",
        "---\n",
        "  - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
        "  - Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
        "  - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2LODD6w9ixlT",
        "colab_type": "text"
      },
      "source": [
        "### Setup  \n",
        "Lightning is easy to install. Simply ```pip install pytorch-lightning```"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zK7-Gg69kMnG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "! pip install pytorch-lightning --quiet"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "w4_TYnt_keJi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import os\n",
        "\n",
        "import torch\n",
        "from torch import nn\n",
        "from torch.nn import functional as F\n",
        "from torch.utils.data import DataLoader, random_split\n",
        "from torchvision.datasets import MNIST\n",
        "from torchvision import transforms\n",
        "import pytorch_lightning as pl\n",
        "from pytorch_lightning.metrics.functional import accuracy"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EHpyMPKFkVbZ",
        "colab_type": "text"
      },
      "source": [
        "## Simplest example\n",
        "\n",
        "Here's the simplest most minimal example with just a training loop (no validation, no testing).\n",
        "\n",
        "**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "V7ELesz1kVQo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class MNISTModel(pl.LightningModule):\n",
        "\n",
        "    def __init__(self):\n",
        "        super(MNISTModel, self).__init__()\n",
        "        self.l1 = torch.nn.Linear(28 * 28, 10)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return torch.relu(self.l1(x.view(x.size(0), -1)))\n",
        "\n",
        "    def training_step(self, batch, batch_nb):\n",
        "        x, y = batch\n",
        "        loss = F.cross_entropy(self(x), y)\n",
        "        return loss\n",
        "\n",
        "    def configure_optimizers(self):\n",
        "        return torch.optim.Adam(self.parameters(), lr=0.02)"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hIrtHg-Dv8TJ",
        "colab_type": "text"
      },
      "source": [
        "By using the `Trainer` you automatically get:\n",
        "1. Tensorboard logging\n",
        "2. Model checkpointing\n",
        "3. Training and validation loop\n",
        "4. early-stopping"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4Dk6Ykv8lI7X",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Init our model\n",
        "mnist_model = MNISTModel()\n",
        "\n",
        "# Init DataLoader from MNIST Dataset\n",
        "train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())\n",
        "train_loader = DataLoader(train_ds, batch_size=32)\n",
        "\n",
        "# Initialize a trainer\n",
        "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
        "\n",
        "# Train the model ⚡\n",
        "trainer.fit(mnist_model, train_loader)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KNpOoBeIjscS",
        "colab_type": "text"
      },
      "source": [
        "## A more complete MNIST Lightning Module Example\n",
        "\n",
        "That wasn't so hard was it?\n",
        "\n",
        "Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n",
        "\n",
        "This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n",
        "\n",
        "---\n",
        "\n",
        "### Note what the following built-in functions are doing:\n",
        "\n",
        "1. [prepare_data()](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.prepare_data) 💾\n",
        "    - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n",
        "    - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n",
        "\n",
        "2. [setup(stage)](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#setup) ⚙️\n",
        "    - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test). \n",
        "    - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n",
        "    - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n",
        "    - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n",
        "\n",
        "3. [x_dataloader()](https://pytorch-lightning.readthedocs.io/en/latest/lightning-module.html#data-hooks) ♻️\n",
        "    - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4DNItffri95Q",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class LitMNIST(pl.LightningModule):\n",
        "    \n",
        "    def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):\n",
        "\n",
        "        super().__init__()\n",
        "\n",
        "        # Set our init args as class attributes\n",
        "        self.data_dir = data_dir\n",
        "        self.hidden_size = hidden_size\n",
        "        self.learning_rate = learning_rate\n",
        "\n",
        "        # Hardcode some dataset specific attributes\n",
        "        self.num_classes = 10\n",
        "        self.dims = (1, 28, 28)\n",
        "        channels, width, height = self.dims\n",
        "        self.transform = transforms.Compose([\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize((0.1307,), (0.3081,))\n",
        "        ])\n",
        "\n",
        "        # Define PyTorch model\n",
        "        self.model = nn.Sequential(\n",
        "            nn.Flatten(),\n",
        "            nn.Linear(channels * width * height, hidden_size),\n",
        "            nn.ReLU(),\n",
        "            nn.Dropout(0.1),\n",
        "            nn.Linear(hidden_size, hidden_size),\n",
        "            nn.ReLU(),\n",
        "            nn.Dropout(0.1),\n",
        "            nn.Linear(hidden_size, self.num_classes)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.model(x)\n",
        "        return F.log_softmax(x, dim=1)\n",
        "\n",
        "    def training_step(self, batch, batch_idx):\n",
        "        x, y = batch\n",
        "        logits = self(x)\n",
        "        loss = F.nll_loss(logits, y)\n",
        "        return loss\n",
        "\n",
        "    def validation_step(self, batch, batch_idx):\n",
        "        x, y = batch\n",
        "        logits = self(x)\n",
        "        loss = F.nll_loss(logits, y)\n",
        "        preds = torch.argmax(logits, dim=1)\n",
        "        acc = accuracy(preds, y)\n",
        "\n",
        "        # Calling self.log will surface up scalars for you in TensorBoard\n",
        "        self.log('val_loss', loss, prog_bar=True)\n",
        "        self.log('val_acc', acc, prog_bar=True)\n",
        "        return loss\n",
        "\n",
        "    def test_step(self, batch, batch_idx):\n",
        "        # Here we just reuse the validation_step for testing\n",
        "        return self.validation_step(batch, batch_idx)\n",
        "\n",
        "    def configure_optimizers(self):\n",
        "        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
        "        return optimizer\n",
        "\n",
        "    ####################\n",
        "    # DATA RELATED HOOKS\n",
        "    ####################\n",
        "\n",
        "    def prepare_data(self):\n",
        "        # download\n",
        "        MNIST(self.data_dir, train=True, download=True)\n",
        "        MNIST(self.data_dir, train=False, download=True)\n",
        "\n",
        "    def setup(self, stage=None):\n",
        "\n",
        "        # Assign train/val datasets for use in dataloaders\n",
        "        if stage == 'fit' or stage is None:\n",
        "            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n",
        "            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n",
        "\n",
        "        # Assign test dataset for use in dataloader(s)\n",
        "        if stage == 'test' or stage is None:\n",
        "            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n",
        "\n",
        "    def train_dataloader(self):\n",
        "        return DataLoader(self.mnist_train, batch_size=32)\n",
        "\n",
        "    def val_dataloader(self):\n",
        "        return DataLoader(self.mnist_val, batch_size=32)\n",
        "\n",
        "    def test_dataloader(self):\n",
        "        return DataLoader(self.mnist_test, batch_size=32)"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Mb0U5Rk2kLBy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model = LitMNIST()\n",
        "trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)\n",
        "trainer.fit(model)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nht8AvMptY6I",
        "colab_type": "text"
      },
      "source": [
        "### Testing\n",
        "\n",
        "To test a model, call `trainer.test(model)`.\n",
        "\n",
        "Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically test using the best saved checkpoint (conditioned on val_loss)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PA151FkLtprO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "trainer.test()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T3-3lbbNtr5T",
        "colab_type": "text"
      },
      "source": [
        "### Bonus Tip\n",
        "\n",
        "You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IFBwCbLet2r6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "trainer.fit(model)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8TRyS5CCt3n9",
        "colab_type": "text"
      },
      "source": [
        "In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wizS-QiLuAYo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Start tensorboard.\n",
        "%load_ext tensorboard\n",
        "%tensorboard --logdir lightning_logs/"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}