{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.conv_learner import *\n", "from fastai.dataset import *\n", "\n", "import json, pdb\n", "from PIL import ImageDraw, ImageFont\n", "from matplotlib import patches, patheffects\n", "torch.cuda.set_device(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.backends.cudnn.benchmark=True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('data/pascal')\n", "trn_j = json.load((PATH / 'pascal_train2007.json').open())\n", "IMAGES,ANNOTATIONS,CATEGORIES = ['images', 'annotations', 'categories']\n", "FILE_NAME,ID,IMG_ID,CAT_ID,BBOX = 'file_name','id','image_id','category_id','bbox'\n", "\n", "cats = dict((o[ID], o['name']) for o in trn_j[CATEGORIES])\n", "trn_fns = dict((o[ID], o[FILE_NAME]) for o in trn_j[IMAGES])\n", "trn_ids = [o[ID] for o in trn_j[IMAGES]]\n", "\n", "JPEGS = 'VOCdevkit/VOC2007/JPEGImages'\n", "IMG_PATH = PATH/JPEGS" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_trn_anno():\n", " trn_anno = collections.defaultdict(lambda:[])\n", " for o in trn_j[ANNOTATIONS]:\n", " if not o['ignore']:\n", " bb = o[BBOX]\n", " bb = np.array([bb[1], bb[0], bb[3]+bb[1]-1, bb[2]+bb[0]-1])\n", " trn_anno[o[IMG_ID]].append((bb,o[CAT_ID]))\n", " return trn_anno\n", "\n", "trn_anno = get_trn_anno()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def show_img(im, figsize=None, ax=None):\n", " if not ax: fig,ax = plt.subplots(figsize=figsize)\n", " ax.imshow(im)\n", " ax.set_xticks(np.linspace(0, 224, 8))\n", " ax.set_yticks(np.linspace(0, 224, 8))\n", " ax.grid()\n", " ax.set_yticklabels([])\n", " ax.set_xticklabels([])\n", " return ax\n", "\n", "def draw_outline(o, lw):\n", " o.set_path_effects([patheffects.Stroke(\n", " linewidth=lw, foreground='black'), patheffects.Normal()])\n", "\n", "def draw_rect(ax, b, color='white'):\n", " patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))\n", " draw_outline(patch, 4)\n", "\n", "def draw_text(ax, xy, txt, sz=14, color='white'):\n", " text = ax.text(*xy, txt,\n", " verticalalignment='top', color=color, fontsize=sz, weight='bold')\n", " draw_outline(text, 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def bb_hw(a): return np.array([a[1],a[0],a[3]-a[1]+1,a[2]-a[0]+1])\n", "\n", "def draw_im(im, ann):\n", " ax = show_img(im, figsize=(16,8))\n", " for b,c in ann:\n", " b = bb_hw(b)\n", " draw_rect(ax, b)\n", " draw_text(ax, b[:2], cats[c], sz=16)\n", "\n", "def draw_idx(i):\n", " im_a = trn_anno[i]\n", " im = open_image(IMG_PATH/trn_fns[i])\n", " draw_im(im, im_a)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multi class" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "MC_CSV = PATH/'tmp/mc.csv'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(array([ 96, 155, 269, 350]), 7)]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trn_anno[12]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mc = [set([cats[p[1]] for p in trn_anno[o]]) for o in trn_ids]\n", "mcs = [' '.join(str(p) for p in o) for o in mc]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame({'fn': [trn_fns[o] for o in trn_ids], 'clas': mcs}, columns=['fn','clas'])\n", "df.to_csv(MC_CSV, index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f_model=resnet34\n", "sz=224\n", "bs=64" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tfms = tfms_from_model(f_model, sz, crop_type=CropType.NO)\n", "md = ImageClassifierData.from_csv(PATH, JPEGS, MC_CSV, tfms=tfms, bs=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner.pretrained(f_model, md)\n", "learn.opt_fn = optim.Adam" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9ccdb0e1cae54d27b5d85d30a5c37b4f", "version_major": 2, "version_minor": 0 }, "text/html": [ "
Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_lossFailed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=1), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " 81%|█████████████████████████████████████████████████████████▋ | 26/32 [00:22<00:05, 1.15it/s, loss=0.33]" ] }, { "data": { "image/png": "\n", "text/plain": [ "\n", " | fn | \n", "bbox | \n", "
---|---|---|
0 | \n", "000012.jpg | \n", "96 155 269 350 | \n", "
1 | \n", "000017.jpg | \n", "61 184 198 278 77 89 335 402 | \n", "
2 | \n", "000023.jpg | \n", "229 8 499 244 219 229 499 333 0 1 368 116 1 2 ... | \n", "
3 | \n", "000026.jpg | \n", "124 89 211 336 | \n", "
4 | \n", "000032.jpg | \n", "77 103 182 374 87 132 122 196 179 194 228 212 ... | \n", "