{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Install MMDetection" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "d9d74c68", "metadata": {}, "source": [ "#### simple" ] }, { "cell_type": "code", "execution_count": null, "id": "2c3e601a", "metadata": {}, "outputs": [], "source": [ "%pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121\n", "%pip install openmim pycocotools faster-coco-eval\n", "%pip install mmcv==2.1.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1/index.html\n", "!python3 -m mim install mmdet" ] }, { "cell_type": "markdown", "id": "b7009f6a", "metadata": {}, "source": [ "## Download COCO VAL" ] }, { "cell_type": "code", "execution_count": null, "id": "bc2a4389", "metadata": {}, "outputs": [], "source": [ "!wget -P COCO/DIR/ http://images.cocodataset.org/annotations/annotations_trainval2017.zip\n", "!wget -P COCO/DIR/ http://images.cocodataset.org/zips/val2017.zip" ] }, { "cell_type": "markdown", "id": "1b94cb3d", "metadata": {}, "source": [ "## Unzip COCO VAL" ] }, { "cell_type": "code", "execution_count": null, "id": "94d9a6c0", "metadata": {}, "outputs": [], "source": [ "!unzip -qq COCO/DIR/annotations_trainval2017.zip -d COCO/DIR/\n", "!unzip -qq COCO/DIR/val2017.zip -d COCO/DIR/" ] }, { "cell_type": "markdown", "id": "ee83ac7b", "metadata": {}, "source": [ "## Download model" ] }, { "cell_type": "code", "execution_count": null, "id": "b6f6860f", "metadata": {}, "outputs": [], "source": [ "import mmdet\n", "import mmengine\n", "import os.path as osp\n", "\n", "config_dir = osp.dirname(mmdet.__file__)\n", "sub_config = \"configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py\"\n", "config_file = osp.join(config_dir, \".mim\", sub_config)\n", "cfg = mmengine.Config.fromfile(config_file)\n", "\n", "model_file = \"https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth\"\n", "\n", "print(f\"{config_file=}\")\n", "print(f\"{model_file=}\")\n", "\n", "!mkdir -p -m 777 model\n", "\n", "cfg.dump(osp.join(\"model\", osp.basename(config_file)))\n", "!wget -P model/ {model_file}\n", "\n", "!ls -lah model" ] }, { "cell_type": "markdown", "id": "313f1978", "metadata": {}, "source": [ "## Validate" ] }, { "cell_type": "code", "execution_count": null, "id": "eebcff84", "metadata": {}, "outputs": [], "source": [ "from mmdet.apis import inference_detector, init_detector\n", "from mmengine.registry import init_default_scope\n", "from mmdet.datasets import CocoDataset\n", "import tqdm\n", "import os.path as osp\n", "import os\n", "import torch\n", "\n", "# from coco_metric import CocoMetric\n", "from mmdet.evaluation import CocoMetric\n", "from mmdet.structures.mask import encode_mask_results\n", "import pathlib\n", "import copy\n", "import time\n", "from pycocotools.coco import COCO as pycocotools_COCO\n", "from pycocotools.cocoeval import COCOeval as pycocotools_COCOeval\n", "from faster_coco_eval import COCO as COCO_faster, COCOeval_faster\n", "import pandas as pd\n", "from IPython.display import display, Markdown" ] }, { "cell_type": "code", "execution_count": null, "id": "c2f5fedd", "metadata": {}, "outputs": [], "source": [ "init_default_scope(\"mmdet\")" ] }, { "cell_type": "markdown", "id": "493db05e", "metadata": {}, "source": [ "### Select first 100 images" ] }, { "cell_type": "code", "execution_count": null, "id": "e1674a30", "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "with open(\"./COCO/DIR/annotations/instances_val2017.json\") as fd:\n", " instances_val2017 = json.load(fd)\n", " \n", "image_id_for_eval = [image['id'] for image in instances_val2017['images']]\n", "image_id_for_eval = image_id_for_eval[:100] # Select first 100 images\n", "\n", "annotations = [ann for ann in instances_val2017['annotations'] if ann['image_id'] in image_id_for_eval]\n", "images = [image for image in instances_val2017['images'] if image['id'] in image_id_for_eval]\n", "\n", "instances_val2017['annotations'] = annotations\n", "instances_val2017['images'] = images\n", "\n", "with open(\"./COCO/DIR/annotations/instances_val2017_first_100.json\", \"w\") as fd:\n", " json.dump(instances_val2017, fd)" ] }, { "cell_type": "markdown", "id": "4bacd42d", "metadata": {}, "source": [ "## Init model" ] }, { "cell_type": "code", "execution_count": null, "id": "1318e70c", "metadata": {}, "outputs": [], "source": [ "model = init_detector(\n", " \"./model/rtmdet-ins_tiny_8xb32-300e_coco.py\",\n", " \"./model/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth\",\n", " device=(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n", ")" ] }, { "cell_type": "markdown", "id": "d7606215", "metadata": {}, "source": [ "## Init dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "1f7878d1", "metadata": {}, "outputs": [], "source": [ "pipeline = [\n", " dict(type=\"LoadImageFromFile\"),\n", " dict(type=\"mmdet.LoadAnnotations\", with_bbox=True),\n", "]\n", "\n", "dataset = CocoDataset(\n", " data_root=\"./COCO/DIR/\",\n", " ann_file=\"annotations/instances_val2017_first_100.json\",\n", " data_prefix=dict(img=\"val2017/\"),\n", " pipeline=pipeline,\n", ")\n", "len(dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "379869cc", "metadata": {}, "outputs": [], "source": [ "metric = CocoMetric(metric=[\"bbox\", \"segm\"])\n", "metric.dataset_meta = model.dataset_meta" ] }, { "cell_type": "code", "execution_count": null, "id": "34328dac", "metadata": {}, "outputs": [], "source": [ "_coco_api = COCO_faster(dataset.ann_file)\n", "metric.cat_ids = _coco_api.get_cat_ids(cat_names=metric.dataset_meta[\"classes\"])" ] }, { "cell_type": "markdown", "id": "eb2f2270", "metadata": {}, "source": [ "## Process images" ] }, { "cell_type": "code", "execution_count": null, "id": "0590adeb", "metadata": {}, "outputs": [], "source": [ "images_path = pathlib.Path(dataset.data_prefix[\"img\"])\n", "\n", "files = list(images_path.rglob(\"*.segm.json\"))\n", "files += list(images_path.rglob(\"*.bbox.json\"))\n", "\n", "for file in tqdm.tqdm(files):\n", " os.remove(file.as_posix())" ] }, { "cell_type": "code", "execution_count": null, "id": "24d1832a", "metadata": {}, "outputs": [], "source": [ "max_images = len(dataset)\n", "\n", "for i in tqdm.tqdm(range(max_images)):\n", " item = dataset[i]\n", " result = inference_detector(model, item[\"img_path\"])\n", "\n", " for key in result.pred_instances.all_keys():\n", " result.pred_instances[key] = result.pred_instances[key].detach().cpu()\n", "\n", " dict_result = dict(\n", " result.pred_instances.to_dict(), **{\"img_id\": item[\"img_id\"]}\n", " )\n", "\n", " if \"masks\" in dict_result:\n", " dict_result[\"masks\"] = encode_mask_results(\n", " dict_result[\"masks\"].detach().cpu().numpy()\n", " )\n", "\n", " metric.results2json(\n", " [dict_result], outfile_prefix=osp.splitext(item[\"img_path\"])[0]\n", " )" ] }, { "cell_type": "markdown", "id": "44914d0e", "metadata": {}, "source": [ "## Convert results" ] }, { "cell_type": "code", "execution_count": null, "id": "8e51afc4", "metadata": {}, "outputs": [], "source": [ "include_segm = \"masks\" in dict_result\n", "print(f\"{include_segm=}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "ac60463b", "metadata": {}, "outputs": [], "source": [ "dataset.data_prefix[\"img\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "c90430cf", "metadata": {}, "outputs": [], "source": [ "images_path = pathlib.Path(dataset.data_prefix[\"img\"])\n", "\n", "if include_segm:\n", " files = list(images_path.rglob(\"*.segm.json\"))\n", "else:\n", " files = list(images_path.rglob(\"*.bbox.json\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "37eb0a66", "metadata": {}, "outputs": [], "source": [ "result_data = []\n", "\n", "for file in tqdm.tqdm(files):\n", " result_data += COCO_faster.load_json(file)" ] }, { "cell_type": "code", "execution_count": null, "id": "123b828c", "metadata": {}, "outputs": [], "source": [ "def load_faster_data(ann_file, result_data):\n", " cocoGt = COCO_faster(ann_file)\n", " cocoDt = cocoGt.loadRes(copy.deepcopy(result_data))\n", " return cocoGt, cocoDt\n", "\n", "\n", "def process_faster(cocoGt, cocoDt, iouType):\n", " cocoEval = COCOeval_faster(cocoGt, cocoDt, iouType, print_function=print)\n", "\n", " ts = time.time()\n", " cocoEval.evaluate()\n", " cocoEval.accumulate()\n", " cocoEval.summarize()\n", " te = time.time()\n", "\n", " return te - ts" ] }, { "cell_type": "code", "execution_count": null, "id": "679a6679", "metadata": {}, "outputs": [], "source": [ "def load_pycocotools_data(ann_file, result_data):\n", " cocoGt = pycocotools_COCO(ann_file)\n", " cocoDt = cocoGt.loadRes(copy.deepcopy(result_data))\n", " return cocoGt, cocoDt\n", "\n", "\n", "def process_pycocotools(cocoGt, cocoDt, iouType):\n", " cocoEval = pycocotools_COCOeval(cocoGt, cocoDt, iouType)\n", "\n", " ts = time.time()\n", " cocoEval.evaluate()\n", " cocoEval.accumulate()\n", " cocoEval.summarize()\n", " te = time.time()\n", "\n", " return te - ts" ] }, { "cell_type": "code", "execution_count": null, "id": "e6144eac", "metadata": {}, "outputs": [], "source": [ "processors = [\n", " [\"faster-coco-eval\", load_faster_data, process_faster],\n", " [\"pycocotools\", load_pycocotools_data, process_pycocotools],\n", "]" ] }, { "cell_type": "markdown", "id": "7e267bc7", "metadata": {}, "source": [ "## Process eval" ] }, { "cell_type": "code", "execution_count": null, "id": "18a0740f", "metadata": {}, "outputs": [], "source": [ "result_table = {}\n", "\n", "for metric in [\"bbox\", \"segm\"] if include_segm else [\"bbox\"]:\n", " if result_table.get(metric) is None:\n", " result_table[metric] = {}\n", "\n", " for _name, _load, _process in processors:\n", " if result_table[metric].get(_name) is None:\n", " result_table[metric][_name] = 0\n", "\n", " print(f\"{metric=}; {_name=}\")\n", " cocoGt, cocoDt = _load(dataset.ann_file, result_data)\n", " result_table[metric][_name] = _process(cocoGt, cocoDt, metric)\n", " print()\n", " print()" ] }, { "cell_type": "markdown", "id": "92aef956", "metadata": {}, "source": [ "## Display results" ] }, { "cell_type": "code", "execution_count": null, "id": "023d1b07", "metadata": {}, "outputs": [], "source": [ "result_table" ] }, { "cell_type": "code", "execution_count": null, "id": "0e31ebc9", "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame(result_table).T.round(3)\n", "df.index.name = \"Type\"\n", "df[\"Profit\"] = (df[\"pycocotools\"] / df[\"faster-coco-eval\"]).round(3)\n", "df" ] }, { "cell_type": "code", "execution_count": null, "id": "0c76a1e0", "metadata": {}, "outputs": [], "source": [ "print(df.to_markdown())" ] }, { "cell_type": "code", "execution_count": null, "id": "c819d20c", "metadata": {}, "outputs": [], "source": [ "display(Markdown(df.to_markdown()))" ] }, { "cell_type": "code", "execution_count": null, "id": "591f042a", "metadata": {}, "outputs": [], "source": [ "cocoGt, cocoDt = load_faster_data(dataset.ann_file, result_data)" ] }, { "cell_type": "code", "execution_count": null, "id": "75d62196", "metadata": {}, "outputs": [], "source": [ "from faster_coco_eval.extra import Curves\n", "\n", "cur = Curves(cocoGt, cocoDt, iou_tresh=0.5, iouType=\"bbox\", useCats=False)\n", "cur.plot_pre_rec()\n", "cur.plot_f1_confidence()" ] }, { "cell_type": "code", "execution_count": null, "id": "9e2cc855", "metadata": {}, "outputs": [], "source": [ "from faster_coco_eval.extra import PreviewResults\n", "\n", "image_preview_count = 1\n", "preview = PreviewResults(\n", " cocoGt, cocoDt, iouType=\"segm\", iou_tresh=0.5, min_score=0.3\n", ")\n", "preview.display_tp_fp_fn(\n", " data_folder=dataset.data_prefix[\"img\"],\n", " image_ids=list(cocoGt.imgs.keys())[:image_preview_count],\n", " display_gt=True,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "c42047ee", "metadata": {}, "outputs": [], "source": [ "preview.display_matrix(normalize=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }