{ "cells": [ { "cell_type": "markdown", "id": "2792dc86", "metadata": {}, "source": [ "# Integrating Cosmos-Transfer for Data Scarcity\n", "\n", "## Description of the Integration\n", "\n", "This tutorial demonstrates how to integrate **NVIDIA Cosmos-Transfer 2.5**, a state-of-the-art **world foundation model (WFM)** for Physical AI, with **FiftyOne**, an open-source tool for visual dataset exploration and model evaluation. The integration enables you to seamlessly **curate, visualize, and process multimodal datasets** (RGB, depth, segmentation, edge maps, etc.) through Cosmos-Transfer’s multi-control generation capabilities, all within the FiftyOne ecosystem.\n", "\n", "By combining Cosmos’s **autoregressive world generation** with FiftyOne’s **dataset management and visualization tools**, this workflow bridges simulation and real-world analysis. It helps developers explore augmented data, evaluate generation quality, and build robust datasets for robotics, autonomous vehicles, video analytics AI agents, and other kind of solutions we will showcase in this tutorial.\n", "\n", "![](https://cdn.voxel51.com/tutorials/cosmos-transfer2_5/biotrove-scarcity.webp)\n", "\n", "\n", "### So, what’s the takeaway?\n", "\n", "With this integration, you can:\n", "- Streamline **data curation and augmentation** pipelines using Cosmos-Transfer and FiftyOne. \n", "- Automate **control map generation** and inference on large video datasets. \n", "- **Visualize and compare** original and generated outputs side-by-side in FiftyOne. \n", "- Prepare your datasets for **physical AI research** and **real-world deployment**.\n", "\n" ] }, { "cell_type": "markdown", "id": "62eaa991", "metadata": {}, "source": [ "## Setup\n", "### 1. Install FiftyOne\n", "\n", "You’ll need Python 3.9+ and other libraries to work with FiftyOne Brain. \n", "\n", "```bash\n", "pip install fiftyone umap-learn\n", "```\n", "\n", "### 2. Install Cosmos-Transfer 2.5 and Dependencies\n", "\n", "Clone the [official repository](https://github.com/nvidia-cosmos/cosmos-transfer2.5) and set up the environment. You can also follow the instructions on this [Cosmos Transfer 2.5 Recipe](https://github.com/nvidia-cosmos/cosmos-cookbook/blob/main/docs/recipes/inference/transfer2_5/inference-biotrove-augmentation_w_FiftyOne/inference.md) and the [Setup Process](https://github.com/nvidia-cosmos/cosmos-cookbook/blob/main/docs/recipes/inference/transfer2_5/inference-biotrove-augmentation_w_FiftyOne/setup.md)\n", "\n", "### System requirements\n", "\n", "- NVIDIA GPUs with Ampere architecture (RTX 30 Series, A100) or newer\n", "- NVIDIA driver >=570.124.06 compatible with CUDA 12.8.1\n", "- Linux x86-64\n", "- glibc>=2.35 (e.g Ubuntu >=22.04)\n", "- Python 3.10\n", "\n", "#### Installation\n", "\n", "After you have your machine ready, follow these [instructions](https://github.com/nvidia-cosmos/cosmos-transfer2.5/blob/main/docs/setup.md#installation)\n", "\n", "```bash\n", "git clone https://github.com/nvidia-cosmos/cosmos-transfer2.5.git\n", "cd cosmos-transfer2.5\n", "pip install -e .\n", "```\n", "\n", "Ensure the following dependencies are met:\n", "- **PyTorch ≥ 2.5**\n", "- **TorchVision**\n", "- **CUDA 12+** (Blackwell optimized)\n", "- **FFmpeg** (for video conversion)\n", "- **Gradio** (for optional interface)\n", "- **EasyIO** (for multi-storage backend)\n", "- **json5**, **pyrefly**, and **torch-compile** (for tokenizer optimization)\n", "\n" ] }, { "cell_type": "markdown", "id": "97a94062", "metadata": {}, "source": [ "## Load Your Dataset into FiftyOne\n", "\n", "For this tutorial, we’ll use a subset of the **BioTrove dataset**, which include samples of moths in multiple scenarios but the majority of them not in the real environments." ] }, { "cell_type": "code", "execution_count": null, "id": "1f98068a", "metadata": {}, "outputs": [], "source": [ "import fiftyone as fo\n", "import fiftyone.utils.huggingface as fouh\n", "\n", "# Your source dataset (images)\n", "dataset_src = fouh.load_from_hub(\n", " \"pjramg/moth_biotrove\",\n", " persistent=True,\n", " overwrite=True, # set to False if you don't want to re-download\n", " max_samples=2,\n", ")\n", "\n", "print(dataset_src.name, dataset_src.media_type, len(dataset_src))" ] }, { "cell_type": "code", "execution_count": null, "id": "53db0818", "metadata": {}, "outputs": [], "source": [ "import os\n", "# Get the filepath of the first sample\n", "first_sample = dataset_src.first()\n", "sample_filepath = first_sample.filepath\n", "# Extract the directory where the dataset is stored\n", "dataset_directory = os.path.dirname(sample_filepath)\n", "print(f\"Dataset directory: {dataset_directory}\")" ] }, { "cell_type": "markdown", "id": "550a152b", "metadata": {}, "source": [ "### Create a grouped dataset\n", "\n", "For educational purposes we will create a slice per step in our Cosmos-Transfer integration." ] }, { "cell_type": "code", "execution_count": null, "id": "e8bdf1bc", "metadata": {}, "outputs": [], "source": [ "import fiftyone as fo\n", "\n", "dataset_grp_name = \"moth_biotrove_grouped\"\n", "if fo.dataset_exists(dataset_grp_name):\n", " fo.delete_dataset(dataset_grp_name)\n", "\n", "# Create dataset WITHOUT media_type\n", "dataset_grp = fo.Dataset(dataset_grp_name)\n", "\n", "# Add a group field with a default slice\n", "dataset_grp.persistent = True\n", "\n", "\n", "dataset_grp\n" ] }, { "cell_type": "markdown", "id": "af3d3947", "metadata": {}, "source": [ "## Check Dataset Format and Add Videos\n", "\n", "Cosmos-Transfer works on **videos**. \n", "If your dataset consists of **images**, you must convert them into short MP4 video clips before using them as inputs.\n", "\n", "In this example, the FiftyOne dataset is downloaded from the HuggingFace Hub into the local FiftyOne directory, typically located at:\n", "\n", "```~/fiftyone/huggingface/hub///```\n", "\n", "Each sample image will be converted into a **1-second video** with the same filename.\n", "\n", "---\n", "\n", "### Convert Images to Videos (Python Version)\n", "\n", "Instead of using a shell loop, we programmatically:\n", "\n", "1. Locate the directory where FiftyOne stores the downloaded dataset media. \n", "2. Create a new `videos/` directory next to the images. \n", "3. Walk through every image in the dataset. \n", "4. Convert each JPG into a 1-second MP4 video using FFmpeg. \n", "\n", "This method works consistently across environments and avoids issues with shell globbing or missing directories.\n", "\n", "Below is the Python code used in this notebook:\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7a0547ca", "metadata": {}, "outputs": [], "source": [ "import os\n", "import subprocess\n", "from pathlib import Path\n", "\n", "# 1) Where your dataset images live\n", "images_root = Path(dataset_directory)\n", "\n", "# 2) Create videos folder in the same parent directory as the data folder\n", "videos_root = images_root.parent / \"videos\"\n", "videos_root.mkdir(parents=True, exist_ok=True)\n", "\n", "print(f\"Images root: {images_root.resolve()}\")\n", "print(f\"Videos root: {videos_root.resolve()}\")\n", "\n", "# 3) Walk all subdirectories under images_root\n", "for dirpath, _dirnames, filenames in os.walk(images_root):\n", " jpgs = sorted([f for f in filenames if f.lower().endswith(\".jpg\")])\n", " if not jpgs:\n", " # No JPGs in this folder, skip\n", " continue\n", "\n", " dirpath = Path(dirpath)\n", " \n", " # Process each image individually\n", " for jpg_file in jpgs:\n", " input_image = dirpath / jpg_file\n", " \n", " # Get the image name without extension\n", " image_name = Path(jpg_file).stem\n", " \n", " # Create output video with same name as image\n", " output_video = videos_root / f\"{image_name}.mp4\"\n", "\n", " cmd = [\n", " \"ffmpeg\",\n", " \"-y\",\n", " \"-loop\", \"1\", # Loop the single image\n", " \"-i\", str(input_image),\n", " \"-t\", \"1\", # Duration of 1 second\n", " \"-vf\", \"pad=ceil(iw/2)*2:ceil(ih/2)*2\", # Pad to even dimensions\n", " \"-c:v\", \"libx264\",\n", " \"-pix_fmt\", \"yuv420p\",\n", " str(output_video),\n", " ]\n", "\n", " print(f\"\\nProcessing image: {jpg_file}\")\n", " print(f\"Output video: {output_video}\")\n", " \n", " try:\n", " subprocess.run(cmd, check=True, capture_output=True, text=True)\n", " print(f\"✓ Created: {output_video}\")\n", " except subprocess.CalledProcessError as e:\n", " print(f\"✗ Error processing {jpg_file}:\")\n", " print(f\"stderr: {e.stderr}\")\n", " continue\n", "\n", "print(\"\\nDone!\")" ] }, { "cell_type": "code", "execution_count": null, "id": "aff048eb", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import fiftyone as fo\n", "\n", "# Where the MP4s were written by your FFmpeg loop\n", "VIDEO_ROOT = videos_root\n", "\n", "def guess_video_path(image_path: str) -> str | None:\n", " \"\"\"\n", " Given an image path, try to infer the matching MP4 path.\n", " Strategy:\n", " - Take the image filename stem\n", " - Expect an MP4 named .mp4 under VIDEO_ROOT\n", " \"\"\"\n", " p = Path(image_path)\n", " image_stem = p.stem # filename without extension\n", " candidate = VIDEO_ROOT / f\"{image_stem}.mp4\"\n", " return str(candidate) if candidate.exists() else None\n", "\n", "# Add slice to the grouped dataset\n", "dataset_grp.add_group_field(\"group\", default=\"image\")\n", "\n", "# Optional: copy over selected label fields from the source sample to the new grouped sample\n", "# Exclude system fields\n", "exclude = {\"id\", \"filepath\", \"group\", \"tags\", \"metadata\"}\n", "schema = dataset_src.get_field_schema()\n", "\n", "# Build grouped samples\n", "samples_to_add = []\n", "\n", "for s in dataset_src.iter_samples(progress=True):\n", " video_fp = guess_video_path(s.filepath)\n", "\n", " # Always create the group anchor\n", " g = fo.Group()\n", "\n", " # --- IMAGE SLICE ---\n", " img_sample = fo.Sample(\n", " filepath=s.filepath,\n", " group=g.element(\"image\"),\n", " tags=list(s.tags) if s.tags else None,\n", " metadata=s.metadata,\n", " )\n", " # copy user fields\n", " for field in schema:\n", " if field not in exclude and hasattr(s, field):\n", " img_sample[field] = getattr(s, field)\n", "\n", " samples_to_add.append(img_sample)\n", "\n", " # --- VIDEO SLICE (optional if present) ---\n", " if video_fp is not None:\n", " vid_sample = fo.Sample(\n", " filepath=video_fp,\n", " group=g.element(\"video\"),\n", " tags=list(s.tags) if s.tags else None,\n", " )\n", " samples_to_add.append(vid_sample)\n", "\n", "dataset_grp.add_samples(samples_to_add)\n", "\n", "print(dataset_grp.name, dataset_grp.media_type, len(dataset_grp))" ] }, { "cell_type": "markdown", "id": "0d1d0b90", "metadata": {}, "source": [ "## Full Python Batch Pipeline\n", "### Cosmos-Transfer 2.5 + FiftyOne Integration\n", "\n", "The pipeline automates the complete process of dataset augmentation and model inference:\n", "\n", "- Collect input videos from the assets directory or a custom list file.\n", "- Generate Canny edge videos using OpenCV to serve as control maps for Cosmos-Transfer 2.5.\n", "- Write JSON spec files defining prompts, guidance, and control paths for each video.\n", "- Run Cosmos-Transfer inference by invoking the official examples/inference.py script through subprocess, leveraging your current Python environment.\n", "- Extract the final frame from each output video to create a static “output last” thumbnail image.\n", "- Attach all derived data—edge videos, generated outputs, and last-frame thumbnails—as new slices (edge, output, output_last) in your existing grouped FiftyOne dataset.\n", "\n", "Once complete, you can open the dataset in the FiftyOne App to visually compare the original images, control maps, generated videos, and last-frame outputs side-by-side. This all-in-Python workflow provides a fully reproducible, GPU-accelerated, and data-centric way to evaluate Cosmos-Transfer 2.5 results directly within your notebook environment.\n", "\n", "\n", "### Configuration Variables\n", "\n", "- `images_root`: Path to the original dataset images downloaded from Hugging Face\n", "- `base_dir`: Parent directory containing all pipeline data (`images_root.parent`)\n", "- `ASSETS_DIR`: Directory containing input videos for processing\n", "- `OUT_DIR`: Directory where Cosmos inference results are saved\n", "- `SPECS_DIR`: Directory for JSON specification files used by Cosmos\n", "- `EDGE_DIR`: Directory for Canny edge detection control videos\n", "- `LAST_FRAMES_DIR`: Directory for extracted last frames from output videos\n", "- `MAX_VIDS`: Maximum number of videos to process (default: 100)\n", "- `COSMOS_DIR`: Root directory of the Cosmos repository\n", "- `INFER_SCRIPT`: Path to the Cosmos inference script\n", "\n", "### Inference Parameters\n", "\n", "- `MOTH_PROMPT`: Detailed prompt describing the desired output style\n", "- `NEG_PROMPT`: Negative prompt to avoid unwanted artifacts\n", "- `GUIDANCE`: Guidance scale for inference (default: 7)\n", "- `RESOLUTION`: Output resolution (default: \"720\")\n", "- `NUM_STEPS`: Number of inference steps (default: 38)\n", "\n", "### FiftyOne Integration\n", "\n", "- `GROUPED_DATASET_NAME`: Name of the FiftyOne grouped dataset to update with new slices\n", "\n", "This configuration ensures all pipeline outputs are organized in a consistent structure alongside your original Hugging Face dataset, making it easy to manage and integrate with FiftyOne's grouped dataset functionality." ] }, { "cell_type": "code", "execution_count": null, "id": "a2474186", "metadata": {}, "outputs": [], "source": [ "# --- Python-only batch pipeline: Cosmos-Transfer2.5 + FiftyOne ---\n", "\n", "import os\n", "import sys\n", "import json\n", "import subprocess\n", "from pathlib import Path\n", "from typing import List, Optional\n", "from datetime import datetime\n", "\n", "import cv2\n", "import fiftyone as fo\n", "\n", "# Set the environment variable for PyTorch in case of memory issues\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"expandable_segments:True\"\n", "\n", "# -------------------- CONFIG --------------------\n", "# Get the dataset directory from your earlier code\n", "images_root = Path(dataset_directory) # This is ~/fiftyone/huggingface/hub///data\n", "base_dir = images_root.parent # This is ~/fiftyone/huggingface/hub//\n", "\n", "#Setup where the COSMOS-TRANSFER2.5 REPO folder is\n", "os.environ[\"COSMOS_DIR\"] = \"/path/to/cosmos-transfer2.5/folder\"\n", "\n", "# Set all directories relative to the base directory\n", "ASSETS_DIR = base_dir / \"videos\" # Where your input videos are\n", "OUT_DIR = base_dir / \"cosmos_result\" # Where inference outputs go\n", "SPECS_DIR = base_dir / \"specs\" # JSON specs\n", "EDGE_DIR = base_dir / \"edge\" # Edge control videos\n", "LAST_FRAMES_DIR = base_dir / \"last_frame\" # Last frames extracted from outputs\n", "LIST_FILE = Path(os.environ.get(\"LIST_FILE\", str(ASSETS_DIR / \"video_list.txt\")))\n", "MAX_VIDS = int(os.environ.get(\"MAX_VIDS\", \"100\"))\n", "COSMOS_DIR = Path(os.environ.get(\"COSMOS_DIR\", \".\"))\n", "INFER_SCRIPT = COSMOS_DIR / \"examples\" / \"inference.py\"\n", "\n", "GROUPED_DATASET_NAME = \"moth_biotrove_grouped\"\n", "\n", "print(f\"Base directory: {base_dir}\")\n", "print(f\"Videos directory: {ASSETS_DIR}\")\n", "print(f\"Output directory: {OUT_DIR}\")\n", "print(f\"Edge directory: {EDGE_DIR}\")\n", "print(f\"Last frames directory: {LAST_FRAMES_DIR}\")\n", "\n", "# Prompts / inference params\n", "MOTH_PROMPT = os.environ.get(\"MOTH_PROMPT\", \"\"\"\n", "The video depicts a realistic outdoor scene captured during daytime in a natural environment. \n", "A single moth is the primary subject, sharply focused, with crisp wing edges, fine scale texture, and natural coloration. \n", "The background consists of softly blurred green leaves and sunlit foliage, replacing the original scene. \n", "Lighting is natural and diffused, as if under mild sunlight or partial shade, producing realistic contrast, soft shadows, and gentle highlights on the moth and nearby leaves. \n", "The overall tone is photographic and lifelike, with balanced exposure and true-to-life colors. \n", "Fine details such as subtle motion blur, tiny airborne particles, and depth of field contribute to a high-quality, authentic nature documentary aesthetic. \n", "Avoid stylization; the output should appear like a professional macro wildlife recording captured in natural daylight.\n", "\"\"\").strip()\n", "\n", "NEG_PROMPT = os.environ.get(\n", " \"NEG_PROMPT\",\n", " \"blurry, motion blur, defocus, low-detail, oversmoothed, painterly, cartoon, glow, haze, \"\n", " \"halos, banding, ghosting, soft edges, unrealistic lighting, watercolor, low contrast\"\n", ")\n", "\n", "GUIDANCE = int(os.environ.get(\"GUIDANCE\", \"7\"))\n", "RESOLUTION = os.environ.get(\"RESOLUTION\", \"720\")\n", "NUM_STEPS = int(os.environ.get(\"NUM_STEPS\", \"38\"))\n", "\n", "# -------------------- UTILITIES --------------------\n", "def list_videos(assets_dir: Path, list_file: Path, max_vids: int) -> List[Path]:\n", " \"\"\"Collect input .mp4 videos, respecting an optional list file.\"\"\"\n", " vids: List[Path] = []\n", " if list_file.exists():\n", " with list_file.open() as f:\n", " for line in f:\n", " name = line.strip()\n", " if not name:\n", " continue\n", " vids.append((assets_dir / name).with_suffix(\".mp4\") if not name.endswith(\".mp4\") else assets_dir / name)\n", " else:\n", " vids = sorted(assets_dir.glob(\"*.mp4\"))\n", " if max_vids > 0:\n", " vids = vids[:max_vids]\n", " return vids\n", "\n", "def ensure_dirs():\n", " SPECS_DIR.mkdir(parents=True, exist_ok=True)\n", " EDGE_DIR.mkdir(parents=True, exist_ok=True)\n", " OUT_DIR.mkdir(parents=True, exist_ok=True)\n", "\n", "def write_spec_json(spec_path: Path, video_abs: Path, edge_abs: Path, name: str):\n", " obj = {\n", " \"name\": name,\n", " \"prompt\": MOTH_PROMPT,\n", " \"negative_prompt\": NEG_PROMPT,\n", " \"video_path\": str(video_abs),\n", " \"guidance\": GUIDANCE,\n", " \"resolution\": RESOLUTION,\n", " \"num_steps\": NUM_STEPS,\n", " \"edge\": {\n", " \"control_weight\": 1.0,\n", " \"control_path\": str(edge_abs),\n", " },\n", " }\n", " spec_path.parent.mkdir(parents=True, exist_ok=True)\n", " spec_path.write_text(json.dumps(obj, indent=2))\n", "\n", "def make_edge_video(input_video: Path, output_video: Path) -> bool:\n", " \"\"\"Generate Canny edge control video (grayscale) for an input video.\"\"\"\n", " cap = cv2.VideoCapture(str(input_video))\n", " if not cap.isOpened():\n", " print(f\"[ERROR] Cannot open video: {input_video}\")\n", " return False\n", " fps = cap.get(cv2.CAP_PROP_FPS) or 24\n", " w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n", " h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n", " fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n", " out = cv2.VideoWriter(str(output_video), fourcc, fps, (w, h), isColor=False)\n", " ok = True\n", " while True:\n", " ret, frame = cap.read()\n", " if not ret:\n", " break\n", " gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)\n", " edges = cv2.Canny(gray, 80, 180)\n", " out.write(edges)\n", " cap.release() \n", " out.release()\n", " if not output_video.exists():\n", " print(f\"[ERROR] Edge video not created: {output_video}\")\n", " ok = False\n", " return ok\n", "\n", "def run_cosmos_inference(spec_json: Path, out_dir: Path) -> bool:\n", " \"\"\"Call the Cosmos inference script via the current Python interpreter.\"\"\"\n", " if not INFER_SCRIPT.exists():\n", " print(f\"[ERROR] inference script not found: {INFER_SCRIPT}\")\n", " return False\n", " cmd = [sys.executable, str(INFER_SCRIPT), \"-i\", str(spec_json), \"-o\", str(out_dir)]\n", " try:\n", " subprocess.run(cmd, check=True)\n", " return True\n", " except subprocess.CalledProcessError as e:\n", " print(f\"[ERROR] inference failed ({spec_json.name}): {e}\")\n", " return False\n", "\n", "def extract_last_frame(video_path: Path, dst_dir: Path) -> Optional[Path]:\n", " \"\"\"Extract last frame as PNG from a video.\"\"\"\n", " dst_dir.mkdir(parents=True, exist_ok=True)\n", " cap = cv2.VideoCapture(str(video_path))\n", " if not cap.isOpened():\n", " print(f\"[warn] cannot open {video_path}\")\n", " return None\n", " total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)\n", " out_png = dst_dir / f\"{video_path.stem}_last.png\"\n", "\n", " if total > 0:\n", " cap.set(cv2.CAP_PROP_POS_FRAMES, max(total - 1, 0))\n", " ret, frame = cap.read()\n", " if not ret or frame is None:\n", " cap.set(cv2.CAP_PROP_POS_FRAMES, max(total - 2, 0))\n", " ret, frame = cap.read()\n", " cap.release()\n", " if not ret or frame is None:\n", " print(f\"[warn] could not read last frame for {video_path}\")\n", " return None\n", " cv2.imwrite(str(out_png), frame)\n", " return out_png\n", "\n", " # Fallback for weird metadata\n", " last = None\n", " while True:\n", " ret, frame = cap.read()\n", " if not ret:\n", " break\n", " last = frame\n", " cap.release()\n", " if last is None:\n", " return None\n", " cv2.imwrite(str(out_png), last)\n", " return out_png\n", "\n", "def base_key(path: Path) -> str:\n", " \"\"\"Derive a key shared across slices; here we use filename stem.\"\"\"\n", " return path.stem\n", "\n", "# -------------------- PIPELINE --------------------\n", "ensure_dirs()\n", "videos = list_videos(ASSETS_DIR, LIST_FILE, MAX_VIDS)\n", "print(f\"[info] found {len(videos)} videos under {ASSETS_DIR}\")\n", "\n", "ok = fail = 0\n", "for idx, v in enumerate(videos, 1):\n", " if not v.exists():\n", " print(f\"[warn] missing: {v}\")\n", " fail += 1\n", " continue\n", "\n", " name = base_key(v)\n", " edge_mp4 = EDGE_DIR / f\"{name}_edge.mp4\"\n", " spec_json = SPECS_DIR / f\"{name}.json\"\n", "\n", " print(f\"[{idx}/{len(videos)}] Edge gen: {name}\")\n", " if not make_edge_video(v, edge_mp4):\n", " fail += 1\n", " continue\n", "\n", " write_spec_json(spec_json, v.resolve(), edge_mp4.resolve(), name)\n", "\n", " print(f\"[{idx}/{len(videos)}] Inference: {name}\")\n", " if run_cosmos_inference(spec_json, OUT_DIR):\n", " ok += 1\n", " else:\n", " fail += 1\n", "\n", "print(f\"Done. Success: {ok} Failed: {fail}\")\n", "print(f\"Specs: {SPECS_DIR}\")\n", "print(f\"Edges: {EDGE_DIR}\")\n", "print(f\"Outputs: {OUT_DIR}\")" ] }, { "cell_type": "markdown", "id": "a776783a", "metadata": {}, "source": [ "### FiftyOne Integration: Adding Cosmos Pipeline Results to Grouped Dataset\n", "\n", "This cell integrates the Cosmos-Transfer2.5 pipeline outputs into your existing FiftyOne grouped dataset, adding three new slices for each processed video.\n", "\n", "### Process Overview\n", "\n", "1. **Load Existing Grouped Dataset**\n", " - Loads the grouped dataset created earlier (containing `image` and `video` slices)\n", " - Raises an error if the dataset doesn't exist\n", "\n", "2. **Index Groups by Image Keys**\n", " - Iterates through all samples in the `image` slice\n", " - Creates a mapping from filename stems to their corresponding group objects\n", " - This enables matching Cosmos outputs back to their original groups\n", "\n", "3. **Add New Slices to Groups**\n", " \n", " For each Cosmos output video, the following slices are added to the matching group:\n", " \n", " - `edge`: Canny edge detection video used as control input\n", " - `output`: Cosmos-generated output video\n", " - `output_last`: Last frame extracted from the output video (as an image)\n", "\n", "4. **Error Handling**\n", " - Tracks unmatched outputs (videos without corresponding groups)\n", " - Warns about missing edge videos or failed frame extractions\n", " - Reports summary statistics\n", "\n", "5. **Dataset Finalization**\n", " - Adds all new samples to the dataset in batch\n", " - Reloads the dataset to ensure changes are reflected\n", " - Displays final group slices and media types\n", "\n", "6. **Launch FiftyOne App**\n", " - Opens the FiftyOne App for interactive visualization\n", " - Displays the URL for accessing the App\n", "\n", "### Expected Group Structure\n", "\n", "After this cell completes, each group will contain up to 5 slices:\n", "\n", "- `image`: Original image from Hugging Face dataset\n", "- `video`: Video created from the image\n", "- `edge`: Last frame of Canny edge detection video\n", "- `output`: Cosmos-generated output video\n", "- `output_last`: Last frame of the Cosmos output (image)\n", "\n", "This grouped structure allows synchronized visualization and comparison of all related data in the FiftyOne App." ] }, { "cell_type": "code", "execution_count": null, "id": "fb138868-b6d4-4027-9053-c0316df0b220", "metadata": {}, "outputs": [], "source": [ "# -------------------- FIFTYONE INTEGRATION --------------------\n", "import fiftyone as fo\n", "from pathlib import Path\n", "\n", "def base_key(path: Path) -> str:\n", " return path.stem\n", "\n", "if fo.dataset_exists(GROUPED_DATASET_NAME):\n", " dataset_grp = fo.load_dataset(GROUPED_DATASET_NAME)\n", "else:\n", " raise RuntimeError(\n", " f\"Grouped dataset '{GROUPED_DATASET_NAME}' not found. \"\n", " \"Create it earlier (image/video slices) before running this cell.\"\n", " )\n", "\n", "# --- Index groups by the IMAGE slice ---\n", "image_groups = {}\n", "for s in dataset_grp.match({\"group.name\": \"image\"}).iter_samples(progress=True):\n", " image_groups[base_key(Path(s.filepath))] = s.group\n", "\n", "new_slices = []\n", "\n", "# Folder where last-frame PNGs will go\n", "edge_last_dir = OUT_DIR / \"edge_last_frames\"\n", "output_last_dir = OUT_DIR / \"output_last_frames\"\n", "edge_last_dir.mkdir(parents=True, exist_ok=True)\n", "output_last_dir.mkdir(parents=True, exist_ok=True)\n", "\n", "# Loop over ALL outputs (*.mp4) but skip *_control_edge.mp4\n", "for out_vid in sorted(OUT_DIR.glob(\"*.mp4\")):\n", " stem = out_vid.stem\n", " if stem.endswith(\"_control_edge\"):\n", " continue\n", "\n", " key = stem\n", " grp = image_groups.get(key)\n", "\n", " if grp is None:\n", " print(f\"[warn] No matching group found for output: {key}\")\n", " continue\n", "\n", " # --- 1. Edge last frame slice ---\n", " cosmos_edge_fp = OUT_DIR / f\"{key}_control_edge.mp4\"\n", " if cosmos_edge_fp.exists():\n", " edge_png = extract_last_frame(cosmos_edge_fp, edge_last_dir)\n", " if edge_png and edge_png.exists():\n", " new_slices.append(\n", " fo.Sample(\n", " filepath=str(edge_png),\n", " group=grp.element(\"edge\"),\n", " )\n", " )\n", " else:\n", " print(f\"[warn] Could not extract last frame for edge: {cosmos_edge_fp}\")\n", " else:\n", " print(f\"[warn] Cosmos edge video not found: {cosmos_edge_fp}\")\n", "\n", " # --- 2. Output video slice ---\n", " new_slices.append(\n", " fo.Sample(\n", " filepath=str(out_vid),\n", " group=grp.element(\"output\"),\n", " )\n", " )\n", "\n", " # --- 3. Output last frame slice ---\n", " output_png = extract_last_frame(out_vid, output_last_dir)\n", " if output_png and output_png.exists():\n", " new_slices.append(\n", " fo.Sample(\n", " filepath=str(output_png),\n", " group=grp.element(\"output_last\"),\n", " )\n", " )\n", " else:\n", " print(f\"[warn] Could not extract last frame for output: {out_vid}\")\n", "\n", "# --- Add slices to dataset ---\n", "if new_slices:\n", " dataset_grp.add_samples(new_slices)\n", " print(f\"[OK] Added {len(new_slices)} new slices (edge_last/output/output_last)\")\n", " print(f\"[info] Group slices now: {dataset_grp.group_slices}\")\n", " print(f\"[info] Group media types: {dataset_grp.group_media_types}\")\n", "else:\n", " print(\"[warn] No new slices were added\")\n", "\n", "dataset_grp.reload()\n", "\n", "print(f\"\\n[info] Final dataset summary:\")\n", "print(f\" Total samples: {len(dataset_grp)}\")\n", "print(f\" Available group slices: {dataset_grp.group_slices}\")\n", "\n", "session = fo.launch_app(dataset_grp, port=5151, auto=False)\n", "print(f\"[info] FiftyOne App launched at: {session.url}\")\n" ] }, { "cell_type": "markdown", "id": "64f0be24", "metadata": {}, "source": [ "### Adding Slice Identifier Field to Grouped Dataset\n", "\n", "This cell adds a `slice_name` field to the grouped dataset to identify which group slice each sample belongs to. This is useful for filtering, visualization, and analysis based on slice origin." ] }, { "cell_type": "code", "execution_count": null, "id": "26db1711", "metadata": {}, "outputs": [], "source": [ "# Add a field to store the slice name\n", "dataset_grp.add_sample_field(\"slice_name\", fo.StringField)\n", "\n", "# Iterate through each slice and set the slice_name field\n", "for slice_name in [\"image\", \"output_last\"]:\n", " # Get samples from this slice\n", " slice_view = dataset_grp.select_group_slices(slice_name)\n", " \n", " # Set the slice_name field for all samples in this slice\n", " for sample in slice_view:\n", " sample[\"slice_name\"] = slice_name\n", " sample.save()" ] }, { "cell_type": "markdown", "id": "93434e17", "metadata": {}, "source": [ "### Computing Embeddings and Similarity Index for `image` and `output_last` Slices in FiftyOne\n", "\n", "This cell demonstrates how to compute embeddings and build a similarity index for both `image` and `output_last` slices in a grouped dataset using the CLIP model from the FiftyOne Model Zoo.\n", "\n", "**Workflow**\n", "\n", "1. **Select multiple slices:** Create a flattened view containing samples from both `image` and `output_last` slices using `select_group_slices([\"image\", \"output_last\"])`. \n", "2. **Load the model:** Retrieve the `clip-vit-base32-torch` model from the FiftyOne Model Zoo, which can generate embeddings for images. \n", "3. **Compute embeddings:** Use the model to generate embeddings for all samples in the flattened view and store them in a field called `embeddings`. \n", "4. **Build a similarity index:** Use FiftyOne Brain's `compute_similarity()` to index the embeddings and enable similarity search or ranking directly in the FiftyOne App under the brain key `key_sim`.\"" ] }, { "cell_type": "code", "execution_count": null, "id": "762a793d", "metadata": {}, "outputs": [], "source": [ "import fiftyone.zoo as foz\n", "import fiftyone.brain as fob\n", "\n", "# Select the two slices you need (e.g., \"image\" and \"output_last\")\n", "flattened_view = dataset_grp.select_group_slices([\"image\", \"output_last\"])\n", "\n", "# Load a model and compute embeddings\n", "model = foz.load_zoo_model(\"clip-vit-base32-torch\")\n", "flattened_view.compute_embeddings(model, embeddings_field=\"embeddings\")\n", "\n", "# Create similarity index for crops\n", "fob.compute_similarity(\n", " flattened_view,\n", " model=\"clip-vit-base32-torch\",\n", " embeddings=\"embeddings\",\n", " brain_key=\"key_sim\",\n", ")\n", "\n", "print(\"[INFO] output_last_view embeddings computed successfully!\")" ] }, { "cell_type": "markdown", "id": "7ac10967", "metadata": {}, "source": [ "### Visualizing Embeddings for `image` and `output_last` Slices in FiftyOne\n", "\n", "This cell creates an interactive 2D visualization of the embeddings computed for both **`image`** and **`output_last`** slices using dimensionality reduction.\n", "\n", "**What it does**\n", "1. **Use existing embeddings:** Works with the embeddings already computed and stored in the `\"embeddings\"` field on the `flattened_view` (containing both `image` and `output_last` slices). \n", "2. **Apply dimensionality reduction:** Uses UMAP to reduce the high-dimensional embeddings to 2D for visualization. You can also use `\"tsne\"` or `\"pca\"` as alternatives. \n", "3. **Store results:** Saves the visualization under the brain key `\"slice_embeddings_viz\"` for access in the FiftyOne App. \n", "4. **Explore:** Open the FiftyOne App's Embeddings panel to interactively explore clusters, filter by the `slice_name` field to distinguish between `image` and `output_last` samples, and select points of interest.\n", "\n", "**Tip:** Color the visualization by `slice_name` to see how the two slices compare in the embedding space:\n", "```python\n", "plot = results.visualize(labels=\"slice_name\")\n", "plot.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "95a7cd7e", "metadata": {}, "outputs": [], "source": [ "import fiftyone.brain as fob\n", "# Compute visualization using the embeddings you already computed\n", "results = fob.compute_visualization(\n", " flattened_view,\n", " embeddings=\"embeddings\", # The field where your embeddings are stored\n", " method=\"umap\", # You can also use \"tsne\" or \"pca\"\n", " brain_key=\"slice_embeddings_viz\",\n", " num_dims=2,\n", " verbose=True\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "fc4fb3b8", "metadata": {}, "outputs": [], "source": [ "session = fo.launch_app(flattened_view, port=5151, auto=False)" ] }, { "cell_type": "markdown", "id": "7abf9f7b", "metadata": {}, "source": [ "![image](https://cdn.voxel51.com/tutorials/cosmos-transfer2_5/cosmos.webp)" ] }, { "cell_type": "markdown", "id": "ae8f9996", "metadata": {}, "source": [ "## Summary\n", "\n", "In this tutorial, we built a **complete integration pipeline** between **Cosmos-Transfer 2.5** and **FiftyOne**, entirely within a Python notebook. \n", "\n", "You learned how to:\n", "\n", "- **Set up and install** both Cosmos-Transfer 2.5 and FiftyOne, ensuring all required dependencies (PyTorch, CUDA, OpenCV) are ready for GPU-accelerated inference. \n", "- **Load and prepare datasets** (such as the BioTrove subset) from Hugging Face, convert images to video format when needed, and organize everything into **grouped datasets** in FiftyOne for multimodal exploration. \n", "- **Generate control maps** automatically using Canny edge detection to guide Cosmos-Transfer’s multi-ControlNet inference process. \n", "- **Run Cosmos-Transfer inference in pure Python**, without shell scripts, by dynamically creating JSON spec files and invoking the model through `subprocess`. \n", "- **Extract the last frame** from each output video to create a static visualization slice for comparison and analysis. \n", "- **Explore results in FiftyOne**, side-by-side, across slices (`image`, `video`, `edge`, `output`, `output_last`) to visualize model performance, quality, and domain alignment.\n", "\n", "This end-to-end workflow demonstrates how to connect **synthetic data generation** and **real-world evaluation** into a single reproducible, data-centric loop—helping you accelerate experimentation, debug model behavior, and evaluate **Physical AI systems** across robotics, autonomous driving, and environmental perception tasks.\n" ] }, { "cell_type": "markdown", "id": "ee17bf6e", "metadata": {}, "source": [ "### Would you like to know more about a easy integration between Cosmos-Transfer and FiftyOne?\n", "\n", "You can run this in your side using all the magic of Open-Source, if you want to move this to the next level. I encourage you to book a demo and see all the capabilities with this integration. Visit: [Physical AI Workbench](https://voxel51.com/physical-ai)" ] } ], "metadata": { "kernelspec": { "display_name": "Python (cosmos-transfer2)", "language": "python", "name": "cosmos-transfer2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }