{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"17_TL-ants-bees-classification.ipynb","provenance":[],"authorship_tag":"ABX9TyOGTId2qIBFmh8Io3/6dOFc"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"MrAbSEftR2qn"},"source":["## 15.7.2 实战项目: 蚂蚁和蜜蜂的分类问题\n","\n","今天我们要解决的问题是通过迁移学习训练一个模型来实现蚂蚁和蜜蜂的分类。如果从头开始训练的话,这是一个非常小的数据集,就算做了数据增强也难以达到很好的效果。因此我们引入迁移学习的方法,采用在 ImageNet 上训练过的 resnet18 作为我们的预训练模型。"]},{"cell_type":"markdown","metadata":{"id":"E0U1NNo_R6e1"},"source":["### 15.7.2.1 下载数据\n","\n","> - 推荐:https://download.csdn.net/download/qq_42951560/13201074\n","> - 备用:https://ghgxj.lanzous.com/i9EGHiv97za\n","\n","imagenet数据集三通道的均值和标准差分别是:$[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]$。\n","\n","该数据集是imagenet非常小的一个子集。只包含蚂蚁和蜜蜂两类。\n","\n","所以数据标准化Normalize的时候我们也继承使用imagenet的均值和标准差。\n","\n","\n","|种类\t|训练集\t|验证集|\n","|--|--|--|\n","|蚂蚁|\t123|\t70|\n","|蜜蜂|\t121|\t83|\n","|总计|\t244|\t153|"]},{"cell_type":"markdown","metadata":{"id":"QFTYiicCR9t3"},"source":["### 导入模块"]},{"cell_type":"code","metadata":{"id":"vCCr2lo2R1un","executionInfo":{"status":"ok","timestamp":1621673866162,"user_tz":-480,"elapsed":3234,"user":{"displayName":"Charmve","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gj9zT4vUzycNuqCVAvPYhbS-GWoSOslxHjcLtv4=s64","userId":"07530046818488914519"}}},"source":["import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","from torch.utils.data import DataLoader\n","from torch.optim import lr_scheduler\n","import torchvision\n","from torchvision import datasets, models, transforms\n","import numpy as np\n","import matplotlib.pyplot as plt\n","import os\n","import time\n","import copy"],"execution_count":1,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"2Aenq3EISAow"},"source":["### 数据增强"]},{"cell_type":"code","metadata":{"id":"vSi3FFw6Pacn","executionInfo":{"status":"ok","timestamp":1621673874696,"user_tz":-480,"elapsed":477,"user":{"displayName":"Charmve","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gj9zT4vUzycNuqCVAvPYhbS-GWoSOslxHjcLtv4=s64","userId":"07530046818488914519"}}},"source":["data_transforms = {\n"," 'train': transforms.Compose([\n"," transforms.RandomResizedCrop(224),\n"," transforms.RandomHorizontalFlip(),\n"," transforms.ToTensor(),\n"," transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n"," ]),\n"," 'val': transforms.Compose([\n"," transforms.Resize(256),\n"," transforms.CenterCrop(224),\n"," transforms.ToTensor(),\n"," transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n"," ])\n","}"],"execution_count":2,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ci1sxca5SDn_"},"source":["### 制作数据集"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":365},"id":"8k0D3kN5SEp3","executionInfo":{"status":"error","timestamp":1621673877692,"user_tz":-480,"elapsed":4,"user":{"displayName":"Charmve","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gj9zT4vUzycNuqCVAvPYhbS-GWoSOslxHjcLtv4=s64","userId":"07530046818488914519"}},"outputId":"ce51a3b0-6bc8-4baf-8c4f-30d20311bb90"},"source":["datasets_path = './dataset'\n","image_datasets = {\n"," x: datasets.ImageFolder(\n"," root=os.path.join('./dataset', x),\n"," transform=data_transforms[x]\n"," ) for x in ['train', 'val']\n","}"],"execution_count":3,"outputs":[{"output_type":"error","ename":"FileNotFoundError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'./dataset'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata_transforms\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m ) for x in ['train', 'val']\n\u001b[0m\u001b[1;32m 6\u001b[0m }\n","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'./dataset'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata_transforms\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m ) for x in ['train', 'val']\n\u001b[0m\u001b[1;32m 6\u001b[0m }\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchvision/datasets/folder.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, transform, target_transform, loader, is_valid_file)\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[0mtarget_transform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtarget_transform\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 256\u001b[0;31m is_valid_file=is_valid_file)\n\u001b[0m\u001b[1;32m 257\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimgs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msamples\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchvision/datasets/folder.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root, loader, extensions, transform, target_transform, is_valid_file)\u001b[0m\n\u001b[1;32m 124\u001b[0m super(DatasetFolder, self).__init__(root, transform=transform,\n\u001b[1;32m 125\u001b[0m target_transform=target_transform)\n\u001b[0;32m--> 126\u001b[0;31m \u001b[0mclasses\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclass_to_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_find_classes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 127\u001b[0m \u001b[0msamples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclass_to_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextensions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_valid_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msamples\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torchvision/datasets/folder.py\u001b[0m in \u001b[0;36m_find_classes\u001b[0;34m(self, dir)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0mNo\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0ma\u001b[0m \u001b[0msubdirectory\u001b[0m \u001b[0mof\u001b[0m \u001b[0manother\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \"\"\"\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0mclasses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0md\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscandir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdir\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 165\u001b[0m \u001b[0mclasses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msort\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0mclass_to_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mcls_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls_name\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclasses\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './dataset/train'"]}]},{"cell_type":"markdown","metadata":{"id":"DSWDFzSoSFrn"},"source":["### 数据加载器"]},{"cell_type":"code","metadata":{"id":"r-ujhIOaSFWI"},"source":["dataloaders = {\n"," x: DataLoader(\n"," dataset=image_datasets[x],\n"," batch_size=4,\n"," shuffle=True,\n"," num_workers=0\n"," ) for x in ['train', 'val']\n","}"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"aJGrcHtNSE_g"},"source":["### 训练数据可视化"]},{"cell_type":"code","metadata":{"id":"l2fBfXOvSSY5"},"source":["inputs, labels = next(iter(dataloaders['train']))\n","grid_images = torchvision.utils.make_grid(inputs)\n","\n","def no_normalize(im):\n"," im = im.permute(1, 2, 0)\n"," im = im*torch.Tensor([0.229, 0.224, 0.225])+torch.Tensor([0.485, 0.456, 0.406])\n"," return im\n","\n","grid_images = no_normalize(grid_images)\n","plt.title([class_names[x] for x in labels])\n","plt.imshow(grid_images)\n","plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"WmvpjB-oSPpZ"},"source":["### 训练模型\n","之前提到过,迁移学习有两种常见的方法,我们就简单的称之为参数微调和特征提取吧。下面,我们将分别使用这两种方法来训练我们的模型,最后再进行对比分析。两种方法用同一个函数训练,只不过传的参数不同。公用的训练函数如下:"]},{"cell_type":"code","metadata":{"id":"oqo_d1RCSXoR"},"source":["def train_model(model, criterion, optimizer, scheduler, num_epochs=10):\n"," t1 = time.time()\n","\n"," best_model_wts = copy.deepcopy(model.state_dict())\n"," best_acc = 0.0\n","\n"," for epoch in range(num_epochs):\n"," lr = optimizer.param_groups[0]['lr']\n"," print(\n"," f'EPOCH: {epoch+1:0>{len(str(num_epochs))}}/{num_epochs}',\n"," f'LR: {lr:.4f}',\n"," end=' '\n"," )\n"," # 每轮都需要训练和评估\n"," for phase in ['train', 'val']:\n"," if phase == 'train':\n"," model.train() # 将模型设置为训练模式\n"," else:\n"," model.eval() # 将模型设置为评估模式\n","\n"," running_loss = 0.0\n"," running_corrects = 0\n","\n"," # 遍历数据\n"," for inputs, labels in dataloaders[phase]:\n"," inputs = inputs.to(device)\n"," labels = labels.to(device)\n","\n"," # 梯度归零\n"," optimizer.zero_grad()\n","\n"," # 前向传播\n"," with torch.set_grad_enabled(phase == 'train'):\n"," outputs = model(inputs)\n"," preds = outputs.argmax(1)\n"," loss = criterion(outputs, labels)\n","\n"," # 反向传播+参数更新\n"," if phase == 'train':\n"," loss.backward()\n"," optimizer.step()\n","\n"," # 统计\n"," running_loss += loss.item() * inputs.size(0)\n"," running_corrects += (preds == labels.data).sum()\n"," if phase == 'train':\n"," # 调整学习率\n"," scheduler.step()\n","\n"," epoch_loss = running_loss / dataset_sizes[phase]\n"," epoch_acc = running_corrects.double() / dataset_sizes[phase]\n","\n"," # 打印训练过程\n"," if phase == 'train':\n"," print(\n"," f'LOSS: {epoch_loss:.4f}',\n"," f'ACC: {epoch_acc:.4f} ',\n"," end=' '\n"," )\n"," else:\n"," print(\n"," f'VAL-LOSS: {epoch_loss:.4f}',\n"," f'VAL-ACC: {epoch_acc:.4f} ',\n"," end='\\n'\n"," )\n","\n"," # 深度拷贝模型参数\n"," if phase == 'val' and epoch_acc > best_acc:\n"," best_acc = epoch_acc\n"," best_model_wts = copy.deepcopy(model.state_dict())\n","\n"," t2 = time.time()\n"," total_time = t2-t1\n"," print('-'*10)\n"," print(\n"," f'TOTAL-TIME: {total_time//60:.0f}m{total_time%60:.0f}s',\n"," f'BEST-VAL-ACC: {best_acc:.4f}'\n"," )\n"," # 加载最佳的模型权重\n"," model.load_state_dict(best_model_wts)\n"," return model"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xGvs0fhASYLR"},"source":["\n","#### 参数微调的方法\n","\n","该方法使用预训练的参数来初始化我们的网络模型,修改全连接层后再训练所有层。"]},{"cell_type":"code","metadata":{"id":"1QAxXgMbSYbh"},"source":["# 加载预训练模型\n","model_ft = models.resnet18(pretrained=True)\n","\n","# 获取resnet18的全连接层的输入特征数\n","num_ftrs = model_ft.fc.in_features\n","\n","# 调整全连接层的输出特征数为2\n","model_ft.fc = nn.Linear(num_ftrs, len(class_names))\n","\n","# 将模型放到GPU/CPU\n","model_ft = model_ft.to(device)\n","\n","# 定义损失函数\n","criterion = nn.CrossEntropyLoss()\n","\n","# 选择优化器\n","optimizer_ft = optim.SGD(model_ft.parameters(), lr=1e-3, momentum=0.9)\n","\n","# 定义优化器器调整策略,每5轮后学习率下调0.1个乘法因子\n","exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.1)\n","\n","# 调用训练函数训练\n","model_ft = train_model(\n"," model_ft, \n"," criterion, \n"," optimizer_ft, \n"," exp_lr_scheduler,\n"," num_epochs=10\n",")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"k9vHTxsRSZF5"},"source":["#### 特征提取的方法\n","该方法冻结除全连接层外的所有层的权重,修改全连接层后仅训练全连接层。"]},{"cell_type":"code","metadata":{"id":"f0MwJuqoSZVp"},"source":["# 加载预训练模型\n","model_conv = models.resnet18(pretrained=True)\n","\n","# 冻结除全连接层外的所有层, 使其梯度不会在反向传播中计算\n","for param in model_conv.parameters():\n"," param.requires_grad = False\n","\n","# 获取resnet18的全连接层的输入特征数\n","num_ftrs = model_conv.fc.in_features\n","\n","# 调整全连接层的输出特征数为2\n","model_conv.fc = nn.Linear(num_ftrs, 2)\n","\n","# 将模型放到GPU/CPU\n","model_conv = model_conv.to(device)\n","\n","# 定义损失函数\n","criterion = nn.CrossEntropyLoss()\n","\n","# 选择优化器, 只传全连接层的参数\n","optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=1e-3, momentum=0.9)\n","\n","# 定义优化器器调整策略,每5轮后学习率下调0.1个乘法因子\n","exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=5, gamma=0.1)\n","\n","# 调用训练函数训练\n","model_conv = train_model(\n"," model_conv,\n"," criterion,\n"," optimizer_conv,\n"," exp_lr_scheduler,\n"," num_epochs=10\n",")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MMYvpHtcSjbZ"},"source":["### 验证结果可视化"]},{"cell_type":"code","metadata":{"id":"Rnnuu06xSjxh"},"source":["def visualize_model(model):\n"," model.eval()\n"," with torch.no_grad():\n"," inputs, labels = next(iter(dataloaders['val']))\n"," inputs = inputs.to(device)\n"," labels = labels.to(device)\n","\n"," outputs = model(inputs)\n"," preds = outputs.argmax(1)\n","\n"," plt.figure(figsize=(9, 9))\n"," for i in range(inputs.size(0)):\n"," plt.subplot(2,2,i+1)\n"," plt.axis('off')\n"," plt.title(f'pred: {class_names[preds[i]]}|true: {class_names[labels[i]]}')\n"," im = no_normalize(inputs[i].cpu())\n"," plt.imshow(im)\n"," plt.savefig('train.jpg')\n","\t\tplt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Mr7kRjjhSpOj"},"source":["### 保存模型"]},{"cell_type":"code","metadata":{"id":"fg6-psq2Spnk"},"source":["torch.save(model_conv.state_dict(), 'model.pt')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"sDrq507wSsmD"},"source":["### 加载模型"]},{"cell_type":"code","metadata":{"id":"suxrsVVjSs9j"},"source":["device = torch.device('cpu')\n","model = models.resnet18(pretrained=False)\n","num_ftrs = model.fc.in_features\n","model.fc = nn.Linear(num_ftrs, len(class_names))\n","model.load_state_dict(torch.load('model.pt', map_location=device))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"CsdgN6NhSwX7"},"source":["### 测试模型\n","百度或必应图片中随便找几张张蚂蚁和蜜蜂的图片,或者用手机拍几张照片也行。用上一步加载的模型测试一下分类的效果。"]}]}