{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Convert bounding boxes to instance segmentation masks using SAM\n", "\n", "This notebook creates a derived Table with an added column containing instance segmentation masks generated by the SAM model using the Table's existing bounding boxes as prompts.\n", "\n", "![img](../images/bb2seg.jpg)\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install 3lc\n", "%pip install git+https://github.com/3lc-ai/3lc-examples.git\n", "%pip install git+https://github.com/facebookresearch/segment-anything\n", "%pip install matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import requests\n", "import tlc\n", "from tqdm import tqdm\n", "\n", "from tlc_tools.sam_autosegment import bbs_to_segments" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Project Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "PROJECT_NAME = \"3LC Tutorials - COCO128\"\n", "MODEL_TYPE = \"vit_b\"\n", "DOWNLOAD_PATH = \"../../transient_data\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the input table" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load the input table, created in [create-table-from-coco.ipynb](../1-create-tables/object%20detection/create-table-from-coco-detection.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "input_table = tlc.Table.from_names(\"initial\", \"COCO128\", PROJECT_NAME)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download the SAM model checkpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "checkpoint_path = Path(DOWNLOAD_PATH) / \"sam_vit_b_01ec64.pth\"\n", "\n", "if not checkpoint_path.exists():\n", " model_url = \"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth\"\n", "\n", " print(f\"Downloading SAM checkpoint to {checkpoint_path}...\")\n", " checkpoint_path.parent.mkdir(parents=True, exist_ok=True)\n", " response = requests.get(model_url, stream=True)\n", " total_size = int(response.headers.get(\"content-length\", 0))\n", " with (\n", " open(checkpoint_path, \"wb\") as f,\n", " tqdm(\n", " desc=\"Downloading\",\n", " total=total_size,\n", " unit=\"B\",\n", " unit_scale=True,\n", " unit_divisor=1024,\n", " ) as bar,\n", " ):\n", " for chunk in response.iter_content(chunk_size=8192):\n", " if chunk:\n", " f.write(chunk)\n", " bar.update(len(chunk))\n", " print(\"Download completed.\")\n", "else:\n", " print(f\"Checkpoint already exists at {checkpoint_path}.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run SAM model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "out_table = bbs_to_segments(\n", " input_table,\n", " sam_model_type=MODEL_TYPE,\n", " checkpoint=checkpoint_path.as_posix(),\n", " description=\"Added segmentation column from bounding boxes\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize an example mask" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "example_mask = out_table[3][\"segments\"][\"masks\"]\n", "print(example_mask.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.figure(figsize=(10, 10))\n", "plt.imshow(example_mask[:, :, 0], cmap=\"gray\")\n", "plt.axis(\"off\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" }, "test_marks": [ "slow" ] }, "nbformat": 4, "nbformat_minor": 2 }