{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: Convert a PyTorch image segmentation model to CoreML\n", "\n", "In this example, we will convert a pretrained PyTorch model for image segmentation. We will assume you have read through the [PyTorch image classification example](pytorch.ipynb) already. This notebook is based on [this example in the coremltools documentation](https://coremltools.readme.io/docs/pytorch-conversion-examples).\n", "\n", "
\n", "NOTE: This example requires more memory than Binder allows. To run this example you will need to download the coreml-demo repository to your system and create a conda environment with \n", "
\n",
    "conda env create -n cml_demo -f binder/environment.yml\n",
    "conda activate cml_demo\n",
    "
\n", "
\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf # temp workaround for protobuf library version issue\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torchvision\n", "import json\n", "\n", "from torchvision import transforms\n", "from PIL import Image\n", "\n", "import coremltools as ct" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create the PyTorch Model\n", "\n", "For this example, we will load the [DeepLabV3](https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/) model with pretrained weights." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load the model (deeplabv3)\n", "model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True).eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test the Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load a sample image (cat_dog.jpg)\n", "input_image = Image.open(\"cat_dog.jpg\")\n", "input_image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As described in the documentation for the DeepLabV3 model, the input pixel data needs to be rescaled for the model. We will do that for testing using PyTorch transforms, but when we convert the model, we will use the coremltools `ImageType` to describe the input scaling and offset required for the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "preprocess = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize(\n", " mean=[0.485, 0.456, 0.406],\n", " std=[0.229, 0.224, 0.225],\n", " ),\n", "])\n", "\n", "input_tensor = preprocess(input_image)\n", "input_batch = input_tensor.unsqueeze(0)\n", "\n", "with torch.no_grad():\n", " output = model(input_batch)['out'][0]\n", "torch_predictions = output.argmax(0)\n", "torch_predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result is a category assignment for every pixel, but that isn't easy to visualize in tensor form. Instead, we can use a helper function to overlap the categories as color tinting on the original image." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def display_segmentation(input_image, output_predictions):\n", " # Create a color pallette, selecting a color for each class\n", " palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])\n", " colors = torch.as_tensor([i for i in range(21)])[:, None] * palette\n", " colors = (colors % 255).numpy().astype(\"uint8\")\n", "\n", " # Plot the semantic segmentation predictions of 21 classes in each color\n", " r = Image.fromarray(\n", " output_predictions.byte().cpu().numpy()\n", " ).resize(input_image.size)\n", " r.putpalette(colors)\n", "\n", " # Overlay the segmentation mask on the original image\n", " alpha_image = input_image.copy()\n", " alpha_image.putalpha(255)\n", " r = r.convert(\"RGBA\")\n", " r.putalpha(128)\n", " seg_image = Image.alpha_composite(alpha_image, r)\n", " return seg_image\n", "\n", "display_segmentation(input_image, torch_predictions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Translate the PyTorch model\n", "\n", "Because this model returns a dictionary, and CoreML only allows tensors and tuples of tensors, we need to extract the result from the dictionary while tracing. We can do this by wrapping the original model in a PyTorch `nn.Module` class that does the dictionary access after the model runs and trace this wrapped model instead:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class WrappedDeeplabv3Resnet101(nn.Module):\n", " \n", " def __init__(self):\n", " super(WrappedDeeplabv3Resnet101, self).__init__()\n", " self.model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True).eval()\n", " \n", " def forward(self, x):\n", " res = self.model(x)\n", " x = res[\"out\"]\n", " return x\n", " \n", "traceable_model = WrappedDeeplabv3Resnet101().eval()\n", "trace = torch.jit.trace(traceable_model, input_batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Converting the model is very similar to the classification example. We use an `ImageType` input to set the scale and offset for the input pixels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Convert the model\n", "mlmodel = ct.convert(\n", " trace,\n", " inputs=[ct.ImageType(name=\"input\", color_layout='RGB', scale=1.0/255.0/0.226,\n", " bias=(-0.485/0.226, -0.456/0.226, -0.406/0.226),\n", " shape=input_batch.shape)],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And then we can save the model to disk for use in our project." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mlmodel.save(\"SegmentationModel.mlmodel\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model Previews\n", "\n", "CoreML models can carry [arbitrary metadata](https://developer.apple.com/documentation/coreml/mlmodeldescription/2879386-metadata) for use by the application. There are some special metadata keys that allow XCode to display a preview of the model and allow the developer to interact with it. For more details on how to set these keys, see [this page in the coremltools documentation](https://coremltools.readme.io/docs/xcode-model-preview-types)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "labels_json = {\"labels\": [\"background\", \"aeroplane\", \"bicycle\", \"bird\", \"board\", \"bottle\", \"bus\", \"car\", \"cat\", \"chair\", \"cow\", \"diningTable\", \"dog\", \"horse\", \"motorbike\", \"person\", \"pottedPlant\", \"sheep\", \"sofa\", \"train\", \"tvOrMonitor\"]}\n", "\n", "mlmodel.user_defined_metadata[\"com.apple.coreml.model.preview.type\"] = \"imageSegmenter\"\n", "mlmodel.user_defined_metadata['com.apple.coreml.model.preview.params'] = json.dumps(labels_json)\n", "\n", "mlmodel.save(\"SegmentationModel_with_metadata.mlmodel\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you have access to a Mac with XCode >= 12.3 (and macOS >= 11), you can preview this model by [downloading it](SegmentationModel_with_metadata.mlmodel) as well as [the test image](cat_dog.jpg). Double-clicking the .mlmodel file to open it in XCode. The \"Preview\" tab will allow you to drag and drop a test image into the preview window and see the result." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using the CoreML Model\n", "\n", "Running the model is the same as for the PyTorch classification example, but note that running this segmentation model requires the CoreML in macOS 11.0 or later." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "IS_MACOS = sys.platform == 'darwin'\n", "\n", "if IS_MACOS:\n", " loaded_model = ct.models.MLModel('SegmentationModel.mlmodel')\n", " prediction = loaded_model.predict({'input': input_image})\n", " result = display_segmentation(input_image, prediction)\n", "else:\n", " prediction = 'Skipping prediction on non-macOS system'\n", " result = None\n", "result" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 4 }