{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# PyTorch 3LC Fashion-MNIST Sample Notebook\n", "\n", "This notebook is fundamentally similar to the MNIST notebook, but it uses the slightly more interesting FashionMNIST dataset.\n", "\n", "![](../images/pytorch-fashion-mnist.png)\n", "\n", "\n", "\n", "As the original authors of the dataset noted: \"MNIST is too easy... MNIST is overused... MNIST cannot represent modern CV tasks.\"\n", "\n", "While this sentiment now applies to FashionMNIST as well, it is still a more interesting example due to the slightly\n", "more complex images and labels.\n", "\n", "This notebook demonstrates training a Convolutional Neural Network (CNN) on the Fashion-MNIST dataset using PyTorch and\n", "3LC. Training runs for 5 epochs, and during this period, classification metrics and embeddings are collected.\n", "\n", "The notebook demonstrates:\n", "\n", "+ How to use a 3LC Table for integrating with built-in PyTorch datasets.\n", "+ Metrics collection using a custom `MetricsCollector` subclass and a `EmbeddingsMetricsCollector`.\n", "+ Reducing the dimensionality of embeddings using PaCMAP as a post-processing step." ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "## Project Setup" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": { "papermill": { "duration": 0.034774, "end_time": "2023-04-19T12:15:44.870371", "exception": false, "start_time": "2023-04-19T12:15:44.835597", "status": "completed" }, "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - Fashion-MNIST\"\n", "RUN_NAME = \"Train a Fashion-MNIST Classifier\"\n", "DESCRIPTION = \"Train a simple CNN to classify Fashion-MNIST images\"\n", "TRAIN_DATASET_NAME = \"fashion-mnist-train\"\n", "VAL_DATASET_NAME = \"fashion-mnist-val\"\n", "DOWNLOAD_PATH = \"../../transient_data\"\n", "COLLECT_METRICS_BATCH_SIZE = 2048\n", "TRAIN_BATCH_SIZE = 64\n", "INITIAL_LR = 1.0\n", "LR_GAMMA = 0.7\n", "EPOCHS = 5\n", "NUM_WORKERS = 0\n", "DEVICE = None" ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": { "tags": [] }, "outputs": [], "source": [ "%pip --quiet install 3lc[pacmap]" ] }, { "cell_type": "markdown", "id": "4", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": { "papermill": { "duration": 1.254606, "end_time": "2023-04-19T12:15:46.131975", "exception": false, "start_time": "2023-04-19T12:15:44.877369", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import tlc\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torchvision\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "if DEVICE is None:\n", " if torch.cuda.is_available():\n", " device = \"cuda:0\"\n", " elif torch.backends.mps.is_available():\n", " device = \"mps\"\n", " else:\n", " device = \"cpu\"\n", "else:\n", " device = DEVICE\n", "\n", "device = torch.device(device)\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "## Initialize a 3LC Run" ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "config = {\n", " \"train_batch_size\": TRAIN_BATCH_SIZE,\n", " \"initial_lr\": INITIAL_LR,\n", " \"lr_gamma\": LR_GAMMA,\n", " \"epochs\": EPOCHS,\n", "}\n", "\n", "run = tlc.init(\n", " project_name=PROJECT_NAME,\n", " run_name=RUN_NAME,\n", " description=DESCRIPTION,\n", " parameters=config,\n", " if_exists=\"overwrite\",\n", ")" ] }, { "cell_type": "markdown", "id": "9", "metadata": {}, "source": [ "## Setup Datasets" ] }, { "cell_type": "code", "execution_count": null, "id": "10", "metadata": {}, "outputs": [], "source": [ "transform = torchvision.transforms.Compose(\n", " [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]\n", ")\n", "\n", "train_dataset = torchvision.datasets.FashionMNIST(root=DOWNLOAD_PATH, train=True, download=True)\n", "eval_dataset = torchvision.datasets.FashionMNIST(root=DOWNLOAD_PATH, train=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "11", "metadata": {}, "outputs": [], "source": [ "class_names = [\n", " \"T-shirt/top\",\n", " \"Trouser\",\n", " \"Pullover\",\n", " \"Dress\",\n", " \"Coat\",\n", " \"Sandal\",\n", " \"Shirt\",\n", " \"Sneaker\",\n", " \"Bag\",\n", " \"Ankle boot\",\n", "]\n", "\n", "structure = (tlc.PILImage(\"image\"), tlc.CategoricalLabel(\"label\", class_names))\n", "\n", "\n", "def transforms(x):\n", " return transform(x[0]), torch.tensor(x[1])\n", "\n", "\n", "# We pick up the latest version of the dataset, so that we can re-run this notebook as-is\n", "# after adding new revisions to the dataset.\n", "tlc_train_dataset = (\n", " tlc.Table.from_torch_dataset(\n", " dataset=train_dataset,\n", " structure=structure,\n", " dataset_name=TRAIN_DATASET_NAME,\n", " project_name=PROJECT_NAME,\n", " description=\"Fashion-MNIST training dataset\",\n", " table_name=\"train\",\n", " if_exists=\"overwrite\",\n", " )\n", " .map(transforms)\n", " .latest()\n", ")\n", "\n", "tlc_val_dataset = (\n", " tlc.Table.from_torch_dataset(\n", " dataset=eval_dataset,\n", " dataset_name=VAL_DATASET_NAME,\n", " structure=structure,\n", " project_name=PROJECT_NAME,\n", " description=\"Fashion-MNIST validation dataset\",\n", " table_name=\"val\",\n", " if_exists=\"overwrite\",\n", " )\n", " .map(transforms)\n", " .latest()\n", ")" ] }, { "cell_type": "markdown", "id": "12", "metadata": {}, "source": [ "## Setup Model" ] }, { "cell_type": "code", "execution_count": null, "id": "13", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "class Net(nn.Module):\n", " # From https://github.com/pytorch/examples/blob/main/mnist/main.py\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = nn.Conv2d(1, 32, 3, 1)\n", " self.conv2 = nn.Conv2d(32, 64, 3, 1)\n", " self.dropout1 = nn.Dropout(0.25)\n", " self.dropout2 = nn.Dropout(0.5)\n", " self.fc1 = nn.Linear(9216, 128)\n", " self.fc2 = nn.Linear(128, 10)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = F.relu(x)\n", " x = self.conv2(x)\n", " x = F.relu(x)\n", " x = F.max_pool2d(x, 2)\n", " x = self.dropout1(x)\n", " x = torch.flatten(x, 1)\n", " x = self.fc1(x)\n", " x = F.relu(x)\n", " x = self.dropout2(x)\n", " x = self.fc2(x)\n", " output = F.log_softmax(x, dim=1)\n", " return output\n", "\n", "\n", "model = Net().to(device)" ] }, { "cell_type": "markdown", "id": "14", "metadata": {}, "source": [ "## Setup Training Loop" ] }, { "cell_type": "code", "execution_count": null, "id": "15", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "optimizer = torch.optim.Adadelta(model.parameters(), lr=INITIAL_LR)\n", "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=LR_GAMMA)" ] }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "def train(model, device, train_loader, optimizer, epoch):\n", " model.train()\n", " for data, target in tqdm(train_loader, desc=f\"Training {epoch + 1}/{EPOCHS}\"): # Epoch is 0-indexed\n", " data, target = data.to(device), target.to(device)\n", " optimizer.zero_grad()\n", " output = model(data)\n", " loss = F.nll_loss(output, target)\n", "\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "markdown", "id": "17", "metadata": {}, "source": [ "## Setup Metrics Collectors" ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": {}, "outputs": [], "source": [ "class FashionMNISTMetricsCollector(tlc.MetricsCollector):\n", " def __init__(self, criterion):\n", " super().__init__()\n", " self.criterion = criterion\n", "\n", " def compute_metrics(self, batch, predictor_output):\n", " predictions = predictor_output.forward\n", " labels = batch[1].to(device)\n", "\n", " metrics = {\n", " \"loss\": self.criterion(predictions, labels).cpu().numpy(),\n", " \"predicted\": torch.argmax(predictions, dim=1).cpu().numpy(),\n", " \"confidence\": torch.exp(torch.max(predictions, dim=1).values).cpu().numpy(),\n", " \"accuracy\": (torch.argmax(predictions, dim=1) == labels).cpu().numpy(),\n", " }\n", " return metrics\n", "\n", " @property\n", " def column_schemas(self):\n", " # Explicitly override the schema of the predicted label, in order for it to be displayed as a\n", " # categorical label in the Dashboard.\n", " schemas = {\n", " \"predicted\": tlc.CategoricalLabelSchema(\n", " classes=class_names,\n", " display_name=\"predicted label\",\n", " )\n", " }\n", " return schemas\n", "\n", "\n", "mnist_metrics_collector = FashionMNISTMetricsCollector(nn.NLLLoss(reduction=\"none\"))\n", "embeddings_metrics_collector = tlc.EmbeddingsMetricsCollector(layers=[4])" ] }, { "cell_type": "markdown", "id": "19", "metadata": {}, "source": [ "## Run Training\n", "\n", "We run training using a weighted sampler provided by the 3LC Table. The sampler uses the default `weights` column\n", "to sample the data. The weights can be updated in the Dashboard, and will be automatically picked up by the sampler." ] }, { "cell_type": "code", "execution_count": null, "id": "20", "metadata": {}, "outputs": [], "source": [ "sampler = tlc_train_dataset.create_sampler()\n", "\n", "train_loader = torch.utils.data.DataLoader(\n", " tlc_train_dataset,\n", " batch_size=TRAIN_BATCH_SIZE,\n", " sampler=sampler,\n", " num_workers=NUM_WORKERS,\n", ")\n", "\n", "metrics_collection_dataloader_args = {\n", " \"num_workers\": NUM_WORKERS,\n", " \"batch_size\": COLLECT_METRICS_BATCH_SIZE,\n", "}\n", "\n", "predictor = tlc.Predictor(model, layers=[4])\n", "\n", "# Train the model\n", "for epoch in range(EPOCHS):\n", " train(model, device, train_loader, optimizer, epoch)\n", "\n", " tlc.collect_metrics(\n", " tlc_train_dataset,\n", " metrics_collectors=[\n", " mnist_metrics_collector,\n", " embeddings_metrics_collector,\n", " ],\n", " predictor=predictor,\n", " split=\"train\",\n", " constants={\"epoch\": epoch},\n", " dataloader_args=metrics_collection_dataloader_args,\n", " )\n", " tlc.collect_metrics(\n", " tlc_val_dataset,\n", " metrics_collectors=[\n", " mnist_metrics_collector,\n", " embeddings_metrics_collector,\n", " ],\n", " predictor=predictor,\n", " split=\"val\",\n", " constants={\"epoch\": epoch},\n", " dataloader_args=metrics_collection_dataloader_args,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "21", "metadata": {}, "outputs": [], "source": [ "# Reduce embeddings using the final validation-set embeddings to fit a PaCMAP model\n", "url_mapping = run.reduce_embeddings_by_foreign_table_url(\n", " tlc_train_dataset.url,\n", " method=\"pacmap\",\n", " n_components=3,\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "papermill": { "default_parameters": {}, "duration": 5.341171, "end_time": "2023-04-19T12:15:47.738497", "environment_variables": {}, "exception": true, "input_path": "c:\\Users\\RupalDangi\\experimental-notebooks\\experimental-notebooks\\vision\\classification\\MNIST\\torch-examples-cnn-tlc.ipynb", "output_path": "c:\\Users\\RupalDangi\\experimental-notebooks\\experimental-notebooks\\vision\\classification\\MNIST\\torch-examples-cnn-tlc.ipynb", "parameters": {}, "start_time": "2023-04-19T12:15:42.397326", "version": "2.4.0" }, "vscode": { "interpreter": { "hash": "4ee08d3161a7c3d0b4fb68735cf1133b294ca2a41e3afd68fe473bc7561f5f08" } } }, "nbformat": 4, "nbformat_minor": 5 }