{ "cells": [ { "cell_type": "markdown", "id": "9476fc52", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Image Classification (CIFAR-10) on Kaggle\n", "\n", "The web address of the competition is https://www.kaggle.com/c/cifar-10" ] }, { "cell_type": "code", "execution_count": 1, "id": "e5a3fb64", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:26.260293Z", "iopub.status.busy": "2023-08-18T19:29:26.259545Z", "iopub.status.idle": "2023-08-18T19:29:29.089576Z", "shell.execute_reply": "2023-08-18T19:29:29.088675Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "import collections\n", "import math\n", "import os\n", "import shutil\n", "import pandas as pd\n", "import torch\n", "import torchvision\n", "from torch import nn\n", "from d2l import torch as d2l" ] }, { "cell_type": "markdown", "id": "0241c1c6", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "We provide a small-scale sample of the dataset that\n", "contains the first 1000 training images and 5 random testing images" ] }, { "cell_type": "code", "execution_count": 2, "id": "0d41dcd1", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.095074Z", "iopub.status.busy": "2023-08-18T19:29:29.094404Z", "iopub.status.idle": "2023-08-18T19:29:29.393994Z", "shell.execute_reply": "2023-08-18T19:29:29.393137Z" }, "origin_pos": 4, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading ../data/kaggle_cifar10_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_cifar10_tiny.zip...\n" ] } ], "source": [ "d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',\n", " '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')\n", "\n", "demo = True\n", "\n", "if demo:\n", " data_dir = d2l.download_extract('cifar10_tiny')\n", "else:\n", " data_dir = '../data/cifar-10/'" ] }, { "cell_type": "markdown", "id": "863e1e66", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Organizing the Dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "04bf8387", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.399003Z", "iopub.status.busy": "2023-08-18T19:29:29.398718Z", "iopub.status.idle": "2023-08-18T19:29:29.406335Z", "shell.execute_reply": "2023-08-18T19:29:29.405552Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# training examples: 1000\n", "# classes: 10\n" ] } ], "source": [ "def read_csv_labels(fname):\n", " \"\"\"Read `fname` to return a filename to label dictionary.\"\"\"\n", " with open(fname, 'r') as f:\n", " lines = f.readlines()[1:]\n", " tokens = [l.rstrip().split(',') for l in lines]\n", " return dict(((name, label) for name, label in tokens))\n", "\n", "labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n", "print('\n", "print('" ] }, { "cell_type": "markdown", "id": "a42d795c", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Split the validation set out of the original training set" ] }, { "cell_type": "code", "execution_count": 4, "id": "0ae3357e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.411145Z", "iopub.status.busy": "2023-08-18T19:29:29.410869Z", "iopub.status.idle": "2023-08-18T19:29:29.418258Z", "shell.execute_reply": "2023-08-18T19:29:29.417439Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def copyfile(filename, target_dir):\n", " \"\"\"Copy a file into a target directory.\"\"\"\n", " os.makedirs(target_dir, exist_ok=True)\n", " shutil.copy(filename, target_dir)\n", "\n", "def reorg_train_valid(data_dir, labels, valid_ratio):\n", " \"\"\"Split the validation set out of the original training set.\"\"\"\n", " n = collections.Counter(labels.values()).most_common()[-1][1]\n", " n_valid_per_label = max(1, math.floor(n * valid_ratio))\n", " label_count = {}\n", " for train_file in os.listdir(os.path.join(data_dir, 'train')):\n", " label = labels[train_file.split('.')[0]]\n", " fname = os.path.join(data_dir, 'train', train_file)\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'train_valid', label))\n", " if label not in label_count or label_count[label] < n_valid_per_label:\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'valid', label))\n", " label_count[label] = label_count.get(label, 0) + 1\n", " else:\n", " copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n", " 'train', label))\n", " return n_valid_per_label" ] }, { "cell_type": "markdown", "id": "ab600f93", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Organizes the testing set for data loading during prediction" ] }, { "cell_type": "code", "execution_count": 5, "id": "890972a8", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.422565Z", "iopub.status.busy": "2023-08-18T19:29:29.422289Z", "iopub.status.idle": "2023-08-18T19:29:29.426856Z", "shell.execute_reply": "2023-08-18T19:29:29.426083Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def reorg_test(data_dir):\n", " \"\"\"Organize the testing set for data loading during prediction.\"\"\"\n", " for test_file in os.listdir(os.path.join(data_dir, 'test')):\n", " copyfile(os.path.join(data_dir, 'test', test_file),\n", " os.path.join(data_dir, 'train_valid_test', 'test',\n", " 'unknown'))" ] }, { "cell_type": "markdown", "id": "5c5a3b9e", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Invoke\n", "functions defined above" ] }, { "cell_type": "code", "execution_count": 7, "id": "1daf58c4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.439643Z", "iopub.status.busy": "2023-08-18T19:29:29.438882Z", "iopub.status.idle": "2023-08-18T19:29:29.700309Z", "shell.execute_reply": "2023-08-18T19:29:29.699321Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def reorg_cifar10_data(data_dir, valid_ratio):\n", " labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n", " reorg_train_valid(data_dir, labels, valid_ratio)\n", " reorg_test(data_dir)\n", "\n", "batch_size = 32 if demo else 128\n", "valid_ratio = 0.1\n", "reorg_cifar10_data(data_dir, valid_ratio)" ] }, { "cell_type": "markdown", "id": "c8ef27b5", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Image Augmentation" ] }, { "cell_type": "code", "execution_count": 9, "id": "be0d5428", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.714890Z", "iopub.status.busy": "2023-08-18T19:29:29.714292Z", "iopub.status.idle": "2023-08-18T19:29:29.718602Z", "shell.execute_reply": "2023-08-18T19:29:29.717807Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "transform_train = torchvision.transforms.Compose([\n", " torchvision.transforms.Resize(40),\n", " torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),\n", " ratio=(1.0, 1.0)),\n", " torchvision.transforms.RandomHorizontalFlip(),\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n", " [0.2023, 0.1994, 0.2010])])\n", "\n", "transform_test = torchvision.transforms.Compose([\n", " torchvision.transforms.ToTensor(),\n", " torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n", " [0.2023, 0.1994, 0.2010])])" ] }, { "cell_type": "markdown", "id": "e743a95e", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Read the organized dataset consisting of raw image files" ] }, { "cell_type": "code", "execution_count": 10, "id": "056ac33a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.722917Z", "iopub.status.busy": "2023-08-18T19:29:29.722506Z", "iopub.status.idle": "2023-08-18T19:29:29.733889Z", "shell.execute_reply": "2023-08-18T19:29:29.733119Z" }, "origin_pos": 23, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(\n", " os.path.join(data_dir, 'train_valid_test', folder),\n", " transform=transform_train) for folder in ['train', 'train_valid']]\n", "\n", "valid_ds, test_ds = [torchvision.datasets.ImageFolder(\n", " os.path.join(data_dir, 'train_valid_test', folder),\n", " transform=transform_test) for folder in ['valid', 'test']]" ] }, { "cell_type": "markdown", "id": "d0ab6608", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Specify all the image augmentation operations defined above" ] }, { "cell_type": "code", "execution_count": 11, "id": "06fa7207", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.738557Z", "iopub.status.busy": "2023-08-18T19:29:29.737952Z", "iopub.status.idle": "2023-08-18T19:29:29.743073Z", "shell.execute_reply": "2023-08-18T19:29:29.742323Z" }, "origin_pos": 26, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "train_iter, train_valid_iter = [torch.utils.data.DataLoader(\n", " dataset, batch_size, shuffle=True, drop_last=True)\n", " for dataset in (train_ds, train_valid_ds)]\n", "\n", "valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,\n", " drop_last=True)\n", "\n", "test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,\n", " drop_last=False)" ] }, { "cell_type": "markdown", "id": "1c95d19b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Model" ] }, { "cell_type": "code", "execution_count": 12, "id": "d527425d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.747678Z", "iopub.status.busy": "2023-08-18T19:29:29.747059Z", "iopub.status.idle": "2023-08-18T19:29:29.751129Z", "shell.execute_reply": "2023-08-18T19:29:29.750380Z" }, "origin_pos": 35, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def get_net():\n", " num_classes = 10\n", " net = d2l.resnet18(num_classes, 3)\n", " return net\n", "\n", "loss = nn.CrossEntropyLoss(reduction=\"none\")" ] }, { "cell_type": "markdown", "id": "f6c66fef", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Training Function" ] }, { "cell_type": "code", "execution_count": 13, "id": "bde40789", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.755665Z", "iopub.status.busy": "2023-08-18T19:29:29.755131Z", "iopub.status.idle": "2023-08-18T19:29:29.764392Z", "shell.execute_reply": "2023-08-18T19:29:29.763621Z" }, "origin_pos": 38, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n", " lr_decay):\n", " trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,\n", " weight_decay=wd)\n", " scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)\n", " num_batches, timer = len(train_iter), d2l.Timer()\n", " legend = ['train loss', 'train acc']\n", " if valid_iter is not None:\n", " legend.append('valid acc')\n", " animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n", " legend=legend)\n", " net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n", " for epoch in range(num_epochs):\n", " net.train()\n", " metric = d2l.Accumulator(3)\n", " for i, (features, labels) in enumerate(train_iter):\n", " timer.start()\n", " l, acc = d2l.train_batch_ch13(net, features, labels,\n", " loss, trainer, devices)\n", " metric.add(l, acc, labels.shape[0])\n", " timer.stop()\n", " if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n", " animator.add(epoch + (i + 1) / num_batches,\n", " (metric[0] / metric[2], metric[1] / metric[2],\n", " None))\n", " if valid_iter is not None:\n", " valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)\n", " animator.add(epoch + 1, (None, None, valid_acc))\n", " scheduler.step()\n", " measures = (f'train loss {metric[0] / metric[2]:.3f}, '\n", " f'train acc {metric[1] / metric[2]:.3f}')\n", " if valid_iter is not None:\n", " measures += f', valid acc {valid_acc:.3f}'\n", " print(measures + f'\\n{metric[2] * num_epochs / timer.sum():.1f}'\n", " f' examples/sec on {str(devices)}')" ] }, { "cell_type": "markdown", "id": "c12f609a", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Training and Validating the Model" ] }, { "cell_type": "code", "execution_count": 14, "id": "cd4a55c7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:29:29.768734Z", "iopub.status.busy": "2023-08-18T19:29:29.768227Z", "iopub.status.idle": "2023-08-18T19:30:37.496878Z", "shell.execute_reply": "2023-08-18T19:30:37.495860Z" }, "origin_pos": 41, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train loss 0.654, train acc 0.789, valid acc 0.438\n", "958.1 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:30:37.438438\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4\n", "lr_period, lr_decay, net = 4, 0.9, get_net()\n", "net(next(iter(train_iter))[0])\n", "train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n", " lr_decay)" ] }, { "cell_type": "markdown", "id": "cd839c40", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Classifying the Testing Set" ] }, { "cell_type": "code", "execution_count": 15, "id": "a66ef205", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:30:37.501313Z", "iopub.status.busy": "2023-08-18T19:30:37.500748Z", "iopub.status.idle": "2023-08-18T19:31:40.934103Z", "shell.execute_reply": "2023-08-18T19:31:40.932837Z" }, "origin_pos": 44, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train loss 0.608, train acc 0.786\n", "1040.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:31:40.877905\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net, preds = get_net(), []\n", "net(next(iter(train_valid_iter))[0])\n", "train(net, train_valid_iter, None, num_epochs, lr, wd, devices, lr_period,\n", " lr_decay)\n", "\n", "for X, _ in test_iter:\n", " y_hat = net(X.to(devices[0]))\n", " preds.extend(y_hat.argmax(dim=1).type(torch.int32).cpu().numpy())\n", "sorted_ids = list(range(1, len(test_ds) + 1))\n", "sorted_ids.sort(key=lambda x: str(x))\n", "df = pd.DataFrame({'id': sorted_ids, 'label': preds})\n", "df['label'] = df['label'].apply(lambda x: train_valid_ds.classes[x])\n", "df.to_csv('submission.csv', index=False)" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }