{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# How to Train a Model on MNIST with FiftyOne and Torch\n", "This recipe demonstrates how to train a PyTorch model on the **MNIST** dataset using `FiftyOneTorchDataset`. This is useful when you want to build and evaluate models in Torch while managing your data pipeline directly from FiftyOne. Specifically, it covers:\n", "\n", "* Loading the MNIST dataset from the [Dataset Zoo](https://voxel51.com/docs/fiftyone/user_guide/dataset_zoo/index.html)\n", "* Creating train/validation/test splits with FiftyOne’s tagging and random splitting utilities\n", "* Building a subset of the dataset for faster experimentation\n", "* Running a simple training loop via an external script (`mnist_training.py`)\n", "* Saving model weights for later evaluation or reuse\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "\n", "If you haven't already, install FiftyOne:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install fiftyone" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we'll use [PyTorch](https://pytorch.org/) for working with tensors and inspecting sample data. To follow along, you'll need to install `torch` and `torchvision`, if necessary:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install torch torchvision" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import fiftyone as fo\n", "import fiftyone.zoo as foz\n", "import fiftyone.utils.random as four" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import DataLoader\n", "import numpy as np\n", "import torchvision.transforms.v2 as transforms\n", "from torchvision import tv_tensors\n", "import matplotlib.pyplot as plt\n", "import matplotlib.patches as plt_patches\n", "from PIL import Image\n", "import urllib.request" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To run this recipe, you’ll need the mnist_training.py script, which contains a simple PyTorch training loop. The following cell will automatically download the file into your working directory so it can be imported directly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "url = \"https://cdn.voxel51.com/tutorials_torch_dataset_examples/notebook_simple_training_example/mnist_training.py\"\n", "urllib.request.urlretrieve(url, \"mnist_training.py\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "url = \"https://cdn.voxel51.com/tutorials_torch_dataset_examples/notebook_the_cache_field_names_argument/utils.py\"\n", "urllib.request.urlretrieve(url, \"utils.py\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import mnist_training" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "torch.multiprocessing.set_start_method('forkserver')\n", "torch.multiprocessing.set_forkserver_preload(['torch', 'fiftyone'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic Training Example on MNIST" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will look at an actual training script with `FiftyOneTorchDataset`" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Split 'train' already downloaded\n", "Split 'test' already downloaded\n", "Loading existing dataset 'mnist'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use\n" ] } ], "source": [ "mnist = foz.load_zoo_dataset(\"mnist\")\n", "mnist.persistent = True" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Session launched. Run `session.show()` to open the App in a cell output.\n" ] }, { "data": { "text/plain": [ "Dataset: mnist\n", "Media type: image\n", "Num samples: 70000\n", "Selected samples: 0\n", "Selected labels: 0\n", "Session URL: http://localhost:5151/" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fo.launch_app(mnist, auto=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's say that for our training, we want to define some random subset of our trainset to be a validation set. We can easily do this with FiftyOne." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'train': 54000, 'validation': 6000, 'test': 10000}\n" ] } ], "source": [ "# remove existing 'train' or 'validation' tags if they exist\n", "mnist.untag_samples(['train', 'validation'])\n", "\n", "# create a random split, just on the samples not tagged 'test'\n", "not_test = mnist.match_tags('test', bool=False)\n", "four.random_split(not_test, {'train' : 0.9, 'validation' : 0.1})\n", "print(mnist.count_sample_tags())" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# subset if we want it\n", "samples = []\n", "samples += mnist.match_tags('train').take(1000).values('id')\n", "for tag in ['test', 'validation']:\n", " samples += mnist.match_tags(tag).values('id')\n", "\n", "subset = mnist.select(samples)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Average Train Loss = 3.999074: 100%|██████████| 63/63 [00:01<00:00, 58.45it/s]\n", "Average Validation Loss = 2.811698: 100%|██████████| 375/375 [00:02<00:00, 149.01it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "New best lost achieved : 2.801392190893491. Saving model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Average Train Loss = 1.072026: 100%|██████████| 63/63 [00:00<00:00, 119.78it/s]\n", "Average Validation Loss = 0.396746: 100%|██████████| 375/375 [00:01<00:00, 215.10it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "New best lost achieved : 0.39641891201337176. Saving model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Average Train Loss = 0.148484: 100%|██████████| 63/63 [00:00<00:00, 120.53it/s]\n", "Average Validation Loss = 0.319500: 100%|██████████| 375/375 [00:01<00:00, 211.25it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "New best lost achieved : 0.3149221637323499. Saving model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Average Train Loss = 0.627752: 100%|██████████| 63/63 [00:00<00:00, 97.89it/s] \n", "Average Validation Loss = 0.304854: 100%|██████████| 375/375 [00:01<00:00, 207.17it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "New best lost achieved : 0.2977131818582614. Saving model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Average Train Loss = 0.204026: 100%|██████████| 63/63 [00:00<00:00, 119.48it/s]\n", "Average Validation Loss = 0.210062: 100%|██████████| 375/375 [00:01<00:00, 214.69it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "New best lost achieved : 0.2064167803612848. Saving model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Average Train Loss = 0.070824: 100%|██████████| 63/63 [00:00<00:00, 106.55it/s]\n", "Average Validation Loss = 1.467735: 100%|██████████| 375/375 [00:02<00:00, 173.34it/s]\n", "Average Train Loss = 0.509837: 100%|██████████| 63/63 [00:00<00:00, 112.51it/s]\n", "Average Validation Loss = 0.387830: 100%|██████████| 375/375 [00:02<00:00, 163.92it/s]\n", "Average Train Loss = 0.236021: 100%|██████████| 63/63 [00:00<00:00, 116.83it/s]\n", "Average Validation Loss = 0.287110: 100%|██████████| 375/375 [00:01<00:00, 211.45it/s]\n", "Average Train Loss = 0.047093: 100%|██████████| 63/63 [00:00<00:00, 99.11it/s] \n", "Average Validation Loss = 0.156705: 100%|██████████| 375/375 [00:01<00:00, 213.70it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "New best lost achieved : 0.14917240004179377. Saving model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Average Train Loss = 0.009842: 100%|██████████| 63/63 [00:00<00:00, 97.05it/s] \n", "Average Validation Loss = 0.138089: 100%|██████████| 375/375 [00:01<00:00, 211.95it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "New best lost achieved : 0.13520573990046977. Saving model...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Average Validation Loss = 0.113355: 100%|██████████| 625/625 [00:10<00:00, 61.62it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Final Test Results:\n", "Loss = 0.11413920720983296\n", " precision recall f1-score support\n", "\n", " 0 - zero 0.98 0.97 0.98 980\n", " 1 - one 0.98 0.99 0.99 1135\n", " 2 - two 0.96 0.97 0.96 1032\n", " 3 - three 0.95 0.97 0.96 1010\n", " 4 - four 0.96 0.97 0.96 982\n", " 5 - five 0.95 0.96 0.95 892\n", " 6 - six 0.96 0.97 0.96 958\n", " 7 - seven 0.97 0.93 0.95 1028\n", " 8 - eight 0.98 0.94 0.96 974\n", " 9 - nine 0.95 0.96 0.96 1009\n", "\n", " accuracy 0.96 10000\n", " macro avg 0.96 0.96 0.96 10000\n", "weighted avg 0.96 0.96 0.96 10000\n", "\n" ] } ], "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "path_to_save_weights = '/path/to/save/weights'\n", "mnist_training.main(subset, 10, 10, device, path_to_save_weights)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This recipe showed how to train a PyTorch model on MNIST using `FiftyOneTorchDataset`, with dataset splits, subsets, and a simple training loop." ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "torch-dataset", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 2 }