{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Training a classifier using PyTorch Lightning\n", "\n", "This notebooks trains a classifier on CIFAR-10 using Pytorch Lightning.\n", "\n", "![](../images/lightning-cifar10.png)\n", "\n", "\n", "\n", "When using a `LightningModule` which defines the `train_dataloader`, `val_dataloader` and/or `test_dataloader` methods, we can decorate our\n", "`LightningModule` with the `tlc.module_decorator` to automatically generate Tables for our datasets and collect any desired metrics into a Run." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Project setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - PyTorch Lightning Classification\"\n", "DOWNLOAD_PATH = \"../../transient_data\"\n", "EPOCHS = 5\n", "BATCH_SIZE = 32\n", "NUM_WORKERS = 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install 3lc[pacmap]\n", "%pip install pytorch-lightning\n", "%pip install git+https://github.com/3lc-ai/3lc-examples.git" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "import tlc\n", "import torch\n", "import torch.nn.functional as F\n", "import torchvision\n", "\n", "from tlc_tools.common import infer_torch_device" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define model creation function" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create model for cifar10 training\n", "def create_model():\n", " return torchvision.models.resnet18(pretrained=False, num_classes=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define the structure of our dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "################## 3LC ##################\n", "\n", "# Define the structure of a sample in the dataset(s)\n", "# Here, the structure is a tuple, where the first element is a PIL image which we will call \"Image\",\n", "# and the second element is an integer label, which maps to the given classes.\n", "classes = [\"airplane\", \"automobile\", \"bird\", \"cat\", \"deer\", \"dog\", \"frog\", \"horse\", \"ship\", \"truck\"]\n", "structure = (tlc.PILImage(\"Image\"), tlc.CategoricalLabel(\"Label\", classes=classes))\n", "\n", "#########################################" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Describe the metrics we want to collect" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "################## 3LC ##################\n", "\n", "# Define a function for the metrics we want to collect\n", "def metrics_fn(batch, predictor_output: tlc.PredictorOutput):\n", " # tuple[torch.Tensor, torch.Tensor]\n", " labels = batch[1].to(infer_torch_device())\n", " predictions = predictor_output.forward\n", " num_classes = predictions.shape[1]\n", " one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()\n", "\n", " # Confidence & Predicted\n", " softmax_output = torch.nn.functional.softmax(predictions, dim=1)\n", " predicted_indices = torch.argmax(predictions, dim=1)\n", " confidence = torch.gather(softmax_output, 1, predicted_indices.unsqueeze(1)).squeeze(1)\n", "\n", " # Per-sample accuracy (1 if correct, 0 otherwise)\n", " accuracy = (predicted_indices == labels).float()\n", "\n", " # Unreduced Cross Entropy Loss\n", " cross_entropy_loss: torch.Tensor = torch.nn.CrossEntropyLoss(reduction=\"none\")(predictions, labels)\n", "\n", " # RMSE\n", " mse: torch.Tensor = torch.nn.MSELoss(reduction=\"none\")(softmax_output, one_hot_labels)\n", " mse = mse.mean(dim=1)\n", " rmse = torch.sqrt(mse)\n", "\n", " # MAE\n", " mae: torch.Tensor = torch.nn.L1Loss(reduction=\"none\")(softmax_output, one_hot_labels)\n", " mae = mae.mean(dim=1)\n", "\n", " # These values will be the columns of the Run in the 3LC Dashboard\n", " return {\n", " \"loss\": cross_entropy_loss.cpu().numpy(),\n", " \"predicted\": predicted_indices.cpu().numpy(),\n", " \"accuracy\": accuracy.cpu().numpy(),\n", " \"confidence\": confidence.cpu().numpy(),\n", " \"rmse\": rmse.cpu().numpy(),\n", " \"mae\": mae.cpu().numpy(),\n", " }\n", "\n", "\n", "# Schemas will be inferred automatically, but can be explicitly defined if customizations are needed,\n", "# for example to set a description or a value map for an integer label.\n", "schemas = {\n", " \"loss\": tlc.Schema(\n", " description=\"Cross entropy loss\",\n", " value=tlc.Float32Value(),\n", " ),\n", " \"predicted\": tlc.CategoricalLabelSchema(\n", " display_name=\"predicted label\",\n", " class_names=classes,\n", " ),\n", "}\n", "\n", "# Use the metrics function and schemas to create a metrics collector\n", "classification_metrics_collector = tlc.FunctionalMetricsCollector(\n", " collection_fn=metrics_fn,\n", " column_schemas=schemas,\n", ")\n", "\n", "#########################################" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define our LightningModule (With 3LC decorator)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "################## 3LC ##################\n", "@tlc.lightning_module(\n", " structure=structure,\n", " project_name=PROJECT_NAME,\n", " metrics_collectors=classification_metrics_collector,\n", ")\n", "#########################################\n", "class MyModule(pl.LightningModule):\n", " def __init__(self, batch_size=BATCH_SIZE, lr=1e-3):\n", " super().__init__()\n", " self.save_hyperparameters()\n", " self.model = create_model()\n", " self.batch_size = batch_size\n", " self.lr = lr\n", "\n", " def forward(self, x):\n", " return self.model(x)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.cross_entropy(logits, y)\n", " return loss\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=self.lr)\n", "\n", " def train_dataloader(self):\n", " # Define transformations for the training dataset\n", " train_transform = torchvision.transforms.Compose(\n", " [\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", " ]\n", " )\n", "\n", " # Create the training dataset, including the transformations\n", " train_dataset = torchvision.datasets.CIFAR10(\n", " root=DOWNLOAD_PATH,\n", " train=True,\n", " download=True,\n", " transform=train_transform,\n", " )\n", "\n", " # Create a DataLoader for the training dataset\n", " return torch.utils.data.DataLoader(\n", " train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=NUM_WORKERS\n", " )\n", "\n", " def val_dataloader(self):\n", " # Define transformations for the validation dataset\n", " val_transform = torchvision.transforms.Compose(\n", " [\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", " ]\n", " )\n", "\n", " # Create the validation dataset, including the transformations\n", " val_dataset = torchvision.datasets.CIFAR10(\n", " root=DOWNLOAD_PATH,\n", " train=False,\n", " download=True,\n", " transform=val_transform,\n", " )\n", "\n", " # Create a DataLoader for the validation dataset\n", " return torch.utils.data.DataLoader(val_dataset, batch_size=self.batch_size, num_workers=NUM_WORKERS)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create the LightningModule\n", "module = MyModule()\n", "\n", "# Train the model\n", "trainer = pl.Trainer(max_epochs=EPOCHS)\n", "trainer.fit(module)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After training has completed, the Run can be viewed in the 3LC Dashboard." ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" }, "test_marks": [ "slow" ] }, "nbformat": 4, "nbformat_minor": 2 }