{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# Create Custom Keypoints Table\n", "\n", "Create a 3LC keypoints table from the Animal-Pose dataset with custom schema definitions for pose estimation tasks.\n", "\n", "![img](../../images/animalpose.png)\n", "\n", "\n", "\n", "Animal pose estimation requires specialized keypoint definitions that differ from human pose datasets. Custom tables allow you to define domain-specific keypoint structures and handle non-standard annotation formats that standard loaders can't process.\n", "\n", "This notebook demonstrates creating a keypoints table from scratch using the Animal-Pose dataset. We manually extract COCO-like JSON annotations and convert them to 3LC keypoint format, showing how to handle custom keypoint definitions and visibility flags. The Animal Pose Dataset contains diverse animal species with specialized pose annotations from a Kaggle-hosted version with COCO-like formatting that requires manual processing to create proper 3LC keypoint structures.\n" ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "## Project setup" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - 2D Keypoints\"\n", "DATASET_NAME = \"AnimalPose\"\n", "TABLE_NAME = \"initial\"\n", "DOWNLOAD_PATH = \"../../transient_data\"" ] }, { "cell_type": "markdown", "id": "3", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [], "source": [ "%pip install 3lc\n", "%pip install kagglehub" ] }, { "cell_type": "markdown", "id": "5", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "import json\n", "from pathlib import Path\n", "\n", "import numpy as np\n", "import tlc\n", "from PIL import Image\n", "from tqdm import tqdm" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "## Prepare data\n", "\n", "The following cell downloads the dataset from Kaggle. The dataset requires 350MB of disk space, as well as a [Kaggle account](https://www.kaggle.com/docs/api#authentication)." ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "import kagglehub\n", "\n", "DATASET_ROOT = kagglehub.dataset_download(\"bloodaxe/animal-pose-dataset\")\n", "DATASET_ROOT = Path(DATASET_ROOT)\n", "\n", "print(\"Path to dataset files:\", DATASET_ROOT)" ] }, { "cell_type": "code", "execution_count": null, "id": "9", "metadata": {}, "outputs": [], "source": [ "ANNOTATIONS_FILE = DATASET_ROOT / \"keypoints.json\"\n", "IMAGE_ROOT = DATASET_ROOT / \"images\" / \"images\"" ] }, { "cell_type": "code", "execution_count": null, "id": "10", "metadata": {}, "outputs": [], "source": [ "# Register the dataset root as a project URL alias - this enables to easily share the table or move the source data\n", "# tlc.register_project_url_alias(\"ANIMAL_POSE_DATA\", DATASET_ROOT, project=PROJECT_NAME)" ] }, { "cell_type": "markdown", "id": "11", "metadata": {}, "source": [ "## Load annotations / metadata" ] }, { "cell_type": "code", "execution_count": null, "id": "12", "metadata": {}, "outputs": [], "source": [ "with open(ANNOTATIONS_FILE) as f:\n", " data = json.load(f)\n", "\n", "# Load metadata from the annotations file\n", "NUM_KEYPOINTS = 20\n", "KEYPOINT_NAMES = data[\"categories\"][0][\"keypoints\"]\n", "CLASSES = {cat[\"id\"]: cat[\"name\"] for cat in data[\"categories\"]}\n", "SKELETON = np.array(data[\"categories\"][0][\"skeleton\"]).reshape(-1).tolist()" ] }, { "cell_type": "markdown", "id": "13", "metadata": {}, "source": [ "Some metadata is not stored in the annotations file, so we need to define it manually.\n", " These values were taken from the SuperGradients example notebook [YoloNAS_Pose_Fine_Tuning_Animals_Pose_Dataset](https://github.com/Deci-AI/super-gradients/blob/master/notebooks/YoloNAS_Pose_Fine_Tuning_Animals_Pose_Dataset.ipynb).\n" ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "OKS_SIGMAS = [0.07] * 20\n", "FLIP_INDEXES = [1, 0, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 17, 18, 19]\n", "\n", "KEYPOINT_COLORS = [\n", " [148, 0, 211],\n", " [75, 0, 130],\n", " [0, 0, 255],\n", " [0, 255, 0],\n", " [255, 255, 0],\n", " [255, 165, 0],\n", " [255, 69, 0],\n", " [255, 0, 0],\n", " [139, 0, 0],\n", " [128, 0, 128],\n", " [238, 130, 238],\n", " [186, 85, 211],\n", " [148, 0, 211],\n", " [0, 255, 255],\n", " [0, 128, 128],\n", " [0, 0, 139],\n", " [0, 0, 255],\n", " [0, 255, 0],\n", " [255, 69, 0],\n", " [255, 20, 147],\n", "]\n", "\n", "KEYPOINT_ATTRIBUTES = [\n", " tlc.MapElement(internal_name=kpt_name, display_color=tlc.rgb_tuple_to_hex(color))\n", " for kpt_name, color in zip(KEYPOINT_NAMES, KEYPOINT_COLORS)\n", "]\n", "\n", "# A roughly drawn default pose suitable for annotating missing animals\n", "KEYPOINTS_DEFAULT_POSE = [\n", " [0.87, 0.22],\n", " [0.78, 0.27],\n", " [0.9, 0.41],\n", " [0.83, 0.1],\n", " [0.72, 0.13],\n", " [0.75, 0.46],\n", " [0.58, 0.46],\n", " [0.29, 0.42],\n", " [0.16, 0.4],\n", " [0.75, 0.69],\n", " [0.59, 0.67],\n", " [0.29, 0.64],\n", " [0.12, 0.63],\n", " [0.78, 0.88],\n", " [0.59, 0.9],\n", " [0.3, 0.87],\n", " [0.1, 0.87],\n", " [0.77, 0.36],\n", " [0.6, 0.2],\n", " [0.25, 0.2],\n", "]\n", "\n", "EDGE_COLORS = [\n", " [127, 0, 255],\n", " [91, 56, 253],\n", " [55, 109, 248],\n", " [19, 157, 241],\n", " [18, 199, 229],\n", " [54, 229, 215],\n", " [90, 248, 199],\n", " [128, 254, 179],\n", " [164, 248, 158],\n", " [200, 229, 135],\n", " [236, 199, 110],\n", " [255, 157, 83],\n", " [255, 109, 56],\n", " [255, 56, 28],\n", " [255, 0, 0],\n", "]\n", "\n", "LINE_ATTRIBUTES = [\n", " tlc.MapElement(internal_name=\"edge\", display_color=tlc.rgb_tuple_to_hex(color)) for color in EDGE_COLORS\n", "]" ] }, { "cell_type": "markdown", "id": "15", "metadata": {}, "source": [ "## Load the annotations" ] }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": {}, "outputs": [], "source": [ "annotations = data[\"annotations\"]\n", "images = data[\"images\"]\n", "\n", "row_data = {\n", " \"image\": [],\n", " \"keypoints_2d\": [],\n", "}\n", "\n", "# Pre-compute mapping from image_id to annotations for faster lookup\n", "image_id_2_anns = {}\n", "for ann in annotations:\n", " image_id_2_anns.setdefault(str(ann[\"image_id\"]), []).append(ann)\n", "\n", "for image_id, image_path in tqdm(images.items(), total=len(images), desc=\"Loading annotations\"):\n", " image_path = IMAGE_ROOT / image_path\n", " if not image_path.exists():\n", " print(f\"Image {image_path} does not exist\")\n", " continue\n", "\n", " with Image.open(image_path) as img:\n", " width, height = img.size\n", "\n", " keypoints = tlc.Keypoints2DInstances.create_empty(image_height=height, image_width=width)\n", "\n", " for ann in image_id_2_anns[image_id]:\n", " # Annotation file uses 0-1 visibility channel, 3LC uses COCO three-state visibility (0-1-2).\n", " ann[\"keypoints\"] = [[x, y, 2 if v else 0] for (x, y, v) in ann[\"keypoints\"]]\n", "\n", " keypoints.add_instance(\n", " keypoints=ann[\"keypoints\"],\n", " bbox=ann[\"bbox\"],\n", " label=ann[\"category_id\"],\n", " )\n", "\n", " row_data[\"image\"].append(tlc.Url(image_path).to_relative().to_str()) # Url.to_relative applies aliases\n", " row_data[\"keypoints_2d\"].append(keypoints.to_row())" ] }, { "cell_type": "markdown", "id": "17", "metadata": {}, "source": [ "## Create table\n", "\n", "We create a Table using `from_dict`, specifying the attributes and metadata of the keypoints-column using a `Keypoints2DSchema`." ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": {}, "outputs": [], "source": [ "keypoints_schema = tlc.Keypoints2DSchema(\n", " num_keypoints=NUM_KEYPOINTS,\n", " classes=CLASSES,\n", " points=KEYPOINTS_DEFAULT_POSE,\n", " lines=SKELETON,\n", " line_attributes=LINE_ATTRIBUTES,\n", " point_attributes=KEYPOINT_ATTRIBUTES,\n", " include_per_point_visibility=True,\n", " flip_indices=FLIP_INDEXES,\n", " oks_sigmas=OKS_SIGMAS,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "19", "metadata": {}, "outputs": [], "source": [ "table = tlc.Table.from_dict(\n", " data=row_data,\n", " structure={\n", " \"image\": tlc.ImageUrlSchema(),\n", " \"keypoints_2d\": keypoints_schema,\n", " },\n", " table_name=TABLE_NAME,\n", " dataset_name=DATASET_NAME,\n", " project_name=PROJECT_NAME,\n", " if_exists=\"rename\",\n", ")" ] }, { "cell_type": "markdown", "id": "20", "metadata": {}, "source": [ "## Inspect the table\n", "\n", "We can use the `KeypointHelper` class to extract various geometric information from the table." ] }, { "cell_type": "code", "execution_count": null, "id": "21", "metadata": {}, "outputs": [], "source": [ "table" ] }, { "cell_type": "code", "execution_count": null, "id": "22", "metadata": {}, "outputs": [], "source": [ "# Get the oks sigmas from the table\n", "tlc.KeypointHelper.get_oks_sigmas_from_table(table)" ] }, { "cell_type": "code", "execution_count": null, "id": "23", "metadata": {}, "outputs": [], "source": [ "# Get the flip indices from the table\n", "tlc.KeypointHelper.get_flip_indices_from_table(table)" ] }, { "cell_type": "code", "execution_count": null, "id": "24", "metadata": {}, "outputs": [], "source": [ "# Get the skeleton from the table\n", "tlc.KeypointHelper.get_lines_from_table(table)" ] }, { "cell_type": "code", "execution_count": null, "id": "25", "metadata": {}, "outputs": [], "source": [ "# Get the keypoint attributes from the table\n", "tlc.KeypointHelper.get_keypoint_attributes_from_table(table)" ] }, { "cell_type": "code", "execution_count": null, "id": "26", "metadata": {}, "outputs": [], "source": [ "# Get the line attributes from the table\n", "tlc.KeypointHelper.get_line_attributes_from_table(table)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" }, "test_marks": [ "dependent" ] }, "nbformat": 4, "nbformat_minor": 5 }