{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install ultralytics --quiet" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import time\n", "from ultralytics.models.yolo.detect.val import DetectionValidator, check_requirements, LOGGER, Path\n", "\n", "class BaseCustomDetectionValidator(DetectionValidator):\n", " def eval_json_orig(self, stats):\n", " \"\"\"Evaluates YOLO output in JSON format and returns performance statistics.\"\"\"\n", " if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):\n", " pred_json = self.save_dir / \"predictions.json\" # predictions\n", " anno_json = (\n", " self.data[\"path\"]\n", " / \"annotations\"\n", " / (\"instances_val2017.json\" if self.is_coco else f\"lvis_v1_{self.args.split}.json\")\n", " ) # annotations\n", " pkg = \"pycocotools\" if self.is_coco else \"lvis\"\n", " LOGGER.info(f\"\\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...\")\n", " try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb\n", " for x in pred_json, anno_json:\n", " assert x.is_file(), f\"{x} file not found\"\n", " check_requirements(\"pycocotools>=2.0.6\" if self.is_coco else \"lvis>=0.5.3\")\n", " if self.is_coco:\n", " from pycocotools.coco import COCO # noqa\n", " from pycocotools.cocoeval import COCOeval # noqa\n", "\n", " anno = COCO(str(anno_json)) # init annotations api\n", " pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)\n", " val = COCOeval(anno, pred, \"bbox\")\n", " else:\n", " from lvis import LVIS, LVISEval\n", "\n", " anno = LVIS(str(anno_json)) # init annotations api\n", " pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)\n", " val = LVISEval(anno, pred, \"bbox\")\n", " val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval\n", " val.evaluate()\n", " val.accumulate()\n", " val.summarize()\n", " if self.is_lvis:\n", " val.print_results() # explicitly call print_results\n", " # update mAP50-95 and mAP50\n", " stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = (\n", " val.stats[:2] if self.is_coco else [val.results[\"AP50\"], val.results[\"AP\"]]\n", " )\n", " except Exception as e:\n", " LOGGER.warning(f\"{pkg} unable to run: {e}\")\n", " return stats\n", " \n", " def eval_json_faster(self, stats):\n", " \"\"\"Evaluates YOLO output in JSON format and returns performance statistics.\"\"\"\n", " if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):\n", " pred_json = self.save_dir / \"predictions.json\" # predictions\n", " anno_json = (\n", " self.data[\"path\"]\n", " / \"annotations\"\n", " / (\"instances_val2017.json\" if self.is_coco else f\"lvis_v1_{self.args.split}.json\")\n", " ) # annotations\n", " pkg = \"faster_coco_eval\"\n", " LOGGER.info(f\"\\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...\")\n", " try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb\n", " for x in pred_json, anno_json:\n", " assert x.is_file(), f\"{x} file not found\"\n", " \n", " from faster_coco_eval import COCO, COCOeval_faster\n", "\n", " if self.is_coco:\n", " anno = COCO(str(anno_json)) # init annotations api\n", " pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)\n", " val = COCOeval_faster(anno, pred, \"bbox\", print_function=print)\n", " else:\n", " anno = COCO(str(anno_json)) # init annotations api\n", " pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path)\n", " val = COCOeval_faster(anno, pred, \"bbox\", lvis_style=True, print_function=print)\n", " val.params.maxDets = [300]\n", " \n", " val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval\n", " val.evaluate()\n", " val.accumulate()\n", " val.summarize()\n", " \n", " # update mAP50-95 and mAP50\n", " stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = val.stats[:2]\n", " except Exception as e:\n", " LOGGER.warning(f\"{pkg} unable to run: {e}\")\n", " return stats\n", " \n", " def eval_json(self, stats):\n", " tic_faster = time.time()\n", " self.eval_json_faster(stats)\n", " toc_faster = time.time()\n", " print(f\"Faster eval took {toc_faster - tic_faster:.2f}s\")\n", "\n", " \n", " tic_orig = time.time()\n", " stats = self.eval_json_orig(stats)\n", " toc_orig = time.time()\n", " print(f\"Original eval took {toc_orig - tic_orig:.2f}s\")\n", " \n", " return stats" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ultralytics YOLOv8.1.47 🚀 Python-3.10.12 torch-2.1.2+cu118 CUDA:0 (NVIDIA GeForce RTX 3080 Ti, 12288MiB)\n", "YOLOv8n summary (fused): 168 layers, 3151904 parameters, 0 gradients, 8.7 GFLOPs\n", "\n", "Dataset 'coco_val_only.yaml' images not found ⚠️, missing path '/home/mixaill76/faster_coco_eval/examples/ultralytics/datasets/coco/val2017.txt'\n", "Downloading https://github.com/ultralytics/yolov5/releases/download/v1.0/coco2017labels-segments.zip to '/home/mixaill76/faster_coco_eval/examples/ultralytics/datasets/coco2017labels-segments.zip'...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/mixaill76/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)\n", " return F.conv2d(input, weight, bias, self.stride,\n", "100%|██████████| 169M/169M [00:06<00:00, 25.8MB/s] \n", "Unzipping /home/mixaill76/faster_coco_eval/examples/ultralytics/datasets/coco2017labels-segments.zip to /home/mixaill76/faster_coco_eval/examples/ultralytics/datasets/coco...: 100%|██████████| 122232/122232 [00:07<00:00, 16920.33file/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Downloading http://images.cocodataset.org/zips/val2017.zip to '/home/mixaill76/faster_coco_eval/examples/ultralytics/datasets/coco/images/val2017.zip'...\n", "Dataset download success ✅ (100.7s), saved to \u001b[1m/home/mixaill76/faster_coco_eval/examples/ultralytics/datasets\u001b[0m\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mval: \u001b[0mScanning /home/mixaill76/faster_coco_eval/examples/ultralytics/datasets/coco/labels/val2017... 4952 images, 48 backgrounds, 0 corrupt: 100%|██████████| 5000/5000 [00:05<00:00, 864.61it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mval: \u001b[0mNew cache created: /home/mixaill76/faster_coco_eval/examples/ultralytics/datasets/coco/labels/val2017.cache\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 313/313 [00:31<00:00, 9.86it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " all 5000 36335 0.629 0.476 0.521 0.37\n", " person 5000 10777 0.751 0.678 0.745 0.515\n", " bicycle 5000 314 0.694 0.411 0.466 0.269\n", " car 5000 1918 0.656 0.527 0.566 0.364\n", " motorcycle 5000 367 0.71 0.573 0.654 0.412\n", " airplane 5000 143 0.755 0.776 0.845 0.654\n", " bus 5000 283 0.73 0.664 0.739 0.621\n", " train 5000 190 0.795 0.774 0.833 0.648\n", " truck 5000 414 0.519 0.384 0.45 0.301\n", " boat 5000 424 0.562 0.297 0.373 0.209\n", " traffic light 5000 634 0.641 0.352 0.415 0.213\n", " fire hydrant 5000 101 0.859 0.693 0.78 0.627\n", " stop sign 5000 75 0.679 0.627 0.676 0.615\n", " parking meter 5000 60 0.686 0.511 0.578 0.449\n", " bench 5000 411 0.553 0.275 0.297 0.197\n", " bird 5000 427 0.662 0.365 0.425 0.28\n", " cat 5000 202 0.766 0.832 0.847 0.648\n", " dog 5000 218 0.684 0.693 0.727 0.587\n", " horse 5000 272 0.687 0.658 0.69 0.521\n", " sheep 5000 354 0.616 0.669 0.669 0.466\n", " cow 5000 372 0.714 0.604 0.674 0.48\n", " elephant 5000 252 0.698 0.843 0.815 0.628\n", " bear 5000 71 0.816 0.749 0.833 0.673\n", " zebra 5000 266 0.802 0.807 0.879 0.661\n", " giraffe 5000 232 0.857 0.836 0.884 0.686\n", " backpack 5000 371 0.493 0.164 0.2 0.105\n", " umbrella 5000 407 0.61 0.521 0.538 0.359\n", " handbag 5000 540 0.474 0.122 0.161 0.0815\n", " tie 5000 252 0.636 0.377 0.429 0.267\n", " suitcase 5000 299 0.558 0.425 0.488 0.334\n", " frisbee 5000 115 0.727 0.757 0.763 0.58\n", " skis 5000 241 0.632 0.34 0.377 0.194\n", " snowboard 5000 69 0.534 0.348 0.381 0.267\n", " sports ball 5000 260 0.702 0.442 0.481 0.331\n", " kite 5000 327 0.612 0.526 0.556 0.379\n", " baseball bat 5000 145 0.555 0.372 0.411 0.214\n", " baseball glove 5000 148 0.649 0.486 0.516 0.304\n", " skateboard 5000 179 0.659 0.592 0.645 0.456\n", " surfboard 5000 267 0.599 0.476 0.5 0.312\n", " tennis racket 5000 225 0.676 0.596 0.661 0.403\n", " bottle 5000 1013 0.603 0.382 0.454 0.297\n", " wine glass 5000 341 0.667 0.328 0.407 0.263\n", " cup 5000 895 0.571 0.437 0.485 0.346\n", " fork 5000 215 0.596 0.312 0.375 0.257\n", " knife 5000 325 0.448 0.16 0.166 0.0963\n", " spoon 5000 253 0.437 0.129 0.162 0.0973\n", " bowl 5000 623 0.586 0.485 0.526 0.393\n", " banana 5000 370 0.554 0.319 0.374 0.228\n", " apple 5000 236 0.427 0.231 0.221 0.151\n", " sandwich 5000 177 0.563 0.467 0.475 0.359\n", " orange 5000 285 0.472 0.421 0.361 0.274\n", " broccoli 5000 312 0.507 0.359 0.367 0.21\n", " carrot 5000 365 0.458 0.285 0.307 0.192\n", " hot dog 5000 125 0.718 0.406 0.489 0.36\n", " pizza 5000 284 0.655 0.616 0.658 0.502\n", " donut 5000 328 0.611 0.491 0.516 0.413\n", " cake 5000 310 0.559 0.406 0.45 0.3\n", " chair 5000 1771 0.578 0.344 0.404 0.259\n", " couch 5000 261 0.612 0.567 0.588 0.429\n", " potted plant 5000 342 0.508 0.374 0.377 0.223\n", " bed 5000 163 0.555 0.558 0.6 0.443\n", " dining table 5000 695 0.524 0.43 0.428 0.287\n", " toilet 5000 179 0.73 0.725 0.78 0.645\n", " tv 5000 288 0.738 0.628 0.724 0.551\n", " laptop 5000 231 0.69 0.662 0.699 0.578\n", " mouse 5000 106 0.662 0.647 0.704 0.522\n", " remote 5000 283 0.427 0.212 0.284 0.165\n", " keyboard 5000 153 0.592 0.569 0.65 0.49\n", " cell phone 5000 262 0.545 0.37 0.406 0.275\n", " microwave 5000 55 0.661 0.564 0.624 0.499\n", " oven 5000 143 0.643 0.497 0.54 0.361\n", " toaster 5000 9 0.593 0.222 0.433 0.311\n", " sink 5000 225 0.582 0.452 0.504 0.327\n", " refrigerator 5000 126 0.684 0.595 0.659 0.506\n", " book 5000 1129 0.458 0.108 0.191 0.0946\n", " clock 5000 267 0.727 0.61 0.672 0.459\n", " vase 5000 274 0.574 0.474 0.471 0.33\n", " scissors 5000 36 0.74 0.333 0.342 0.277\n", " teddy bear 5000 190 0.64 0.574 0.605 0.413\n", " hair drier 5000 11 1 0 0.00606 0.00426\n", " toothbrush 5000 57 0.434 0.211 0.218 0.137\n", "Speed: 0.1ms preprocess, 1.2ms inference, 0.0ms loss, 0.9ms postprocess per image\n", "Saving runs/detect/train/predictions.json...\n", "\n", "Evaluating faster_coco_eval mAP using runs/detect/train/predictions.json and /home/mixaill76/faster_coco_eval/examples/ultralytics/datasets/coco/annotations/instances_val2017.json...\n", "Evaluate annotation type *bbox*\n", "COCOeval_opt.evaluate() finished...\n", "DONE (t=5.09s).\n", "Accumulating evaluation results...\n", "COCOeval_opt.accumulate() finished...\n", "DONE (t=0.00s).\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.373\n", " Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.526\n", " Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.405\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.187\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.410\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.533\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.320\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.536\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.592\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.362\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.657\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.772\n", " Average Recall (AR) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.814\n", " Average Recall (AR) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.644\n", "Faster eval took 7.34s\n", "\n", "Evaluating pycocotools mAP using runs/detect/train/predictions.json and /home/mixaill76/faster_coco_eval/examples/ultralytics/datasets/coco/annotations/instances_val2017.json...\n", "loading annotations into memory...\n", "Done (t=0.16s)\n", "creating index...\n", "index created!\n", "Loading and preparing results...\n", "DONE (t=0.96s)\n", "creating index...\n", "index created!\n", "Running per image evaluation...\n", "Evaluate annotation type *bbox*\n", "DONE (t=25.35s).\n", "Accumulating evaluation results...\n", "DONE (t=6.93s).\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.373\n", " Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.526\n", " Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.405\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.187\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.410\n", " Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.533\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.320\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.536\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.592\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.362\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.657\n", " Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.772\n", "Original eval took 34.35s\n", "Results saved to \u001b[1mruns/detect/train\u001b[0m\n" ] }, { "data": { "text/plain": [ "{'metrics/precision(B)': 0.6292564172491455,\n", " 'metrics/recall(B)': 0.47631066232459646,\n", " 'metrics/mAP50(B)': 0.5255884728465783,\n", " 'metrics/mAP50-95(B)': 0.3731985247705999,\n", " 'fitness': 0.38551831525427444}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "!rm -rf rm -rf runs/\n", "\n", "args = dict(model='yolov8n.pt', data='./coco_val_only.yaml')\n", "validator = BaseCustomDetectionValidator(args=args)\n", "validator()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Faster eval took 7.34s" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Original eval took 34.35s" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import time\n", "from ultralytics.models.yolo.segment.val import SegmentationValidator, check_requirements, LOGGER, Path\n", "\n", "class BaseCustomSegmentationValidator(SegmentationValidator):\n", " def eval_json_orig(self, stats):\n", " \"\"\"Return COCO-style object detection evaluation metrics.\"\"\"\n", " if self.args.save_json and self.is_coco and len(self.jdict):\n", " anno_json = self.data[\"path\"] / \"annotations/instances_val2017.json\" # annotations\n", " pred_json = self.save_dir / \"predictions.json\" # predictions\n", " LOGGER.info(f\"\\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...\")\n", " try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb\n", " check_requirements(\"pycocotools>=2.0.6\")\n", " from pycocotools.coco import COCO # noqa\n", " from pycocotools.cocoeval import COCOeval # noqa\n", "\n", " for x in anno_json, pred_json:\n", " assert x.is_file(), f\"{x} file not found\"\n", " anno = COCO(str(anno_json)) # init annotations api\n", " pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)\n", " for i, eval in enumerate([COCOeval(anno, pred, \"bbox\"), COCOeval(anno, pred, \"segm\")]):\n", " if self.is_coco:\n", " eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval\n", " eval.evaluate()\n", " eval.accumulate()\n", " eval.summarize()\n", " idx = i * 4 + 2\n", " stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[\n", " :2\n", " ] # update mAP50-95 and mAP50\n", " except Exception as e:\n", " LOGGER.warning(f\"pycocotools unable to run: {e}\")\n", " return stats\n", " \n", " def eval_json_faster(self, stats):\n", " \"\"\"Return COCO-style object detection evaluation metrics.\"\"\"\n", " if self.args.save_json and self.is_coco and len(self.jdict):\n", " anno_json = self.data[\"path\"] / \"annotations/instances_val2017.json\" # annotations\n", " pred_json = self.save_dir / \"predictions.json\" # predictions\n", " LOGGER.info(f\"\\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...\")\n", " try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb\n", " from faster_coco_eval import COCO, COCOeval_faster\n", "\n", " for x in anno_json, pred_json:\n", " assert x.is_file(), f\"{x} file not found\"\n", " anno = COCO(str(anno_json)) # init annotations api\n", " pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)\n", " for i, eval in enumerate([COCOeval_faster(anno, pred, \"bbox\", print_function=print), COCOeval_faster(anno, pred, \"segm\", print_function=print)]):\n", " if self.is_coco:\n", " eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval\n", " eval.evaluate()\n", " eval.accumulate()\n", " eval.summarize()\n", " idx = i * 4 + 2\n", " stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[\n", " :2\n", " ] # update mAP50-95 and mAP50\n", " except Exception as e:\n", " LOGGER.warning(f\"faster_coco_eval unable to run: {e}\")\n", " return stats\n", "\n", " def eval_json(self, stats):\n", " tic_faster = time.time()\n", " self.eval_json_faster(stats)\n", " toc_faster = time.time()\n", " print(f\"Faster eval took {toc_faster - tic_faster:.2f}s\")\n", "\n", " \n", " tic_orig = time.time()\n", " stats = self.eval_json_orig(stats)\n", " toc_orig = time.time()\n", " print(f\"Original eval took {toc_orig - tic_orig:.2f}s\")\n", " \n", " return stats" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ultralytics YOLOv8.1.47 🚀 Python-3.10.12 torch-2.1.2+cu118 CUDA:0 (NVIDIA GeForce RTX 3080 Ti, 12288MiB)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/mixaill76/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)\n", " return F.conv2d(input, weight, bias, self.stride,\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "YOLOv8n-seg summary (fused): 195 layers, 3404320 parameters, 0 gradients, 12.6 GFLOPs\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mval: \u001b[0mScanning /home/mixaill76/faster_coco_eval/examples/ultralytics/datasets/coco/labels/val2017.cache... 4952 images, 48 backgrounds, 0 corrupt: 100%|██████████| 5000/5000 [00:00=2.0.6\")\n", " from pycocotools.coco import COCO # noqa\n", " from pycocotools.cocoeval import COCOeval # noqa\n", "\n", " for x in anno_json, pred_json:\n", " assert x.is_file(), f\"{x} file not found\"\n", " anno = COCO(str(anno_json)) # init annotations api\n", " pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)\n", " for i, eval in enumerate([COCOeval(anno, pred, \"bbox\"), COCOeval(anno, pred, \"keypoints\")]):\n", " if self.is_coco:\n", " eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval\n", " eval.evaluate()\n", " eval.accumulate()\n", " eval.summarize()\n", " idx = i * 4 + 2\n", " stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[\n", " :2\n", " ] # update mAP50-95 and mAP50\n", " except Exception as e:\n", " LOGGER.warning(f\"pycocotools unable to run: {e}\")\n", " return stats\n", " \n", " def eval_json_faster(self, stats):\n", " \"\"\"Evaluates object detection model using COCO JSON format.\"\"\"\n", " if self.args.save_json and self.is_coco and len(self.jdict):\n", " anno_json = self.data[\"path\"] / \"annotations/person_keypoints_val2017.json\" # annotations\n", " pred_json = self.save_dir / \"predictions.json\" # predictions\n", " LOGGER.info(f\"\\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...\")\n", " try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb\n", " from faster_coco_eval import COCO, COCOeval_faster\n", "\n", " for x in anno_json, pred_json:\n", " assert x.is_file(), f\"{x} file not found\"\n", " anno = COCO(str(anno_json)) # init annotations api\n", " pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)\n", " for i, eval in enumerate([COCOeval_faster(anno, pred, \"bbox\", print_function=print), COCOeval_faster(anno, pred, \"keypoints\", print_function=print)]):\n", " if self.is_coco:\n", " eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval\n", " eval.evaluate()\n", " eval.accumulate()\n", " eval.summarize()\n", " idx = i * 4 + 2\n", " stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[\n", " :2\n", " ] # update mAP50-95 and mAP50\n", " except Exception as e:\n", " LOGGER.warning(f\"pycocotools unable to run: {e}\")\n", " return stats\n", " \n", " def eval_json(self, stats):\n", " tic_faster = time.time()\n", " self.eval_json_faster(stats)\n", " toc_faster = time.time()\n", " print(f\"Faster eval took {toc_faster - tic_faster:.2f}s\")\n", "\n", " \n", " tic_orig = time.time()\n", " stats = self.eval_json_orig(stats)\n", " toc_orig = time.time()\n", " print(f\"Original eval took {toc_orig - tic_orig:.2f}s\")\n", " \n", " return stats" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ultralytics YOLOv8.1.47 🚀 Python-3.10.12 torch-2.1.2+cu118 CUDA:0 (NVIDIA GeForce RTX 3080 Ti, 12288MiB)\n", "YOLOv8n-pose summary (fused): 187 layers, 3289964 parameters, 0 gradients, 9.2 GFLOPs\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/mixaill76/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)\n", " return F.conv2d(input, weight, bias, self.stride,\n", "\u001b[34m\u001b[1mval: \u001b[0mScanning /home/mixaill76/faster_coco_eval/examples/ultralytics/datasets/coco-pose/labels/val2017.cache... 2346 images, 0 backgrounds, 0 corrupt: 100%|██████████| 2346/2346 [00:00