{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from nb_006c import *\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Imagenet Object Localization\n", "\n", "see https://www.kaggle.com/c/imagenet-object-localization-challenge
\n", "we are using a reduced dataset (only 28 of the 1000 classes in the challenge above)
\n", "You can download it here:\n", "https://www.kaggle.com/fm313v/imgnet-obj-loc-small (4.34GB)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('data/imgnetloc-small')\n", "TRAIN_CSV = PATH/'LOC_train_solution.csv'\n", "VALID_CSV = PATH/'LOC_val_solution.csv'\n", "CLASSES_TXT = PATH/'LOC_synset_mapping.txt'\n", "ILSVRC = PATH/'ILSVRC'\n", "IMG_PATH = ILSVRC/'Data/CLS-LOC'\n", "TRAIN_IMG = IMG_PATH/'train'\n", "VALID_IMG = IMG_PATH/'val'\n", "ANNO_PATH = ILSVRC/'Annotations/CLS-LOC'\n", "TRAIN_ANNO = ANNO_PATH/'train'\n", "VALID_ANNO = ANNO_PATH/'val'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df = pd.read_csv(TRAIN_CSV)\n", "valid_df = pd.read_csv(VALID_CSV)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "valid_df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def read_classes():\n", " classes = {}\n", " with open(CLASSES_TXT, 'r') as class_file:\n", " lines = class_file.readlines()\n", " for line in lines:\n", " classes[line[0:9]] = line[10:].strip().split(',')[0] # strip extra items after ','\n", " return classes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from matplotlib import patches, patheffects\n", "\n", "def bb_hw(a): return np.array([a[1],a[0],a[3]-a[1]+1,a[2]-a[0]+1])\n", "\n", "def draw_outline(o, lw):\n", " o.set_path_effects([patheffects.Stroke(\n", " linewidth=lw, foreground='black'), patheffects.Normal()])\n", "\n", "def draw_rect(ax, b, color='white'):\n", " patch = ax.add_patch(patches.Rectangle(b[:2], *b[2:], fill=False, edgecolor=color, lw=2))\n", " draw_outline(patch, 4)\n", "\n", "def draw_text(ax, xy, txt, sz=14):\n", " text = ax.text(*xy, txt,\n", " verticalalignment='top', color='white', fontsize=sz, weight='bold')\n", " draw_outline(text, 1)\n", "\n", "def show_img_annos(img, annos, lbl_to_txt=None, ax=None):\n", " if not ax: fig,ax = plt.subplots()\n", " ax.imshow(img.numpy().transpose(1,2,0))\n", " for anno in annos: draw_anno(ax, anno, lbl_to_txt=lbl_to_txt)\n", "\n", "def show_img_anno(img, anno, lbl_to_txt=None, ax=None):\n", " if not ax: fig,ax = plt.subplots()\n", " ax.imshow(img.numpy().transpose(1,2,0))\n", " draw_anno(ax, anno, lbl_to_txt=lbl_to_txt)\n", "\n", "def draw_anno(ax, anno, lbl_to_txt=None):\n", " c, bb = anno\n", " b = bb_hw(bb)\n", " draw_rect(ax, b)\n", " if lbl_to_txt: draw_text(ax, b[:2], lbl_to_txt[c], sz=16)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class_to_text = read_classes()\n", "lbl_to_class = dict(enumerate(class_to_text.keys()))\n", "class_to_lbl = {v:k for k,v in lbl_to_class.items()}\n", "lbl_to_text = { i:class_to_text[c] for i,c in lbl_to_class.items()}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def pull_class_id(x): return x.split(' ')[0]\n", "\n", "def train_to_image_path(x):\n", " class_id = x.split('_')[0]\n", " return TRAIN_IMG/class_id/f'{x}.JPEG'\n", "\n", "def train_to_anno_path(x):\n", " class_id = pull_class_id(x)\n", " return TRAIN_ANNO/class_id/f'{x}.xml'\n", "\n", "def valid_to_image_path(x): return VALID_IMG/f'{x}.JPEG'\n", "def valid_to_anno_path(x): return VALID_IMG/f'{x}.xml'\n", "\n", "\n", "train_df['image_fn'] = train_df.ImageId.apply(train_to_image_path)\n", "train_df['anno_fn'] = train_df.ImageId.apply(train_to_anno_path)\n", "train_df['class_id'] = train_df.PredictionString.apply(pull_class_id)\n", "\n", "valid_df['image_fn'] = valid_df.ImageId.apply(valid_to_image_path)\n", "valid_df['anno_fn'] = valid_df.ImageId.apply(valid_to_anno_path)\n", "valid_df['class_id'] = valid_df.PredictionString.apply(pull_class_id)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def to_preds(x):\n", " boxes = []\n", " items = x.strip().split(' ')\n", " for i in range(0,len(items),5):\n", " class_id, left, top, right, bottom = items[i:(i+5)]\n", " c = class_to_lbl[class_id]\n", " boxes.append((c, [float(top), float(left), float(bottom), float(right)]))\n", " return boxes\n", "\n", "train_fns = list(train_df.image_fn)\n", "train_annos = list(train_df.PredictionString.apply(to_preds))\n", "valid_fns = list(valid_df.image_fn)\n", "valid_annos = list(valid_df.PredictionString.apply(to_preds))\n", "\n", "def get_biggest_annos(img_annos):\n", " biggest_annos = []\n", " \n", " j = 0\n", " for annos in img_annos:\n", " size,best = 0,0\n", " for i, anno in enumerate(annos):\n", " c, bb = anno\n", " b = bb_hw(bb)\n", " o_sz = b[2] * b[3]\n", " if size < o_sz: size,best = o_sz,i\n", " biggest_annos.append(annos[best])\n", " j += 1\n", " return biggest_annos\n", "\n", "train_annos_lrg = get_biggest_annos(train_annos)\n", "valid_annos_lrg= get_biggest_annos(valid_annos)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 15460\n", "img = open_image(train_df.image_fn[idx])\n", "annos = train_annos[idx]\n", "show_img_annos(img, annos, lbl_to_text)\n", "show_img_anno(img, train_annos_lrg[idx], lbl_to_text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Largest item classifier" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class AnnoTargetDataset(Dataset):\n", " x_fns:List[Path]; bbs:Tuple[int, List[float]]\n", " def __post_init__(self): assert len(self.x_fns)==len(self.bbs)\n", " def __repr__(self): return f'{type(self).__name__} of len {len(self.x_fns)}'\n", " def __len__(self): return len(self.x_fns)\n", " def __getitem__(self, i): \n", " return open_image(self.x_fns[i]), self.bbs[i]\n", "\n", "train_ds = AnnoTargetDataset(train_fns, train_annos_lrg)\n", "valid_ds = AnnoTargetDataset(valid_fns, valid_annos_lrg)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x, y = next(iter(train_ds))\n", "show_img_anno(x, y, lbl_to_text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torchvision.models import resnet18, resnet34\n", "arch = resnet34\n", "\n", "# imagenet mean / std\n", "data_mean, data_std = map(tensor, ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))\n", "data_norm,data_denorm = normalize_funcs(data_mean,data_std)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs = 128\n", "size=128\n", "workers=0\n", "\n", "def get_data(bs, size):\n", " tfms = get_transforms(do_flip=True, max_rotate=10, max_zoom=1.2, max_lighting=0.3, max_warp=0.15)\n", " tds = transform_datasets(train_ds, valid_ds, tfms, size=size)\n", " data = DataBunch.create(*tds, bs=bs, num_workers=workers, tfms=data_norm)\n", " return data\n", "\n", "data = get_data(bs, size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = next(iter(data.train_dl))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tensor([[5,6],[1,2,3]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "type(b)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }