{
"cells": [
{
"cell_type": "markdown",
"id": "e10fb2c2",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# 锚框\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e079962e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:42.677769Z",
"iopub.status.busy": "2023-08-18T07:00:42.676695Z",
"iopub.status.idle": "2023-08-18T07:00:45.106116Z",
"shell.execute_reply": "2023-08-18T07:00:45.104773Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import torch\n",
"from d2l import torch as d2l\n",
"\n",
"torch.set_printoptions(2)"
]
},
{
"cell_type": "markdown",
"id": "b96c5129",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"锚框的宽度和高度分别是$hs\\sqrt{r}$和$hs/\\sqrt{r}$。\n",
"我们只考虑\n",
"组合:\n",
"\n",
"$$(s_1, r_1), (s_1, r_2), \\ldots, (s_1, r_m), (s_2, r_1), (s_3, r_1), \\ldots, (s_n, r_1)$$"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4c5fb635",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:45.112186Z",
"iopub.status.busy": "2023-08-18T07:00:45.111657Z",
"iopub.status.idle": "2023-08-18T07:00:45.126939Z",
"shell.execute_reply": "2023-08-18T07:00:45.125859Z"
},
"origin_pos": 6,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def multibox_prior(data, sizes, ratios):\n",
" \"\"\"生成以每个像素为中心具有不同形状的锚框\"\"\"\n",
" in_height, in_width = data.shape[-2:]\n",
" device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)\n",
" boxes_per_pixel = (num_sizes + num_ratios - 1)\n",
" size_tensor = torch.tensor(sizes, device=device)\n",
" ratio_tensor = torch.tensor(ratios, device=device)\n",
"\n",
" offset_h, offset_w = 0.5, 0.5\n",
" steps_h = 1.0 / in_height\n",
" steps_w = 1.0 / in_width\n",
"\n",
" center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h\n",
" center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w\n",
" shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')\n",
" shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)\n",
"\n",
" w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),\n",
" sizes[0] * torch.sqrt(ratio_tensor[1:])))\\\n",
" * in_height / in_width\n",
" h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),\n",
" sizes[0] / torch.sqrt(ratio_tensor[1:])))\n",
" anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(\n",
" in_height * in_width, 1) / 2\n",
"\n",
" out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],\n",
" dim=1).repeat_interleave(boxes_per_pixel, dim=0)\n",
" output = out_grid + anchor_manipulations\n",
" return output.unsqueeze(0)"
]
},
{
"cell_type": "markdown",
"id": "27883d90",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"返回的锚框变量`Y`的形状"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f411d4af",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:45.131714Z",
"iopub.status.busy": "2023-08-18T07:00:45.131003Z",
"iopub.status.idle": "2023-08-18T07:00:45.238891Z",
"shell.execute_reply": "2023-08-18T07:00:45.237843Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"561 728\n"
]
},
{
"data": {
"text/plain": [
"torch.Size([1, 2042040, 4])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img = d2l.plt.imread('../img/catdog.jpg')\n",
"h, w = img.shape[:2]\n",
"\n",
"print(h, w)\n",
"X = torch.rand(size=(1, 3, h, w))\n",
"Y = multibox_prior(X, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5])\n",
"Y.shape"
]
},
{
"cell_type": "markdown",
"id": "3ebaca32",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"访问以(250,250)为中心的第一个锚框"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a7b7cfa3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:45.244522Z",
"iopub.status.busy": "2023-08-18T07:00:45.243982Z",
"iopub.status.idle": "2023-08-18T07:00:45.252916Z",
"shell.execute_reply": "2023-08-18T07:00:45.251985Z"
},
"origin_pos": 13,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.06, 0.07, 0.63, 0.82])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"boxes = Y.reshape(h, w, 5, 4)\n",
"boxes[250, 250, 0, :]"
]
},
{
"cell_type": "markdown",
"id": "afa57c6b",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"显示以图像中以某个像素为中心的所有锚框"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c199e557",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:45.272415Z",
"iopub.status.busy": "2023-08-18T07:00:45.271753Z",
"iopub.status.idle": "2023-08-18T07:00:45.634073Z",
"shell.execute_reply": "2023-08-18T07:00:45.632866Z"
},
"origin_pos": 18,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"def show_bboxes(axes, bboxes, labels=None, colors=None):\n",
" \"\"\"显示所有边界框\"\"\"\n",
" def _make_list(obj, default_values=None):\n",
" if obj is None:\n",
" obj = default_values\n",
" elif not isinstance(obj, (list, tuple)):\n",
" obj = [obj]\n",
" return obj\n",
"\n",
" labels = _make_list(labels)\n",
" colors = _make_list(colors, ['b', 'g', 'r', 'm', 'c'])\n",
" for i, bbox in enumerate(bboxes):\n",
" color = colors[i % len(colors)]\n",
" rect = d2l.bbox_to_rect(bbox.detach().numpy(), color)\n",
" axes.add_patch(rect)\n",
" if labels and len(labels) > i:\n",
" text_color = 'k' if color == 'w' else 'w'\n",
" axes.text(rect.xy[0], rect.xy[1], labels[i],\n",
" va='center', ha='center', fontsize=9, color=text_color,\n",
" bbox=dict(facecolor=color, lw=0))\n",
"\n",
"d2l.set_figsize()\n",
"bbox_scale = torch.tensor((w, h, w, h))\n",
"fig = d2l.plt.imshow(img)\n",
"show_bboxes(fig.axes, boxes[250, 250, :, :] * bbox_scale,\n",
" ['s=0.75, r=1', 's=0.5, r=1', 's=0.25, r=1', 's=0.75, r=2',\n",
" 's=0.75, r=0.5'])"
]
},
{
"cell_type": "markdown",
"id": "3fbf419b",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"交并比(IoU)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "feab924c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:45.639984Z",
"iopub.status.busy": "2023-08-18T07:00:45.639243Z",
"iopub.status.idle": "2023-08-18T07:00:45.649165Z",
"shell.execute_reply": "2023-08-18T07:00:45.648025Z"
},
"origin_pos": 21,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def box_iou(boxes1, boxes2):\n",
" \"\"\"计算两个锚框或边界框列表中成对的交并比\"\"\"\n",
" box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *\n",
" (boxes[:, 3] - boxes[:, 1]))\n",
" areas1 = box_area(boxes1)\n",
" areas2 = box_area(boxes2)\n",
" inter_upperlefts = torch.max(boxes1[:, None, :2], boxes2[:, :2])\n",
" inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])\n",
" inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)\n",
" inter_areas = inters[:, :, 0] * inters[:, :, 1]\n",
" union_areas = areas1[:, None] + areas2 - inter_areas\n",
" return inter_areas / union_areas"
]
},
{
"cell_type": "markdown",
"id": "e5e46ca6",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"将真实边界框分配给锚框"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8ac4113a",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:45.654100Z",
"iopub.status.busy": "2023-08-18T07:00:45.653303Z",
"iopub.status.idle": "2023-08-18T07:00:45.664045Z",
"shell.execute_reply": "2023-08-18T07:00:45.663020Z"
},
"origin_pos": 25,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def assign_anchor_to_bbox(ground_truth, anchors, device, iou_threshold=0.5):\n",
" \"\"\"将最接近的真实边界框分配给锚框\"\"\"\n",
" num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]\n",
" jaccard = box_iou(anchors, ground_truth)\n",
" anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long,\n",
" device=device)\n",
" max_ious, indices = torch.max(jaccard, dim=1)\n",
" anc_i = torch.nonzero(max_ious >= iou_threshold).reshape(-1)\n",
" box_j = indices[max_ious >= iou_threshold]\n",
" anchors_bbox_map[anc_i] = box_j\n",
" col_discard = torch.full((num_anchors,), -1)\n",
" row_discard = torch.full((num_gt_boxes,), -1)\n",
" for _ in range(num_gt_boxes):\n",
" max_idx = torch.argmax(jaccard)\n",
" box_idx = (max_idx % num_gt_boxes).long()\n",
" anc_idx = (max_idx / num_gt_boxes).long()\n",
" anchors_bbox_map[anc_idx] = box_idx\n",
" jaccard[:, box_idx] = col_discard\n",
" jaccard[anc_idx, :] = row_discard\n",
" return anchors_bbox_map"
]
},
{
"cell_type": "markdown",
"id": "592381a5",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"给定框$A$和$B$,中心坐标分别为$(x_a, y_a)$和$(x_b, y_b)$,宽度分别为$w_a$和$w_b$,高度分别为$h_a$和$h_b$,可以将$A$的偏移量标记为:\n",
"\n",
"$$\\left( \\frac{ \\frac{x_b - x_a}{w_a} - \\mu_x }{\\sigma_x},\n",
"\\frac{ \\frac{y_b - y_a}{h_a} - \\mu_y }{\\sigma_y},\n",
"\\frac{ \\log \\frac{w_b}{w_a} - \\mu_w }{\\sigma_w},\n",
"\\frac{ \\log \\frac{h_b}{h_a} - \\mu_h }{\\sigma_h}\\right)$$"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a5b6bbdc",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:45.668909Z",
"iopub.status.busy": "2023-08-18T07:00:45.668107Z",
"iopub.status.idle": "2023-08-18T07:00:45.675569Z",
"shell.execute_reply": "2023-08-18T07:00:45.674509Z"
},
"origin_pos": 28,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def offset_boxes(anchors, assigned_bb, eps=1e-6):\n",
" \"\"\"对锚框偏移量的转换\"\"\"\n",
" c_anc = d2l.box_corner_to_center(anchors)\n",
" c_assigned_bb = d2l.box_corner_to_center(assigned_bb)\n",
" offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]\n",
" offset_wh = 5 * torch.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])\n",
" offset = torch.cat([offset_xy, offset_wh], axis=1)\n",
" return offset"
]
},
{
"cell_type": "markdown",
"id": "94c6aea8",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"标记锚框的类别和偏移量"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "291738a2",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:45.680339Z",
"iopub.status.busy": "2023-08-18T07:00:45.679651Z",
"iopub.status.idle": "2023-08-18T07:00:45.692690Z",
"shell.execute_reply": "2023-08-18T07:00:45.691647Z"
},
"origin_pos": 31,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def multibox_target(anchors, labels):\n",
" \"\"\"使用真实边界框标记锚框\"\"\"\n",
" batch_size, anchors = labels.shape[0], anchors.squeeze(0)\n",
" batch_offset, batch_mask, batch_class_labels = [], [], []\n",
" device, num_anchors = anchors.device, anchors.shape[0]\n",
" for i in range(batch_size):\n",
" label = labels[i, :, :]\n",
" anchors_bbox_map = assign_anchor_to_bbox(\n",
" label[:, 1:], anchors, device)\n",
" bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(\n",
" 1, 4)\n",
" class_labels = torch.zeros(num_anchors, dtype=torch.long,\n",
" device=device)\n",
" assigned_bb = torch.zeros((num_anchors, 4), dtype=torch.float32,\n",
" device=device)\n",
" indices_true = torch.nonzero(anchors_bbox_map >= 0)\n",
" bb_idx = anchors_bbox_map[indices_true]\n",
" class_labels[indices_true] = label[bb_idx, 0].long() + 1\n",
" assigned_bb[indices_true] = label[bb_idx, 1:]\n",
" offset = offset_boxes(anchors, assigned_bb) * bbox_mask\n",
" batch_offset.append(offset.reshape(-1))\n",
" batch_mask.append(bbox_mask.reshape(-1))\n",
" batch_class_labels.append(class_labels)\n",
" bbox_offset = torch.stack(batch_offset)\n",
" bbox_mask = torch.stack(batch_mask)\n",
" class_labels = torch.stack(batch_class_labels)\n",
" return (bbox_offset, bbox_mask, class_labels)"
]
},
{
"cell_type": "markdown",
"id": "524b87ff",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"在图像中绘制这些真实边界框和锚框"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "e22f46a6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:45.697286Z",
"iopub.status.busy": "2023-08-18T07:00:45.696846Z",
"iopub.status.idle": "2023-08-18T07:00:46.054321Z",
"shell.execute_reply": "2023-08-18T07:00:46.053210Z"
},
"origin_pos": 34,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"ground_truth = torch.tensor([[0, 0.1, 0.08, 0.52, 0.92],\n",
" [1, 0.55, 0.2, 0.9, 0.88]])\n",
"anchors = torch.tensor([[0, 0.1, 0.2, 0.3], [0.15, 0.2, 0.4, 0.4],\n",
" [0.63, 0.05, 0.88, 0.98], [0.66, 0.45, 0.8, 0.8],\n",
" [0.57, 0.3, 0.92, 0.9]])\n",
"\n",
"fig = d2l.plt.imshow(img)\n",
"show_bboxes(fig.axes, ground_truth[:, 1:] * bbox_scale, ['dog', 'cat'], 'k')\n",
"show_bboxes(fig.axes, anchors * bbox_scale, ['0', '1', '2', '3', '4']);"
]
},
{
"cell_type": "markdown",
"id": "f0319b92",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"根据狗和猫的真实边界框,标注这些锚框的分类和偏移量"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "78a2e091",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:46.075902Z",
"iopub.status.busy": "2023-08-18T07:00:46.075000Z",
"iopub.status.idle": "2023-08-18T07:00:46.084256Z",
"shell.execute_reply": "2023-08-18T07:00:46.083114Z"
},
"origin_pos": 40,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 1, 2, 0, 2]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels = multibox_target(anchors.unsqueeze(dim=0),\n",
" ground_truth.unsqueeze(dim=0))\n",
"\n",
"labels[2]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "94a1597f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:46.097601Z",
"iopub.status.busy": "2023-08-18T07:00:46.097228Z",
"iopub.status.idle": "2023-08-18T07:00:46.104883Z",
"shell.execute_reply": "2023-08-18T07:00:46.103744Z"
},
"origin_pos": 42,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.,\n",
" 1., 1.]])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels[1]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "25e7f69b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:46.112301Z",
"iopub.status.busy": "2023-08-18T07:00:46.111934Z",
"iopub.status.idle": "2023-08-18T07:00:46.118802Z",
"shell.execute_reply": "2023-08-18T07:00:46.117910Z"
},
"origin_pos": 44,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, 1.40e+00, 1.00e+01,\n",
" 2.59e+00, 7.18e+00, -1.20e+00, 2.69e-01, 1.68e+00, -1.57e+00,\n",
" -0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, -5.71e-01, -1.00e+00,\n",
" 4.17e-06, 6.26e-01]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels[0]"
]
},
{
"cell_type": "markdown",
"id": "bff0477c",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"应用逆偏移变换来返回预测的边界框坐标"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "227ae1a8",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:46.127370Z",
"iopub.status.busy": "2023-08-18T07:00:46.127014Z",
"iopub.status.idle": "2023-08-18T07:00:46.133873Z",
"shell.execute_reply": "2023-08-18T07:00:46.132925Z"
},
"origin_pos": 46,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def offset_inverse(anchors, offset_preds):\n",
" \"\"\"根据带有预测偏移量的锚框来预测边界框\"\"\"\n",
" anc = d2l.box_corner_to_center(anchors)\n",
" pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]\n",
" pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]\n",
" pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)\n",
" predicted_bbox = d2l.box_center_to_corner(pred_bbox)\n",
" return predicted_bbox"
]
},
{
"cell_type": "markdown",
"id": "83ebeef8",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"以下`nms`函数按降序对置信度进行排序并返回其索引"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "ac5c4e3c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:46.138652Z",
"iopub.status.busy": "2023-08-18T07:00:46.138060Z",
"iopub.status.idle": "2023-08-18T07:00:46.151360Z",
"shell.execute_reply": "2023-08-18T07:00:46.150447Z"
},
"origin_pos": 49,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def nms(boxes, scores, iou_threshold):\n",
" \"\"\"对预测边界框的置信度进行排序\"\"\"\n",
" B = torch.argsort(scores, dim=-1, descending=True)\n",
" keep = []\n",
" while B.numel() > 0:\n",
" i = B[0]\n",
" keep.append(i)\n",
" if B.numel() == 1: break\n",
" iou = box_iou(boxes[i, :].reshape(-1, 4),\n",
" boxes[B[1:], :].reshape(-1, 4)).reshape(-1)\n",
" inds = torch.nonzero(iou <= iou_threshold).reshape(-1)\n",
" B = B[inds + 1]\n",
" return torch.tensor(keep, device=boxes.device)"
]
},
{
"cell_type": "markdown",
"id": "d5097b86",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"将非极大值抑制应用于预测边界框"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "baa9f34f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:46.157117Z",
"iopub.status.busy": "2023-08-18T07:00:46.156701Z",
"iopub.status.idle": "2023-08-18T07:00:46.175104Z",
"shell.execute_reply": "2023-08-18T07:00:46.174174Z"
},
"origin_pos": 53,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,\n",
" pos_threshold=0.009999999):\n",
" \"\"\"使用非极大值抑制来预测边界框\"\"\"\n",
" device, batch_size = cls_probs.device, cls_probs.shape[0]\n",
" anchors = anchors.squeeze(0)\n",
" num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]\n",
" out = []\n",
" for i in range(batch_size):\n",
" cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)\n",
" conf, class_id = torch.max(cls_prob[1:], 0)\n",
" predicted_bb = offset_inverse(anchors, offset_pred)\n",
" keep = nms(predicted_bb, conf, nms_threshold)\n",
"\n",
" all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)\n",
" combined = torch.cat((keep, all_idx))\n",
" uniques, counts = combined.unique(return_counts=True)\n",
" non_keep = uniques[counts == 1]\n",
" all_id_sorted = torch.cat((keep, non_keep))\n",
" class_id[non_keep] = -1\n",
" class_id = class_id[all_id_sorted]\n",
" conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]\n",
" below_min_idx = (conf < pos_threshold)\n",
" class_id[below_min_idx] = -1\n",
" conf[below_min_idx] = 1 - conf[below_min_idx]\n",
" pred_info = torch.cat((class_id.unsqueeze(1),\n",
" conf.unsqueeze(1),\n",
" predicted_bb), dim=1)\n",
" out.append(pred_info)\n",
" return torch.stack(out)"
]
},
{
"cell_type": "markdown",
"id": "0492308f",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"将上述算法应用到一个带有四个锚框的具体示例中"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "4654e2f7",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:46.179666Z",
"iopub.status.busy": "2023-08-18T07:00:46.179298Z",
"iopub.status.idle": "2023-08-18T07:00:46.188634Z",
"shell.execute_reply": "2023-08-18T07:00:46.187737Z"
},
"origin_pos": 56,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"anchors = torch.tensor([[0.1, 0.08, 0.52, 0.92], [0.08, 0.2, 0.56, 0.95],\n",
" [0.15, 0.3, 0.62, 0.91], [0.55, 0.2, 0.9, 0.88]])\n",
"offset_preds = torch.tensor([0] * anchors.numel())\n",
"cls_probs = torch.tensor([[0] * 4,\n",
" [0.9, 0.8, 0.7, 0.1],\n",
" [0.1, 0.2, 0.3, 0.9]])"
]
},
{
"cell_type": "markdown",
"id": "34ef9113",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"在图像上绘制这些预测边界框和置信度"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "b9619cba",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:46.193833Z",
"iopub.status.busy": "2023-08-18T07:00:46.193198Z",
"iopub.status.idle": "2023-08-18T07:00:46.435994Z",
"shell.execute_reply": "2023-08-18T07:00:46.434923Z"
},
"origin_pos": 59,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = d2l.plt.imshow(img)\n",
"show_bboxes(fig.axes, anchors * bbox_scale,\n",
" ['dog=0.9', 'dog=0.8', 'dog=0.7', 'cat=0.9'])"
]
},
{
"cell_type": "markdown",
"id": "10393ba6",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"返回结果的形状是(批量大小,锚框的数量,6)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "ab9c180c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:46.440560Z",
"iopub.status.busy": "2023-08-18T07:00:46.439738Z",
"iopub.status.idle": "2023-08-18T07:00:46.458973Z",
"shell.execute_reply": "2023-08-18T07:00:46.457996Z"
},
"origin_pos": 62,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.00, 0.90, 0.10, 0.08, 0.52, 0.92],\n",
" [ 1.00, 0.90, 0.55, 0.20, 0.90, 0.88],\n",
" [-1.00, 0.80, 0.08, 0.20, 0.56, 0.95],\n",
" [-1.00, 0.70, 0.15, 0.30, 0.62, 0.91]]])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output = multibox_detection(cls_probs.unsqueeze(dim=0),\n",
" offset_preds.unsqueeze(dim=0),\n",
" anchors.unsqueeze(dim=0),\n",
" nms_threshold=0.5)\n",
"output"
]
},
{
"cell_type": "markdown",
"id": "e8a12bcd",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"输出由非极大值抑制保存的最终预测边界框"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "f1e04b3f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T07:00:46.496750Z",
"iopub.status.busy": "2023-08-18T07:00:46.495866Z",
"iopub.status.idle": "2023-08-18T07:00:46.753536Z",
"shell.execute_reply": "2023-08-18T07:00:46.752302Z"
},
"origin_pos": 65,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = d2l.plt.imshow(img)\n",
"for i in output[0].detach().numpy():\n",
" if i[0] == -1:\n",
" continue\n",
" label = ('dog=', 'cat=')[int(i[0])] + str(i[1])\n",
" show_bboxes(fig.axes, [torch.tensor(i[2:]) * bbox_scale], label)"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}