{
"cells": [
{
"cell_type": "markdown",
"id": "8c28a85d",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# 实战Kaggle比赛:狗的品种识别(ImageNet Dogs)\n",
"\n",
"比赛网址是https://www.kaggle.com/c/dog-breed-identification"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b6e1a2a2",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:14.555794Z",
"iopub.status.busy": "2023-08-18T06:58:14.555246Z",
"iopub.status.idle": "2023-08-18T06:58:16.563976Z",
"shell.execute_reply": "2023-08-18T06:58:16.563095Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torchvision\n",
"from torch import nn\n",
"from d2l import torch as d2l"
]
},
{
"cell_type": "markdown",
"id": "19c4ff5d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"我们提供完整数据集的小规模样本"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9ecb1309",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:16.567802Z",
"iopub.status.busy": "2023-08-18T06:58:16.567412Z",
"iopub.status.idle": "2023-08-18T06:58:17.348683Z",
"shell.execute_reply": "2023-08-18T06:58:17.347865Z"
},
"origin_pos": 5,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading ../data/kaggle_dog_tiny.zip from http://d2l-data.s3-accelerate.amazonaws.com/kaggle_dog_tiny.zip...\n"
]
}
],
"source": [
"d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip',\n",
" '0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')\n",
"\n",
"demo = True\n",
"if demo:\n",
" data_dir = d2l.download_extract('dog_tiny')\n",
"else:\n",
" data_dir = os.path.join('..', 'data', 'dog-breed-identification')"
]
},
{
"cell_type": "markdown",
"id": "5dc5441a",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"整理数据集"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3b420853",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:17.352573Z",
"iopub.status.busy": "2023-08-18T06:58:17.352101Z",
"iopub.status.idle": "2023-08-18T06:58:17.685237Z",
"shell.execute_reply": "2023-08-18T06:58:17.683473Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def reorg_dog_data(data_dir, valid_ratio):\n",
" labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))\n",
" d2l.reorg_train_valid(data_dir, labels, valid_ratio)\n",
" d2l.reorg_test(data_dir)\n",
"\n",
"\n",
"batch_size = 32 if demo else 128\n",
"valid_ratio = 0.1\n",
"reorg_dog_data(data_dir, valid_ratio)"
]
},
{
"cell_type": "markdown",
"id": "4b2671b6",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"图像增广"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0b467084",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:17.704258Z",
"iopub.status.busy": "2023-08-18T06:58:17.703547Z",
"iopub.status.idle": "2023-08-18T06:58:17.710398Z",
"shell.execute_reply": "2023-08-18T06:58:17.709360Z"
},
"origin_pos": 14,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"transform_train = torchvision.transforms.Compose([\n",
" torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0),\n",
" ratio=(3.0/4.0, 4.0/3.0)),\n",
" torchvision.transforms.RandomHorizontalFlip(),\n",
" torchvision.transforms.ColorJitter(brightness=0.4,\n",
" contrast=0.4,\n",
" saturation=0.4),\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize([0.485, 0.456, 0.406],\n",
" [0.229, 0.224, 0.225])])\n",
"\n",
"transform_test = torchvision.transforms.Compose([\n",
" torchvision.transforms.Resize(256),\n",
" torchvision.transforms.CenterCrop(224),\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize([0.485, 0.456, 0.406],\n",
" [0.229, 0.224, 0.225])])"
]
},
{
"cell_type": "markdown",
"id": "9a8f744d",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"读取数据集"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8ef84d02",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:17.756485Z",
"iopub.status.busy": "2023-08-18T06:58:17.755671Z",
"iopub.status.idle": "2023-08-18T06:58:17.764122Z",
"shell.execute_reply": "2023-08-18T06:58:17.762916Z"
},
"origin_pos": 22,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(\n",
" os.path.join(data_dir, 'train_valid_test', folder),\n",
" transform=transform_train) for folder in ['train', 'train_valid']]\n",
"\n",
"valid_ds, test_ds = [torchvision.datasets.ImageFolder(\n",
" os.path.join(data_dir, 'train_valid_test', folder),\n",
" transform=transform_test) for folder in ['valid', 'test']]\n",
"\n",
"train_iter, train_valid_iter = [torch.utils.data.DataLoader(\n",
" dataset, batch_size, shuffle=True, drop_last=True)\n",
" for dataset in (train_ds, train_valid_ds)]\n",
"\n",
"valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False,\n",
" drop_last=True)\n",
"\n",
"test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False,\n",
" drop_last=False)"
]
},
{
"cell_type": "markdown",
"id": "b773e0c9",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"微调预训练模型"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1fd0cd74",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:17.769780Z",
"iopub.status.busy": "2023-08-18T06:58:17.768697Z",
"iopub.status.idle": "2023-08-18T06:58:17.777622Z",
"shell.execute_reply": "2023-08-18T06:58:17.776303Z"
},
"origin_pos": 26,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def get_net(devices):\n",
" finetune_net = nn.Sequential()\n",
" finetune_net.features = torchvision.models.resnet34(pretrained=True)\n",
" finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256),\n",
" nn.ReLU(),\n",
" nn.Linear(256, 120))\n",
" finetune_net = finetune_net.to(devices[0])\n",
" for param in finetune_net.features.parameters():\n",
" param.requires_grad = False\n",
" return finetune_net"
]
},
{
"cell_type": "markdown",
"id": "8e5c7a5f",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"计算损失"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d6936a15",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:17.783286Z",
"iopub.status.busy": "2023-08-18T06:58:17.782296Z",
"iopub.status.idle": "2023-08-18T06:58:17.791061Z",
"shell.execute_reply": "2023-08-18T06:58:17.789830Z"
},
"origin_pos": 30,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"loss = nn.CrossEntropyLoss(reduction='none')\n",
"\n",
"def evaluate_loss(data_iter, net, devices):\n",
" l_sum, n = 0.0, 0\n",
" for features, labels in data_iter:\n",
" features, labels = features.to(devices[0]), labels.to(devices[0])\n",
" outputs = net(features)\n",
" l = loss(outputs, labels)\n",
" l_sum += l.sum()\n",
" n += labels.numel()\n",
" return (l_sum / n).to('cpu')"
]
},
{
"cell_type": "markdown",
"id": "e391dd8f",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"训练函数"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "4a196c68",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:17.796668Z",
"iopub.status.busy": "2023-08-18T06:58:17.795696Z",
"iopub.status.idle": "2023-08-18T06:58:17.813822Z",
"shell.execute_reply": "2023-08-18T06:58:17.812372Z"
},
"origin_pos": 34,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,\n",
" lr_decay):\n",
" net = nn.DataParallel(net, device_ids=devices).to(devices[0])\n",
" trainer = torch.optim.SGD((param for param in net.parameters()\n",
" if param.requires_grad), lr=lr,\n",
" momentum=0.9, weight_decay=wd)\n",
" scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)\n",
" num_batches, timer = len(train_iter), d2l.Timer()\n",
" legend = ['train loss']\n",
" if valid_iter is not None:\n",
" legend.append('valid loss')\n",
" animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n",
" legend=legend)\n",
" for epoch in range(num_epochs):\n",
" metric = d2l.Accumulator(2)\n",
" for i, (features, labels) in enumerate(train_iter):\n",
" timer.start()\n",
" features, labels = features.to(devices[0]), labels.to(devices[0])\n",
" trainer.zero_grad()\n",
" output = net(features)\n",
" l = loss(output, labels).sum()\n",
" l.backward()\n",
" trainer.step()\n",
" metric.add(l, labels.shape[0])\n",
" timer.stop()\n",
" if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n",
" animator.add(epoch + (i + 1) / num_batches,\n",
" (metric[0] / metric[1], None))\n",
" measures = f'train loss {metric[0] / metric[1]:.3f}'\n",
" if valid_iter is not None:\n",
" valid_loss = evaluate_loss(valid_iter, net, devices)\n",
" animator.add(epoch + 1, (None, valid_loss.detach().cpu()))\n",
" scheduler.step()\n",
" if valid_iter is not None:\n",
" measures += f', valid loss {valid_loss:.3f}'\n",
" print(measures + f'\\n{metric[1] * num_epochs / timer.sum():.1f}'\n",
" f' examples/sec on {str(devices)}')"
]
},
{
"cell_type": "markdown",
"id": "befc0634",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"训练和验证模型"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d407d036",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T06:58:17.819464Z",
"iopub.status.busy": "2023-08-18T06:58:17.818676Z",
"iopub.status.idle": "2023-08-18T07:00:28.078597Z",
"shell.execute_reply": "2023-08-18T07:00:28.077772Z"
},
"origin_pos": 38,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss 1.237, valid loss 1.503\n",
"442.6 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"