{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Collect metrics using a pre-trained model\n", "\n", "This notebook demonstrates how to use a pre-trained model to collect metrics on a dataset.\n", "\n", "![](../images/cifar10-collect-metrics.png)\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install 3lc\n", "%pip install timm\n", "%pip install git+https://github.com/3lc-ai/3lc-examples.git" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import timm\n", "import tlc\n", "import torchvision\n", "\n", "from tlc_tools.common import infer_torch_device" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Project setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "NUM_WORKERS = 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare Tables\n", "\n", "We will reuse the tables from the notebook [create-table-from-torch.ipynb](../1-create-tables/create-table-from-torch-dataset.ipynb),\n", "and use a pre-trained model from Hugging Face Hub." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = infer_torch_device()\n", "\n", "# Use a resnet18 model from timm, already trained on CIFAR-10\n", "model = timm.create_model(\"hf_hub:FredMell/resnet18-cifar10\", pretrained=True).to(device)\n", "\n", "# Load the tables\n", "train_table = tlc.Table.from_names(\"initial\", \"CIFAR-10-train\", \"3LC Tutorials - CIFAR-10\")\n", "val_table = tlc.Table.from_names(\"initial\", \"CIFAR-10-val\", \"3LC Tutorials - CIFAR-10\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_transform = torchvision.transforms.Compose(\n", " [\n", " torchvision.transforms.Resize(224),\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n", " ]\n", ")\n", "\n", "\n", "def transform(sample):\n", " image = sample[0]\n", " label = sample[1]\n", " return (image_transform(image), label)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Apply the transforms to the tables to ensure model-compatibility (ensure any existing maps are cleared first)\n", "\n", "train_table.clear_maps()\n", "train_table = train_table.map(transform)\n", "\n", "val_table.clear_maps()\n", "val_table = val_table.map(transform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Collect metrics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create a 3LC run and collect metrics\n", "run = tlc.init(\n", " project_name=train_table.project_name,\n", " description=\"Only collect metrics with trained model on CIFAR-10\",\n", ")\n", "\n", "dataloader_args = {\n", " \"batch_size\": 128,\n", " \"num_workers\": NUM_WORKERS,\n", " \"pin_memory\": True,\n", "}\n", "\n", "classes = list(train_table.get_simple_value_map(\"Label\").values())\n", "\n", "tlc.collect_metrics(\n", " table=train_table,\n", " predictor=model,\n", " metrics_collectors=tlc.ClassificationMetricsCollector(classes=classes),\n", " dataloader_args=dataloader_args,\n", " split=\"train\",\n", ")\n", "\n", "tlc.collect_metrics(\n", " table=val_table,\n", " predictor=model,\n", " metrics_collectors=tlc.ClassificationMetricsCollector(classes=classes),\n", " dataloader_args=dataloader_args,\n", " split=\"val\",\n", ")\n", "\n", "run.set_status_completed()" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" }, "test_marks": [ "slow" ] }, "nbformat": 4, "nbformat_minor": 2 }