# Segment Anything Model (SAM) using Ortex ```elixir Mix.install([ {:ortex, "~> 0.1.9"}, {:image, "~> 0.47"}, {:nx_image, "~> 0.1.2"}, {:exla, "~> 0.7.2"}, {:kino, "~> 0.12.3"}, {:req, "~> 0.4"} ]) ``` ## Recap This is an implementation by @herisson of the [Segment Anything Model](https://github.com/facebookresearch/segment-anything) (SAM) from facebook research using the [Ortex](https://github.com/elixir-nx/ortex) library to run ONNX models. It is approximately a port this jupyter notebook example to livebook: https://colab.research.google.com/drive/1wmjHHcrZ_s8iFuVFh9iHo6GbUS_xH5xq It uses the onnx "mobile" models found [here](https://huggingface.co/vietanhdev/segment-anything-onnx-models/tree/main). ```elixir Nx.global_default_backend(EXLA.Backend) Nx.default_backend() ``` ## Loading the models ```elixir encoder_url = "https://huggingface.co/ginkgoo/SegmentAnythingModel-Elixir-Ortex/resolve/main/vision_encoder_quantized.onnx" encoder_path = "/tmp/vision_encoder_quantized.onnx" %{body: body} = Req.get!(encoder_url) File.write!(encoder_path, body) model = Ortex.load(encoder_path) ``` ```elixir decoder_url = "https://huggingface.co/ginkgoo/SegmentAnythingModel-Elixir-Ortex/resolve/main/vit_b_decoder.onnx" decoder_path = "/tmp/vit_b_decoder.onnx" %{body: body} = Req.get!(decoder_url) File.write!(decoder_path, body) decoder = Ortex.load(decoder_path) ``` ## Getting the image ```elixir image_input = Kino.Input.image("Uploaded Image") # https://ibb.co/Hr8588Y the image I'm using ``` ```elixir image = Kino.Input.read(image_input) |> Image.from_kino!() # The demo image is already these dimensions # but other images may not be. resized = Image.thumbnail!(image, "1024x1024") original_label = Kino.Markdown.new("**Original image**") resized_label = Kino.Markdown.new("**Resized image**") Kino.Layout.grid( [ Kino.Layout.grid([image, original_label], boxed: true), Kino.Layout.grid([resized, resized_label], boxed: true) ], columns: 2 ) ``` ```elixir # Send the image to Nx. This is a zero-copy # operation - basically just give the pointer # to a memory area to Nx. tensor = resized |> Image.to_nx!() |> Nx.as_type(:f32) # Mean and std values copied from transformer.js mean = Nx.tensor([123.675, 116.28, 103.53]) std = Nx.tensor([58.395, 57.12, 57.375]) normalized_tensor = tensor |> NxImage.normalize(mean, std) |> Nx.tensor(names: [:height, :width, :bands]) |> Nx.transpose(axes: [:bands, :height, :width]) # Running image encoder / outputs depends on the onnx model {image_embeddings, _} = Ortex.run(model, Nx.broadcast(normalized_tensor, {1, 3, 1024, 1024})) # {image_embeddings} = Ortex.run(model, Nx.broadcast(normalized_tensor, {1024, 1024, 3})) ``` ## Prompt encoding & mask generation ```elixir # prepare inputs # xy box coordinates in our image of the object we want to detour # input_point = Nx.tensor([[345, 272], [640, 760]]) |> Nx.as_type(:f32) |> Nx.reshape({1, 2, 2}) # 2, 3 is for box startig / end points # input_label = Nx.tensor([2, 3]) |> Nx.reshape({1, 2}) |> Nx.as_type(:f32) # single point input_point = Nx.tensor([[514, 514], [0, 0]]) |> Nx.as_type(:f32) |> Nx.reshape({1, 2, 2}) input_label = Nx.tensor([1, -1]) |> Nx.reshape({1, 2}) |> Nx.as_type(:f32) # Filled with 0, not used here mask_input = Nx.broadcast(0, {1, 1, 256, 256}) |> Nx.as_type(:f32) # not using mask_input has_mask = Nx.broadcast(0, 1) |> Nx.as_type(:f32) original_image_dim = Nx.tensor([1024, 1024]) |> Nx.as_type(:f32) {mask, scores, low_res} = Ortex.run(decoder, { Nx.broadcast(image_embeddings, {1, 256, 64, 64}), Nx.broadcast(input_point, {1, 2, 2}), Nx.broadcast(input_label, {1, 2}), Nx.broadcast(mask_input, {1, 1, 256, 256}), Nx.broadcast(has_mask, {1}), Nx.broadcast(original_image_dim, {2}) }) ``` ### Transform output to a black an white image ```elixir use Image.Math mask = mask[0][2] |> Nx.reshape({1024, 1024, 1}) |> Image.from_nx!() low_res = low_res[0][2] |> Nx.reshape({256, 256, 1}) |> Image.from_nx!() # Threshold the masks to produce black and white # This is one NIF call only. Then then take only the # first band which we'll use as the alpha band. mask = Image.if_then_else!(mask > 0, 255, 0)[0] low_res_mask = Image.if_then_else!(mask > 0, 255, 0)[0] mask_label = Kino.Markdown.new("**Mask**") low_res_label = Kino.Markdown.new("**Low Res**") Kino.Layout.grid( [ Kino.Layout.grid([mask, mask_label], boxed: true), Kino.Layout.grid([low_res_mask, low_res_label], boxed: true) ], columns: 2 ) ``` ### Showing the masks As you can see the high res mask is distorted / badly placed. However, the low res mask fits corectly ?? ```elixir new_image = Image.add_alpha!(image, mask) mask_label = Kino.Markdown.new("**Image mask**") new_image_label = Kino.Markdown.new("**new image**") Kino.Layout.grid( [ Kino.Layout.grid([image, original_label], boxed: true), Kino.Layout.grid([mask, mask_label], boxed: true), Kino.Layout.grid([new_image, new_image_label], boxed: true) ], columns: 3 ) ```