{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Training a classifier using PyTorch Lightning\n", "\n", "This notebook trains a classifier on CIFAR-10 using PyTorch Lightning.\n", "\n", "![](../images/lightning-cifar10.png)\n", "\n", "\n", "\n", "We integrate 3LC with a `LightningModule` by creating Tables up front (outside the module) and\n", "calling 3LC's public API (`tlc.init`, `tlc.collect_metrics`, `tlc.metrics.Predictor`) directly\n", "from standard Lightning hooks (`on_train_start`, `on_train_end`)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Project setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - PyTorch Lightning Classification\"\n", "RUN_NAME = \"Train classifier\"\n", "RUN_DESCRIPTION = \"Train a resnet model on CIFAR-10\"\n", "TRAIN_DATASET_NAME = \"cifar-10-train\"\n", "VAL_DATASET_NAME = \"cifar-10-val\"\n", "DOWNLOAD_PATH = \"../../transient_data\"\n", "EPOCHS = 5\n", "BATCH_SIZE = 32\n", "NUM_WORKERS = 0\n", "INSTALL_DEPENDENCIES = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if INSTALL_DEPENDENCIES:\n", " %pip install -q 3lc[pacmap]\n", " %pip install -q pytorch-lightning\n", " %pip install -q git+https://github.com/3lc-ai/3lc-examples.git" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "import tlc\n", "import torch\n", "import torch.nn.functional as F\n", "import torchvision\n", "from tlc.integration.torch.samplers import create_sampler\n", "from torch.utils.data import DataLoader\n", "\n", "from tlc_tools.common import infer_torch_device" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define model creation function" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create model for cifar10 training\n", "def create_model():\n", " return torchvision.models.resnet18(pretrained=False, num_classes=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define the schema of our dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "################## 3LC ##################\n", "\n", "# Schema describes the columns 3LC will write into the Table. We include an explicit\n", "# weight column (`SampleWeightSchema`) so the 3LC sampler has weights to sample from;\n", "# `TableWriter` doesn't add one automatically the way `Table.from_torch_dataset` did.\n", "classes = [\"airplane\", \"automobile\", \"bird\", \"cat\", \"deer\", \"dog\", \"frog\", \"horse\", \"ship\", \"truck\"]\n", "schema = {\n", " \"image\": tlc.schemas.ImageSchema(),\n", " \"label\": tlc.schemas.CategoricalLabelSchema(classes=classes),\n", " \"weight\": tlc.schemas.SampleWeightSchema(),\n", "}\n", "\n", "#########################################" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Describe the metrics we want to collect" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "################## 3LC ##################\n", "\n", "# Define a function for the metrics we want to collect\n", "def metrics_fn(batch, predictor_output: tlc.metrics.PredictorOutput):\n", " # tuple[torch.Tensor, torch.Tensor]\n", " labels = batch[1].to(infer_torch_device())\n", " predictions = predictor_output.forward\n", " num_classes = predictions.shape[1]\n", " one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()\n", "\n", " # Confidence & Predicted\n", " softmax_output = torch.nn.functional.softmax(predictions, dim=1)\n", " predicted_indices = torch.argmax(predictions, dim=1)\n", " confidence = torch.gather(softmax_output, 1, predicted_indices.unsqueeze(1)).squeeze(1)\n", "\n", " # Per-sample accuracy (1 if correct, 0 otherwise)\n", " accuracy = (predicted_indices == labels).float()\n", "\n", " # Unreduced Cross Entropy Loss\n", " cross_entropy_loss: torch.Tensor = torch.nn.CrossEntropyLoss(reduction=\"none\")(predictions, labels)\n", "\n", " # RMSE\n", " mse: torch.Tensor = torch.nn.MSELoss(reduction=\"none\")(softmax_output, one_hot_labels)\n", " mse = mse.mean(dim=1)\n", " rmse = torch.sqrt(mse)\n", "\n", " # MAE\n", " mae: torch.Tensor = torch.nn.L1Loss(reduction=\"none\")(softmax_output, one_hot_labels)\n", " mae = mae.mean(dim=1)\n", "\n", " # These values will be the columns of the Run in the 3LC Dashboard\n", " return {\n", " \"loss\": cross_entropy_loss.cpu().numpy(),\n", " \"predicted\": predicted_indices.cpu().numpy(),\n", " \"accuracy\": accuracy.cpu().numpy(),\n", " \"confidence\": confidence.cpu().numpy(),\n", " \"rmse\": rmse.cpu().numpy(),\n", " \"mae\": mae.cpu().numpy(),\n", " }\n", "\n", "\n", "# Schemas will be inferred automatically, but can be explicitly defined if customizations are needed,\n", "# for example to set a description or a value map for an integer label.\n", "schemas = {\n", " \"loss\": tlc.schemas.Float32Schema(description=\"Cross entropy loss\"),\n", " \"predicted\": tlc.schemas.CategoricalLabelSchema(\n", " display_name=\"predicted label\",\n", " classes=classes,\n", " ),\n", "}\n", "\n", "# Use the metrics function and schemas to create a metrics collector\n", "classification_metrics_collector = tlc.metrics.FunctionalMetricsCollector(\n", " collection_fn=metrics_fn,\n", " schema=schemas,\n", ")\n", "\n", "#########################################" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create 3LC Tables\n", "\n", "We create the Tables eagerly, outside the `LightningModule`. This sidesteps the DDP\n", "table-coordination problem that arises when tables are created inside `train_dataloader`\n", "(which Lightning replicates per process). With tables on disk before `Trainer.fit()`, every\n", "process simply opens the same Table.\n", "\n", "We use `tlc.TableWriter` directly — iterating the torchvision dataset and pushing rows\n", "in batches. Transforms are attached at sample-time via `Table.with_transform`, so the\n", "Table itself preserves the original PIL images for visualization. We use the validation\n", "transform (no augmentation) when collecting metrics on the training table." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_transform = torchvision.transforms.Compose(\n", " [\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", " ]\n", ")\n", "\n", "val_transform = torchvision.transforms.Compose(\n", " [\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", " ]\n", ")\n", "\n", "\n", "def train_fn(sample):\n", " return train_transform(sample[\"image\"]), sample[\"label\"]\n", "\n", "\n", "def val_fn(sample):\n", " return val_transform(sample[\"image\"]), sample[\"label\"]\n", "\n", "\n", "def write_cifar_table(dataset, dataset_name):\n", " \"\"\"Stream a torchvision CIFAR-10 split into a 3LC Table via TableWriter.\"\"\"\n", " writer = tlc.TableWriter(\n", " project_name=PROJECT_NAME,\n", " dataset_name=dataset_name,\n", " schema=schema,\n", " if_exists=\"overwrite\",\n", " )\n", " images, labels = [], []\n", "\n", " def flush():\n", " if not images:\n", " return\n", " writer.add_batch(\n", " {\n", " \"image\": images,\n", " \"label\": labels,\n", " \"weight\": [1.0] * len(images),\n", " }\n", " )\n", " images.clear()\n", " labels.clear()\n", "\n", " for image, label in dataset:\n", " images.append(image)\n", " labels.append(label)\n", " if len(images) >= 1000:\n", " flush()\n", " flush()\n", " return writer.finalize()\n", "\n", "\n", "raw_train_dataset = torchvision.datasets.CIFAR10(root=DOWNLOAD_PATH, train=True, download=True)\n", "raw_val_dataset = torchvision.datasets.CIFAR10(root=DOWNLOAD_PATH, train=False, download=True)\n", "\n", "train_table = write_cifar_table(raw_train_dataset, TRAIN_DATASET_NAME)\n", "val_table = write_cifar_table(raw_val_dataset, VAL_DATASET_NAME)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define our LightningModule\n", "\n", "The 3LC integration is just a few standard Lightning hooks:\n", "\n", "- `on_train_start` initializes a Run and records hyperparameters.\n", "- `on_train_end` collects per-sample metrics on the train and val tables, then marks the\n", " Run completed.\n", "\n", "The dataloaders are built directly from the Tables, with a 3LC weighted sampler on the\n", "training side via `tlc.integration.torch.samplers.create_sampler`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MyModule(pl.LightningModule):\n", " def __init__(self, train_table, val_table, batch_size=BATCH_SIZE, lr=1e-3):\n", " super().__init__()\n", " self.save_hyperparameters(ignore=[\"train_table\", \"val_table\"])\n", " self.train_table = train_table\n", " self.val_table = val_table\n", " self.model = create_model()\n", " self.batch_size = batch_size\n", " self.lr = lr\n", " self.tlc_run: tlc.Run | None = None\n", "\n", " def forward(self, x):\n", " return self.model(x)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.cross_entropy(logits, y)\n", " return loss\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=self.lr)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(\n", " self.train_table.with_transform(train_fn),\n", " sampler=create_sampler(self.train_table, weighted=True, exclude_zero_weights=True),\n", " batch_size=self.batch_size,\n", " num_workers=NUM_WORKERS,\n", " )\n", "\n", " def val_dataloader(self):\n", " return DataLoader(\n", " self.val_table.with_transform(val_fn),\n", " batch_size=self.batch_size,\n", " num_workers=NUM_WORKERS,\n", " )\n", "\n", " def on_train_start(self):\n", " super().on_train_start()\n", " self.tlc_run = tlc.init(\n", " project_name=PROJECT_NAME,\n", " run_name=RUN_NAME,\n", " description=RUN_DESCRIPTION,\n", " parameters=dict(self.hparams_initial),\n", " if_exists=\"rename\",\n", " )\n", " self.tlc_run.set_status_running()\n", "\n", " def on_train_end(self):\n", " super().on_train_end()\n", " # Mirrors the decorator default: metrics collected once, at end of training.\n", " self._collect_3lc_metrics()\n", " if self.tlc_run is not None:\n", " self.tlc_run.set_status_completed()\n", "\n", " def _collect_3lc_metrics(self):\n", " predictor = tlc.metrics.Predictor(self)\n", " # Use the val transform on both splits so metrics aren't computed on augmented images.\n", " for split, table in [(\"train\", self.train_table), (\"val\", self.val_table)]:\n", " tlc.collect_metrics(\n", " table=table.with_transform(val_fn),\n", " metrics_collectors=[classification_metrics_collector],\n", " predictor=predictor,\n", " split=split,\n", " constants={\"epoch\": self.current_epoch},\n", " exclude_zero_weights=True,\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create the LightningModule, passing in the Tables we created above.\n", "module = MyModule(train_table=train_table, val_table=val_table)\n", "\n", "# Train the model\n", "trainer = pl.Trainer(max_epochs=EPOCHS)\n", "trainer.fit(module)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After training has completed, the Run can be viewed in the 3LC Dashboard." ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" }, "test_marks": [ "slow" ] }, "nbformat": 4, "nbformat_minor": 2 }