{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using cache found in /Users/sseibert/.cache/torch/hub/pytorch_vision_v0.6.0\n" ] } ], "source": [ "import torch\n", "import torchvision\n", "\n", "torch_model = torch.hub.load('pytorch/vision:v0.6.0', 'mobilenet_v2', pretrained=True)\n", "torch_model.eval();" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from torchvision import transforms\n", "full_model = torch.nn.Sequential(\n", " torch_model,\n", " torch.nn.Softmax(dim=0),\n", ")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Trace with random data\n", "example_input = torch.rand(1, 3, 224, 224)\n", "traced_model = torch.jit.trace(full_model, example_input)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Download class labels (from a separate file)\n", "import urllib\n", "label_url = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'\n", "class_labels = urllib.request.urlopen(label_url).read().decode(\"utf-8\").splitlines()\n", "\n", "class_labels = class_labels[1:] # remove the first class which is background\n", "assert len(class_labels) == 1000" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Converting Frontend ==> MIL Ops: 100%|█████████▉| 386/387 [00:00<00:00, 1501.18 ops/s]\n", "Running MIL optimization passes: 100%|██████████| 17/17 [00:00<00:00, 129.66 passes/s]\n", "Translating MIL ==> MLModel Ops: 100%|██████████| 708/708 [00:00<00:00, 1953.60 ops/s]\n" ] } ], "source": [ "import coremltools as ct\n", "# Convert to Core ML using the Unified Conversion API\n", "model = ct.convert(\n", " traced_model,\n", " inputs=[ct.ImageType(name=\"image\", shape=example_input.shape,\n", " color_layout=\"RGB\",\n", " bias=[-1, -1, -1], \n", " scale=1/127.0)],\n", " classifier_config = ct.ClassifierConfig(class_labels) # provide only if step 4 was performed\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Save model\n", "model.save(\"MobileNet_v2.mlmodel\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import PIL\n", "import numpy as np\n", "\n", "def load_image(path, resize_to=None):\n", " # resize_to: (Width, Height)\n", " img = PIL.Image.open(path)\n", " if resize_to is not None:\n", " img = img.resize(resize_to, PIL.Image.ANTIALIAS)\n", " img_np = np.array(img).astype(np.float32)\n", " return img_np, img" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Translated PyTorch Model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('tench',\n", " [('Afghan hound', 1.0),\n", " ('African chameleon', 1.0),\n", " ('African crocodile', 1.0),\n", " ('African elephant', 1.0),\n", " ('African grey', 1.0),\n", " ('African hunting dog', 1.0),\n", " ('Airedale', 1.0),\n", " ('American Staffordshire terrier', 1.0),\n", " ('American alligator', 1.0),\n", " ('American black bear', 1.0)])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "_, img = load_image('Aalto_table.jpeg', resize_to=(224,224))\n", "out = model.predict({'image': img})\n", "out['classLabel'], sorted(out['649'].items(), key=lambda x: -x[1])[:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Reference Model\n", "\n", "MobileNet.mlmodel taken from https://github.com/sivu22/CoreMLCompare/tree/master/CoreMLCompare" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "ref = ct.models.MLModel('MobileNet.mlmodel')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('bannister, banister, balustrade, balusters, handrail',\n", " [('bannister, banister, balustrade, balusters, handrail', 0.7715724110603333),\n", " ('shoji', 0.054597772657871246),\n", " ('four-poster', 0.029684294015169144),\n", " ('parallel bars, bars', 0.028104523196816444),\n", " ('dining table, board', 0.024609141051769257),\n", " ('sliding door', 0.01916561648249626),\n", " ('fire screen, fireguard', 0.01576521061360836),\n", " ('studio couch, day bed', 0.00406468752771616),\n", " ('pedestal, plinth, footstall', 0.004033055622130632),\n", " ('prayer rug, prayer mat', 0.003056224901229143)])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "_, img = load_image('Aalto_table.jpeg', resize_to=(224,224))\n", "ref_out = ref.predict({'image': img})\n", "ref_out['classLabel'], sorted(ref_out['classLabelProbs'].items(), key=lambda x: -x[1])[:10]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }