{
"cells": [
{
"cell_type": "markdown",
"id": "9476fc52",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Image Classification (CIFAR-10) on Kaggle\n",
"\n",
"The web address of the competition is https://www.kaggle.com/c/cifar-10"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e5a3fb64",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:26.260293Z",
"iopub.status.busy": "2023-08-18T19:29:26.259545Z",
"iopub.status.idle": "2023-08-18T19:29:29.089576Z",
"shell.execute_reply": "2023-08-18T19:29:29.088675Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import collections\n",
"import math\n",
"import os\n",
"import shutil\n",
"import pandas as pd\n",
"import torch\n",
"import torchvision\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "0241c1c6",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"We provide a small-scale sample of the dataset that\n",
"contains the first 1000 training images and 5 random testing images"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0d41dcd1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:29.095074Z",
"iopub.status.busy": "2023-08-18T19:29:29.094404Z",
"iopub.status.idle": "2023-08-18T19:29:29.393994Z",
"shell.execute_reply": "2023-08-18T19:29:29.393137Z"
},
"origin_pos": 4,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading ../data/kaggle_cifar10_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_cifar10_tiny.zip...\n"
]
}
],
"source": [
"d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',\n",
" '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')\n",
"\n",
"demo = True\n",
"\n",
"if demo:\n",
" data_dir = d2l.download_extract('cifar10_tiny')\n",
"else:\n",
" data_dir = '../data/cifar-10/'"
]
},
{
"cell_type": "markdown",
"id": "863e1e66",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Organizing the Dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "04bf8387",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:29.399003Z",
"iopub.status.busy": "2023-08-18T19:29:29.398718Z",
"iopub.status.idle": "2023-08-18T19:29:29.406335Z",
"shell.execute_reply": "2023-08-18T19:29:29.405552Z"
},
"origin_pos": 6,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# training examples: 1000\n",
"# classes: 10\n"
]
}
],
"source": [
"def read_csv_labels(fname):\n",
" \"\"\"Read `fname` to return a filename to label dictionary.\"\"\"\n",
" with open(fname, 'r') as f:\n",
" lines = f.readlines()[1:]\n",
" tokens = [l.rstrip().split(',') for l in lines]\n",
" return dict(((name, label) for name, label in tokens))\n",
"\n",
"labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n",
"print('\n",
"print('"
]
},
{
"cell_type": "markdown",
"id": "a42d795c",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Split the validation set out of the original training set"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0ae3357e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:29.411145Z",
"iopub.status.busy": "2023-08-18T19:29:29.410869Z",
"iopub.status.idle": "2023-08-18T19:29:29.418258Z",
"shell.execute_reply": "2023-08-18T19:29:29.417439Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def copyfile(filename, target_dir):\n",
" \"\"\"Copy a file into a target directory.\"\"\"\n",
" os.makedirs(target_dir, exist_ok=True)\n",
" shutil.copy(filename, target_dir)\n",
"\n",
"def reorg_train_valid(data_dir, labels, valid_ratio):\n",
" \"\"\"Split the validation set out of the original training set.\"\"\"\n",
" n = collections.Counter(labels.values()).most_common()[-1][1]\n",
" n_valid_per_label = max(1, math.floor(n * valid_ratio))\n",
" label_count = {}\n",
" for train_file in os.listdir(os.path.join(data_dir, 'train')):\n",
" label = labels[train_file.split('.')[0]]\n",
" fname = os.path.join(data_dir, 'train', train_file)\n",
" copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n",
" 'train_valid', label))\n",
" if label not in label_count or label_count[label] < n_valid_per_label:\n",
" copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n",
" 'valid', label))\n",
" label_count[label] = label_count.get(label, 0) + 1\n",
" else:\n",
" copyfile(fname, os.path.join(data_dir, 'train_valid_test',\n",
" 'train', label))\n",
" return n_valid_per_label"
]
},
{
"cell_type": "markdown",
"id": "ab600f93",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Organizes the testing set for data loading during prediction"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "890972a8",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:29.422565Z",
"iopub.status.busy": "2023-08-18T19:29:29.422289Z",
"iopub.status.idle": "2023-08-18T19:29:29.426856Z",
"shell.execute_reply": "2023-08-18T19:29:29.426083Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def reorg_test(data_dir):\n",
" \"\"\"Organize the testing set for data loading during prediction.\"\"\"\n",
" for test_file in os.listdir(os.path.join(data_dir, 'test')):\n",
" copyfile(os.path.join(data_dir, 'test', test_file),\n",
" os.path.join(data_dir, 'train_valid_test', 'test',\n",
" 'unknown'))"
]
},
{
"cell_type": "markdown",
"id": "5c5a3b9e",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Invoke\n",
"functions defined above"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1daf58c4",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:29.439643Z",
"iopub.status.busy": "2023-08-18T19:29:29.438882Z",
"iopub.status.idle": "2023-08-18T19:29:29.700309Z",
"shell.execute_reply": "2023-08-18T19:29:29.699321Z"
},
"origin_pos": 14,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def reorg_cifar10_data(data_dir, valid_ratio):\n",
" labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))\n",
" reorg_train_valid(data_dir, labels, valid_ratio)\n",
" reorg_test(data_dir)\n",
"\n",
"batch_size = 32 if demo else 128\n",
"valid_ratio = 0.1\n",
"reorg_cifar10_data(data_dir, valid_ratio)"
]
},
{
"cell_type": "markdown",
"id": "c8ef27b5",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Image Augmentation"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "be0d5428",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:29.714890Z",
"iopub.status.busy": "2023-08-18T19:29:29.714292Z",
"iopub.status.idle": "2023-08-18T19:29:29.718602Z",
"shell.execute_reply": "2023-08-18T19:29:29.717807Z"
},
"origin_pos": 20,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"transform_train = torchvision.transforms.Compose([\n",
" torchvision.transforms.Resize(40),\n",
" torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),\n",
" ratio=(1.0, 1.0)),\n",
" torchvision.transforms.RandomHorizontalFlip(),\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n",
" [0.2023, 0.1994, 0.2010])])\n",
"\n",
"transform_test = torchvision.transforms.Compose([\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize([0.4914, 0.4822, 0.4465],\n",
" [0.2023, 0.1994, 0.2010])])"
]
},
{
"cell_type": "markdown",
"id": "e743a95e",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Read the organized dataset consisting of raw image files"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "056ac33a",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:29.722917Z",
"iopub.status.busy": "2023-08-18T19:29:29.722506Z",
"iopub.status.idle": "2023-08-18T19:29:29.733889Z",
"shell.execute_reply": "2023-08-18T19:29:29.733119Z"
},
"origin_pos": 23,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(\n",
" os.path.join(data_dir, 'train_valid_test', folder),\n",
" transform=transform_train) for folder in ['train', 'train_valid']]\n",
"\n",
"valid_ds, test_ds = [torchvision.datasets.ImageFolder(\n",
" os.path.join(data_dir, 'train_valid_test', folder),\n",
" transform=transform_test) for folder in ['valid', 'test']]"
]
},
{
"cell_type": "markdown",
"id": "d0ab6608",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Specify all the image augmentation operations defined above"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "06fa7207",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:29.738557Z",
"iopub.status.busy": "2023-08-18T19:29:29.737952Z",
"iopub.status.idle": "2023-08-18T19:29:29.743073Z",
"shell.execute_reply": "2023-08-18T19:29:29.742323Z"
},
"origin_pos": 26,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"train_iter, train_valid_iter = [torch.utils.data.DataLoader(\n",
" dataset, batch_size, shuffle=True, drop_last=True)\n",
" for dataset in (train_ds, train_valid_ds)]\n",
"\n",
"valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,\n",
" drop_last=True)\n",
"\n",
"test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,\n",
" drop_last=False)"
]
},
{
"cell_type": "markdown",
"id": "1c95d19b",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Model"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "d527425d",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:29.747678Z",
"iopub.status.busy": "2023-08-18T19:29:29.747059Z",
"iopub.status.idle": "2023-08-18T19:29:29.751129Z",
"shell.execute_reply": "2023-08-18T19:29:29.750380Z"
},
"origin_pos": 35,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def get_net():\n",
" num_classes = 10\n",
" net = d2l.resnet18(num_classes, 3)\n",
" return net\n",
"\n",
"loss = nn.CrossEntropyLoss(reduction=\"none\")"
]
},
{
"cell_type": "markdown",
"id": "f6c66fef",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Training Function"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "bde40789",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:29.755665Z",
"iopub.status.busy": "2023-08-18T19:29:29.755131Z",
"iopub.status.idle": "2023-08-18T19:29:29.764392Z",
"shell.execute_reply": "2023-08-18T19:29:29.763621Z"
},
"origin_pos": 38,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n",
" lr_decay):\n",
" trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,\n",
" weight_decay=wd)\n",
" scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)\n",
" num_batches, timer = len(train_iter), d2l.Timer()\n",
" legend = ['train loss', 'train acc']\n",
" if valid_iter is not None:\n",
" legend.append('valid acc')\n",
" animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n",
" legend=legend)\n",
" net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n",
" for epoch in range(num_epochs):\n",
" net.train()\n",
" metric = d2l.Accumulator(3)\n",
" for i, (features, labels) in enumerate(train_iter):\n",
" timer.start()\n",
" l, acc = d2l.train_batch_ch13(net, features, labels,\n",
" loss, trainer, devices)\n",
" metric.add(l, acc, labels.shape[0])\n",
" timer.stop()\n",
" if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n",
" animator.add(epoch + (i + 1) / num_batches,\n",
" (metric[0] / metric[2], metric[1] / metric[2],\n",
" None))\n",
" if valid_iter is not None:\n",
" valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)\n",
" animator.add(epoch + 1, (None, None, valid_acc))\n",
" scheduler.step()\n",
" measures = (f'train loss {metric[0] / metric[2]:.3f}, '\n",
" f'train acc {metric[1] / metric[2]:.3f}')\n",
" if valid_iter is not None:\n",
" measures += f', valid acc {valid_acc:.3f}'\n",
" print(measures + f'\\n{metric[2] * num_epochs / timer.sum():.1f}'\n",
" f' examples/sec on {str(devices)}')"
]
},
{
"cell_type": "markdown",
"id": "c12f609a",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Training and Validating the Model"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "cd4a55c7",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:29:29.768734Z",
"iopub.status.busy": "2023-08-18T19:29:29.768227Z",
"iopub.status.idle": "2023-08-18T19:30:37.496878Z",
"shell.execute_reply": "2023-08-18T19:30:37.495860Z"
},
"origin_pos": 41,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss 0.654, train acc 0.789, valid acc 0.438\n",
"958.1 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"