{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# PyTorch 3LC CIFAR-10 Sample Notebook\n", "\n", "This notebook demonstrates fine-tuning a pretrained ResNet-18 model on the CIFAR-10 dataset using PyTorch and 3LC. We\n", "run the fine-tuning process for 5 epochs. During training, both classification and embeddings metrics are collected.\n", "\n", "![](../images/pytorch-cifar10.png)\n", "\n", "\n", "\n", "The notebook covers:\n", "\n", "+ Creating a Table from a PyTorch Dataset.\n", "+ Fine-tuning a pretrained ResNet-18 on CIFAR-10 using the Table.\n", "+ Using FunctionalMetricsCollector and EmbeddingsMetricsCollector for metrics and embedding collection.\n", "+ Reducing the dimensionality of embeddings using PaCMAP after training completes." ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "## Project Setup" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - CIFAR-10\"\n", "RUN_NAME = \"CIFAR-10 Demo Run\"\n", "DESCRIPTION = \"Fine-tune ResNet18 on CIFAR-10\"\n", "TRAIN_DATASET_NAME = \"cifar-10-train\"\n", "VAL_DATASET_NAME = \"cifar-10-val\"\n", "NUM_CLASSES = 10\n", "EMBEDDINGS_COLLECTION_FREQUENCY = 4\n", "TRANSIENT_DATA_PATH = \"../transient_data\"\n", "EPOCHS = 5\n", "BATCH_SIZE = 32\n", "INITIAL_LR = 0.01\n", "LR_GAMMA = 0.9\n", "NUM_WORKERS = 0\n", "MODEL_NAME = \"resnet18\"\n", "PRETRAINED = True\n", "DEVICE = None\n", "DROP_RATE = 0.2\n", "DROP_PATH_RATE = 0.2" ] }, { "cell_type": "code", "execution_count": null, "id": "3", "metadata": {}, "outputs": [], "source": [ "%pip --quiet install timm\n", "%pip --quiet install 3lc[pacmap]" ] }, { "cell_type": "markdown", "id": "4", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "5", "metadata": {}, "outputs": [], "source": [ "import tlc\n", "import torch\n", "import torchvision\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "if DEVICE is None:\n", " if torch.cuda.is_available():\n", " device = \"cuda:0\"\n", " elif torch.backends.mps.is_available():\n", " device = \"mps\"\n", " else:\n", " device = \"cpu\"\n", "else:\n", " device = DEVICE\n", "\n", "device = torch.device(device)\n", "print(f\"Using device: {device}\")" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "## Initialize a 3LC Run\n", "\n", "First, we initialize a 3LC run. This will create a new empty run which will be visible in the 3LC dashboard.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "run = tlc.init(\n", " project_name=PROJECT_NAME,\n", " run_name=RUN_NAME,\n", " description=DESCRIPTION,\n", " if_exists=\"overwrite\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "9", "metadata": {}, "outputs": [], "source": [ "config = {\n", " \"epochs\": EPOCHS,\n", " \"batch_size\": BATCH_SIZE,\n", " \"initial_lr\": INITIAL_LR,\n", " \"lr_gamma\": LR_GAMMA,\n", " \"model_name\": MODEL_NAME,\n", " \"pretrained\": PRETRAINED,\n", " \"drop_rate\": DROP_RATE,\n", " \"drop_path_rate\": DROP_PATH_RATE,\n", "}\n", "\n", "# Persist the notebook configuration parameters to the run\n", "run.set_parameters(config)" ] }, { "cell_type": "markdown", "id": "10", "metadata": {}, "source": [ "## Setup Datasets\n", "\n", "We will create a Table using the CIFAR-10 dataset from torchvision\n", "which will be used for visualization in the 3LC dashboard, and for associating metrics with the dataset. This will also \n", "allow the user to make virtual edits to the dataset, and run new experiments on the modified dataset.\n", "\n", "Since the underlying CIFAR-10 dataset is not stored as individual image files, 3LC will copy the images to the\n", "configured sample root." ] }, { "cell_type": "code", "execution_count": null, "id": "11", "metadata": {}, "outputs": [], "source": [ "train_dataset = torchvision.datasets.CIFAR10(root=TRANSIENT_DATA_PATH, train=True, download=True)\n", "val_dataset = torchvision.datasets.CIFAR10(root=TRANSIENT_DATA_PATH, train=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "12", "metadata": {}, "outputs": [], "source": [ "# The `structure` describes the layout of the samples in the dataset\n", "# This helps 3lc to create a table with the correct columns and schemas\n", "class_names = [\"airplane\", \"automobile\", \"bird\", \"cat\", \"deer\", \"dog\", \"frog\", \"horse\", \"ship\", \"truck\"]\n", "\n", "structure = (\n", " tlc.PILImage(\"image\"),\n", " tlc.CategoricalLabel(\"label\", classes=class_names),\n", ")\n", "\n", "train_transform = torchvision.transforms.Compose(\n", " [\n", " torchvision.transforms.Resize(224),\n", " torchvision.transforms.RandomHorizontalFlip(),\n", " torchvision.transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n", " ]\n", ")\n", "\n", "val_transform = torchvision.transforms.Compose(\n", " [\n", " torchvision.transforms.Resize(224),\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n", " ]\n", ")\n", "\n", "\n", "def train_fn(sample):\n", " return train_transform(sample[0]), sample[1]\n", "\n", "\n", "def val_fn(sample):\n", " return val_transform(sample[0]), sample[1]\n", "\n", "\n", "# Create the 3LC Tables\n", "\n", "# Notice that instead of assigning the transforms to the torch dataset\n", "# we created, we assign them to the 3LC Table using `map`. This is because\n", "# we want the Table to be able to capture the untransformed images.\n", "\n", "# Notice that we also call `map_collect_metrics` on the training table\n", "# to specify the transforms which should be used to collect metrics.\n", "# Since we don't want 3LC to collect metrics on augmented images, we\n", "# use the validation transforms for metrics collection. If\n", "# `map_collect_metrics` is not called, the transforms given to `map`\n", "# will be used for metrics collection.\n", "\n", "tlc_train_dataset = (\n", " tlc.Table.from_torch_dataset(\n", " dataset=train_dataset,\n", " dataset_name=TRAIN_DATASET_NAME,\n", " table_name=\"train\",\n", " description=\"CIFAR-10 training dataset\",\n", " structure=structure,\n", " if_exists=\"overwrite\",\n", " )\n", " .map(train_fn)\n", " .map_collect_metrics(val_fn)\n", ")\n", "\n", "tlc_val_dataset = tlc.Table.from_torch_dataset(\n", " dataset=val_dataset,\n", " dataset_name=VAL_DATASET_NAME,\n", " table_name=\"val\",\n", " description=\"CIFAR-10 validation dataset\",\n", " structure=structure,\n", " if_exists=\"overwrite\",\n", ").map(val_fn)\n", "\n", "\n", "# Automatically pick up the latest version of the tables to include edits committed in the dashboard.\n", "initial_train_url = tlc_train_dataset.url\n", "initial_val_url = tlc_val_dataset.url\n", "\n", "tlc_train_dataset = tlc_train_dataset.latest()\n", "tlc_val_dataset = tlc_val_dataset.latest()\n", "\n", "if tlc_train_dataset.url != initial_train_url:\n", " print(f\"Using latest training table {tlc_train_dataset.url}\")\n", "else:\n", " print(f\"Using source training table {initial_train_url}\")\n", "\n", "if tlc_val_dataset.url != initial_val_url:\n", " print(f\"Using latest validation table {tlc_val_dataset.url}\")\n", "else:\n", " print(f\"Using source validation table {initial_val_url}\")" ] }, { "cell_type": "markdown", "id": "13", "metadata": {}, "source": [ "## Setup Model\n", "\n", "We use a ResNet-18 model from the `timm` model repository." ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "import timm\n", "\n", "torch.backends.cudnn.benchmark = True\n", "\n", "model = timm.create_model(\n", " MODEL_NAME, pretrained=PRETRAINED, num_classes=NUM_CLASSES, drop_rate=DROP_RATE, drop_path_rate=DROP_PATH_RATE\n", ").to(device)" ] }, { "cell_type": "markdown", "id": "15", "metadata": {}, "source": [ "## Setup Training Loop" ] }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": {}, "outputs": [], "source": [ "criterion = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(\n", " model.parameters(),\n", " lr=INITIAL_LR,\n", " momentum=0.9,\n", " weight_decay=1e-4,\n", " nesterov=True,\n", ")\n", "lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=LR_GAMMA)\n", "scaler = torch.cuda.amp.GradScaler(enabled=device.type == \"cuda\")" ] }, { "cell_type": "code", "execution_count": null, "id": "17", "metadata": {}, "outputs": [], "source": [ "def train(model, loader, criterion, optimizer, scaler):\n", " model.train()\n", " train_loss = 0\n", " correct = 0\n", " total = 0\n", " for images, labels in tqdm(loader, desc=\"Training\"):\n", " images, labels = images.to(device), labels.to(device)\n", " optimizer.zero_grad()\n", " with torch.cuda.amp.autocast(enabled=device.type == \"cuda\"):\n", " outputs = model(images)\n", " loss = criterion(outputs, labels)\n", " scaler.scale(loss).backward()\n", " scaler.step(optimizer)\n", " scaler.update()\n", " train_loss += loss.item() * labels.size(0)\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", " return train_loss / total, 100 * correct / total" ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": {}, "outputs": [], "source": [ "def validate(model, loader, criterion):\n", " model.eval()\n", " val_loss = 0\n", " correct = 0\n", " total = 0\n", " with torch.no_grad():\n", " for images, labels in tqdm(loader, desc=\"Validating\"):\n", " images, labels = images.to(device), labels.to(device)\n", " outputs = model(images)\n", " loss = criterion(outputs, labels)\n", " val_loss += loss.item() * labels.size(0)\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", " return val_loss / total, 100 * correct / total" ] }, { "cell_type": "markdown", "id": "19", "metadata": {}, "source": [ "## Setup Metrics Collectors" ] }, { "cell_type": "code", "execution_count": null, "id": "20", "metadata": {}, "outputs": [], "source": [ "# Print the model layers. This is useful for finding the indices of named modules in a model.\n", "# The index will be used for creating the embeddings metrics collector.\n", "indices_and_modules = list(enumerate(model.named_modules()))\n", "for idx, (name, _module) in indices_and_modules:\n", " print(idx, name)\n", " pass\n", "\n", "# The final fully connected layer will be used for collecting the embeddings.\n", "final_flatten_layer_index = indices_and_modules[-1][0]\n", "final_flatten_layer_name = indices_and_modules[-1][1][0]\n", "\n", "print(f\"Using layer {final_flatten_layer_index} ({final_flatten_layer_name}) for embeddings collection\")" ] }, { "cell_type": "code", "execution_count": null, "id": "21", "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "\n", "## Define a function for the metrics we want to collect, will be passed to a FunctionalMetricsCollector\n", "def metrics_fn(batch, predictor_output: tlc.PredictorOutput):\n", " # tuple[torch.Tensor, torch.Tensor]\n", " labels = batch[1].to(device)\n", " predictions = predictor_output.forward\n", " num_classes = predictions.shape[1]\n", " one_hot_labels = F.one_hot(labels, num_classes=num_classes).float()\n", "\n", " # Confidence & Predicted\n", " softmax_output = torch.nn.functional.softmax(predictions, dim=1)\n", " predicted_indices = torch.argmax(predictions, dim=1)\n", " confidence = torch.gather(softmax_output, 1, predicted_indices.unsqueeze(1)).squeeze(1)\n", "\n", " # Per-sample accuracy (1 if correct, 0 otherwise)\n", " accuracy = (predicted_indices == labels).float()\n", "\n", " # Unreduced Cross Entropy Loss\n", " cross_entropy_loss: torch.Tensor = torch.nn.CrossEntropyLoss(reduction=\"none\")(predictions, labels)\n", "\n", " # RMSE\n", " mse: torch.Tensor = torch.nn.MSELoss(reduction=\"none\")(softmax_output, one_hot_labels)\n", " mse = mse.mean(dim=1)\n", " rmse = torch.sqrt(mse)\n", "\n", " # MAE\n", " mae: torch.Tensor = torch.nn.L1Loss(reduction=\"none\")(softmax_output, one_hot_labels)\n", " mae = mae.mean(dim=1)\n", "\n", " return {\n", " \"loss\": cross_entropy_loss.cpu().numpy(),\n", " \"predicted\": predicted_indices.cpu().numpy(),\n", " \"accuracy\": accuracy.cpu().numpy(),\n", " \"confidence\": confidence.cpu().numpy(),\n", " \"rmse\": rmse.cpu().numpy(),\n", " \"mae\": mae.cpu().numpy(),\n", " }\n", "\n", "\n", "# Schemas will be inferred automatically, but can be explicitly defined if customizations are needed,\n", "# for example to set the description, display name, display_importance, class_names, etc..\n", "\n", "schemas = {\n", " \"loss\": tlc.Schema(\n", " description=\"Cross entropy loss\",\n", " value=tlc.Float32Value(),\n", " ),\n", " \"predicted\": tlc.CategoricalLabelSchema(\n", " display_name=\"predicted label\",\n", " classes=class_names,\n", " ),\n", "}\n", "\n", "## Define metrics collectors\n", "\n", "classification_metrics_collector = tlc.FunctionalMetricsCollector(\n", " collection_fn=metrics_fn,\n", " column_schemas=schemas,\n", ")\n", "\n", "embeddings_metrics_collector = tlc.EmbeddingsMetricsCollector(layers=[final_flatten_layer_index])" ] }, { "cell_type": "markdown", "id": "22", "metadata": {}, "source": [ "## Run training\n", "\n", "We have now defined our training and validation datasets, defined our model, and configured our metrics collectors. We\n", "are ready to run training." ] }, { "cell_type": "code", "execution_count": null, "id": "23", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "# Create a weighted sampler to determine the sampling probability for each sample\n", "sampler = tlc_train_dataset.create_sampler()\n", "\n", "predictor = tlc.Predictor(model, layers=[final_flatten_layer_index])\n", "\n", "# Create training and validation dataloaders using our 3LC Tables\n", "train_loader = DataLoader(\n", " tlc_train_dataset,\n", " batch_size=BATCH_SIZE,\n", " num_workers=NUM_WORKERS,\n", " sampler=sampler,\n", ")\n", "\n", "val_loader = DataLoader(\n", " tlc_val_dataset,\n", " batch_size=BATCH_SIZE,\n", " num_workers=NUM_WORKERS,\n", ")\n", "\n", "# We can use a larger batch size for the metrics collection, as we don't need to backpropagate\n", "metrics_collection_dataloader_args = {\"num_workers\": NUM_WORKERS, \"batch_size\": 512}\n", "\n", "# We will collect the learning rate as a constant value per metrics-collection run,\n", "# but we want it to be hidden by default in the Dashboard.\n", "learning_rate_schema = tlc.Schema(\n", " display_name=\"LR\",\n", " description=\"Learning rate\",\n", " value=tlc.Float32Value(),\n", " default_visible=False,\n", ")\n", "\n", "# Train the model\n", "for epoch in range(EPOCHS):\n", " train_loss, train_acc = train(model, train_loader, criterion, optimizer, scaler)\n", " val_loss, val_acc = validate(model, val_loader, criterion)\n", "\n", " # Collect classification metrics every epoch\n", " tlc.collect_metrics(\n", " tlc_val_dataset,\n", " metrics_collectors=[classification_metrics_collector, embeddings_metrics_collector],\n", " predictor=predictor,\n", " split=\"val\",\n", " constants={\"epoch\": epoch, \"learning_rate\": lr_scheduler.get_last_lr()[0]},\n", " constants_schemas={\"learning_rate\": learning_rate_schema},\n", " dataloader_args=metrics_collection_dataloader_args,\n", " )\n", " tlc.collect_metrics(\n", " tlc_train_dataset,\n", " metrics_collectors=[classification_metrics_collector, embeddings_metrics_collector],\n", " predictor=predictor,\n", " split=\"train\",\n", " constants={\"epoch\": epoch, \"learning_rate\": lr_scheduler.get_last_lr()[0]},\n", " constants_schemas={\"learning_rate\": learning_rate_schema},\n", " dataloader_args=metrics_collection_dataloader_args,\n", " )\n", "\n", " print(\n", " f\"Epoch {epoch + 1}/{EPOCHS}:, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, \"\n", " f\"lr: {lr_scheduler.get_last_lr()[0]:.6f}\"\n", " )\n", " print(f\"Epoch {epoch + 1}/{EPOCHS}:, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%\")\n", "\n", " lr_scheduler.step()" ] }, { "cell_type": "code", "execution_count": null, "id": "24", "metadata": {}, "outputs": [], "source": [ "# Reduce embeddings using the final validation-set embeddings to fit a PaCMAP model\n", "url_mapping = run.reduce_embeddings_by_foreign_table_url(\n", " tlc_train_dataset.url,\n", " method=\"pacmap\",\n", " n_components=3,\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "vscode": { "interpreter": { "hash": "20406f9571d76e3cd1c9018bc6ce1da8abf2dfe1b9830ddfcfda1f866a1db7c5" } } }, "nbformat": 4, "nbformat_minor": 5 }