{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from nb_006c import *\n",
"import pandas as pd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imagenet Object Localization\n",
"\n",
"see https://www.kaggle.com/c/imagenet-object-localization-challenge
\n",
"we are using a reduced dataset (only 28 of the 1000 classes in the challenge above)
\n",
"You can download it here:\n",
"https://www.kaggle.com/fm313v/imgnet-obj-loc-small (4.34GB)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"PATH = Path('data/imgnetloc-small')\n",
"TRAIN_CSV = PATH/'LOC_train_solution.csv'\n",
"VALID_CSV = PATH/'LOC_val_solution.csv'\n",
"CLASSES_TXT = PATH/'LOC_synset_mapping.txt'\n",
"ILSVRC = PATH/'ILSVRC'\n",
"IMG_PATH = ILSVRC/'Data/CLS-LOC'\n",
"TRAIN_IMG = IMG_PATH/'train'\n",
"VALID_IMG = IMG_PATH/'val'\n",
"ANNO_PATH = ILSVRC/'Annotations/CLS-LOC'\n",
"TRAIN_ANNO = ANNO_PATH/'train'\n",
"VALID_ANNO = ANNO_PATH/'val'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df = pd.read_csv(TRAIN_CSV)\n",
"valid_df = pd.read_csv(VALID_CSV)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"valid_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def read_classes():\n",
" classes = {}\n",
" with open(CLASSES_TXT, 'r') as class_file:\n",
" lines = class_file.readlines()\n",
" for line in lines:\n",
" classes[line[0:9]] = line[10:].strip().split(',')[0] # strip extra items after ','\n",
" return classes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import patches, patheffects\n",
"\n",
"def bb_hw(a): return np.array([a[1],a[0],a[3]-a[1]+1,a[2]-a[0]+1])\n",
"\n",
"def draw_outline(o, lw):\n",
" o.set_path_effects([patheffects.Stroke(\n",
" linewidth=lw, foreground='black'), patheffects.Normal()])\n",
"\n",
"def draw_rect(ax, b, color='white'):\n",
" patch = ax.add_patch(patches.Rectangle(b[:2], *b[2:], fill=False, edgecolor=color, lw=2))\n",
" draw_outline(patch, 4)\n",
"\n",
"def draw_text(ax, xy, txt, sz=14):\n",
" text = ax.text(*xy, txt,\n",
" verticalalignment='top', color='white', fontsize=sz, weight='bold')\n",
" draw_outline(text, 1)\n",
"\n",
"def show_img_annos(img, annos, lbl_to_txt=None, ax=None):\n",
" if not ax: fig,ax = plt.subplots()\n",
" ax.imshow(img.numpy().transpose(1,2,0))\n",
" for anno in annos: draw_anno(ax, anno, lbl_to_txt=lbl_to_txt)\n",
"\n",
"def show_img_anno(img, anno, lbl_to_txt=None, ax=None):\n",
" if not ax: fig,ax = plt.subplots()\n",
" ax.imshow(img.numpy().transpose(1,2,0))\n",
" draw_anno(ax, anno, lbl_to_txt=lbl_to_txt)\n",
"\n",
"def draw_anno(ax, anno, lbl_to_txt=None):\n",
" c, bb = anno\n",
" b = bb_hw(bb)\n",
" draw_rect(ax, b)\n",
" if lbl_to_txt: draw_text(ax, b[:2], lbl_to_txt[c], sz=16)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class_to_text = read_classes()\n",
"lbl_to_class = dict(enumerate(class_to_text.keys()))\n",
"class_to_lbl = {v:k for k,v in lbl_to_class.items()}\n",
"lbl_to_text = { i:class_to_text[c] for i,c in lbl_to_class.items()}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def pull_class_id(x): return x.split(' ')[0]\n",
"\n",
"def train_to_image_path(x):\n",
" class_id = x.split('_')[0]\n",
" return TRAIN_IMG/class_id/f'{x}.JPEG'\n",
"\n",
"def train_to_anno_path(x):\n",
" class_id = pull_class_id(x)\n",
" return TRAIN_ANNO/class_id/f'{x}.xml'\n",
"\n",
"def valid_to_image_path(x): return VALID_IMG/f'{x}.JPEG'\n",
"def valid_to_anno_path(x): return VALID_IMG/f'{x}.xml'\n",
"\n",
"\n",
"train_df['image_fn'] = train_df.ImageId.apply(train_to_image_path)\n",
"train_df['anno_fn'] = train_df.ImageId.apply(train_to_anno_path)\n",
"train_df['class_id'] = train_df.PredictionString.apply(pull_class_id)\n",
"\n",
"valid_df['image_fn'] = valid_df.ImageId.apply(valid_to_image_path)\n",
"valid_df['anno_fn'] = valid_df.ImageId.apply(valid_to_anno_path)\n",
"valid_df['class_id'] = valid_df.PredictionString.apply(pull_class_id)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def to_preds(x):\n",
" boxes = []\n",
" items = x.strip().split(' ')\n",
" for i in range(0,len(items),5):\n",
" class_id, left, top, right, bottom = items[i:(i+5)]\n",
" c = class_to_lbl[class_id]\n",
" boxes.append((c, [float(top), float(left), float(bottom), float(right)]))\n",
" return boxes\n",
"\n",
"train_fns = list(train_df.image_fn)\n",
"train_annos = list(train_df.PredictionString.apply(to_preds))\n",
"valid_fns = list(valid_df.image_fn)\n",
"valid_annos = list(valid_df.PredictionString.apply(to_preds))\n",
"\n",
"def get_biggest_annos(img_annos):\n",
" biggest_annos = []\n",
" \n",
" j = 0\n",
" for annos in img_annos:\n",
" size,best = 0,0\n",
" for i, anno in enumerate(annos):\n",
" c, bb = anno\n",
" b = bb_hw(bb)\n",
" o_sz = b[2] * b[3]\n",
" if size < o_sz: size,best = o_sz,i\n",
" biggest_annos.append(annos[best])\n",
" j += 1\n",
" return biggest_annos\n",
"\n",
"train_annos_lrg = get_biggest_annos(train_annos)\n",
"valid_annos_lrg= get_biggest_annos(valid_annos)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = 15460\n",
"img = open_image(train_df.image_fn[idx])\n",
"annos = train_annos[idx]\n",
"show_img_annos(img, annos, lbl_to_text)\n",
"show_img_anno(img, train_annos_lrg[idx], lbl_to_text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Largest item classifier"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class AnnoTargetDataset(Dataset):\n",
" x_fns:List[Path]; bbs:Tuple[int, List[float]]\n",
" def __post_init__(self): assert len(self.x_fns)==len(self.bbs)\n",
" def __repr__(self): return f'{type(self).__name__} of len {len(self.x_fns)}'\n",
" def __len__(self): return len(self.x_fns)\n",
" def __getitem__(self, i): \n",
" return open_image(self.x_fns[i]), self.bbs[i]\n",
"\n",
"train_ds = AnnoTargetDataset(train_fns, train_annos_lrg)\n",
"valid_ds = AnnoTargetDataset(valid_fns, valid_annos_lrg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x, y = next(iter(train_ds))\n",
"show_img_anno(x, y, lbl_to_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torchvision.models import resnet18, resnet34\n",
"arch = resnet34\n",
"\n",
"# imagenet mean / std\n",
"data_mean, data_std = map(tensor, ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))\n",
"data_norm,data_denorm = normalize_funcs(data_mean,data_std)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bs = 128\n",
"size=128\n",
"workers=0\n",
"\n",
"def get_data(bs, size):\n",
" tfms = get_transforms(do_flip=True, max_rotate=10, max_zoom=1.2, max_lighting=0.3, max_warp=0.15)\n",
" tds = transform_datasets(train_ds, valid_ds, tfms, size=size)\n",
" data = DataBunch.create(*tds, bs=bs, num_workers=workers, tfms=data_norm)\n",
" return data\n",
"\n",
"data = get_data(bs, size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x,y = next(iter(data.train_dl))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tensor([[5,6],[1,2,3]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"type(b)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}