{ "cells": [ { "cell_type": "markdown", "id": "e10fb2c2", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# 锚框\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "e079962e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:42.677769Z", "iopub.status.busy": "2023-08-18T07:00:42.676695Z", "iopub.status.idle": "2023-08-18T07:00:45.106116Z", "shell.execute_reply": "2023-08-18T07:00:45.104773Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "from d2l import torch as d2l\n", "\n", "torch.set_printoptions(2)" ] }, { "cell_type": "markdown", "id": "b96c5129", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "锚框的宽度和高度分别是$hs\\sqrt{r}$和$hs/\\sqrt{r}$。\n", "我们只考虑\n", "组合:\n", "\n", "$$(s_1, r_1), (s_1, r_2), \\ldots, (s_1, r_m), (s_2, r_1), (s_3, r_1), \\ldots, (s_n, r_1)$$" ] }, { "cell_type": "code", "execution_count": 2, "id": "4c5fb635", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.112186Z", "iopub.status.busy": "2023-08-18T07:00:45.111657Z", "iopub.status.idle": "2023-08-18T07:00:45.126939Z", "shell.execute_reply": "2023-08-18T07:00:45.125859Z" }, "origin_pos": 6, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def multibox_prior(data, sizes, ratios):\n", " \"\"\"生成以每个像素为中心具有不同形状的锚框\"\"\"\n", " in_height, in_width = data.shape[-2:]\n", " device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)\n", " boxes_per_pixel = (num_sizes + num_ratios - 1)\n", " size_tensor = torch.tensor(sizes, device=device)\n", " ratio_tensor = torch.tensor(ratios, device=device)\n", "\n", " offset_h, offset_w = 0.5, 0.5\n", " steps_h = 1.0 / in_height\n", " steps_w = 1.0 / in_width\n", "\n", " center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h\n", " center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w\n", " shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')\n", " shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)\n", "\n", " w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),\n", " sizes[0] * torch.sqrt(ratio_tensor[1:])))\\\n", " * in_height / in_width\n", " h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),\n", " sizes[0] / torch.sqrt(ratio_tensor[1:])))\n", " anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(\n", " in_height * in_width, 1) / 2\n", "\n", " out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],\n", " dim=1).repeat_interleave(boxes_per_pixel, dim=0)\n", " output = out_grid + anchor_manipulations\n", " return output.unsqueeze(0)" ] }, { "cell_type": "markdown", "id": "27883d90", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "返回的锚框变量`Y`的形状" ] }, { "cell_type": "code", "execution_count": 3, "id": "f411d4af", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.131714Z", "iopub.status.busy": "2023-08-18T07:00:45.131003Z", "iopub.status.idle": "2023-08-18T07:00:45.238891Z", "shell.execute_reply": "2023-08-18T07:00:45.237843Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "561 728\n" ] }, { "data": { "text/plain": [ "torch.Size([1, 2042040, 4])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = d2l.plt.imread('../img/catdog.jpg')\n", "h, w = img.shape[:2]\n", "\n", "print(h, w)\n", "X = torch.rand(size=(1, 3, h, w))\n", "Y = multibox_prior(X, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5])\n", "Y.shape" ] }, { "cell_type": "markdown", "id": "3ebaca32", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "访问以(250,250)为中心的第一个锚框" ] }, { "cell_type": "code", "execution_count": 4, "id": "a7b7cfa3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.244522Z", "iopub.status.busy": "2023-08-18T07:00:45.243982Z", "iopub.status.idle": "2023-08-18T07:00:45.252916Z", "shell.execute_reply": "2023-08-18T07:00:45.251985Z" }, "origin_pos": 13, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([0.06, 0.07, 0.63, 0.82])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "boxes = Y.reshape(h, w, 5, 4)\n", "boxes[250, 250, 0, :]" ] }, { "cell_type": "markdown", "id": "afa57c6b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "显示以图像中以某个像素为中心的所有锚框" ] }, { "cell_type": "code", "execution_count": 6, "id": "c199e557", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.272415Z", "iopub.status.busy": "2023-08-18T07:00:45.271753Z", "iopub.status.idle": "2023-08-18T07:00:45.634073Z", "shell.execute_reply": "2023-08-18T07:00:45.632866Z" }, "origin_pos": 18, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:45.532795\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def show_bboxes(axes, bboxes, labels=None, colors=None):\n", " \"\"\"显示所有边界框\"\"\"\n", " def _make_list(obj, default_values=None):\n", " if obj is None:\n", " obj = default_values\n", " elif not isinstance(obj, (list, tuple)):\n", " obj = [obj]\n", " return obj\n", "\n", " labels = _make_list(labels)\n", " colors = _make_list(colors, ['b', 'g', 'r', 'm', 'c'])\n", " for i, bbox in enumerate(bboxes):\n", " color = colors[i % len(colors)]\n", " rect = d2l.bbox_to_rect(bbox.detach().numpy(), color)\n", " axes.add_patch(rect)\n", " if labels and len(labels) > i:\n", " text_color = 'k' if color == 'w' else 'w'\n", " axes.text(rect.xy[0], rect.xy[1], labels[i],\n", " va='center', ha='center', fontsize=9, color=text_color,\n", " bbox=dict(facecolor=color, lw=0))\n", "\n", "d2l.set_figsize()\n", "bbox_scale = torch.tensor((w, h, w, h))\n", "fig = d2l.plt.imshow(img)\n", "show_bboxes(fig.axes, boxes[250, 250, :, :] * bbox_scale,\n", " ['s=0.75, r=1', 's=0.5, r=1', 's=0.25, r=1', 's=0.75, r=2',\n", " 's=0.75, r=0.5'])" ] }, { "cell_type": "markdown", "id": "3fbf419b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "交并比(IoU)" ] }, { "cell_type": "code", "execution_count": 7, "id": "feab924c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.639984Z", "iopub.status.busy": "2023-08-18T07:00:45.639243Z", "iopub.status.idle": "2023-08-18T07:00:45.649165Z", "shell.execute_reply": "2023-08-18T07:00:45.648025Z" }, "origin_pos": 21, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def box_iou(boxes1, boxes2):\n", " \"\"\"计算两个锚框或边界框列表中成对的交并比\"\"\"\n", " box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *\n", " (boxes[:, 3] - boxes[:, 1]))\n", " areas1 = box_area(boxes1)\n", " areas2 = box_area(boxes2)\n", " inter_upperlefts = torch.max(boxes1[:, None, :2], boxes2[:, :2])\n", " inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])\n", " inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)\n", " inter_areas = inters[:, :, 0] * inters[:, :, 1]\n", " union_areas = areas1[:, None] + areas2 - inter_areas\n", " return inter_areas / union_areas" ] }, { "cell_type": "markdown", "id": "e5e46ca6", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "将真实边界框分配给锚框" ] }, { "cell_type": "code", "execution_count": 8, "id": "8ac4113a", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.654100Z", "iopub.status.busy": "2023-08-18T07:00:45.653303Z", "iopub.status.idle": "2023-08-18T07:00:45.664045Z", "shell.execute_reply": "2023-08-18T07:00:45.663020Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def assign_anchor_to_bbox(ground_truth, anchors, device, iou_threshold=0.5):\n", " \"\"\"将最接近的真实边界框分配给锚框\"\"\"\n", " num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]\n", " jaccard = box_iou(anchors, ground_truth)\n", " anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long,\n", " device=device)\n", " max_ious, indices = torch.max(jaccard, dim=1)\n", " anc_i = torch.nonzero(max_ious >= iou_threshold).reshape(-1)\n", " box_j = indices[max_ious >= iou_threshold]\n", " anchors_bbox_map[anc_i] = box_j\n", " col_discard = torch.full((num_anchors,), -1)\n", " row_discard = torch.full((num_gt_boxes,), -1)\n", " for _ in range(num_gt_boxes):\n", " max_idx = torch.argmax(jaccard)\n", " box_idx = (max_idx % num_gt_boxes).long()\n", " anc_idx = (max_idx / num_gt_boxes).long()\n", " anchors_bbox_map[anc_idx] = box_idx\n", " jaccard[:, box_idx] = col_discard\n", " jaccard[anc_idx, :] = row_discard\n", " return anchors_bbox_map" ] }, { "cell_type": "markdown", "id": "592381a5", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "给定框$A$和$B$,中心坐标分别为$(x_a, y_a)$和$(x_b, y_b)$,宽度分别为$w_a$和$w_b$,高度分别为$h_a$和$h_b$,可以将$A$的偏移量标记为:\n", "\n", "$$\\left( \\frac{ \\frac{x_b - x_a}{w_a} - \\mu_x }{\\sigma_x},\n", "\\frac{ \\frac{y_b - y_a}{h_a} - \\mu_y }{\\sigma_y},\n", "\\frac{ \\log \\frac{w_b}{w_a} - \\mu_w }{\\sigma_w},\n", "\\frac{ \\log \\frac{h_b}{h_a} - \\mu_h }{\\sigma_h}\\right)$$" ] }, { "cell_type": "code", "execution_count": 9, "id": "a5b6bbdc", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.668909Z", "iopub.status.busy": "2023-08-18T07:00:45.668107Z", "iopub.status.idle": "2023-08-18T07:00:45.675569Z", "shell.execute_reply": "2023-08-18T07:00:45.674509Z" }, "origin_pos": 28, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def offset_boxes(anchors, assigned_bb, eps=1e-6):\n", " \"\"\"对锚框偏移量的转换\"\"\"\n", " c_anc = d2l.box_corner_to_center(anchors)\n", " c_assigned_bb = d2l.box_corner_to_center(assigned_bb)\n", " offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]\n", " offset_wh = 5 * torch.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])\n", " offset = torch.cat([offset_xy, offset_wh], axis=1)\n", " return offset" ] }, { "cell_type": "markdown", "id": "94c6aea8", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "标记锚框的类别和偏移量" ] }, { "cell_type": "code", "execution_count": 10, "id": "291738a2", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.680339Z", "iopub.status.busy": "2023-08-18T07:00:45.679651Z", "iopub.status.idle": "2023-08-18T07:00:45.692690Z", "shell.execute_reply": "2023-08-18T07:00:45.691647Z" }, "origin_pos": 31, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def multibox_target(anchors, labels):\n", " \"\"\"使用真实边界框标记锚框\"\"\"\n", " batch_size, anchors = labels.shape[0], anchors.squeeze(0)\n", " batch_offset, batch_mask, batch_class_labels = [], [], []\n", " device, num_anchors = anchors.device, anchors.shape[0]\n", " for i in range(batch_size):\n", " label = labels[i, :, :]\n", " anchors_bbox_map = assign_anchor_to_bbox(\n", " label[:, 1:], anchors, device)\n", " bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(\n", " 1, 4)\n", " class_labels = torch.zeros(num_anchors, dtype=torch.long,\n", " device=device)\n", " assigned_bb = torch.zeros((num_anchors, 4), dtype=torch.float32,\n", " device=device)\n", " indices_true = torch.nonzero(anchors_bbox_map >= 0)\n", " bb_idx = anchors_bbox_map[indices_true]\n", " class_labels[indices_true] = label[bb_idx, 0].long() + 1\n", " assigned_bb[indices_true] = label[bb_idx, 1:]\n", " offset = offset_boxes(anchors, assigned_bb) * bbox_mask\n", " batch_offset.append(offset.reshape(-1))\n", " batch_mask.append(bbox_mask.reshape(-1))\n", " batch_class_labels.append(class_labels)\n", " bbox_offset = torch.stack(batch_offset)\n", " bbox_mask = torch.stack(batch_mask)\n", " class_labels = torch.stack(batch_class_labels)\n", " return (bbox_offset, bbox_mask, class_labels)" ] }, { "cell_type": "markdown", "id": "524b87ff", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "在图像中绘制这些真实边界框和锚框" ] }, { "cell_type": "code", "execution_count": 11, "id": "e22f46a6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:45.697286Z", "iopub.status.busy": "2023-08-18T07:00:45.696846Z", "iopub.status.idle": "2023-08-18T07:00:46.054321Z", "shell.execute_reply": "2023-08-18T07:00:46.053210Z" }, "origin_pos": 34, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:45.953248\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ground_truth = torch.tensor([[0, 0.1, 0.08, 0.52, 0.92],\n", " [1, 0.55, 0.2, 0.9, 0.88]])\n", "anchors = torch.tensor([[0, 0.1, 0.2, 0.3], [0.15, 0.2, 0.4, 0.4],\n", " [0.63, 0.05, 0.88, 0.98], [0.66, 0.45, 0.8, 0.8],\n", " [0.57, 0.3, 0.92, 0.9]])\n", "\n", "fig = d2l.plt.imshow(img)\n", "show_bboxes(fig.axes, ground_truth[:, 1:] * bbox_scale, ['dog', 'cat'], 'k')\n", "show_bboxes(fig.axes, anchors * bbox_scale, ['0', '1', '2', '3', '4']);" ] }, { "cell_type": "markdown", "id": "f0319b92", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "根据狗和猫的真实边界框,标注这些锚框的分类和偏移量" ] }, { "cell_type": "code", "execution_count": 13, "id": "78a2e091", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.075902Z", "iopub.status.busy": "2023-08-18T07:00:46.075000Z", "iopub.status.idle": "2023-08-18T07:00:46.084256Z", "shell.execute_reply": "2023-08-18T07:00:46.083114Z" }, "origin_pos": 40, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[0, 1, 2, 0, 2]])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels = multibox_target(anchors.unsqueeze(dim=0),\n", " ground_truth.unsqueeze(dim=0))\n", "\n", "labels[2]" ] }, { "cell_type": "code", "execution_count": 14, "id": "94a1597f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.097601Z", "iopub.status.busy": "2023-08-18T07:00:46.097228Z", "iopub.status.idle": "2023-08-18T07:00:46.104883Z", "shell.execute_reply": "2023-08-18T07:00:46.103744Z" }, "origin_pos": 42, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.,\n", " 1., 1.]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels[1]" ] }, { "cell_type": "code", "execution_count": 15, "id": "25e7f69b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.112301Z", "iopub.status.busy": "2023-08-18T07:00:46.111934Z", "iopub.status.idle": "2023-08-18T07:00:46.118802Z", "shell.execute_reply": "2023-08-18T07:00:46.117910Z" }, "origin_pos": 44, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, 1.40e+00, 1.00e+01,\n", " 2.59e+00, 7.18e+00, -1.20e+00, 2.69e-01, 1.68e+00, -1.57e+00,\n", " -0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, -5.71e-01, -1.00e+00,\n", " 4.17e-06, 6.26e-01]])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels[0]" ] }, { "cell_type": "markdown", "id": "bff0477c", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "应用逆偏移变换来返回预测的边界框坐标" ] }, { "cell_type": "code", "execution_count": 16, "id": "227ae1a8", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.127370Z", "iopub.status.busy": "2023-08-18T07:00:46.127014Z", "iopub.status.idle": "2023-08-18T07:00:46.133873Z", "shell.execute_reply": "2023-08-18T07:00:46.132925Z" }, "origin_pos": 46, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def offset_inverse(anchors, offset_preds):\n", " \"\"\"根据带有预测偏移量的锚框来预测边界框\"\"\"\n", " anc = d2l.box_corner_to_center(anchors)\n", " pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]\n", " pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]\n", " pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)\n", " predicted_bbox = d2l.box_center_to_corner(pred_bbox)\n", " return predicted_bbox" ] }, { "cell_type": "markdown", "id": "83ebeef8", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "以下`nms`函数按降序对置信度进行排序并返回其索引" ] }, { "cell_type": "code", "execution_count": 17, "id": "ac5c4e3c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.138652Z", "iopub.status.busy": "2023-08-18T07:00:46.138060Z", "iopub.status.idle": "2023-08-18T07:00:46.151360Z", "shell.execute_reply": "2023-08-18T07:00:46.150447Z" }, "origin_pos": 49, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def nms(boxes, scores, iou_threshold):\n", " \"\"\"对预测边界框的置信度进行排序\"\"\"\n", " B = torch.argsort(scores, dim=-1, descending=True)\n", " keep = []\n", " while B.numel() > 0:\n", " i = B[0]\n", " keep.append(i)\n", " if B.numel() == 1: break\n", " iou = box_iou(boxes[i, :].reshape(-1, 4),\n", " boxes[B[1:], :].reshape(-1, 4)).reshape(-1)\n", " inds = torch.nonzero(iou <= iou_threshold).reshape(-1)\n", " B = B[inds + 1]\n", " return torch.tensor(keep, device=boxes.device)" ] }, { "cell_type": "markdown", "id": "d5097b86", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "将非极大值抑制应用于预测边界框" ] }, { "cell_type": "code", "execution_count": 18, "id": "baa9f34f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.157117Z", "iopub.status.busy": "2023-08-18T07:00:46.156701Z", "iopub.status.idle": "2023-08-18T07:00:46.175104Z", "shell.execute_reply": "2023-08-18T07:00:46.174174Z" }, "origin_pos": 53, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,\n", " pos_threshold=0.009999999):\n", " \"\"\"使用非极大值抑制来预测边界框\"\"\"\n", " device, batch_size = cls_probs.device, cls_probs.shape[0]\n", " anchors = anchors.squeeze(0)\n", " num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]\n", " out = []\n", " for i in range(batch_size):\n", " cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)\n", " conf, class_id = torch.max(cls_prob[1:], 0)\n", " predicted_bb = offset_inverse(anchors, offset_pred)\n", " keep = nms(predicted_bb, conf, nms_threshold)\n", "\n", " all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)\n", " combined = torch.cat((keep, all_idx))\n", " uniques, counts = combined.unique(return_counts=True)\n", " non_keep = uniques[counts == 1]\n", " all_id_sorted = torch.cat((keep, non_keep))\n", " class_id[non_keep] = -1\n", " class_id = class_id[all_id_sorted]\n", " conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]\n", " below_min_idx = (conf < pos_threshold)\n", " class_id[below_min_idx] = -1\n", " conf[below_min_idx] = 1 - conf[below_min_idx]\n", " pred_info = torch.cat((class_id.unsqueeze(1),\n", " conf.unsqueeze(1),\n", " predicted_bb), dim=1)\n", " out.append(pred_info)\n", " return torch.stack(out)" ] }, { "cell_type": "markdown", "id": "0492308f", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "将上述算法应用到一个带有四个锚框的具体示例中" ] }, { "cell_type": "code", "execution_count": 19, "id": "4654e2f7", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.179666Z", "iopub.status.busy": "2023-08-18T07:00:46.179298Z", "iopub.status.idle": "2023-08-18T07:00:46.188634Z", "shell.execute_reply": "2023-08-18T07:00:46.187737Z" }, "origin_pos": 56, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "anchors = torch.tensor([[0.1, 0.08, 0.52, 0.92], [0.08, 0.2, 0.56, 0.95],\n", " [0.15, 0.3, 0.62, 0.91], [0.55, 0.2, 0.9, 0.88]])\n", "offset_preds = torch.tensor([0] * anchors.numel())\n", "cls_probs = torch.tensor([[0] * 4,\n", " [0.9, 0.8, 0.7, 0.1],\n", " [0.1, 0.2, 0.3, 0.9]])" ] }, { "cell_type": "markdown", "id": "34ef9113", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "在图像上绘制这些预测边界框和置信度" ] }, { "cell_type": "code", "execution_count": 20, "id": "b9619cba", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.193833Z", "iopub.status.busy": "2023-08-18T07:00:46.193198Z", "iopub.status.idle": "2023-08-18T07:00:46.435994Z", "shell.execute_reply": "2023-08-18T07:00:46.434923Z" }, "origin_pos": 59, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:46.369699\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = d2l.plt.imshow(img)\n", "show_bboxes(fig.axes, anchors * bbox_scale,\n", " ['dog=0.9', 'dog=0.8', 'dog=0.7', 'cat=0.9'])" ] }, { "cell_type": "markdown", "id": "10393ba6", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "返回结果的形状是(批量大小,锚框的数量,6)" ] }, { "cell_type": "code", "execution_count": 21, "id": "ab9c180c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.440560Z", "iopub.status.busy": "2023-08-18T07:00:46.439738Z", "iopub.status.idle": "2023-08-18T07:00:46.458973Z", "shell.execute_reply": "2023-08-18T07:00:46.457996Z" }, "origin_pos": 62, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 0.00, 0.90, 0.10, 0.08, 0.52, 0.92],\n", " [ 1.00, 0.90, 0.55, 0.20, 0.90, 0.88],\n", " [-1.00, 0.80, 0.08, 0.20, 0.56, 0.95],\n", " [-1.00, 0.70, 0.15, 0.30, 0.62, 0.91]]])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = multibox_detection(cls_probs.unsqueeze(dim=0),\n", " offset_preds.unsqueeze(dim=0),\n", " anchors.unsqueeze(dim=0),\n", " nms_threshold=0.5)\n", "output" ] }, { "cell_type": "markdown", "id": "e8a12bcd", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "输出由非极大值抑制保存的最终预测边界框" ] }, { "cell_type": "code", "execution_count": 22, "id": "f1e04b3f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T07:00:46.496750Z", "iopub.status.busy": "2023-08-18T07:00:46.495866Z", "iopub.status.idle": "2023-08-18T07:00:46.753536Z", "shell.execute_reply": "2023-08-18T07:00:46.752302Z" }, "origin_pos": 65, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T07:00:46.686810\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.5.1, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = d2l.plt.imshow(img)\n", "for i in output[0].detach().numpy():\n", " if i[0] == -1:\n", " continue\n", " label = ('dog=', 'cat=')[int(i[0])] + str(i[1])\n", " show_bboxes(fig.axes, [torch.tensor(i[2:]) * bbox_scale], label)" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }