{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('../data/coco')\n", "ANNOT_PATH = PATH/'annotations'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds = ObjectDetectDataset.from_json(PATH/'train2017', ANNOT_PATH/'train_sample.json')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = get_transforms()\n", "train_tds = DatasetTfm(train_ds, tfms[0], tfm_y=True, size=224)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = train_tds[5]\n", "x.show(y=y, classes=train_ds.classes, figsize=(6,4))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "size = 224" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = ([flip_lr(p=0.5), crop_pad(size=size)], [crop_pad(size=size)])\n", "train_tds = DatasetTfm(train_ds, tfms[0], tfm_y=True, size=size, padding_mode='zeros', do_crop=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = train_tds[0]\n", "x.show(y=y, classes=train_ds.classes, figsize=(6,4))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y.data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x.size" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def bb_pad_collate(samples:BatchSamples, pad_idx:int=0, pad_first:bool=True) -> Tuple[FloatTensor, Tuple[LongTensor, LongTensor]]:\n", " \"Function that collect samples and adds padding.\"\n", " max_len = max([len(s[1].data[1]) for s in samples])\n", " bboxes = torch.zeros(len(samples), max_len, 4)\n", " labels = torch.zeros(len(samples), max_len).long() + pad_idx\n", " imgs = []\n", " for i,s in enumerate(samples): \n", " imgs.append(s[0].data[None])\n", " bbs, lbls = s[1].data\n", " bboxes[i,-len(lbls):] = bbs\n", " labels[i,-len(lbls):] = lbls\n", " return torch.cat(imgs,0), (bboxes,labels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_dl = DataLoader(train_tds, 64, shuffle=False, collate_fn=bb_pad_collate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_sample(dl, rows, start=0):\n", " x,y = next(iter(dl))\n", " x = x[start:start+rows*rows].cpu()\n", " _,axs = plt.subplots(rows,rows,figsize=(9,9))\n", " for i, ax in enumerate(axs.flatten()):\n", " img = Image(x[i])\n", " idxs = y[1][start+i].nonzero()[:,0]\n", " if len(idxs) != 0:\n", " bbs,lbls = y[0][start+i][idxs],y[1][start+i][idxs]\n", " h,w = img.size\n", " bbs = ((bbs+1) * torch.tensor([h/2,w/2, h/2, w/2])).long()\n", " bbox = ImageBBox.create(bbs, *img.size, lbls)\n", " img.show(ax=ax, y=bbox, classes=dl.dataset.classes)\n", " else: img.show(ax=ax)\n", " plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_sample(train_dl, 3, 18)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds, valid_ds = ObjectDetectDataset.from_json(PATH/'train2017', ANNOT_PATH/'train_sample.json', valid_pct=0.2)\n", "data = DataBunch.create(train_ds, valid_ds, path=PATH, ds_tfms=tfms, tfms=imagenet_norm, collate_fn=bb_pad_collate, \n", " num_workers=8, bs=16, size=128, tfm_y=True, padding_mode='zeros', do_crop=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_sample(dl, rows, denorm=None):\n", " x,y = next(iter(dl))\n", " x = x[:rows*rows].cpu()\n", " if denorm: x = denorm(x)\n", " _,axs = plt.subplots(rows,rows,figsize=(9,9))\n", " for i, ax in enumerate(axs.flatten()):\n", " img = Image(x[i])\n", " idxs = y[1][i].nonzero()[:,0]\n", " if len(idxs) != 0:\n", " bbs,lbls = y[0][i][idxs],y[1][i][idxs]\n", " h,w = img.size\n", " bbs = ((bbs.cpu()+1) * torch.tensor([h/2,w/2, h/2, w/2])).long()\n", " bbox = ImageBBox.create(bbs, *img.size, lbls)\n", " img.show(ax=ax, y=bbox, classes=dl.dataset.classes)\n", " else: img.show(ax=ax)\n", " plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_sample(data.train_dl, 3, denorm=imagenet_denorm)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _get_sfs_idxs(sizes:Sizes) -> List[int]:\n", " \"Get the indexes of the layers where the size of the activation changes.\"\n", " feature_szs = [size[-1] for size in sizes]\n", " sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])\n", " if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs\n", " return sfs_idxs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "encoder = create_body(tvm.resnet50(True), -2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class LateralUpsampleMerge(nn.Module):\n", " \n", " def __init__(self, ch, ch_lat, hook):\n", " super().__init__()\n", " self.hook = hook\n", " self.conv_lat = conv2d(ch_lat, ch, ks=1, bias=True)\n", " \n", " def forward(self, x):\n", " return self.conv_lat(self.hook.stored) + F.interpolate(x, scale_factor=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class RetinaNet(nn.Module):\n", " \"Implements RetinaNet from https://arxiv.org/abs/1708.02002\"\n", " def __init__(self, encoder:Model, n_classes, final_bias=0., chs=256, n_anchors=9, flatten=True):\n", " super().__init__()\n", " self.n_classes,self.flatten = n_classes,flatten\n", " imsize = (256,256)\n", " sfs_szs,x,hooks = model_sizes(encoder, size=imsize)\n", " sfs_idxs = _get_sfs_idxs(sfs_szs)\n", " self.encoder = encoder\n", " self.c5top5 = conv2d(sfs_szs[-1][1], chs, ks=1, bias=True)\n", " self.c5top6 = conv2d(sfs_szs[-1][1], chs, stride=2, bias=True)\n", " self.p6top7 = nn.Sequential(nn.ReLU(), conv2d(chs, chs, stride=2, bias=True))\n", " self.merges = nn.ModuleList([LateralUpsampleMerge(chs, szs[1], hook) \n", " for szs,hook in zip(sfs_szs[-2:-4:-1], hooks[-2:-4:-1])])\n", " self.smoothers = nn.ModuleList([conv2d(chs, chs, 3, bias=True) for _ in range(3)])\n", " self.classifier = self._head_subnet(n_classes, n_anchors, final_bias, chs=chs)\n", " self.box_regressor = self._head_subnet(4, n_anchors, 0., chs=chs)\n", " \n", " def _head_subnet(self, n_classes, n_anchors, final_bias=0., n_conv=4, chs=256):\n", " layers = [conv2d_relu(chs, chs, bias=True) for _ in range(n_conv)]\n", " layers += [conv2d(chs, n_classes * n_anchors, bias=True)]\n", " layers[-1].bias.data.zero_().add_(final_bias)\n", " layers[-1].weight.data.fill_(0)\n", " return nn.Sequential(*layers)\n", " \n", " def _apply_transpose(self, func, p_states, n_classes):\n", " if not self.flatten: \n", " sizes = [[p.size(0), p.size(2), p.size(3)] for p in p_states]\n", " return [func(p).permute(0,2,3,1).view(*sz,-1,n_classes) for p,sz in zip(p_states,sizes)]\n", " else:\n", " return torch.cat([func(p).permute(0,2,3,1).contiguous().view(p.size(0),-1,n_classes) for p in p_states],1)\n", " \n", " def forward(self, x):\n", " c5 = self.encoder(x)\n", " p_states = [self.c5top5(c5.clone()), self.c5top6(c5)]\n", " p_states.append(self.p6top7(p_states[-1]))\n", " for merge in self.merges: p_states = [merge(p_states[0])] + p_states\n", " for i, smooth in enumerate(self.smoothers[:3]):\n", " p_states[i] = smooth(p_states[i])\n", " return [self._apply_transpose(self.classifier, p_states, self.n_classes), \n", " self._apply_transpose(self.box_regressor, p_states, 4),\n", " [[p.size(2), p.size(3)] for p in p_states]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "encoder = create_body(tvm.resnet50(True), -2)\n", "model = RetinaNet(encoder, 6, -4) \n", "model.eval()\n", "x = torch.randn(2,3,256,256)\n", "output = model(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "[y.size() for y in output[:2]], output[2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Anchors" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to create the corresponding anchors in this order:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.arange(1,17).long().view(4,4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def create_grid(size):\n", " \"Create a grid of a given `size`.\"\n", " H, W = size if is_tuple(size) else (size,size)\n", " grid = FloatTensor(H, W, 2)\n", " linear_points = torch.linspace(-1+1/W, 1-1/W, W) if W > 1 else tensor([0.])\n", " grid[:, :, 1] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, 0])\n", " linear_points = torch.linspace(-1+1/H, 1-1/H, H) if H > 1 else tensor([0.])\n", " grid[:, :, 0] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, 1])\n", " return grid.view(-1,2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convention (-1.,-1.) to (1.,1.), first is y, second is x (like for the bboxes). -1 is left/top, 1 is right/bottom." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def show_anchors(ancs, size):\n", " _,ax = plt.subplots(1,1, figsize=(5,5))\n", " ax.set_xticks(np.linspace(-1,1, size[1]+1))\n", " ax.set_yticks(np.linspace(-1,1, size[0]+1))\n", " ax.grid()\n", " ax.scatter(ancs[:,1], ancs[:,0]) #y is first\n", " ax.set_yticklabels([])\n", " ax.set_xticklabels([])\n", " ax.set_xlim(-1,1)\n", " ax.set_ylim(1,-1) #-1 is top, 1 is bottom\n", " for i, (x, y) in enumerate(zip(ancs[:, 1], ancs[:, 0])): ax.annotate(i, xy = (x,y))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "size = (4,4)\n", "show_anchors(create_grid(size), size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def create_anchors(sizes, ratios, scales, flatten=True):\n", " \"Create anchor of `sizes`, `ratios` and `scales`.\"\n", " aspects = [[[s*math.sqrt(r), s*math.sqrt(1/r)] for s in scales] for r in ratios]\n", " aspects = torch.tensor(aspects).view(-1,2)\n", " anchors = []\n", " for h,w in sizes:\n", " #4 here to have the anchors overlap.\n", " sized_aspects = 4 * (aspects * torch.tensor([2/h,2/w])).unsqueeze(0)\n", " base_grid = create_grid((h,w)).unsqueeze(1)\n", " n,a = base_grid.size(0),aspects.size(0)\n", " ancs = torch.cat([base_grid.expand(n,a,2), sized_aspects.expand(n,a,2)], 2)\n", " anchors.append(ancs.view(h,w,a,4))\n", " return torch.cat([anc.view(-1,4) for anc in anchors],0) if flatten else anchors" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ratios = [1/2,1,2]\n", "#scales = [1,2**(-1/3), 2**(-2/3)]\n", "scales = [1,2**(1/3), 2**(2/3)]\n", "sizes = [(2**i,2**i) for i in range(5)]\n", "sizes.reverse()\n", "anchors = create_anchors(sizes, ratios, scales)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "anchors.size()\n", "#[anc.size() for anc in anchors]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.cm as cmx\n", "import matplotlib.colors as mcolors\n", "from cycler import cycler\n", "\n", "def get_cmap(N):\n", " color_norm = mcolors.Normalize(vmin=0, vmax=N-1)\n", " return cmx.ScalarMappable(norm=color_norm, cmap='Set3').to_rgba\n", "\n", "num_color = 12\n", "cmap = get_cmap(num_color)\n", "color_list = [cmap(float(x)) for x in range(num_color)]\n", "\n", "def draw_outline(o, lw):\n", " o.set_path_effects([patheffects.Stroke(\n", " linewidth=lw, foreground='black'), patheffects.Normal()])\n", "\n", "def draw_rect(ax, b, color='white'):\n", " patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))\n", " draw_outline(patch, 4)\n", "\n", "def draw_text(ax, xy, txt, sz=14, color='white'):\n", " text = ax.text(*xy, txt,\n", " verticalalignment='top', color=color, fontsize=sz, weight='bold')\n", " draw_outline(text, 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_boxes(boxes):\n", " \"Show the `boxes` (size by 4)\"\n", " _, ax = plt.subplots(1,1, figsize=(5,5))\n", " ax.set_xlim(-1,1)\n", " ax.set_ylim(1,-1)\n", " for i, bbox in enumerate(boxes):\n", " bb = bbox.numpy()\n", " rect = [bb[1]-bb[3]/2, bb[0]-bb[2]/2, bb[3], bb[2]]\n", " draw_rect(ax, rect, color=color_list[i%num_color])\n", " draw_text(ax, [bb[1]-bb[3]/2,bb[0]-bb[2]/2], str(i), color=color_list[i%num_color])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_boxes(anchors[-9:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def activ_to_bbox(acts, anchors, flatten=True):\n", " \"Extrapolate bounding boxes on anchors from the model activations.\"\n", " if flatten:\n", " acts.mul_(acts.new_tensor([[0.1, 0.1, 0.2, 0.2]]))\n", " centers = anchors[...,2:] * acts[...,:2] + anchors[...,:2]\n", " sizes = anchors[...,2:] * torch.exp(acts[...,:2])\n", " return torch.cat([centers, sizes], -1)\n", " else: return [activ_to_bbox(act,anc) for act,anc in zip(acts, anchors)]\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "size=(3,4)\n", "anchors = create_grid(size)\n", "anchors = torch.cat([anchors, torch.tensor([2/size[0],2/size[1]]).expand_as(anchors)], 1)\n", "activations = 0.1 * torch.randn(size[0]*size[1], 4)\n", "bboxes = activ_to_bbox(activations, anchors)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_boxes(bboxes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def cthw2tlbr(boxes):\n", " \"Convert center/size format `boxes` to top/left bottom/right corners.\"\n", " top_left = boxes[:,:2] - boxes[:,2:]/2\n", " bot_right = boxes[:,:2] + boxes[:,2:]/2\n", " return torch.cat([top_left, bot_right], 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def intersection(anchors, targets):\n", " \"Compute the sizes of the intersections of `anchors` by `targets`.\"\n", " ancs, tgts = cthw2tlbr(anchors), cthw2tlbr(targets)\n", " a, t = ancs.size(0), tgts.size(0)\n", " ancs, tgts = ancs.unsqueeze(1).expand(a,t,4), tgts.unsqueeze(0).expand(a,t,4)\n", " top_left_i = torch.max(ancs[...,:2], tgts[...,:2])\n", " bot_right_i = torch.min(ancs[...,2:], tgts[...,2:])\n", " sizes = torch.clamp(bot_right_i - top_left_i, min=0) \n", " return sizes[...,0] * sizes[...,1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_boxes(anchors)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "targets = torch.tensor([[0.,0.,2.,2.], [-0.5,-0.5,1.,1.], [1/3,0.5,0.5,0.5]])\n", "show_boxes(targets)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "intersection(anchors, targets)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def IoU_values(anchors, targets):\n", " \"Compute the IoU values of `anchors` by `targets`.\"\n", " inter = intersection(anchors, targets)\n", " anc_sz, tgt_sz = anchors[:,2] * anchors[:,3], targets[:,2] * targets[:,3]\n", " union = anc_sz.unsqueeze(1) + tgt_sz.unsqueeze(0) - inter\n", " return inter/(union+1e-8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IoU_values(anchors, targets)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Manually checked that those are right." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def match_anchors(anchors, targets, match_thr=0.5, bkg_thr=0.4):\n", " \"Match `anchors` to targets. -1 is match to background, -2 is ignore.\"\n", " ious = IoU_values(anchors, targets)\n", " matches = anchors.new(anchors.size(0)).zero_().long() - 2\n", " vals,idxs = torch.max(ious,1)\n", " matches[vals < bkg_thr] = -1\n", " matches[vals > match_thr] = idxs[vals > match_thr]\n", " #Overwrite matches with each target getting the anchor that has the max IoU.\n", " #vals,idxs = torch.max(ious,0)\n", " #If idxs contains repetition, this doesn't bug and only the last is considered.\n", " #matches[idxs] = targets.new_tensor(list(range(targets.size(0)))).long()\n", " return matches" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Last example" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "match_anchors(anchors, targets)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With anchors very close to the targets." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "size=(3,4)\n", "anchors = create_grid(size)\n", "anchors = torch.cat([anchors, torch.tensor([2/size[0],2/size[1]]).expand_as(anchors)], 1)\n", "activations = 0.1 * torch.randn(size[0]*size[1], 4)\n", "bboxes = activ_to_bbox(activations, anchors)\n", "match_anchors(anchors,bboxes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With anchors in the grey area." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "anchors = create_grid((2,2))\n", "anchors = torch.cat([anchors, torch.tensor([1.,1.]).expand_as(anchors)], 1)\n", "targets = anchors.clone()\n", "anchors = torch.cat([anchors, torch.tensor([[-0.5,0.,1.,1.8]])], 0)\n", "match_anchors(anchors,targets)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def tlbr2cthw(boxes):\n", " \"Convert top/left bottom/right format `boxes` to center/size corners.\"\n", " center = (boxes[:,:2] + boxes[:,2:])/2\n", " sizes = boxes[:,2:] - boxes[:,:2]\n", " return torch.cat([center, sizes], 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def bbox_to_activ(bboxes, anchors, flatten=True):\n", " \"Return the target of the model on `anchors` for the `bboxes`.\"\n", " if flatten:\n", " t_centers = (bboxes[...,:2] - anchors[...,:2]) / anchors[...,2:] \n", " t_sizes = torch.log(bboxes[...,2:] / anchors[...,2:] + 1e-8) \n", " return torch.cat([t_centers, t_sizes], -1).div_(bboxes.new_tensor([[0.1, 0.1, 0.2, 0.2]]))\n", " else: return [activ_to_bbox(act,anc) for act,anc in zip(acts, anchors)]\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def encode_class(idxs, n_classes):\n", " target = idxs.new_zeros(len(idxs), n_classes).float()\n", " mask = idxs != 0\n", " i1s = LongTensor(list(range(len(idxs))))\n", " target[i1s[mask],idxs[mask]-1] = 1\n", " return target" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "encode_class(LongTensor([1,2,0,1,3]),3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class RetinaNetFocalLoss(nn.Module):\n", " \n", " def __init__(self, gamma:float=2., alpha:float=0.25, pad_idx:int=0, scales:Collection[float]=None, \n", " ratios:Collection[float]=None, reg_loss:LossFunction=F.smooth_l1_loss):\n", " super().__init__()\n", " self.gamma,self.alpha,self.pad_idx,self.reg_loss = gamma,alpha,pad_idx,reg_loss\n", " self.scales = ifnone(scales, [1,2**(-1/3), 2**(-2/3)])\n", " self.ratios = ifnone(ratios, [1/2,1,2])\n", " \n", " def _change_anchors(self, sizes:Sizes) -> bool:\n", " if not hasattr(self, 'sizes'): return True\n", " for sz1, sz2 in zip(self.sizes, sizes):\n", " if sz1[0] != sz2[0] or sz1[1] != sz2[1]: return True\n", " return False\n", " \n", " def _create_anchors(self, sizes:Sizes, device:torch.device):\n", " self.sizes = sizes\n", " self.anchors = create_anchors(sizes, self.ratios, self.scales).to(device)\n", " \n", " def _unpad(self, bbox_tgt, clas_tgt):\n", " i = torch.min(torch.nonzero(clas_tgt-self.pad_idx))\n", " return tlbr2cthw(bbox_tgt[i:]), clas_tgt[i:]-1+self.pad_idx\n", " \n", " def _focal_loss(self, clas_pred, clas_tgt):\n", " encoded_tgt = encode_class(clas_tgt, clas_pred.size(1))\n", " ps = torch.sigmoid(clas_pred)\n", " weights = encoded_tgt * (1-ps) + (1-encoded_tgt) * ps\n", " alphas = (1-encoded_tgt) * self.alpha + encoded_tgt * (1-self.alpha)\n", " weights.pow_(self.gamma).mul_(alphas)\n", " clas_loss = F.binary_cross_entropy_with_logits(clas_pred, encoded_tgt, weights, reduction='sum')\n", " return clas_loss\n", " \n", " def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt):\n", " bbox_tgt, clas_tgt = self._unpad(bbox_tgt, clas_tgt)\n", " matches = match_anchors(self.anchors, bbox_tgt)\n", " bbox_mask = matches>=0\n", " if bbox_mask.sum() != 0:\n", " bbox_pred = bbox_pred[bbox_mask]\n", " bbox_tgt = bbox_tgt[matches[bbox_mask]]\n", " bb_loss = self.reg_loss(bbox_pred, bbox_to_activ(bbox_tgt, self.anchors[bbox_mask]))\n", " else: bb_loss = 0.\n", " matches.add_(1)\n", " clas_tgt = clas_tgt + 1\n", " clas_mask = matches>=0\n", " clas_pred = clas_pred[clas_mask]\n", " clas_tgt = torch.cat([clas_tgt.new_zeros(1).long(), clas_tgt])\n", " clas_tgt = clas_tgt[matches[clas_mask]]\n", " return bb_loss + self._focal_loss(clas_pred, clas_tgt)/torch.clamp(bbox_mask.sum(), min=1.)\n", " \n", " def forward(self, output, bbox_tgts, clas_tgts):\n", " clas_preds, bbox_preds, sizes = output\n", " if self._change_anchors(sizes): self._create_anchors(sizes, clas_preds.device)\n", " n_classes = clas_preds.size(2)\n", " return sum([self._one_loss(cp, bp, ct, bt)\n", " for (cp, bp, ct, bt) in zip(clas_preds, bbox_preds, clas_tgts, bbox_tgts)])/clas_tgts.size(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alternative to the L1 smooth loss used in online implementations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class SigmaL1SmoothLoss(nn.Module):\n", "\n", " def forward(self, output, target):\n", " reg_diff = torch.abs(target - output)\n", " reg_loss = torch.where(torch.le(reg_diff, 1/9), 4.5 * torch.pow(reg_diff, 2), reg_diff - 1/18)\n", " return reg_loss.mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sketch to test the loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "LongTensor([[[0,0,64,128,0], [32,64,128,128,1]], [[128,96,256,192,2], [96,192,128,256,3]]]).float().cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tgt_clas = LongTensor([[1,2], [3,4]])\n", "tgt_bbox = FloatTensor([[[0,0,128,64], [64,32,128,128]], [[96,128,192,256], [192,96,256,128]]])\n", "tgt_bbox = tgt_bbox / 128 - 1.\n", "y = [tgt_bbox.cuda(), tgt_clas.cuda()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "clas = torch.load(PATH/'models'/'tst_clas.pth')\n", "regr = torch.load(PATH/'models'/'tst_regr.pth')\n", "sizes = [[32, 32], [16, 16], [8, 8], [4, 4], [2, 2]]\n", "output = [logit(clas), regr, sizes]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "crit(output, *y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Checking the output" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def unpad(tgt_bbox, tgt_clas, pad_idx=0):\n", " i = torch.min(torch.nonzero(tgt_clas-pad_idx))\n", " return tlbr2cthw(tgt_bbox[i:]), tgt_clas[i:]-1+pad_idx" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 0\n", "clas_pred,bbox_pred,sizes = output[0][idx].cpu(), output[1][idx].cpu(), output[2]\n", "bbox_tgt, clas_tgt = y[0][idx].cpu(),y[1][idx].cpu()\n", "bbox_tgt, clas_tgt = unpad(bbox_tgt, clas_tgt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bbox_tgt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "anchors = create_anchors(sizes, ratios, scales)\n", "ious = IoU_values(anchors, bbox_tgt)\n", "matches = match_anchors(anchors, bbox_tgt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ious[-9:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "(matches==-2).sum(), (matches==-1).sum(), (matches>=0).sum()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bbox_mask = matches>=0\n", "bbox_pred = bbox_pred[bbox_mask]\n", "bbox_tgt = bbox_tgt[matches[bbox_mask]]\n", "bb_loss = F.smooth_l1_loss(bbox_pred, bbox_to_activ(bbox_tgt, anchors[bbox_mask]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "F.smooth_l1_loss(bbox_pred, bbox_to_activ(bbox_tgt, anchors[bbox_mask]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_loss = SigmaL1SmoothLoss()\n", "tst_loss(bbox_pred, bbox_to_activ(bbox_tgt, anchors[bbox_mask]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "crit.reg_loss(bbox_pred, bbox_to_activ(bbox_tgt, anchors[bbox_mask]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "matches.add_(1)\n", "clas_tgt += 1\n", "clas_mask = matches>=0\n", "clas_pred = clas_pred[clas_mask]\n", "clas_tgt = torch.cat([clas_tgt.new_zeros(1).long(), clas_tgt])\n", "clas_tgt = clas_tgt[matches[clas_mask]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Focal loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "alpha, gamma, n_classes = 0.25, 2., 6\n", "encoded_tgt = encode_class(clas_tgt, n_classes)\n", "ps = torch.sigmoid(clas_pred)\n", "weights = encoded_tgt * (1-ps) + (1-encoded_tgt) * ps\n", "alphas = encoded_tgt * alpha + (1-encoded_tgt) * (1-alpha)\n", "weights.pow_(gamma).mul_(alphas)\n", "clas_loss = F.binary_cross_entropy_with_logits(clas_pred, encoded_tgt, weights, reduction='sum') / bbox_mask.sum()\n", "clas_loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's look at the objects missclassified." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "clas_pred[clas_tgt.nonzero().squeeze()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "F.binary_cross_entropy_with_logits(clas_pred[clas_tgt.nonzero().squeeze()], encoded_tgt[clas_tgt.nonzero().squeeze()], weights[clas_tgt.nonzero().squeeze()], reduction='sum') / bbox_mask.sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "They account for half the loss!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_classes = 6\n", "encoder = create_body(tvm.resnet50(True), -2)\n", "model = RetinaNet(encoder, n_classes,final_bias=-4) \n", "crit = RetinaNetFocalLoss(scales=scales, ratios=ratios)\n", "learn = Learner(data, model, loss_fn=crit)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.split([model.encoder[6], model.c5top5])\n", "learn.freeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(1, 1e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('sample')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.load('sample')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img,target = next(iter(data.valid_dl))\n", "with torch.no_grad():\n", " output = model(img)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.save(img, PATH/'models'/'tst_input.pth')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _draw_outline(o:Patch, lw:int):\n", " \"Outline bounding box onto image `Patch`.\"\n", " o.set_path_effects([patheffects.Stroke(\n", " linewidth=lw, foreground='black'), patheffects.Normal()])\n", "\n", "def draw_rect(ax:plt.Axes, b:Collection[int], color:str='white', text=None, text_size=14):\n", " \"Draw bounding box on `ax`.\"\n", " patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))\n", " _draw_outline(patch, 4)\n", " if text is not None:\n", " patch = ax.text(*b[:2], text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')\n", " _draw_outline(patch,1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_preds(img, output, idx, detect_thresh=0.3, classes=None):\n", " clas_pred,bbox_pred,sizes = output[0][idx].cpu(), output[1][idx].cpu(), output[2]\n", " anchors = create_anchors(sizes, ratios, scales)\n", " bbox_pred = activ_to_bbox(bbox_pred, anchors)\n", " clas_pred = torch.sigmoid(clas_pred)\n", " detect_mask = clas_pred.max(1)[0] > detect_thresh\n", " bbox_pred, clas_pred = bbox_pred[detect_mask], clas_pred[detect_mask]\n", " t_sz = torch.Tensor([*img.size])[None].float()\n", " bbox_pred[:,:2] = bbox_pred[:,:2] - bbox_pred[:,2:]/2\n", " bbox_pred[:,:2] = (bbox_pred[:,:2] + 1) * t_sz/2\n", " bbox_pred[:,2:] = bbox_pred[:,2:] * t_sz\n", " bbox_pred = bbox_pred.long()\n", " _, ax = plt.subplots(1,1)\n", " for bbox, c in zip(bbox_pred, clas_pred.argmax(1)):\n", " img.show(ax=ax)\n", " txt = str(c.item()) if classes is None else classes[c.item()+1]\n", " draw_rect(ax, [bbox[1],bbox[0],bbox[3],bbox[2]], text=txt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 0\n", "img = data.valid_ds[idx][0]\n", "classes = data.train_ds.classes\n", "show_preds(img, output, idx, detect_thresh=0.2, classes=classes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def nms(boxes, scores, thresh=0.5):\n", " idx_sort = scores.argsort(descending=True)\n", " boxes, scores = boxes[idx_sort], scores[idx_sort]\n", " to_keep, indexes = [], torch.LongTensor(range_of(scores))\n", " while len(scores) > 0:\n", " #pdb.set_trace()\n", " to_keep.append(idx_sort[indexes[0]])\n", " iou_vals = IoU_values(boxes, boxes[:1]).squeeze()\n", " mask_keep = iou_vals <= thresh\n", " if len(mask_keep.nonzero()) == 0: break\n", " idx_first = mask_keep.nonzero().min().item()\n", " boxes, scores, indexes = boxes[mask_keep], scores[mask_keep], indexes[mask_keep]\n", " return LongTensor(to_keep)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def process_output(output, i, detect_thresh=0.25):\n", " clas_pred,bbox_pred,sizes = output[0][i], output[1][i], output[2]\n", " anchors = create_anchors(sizes, ratios, scales).to(clas_pred.device)\n", " bbox_pred = activ_to_bbox(bbox_pred, anchors)\n", " clas_pred = torch.sigmoid(clas_pred)\n", " detect_mask = clas_pred.max(1)[0] > detect_thresh\n", " bbox_pred, clas_pred = bbox_pred[detect_mask], clas_pred[detect_mask]\n", " bbox_pred = tlbr2cthw(torch.clamp(cthw2tlbr(bbox_pred), min=-1, max=1)) \n", " scores, preds = clas_pred.max(1)\n", " return bbox_pred, scores, preds" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_preds(img, output, idx, detect_thresh=0.25, classes=None):\n", " bbox_pred, scores, preds = process_output(output, idx, detect_thresh)\n", " to_keep = nms(bbox_pred, scores)\n", " bbox_pred, preds, scores = bbox_pred[to_keep].cpu(), preds[to_keep].cpu(), scores[to_keep].cpu()\n", " t_sz = torch.Tensor([*img.size])[None].float()\n", " bbox_pred[:,:2] = bbox_pred[:,:2] - bbox_pred[:,2:]/2\n", " bbox_pred[:,:2] = (bbox_pred[:,:2] + 1) * t_sz/2\n", " bbox_pred[:,2:] = bbox_pred[:,2:] * t_sz\n", " bbox_pred = bbox_pred.long()\n", " _, ax = plt.subplots(1,1)\n", " for bbox, c, scr in zip(bbox_pred, preds, scores):\n", " img.show(ax=ax)\n", " txt = str(c.item()) if classes is None else classes[c.item()+1]\n", " draw_rect(ax, [bbox[1],bbox[0],bbox[3],bbox[2]], text=f'{txt} {scr:.2f}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 0\n", "img = data.valid_ds[idx][0]\n", "show_preds(img, output, idx, detect_thresh=0.2, classes=data.classes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def get_predictions(output, idx, detect_thresh=0.05):\n", " bbox_pred, scores, preds = process_output(output, idx, detect_thresh)\n", " to_keep = nms(bbox_pred, scores)\n", " return bbox_pred[to_keep], preds[to_keep], scores[to_keep]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "get_predictions(output, 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## mAP" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def compute_ap(precision, recall):\n", " \"Compute the average precision for `precision` and `recall` curve.\"\n", " recall = np.concatenate(([0.], list(recall), [1.]))\n", " precision = np.concatenate(([0.], list(precision), [0.]))\n", " for i in range(len(precision) - 1, 0, -1):\n", " precision[i - 1] = np.maximum(precision[i - 1], precision[i])\n", " idx = np.where(recall[1:] != recall[:-1])[0]\n", " ap = np.sum((recall[idx + 1] - recall[idx]) * precision[idx + 1])\n", " return ap" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def compute_class_AP(model, dl, n_classes, iou_thresh=0.5, detect_thresh=0.05, num_keep=100):\n", " tps, clas, p_scores = [], [], []\n", " classes, n_gts = LongTensor(range(n_classes)),torch.zeros(n_classes).long()\n", " with torch.no_grad():\n", " for input,target in progress_bar(dl):\n", " output = model(input)\n", " for i in range(target[0].size(0)):\n", " bbox_pred, preds, scores = get_predictions(output, i, detect_thresh)\n", " tgt_bbox, tgt_clas = unpad(target[0][i], target[1][i])\n", " ious = IoU_values(bbox_pred, tgt_bbox)\n", " max_iou, matches = ious.max(1)\n", " detected = []\n", " for i in range_of(preds):\n", " if max_iou[i] >= iou_thresh and matches[i] not in detected and tgt_clas[matches[i]] == preds[i]:\n", " detected.append(matches[i])\n", " tps.append(1)\n", " else: tps.append(0)\n", " clas.append(preds.cpu())\n", " p_scores.append(scores.cpu())\n", " n_gts += (tgt_clas.cpu()[:,None] == classes[None,:]).sum(0)\n", " tps, p_scores, clas = torch.tensor(tps), torch.cat(p_scores,0), torch.cat(clas,0)\n", " fps = 1-tps\n", " idx = p_scores.argsort(descending=True)\n", " tps, fps, clas = tps[idx], fps[idx], clas[idx]\n", " aps = []\n", " #return tps, clas\n", " for cls in range(n_classes):\n", " tps_cls, fps_cls = tps[clas==cls].float().cumsum(0), fps[clas==cls].float().cumsum(0)\n", " if tps_cls[-1] != 0:\n", " precision = tps_cls / (tps_cls + fps_cls + 1e-8)\n", " recall = tps_cls / (n_gts[cls] + 1e-8)\n", " aps.append(compute_ap(precision, recall))\n", " else: aps.append(0.)\n", " return aps" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "L = compute_class_AP(learn.model, tst_dl, 6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "L[0]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }