{ "cells": [ { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "raw" } }, "source": [ "# Auto-segment images using SAM\n", "\n", "This notebook creates a 3LC Table with auto-generated segmentation masks from\n", "SAM using grid-based point prompting.\n", "\n", "![img](../images/sam-autosegment.jpg)\n", "\n", "\n", "\n", "Unlike the [bounding box-based approach](transform-bbs-to-segs.ipynb), this\n", "method automatically discovers objects in images without requiring ground truth\n", "annotations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install 3lc\n", "%pip install git+https://github.com/3lc-ai/3lc-examples.git\n", "%pip install git+https://github.com/facebookresearch/segment-anything\n", "%pip install matplotlib" ] }, { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "raw" } }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import cv2\n", "import numpy as np\n", "import requests\n", "import tlc\n", "from segment_anything import SamAutomaticMaskGenerator, sam_model_registry\n", "from tqdm import tqdm\n", "\n", "from tlc_tools.common import infer_torch_device" ] }, { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "raw" } }, "source": [ "## Project Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - COCO128\"\n", "DATASET_NAME = \"AutoSegmented Images\"\n", "TABLE_NAME = \"autosegmented_images\"\n", "MODEL_TYPE = \"vit_b\"\n", "DOWNLOAD_PATH = \"../../transient_data\"\n", "\n", "# Image dataset configuration\n", "DATA_PATH = \"../../data\"\n", "MAX_IMAGES = 20 # Limit for initial testing - set to None for all images\n", "\n", "# Segmentation filtering parameters\n", "MIN_AREA_THRESHOLD = 1000 # Minimum area in pixels to keep a segment\n", "MIN_CONFIDENCE_THRESHOLD = 0.7 # Minimum confidence score to keep a segment\n", "POINTS_PER_SIDE = 32\n", "PRED_IOU_THRESHOLD = 0.88" ] }, { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "raw" } }, "source": [ "## Load Images" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Get list of image files\n", "image_dir = (Path(DATA_PATH) / \"coco128\" / \"images\").resolve()\n", "image_extensions = {\".jpg\", \".jpeg\", \".png\", \".bmp\", \".tiff\"}\n", "image_files = [f for f in image_dir.glob(\"*\") if f.suffix.lower() in image_extensions]\n", "\n", "if MAX_IMAGES is not None:\n", " image_files = image_files[:MAX_IMAGES]\n", "\n", "print(f\"Found {len(image_files)} images to process\")\n", "print(f\"Sample images: {[f.name for f in image_files[:5]]}\")" ] }, { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "raw" } }, "source": [ "## Initialize SAM Automatic Mask Generator\n", "\n", "We'll use SAM's `SamAutomaticMaskGenerator` which uses automatic point grid prompting to segment all objects in each image without requiring any input prompts." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Download checkpoint\n", "\n", "model_url = \"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth\"\n", "# Download the SAM model checkpoint if it doesn't already exist\n", "checkpoint_path = Path(DOWNLOAD_PATH) / \"sam_vit_b_01ec64.pth\"\n", "\n", "if not checkpoint_path.exists():\n", " print(f\"Downloading SAM checkpoint to {checkpoint_path}...\")\n", " checkpoint_path.parent.mkdir(parents=True, exist_ok=True)\n", " response = requests.get(model_url, stream=True)\n", " total_size = int(response.headers.get(\"content-length\", 0))\n", " with (\n", " open(checkpoint_path, \"wb\") as f,\n", " tqdm(\n", " desc=\"Downloading\",\n", " total=total_size,\n", " unit=\"B\",\n", " unit_scale=True,\n", " unit_divisor=1024,\n", " ) as bar,\n", " ):\n", " for chunk in response.iter_content(chunk_size=8192):\n", " if chunk:\n", " f.write(chunk)\n", " bar.update(len(chunk))\n", " print(\"Download completed.\")\n", "else:\n", " print(f\"Checkpoint already exists at {checkpoint_path}.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load SAM model\n", "device = infer_torch_device()\n", "sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint_path)\n", "sam.to(device=device)\n", "\n", "# Initialize the automatic mask generator\n", "mask_generator = SamAutomaticMaskGenerator(\n", " model=sam,\n", " points_per_side=POINTS_PER_SIDE, # Grid size for point prompts\n", " pred_iou_thresh=PRED_IOU_THRESHOLD, # IoU threshold for mask quality filtering\n", " stability_score_thresh=0.92, # Stability score threshold\n", " crop_n_layers=1, # Number of crop layers\n", " crop_n_points_downscale_factor=2, # Downscale factor for crop points\n", " min_mask_region_area=MIN_AREA_THRESHOLD, # Minimum area in pixels\n", ")\n", "\n", "print(f\"Initialized SAM Automatic Mask Generator on device: {device}\")\n", "print(f\"Using model type: {MODEL_TYPE}\")\n", "print(f\"Checkpoint: {checkpoint_path}\")" ] }, { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "raw" } }, "source": [ "## Process Images and Generate Masks\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Process each image and collect segmentations\n", "segmentations_data = []\n", "\n", "for image_path in tqdm(image_files, desc=\"Processing images\", total=len(image_files)):\n", " # Load image\n", " image = cv2.imread(str(image_path))\n", " image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", "\n", " # Generate masks for this image\n", " masks = mask_generator.generate(image_rgb)\n", "\n", " if masks:\n", " # Convert masks to the format expected by 3LC\n", " h, w = image_rgb.shape[:2]\n", "\n", " # Stack all masks for this image into a single array\n", " mask_array = np.stack([mask[\"segmentation\"] for mask in masks], axis=2).astype(np.uint8)\n", "\n", " # Create instance properties (scores, areas, etc.)\n", " instance_properties = {\n", " \"score\": [mask[\"stability_score\"] for mask in masks],\n", " \"area\": [mask[\"area\"] for mask in masks],\n", " \"predicted_iou\": [mask[\"predicted_iou\"] for mask in masks],\n", " \"keep\": [False] * len(masks),\n", " }\n", "\n", " segmentation_data = {\n", " \"image_height\": h,\n", " \"image_width\": w,\n", " \"masks\": mask_array,\n", " \"instance_properties\": instance_properties,\n", " }\n", "\n", " row_data = {\n", " \"image\": tlc.Url(image_path).to_relative().to_str(),\n", " \"segments\": segmentation_data,\n", " }\n", "\n", " segmentations_data.append(row_data)\n", "\n", "print(f\"\\\\nProcessed {len(segmentations_data)} images with valid segmentations\")" ] }, { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "raw" } }, "source": [ "## Create 3LC Table\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create the 3LC table with auto-generated segmentations\n", "table_writer = tlc.TableWriter(\n", " project_name=PROJECT_NAME,\n", " dataset_name=DATASET_NAME,\n", " table_name=TABLE_NAME,\n", " column_schemas={\n", " \"image\": tlc.ImageUrlSchema(),\n", " \"segments\": tlc.InstanceSegmentationMasks(\n", " \"segments\",\n", " instance_properties_structure={\n", " \"score\": tlc.Schema(value=tlc.Float32Value(0, 1), writable=False),\n", " \"area\": tlc.Schema(value=tlc.Int32Value(), writable=False),\n", " \"predicted_iou\": tlc.Schema(value=tlc.Float32Value(0, 1), writable=False),\n", " \"keep\": tlc.Schema(value=tlc.BoolValue(), writable=True),\n", " },\n", " ),\n", " },\n", ")\n", "\n", "# Add all the segmentation data to the table\n", "for row_data in segmentations_data:\n", " table_writer.add_row(row_data)\n", "\n", "# Finalize the table\n", "table = table_writer.finalize()\n", "\n", "print(f\"\\nCreated 3LC table: {table.name}\")\n", "print(f\"Table URL: {table.url}\")\n", "print(f\"Total rows: {len(table)}\")\n", "\n", "# Display some statistics\n", "total_segments = sum(len(row[\"segments\"][\"instance_properties\"][\"score\"]) for row in segmentations_data)\n", "avg_segments_per_image = total_segments / len(segmentations_data) if segmentations_data else 0\n", "\n", "print(\"\\nSegmentation Statistics:\")\n", "print(f\"Total segments collected: {total_segments}\")\n", "print(f\"Average segments per image: {avg_segments_per_image:.1f}\")" ] }, { "cell_type": "raw", "metadata": { "vscode": { "languageId": "raw" } }, "source": [ "## Visualize Sample Results\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "if len(table) > 0:\n", " # Get a sample from the table\n", " sample_idx = 0\n", " sample = table[sample_idx]\n", "\n", " # Load the original image\n", " image_path = sample[\"image\"]\n", " image = cv2.imread(image_path)\n", " image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", "\n", " # Get the segmentation masks\n", " masks = sample[\"segments\"][\"masks\"]\n", " scores = sample[\"segments\"][\"instance_properties\"][\"score\"]\n", " areas = sample[\"segments\"][\"instance_properties\"][\"area\"]\n", "\n", " print(f\"Sample image: {Path(image_path).name}\")\n", " print(f\"Number of segments: {masks.shape[2]}\")\n", " print(f\"Score range: {min(scores):.3f} - {max(scores):.3f}\")\n", " print(f\"Area range: {min(areas)} - {max(areas)} pixels\")\n", "\n", " # Create visualization\n", " fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n", "\n", " # Original image\n", " axes[0].imshow(image_rgb)\n", " axes[0].set_title(\"Original Image\")\n", " axes[0].axis(\"off\")\n", "\n", " # All masks overlay\n", " axes[1].imshow(image_rgb)\n", " combined_mask = np.zeros((masks.shape[0], masks.shape[1]))\n", " colors = plt.cm.tab20(np.linspace(0, 1, min(20, masks.shape[2])))\n", "\n", " for i in range(min(masks.shape[2], 20)): # Show up to 20 masks\n", " mask = masks[:, :, i]\n", " if mask.sum() > 0:\n", " # Create colored overlay\n", " colored_mask = np.zeros((*mask.shape, 4))\n", " colored_mask[mask == 1] = colors[i % len(colors)]\n", " axes[1].imshow(colored_mask, alpha=0.7)\n", "\n", " axes[1].set_title(f\"All Segments Overlay ({min(masks.shape[2], 20)} shown)\")\n", " axes[1].axis(\"off\")\n", "\n", " # Individual high-quality masks\n", " axes[2].imshow(image_rgb)\n", " # Show only the top 5 masks by score\n", " top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:5]\n", "\n", " for i, mask_idx in enumerate(top_indices):\n", " mask = masks[:, :, mask_idx]\n", " if mask.sum() > 0:\n", " colored_mask = np.zeros((*mask.shape, 4))\n", " colored_mask[mask == 1] = colors[i % len(colors)]\n", " axes[2].imshow(colored_mask, alpha=0.8)\n", "\n", " axes[2].set_title(\"Top 5 Segments by Score\")\n", " axes[2].axis(\"off\")\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "else:\n", " print(\"No images processed successfully\")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" }, "test_marks": [ "slow" ] }, "nbformat": 4, "nbformat_minor": 2 }