--- name: segment-anything-model description: Foundation model for image segmentation with zero-shot transfer. Use when you need to segment any object in images using points, boxes, or masks as prompts, or automatically generate all object masks in an image. version: 1.0.0 author: Orchestra Research license: MIT tags: [Multimodal, Image Segmentation, Computer Vision, SAM, Zero-Shot] dependencies: [segment-anything, transformers>=4.30.0, torch>=1.7.0] --- # Segment Anything Model (SAM) Comprehensive guide to using Meta AI's Segment Anything Model for zero-shot image segmentation. ## When to use SAM **Use SAM when:** - Need to segment any object in images without task-specific training - Building interactive annotation tools with point/box prompts - Generating training data for other vision models - Need zero-shot transfer to new image domains - Building object detection/segmentation pipelines - Processing medical, satellite, or domain-specific images **Key features:** - **Zero-shot segmentation**: Works on any image domain without fine-tuning - **Flexible prompts**: Points, bounding boxes, or previous masks - **Automatic segmentation**: Generate all object masks automatically - **High quality**: Trained on 1.1 billion masks from 11 million images - **Multiple model sizes**: ViT-B (fastest), ViT-L, ViT-H (most accurate) - **ONNX export**: Deploy in browsers and edge devices **Use alternatives instead:** - **YOLO/Detectron2**: For real-time object detection with classes - **Mask2Former**: For semantic/panoptic segmentation with categories - **GroundingDINO + SAM**: For text-prompted segmentation - **SAM 2**: For video segmentation tasks ## Quick start ### Installation ```bash # From GitHub pip install git+https://github.com/facebookresearch/segment-anything.git # Optional dependencies pip install opencv-python pycocotools matplotlib # Or use HuggingFace transformers pip install transformers ``` ### Download checkpoints ```bash # ViT-H (largest, most accurate) - 2.4GB wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth # ViT-L (medium) - 1.2GB wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth # ViT-B (smallest, fastest) - 375MB wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth ``` ### Basic usage with SamPredictor ```python import numpy as np from segment_anything import sam_model_registry, SamPredictor # Load model sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") sam.to(device="cuda") # Create predictor predictor = SamPredictor(sam) # Set image (computes embeddings once) image = cv2.imread("image.jpg") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) predictor.set_image(image) # Predict with point prompts input_point = np.array([[500, 375]]) # (x, y) coordinates input_label = np.array([1]) # 1 = foreground, 0 = background masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True # Returns 3 mask options ) # Select best mask best_mask = masks[np.argmax(scores)] ``` ### HuggingFace Transformers ```python import torch from PIL import Image from transformers import SamModel, SamProcessor # Load model and processor model = SamModel.from_pretrained("facebook/sam-vit-huge") processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") model.to("cuda") # Process image with point prompt image = Image.open("image.jpg") input_points = [[[450, 600]]] # Batch of points inputs = processor(image, input_points=input_points, return_tensors="pt") inputs = {k: v.to("cuda") for k, v in inputs.items()} # Generate masks with torch.no_grad(): outputs = model(**inputs) # Post-process masks to original size masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) ``` ## Core concepts ### Model architecture ``` SAM Architecture: ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │ Image Encoder │────▶│ Prompt Encoder │────▶│ Mask Decoder │ │ (ViT) │ │ (Points/Boxes) │ │ (Transformer) │ └─────────────────┘ └─────────────────┘ └─────────────────┘ │ │ │ Image Embeddings Prompt Embeddings Masks + IoU (computed once) (per prompt) predictions ``` ### Model variants | Model | Checkpoint | Size | Speed | Accuracy | |-------|------------|------|-------|----------| | ViT-H | `vit_h` | 2.4 GB | Slowest | Best | | ViT-L | `vit_l` | 1.2 GB | Medium | Good | | ViT-B | `vit_b` | 375 MB | Fastest | Good | ### Prompt types | Prompt | Description | Use Case | |--------|-------------|----------| | Point (foreground) | Click on object | Single object selection | | Point (background) | Click outside object | Exclude regions | | Bounding box | Rectangle around object | Larger objects | | Previous mask | Low-res mask input | Iterative refinement | ## Interactive segmentation ### Point prompts ```python # Single foreground point input_point = np.array([[500, 375]]) input_label = np.array([1]) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True ) # Multiple points (foreground + background) input_points = np.array([[500, 375], [600, 400], [450, 300]]) input_labels = np.array([1, 1, 0]) # 2 foreground, 1 background masks, scores, logits = predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=False # Single mask when prompts are clear ) ``` ### Box prompts ```python # Bounding box [x1, y1, x2, y2] input_box = np.array([425, 600, 700, 875]) masks, scores, logits = predictor.predict( box=input_box, multimask_output=False ) ``` ### Combined prompts ```python # Box + points for precise control masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), box=np.array([400, 300, 700, 600]), multimask_output=False ) ``` ### Iterative refinement ```python # Initial prediction masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True ) # Refine with additional point using previous mask masks, scores, logits = predictor.predict( point_coords=np.array([[500, 375], [550, 400]]), point_labels=np.array([1, 0]), # Add background point mask_input=logits[np.argmax(scores)][None, :, :], # Use best mask multimask_output=False ) ``` ## Automatic mask generation ### Basic automatic segmentation ```python from segment_anything import SamAutomaticMaskGenerator # Create generator mask_generator = SamAutomaticMaskGenerator(sam) # Generate all masks masks = mask_generator.generate(image) # Each mask contains: # - segmentation: binary mask # - bbox: [x, y, w, h] # - area: pixel count # - predicted_iou: quality score # - stability_score: robustness score # - point_coords: generating point ``` ### Customized generation ```python mask_generator = SamAutomaticMaskGenerator( model=sam, points_per_side=32, # Grid density (more = more masks) pred_iou_thresh=0.88, # Quality threshold stability_score_thresh=0.95, # Stability threshold crop_n_layers=1, # Multi-scale crops crop_n_points_downscale_factor=2, min_mask_region_area=100, # Remove tiny masks ) masks = mask_generator.generate(image) ``` ### Filtering masks ```python # Sort by area (largest first) masks = sorted(masks, key=lambda x: x['area'], reverse=True) # Filter by predicted IoU high_quality = [m for m in masks if m['predicted_iou'] > 0.9] # Filter by stability score stable_masks = [m for m in masks if m['stability_score'] > 0.95] ``` ## Batched inference ### Multiple images ```python # Process multiple images efficiently images = [cv2.imread(f"image_{i}.jpg") for i in range(10)] all_masks = [] for image in images: predictor.set_image(image) masks, _, _ = predictor.predict( point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True ) all_masks.append(masks) ``` ### Multiple prompts per image ```python # Process multiple prompts efficiently (one image encoding) predictor.set_image(image) # Batch of point prompts points = [ np.array([[100, 100]]), np.array([[200, 200]]), np.array([[300, 300]]) ] all_masks = [] for point in points: masks, scores, _ = predictor.predict( point_coords=point, point_labels=np.array([1]), multimask_output=True ) all_masks.append(masks[np.argmax(scores)]) ``` ## ONNX deployment ### Export model ```bash python scripts/export_onnx_model.py \ --checkpoint sam_vit_h_4b8939.pth \ --model-type vit_h \ --output sam_onnx.onnx \ --return-single-mask ``` ### Use ONNX model ```python import onnxruntime # Load ONNX model ort_session = onnxruntime.InferenceSession("sam_onnx.onnx") # Run inference (image embeddings computed separately) masks = ort_session.run( None, { "image_embeddings": image_embeddings, "point_coords": point_coords, "point_labels": point_labels, "mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32), "has_mask_input": np.array([0], dtype=np.float32), "orig_im_size": np.array([h, w], dtype=np.float32) } ) ``` ## Common workflows ### Workflow 1: Annotation tool ```python import cv2 # Load model predictor = SamPredictor(sam) predictor.set_image(image) def on_click(event, x, y, flags, param): if event == cv2.EVENT_LBUTTONDOWN: # Foreground point masks, scores, _ = predictor.predict( point_coords=np.array([[x, y]]), point_labels=np.array([1]), multimask_output=True ) # Display best mask display_mask(masks[np.argmax(scores)]) ``` ### Workflow 2: Object extraction ```python def extract_object(image, point): """Extract object at point with transparent background.""" predictor.set_image(image) masks, scores, _ = predictor.predict( point_coords=np.array([point]), point_labels=np.array([1]), multimask_output=True ) best_mask = masks[np.argmax(scores)] # Create RGBA output rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) rgba[:, :, :3] = image rgba[:, :, 3] = best_mask * 255 return rgba ``` ### Workflow 3: Medical image segmentation ```python # Process medical images (grayscale to RGB) medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE) rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB) predictor.set_image(rgb_image) # Segment region of interest masks, scores, _ = predictor.predict( box=np.array([x1, y1, x2, y2]), # ROI bounding box multimask_output=True ) ``` ## Output format ### Mask data structure ```python # SamAutomaticMaskGenerator output { "segmentation": np.ndarray, # H×W binary mask "bbox": [x, y, w, h], # Bounding box "area": int, # Pixel count "predicted_iou": float, # 0-1 quality score "stability_score": float, # 0-1 robustness score "crop_box": [x, y, w, h], # Generation crop region "point_coords": [[x, y]], # Input point } ``` ### COCO RLE format ```python from pycocotools import mask as mask_utils # Encode mask to RLE rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8))) rle["counts"] = rle["counts"].decode("utf-8") # Decode RLE to mask decoded_mask = mask_utils.decode(rle) ``` ## Performance optimization ### GPU memory ```python # Use smaller model for limited VRAM sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth") # Process images in batches # Clear CUDA cache between large batches torch.cuda.empty_cache() ``` ### Speed optimization ```python # Use half precision sam = sam.half() # Reduce points for automatic generation mask_generator = SamAutomaticMaskGenerator( model=sam, points_per_side=16, # Default is 32 ) # Use ONNX for deployment # Export with --return-single-mask for faster inference ``` ## Common issues | Issue | Solution | |-------|----------| | Out of memory | Use ViT-B model, reduce image size | | Slow inference | Use ViT-B, reduce points_per_side | | Poor mask quality | Try different prompts, use box + points | | Edge artifacts | Use stability_score filtering | | Small objects missed | Increase points_per_side | ## References - **[Advanced Usage](references/advanced-usage.md)** - Batching, fine-tuning, integration - **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions ## Resources - **GitHub**: https://github.com/facebookresearch/segment-anything - **Paper**: https://arxiv.org/abs/2304.02643 - **Demo**: https://segment-anything.com - **SAM 2 (Video)**: https://github.com/facebookresearch/segment-anything-2 - **HuggingFace**: https://huggingface.co/facebook/sam-vit-huge