{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pylab inline\n", "import pandas as pd\n", "import numpy as np\n", "from pathlib import Path\n", "import random\n", "import shutil\n", "from fastprogress import progress_bar\n", "import PIL.Image\n", "from functools import partial" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IMGNET = Path('/DATA/kaggle/imgnetloc/')\n", "IMAGES_TRAIN = Path('/DATA/kaggle/imgnetloc/ILSVRC/Data/CLS-LOC/train/')\n", "IMAGES_VAL = Path('/DATA/kaggle/imgnetloc/ILSVRC/Data/CLS-LOC/val/')\n", "TRAIN_SOLUTION_CSV = IMGNET/'LOC_train_solution.csv'\n", "VALID_SOLUTION_CSV = IMGNET/'LOC_val_solution.csv'\n", "ANNO_TRAIN = Path('/DATA/kaggle/imgnetloc/ILSVRC/Annotations/CLS-LOC/train/')\n", "ANNO_VAL = Path('/DATA/kaggle/imgnetloc/ILSVRC/Annotations/CLS-LOC/val/')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# parse one line of class file, just going to grab first descriptions\n", "def parse_class_line(l):\n", " id = l.split(' ')[0]\n", " classes = l[len(id):].strip().split(',')\n", " return id, classes[0].strip()\n", "\n", "# read in mapping of class id to text description\n", "def read_classes(fn):\n", " classes = dict(map(parse_class_line, open(fn,'r').readlines()))\n", " return classes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "classes = read_classes(IMGNET/'LOC_synset_mapping.txt')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_img_fns(img_train_path, class_id):\n", " img_fns = []\n", " for fn in (img_train_path/class_id).iterdir():\n", " img_fns.append(fn)\n", " return img_fns" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_samples(clsid):\n", " img_fns = get_img_fns(IMAGES_TRAIN, clsid)\n", " images = [PIL.Image.open(fn) for fn in np.random.choice(img_fns, 3)]\n", " _,axes = plt.subplots(1,3, figsize=(12,3))\n", " for i,ax in enumerate(axes.flat): ax.imshow(images[i])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pull_classes = [\n", " 'n01443537', 'n01669191', 'n01774750', 'n01641577', 'n01882714',\n", " 'n01983481', 'n02114367', 'n02115641', 'n02317335', 'n01806143',\n", " 'n01484850', 'n03063689', 'n03272010', 'n03124170', 'n02799071',\n", " 'n03400231', 'n03452741', 'n02802426', 'n02692877', 'n02787622',\n", " 'n03785016', 'n04252077', 'n02088466', 'n04254680', 'n02504458',\n", " 'n03345487', 'n03642806', 'n03063599'\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for k in pull_classes:\n", " plot_samples(k)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "total_images = 0\n", "for clsid in pull_classes:\n", " img_fns = get_img_fns(IMAGES_TRAIN, clsid)\n", " num_images = len(img_fns)\n", " total_images += num_images\n", " print(classes[clsid], num_images)\n", "print('total images:', total_images)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "valid_df = pd.read_csv(VALID_SOLUTION_CSV)\n", "train_df = pd.read_csv(TRAIN_SOLUTION_CSV)\n", "\n", "len(train_df), len(valid_df)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df['classid'] = train_df.ImageId.apply(lambda x: x.split('_')[0])\n", "\n", "def parse_prediction_string(s):\n", " ids = []\n", " items = s.split(' ')\n", " pred_count = len(items) // 5\n", " for i in range(pred_count):\n", " ids.append(items[i*5])\n", " return ids[0]\n", "\n", "valid_df['classid'] = valid_df.PredictionString.apply(parse_prediction_string)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "small_train_df = train_df.loc[train_df.classid.isin(pull_classes)]\n", "small_valid_df = valid_df.loc[valid_df.classid.isin(pull_classes)]\n", "len(pull_classes), small_train_df.shape, small_valid_df.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IMGNET_SMALL = Path('/DATA/kaggle/imgnetloc_small/')\n", "SMALL_DATA = IMGNET_SMALL/'ILSVRC/Data/CLS-LOC'\n", "SMALL_ANNO = IMGNET_SMALL/'ILSVRC/Annotations/CLS-LOC'\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "SMALL_DATA.mkdir(parents=True, exist_ok=True)\n", "SMALL_ANNO.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "(SMALL_DATA/'train').mkdir(parents=True, exist_ok=True)\n", "(SMALL_ANNO/'val').mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# copy training directories\n", "for k in progress_bar(pull_classes):\n", " src_data_path = IMAGES_TRAIN/k\n", " dest_data_path = SMALL_DATA/'train'/k\n", " if dest_data_path.exists():\n", " shutil.rmtree(dest_data_path) \n", " shutil.copytree(src_data_path, dest_data_path)\n", " \n", " src_data_path = ANNO_TRAIN/k\n", " dest_data_path = SMALL_ANNO/'train'/k\n", " if dest_data_path.exists():\n", " shutil.rmtree(dest_data_path)\n", " shutil.copytree(src_data_path, dest_data_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# copy validation directories\n", "dest_val_data = SMALL_DATA/'val'\n", "dest_val_anno = SMALL_ANNO/'val'\n", "if dest_val_data.exists():\n", " shutil.rmtree(dest_val_data)\n", "if dest_val_anno.exists():\n", " shutil.rmtree(dest_val_anno)\n", "\n", "dest_val_data.mkdir(parents=True, exist_ok=True)\n", "dest_val_anno.mkdir(parents=True, exist_ok=True)\n", "\n", "for ix, row in progress_bar(list(small_valid_df.ImageId.items())):\n", " src_file = IMAGES_VAL/f'{row}.JPEG'\n", " dest_file = dest_val_data/f'{row}.JPEG'\n", " shutil.copyfile(src_file, dest_file)\n", " \n", " src_file = ANNO_VAL/f'{row}.xml'\n", " dest_file = dest_val_anno/f'{row}.xml'\n", " shutil.copyfile(src_file, dest_file)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# copy text files, filtering out classes we don't want\n", "def copy_file_with_filter(src_file, dst_file, filter_func, has_header=True):\n", " with open(src_file,'r') as rf:\n", " src_lines = rf.readlines()\n", "\n", " start_line = 1 if has_header else 0\n", " with open(dst_file,'w') as wf:\n", " if has_header: wf.write(src_lines[0])\n", " for line in src_lines[start_line:]:\n", " if filter_func(line):\n", " wf.write(line)\n", "\n", "def valid_is_desired_class(line, classes):\n", " clsid = line.split(',')[1].split(' ')[0]\n", " #print(clsid in classes)\n", " return clsid in classes\n", "\n", "def is_desired_class(line, classes):\n", " clsid = line[0:9]\n", " return clsid in classes\n", "\n", "\n", "def copy_filtered_csv(src_path, dst_path, fn, classes, has_header, check_func):\n", " copy_file_with_filter(src_path/fn, dst_path/fn, partial(check_func, classes=classes), has_header=has_header)\n", " \n", "\n", "text_files = [\n", " ('LOC_val_solution.csv', True, valid_is_desired_class),\n", " ('LOC_train_solution.csv', True, is_desired_class),\n", " ('LOC_synset_mapping.txt', False, is_desired_class)\n", "]\n", "\n", "\n", "for fn,has_header,func in text_files: copy_filtered_csv(IMGNET, IMGNET_SMALL, fn, pull_classes, has_header, func)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }