{ "cells": [ { "cell_type": "markdown", "id": "ac6561ad", "metadata": { "slideshow": { "slide_type": "-" } }, "source": [ "# Anchor Boxes\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "14c6e4a6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:40.922407Z", "iopub.status.busy": "2023-08-18T19:32:40.921572Z", "iopub.status.idle": "2023-08-18T19:32:43.887912Z", "shell.execute_reply": "2023-08-18T19:32:43.886600Z" }, "origin_pos": 2, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "from d2l import torch as d2l\n", "\n", "torch.set_printoptions(2)" ] }, { "cell_type": "markdown", "id": "3e327544", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The width and height of the anchor box are $ws\\sqrt{r}$ and $hs/\\sqrt{r}$, respectively.\n", "Consider those combinations\n", "containing\n", "$$(s_1, r_1), (s_1, r_2), \\ldots, (s_1, r_m), (s_2, r_1), (s_3, r_1), \\ldots, (s_n, r_1)$$" ] }, { "cell_type": "code", "execution_count": 2, "id": "c0e17016", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:43.893598Z", "iopub.status.busy": "2023-08-18T19:32:43.893199Z", "iopub.status.idle": "2023-08-18T19:32:43.902717Z", "shell.execute_reply": "2023-08-18T19:32:43.901834Z" }, "origin_pos": 5, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def multibox_prior(data, sizes, ratios):\n", " \"\"\"Generate anchor boxes with different shapes centered on each pixel.\"\"\"\n", " in_height, in_width = data.shape[-2:]\n", " device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)\n", " boxes_per_pixel = (num_sizes + num_ratios - 1)\n", " size_tensor = torch.tensor(sizes, device=device)\n", " ratio_tensor = torch.tensor(ratios, device=device)\n", " offset_h, offset_w = 0.5, 0.5\n", " steps_h = 1.0 / in_height\n", " steps_w = 1.0 / in_width\n", "\n", " center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h\n", " center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w\n", " shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')\n", " shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)\n", "\n", " w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),\n", " sizes[0] * torch.sqrt(ratio_tensor[1:])))\\\n", " * in_height / in_width\n", " h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),\n", " sizes[0] / torch.sqrt(ratio_tensor[1:])))\n", " anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(\n", " in_height * in_width, 1) / 2\n", "\n", " out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],\n", " dim=1).repeat_interleave(boxes_per_pixel, dim=0)\n", " output = out_grid + anchor_manipulations\n", " return output.unsqueeze(0)" ] }, { "cell_type": "markdown", "id": "fca5f107", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The shape of the returned anchor box variable `Y`" ] }, { "cell_type": "code", "execution_count": 3, "id": "0509b5af", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:43.906825Z", "iopub.status.busy": "2023-08-18T19:32:43.906238Z", "iopub.status.idle": "2023-08-18T19:32:44.026992Z", "shell.execute_reply": "2023-08-18T19:32:44.025888Z" }, "origin_pos": 8, "tab": [ "pytorch" ] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "561 728\n" ] }, { "data": { "text/plain": [ "torch.Size([1, 2042040, 4])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = d2l.plt.imread('../img/catdog.jpg')\n", "h, w = img.shape[:2]\n", "\n", "print(h, w)\n", "X = torch.rand(size=(1, 3, h, w))\n", "Y = multibox_prior(X, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5])\n", "Y.shape" ] }, { "cell_type": "markdown", "id": "3b70f7ac", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Access the first anchor box centered on (250, 250)" ] }, { "cell_type": "code", "execution_count": 4, "id": "68fde78e", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.031304Z", "iopub.status.busy": "2023-08-18T19:32:44.030463Z", "iopub.status.idle": "2023-08-18T19:32:44.039395Z", "shell.execute_reply": "2023-08-18T19:32:44.038385Z" }, "origin_pos": 10, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([0.06, 0.07, 0.63, 0.82])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "boxes = Y.reshape(h, w, 5, 4)\n", "boxes[250, 250, 0, :]" ] }, { "cell_type": "markdown", "id": "29c2016a", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Show all the anchor boxes centered on one pixel in the image" ] }, { "cell_type": "code", "execution_count": 6, "id": "f4e0c959", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.055283Z", "iopub.status.busy": "2023-08-18T19:32:44.054735Z", "iopub.status.idle": "2023-08-18T19:32:44.372820Z", "shell.execute_reply": "2023-08-18T19:32:44.371734Z" }, "origin_pos": 14, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:32:44.289426\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def show_bboxes(axes, bboxes, labels=None, colors=None):\n", " \"\"\"Show bounding boxes.\"\"\"\n", "\n", " def make_list(obj, default_values=None):\n", " if obj is None:\n", " obj = default_values\n", " elif not isinstance(obj, (list, tuple)):\n", " obj = [obj]\n", " return obj\n", "\n", " labels = make_list(labels)\n", " colors = make_list(colors, ['b', 'g', 'r', 'm', 'c'])\n", " for i, bbox in enumerate(bboxes):\n", " color = colors[i % len(colors)]\n", " rect = d2l.bbox_to_rect(bbox.detach().numpy(), color)\n", " axes.add_patch(rect)\n", " if labels and len(labels) > i:\n", " text_color = 'k' if color == 'w' else 'w'\n", " axes.text(rect.xy[0], rect.xy[1], labels[i],\n", " va='center', ha='center', fontsize=9, color=text_color,\n", " bbox=dict(facecolor=color, lw=0))\n", "\n", "d2l.set_figsize()\n", "bbox_scale = torch.tensor((w, h, w, h))\n", "fig = d2l.plt.imshow(img)\n", "show_bboxes(fig.axes, boxes[250, 250, :, :] * bbox_scale,\n", " ['s=0.75, r=1', 's=0.5, r=1', 's=0.25, r=1', 's=0.75, r=2',\n", " 's=0.75, r=0.5'])" ] }, { "cell_type": "markdown", "id": "d7322def", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Intersection over Union (IoU)" ] }, { "cell_type": "code", "execution_count": 7, "id": "ed9ef04d", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.377494Z", "iopub.status.busy": "2023-08-18T19:32:44.376956Z", "iopub.status.idle": "2023-08-18T19:32:44.384235Z", "shell.execute_reply": "2023-08-18T19:32:44.382935Z" }, "origin_pos": 17, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def box_iou(boxes1, boxes2):\n", " \"\"\"Compute pairwise IoU across two lists of anchor or bounding boxes.\"\"\"\n", " box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *\n", " (boxes[:, 3] - boxes[:, 1]))\n", " areas1 = box_area(boxes1)\n", " areas2 = box_area(boxes2)\n", " inter_upperlefts = torch.max(boxes1[:, None, :2], boxes2[:, :2])\n", " inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])\n", " inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)\n", " inter_areas = inters[:, :, 0] * inters[:, :, 1]\n", " union_areas = areas1[:, None] + areas2 - inter_areas\n", " return inter_areas / union_areas" ] }, { "cell_type": "markdown", "id": "91164ce8", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Assigning Ground-Truth Bounding Boxes to Anchor Boxes" ] }, { "cell_type": "code", "execution_count": 8, "id": "50237ceb", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.388937Z", "iopub.status.busy": "2023-08-18T19:32:44.388033Z", "iopub.status.idle": "2023-08-18T19:32:44.397131Z", "shell.execute_reply": "2023-08-18T19:32:44.395737Z" }, "origin_pos": 20, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def assign_anchor_to_bbox(ground_truth, anchors, device, iou_threshold=0.5):\n", " \"\"\"Assign closest ground-truth bounding boxes to anchor boxes.\"\"\"\n", " num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]\n", " jaccard = box_iou(anchors, ground_truth)\n", " anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long,\n", " device=device)\n", " max_ious, indices = torch.max(jaccard, dim=1)\n", " anc_i = torch.nonzero(max_ious >= iou_threshold).reshape(-1)\n", " box_j = indices[max_ious >= iou_threshold]\n", " anchors_bbox_map[anc_i] = box_j\n", " col_discard = torch.full((num_anchors,), -1)\n", " row_discard = torch.full((num_gt_boxes,), -1)\n", " for _ in range(num_gt_boxes):\n", " max_idx = torch.argmax(jaccard)\n", " box_idx = (max_idx % num_gt_boxes).long()\n", " anc_idx = (max_idx / num_gt_boxes).long()\n", " anchors_bbox_map[anc_idx] = box_idx\n", " jaccard[:, box_idx] = col_discard\n", " jaccard[anc_idx, :] = row_discard\n", " return anchors_bbox_map" ] }, { "cell_type": "markdown", "id": "06ff993e", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Given the central coordinates of $A$ and $B$ as $(x_a, y_a)$ and $(x_b, y_b)$, \n", "their widths as $w_a$ and $w_b$, \n", "and their heights as $h_a$ and $h_b$, respectively. We may label the offset of $A$ as\n", "\n", "$$\\left( \\frac{ \\frac{x_b - x_a}{w_a} - \\mu_x }{\\sigma_x},\n", "\\frac{ \\frac{y_b - y_a}{h_a} - \\mu_y }{\\sigma_y},\n", "\\frac{ \\log \\frac{w_b}{w_a} - \\mu_w }{\\sigma_w},\n", "\\frac{ \\log \\frac{h_b}{h_a} - \\mu_h }{\\sigma_h}\\right)$$" ] }, { "cell_type": "code", "execution_count": 9, "id": "555ef063", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.401240Z", "iopub.status.busy": "2023-08-18T19:32:44.400631Z", "iopub.status.idle": "2023-08-18T19:32:44.406150Z", "shell.execute_reply": "2023-08-18T19:32:44.405146Z" }, "origin_pos": 22, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def offset_boxes(anchors, assigned_bb, eps=1e-6):\n", " \"\"\"Transform for anchor box offsets.\"\"\"\n", " c_anc = d2l.box_corner_to_center(anchors)\n", " c_assigned_bb = d2l.box_corner_to_center(assigned_bb)\n", " offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]\n", " offset_wh = 5 * torch.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])\n", " offset = torch.cat([offset_xy, offset_wh], axis=1)\n", " return offset" ] }, { "cell_type": "markdown", "id": "a076b90b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Label classes and offsets for anchor boxes" ] }, { "cell_type": "code", "execution_count": 10, "id": "16d715d6", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.410237Z", "iopub.status.busy": "2023-08-18T19:32:44.409587Z", "iopub.status.idle": "2023-08-18T19:32:44.420196Z", "shell.execute_reply": "2023-08-18T19:32:44.419115Z" }, "origin_pos": 25, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def multibox_target(anchors, labels):\n", " \"\"\"Label anchor boxes using ground-truth bounding boxes.\"\"\"\n", " batch_size, anchors = labels.shape[0], anchors.squeeze(0)\n", " batch_offset, batch_mask, batch_class_labels = [], [], []\n", " device, num_anchors = anchors.device, anchors.shape[0]\n", " for i in range(batch_size):\n", " label = labels[i, :, :]\n", " anchors_bbox_map = assign_anchor_to_bbox(\n", " label[:, 1:], anchors, device)\n", " bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(\n", " 1, 4)\n", " class_labels = torch.zeros(num_anchors, dtype=torch.long,\n", " device=device)\n", " assigned_bb = torch.zeros((num_anchors, 4), dtype=torch.float32,\n", " device=device)\n", " indices_true = torch.nonzero(anchors_bbox_map >= 0)\n", " bb_idx = anchors_bbox_map[indices_true]\n", " class_labels[indices_true] = label[bb_idx, 0].long() + 1\n", " assigned_bb[indices_true] = label[bb_idx, 1:]\n", " offset = offset_boxes(anchors, assigned_bb) * bbox_mask\n", " batch_offset.append(offset.reshape(-1))\n", " batch_mask.append(bbox_mask.reshape(-1))\n", " batch_class_labels.append(class_labels)\n", " bbox_offset = torch.stack(batch_offset)\n", " bbox_mask = torch.stack(batch_mask)\n", " class_labels = torch.stack(batch_class_labels)\n", " return (bbox_offset, bbox_mask, class_labels)" ] }, { "cell_type": "markdown", "id": "e8bd575e", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Plot these ground-truth bounding boxes \n", "and anchor boxes \n", "in the image" ] }, { "cell_type": "code", "execution_count": 11, "id": "5241a5a4", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.424500Z", "iopub.status.busy": "2023-08-18T19:32:44.423804Z", "iopub.status.idle": "2023-08-18T19:32:44.818235Z", "shell.execute_reply": "2023-08-18T19:32:44.817279Z" }, "origin_pos": 27, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:32:44.708706\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ground_truth = torch.tensor([[0, 0.1, 0.08, 0.52, 0.92],\n", " [1, 0.55, 0.2, 0.9, 0.88]])\n", "anchors = torch.tensor([[0, 0.1, 0.2, 0.3], [0.15, 0.2, 0.4, 0.4],\n", " [0.63, 0.05, 0.88, 0.98], [0.66, 0.45, 0.8, 0.8],\n", " [0.57, 0.3, 0.92, 0.9]])\n", "\n", "fig = d2l.plt.imshow(img)\n", "show_bboxes(fig.axes, ground_truth[:, 1:] * bbox_scale, ['dog', 'cat'], 'k')\n", "show_bboxes(fig.axes, anchors * bbox_scale, ['0', '1', '2', '3', '4']);" ] }, { "cell_type": "markdown", "id": "33f5ee91", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Label classes and offsets\n", "of these anchor boxes based on\n", "the ground-truth bounding boxes" ] }, { "cell_type": "code", "execution_count": 13, "id": "27063d40", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.834163Z", "iopub.status.busy": "2023-08-18T19:32:44.833548Z", "iopub.status.idle": "2023-08-18T19:32:44.840313Z", "shell.execute_reply": "2023-08-18T19:32:44.839244Z" }, "origin_pos": 32, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[0, 1, 2, 0, 2]])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels = multibox_target(anchors.unsqueeze(dim=0),\n", " ground_truth.unsqueeze(dim=0))\n", "\n", "labels[2]" ] }, { "cell_type": "code", "execution_count": 14, "id": "0b55433b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.844357Z", "iopub.status.busy": "2023-08-18T19:32:44.843744Z", "iopub.status.idle": "2023-08-18T19:32:44.850721Z", "shell.execute_reply": "2023-08-18T19:32:44.849660Z" }, "origin_pos": 34, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.,\n", " 1., 1.]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels[1]" ] }, { "cell_type": "code", "execution_count": 15, "id": "3ce74d8f", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.854664Z", "iopub.status.busy": "2023-08-18T19:32:44.854067Z", "iopub.status.idle": "2023-08-18T19:32:44.861006Z", "shell.execute_reply": "2023-08-18T19:32:44.859933Z" }, "origin_pos": 36, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, 1.40e+00, 1.00e+01,\n", " 2.59e+00, 7.18e+00, -1.20e+00, 2.69e-01, 1.68e+00, -1.57e+00,\n", " -0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, -5.71e-01, -1.00e+00,\n", " 4.17e-06, 6.26e-01]])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels[0]" ] }, { "cell_type": "markdown", "id": "1250fda8", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Applies inverse offset transformations to\n", "return the predicted bounding box coordinates" ] }, { "cell_type": "code", "execution_count": 16, "id": "2aa3364b", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.865010Z", "iopub.status.busy": "2023-08-18T19:32:44.864433Z", "iopub.status.idle": "2023-08-18T19:32:44.870968Z", "shell.execute_reply": "2023-08-18T19:32:44.869955Z" }, "origin_pos": 38, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def offset_inverse(anchors, offset_preds):\n", " \"\"\"Predict bounding boxes based on anchor boxes with predicted offsets.\"\"\"\n", " anc = d2l.box_corner_to_center(anchors)\n", " pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]\n", " pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]\n", " pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)\n", " predicted_bbox = d2l.box_center_to_corner(pred_bbox)\n", " return predicted_bbox" ] }, { "cell_type": "markdown", "id": "4cb8a18b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The following `nms` function sorts confidence scores in descending order and returns their indices" ] }, { "cell_type": "code", "execution_count": 17, "id": "4ab91d73", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.874979Z", "iopub.status.busy": "2023-08-18T19:32:44.874405Z", "iopub.status.idle": "2023-08-18T19:32:44.881796Z", "shell.execute_reply": "2023-08-18T19:32:44.880768Z" }, "origin_pos": 41, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def nms(boxes, scores, iou_threshold):\n", " \"\"\"Sort confidence scores of predicted bounding boxes.\"\"\"\n", " B = torch.argsort(scores, dim=-1, descending=True)\n", " keep = []\n", " while B.numel() > 0:\n", " i = B[0]\n", " keep.append(i)\n", " if B.numel() == 1: break\n", " iou = box_iou(boxes[i, :].reshape(-1, 4),\n", " boxes[B[1:], :].reshape(-1, 4)).reshape(-1)\n", " inds = torch.nonzero(iou <= iou_threshold).reshape(-1)\n", " B = B[inds + 1]\n", " return torch.tensor(keep, device=boxes.device)" ] }, { "cell_type": "markdown", "id": "0e7e8486", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Apply non-maximum suppression\n", "to predicting bounding boxes" ] }, { "cell_type": "code", "execution_count": 18, "id": "4a5b6c3c", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.885871Z", "iopub.status.busy": "2023-08-18T19:32:44.885083Z", "iopub.status.idle": "2023-08-18T19:32:44.896254Z", "shell.execute_reply": "2023-08-18T19:32:44.895293Z" }, "origin_pos": 44, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,\n", " pos_threshold=0.009999999):\n", " \"\"\"Predict bounding boxes using non-maximum suppression.\"\"\"\n", " device, batch_size = cls_probs.device, cls_probs.shape[0]\n", " anchors = anchors.squeeze(0)\n", " num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]\n", " out = []\n", " for i in range(batch_size):\n", " cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)\n", " conf, class_id = torch.max(cls_prob[1:], 0)\n", " predicted_bb = offset_inverse(anchors, offset_pred)\n", " keep = nms(predicted_bb, conf, nms_threshold)\n", " all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)\n", " combined = torch.cat((keep, all_idx))\n", " uniques, counts = combined.unique(return_counts=True)\n", " non_keep = uniques[counts == 1]\n", " all_id_sorted = torch.cat((keep, non_keep))\n", " class_id[non_keep] = -1\n", " class_id = class_id[all_id_sorted]\n", " conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]\n", " below_min_idx = (conf < pos_threshold)\n", " class_id[below_min_idx] = -1\n", " conf[below_min_idx] = 1 - conf[below_min_idx]\n", " pred_info = torch.cat((class_id.unsqueeze(1),\n", " conf.unsqueeze(1),\n", " predicted_bb), dim=1)\n", " out.append(pred_info)\n", " return torch.stack(out)" ] }, { "cell_type": "markdown", "id": "72eb0d88", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Apply the above implementations\n", "to a concrete example with four anchor boxes" ] }, { "cell_type": "code", "execution_count": 19, "id": "e0fd4db5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.900401Z", "iopub.status.busy": "2023-08-18T19:32:44.899731Z", "iopub.status.idle": "2023-08-18T19:32:44.906161Z", "shell.execute_reply": "2023-08-18T19:32:44.905287Z" }, "origin_pos": 46, "tab": [ "pytorch" ] }, "outputs": [], "source": [ "anchors = torch.tensor([[0.1, 0.08, 0.52, 0.92], [0.08, 0.2, 0.56, 0.95],\n", " [0.15, 0.3, 0.62, 0.91], [0.55, 0.2, 0.9, 0.88]])\n", "offset_preds = torch.tensor([0] * anchors.numel())\n", "cls_probs = torch.tensor([[0] * 4,\n", " [0.9, 0.8, 0.7, 0.1],\n", " [0.1, 0.2, 0.3, 0.9]])" ] }, { "cell_type": "markdown", "id": "ca311d2b", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Plot these predicted bounding boxes with their confidence on the image" ] }, { "cell_type": "code", "execution_count": 20, "id": "a6637511", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:44.910439Z", "iopub.status.busy": "2023-08-18T19:32:44.909614Z", "iopub.status.idle": "2023-08-18T19:32:45.357943Z", "shell.execute_reply": "2023-08-18T19:32:45.357080Z" }, "origin_pos": 48, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:32:45.272849\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = d2l.plt.imshow(img)\n", "show_bboxes(fig.axes, anchors * bbox_scale,\n", " ['dog=0.9', 'dog=0.8', 'dog=0.7', 'cat=0.9'])" ] }, { "cell_type": "markdown", "id": "117a9ef6", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "The shape of the returned result" ] }, { "cell_type": "code", "execution_count": 21, "id": "81abdaa3", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:45.361596Z", "iopub.status.busy": "2023-08-18T19:32:45.360996Z", "iopub.status.idle": "2023-08-18T19:32:45.369505Z", "shell.execute_reply": "2023-08-18T19:32:45.368704Z" }, "origin_pos": 51, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 0.00, 0.90, 0.10, 0.08, 0.52, 0.92],\n", " [ 1.00, 0.90, 0.55, 0.20, 0.90, 0.88],\n", " [-1.00, 0.80, 0.08, 0.20, 0.56, 0.95],\n", " [-1.00, 0.70, 0.15, 0.30, 0.62, 0.91]]])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = multibox_detection(cls_probs.unsqueeze(dim=0),\n", " offset_preds.unsqueeze(dim=0),\n", " anchors.unsqueeze(dim=0),\n", " nms_threshold=0.5)\n", "output" ] }, { "cell_type": "markdown", "id": "c6fd8d37", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Output the final predicted bounding box\n", "kept by non-maximum suppression" ] }, { "cell_type": "code", "execution_count": 22, "id": "a9d97bd5", "metadata": { "execution": { "iopub.execute_input": "2023-08-18T19:32:45.373009Z", "iopub.status.busy": "2023-08-18T19:32:45.372440Z", "iopub.status.idle": "2023-08-18T19:32:45.667474Z", "shell.execute_reply": "2023-08-18T19:32:45.666054Z" }, "origin_pos": 53, "tab": [ "pytorch" ] }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2023-08-18T19:32:45.565008\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.7.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig = d2l.plt.imshow(img)\n", "for i in output[0].detach().numpy():\n", " if i[0] == -1:\n", " continue\n", " label = ('dog=', 'cat=')[int(i[0])] + str(i[1])\n", " show_bboxes(fig.axes, [torch.tensor(i[2:]) * bbox_scale], label)" ] } ], "metadata": { "celltoolbar": "Slideshow", "language_info": { "name": "python" }, "required_libs": [], "rise": { "autolaunch": true, "enable_chalkboard": true, "overlay": "
", "scroll": true } }, "nbformat": 4, "nbformat_minor": 5 }