{ "cells": [ { "cell_type": "markdown", "id": "0", "metadata": {}, "source": [ "# Train a YOLO-NAS model for pose estimation with SuperGradients\n", "\n", "This tutorial trains a SuperGradients YOLO-NAS model for pose estimation the AnimalPose dataset.\n", "\n", "The input Table required for running this notebook is created in [create-custom-keypoints-table.ipynb](../1-create-tables/keypoints/create-custom-keypoints-table.ipynb).\n", "\n", "![](../images/sg-animalpose.png)\n", "\n", "\n", "\n", "This notebook is a modified version of the [SuperGradients YoloNAS Pose Fine Tuning Notebook](https://github.com/Deci-AI/super-gradients/blob/master/notebooks/YoloNAS_Pose_Fine_Tuning_Animals_Pose_Dataset.ipynb)." ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": {}, "outputs": [], "source": [ "%pip install 3lc\n", "%pip install super-gradients\n", "%pip install termcolor==3.1.0\n", "%pip install git+https://github.com/3lc-ai/3lc-examples.git" ] }, { "cell_type": "markdown", "id": "3", "metadata": {}, "source": [ "## Project setup" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - 2D Keypoints\"\n", "DATASET_NAME = \"AnimalPose\"\n", "TABLE_NAME = \"initial\"\n", "MODEL_NAME = \"yolo_nas_pose_n\"\n", "RUN_NAME = \"fine-tune-yolo-nas-pose-n-animalpose\"\n", "BATCH_SIZE = 16\n", "NUM_WORKERS = 0\n", "DOWNLOAD_PATH = \"../../transient_data\"\n", "MAX_EPOCHS = 10\n", "IMAGE_SIZE = 640" ] }, { "cell_type": "markdown", "id": "5", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "import requests\n", "from super_gradients.training import Trainer, models\n", "from super_gradients.training.datasets.pose_estimation_datasets import YoloNASPoseCollateFN\n", "from super_gradients.training.metrics import PoseEstimationMetrics\n", "from super_gradients.training.models.pose_estimation_models.yolo_nas_pose import YoloNASPosePostPredictionCallback\n", "from super_gradients.training.transforms.keypoints import (\n", " KeypointsBrightnessContrast,\n", " KeypointsHSV,\n", " KeypointsImageStandardize,\n", " KeypointsLongestMaxSize,\n", " KeypointsPadIfNeeded,\n", " KeypointsRandomAffineTransform,\n", " KeypointsRandomHorizontalFlip,\n", " KeypointsRemoveSmallObjects,\n", ")\n", "from super_gradients.training.utils.callbacks import Callback\n", "from tlc.core import KeypointHelper, Table\n", "from tlc.integration.super_gradients import PoseEstimationDataset, PoseEstimationMetricsCollectionCallback\n", "from torch.utils.data import DataLoader\n", "from torchmetrics.metric import Metric\n", "\n", "from tlc_tools.split import split_table" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "## Download pretrained model" ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "MODEL_PATH = Path(DOWNLOAD_PATH) / \"yolo_nas_pose_n_coco_pose.pth\"\n", "\n", "if not MODEL_PATH.exists():\n", " MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)\n", " response = requests.get(\"https://sg-hub-nv.s3.amazonaws.com/models/yolo_nas_pose_n_coco_pose.pth\")\n", " MODEL_PATH.write_bytes(response.content)" ] }, { "cell_type": "markdown", "id": "9", "metadata": {}, "source": [ "## Load and split input tables" ] }, { "cell_type": "code", "execution_count": null, "id": "10", "metadata": {}, "outputs": [], "source": [ "initial_table = Table.from_names(TABLE_NAME, DATASET_NAME, PROJECT_NAME)\n", "\n", "\n", "def split_by(table_row):\n", " \"\"\"Callable to get the label of the first keypoint instance\n", "\n", " This allows us to do a stratified split by label, just like in the original SuperGradients notebook.\n", " \"\"\"\n", " return table_row[\"keypoints_2d\"][\"instances_additional_data\"][\"label\"][0]\n", "\n", "\n", "train_val_test = split_table(\n", " initial_table,\n", " splits={\"train\": 0.8, \"val_test\": 0.2},\n", " split_strategy=\"stratified\",\n", " split_by=split_by,\n", " random_seed=42,\n", " shuffle=False,\n", ")\n", "\n", "test_val = split_table(\n", " train_val_test[\"val_test\"],\n", " splits={\"val\": 0.5, \"test\": 0.5},\n", " split_strategy=\"stratified\",\n", " split_by=split_by,\n", " shuffle=False,\n", " random_seed=42,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "11", "metadata": {}, "outputs": [], "source": [ "train_table = train_val_test[\"train\"]\n", "val_table = test_val[\"val\"]\n", "test_table = test_val[\"test\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "12", "metadata": {}, "outputs": [], "source": [ "print(initial_table)\n", "print(train_table)\n", "print(val_table)\n", "print(test_table)" ] }, { "cell_type": "markdown", "id": "13", "metadata": {}, "source": [ "## Prepare for training" ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "def create_transforms(image_size: int, flip_indices: list[int]):\n", " keypoints_random_horizontal_flip = KeypointsRandomHorizontalFlip(flip_index=flip_indices, prob=0.5)\n", " keypoints_hsv = KeypointsHSV(prob=0.5, hgain=20, sgain=20, vgain=20)\n", " keypoints_brightness_contrast = KeypointsBrightnessContrast(\n", " prob=0.5, brightness_range=[0.8, 1.2], contrast_range=[0.8, 1.2]\n", " )\n", " keypoints_random_affine_transform = KeypointsRandomAffineTransform(\n", " max_rotation=0,\n", " min_scale=0.5,\n", " max_scale=1.5,\n", " max_translate=0.1,\n", " image_pad_value=127,\n", " mask_pad_value=1,\n", " prob=0.75,\n", " interpolation_mode=[0, 1, 2, 3, 4],\n", " )\n", " keypoints_longest_max_size = KeypointsLongestMaxSize(max_height=image_size, max_width=image_size)\n", " keypoints_pad_if_needed = KeypointsPadIfNeeded(\n", " min_height=image_size,\n", " min_width=image_size,\n", " image_pad_value=[127, 127, 127],\n", " mask_pad_value=1,\n", " padding_mode=\"bottom_right\",\n", " )\n", " keypoints_image_standardize = KeypointsImageStandardize(max_value=255)\n", " keypoints_remove_small_objects = KeypointsRemoveSmallObjects(min_instance_area=1, min_visible_keypoints=1)\n", "\n", " train_transforms = [\n", " keypoints_random_horizontal_flip,\n", " keypoints_hsv,\n", " keypoints_brightness_contrast,\n", " keypoints_random_affine_transform,\n", " keypoints_longest_max_size,\n", " keypoints_pad_if_needed,\n", " keypoints_image_standardize,\n", " keypoints_remove_small_objects,\n", " ]\n", "\n", " val_transforms = [\n", " keypoints_longest_max_size,\n", " keypoints_pad_if_needed,\n", " keypoints_image_standardize,\n", " ]\n", "\n", " return train_transforms, val_transforms\n", "\n", "\n", "def create_training_params(max_epochs: int, callbacks: list[Callback], metrics: list[Metric], oks_sigmas: list[float]):\n", " return {\n", " \"seed\": 42,\n", " \"warmup_mode\": \"LinearBatchLRWarmup\",\n", " \"warmup_initial_lr\": 1e-8,\n", " \"lr_warmup_epochs\": 2,\n", " \"initial_lr\": 5e-4,\n", " \"lr_mode\": \"cosine\",\n", " \"cosine_final_lr_ratio\": 0.05,\n", " \"max_epochs\": max_epochs,\n", " \"zero_weight_decay_on_bias_and_bn\": True,\n", " \"batch_accumulate\": 1,\n", " \"average_best_models\": False,\n", " \"save_ckpt_epoch_list\": [],\n", " \"loss\": \"yolo_nas_pose_loss\",\n", " \"criterion_params\": {\n", " \"oks_sigmas\": oks_sigmas,\n", " \"classification_loss_weight\": 1.0,\n", " \"classification_loss_type\": \"focal\",\n", " \"regression_iou_loss_type\": \"ciou\",\n", " \"iou_loss_weight\": 2.5,\n", " \"dfl_loss_weight\": 0.01,\n", " \"pose_cls_loss_weight\": 1.0,\n", " \"pose_reg_loss_weight\": 34.0,\n", " \"pose_classification_loss_type\": \"focal\",\n", " \"rescale_pose_loss_with_assigned_score\": True,\n", " \"assigner_multiply_by_pose_oks\": True,\n", " },\n", " \"optimizer\": \"AdamW\",\n", " \"optimizer_params\": {\"weight_decay\": 0.000001},\n", " \"ema\": True,\n", " \"ema_params\": {\"decay\": 0.997, \"decay_type\": \"threshold\"},\n", " \"mixed_precision\": True,\n", " \"sync_bn\": False,\n", " \"valid_metrics_list\": metrics,\n", " \"phase_callbacks\": callbacks,\n", " \"pre_prediction_callback\": None,\n", " \"metric_to_watch\": \"AP\",\n", " \"greater_metric_to_watch_is_better\": True,\n", " }" ] }, { "cell_type": "code", "execution_count": null, "id": "15", "metadata": {}, "outputs": [], "source": [ "flip_indices = KeypointHelper.get_flip_indices_from_table(initial_table)\n", "oks_sigmas = KeypointHelper.get_oks_sigmas_from_table(initial_table)\n", "\n", "train_transforms, val_transforms = create_transforms(image_size=IMAGE_SIZE, flip_indices=flip_indices)" ] }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": {}, "outputs": [], "source": [ "train_dataset = PoseEstimationDataset(train_table, transforms=train_transforms)\n", "val_dataset = PoseEstimationDataset(val_table, transforms=val_transforms)" ] }, { "cell_type": "code", "execution_count": null, "id": "17", "metadata": {}, "outputs": [], "source": [ "post_prediction_callback = YoloNASPosePostPredictionCallback(\n", " pose_confidence_threshold=0.01,\n", " nms_iou_threshold=0.7,\n", " pre_nms_max_predictions=100,\n", " post_nms_max_predictions=15,\n", ")\n", "\n", "pose_estimation_metrics = PoseEstimationMetrics(\n", " num_joints=train_dataset.num_joints,\n", " oks_sigmas=oks_sigmas,\n", " max_objects_per_image=15,\n", " post_prediction_callback=post_prediction_callback,\n", ")\n", "\n", "tlc_callback = PoseEstimationMetricsCollectionCallback(project_name=PROJECT_NAME, run_name=RUN_NAME)" ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": {}, "outputs": [], "source": [ "training_params = create_training_params(\n", " max_epochs=MAX_EPOCHS,\n", " callbacks=[tlc_callback],\n", " metrics=[pose_estimation_metrics],\n", " oks_sigmas=oks_sigmas,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "19", "metadata": {}, "outputs": [], "source": [ "train_dataloader_params = {\n", " \"shuffle\": True,\n", " \"batch_size\": BATCH_SIZE,\n", " \"drop_last\": True,\n", " \"pin_memory\": False,\n", " \"collate_fn\": YoloNASPoseCollateFN(),\n", " \"num_workers\": NUM_WORKERS,\n", " \"persistent_workers\": NUM_WORKERS > 0,\n", "}\n", "val_dataloader_params = {\n", " \"shuffle\": False,\n", " \"batch_size\": BATCH_SIZE,\n", " \"drop_last\": True,\n", " \"pin_memory\": False,\n", " \"collate_fn\": YoloNASPoseCollateFN(),\n", " \"num_workers\": NUM_WORKERS,\n", " \"persistent_workers\": NUM_WORKERS > 0,\n", "}\n", "\n", "train_dataloader = DataLoader(train_dataset, **train_dataloader_params)\n", "val_dataloader = DataLoader(val_dataset, **val_dataloader_params)" ] }, { "cell_type": "markdown", "id": "20", "metadata": {}, "source": [ "## Train model" ] }, { "cell_type": "code", "execution_count": null, "id": "21", "metadata": {}, "outputs": [], "source": [ "yolo_nas_pose = models.get(\n", " MODEL_NAME,\n", " num_classes=20,\n", " checkpoint_path=MODEL_PATH.as_posix(),\n", " checkpoint_num_classes=17,\n", ").cuda()\n", "\n", "trainer = Trainer(experiment_name=RUN_NAME, ckpt_root_dir=DOWNLOAD_PATH + \"/sg-checkpoints\")" ] }, { "cell_type": "code", "execution_count": null, "id": "22", "metadata": {}, "outputs": [], "source": [ "trainer.train(\n", " model=yolo_nas_pose, training_params=training_params, train_loader=train_dataloader, valid_loader=val_dataloader\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" }, "test_marks": [ "dependent" ] }, "nbformat": 4, "nbformat_minor": 5 }