{ "cells": [ { "cell_type": "markdown", "id": "0ca5f47a-43c1-47ba-9322-c48f645b0908", "metadata": { "id": "0ca5f47a-43c1-47ba-9322-c48f645b0908" }, "source": [ "# Pytorch Example with Confusion Matrix with Comet ML\n", "\n", "For this example, we will use Pytorch and create a interactive Confusion Matrix\n", "in Comet ML. You'll need a Comet API key to log the Confusion Matrix, which is free for anyone.\n", "\n", "## Overview\n", "\n", "Our goal in this demonstration is to train a Pytorch model to categorize images of digits from the MNIST dataset, being able to see examples of each cell in a confusion matrix, like this:\n", "\n", "\n", "\n", "Comet provides a very easy way to make such confusion matrices. You can do that with a single command:\n", "\n", "```python\n", "experiment.log_confusion_matrix(actual, predicted, images=images)\n", "```\n", "\n", "where `actual` is the ground truth (given as vectors or labels), `predicted` is the ML's prediction (given as vectors or labels), and `images` is a list of image data.\n", "\n", "## End-to-End Example\n", "\n", "Let's explore a complete example from start to finish. \n", "\n", "First, we install the needed Python libraries:" ] }, { "cell_type": "code", "execution_count": 1, "id": "dec1NN24O_MH", "metadata": { "id": "dec1NN24O_MH" }, "outputs": [], "source": [ "%pip install \"comet_ml>=3.44.0\" torch torchvision" ] }, { "cell_type": "markdown", "id": "9aY1iWM88iUa", "metadata": { "id": "9aY1iWM88iUa" }, "source": [ "Now we import Comet:" ] }, { "cell_type": "code", "execution_count": 2, "id": "82048786-8d61-44c7-a6d2-0484a69779d0", "metadata": { "id": "82048786-8d61-44c7-a6d2-0484a69779d0" }, "outputs": [], "source": [ "import comet_ml" ] }, { "cell_type": "markdown", "id": "65e82fc4-adda-4a7b-a84b-91d7cce363c4", "metadata": { "id": "65e82fc4-adda-4a7b-a84b-91d7cce363c4" }, "source": [ "We can then make sure that our Comet API key is properly configured. The following command will give instructions if not:" ] }, { "cell_type": "code", "execution_count": 3, "id": "406fa63d-8854-4812-a082-e3ad011ed3fc", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "406fa63d-8854-4812-a082-e3ad011ed3fc", "outputId": "44a0adc9-12c7-4251-80ac-34bc8318644c" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "COMET INFO: Comet API key is valid\n" ] } ], "source": [ "comet_ml.login()" ] }, { "cell_type": "markdown", "id": "6cc97390-6fb2-45eb-9b81-47f25ae73cc7", "metadata": { "id": "6cc97390-6fb2-45eb-9b81-47f25ae73cc7" }, "source": [ "Now, we import the rest of the Python libraries that we will need:" ] }, { "cell_type": "code", "execution_count": 4, "id": "f11f4599-c2d6-4356-9936-aab28ade436b", "metadata": { "id": "f11f4599-c2d6-4356-9936-aab28ade436b" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torchvision.datasets as dsets\n", "import torchvision.transforms as transforms\n", "from torch.utils.data import SubsetRandomSampler\n", "from torch.autograd import Variable" ] }, { "cell_type": "markdown", "id": "32a480fb-f40e-48ec-b27b-12a0cbe72efb", "metadata": { "id": "32a480fb-f40e-48ec-b27b-12a0cbe72efb" }, "source": [ "## MNIST Dataset\n", "\n" ] }, { "cell_type": "markdown", "id": "b5b7c632-760c-490b-9d14-8e0e7dcbd277", "metadata": { "id": "b5b7c632-760c-490b-9d14-8e0e7dcbd277" }, "source": [ "The first time this runs may take a few minutes to download, and then a couple more minutes to process:" ] }, { "cell_type": "code", "execution_count": 5, "id": "0cbb42ca-f73e-4de5-bb4a-d05dc45d6b47", "metadata": { "id": "0cbb42ca-f73e-4de5-bb4a-d05dc45d6b47" }, "outputs": [], "source": [ "train_dataset = dsets.MNIST(\n", " root=\"./data/\", train=True, transform=transforms.ToTensor(), download=True\n", ")\n", "\n", "test_dataset = dsets.MNIST(root=\"./data/\", train=False, transform=transforms.ToTensor())" ] }, { "cell_type": "markdown", "id": "563109fd-e169-4238-ac70-519a1c00c633", "metadata": { "id": "563109fd-e169-4238-ac70-519a1c00c633" }, "source": [ "## Create the Model\n", "\n", "We'll now write a function that will create the model.\n", "\n", "In this example, we'll take advantage of Comet's `Experiment` to get access to the hyperparameters via `experiment.get_parameter()`. This will be very handy when we later use Comet's Hyperparameter Optimizer to generate the Experiments.\n", "\n", "This function will actually return the three components of the model: the rnn, the criterion, and the optimizer." ] }, { "cell_type": "code", "execution_count": 6, "id": "71cc8afc-0b6e-42e4-a0aa-f66a5b2fe006", "metadata": { "id": "71cc8afc-0b6e-42e4-a0aa-f66a5b2fe006" }, "outputs": [], "source": [ "def build_model(experiment):\n", " input_size = experiment.get_parameter(\"input_size\")\n", " hidden_size = experiment.get_parameter(\"hidden_size\")\n", " num_layers = experiment.get_parameter(\"num_layers\")\n", " num_classes = experiment.get_parameter(\"num_classes\")\n", " learning_rate = experiment.get_parameter(\"learning_rate\")\n", "\n", " class RNN(nn.Module):\n", " def __init__(self, input_size, hidden_size, num_layers, num_classes):\n", " super(RNN, self).__init__()\n", " self.hidden_size = hidden_size\n", " self.num_layers = num_layers\n", " self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)\n", " self.fc = nn.Linear(hidden_size, num_classes)\n", "\n", " def forward(self, x):\n", " # Set initial states\n", " h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))\n", " c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))\n", "\n", " # Forward propagate RNN\n", " self.out, _ = self.lstm(x, (h0, c0))\n", "\n", " # Decode hidden state of last time step\n", " out = self.fc(self.out[:, -1, :])\n", " return out\n", "\n", " rnn = RNN(\n", " input_size,\n", " hidden_size,\n", " num_layers,\n", " num_classes,\n", " )\n", "\n", " # Loss and Optimizer\n", " criterion = nn.CrossEntropyLoss()\n", " optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)\n", "\n", " return (rnn, criterion, optimizer)" ] }, { "cell_type": "markdown", "id": "53886701-7210-4895-b946-cfd7d788662b", "metadata": { "id": "53886701-7210-4895-b946-cfd7d788662b" }, "source": [ "We'll call this function below, once we create an `Experiment`." ] }, { "cell_type": "markdown", "id": "8114f369-7ba3-4783-93f4-c13ac1c70a4b", "metadata": { "id": "8114f369-7ba3-4783-93f4-c13ac1c70a4b" }, "source": [ "## Train the Dataset on the Model\n", "\n", "Now we are ready to set up a Comet Experiment, and train the model.\n", "\n", "First, we can set all of the Hyperparameters of the model:" ] }, { "cell_type": "code", "execution_count": 7, "id": "c8b50453-3097-4a0b-9cf3-9b69f8b51df7", "metadata": { "id": "c8b50453-3097-4a0b-9cf3-9b69f8b51df7" }, "outputs": [], "source": [ "hyper_params = {\n", " \"epochs\": 10,\n", " \"batch_size\": 120,\n", " \"first_layer_units\": 128,\n", " \"sequence_length\": 28,\n", " \"input_size\": 28,\n", " \"hidden_size\": 128,\n", " \"num_layers\": 2,\n", " \"num_classes\": 10,\n", " \"learning_rate\": 0.01,\n", "}" ] }, { "cell_type": "markdown", "id": "062c6192-5c29-41ca-8841-5ceb82b7d978", "metadata": { "id": "062c6192-5c29-41ca-8841-5ceb82b7d978" }, "source": [ "Next we create the experiment, and log the Hyperparameters:" ] }, { "cell_type": "code", "execution_count": 8, "id": "174f3481-0798-4a11-acda-e651753a1caa", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "174f3481-0798-4a11-acda-e651753a1caa", "outputId": "b9038054-a7b9-483a-8a19-11eb7c2f114c" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "COMET WARNING: As you are running in a Jupyter environment, you will need to call `experiment.end()` when finished to ensure all metrics and code are logged before exiting.\n", "COMET INFO: Experiment is live on comet.ml https://www.comet.com/dsblank/pytorch-confusion-matrix/819f19ee68ba4b91bab88421b795451d\n", "\n" ] } ], "source": [ "experiment = comet_ml.start(project_name=\"pytorch-confusion-matrix\")\n", "experiment.log_parameters(hyper_params)" ] }, { "cell_type": "markdown", "id": "4a5976cc-79e6-43ae-8301-8c39abf302a0", "metadata": { "id": "4a5976cc-79e6-43ae-8301-8c39abf302a0" }, "source": [ "We can now construct the model components:" ] }, { "cell_type": "code", "execution_count": 9, "id": "8778542d-da44-4f7e-a01e-2e24d2dfec3e", "metadata": { "id": "8778542d-da44-4f7e-a01e-2e24d2dfec3e" }, "outputs": [], "source": [ "rnn, criterion, optimizer = build_model(experiment)" ] }, { "cell_type": "markdown", "id": "423e17c2-e9ef-473c-b2ee-055669033984", "metadata": { "id": "423e17c2-e9ef-473c-b2ee-055669033984" }, "source": [ "To make this demonstration go a little faster, we'll just use a sample of the items from the training set:" ] }, { "cell_type": "code", "execution_count": 10, "id": "56724a82-9a93-4af8-b0d2-6492364d4be9", "metadata": { "id": "56724a82-9a93-4af8-b0d2-6492364d4be9" }, "outputs": [], "source": [ "SAMPLE_SIZE = 1000" ] }, { "cell_type": "markdown", "id": "acczEZl36Bml", "metadata": { "id": "acczEZl36Bml" }, "source": [ "Now we can construct the loader:" ] }, { "cell_type": "code", "execution_count": 11, "id": "761fe054-c539-4883-a58f-ad46a0108fcb", "metadata": { "id": "761fe054-c539-4883-a58f-ad46a0108fcb" }, "outputs": [], "source": [ "sampler = SubsetRandomSampler(list(range(SAMPLE_SIZE)))\n", "train_loader = torch.utils.data.DataLoader(\n", " dataset=train_dataset,\n", " batch_size=experiment.get_parameter(\"batch_size\"),\n", " sampler=sampler,\n", " # shuffle=True, # can't use shuffle with sampler\n", ")" ] }, { "cell_type": "markdown", "id": "EaVEhHpJC2nY", "metadata": { "id": "EaVEhHpJC2nY" }, "source": [ "Instead, if you would rather train on the entire dataset, you can:\n", "\n", "```python\n", "train_loader = torch.utils.data.DataLoader(\n", " dataset=train_dataset,\n", " batch_size=experiment.get_parameter('batch_size'),\n", " shuffle=True,\n", ")\n", "```" ] }, { "cell_type": "markdown", "id": "RZvVJzzE6FgW", "metadata": { "id": "RZvVJzzE6FgW" }, "source": [ "Now we can train the model. Some items to note:\n", "\n", "1. We use `experiment.train()` to provide the context for logged metrics\n", "2. We collect the actual, predicted, and images for each batch\n", "3. At the end of the epoch, compute and log the confusion matrix" ] }, { "cell_type": "code", "execution_count": 12, "id": "8548f3f5-e1ef-4c35-b383-3dbe01941025", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8548f3f5-e1ef-4c35-b383-3dbe01941025", "outputId": "5df2bd81-5202-46b4-bb09-d3af5f07d5d3", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 0\n", ".........\n", "epoch: 1\n", ".........\n", "epoch: 2\n", ".........\n", "epoch: 3\n", ".........\n", "epoch: 4\n", ".........\n", "epoch: 5\n", ".........\n", "epoch: 6\n", ".........\n", "epoch: 7\n", ".........\n", "epoch: 8\n", ".........\n", "epoch: 9\n", "........." ] } ], "source": [ "with experiment.train():\n", " step = 0\n", " for epoch in range(experiment.get_parameter(\"epochs\")):\n", " print(\"\\nepoch:\", epoch)\n", " correct = 0\n", " total = 0\n", " for batch_step, (images, labels) in enumerate(train_loader):\n", " print(\".\", end=\"\")\n", " images = Variable(\n", " images.view(\n", " -1,\n", " experiment.get_parameter(\"sequence_length\"),\n", " experiment.get_parameter(\"input_size\"),\n", " )\n", " )\n", "\n", " labels = Variable(labels)\n", "\n", " # Forward + Backward + Optimize\n", " optimizer.zero_grad()\n", " outputs = rnn(images)\n", " loss = criterion(outputs, labels)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # Compute train accuracy\n", " _, predicted = torch.max(outputs.data, 1)\n", " batch_total = labels.size(0)\n", " total += batch_total\n", "\n", " batch_correct = (predicted == labels.data).sum()\n", " correct += batch_correct\n", "\n", " # Log batch_accuracy to Comet.ml; step is each batch\n", " step += 1\n", " experiment.log_metric(\n", " \"batch_accuracy\", batch_correct / batch_total, step=step\n", " )\n", "\n", " if (batch_step + 1) % 100 == 0:\n", " print(\n", " \"Epoch [%d/%d], Step [%d/%d], Loss: %.4f\"\n", " % (\n", " epoch + 1,\n", " experiment.get_parameter(\"epochs\"),\n", " batch_step + 1,\n", " len(train_dataset) // experiment.get_parameter(\"batch_size\"),\n", " loss.item(),\n", " )\n", " )\n", "\n", " # Log epoch accuracy to Comet.ml; step is each epoch\n", " experiment.log_metric(\n", " \"batch_accuracy\", correct / total, step=epoch, epoch=epoch\n", " )" ] }, { "cell_type": "markdown", "id": "92436469-0313-46ca-817e-6f0ed4c6a663", "metadata": { "id": "92436469-0313-46ca-817e-6f0ed4c6a663" }, "source": [ "### Comet Confusion Matrix\n", "\n", "After the training loop, we can then test the test dataset with:\n", "\n", "```python\n", "confusion_matrix = experiment.create_confusion_matrix()\n", "for batch in batches:\n", " ...\n", " confusion_matrix.compute_matrix(actual, predicted, images=images)\n", "experiment.log_confusion_matrix(matrix=confusion_matrix)\n", "```\n", "and that will create a nice Confusion Matrix visualization in Comet with image examples.\n", "\n", "Here is the actual code:" ] }, { "cell_type": "code", "execution_count": 13, "id": "9d04c26b-94ad-4869-bbf7-7a817bad7dcb", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9d04c26b-94ad-4869-bbf7-7a817bad7dcb", "outputId": "9a3c00d9-ef1d-4b8a-9d36-137c834415ca" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "........................................................................................................................................................................................................................................................................................................................." ] } ], "source": [ "test_loader = torch.utils.data.DataLoader(\n", " dataset=test_dataset,\n", " batch_size=32,\n", " shuffle=False,\n", ")\n", "\n", "confusion_matrix = experiment.create_confusion_matrix()\n", "\n", "for batch_step, (images, labels) in enumerate(test_loader):\n", " print(\".\", end=\"\")\n", " images = Variable(\n", " images.view(\n", " -1,\n", " experiment.get_parameter(\"sequence_length\"),\n", " experiment.get_parameter(\"input_size\"),\n", " )\n", " )\n", " labels = Variable(labels)\n", "\n", " outputs = rnn(images)\n", " _, predicted = torch.max(outputs.data, 1)\n", "\n", " confusion_matrix.compute_matrix(labels.data, predicted, images=images)\n", "\n", "experiment.log_confusion_matrix(\n", " matrix=confusion_matrix,\n", " title=\"MNIST Confusion Matrix, Epoch #%d\" % (epoch + 1),\n", " file_name=\"confusion-matrix-%03d.json\" % (epoch + 1),\n", ");" ] }, { "cell_type": "markdown", "id": "B9Rl3sOHJt8u", "metadata": { "id": "B9Rl3sOHJt8u" }, "source": [ "Now, because we are in a Jupyter Notebook, we signal that the experiment has completed:" ] }, { "cell_type": "code", "execution_count": 14, "id": "0d993454-6ea1-4b61-9cba-8425ba2c7eba", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0d993454-6ea1-4b61-9cba-8425ba2c7eba", "outputId": "5145f3dd-9e41-466c-ea56-923ea67306cd" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "COMET INFO: ---------------------------\n", "COMET INFO: Comet.ml Experiment Summary\n", "COMET INFO: ---------------------------\n", "COMET INFO: Data:\n", "COMET INFO: display_summary_level : 1\n", "COMET INFO: url : https://www.comet.com/dsblank/pytorch-confusion-matrix/819f19ee68ba4b91bab88421b795451d\n", "COMET INFO: Metrics [count] (min, max):\n", "COMET INFO: train_batch_accuracy [100] : (0.10000000149011612, 0.925000011920929)\n", "COMET INFO: train_loss [9] : (0.4030756652355194, 2.309687614440918)\n", "COMET INFO: Parameters:\n", "COMET INFO: batch_size : 120\n", "COMET INFO: epochs : 10\n", "COMET INFO: first_layer_units : 128\n", "COMET INFO: hidden_size : 128\n", "COMET INFO: input_size : 28\n", "COMET INFO: learning_rate : 0.01\n", "COMET INFO: num_classes : 10\n", "COMET INFO: num_layers : 2\n", "COMET INFO: sequence_length : 28\n", "COMET INFO: Uploads [count]:\n", "COMET INFO: confusion-matrix : 1\n", "COMET INFO: environment details : 1\n", "COMET INFO: filename : 1\n", "COMET INFO: images [1258] : 1258\n", "COMET INFO: installed packages : 1\n", "COMET INFO: model graph : 1\n", "COMET INFO: notebook : 1\n", "COMET INFO: os packages : 1\n", "COMET INFO: source_code : 1\n", "COMET INFO: ---------------------------\n", "COMET INFO: Uploading metrics, params, and assets to Comet before program termination (may take several seconds)\n", "COMET INFO: The Python SDK has 3600 seconds to finish before aborting...\n" ] } ], "source": [ "experiment.end()" ] }, { "cell_type": "markdown", "id": "4JAptoP87RW1", "metadata": { "id": "4JAptoP87RW1" }, "source": [ "Finally, we can explore the Confusion Matrix in the Comet UI. You can select the epoch by selecting the \"Confusion Matrix Name\" and click on a cell to see examples of that type." ] }, { "cell_type": "code", "execution_count": 15, "id": "6af03e15-bc30-4d26-b743-ff425d789137", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 821 }, "id": "6af03e15-bc30-4d26-b743-ff425d789137", "outputId": "4f40a1ae-fd2f-4cba-f09f-6709216086b2" }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "experiment.display(tab=\"confusion-matrix\")" ] }, { "cell_type": "markdown", "id": "fkZRIi8CJ7-G", "metadata": { "id": "fkZRIi8CJ7-G" }, "source": [ "Clicking on a cell in the matrix should show up to 25 examples of that type of confusion or correct classification.\n", "\n", "For more information about Comet ML, please see:\n", "\n", "1. [Getting started in 30 seconds](https://www.comet.com/docs/v2/guides/quickstart/)\n", "2. [Experiments](https://www.comet.com/docs/v2/guides/experiment-management/create-an-experiment/)\n", "3. [Working with Jupyter Notebooks](https://www.comet.com/docs/v2/guides/experiment-management/jupyter-notebook/)\n", "4. [Confusion Matrix](https://www.comet.com/docs/v2/guides/experiment-management/log-data/confusion-matrix/)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "PytorchConfusionMatrixSimple.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 5 }