# Example: Convert a PyTorch image segmentation model to CoreML

In this example, we will convert a pretrained PyTorch model for image segmentation.  We will assume you have read through the [PyTorch image classification example](pytorch.ipynb) already.  This notebook is based on [this example in the coremltools documentation](https://coremltools.readme.io/docs/pytorch-conversion-examples).

<div class="alert alert-block alert-warning">
<b>NOTE:</b> This example requires more memory than Binder allows.  To run this example you will need to download the <a href="https://github.com/ContinuumIO/coreml-demo/">coreml-demo</a> repository to your system and create a conda environment with 
<pre>
conda env create -n cml_demo -f binder/environment.yml
conda activate cml_demo
</pre>
</div>


In [None]:
import tensorflow as tf  # temp workaround for protobuf library version issue

import torch
import torch.nn as nn
import torchvision
import json

from torchvision import transforms
from PIL import Image

import coremltools as ct

## Create the PyTorch Model

For this example, we will load the [DeepLabV3](https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/) model with pretrained weights.

In [None]:
# Load the model (deeplabv3)
model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True).eval()

### Test the Model

In [None]:
# Load a sample image (cat_dog.jpg)
input_image = Image.open("cat_dog.jpg")
input_image

As described in the documentation for the DeepLabV3 model, the input pixel data needs to be rescaled for the model.  We will do that for testing using PyTorch transforms, but when we convert the model, we will use the coremltools `ImageType` to describe the input scaling and offset required for the model.

In [None]:
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)

with torch.no_grad():
    output = model(input_batch)['out'][0]
torch_predictions = output.argmax(0)
torch_predictions

The result is a category assignment for every pixel, but that isn't easy to visualize in tensor form.  Instead, we can use a helper function to overlap the categories as color tinting on the original image.

In [None]:
def display_segmentation(input_image, output_predictions):
    # Create a color pallette, selecting a color for each class
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
    colors = (colors % 255).numpy().astype("uint8")

    # Plot the semantic segmentation predictions of 21 classes in each color
    r = Image.fromarray(
        output_predictions.byte().cpu().numpy()
    ).resize(input_image.size)
    r.putpalette(colors)

    # Overlay the segmentation mask on the original image
    alpha_image = input_image.copy()
    alpha_image.putalpha(255)
    r = r.convert("RGBA")
    r.putalpha(128)
    seg_image = Image.alpha_composite(alpha_image, r)
    return seg_image

display_segmentation(input_image, torch_predictions)

## Translate the PyTorch model

Because this model returns a dictionary, and CoreML only allows tensors and tuples of tensors, we need to extract the result from the dictionary while tracing.  We can do this by wrapping the original model in a PyTorch `nn.Module` class that does the dictionary access after the model runs and trace this wrapped model instead:

In [None]:
class WrappedDeeplabv3Resnet101(nn.Module):
    
    def __init__(self):
        super(WrappedDeeplabv3Resnet101, self).__init__()
        self.model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True).eval()
    
    def forward(self, x):
        res = self.model(x)
        x = res["out"]
        return x
        
traceable_model = WrappedDeeplabv3Resnet101().eval()
trace = torch.jit.trace(traceable_model, input_batch)

Converting the model is very similar to the classification example.  We use an `ImageType` input to set the scale and offset for the input pixels.

In [None]:
# Convert the model
mlmodel = ct.convert(
    trace,
    inputs=[ct.ImageType(name="input", color_layout='RGB', scale=1.0/255.0/0.226,
                         bias=(-0.485/0.226, -0.456/0.226, -0.406/0.226),
                         shape=input_batch.shape)],
)

And then we can save the model to disk for use in our project.

In [None]:
mlmodel.save("SegmentationModel.mlmodel")

### Model Previews

CoreML models can carry [arbitrary metadata](https://developer.apple.com/documentation/coreml/mlmodeldescription/2879386-metadata) for use by the application.  There are some special metadata keys that allow XCode to display a preview of the model and allow the developer to interact with it.  For more details on how to set these keys, see [this page in the coremltools documentation](https://coremltools.readme.io/docs/xcode-model-preview-types).

In [None]:
labels_json = {"labels": ["background", "aeroplane", "bicycle", "bird", "board", "bottle", "bus", "car", "cat", "chair", "cow", "diningTable", "dog", "horse", "motorbike", "person", "pottedPlant", "sheep", "sofa", "train", "tvOrMonitor"]}

mlmodel.user_defined_metadata["com.apple.coreml.model.preview.type"] = "imageSegmenter"
mlmodel.user_defined_metadata['com.apple.coreml.model.preview.params'] = json.dumps(labels_json)

mlmodel.save("SegmentationModel_with_metadata.mlmodel")

If you have access to a Mac with XCode >= 12.3 (and macOS >= 11), you can preview this model by [downloading it](SegmentationModel_with_metadata.mlmodel) as well as [the test image](cat_dog.jpg).  Double-clicking the .mlmodel file to open it in XCode.  The "Preview" tab will allow you to drag and drop a test image into the preview window and see the result.

## Using the CoreML Model

Running the model is the same as for the PyTorch classification example, but note that running this segmentation model requires the CoreML in macOS 11.0 or later.

In [None]:
import sys
IS_MACOS = sys.platform == 'darwin'

if IS_MACOS:
    loaded_model = ct.models.MLModel('SegmentationModel.mlmodel')
    prediction = loaded_model.predict({'input': input_image})
    result = display_segmentation(input_image, prediction)
else:
    prediction = 'Skipping prediction on non-macOS system'
    result = None
result