{
"cells": [
{
"cell_type": "markdown",
"id": "ac6561ad",
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"source": [
"# Anchor Boxes\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "14c6e4a6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:40.922407Z",
"iopub.status.busy": "2023-08-18T19:32:40.921572Z",
"iopub.status.idle": "2023-08-18T19:32:43.887912Z",
"shell.execute_reply": "2023-08-18T19:32:43.886600Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import torch\n",
"from d2l import torch as d2l\n",
"\n",
"torch.set_printoptions(2)"
]
},
{
"cell_type": "markdown",
"id": "3e327544",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"The width and height of the anchor box are $ws\\sqrt{r}$ and $hs/\\sqrt{r}$, respectively.\n",
"Consider those combinations\n",
"containing\n",
"$$(s_1, r_1), (s_1, r_2), \\ldots, (s_1, r_m), (s_2, r_1), (s_3, r_1), \\ldots, (s_n, r_1)$$"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c0e17016",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:43.893598Z",
"iopub.status.busy": "2023-08-18T19:32:43.893199Z",
"iopub.status.idle": "2023-08-18T19:32:43.902717Z",
"shell.execute_reply": "2023-08-18T19:32:43.901834Z"
},
"origin_pos": 5,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def multibox_prior(data, sizes, ratios):\n",
" \"\"\"Generate anchor boxes with different shapes centered on each pixel.\"\"\"\n",
" in_height, in_width = data.shape[-2:]\n",
" device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)\n",
" boxes_per_pixel = (num_sizes + num_ratios - 1)\n",
" size_tensor = torch.tensor(sizes, device=device)\n",
" ratio_tensor = torch.tensor(ratios, device=device)\n",
" offset_h, offset_w = 0.5, 0.5\n",
" steps_h = 1.0 / in_height\n",
" steps_w = 1.0 / in_width\n",
"\n",
" center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h\n",
" center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w\n",
" shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')\n",
" shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)\n",
"\n",
" w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),\n",
" sizes[0] * torch.sqrt(ratio_tensor[1:])))\\\n",
" * in_height / in_width\n",
" h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),\n",
" sizes[0] / torch.sqrt(ratio_tensor[1:])))\n",
" anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(\n",
" in_height * in_width, 1) / 2\n",
"\n",
" out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],\n",
" dim=1).repeat_interleave(boxes_per_pixel, dim=0)\n",
" output = out_grid + anchor_manipulations\n",
" return output.unsqueeze(0)"
]
},
{
"cell_type": "markdown",
"id": "fca5f107",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"The shape of the returned anchor box variable `Y`"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0509b5af",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:43.906825Z",
"iopub.status.busy": "2023-08-18T19:32:43.906238Z",
"iopub.status.idle": "2023-08-18T19:32:44.026992Z",
"shell.execute_reply": "2023-08-18T19:32:44.025888Z"
},
"origin_pos": 8,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"561 728\n"
]
},
{
"data": {
"text/plain": [
"torch.Size([1, 2042040, 4])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img = d2l.plt.imread('../img/catdog.jpg')\n",
"h, w = img.shape[:2]\n",
"\n",
"print(h, w)\n",
"X = torch.rand(size=(1, 3, h, w))\n",
"Y = multibox_prior(X, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5])\n",
"Y.shape"
]
},
{
"cell_type": "markdown",
"id": "3b70f7ac",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Access the first anchor box centered on (250, 250)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "68fde78e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.031304Z",
"iopub.status.busy": "2023-08-18T19:32:44.030463Z",
"iopub.status.idle": "2023-08-18T19:32:44.039395Z",
"shell.execute_reply": "2023-08-18T19:32:44.038385Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.06, 0.07, 0.63, 0.82])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"boxes = Y.reshape(h, w, 5, 4)\n",
"boxes[250, 250, 0, :]"
]
},
{
"cell_type": "markdown",
"id": "29c2016a",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Show all the anchor boxes centered on one pixel in the image"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f4e0c959",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.055283Z",
"iopub.status.busy": "2023-08-18T19:32:44.054735Z",
"iopub.status.idle": "2023-08-18T19:32:44.372820Z",
"shell.execute_reply": "2023-08-18T19:32:44.371734Z"
},
"origin_pos": 14,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def show_bboxes(axes, bboxes, labels=None, colors=None):\n",
" \"\"\"Show bounding boxes.\"\"\"\n",
"\n",
" def make_list(obj, default_values=None):\n",
" if obj is None:\n",
" obj = default_values\n",
" elif not isinstance(obj, (list, tuple)):\n",
" obj = [obj]\n",
" return obj\n",
"\n",
" labels = make_list(labels)\n",
" colors = make_list(colors, ['b', 'g', 'r', 'm', 'c'])\n",
" for i, bbox in enumerate(bboxes):\n",
" color = colors[i % len(colors)]\n",
" rect = d2l.bbox_to_rect(bbox.detach().numpy(), color)\n",
" axes.add_patch(rect)\n",
" if labels and len(labels) > i:\n",
" text_color = 'k' if color == 'w' else 'w'\n",
" axes.text(rect.xy[0], rect.xy[1], labels[i],\n",
" va='center', ha='center', fontsize=9, color=text_color,\n",
" bbox=dict(facecolor=color, lw=0))\n",
"\n",
"d2l.set_figsize()\n",
"bbox_scale = torch.tensor((w, h, w, h))\n",
"fig = d2l.plt.imshow(img)\n",
"show_bboxes(fig.axes, boxes[250, 250, :, :] * bbox_scale,\n",
" ['s=0.75, r=1', 's=0.5, r=1', 's=0.25, r=1', 's=0.75, r=2',\n",
" 's=0.75, r=0.5'])"
]
},
{
"cell_type": "markdown",
"id": "d7322def",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Intersection over Union (IoU)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ed9ef04d",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.377494Z",
"iopub.status.busy": "2023-08-18T19:32:44.376956Z",
"iopub.status.idle": "2023-08-18T19:32:44.384235Z",
"shell.execute_reply": "2023-08-18T19:32:44.382935Z"
},
"origin_pos": 17,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def box_iou(boxes1, boxes2):\n",
" \"\"\"Compute pairwise IoU across two lists of anchor or bounding boxes.\"\"\"\n",
" box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *\n",
" (boxes[:, 3] - boxes[:, 1]))\n",
" areas1 = box_area(boxes1)\n",
" areas2 = box_area(boxes2)\n",
" inter_upperlefts = torch.max(boxes1[:, None, :2], boxes2[:, :2])\n",
" inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])\n",
" inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)\n",
" inter_areas = inters[:, :, 0] * inters[:, :, 1]\n",
" union_areas = areas1[:, None] + areas2 - inter_areas\n",
" return inter_areas / union_areas"
]
},
{
"cell_type": "markdown",
"id": "91164ce8",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Assigning Ground-Truth Bounding Boxes to Anchor Boxes"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "50237ceb",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.388937Z",
"iopub.status.busy": "2023-08-18T19:32:44.388033Z",
"iopub.status.idle": "2023-08-18T19:32:44.397131Z",
"shell.execute_reply": "2023-08-18T19:32:44.395737Z"
},
"origin_pos": 20,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def assign_anchor_to_bbox(ground_truth, anchors, device, iou_threshold=0.5):\n",
" \"\"\"Assign closest ground-truth bounding boxes to anchor boxes.\"\"\"\n",
" num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]\n",
" jaccard = box_iou(anchors, ground_truth)\n",
" anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long,\n",
" device=device)\n",
" max_ious, indices = torch.max(jaccard, dim=1)\n",
" anc_i = torch.nonzero(max_ious >= iou_threshold).reshape(-1)\n",
" box_j = indices[max_ious >= iou_threshold]\n",
" anchors_bbox_map[anc_i] = box_j\n",
" col_discard = torch.full((num_anchors,), -1)\n",
" row_discard = torch.full((num_gt_boxes,), -1)\n",
" for _ in range(num_gt_boxes):\n",
" max_idx = torch.argmax(jaccard)\n",
" box_idx = (max_idx % num_gt_boxes).long()\n",
" anc_idx = (max_idx / num_gt_boxes).long()\n",
" anchors_bbox_map[anc_idx] = box_idx\n",
" jaccard[:, box_idx] = col_discard\n",
" jaccard[anc_idx, :] = row_discard\n",
" return anchors_bbox_map"
]
},
{
"cell_type": "markdown",
"id": "06ff993e",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Given the central coordinates of $A$ and $B$ as $(x_a, y_a)$ and $(x_b, y_b)$, \n",
"their widths as $w_a$ and $w_b$, \n",
"and their heights as $h_a$ and $h_b$, respectively. We may label the offset of $A$ as\n",
"\n",
"$$\\left( \\frac{ \\frac{x_b - x_a}{w_a} - \\mu_x }{\\sigma_x},\n",
"\\frac{ \\frac{y_b - y_a}{h_a} - \\mu_y }{\\sigma_y},\n",
"\\frac{ \\log \\frac{w_b}{w_a} - \\mu_w }{\\sigma_w},\n",
"\\frac{ \\log \\frac{h_b}{h_a} - \\mu_h }{\\sigma_h}\\right)$$"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "555ef063",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.401240Z",
"iopub.status.busy": "2023-08-18T19:32:44.400631Z",
"iopub.status.idle": "2023-08-18T19:32:44.406150Z",
"shell.execute_reply": "2023-08-18T19:32:44.405146Z"
},
"origin_pos": 22,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def offset_boxes(anchors, assigned_bb, eps=1e-6):\n",
" \"\"\"Transform for anchor box offsets.\"\"\"\n",
" c_anc = d2l.box_corner_to_center(anchors)\n",
" c_assigned_bb = d2l.box_corner_to_center(assigned_bb)\n",
" offset_xy = 10 * (c_assigned_bb[:, :2] - c_anc[:, :2]) / c_anc[:, 2:]\n",
" offset_wh = 5 * torch.log(eps + c_assigned_bb[:, 2:] / c_anc[:, 2:])\n",
" offset = torch.cat([offset_xy, offset_wh], axis=1)\n",
" return offset"
]
},
{
"cell_type": "markdown",
"id": "a076b90b",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Label classes and offsets for anchor boxes"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "16d715d6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.410237Z",
"iopub.status.busy": "2023-08-18T19:32:44.409587Z",
"iopub.status.idle": "2023-08-18T19:32:44.420196Z",
"shell.execute_reply": "2023-08-18T19:32:44.419115Z"
},
"origin_pos": 25,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def multibox_target(anchors, labels):\n",
" \"\"\"Label anchor boxes using ground-truth bounding boxes.\"\"\"\n",
" batch_size, anchors = labels.shape[0], anchors.squeeze(0)\n",
" batch_offset, batch_mask, batch_class_labels = [], [], []\n",
" device, num_anchors = anchors.device, anchors.shape[0]\n",
" for i in range(batch_size):\n",
" label = labels[i, :, :]\n",
" anchors_bbox_map = assign_anchor_to_bbox(\n",
" label[:, 1:], anchors, device)\n",
" bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(\n",
" 1, 4)\n",
" class_labels = torch.zeros(num_anchors, dtype=torch.long,\n",
" device=device)\n",
" assigned_bb = torch.zeros((num_anchors, 4), dtype=torch.float32,\n",
" device=device)\n",
" indices_true = torch.nonzero(anchors_bbox_map >= 0)\n",
" bb_idx = anchors_bbox_map[indices_true]\n",
" class_labels[indices_true] = label[bb_idx, 0].long() + 1\n",
" assigned_bb[indices_true] = label[bb_idx, 1:]\n",
" offset = offset_boxes(anchors, assigned_bb) * bbox_mask\n",
" batch_offset.append(offset.reshape(-1))\n",
" batch_mask.append(bbox_mask.reshape(-1))\n",
" batch_class_labels.append(class_labels)\n",
" bbox_offset = torch.stack(batch_offset)\n",
" bbox_mask = torch.stack(batch_mask)\n",
" class_labels = torch.stack(batch_class_labels)\n",
" return (bbox_offset, bbox_mask, class_labels)"
]
},
{
"cell_type": "markdown",
"id": "e8bd575e",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Plot these ground-truth bounding boxes \n",
"and anchor boxes \n",
"in the image"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5241a5a4",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.424500Z",
"iopub.status.busy": "2023-08-18T19:32:44.423804Z",
"iopub.status.idle": "2023-08-18T19:32:44.818235Z",
"shell.execute_reply": "2023-08-18T19:32:44.817279Z"
},
"origin_pos": 27,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"ground_truth = torch.tensor([[0, 0.1, 0.08, 0.52, 0.92],\n",
" [1, 0.55, 0.2, 0.9, 0.88]])\n",
"anchors = torch.tensor([[0, 0.1, 0.2, 0.3], [0.15, 0.2, 0.4, 0.4],\n",
" [0.63, 0.05, 0.88, 0.98], [0.66, 0.45, 0.8, 0.8],\n",
" [0.57, 0.3, 0.92, 0.9]])\n",
"\n",
"fig = d2l.plt.imshow(img)\n",
"show_bboxes(fig.axes, ground_truth[:, 1:] * bbox_scale, ['dog', 'cat'], 'k')\n",
"show_bboxes(fig.axes, anchors * bbox_scale, ['0', '1', '2', '3', '4']);"
]
},
{
"cell_type": "markdown",
"id": "33f5ee91",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Label classes and offsets\n",
"of these anchor boxes based on\n",
"the ground-truth bounding boxes"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "27063d40",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.834163Z",
"iopub.status.busy": "2023-08-18T19:32:44.833548Z",
"iopub.status.idle": "2023-08-18T19:32:44.840313Z",
"shell.execute_reply": "2023-08-18T19:32:44.839244Z"
},
"origin_pos": 32,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 1, 2, 0, 2]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels = multibox_target(anchors.unsqueeze(dim=0),\n",
" ground_truth.unsqueeze(dim=0))\n",
"\n",
"labels[2]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "0b55433b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.844357Z",
"iopub.status.busy": "2023-08-18T19:32:44.843744Z",
"iopub.status.idle": "2023-08-18T19:32:44.850721Z",
"shell.execute_reply": "2023-08-18T19:32:44.849660Z"
},
"origin_pos": 34,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1.,\n",
" 1., 1.]])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels[1]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "3ce74d8f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.854664Z",
"iopub.status.busy": "2023-08-18T19:32:44.854067Z",
"iopub.status.idle": "2023-08-18T19:32:44.861006Z",
"shell.execute_reply": "2023-08-18T19:32:44.859933Z"
},
"origin_pos": 36,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, 1.40e+00, 1.00e+01,\n",
" 2.59e+00, 7.18e+00, -1.20e+00, 2.69e-01, 1.68e+00, -1.57e+00,\n",
" -0.00e+00, -0.00e+00, -0.00e+00, -0.00e+00, -5.71e-01, -1.00e+00,\n",
" 4.17e-06, 6.26e-01]])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels[0]"
]
},
{
"cell_type": "markdown",
"id": "1250fda8",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Applies inverse offset transformations to\n",
"return the predicted bounding box coordinates"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "2aa3364b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.865010Z",
"iopub.status.busy": "2023-08-18T19:32:44.864433Z",
"iopub.status.idle": "2023-08-18T19:32:44.870968Z",
"shell.execute_reply": "2023-08-18T19:32:44.869955Z"
},
"origin_pos": 38,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def offset_inverse(anchors, offset_preds):\n",
" \"\"\"Predict bounding boxes based on anchor boxes with predicted offsets.\"\"\"\n",
" anc = d2l.box_corner_to_center(anchors)\n",
" pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]\n",
" pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]\n",
" pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)\n",
" predicted_bbox = d2l.box_center_to_corner(pred_bbox)\n",
" return predicted_bbox"
]
},
{
"cell_type": "markdown",
"id": "4cb8a18b",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"The following `nms` function sorts confidence scores in descending order and returns their indices"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4ab91d73",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.874979Z",
"iopub.status.busy": "2023-08-18T19:32:44.874405Z",
"iopub.status.idle": "2023-08-18T19:32:44.881796Z",
"shell.execute_reply": "2023-08-18T19:32:44.880768Z"
},
"origin_pos": 41,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def nms(boxes, scores, iou_threshold):\n",
" \"\"\"Sort confidence scores of predicted bounding boxes.\"\"\"\n",
" B = torch.argsort(scores, dim=-1, descending=True)\n",
" keep = []\n",
" while B.numel() > 0:\n",
" i = B[0]\n",
" keep.append(i)\n",
" if B.numel() == 1: break\n",
" iou = box_iou(boxes[i, :].reshape(-1, 4),\n",
" boxes[B[1:], :].reshape(-1, 4)).reshape(-1)\n",
" inds = torch.nonzero(iou <= iou_threshold).reshape(-1)\n",
" B = B[inds + 1]\n",
" return torch.tensor(keep, device=boxes.device)"
]
},
{
"cell_type": "markdown",
"id": "0e7e8486",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Apply non-maximum suppression\n",
"to predicting bounding boxes"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "4a5b6c3c",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.885871Z",
"iopub.status.busy": "2023-08-18T19:32:44.885083Z",
"iopub.status.idle": "2023-08-18T19:32:44.896254Z",
"shell.execute_reply": "2023-08-18T19:32:44.895293Z"
},
"origin_pos": 44,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,\n",
" pos_threshold=0.009999999):\n",
" \"\"\"Predict bounding boxes using non-maximum suppression.\"\"\"\n",
" device, batch_size = cls_probs.device, cls_probs.shape[0]\n",
" anchors = anchors.squeeze(0)\n",
" num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]\n",
" out = []\n",
" for i in range(batch_size):\n",
" cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)\n",
" conf, class_id = torch.max(cls_prob[1:], 0)\n",
" predicted_bb = offset_inverse(anchors, offset_pred)\n",
" keep = nms(predicted_bb, conf, nms_threshold)\n",
" all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)\n",
" combined = torch.cat((keep, all_idx))\n",
" uniques, counts = combined.unique(return_counts=True)\n",
" non_keep = uniques[counts == 1]\n",
" all_id_sorted = torch.cat((keep, non_keep))\n",
" class_id[non_keep] = -1\n",
" class_id = class_id[all_id_sorted]\n",
" conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]\n",
" below_min_idx = (conf < pos_threshold)\n",
" class_id[below_min_idx] = -1\n",
" conf[below_min_idx] = 1 - conf[below_min_idx]\n",
" pred_info = torch.cat((class_id.unsqueeze(1),\n",
" conf.unsqueeze(1),\n",
" predicted_bb), dim=1)\n",
" out.append(pred_info)\n",
" return torch.stack(out)"
]
},
{
"cell_type": "markdown",
"id": "72eb0d88",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Apply the above implementations\n",
"to a concrete example with four anchor boxes"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "e0fd4db5",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.900401Z",
"iopub.status.busy": "2023-08-18T19:32:44.899731Z",
"iopub.status.idle": "2023-08-18T19:32:44.906161Z",
"shell.execute_reply": "2023-08-18T19:32:44.905287Z"
},
"origin_pos": 46,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"anchors = torch.tensor([[0.1, 0.08, 0.52, 0.92], [0.08, 0.2, 0.56, 0.95],\n",
" [0.15, 0.3, 0.62, 0.91], [0.55, 0.2, 0.9, 0.88]])\n",
"offset_preds = torch.tensor([0] * anchors.numel())\n",
"cls_probs = torch.tensor([[0] * 4,\n",
" [0.9, 0.8, 0.7, 0.1],\n",
" [0.1, 0.2, 0.3, 0.9]])"
]
},
{
"cell_type": "markdown",
"id": "ca311d2b",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Plot these predicted bounding boxes with their confidence on the image"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "a6637511",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:44.910439Z",
"iopub.status.busy": "2023-08-18T19:32:44.909614Z",
"iopub.status.idle": "2023-08-18T19:32:45.357943Z",
"shell.execute_reply": "2023-08-18T19:32:45.357080Z"
},
"origin_pos": 48,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = d2l.plt.imshow(img)\n",
"show_bboxes(fig.axes, anchors * bbox_scale,\n",
" ['dog=0.9', 'dog=0.8', 'dog=0.7', 'cat=0.9'])"
]
},
{
"cell_type": "markdown",
"id": "117a9ef6",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"The shape of the returned result"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "81abdaa3",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:45.361596Z",
"iopub.status.busy": "2023-08-18T19:32:45.360996Z",
"iopub.status.idle": "2023-08-18T19:32:45.369505Z",
"shell.execute_reply": "2023-08-18T19:32:45.368704Z"
},
"origin_pos": 51,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ 0.00, 0.90, 0.10, 0.08, 0.52, 0.92],\n",
" [ 1.00, 0.90, 0.55, 0.20, 0.90, 0.88],\n",
" [-1.00, 0.80, 0.08, 0.20, 0.56, 0.95],\n",
" [-1.00, 0.70, 0.15, 0.30, 0.62, 0.91]]])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output = multibox_detection(cls_probs.unsqueeze(dim=0),\n",
" offset_preds.unsqueeze(dim=0),\n",
" anchors.unsqueeze(dim=0),\n",
" nms_threshold=0.5)\n",
"output"
]
},
{
"cell_type": "markdown",
"id": "c6fd8d37",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Output the final predicted bounding box\n",
"kept by non-maximum suppression"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "a9d97bd5",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:32:45.373009Z",
"iopub.status.busy": "2023-08-18T19:32:45.372440Z",
"iopub.status.idle": "2023-08-18T19:32:45.667474Z",
"shell.execute_reply": "2023-08-18T19:32:45.666054Z"
},
"origin_pos": 53,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig = d2l.plt.imshow(img)\n",
"for i in output[0].detach().numpy():\n",
" if i[0] == -1:\n",
" continue\n",
" label = ('dog=', 'cat=')[int(i[0])] + str(i[1])\n",
" show_bboxes(fig.axes, [torch.tensor(i[2:]) * bbox_scale], label)"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}