{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from nb_006a import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Pascal" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('data/pascal')\n", "JPEG_PATH = PATH/'VOCdevkit'/'VOC2007'/'JPEGImages'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "trn_j = json.load((PATH / 'pascal_train2007.json').open())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "classes = {o['id']:o['name'] for o in trn_j['categories']}\n", "filenames = {o['id']:JPEG_PATH/o['file_name'] for o in trn_j['images']}\n", "annotations = [{'img_id': o['image_id'], \n", " 'class': classes[o['category_id']], \n", " 'bbox':o['bbox']} for o in trn_j['annotations'] if not o['ignore']]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(annotations)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "annot_by_img = collections.defaultdict(list)\n", "for annot in annotations:\n", " annot_by_img[annot['img_id']].append({'class': annot['class'], 'bbox': annot['bbox']})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(annot_by_img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let's do build a model finding the biggest bbox." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "biggest_bb = {}\n", "for id in filenames.keys():\n", " size,best = 0,0\n", " for i,o in enumerate(annot_by_img[id]):\n", " o_sz = o['bbox'][2] * o['bbox'][3]\n", " if size < o_sz:\n", " size,best = o_sz,i\n", " biggest_bb[id] = annot_by_img[id][best]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ids = np.array(list(filenames.keys()))\n", "ids = np.random.permutation(ids)\n", "split = int(len(filenames) * 0.2)\n", "train_fns = [filenames[i] for i in ids[split:]]\n", "valid_fns = [filenames[i] for i in ids[:split]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bboxes = {}\n", "for i in filenames.keys():\n", " bb = biggest_bb[i]['bbox']\n", " bboxes[i] = [[bb[1],bb[0], bb[3]+bb[1], bb[2]+bb[0]]]\n", "\n", "train_bbs = [bboxes[i] for i in ids[split:]]\n", "valid_bbs = [bboxes[i] for i in ids[:split]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_bboxes = collections.defaultdict(list)\n", "for i in filenames.keys():\n", " for o in annot_by_img[i]:\n", " bb = o['bbox']\n", " all_bboxes[i].append([bb[1],bb[0], bb[3]+bb[1], bb[2]+bb[0]])\n", " \n", "train_all_bbs = [all_bboxes[i] for i in ids[split:]]\n", "valid_all_bbs = [all_bboxes[i] for i in ids[:split]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ImageBBox(ImageMask):\n", " \"Image class for bbox-style annotations\"\n", " def clone(self):\n", " return self.__class__(self.px.clone())\n", " \n", " @classmethod\n", " def create(cls, bboxes:Collection[Collection[int]], h:int, w:int) -> 'ImageBBox':\n", " \"Creates an ImageBBox object from bboxes\"\n", " pxls = torch.zeros(len(bboxes),h, w).long()\n", " for i,bbox in enumerate(bboxes):\n", " pxls[i,bbox[0]:bbox[2]+1,bbox[1]:bbox[3]+1] = 1\n", " return cls(pxls.float())\n", " \n", " @property\n", " def data(self) -> LongTensor:\n", " bboxes = []\n", " for i in range(self.px.size(0)):\n", " idxs = torch.nonzero(self.px[i])\n", " if len(idxs) != 0:\n", " bboxes.append(torch.tensor([idxs[:,0].min(), idxs[:,1].min(), idxs[:,0].max(), idxs[:,1].max()])[None])\n", " return torch.cat(bboxes, 0).squeeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from matplotlib import patches, patheffects\n", "from matplotlib.patches import Patch\n", "\n", "def bb2hw(a:Collection[int]) -> np.ndarray: \n", " \"Converts bounding box points from (width,height,center) to (height,width,top,left)\"\n", " return np.array([a[1],a[0],a[3]-a[1],a[2]-a[0]])\n", "\n", "def draw_outline(o:Patch, lw:int):\n", " \"Outlines bounding box onto image `Patch`\"\n", " o.set_path_effects([patheffects.Stroke(\n", " linewidth=lw, foreground='black'), patheffects.Normal()])\n", "\n", "def draw_rect(ax:plt.Axes, b:Collection[int], color:str='white'):\n", " \"Draws bounding box on `ax`\"\n", " patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))\n", " draw_outline(patch, 4)\n", "\n", "def _show_image(img:Image, ax:plt.Axes=None, figsize:tuple=(3,3), hide_axis:bool=True, cmap:str='binary', \n", " alpha:float=None) -> plt.Axes:\n", " if ax is None: fig,ax = plt.subplots(figsize=figsize)\n", " ax.imshow(image2np(img), cmap=cmap, alpha=alpha)\n", " if hide_axis: ax.axis('off')\n", " return ax\n", "\n", "def show_image(x:Image, y:Image=None, ax:plt.Axes=None, figsize:tuple=(3,3), alpha:float=0.5, \n", " hide_axis:bool=True, cmap:str='viridis'):\n", " ax1 = _show_image(x, ax=ax, hide_axis=hide_axis, cmap=cmap)\n", " if y is not None: _show_image(y, ax=ax1, alpha=alpha, hide_axis=hide_axis, cmap=cmap)\n", " if hide_axis: ax1.axis('off')\n", " \n", "def _show(self:Image, ax:plt.Axes=None, y:Image=None, **kwargs):\n", " if y is not None:\n", " is_bb = isinstance(y, ImageBBox)\n", " y=y.data\n", " if not is_bb: return show_image(self.data, ax=ax, y=y, **kwargs)\n", " ax = _show_image(self.data, ax=ax)\n", " if len(y.size()) == 1: draw_rect(ax, bb2hw(y))\n", " else:\n", " for i in range(y.size(0)): draw_rect(ax, bb2hw(y[i]))\n", "\n", "Image.show = _show" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@dataclass\n", "class CoordTargetDataset(Dataset):\n", " \"A dataset with annotated images\"\n", " x_fns:Collection[Path]\n", " bbs:Collection[Collection[int]]\n", " def __post_init__(self): assert len(self.x_fns)==len(self.bbs)\n", " def __repr__(self) -> str: return f'{type(self).__name__} of len {len(self.x_fns)}'\n", " def __len__(self) -> int: return len(self.x_fns)\n", " def __getitem__(self, i:int) -> Tuple[Image,ImageBBox]: \n", " x = open_image(self.x_fns[i])\n", " return x, ImageBBox.create(self.bbs[i], *x.size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds = CoordTargetDataset(train_fns, train_all_bbs)\n", "valid_ds = CoordTargetDataset(valid_fns, valid_all_bbs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_fns.index(Path(JPEG_PATH/'000012.jpg'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = train_ds[1477]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y.data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x.show(y=y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x.show(y=ImageMask(y.px[0].unsqueeze(0)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y.data, valid_all_bbs[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_tds = DatasetTfm(train_ds, tfms=tfms[0], tfm_y=True, size=128, padding_mode='border')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = train_tds[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig,axs = plt.subplots(4,4, figsize=(10,10))\n", "for ax in axs.flatten():\n", " x,y = train_tds[0]\n", " x.show(ax=ax,y=y)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs,sz=4,224\n", "tfms = get_transforms(do_flip=True, max_rotate=4, max_lighting=0.2)\n", "train_ds = CoordTargetDataset(train_fns, train_bbs)\n", "valid_ds = CoordTargetDataset(valid_fns, valid_bbs)\n", "data = DataBunch.create(train_ds, valid_ds, path=PATH, bs=bs, num_workers=0, ds_tfms=tfms, size=sz, tfms=imagenet_norm, \n", " padding_mode='border')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We take a pretrained resnet34 with a custom head." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "arch = tvm.resnet34" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = create_body(arch(), -2)\n", "num_features(model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def custom_loss(output, target):\n", " target = target.float().div_(sz)\n", " return F.l1_loss(output, target)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "arch = tvm.resnet34\n", "head_reg4 = nn.Sequential(Flatten(), nn.Linear(512 * 7*7,4), nn.Sigmoid())\n", "learn = ConvLearner(data, arch, metrics=accuracy, custom_head=head_reg4)\n", "learn.loss_fn = custom_loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }