{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "CL089-JvFAXe" }, "source": [ "# Fine-tuning for Semantic Segmentation with 🤗 Transformers\n", "\n", "In this notebook, you'll learn how to fine-tune a pretrained vision model for Semantic Segmentation on a custom dataset in PyTorch. The idea is to add a randomly initialized segmentation head on top of a pre-trained encoder, and fine-tune the model altogether on a labeled dataset. You can find an accompanying blog post [here](https://huggingface.co/blog/fine-tune-segformer). " ] }, { "cell_type": "markdown", "metadata": { "id": "2UEtIWt0sGkR" }, "source": [ "## Model\n", "\n", "This notebook is built for the [SegFormer model](https://huggingface.co/docs/transformers/model_doc/segformer#transformers.SegformerForSemanticSegmentation) and is supposed to run on any semantic segmentation dataset. You can adapt this notebook to other supported semantic segmentation models such as [MobileViT](https://huggingface.co/docs/transformers/model_doc/mobilevit)." ] }, { "cell_type": "markdown", "metadata": { "id": "j9Fe4GuFtRMs" }, "source": [ "## Data augmentation\n", "\n", "This notebook leverages `torchvision`'s [`transforms` module](https://pytorch.org/vision/stable/transforms.html) for applying data augmentation. Using other augmentation libraries like `albumentations` is also [supported](https://github.com/huggingface/notebooks/blob/main/examples/image_classification_albumentations.ipynb).\n", "\n", "---\n", "\n", "Depending on the model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those two parameters, then the rest of the notebook should run smoothly.\n", "\n", "In this notebook, we'll fine-tune from the https://huggingface.co/nvidia/mit-b0 checkpoint, but note that there are others [available on the hub](https://huggingface.co/models?pipeline_tag=image-segmentation)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6EGlFGE8uyBS" }, "outputs": [], "source": [ "model_checkpoint = \"nvidia/mit-b0\" # pre-trained model from which to fine-tune\n", "batch_size = 4 # batch size for training and evaluation" ] }, { "cell_type": "markdown", "metadata": { "id": "d32sBZeq_HQ2" }, "source": [ "Before we start, let's install the `datasets`, `transformers`, and `evaluate` libraries. We also install Git-LFS to upload the model checkpoints to Hub.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vAvgC48er3Ih" }, "outputs": [], "source": [ "!pip -q install datasets transformers evaluate\n", "\n", "!git lfs install\n", "!git config --global credential.helper store" ] }, { "cell_type": "markdown", "metadata": { "id": "qSq03UMRvjRS" }, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries or run the `pip install` command above with the `--upgrade` flag.\n", "\n", "You can share the resulting model with the community. By pushing the model to the Hub, others can discover your model and build on top of it. You also get an automatically generated model card that documents how the model works and a widget that will allow anyone to try out the model directly in the browser. To enable this, you'll need to login to your account." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PlSGaFpFuQSW" }, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"semantic_segmentation_notebook\", framework=\"pytorch\")" ] }, { "cell_type": "markdown", "metadata": { "id": "XalxdrirGkLl" }, "source": [ "## Fine-tuning a model on a semantic segmentation task\n", "\n", "Given an image, the goal is to associate each and every pixel to a particular category (such as table). The screenshot below is taken from a [SegFormer fine-tuned on ADE20k](https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512) - try out the inference widget!\n", "\n", "\"drawing\"" ] }, { "cell_type": "markdown", "metadata": { "id": "mcE455KaG687" }, "source": [ "### Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "RD_G2KJgG_bU" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download our custom dataset into a [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict).\n", "\n", "We're using the Sidewalk dataset which is dataset of sidewalk images gathered in Belgium in the summer of 2021. You can learn more about the dataset [here](https://huggingface.co/datasets/segments/sidewalk-semantic)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 281, "referenced_widgets": [ "6de4fc2807bc4bb0825befbbaca366c9", "77c817c830314013805fbd0d203ea8b8", "f1f9875ed07b480591c975222fc95950", "3a36ff18a13e46768228e7128e02002b", "c7d020044efb4725a24a0a84ae7409fa", "ad66d4abe2e2450cb7c8a573dc50d9ff", "3e3d572c197e464da6a3ce0cfc4cc89e", "59c9c5f2ef6d4625a97a4bc9bdc331b2", "9725bcfbf9d54824974e8e69c64ca6fd", "83896f5e610144adb992c63e9cbd56ff", "ef4a5405c1d34d5badf0a813a8ea1fa5", "289120fb74224df4a9d2aa2821432ba5", "f7e2d0f3ee6e4231b22616fbf933bf02", "78a551a0db66493f876fbe18030c3174", "f3174b816bec4c918cd3ab4b10ec0efe", "4fa96e02fc4b44cb9996bbef089fd74c", "0031937aefa0465683e219549c935dd7", "d1b7953db7a6462f831fb075199542ef", "1c4348e789704fe2b25cb8698e88ca8e", "bff07d8078824308b8aabf71faf16e88", "12a94046754c49868ad14d688c48d88f", "5ad48a5ce1dc426faa6e790064418c9b", "3de3cc234bda43c1ad0b354dbc474ef6", "d9f56ee4b21f4e0e8693232cbb490cdd", "05621203ac434c5f93d355bb6ce6db3d", "d134802237f84c73a2717eed6cc1ecb4", "afce9206edbb444a944c291d5282c217", "a60fbc5906394df9858d898670a7f819", "1ee064e37e61448cb9a34fa7e66a3e34", "01ecfb419b3f460b8377c79793740982", "43f8dc298b7941bf8ca651015dec533c", "32988ad67af149489bfd676c0942dbe4", "a687b33460fa461eafec017738c863de", "c5030cc8ebff4e628da62e8835965f33", "33ad7766b2c146afb54a3470f352b8ee", "915fad5434a545b086baf639459d4558", "79894b070903437584f7aca3052e50de", "ad87c3f60ac7492ab8aaa471630c41c7", "e74bfbec17f94789bfb8e367983aa193", "dccebc6c988d49c887484b78f4ee95e9", "d04bc0eadaf242b6a81a9ade962c4de2", "4ab875034a084c338994bc7909053241", "3e2ff151429e49f5aaffd0a940d36805", "0459a4d9fcbd4478a52481ac016ddd78", "23735d2bbcf24587859d5d36bf901fd0", "ca4341aa40384d228a99074543527b31", "974c53ab4fb94ee3a86f650bba9562c1", "3d40886e38454a618a8a9a756f5c458d", "8db4e36ccc6a4d4492b78a50c7ec2409", "ab102484e15d44a4b1182029637dfc30", "389469716c4e42c884ca47c6659d5612", "4d6b9ad04eab4a70bf1df3f43a7f5f94", "b969b475e1b149dda5b119e37d7761c7", "c298ba164a35417b9b1fff7a97bd7f7f", "c8cd775a94a944d9a0e0ebeb2c3645f8", "2cbb298d54ac4771be87a85cbcb5dff4", "1a5bc3a601a949b5ab59b4a4b776f2f0", "8c9c506bf6dd47fd882e8d70581e54c2", "8652d0f8599843e98edfe6c21de5be9c", "3b280d74e111402d94a18f61fe8c4478", "998feff5e78b491dacc43d045c73cb8c", "9747349d77e14ab3a47fd4c479970a0a", "47cf642995f44f04a56be02be4ab0a87", "dadc0338ea954e478246a588b2aa1dc7", "219ab526698c43f895ca102befa0135c", "ff2adb06e63e47798f7130e9630c1255", "f596082d8bda42458543179605fb2dc2", "af8c7c09c55341a6a14b4799ae2dff44", "396a68b4f1b947b9876a9587b6958fce", "a3a68016faa148158bd9f4d805863295", "3bea3ea92ee8481fb036fa4e07c08009", "0d05591c5dd74c26b9399ece8cea12b0", "ea01151241174b9da977485761a27463", "b2985292e9d24f8ab6c6523056d8371a", "4af03fff36ed4c569c86908f228f7954", "ae1bd290c1474450a169593db137974f", "097c5b79a85642efa389e68aa5df428c" ] }, "id": "U10po8Q3w9ZJ", "outputId": "292c707c-2cae-4818-fe23-f9c71efe0f37" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6de4fc2807bc4bb0825befbbaca366c9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading metadata: 0%| | 0.00/635 [00:00" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example = ds[\"train\"][10]\n", "example[\"pixel_values\"].resize((200, 200))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 217 }, "id": "-p7R9KP0yUOa", "outputId": "c0487523-1ffe-4cb6-bfe7-dec0033a5d67" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADICAAAAACIM/FCAAALGklEQVR4nN2d25ajOg6Gf9mEpKqnuvfMrHn/B9zd1V2nhAC25oJDgNgcjJyQrYtaKULAH5JsWT5A/0Ejak8AYAuD5fKk/d+xBav6s2UoCrj+pKjLR2urI/sk4Dp2/OsoZe9JVWhSRKQaqJQHOtHalhPXMTsADOJBkQksUcxpSQCtFfVvn2a9u6cJse08cqXZyWUB8PDh2/jKAAAkSNKrg7TPOv/pBIwkA9UWkqQAFdeXYgCMgUrYjjiPpCS0cxxVaV5/slYlBDBygEA7qqxRe0DuJx7HTrguacEoc5VoIgaDixRKAVDqyrsbjqGXuM8SF18NlZQMAJYBewZBVSWwZVKV0w8yPMzjZFKi2F0AqkybAYDoUr0arr4YFu7y//B6DHsLo1PwNH8dkK6wcT9f5TjGzJ7HFEESX1OmqKqFhjLa7FP1h5kAOucEIAWm/EZEEp9tky7hAqmPuEp2NCBFQKphSbEBQ2sCbgTiEwVghmWwYQAoDoeiaUZP+mD5SVkgSQcKj2dpfpCkGLtrotgwM8NUjTxnh6J97uaU4pwy9HVbG00SXzxHSZm7NLKr3VpBA2xRtzdky+JyJctkz6Abcjhrm/qLJ+d3puyQk27bEz51qwEDMHY3CrMA1A/WJQT3N/bj1HFdKlsfyLvaswB2Id2BYFFPOyL3k/NZxtdb3v6Az+5zbGGSHTCrupCRqhPiRPHVmWQ+myieitxzkjGuYDSiNH5ARA7NkLNfRPZcn8n+bqtHzdGkb8f9m1PCibKZAQ++odGaGQCgnjv/MHR0ExtzSNoDhL1la0pbdi2tiWjJ+MrHvQuTgvdMIZmqWRikCQCX59xctznG50jWXK6sQCC6MwhQxxXJzhTHcr7l27bkmkLyS0tlfl3P6pCWWWkI87I7LchtauAljRZTmtosM1AHNRn+mVwTqnrvKkMUQ5a1vgx6PmQfRXl6SonGykSn08sOYKayoFvEKl4Q62kimJ7Tj5M5a5XsNPnzdhavCgol2Oj/eSM6OfGCePq0BGD3fIK1yACovSEy7LhOacCg1DLwlHAvwRdFEn9Osz1Ozb8WyG1RoIkUiRUYMKauowdCddf3uH8y/aMR3D857jwsdcks2QJcwp7BpuqyEynUtZKBYTDATZqoL1WDbrMnx4WFJfmld/tkf3EIRlVh0tnCGpgcTaKXQFR1xBjMXNE3FlMqjTrR2wny695K0SSAiaPFYAnZLIPWB53SGUCZA2QtcIkpiAhNQMktwVAYXMdglDZl1XVerDA7BkBNsB2jYUkAAqzJoQZ5AiIQNXbOsD6CWhgwxAyAz6rqyZDi+pJ52s3Sx/GR5tq9AYGagABmCzvnxkY36WIwV6NGuv3dm/2m6vswc1bsnsVRqFfFEwiq1oHfiqYloZJBCZVtmJL+OwVbzs2ZreHv/xIHqUes6udfDSKuT3WWAKDpchXKPw4mN2yrhzU1/hUgiWr9AGDLLBfj9ZyOTkd9uVGEPFFSp0rmusJ80YMRUmpHdqFigICZpRkAgPQwPc72kiCQr7aSfPqcIFHXQ9ZRh3xixaWkHc4Wc8QnFsjQQyqpSZIId40EQg7LAprg5DmCiUUCcSsEYAuAdhFMLBKIRyERMxFxQLS/xBYcw0XigJDPsgDA6O8x7hkFRGEEhA9R8vQxQEiP9p3yKG4Sy1z9WVIqoiRUIoDQiKsDgPUNDq2SSG3s2EPnbOTLVXeVFdIYTy9QFiM7Lw+iMBUdmuwhQhSFyXzP8RFCFE2YmjxLeSmvEnkQTE4CBp+k7yoPooFpEJzkbUsYpKqypopJ5VnctoRBZqZ2+Sh7W8QBmQ5B6CzelMiCVEPyczLF4k2JKAjNc3UggruLglQXmwNCeSGsEnmQeUG6lY4cJUHqqbHzTj4J90okQaprzQOhUrhXIgii69GVeWdLhymCIPNdHQDoJDvaIwdSe8jsatXKhilyIMsUAummRAxEtRM9Zopwr0QMpF5vsuApW9HIUQqkVsiixiGTbEqkQOox1SVFk21KhEAWewgg3JQIgdQKWdbLoJNg5CgD0ihkYYVqPUsEgoogcpUADwEAyakcIiDNtLXFLVwp1yaKgDTTQJYWi4L2AHCLBEiwQmA9yzNDCiFwjWYJ0/L2jeUmkAiANAoJmaBxFGvcBUAahQRwUCkWOK4HaT0k5OHKNe7rQVYoBEAmVQGvBmmnPgeVSC5wXA3SrroMe7RitrUW5KKQQBs5C9Vba0FahQSWR8y2VoJcFgeEPlipUfeVICs9BHKj7utA1isEELKtlSDtpxXNgYxtrQK5rBBYMRFWaBhuFchlhdia9lnGttaAdBb4rWoMRHKnq0DaT6s46HxnECXi6hCaibYGpP0UFMFfRKRNDAchKYWAjgKj7uEgnc1G1to4v55WkwSDdFeZrQ5g7Z/VOgkG6ShkfSBO5vfafknoitPuHnwy+cLnv1b9PFQj3d/JdLu/Xldt6Riqkc66NSuU9+TdX/sVIVsYiO64iFgmmunlJXg9XKBpdX4ml1AnvP8KTgaHgXT3E5Bco0PZz9BwJQyk+yvRWT5U/PwKIwnyEdXZqsK90ewaef6uA5QcpJHuVmjyi0G+foWMkYaA9HZMkZ+KTPnfAaFXEEjnc4zleMSvb4v36gjwka6HiLWGA+HDj4WrLgNAessIBcdle8Lq5WXRpZebFsW2rOou/PZ70arx5RrpKcRE3BKI0x8LYq/FIP09tB1bY8vJothrMUhPIfKtYV94/yOdSbIUpK+QmJYFAGD98m2eUpaC9Fc+x6qzOsJPP5I5d1kI0ldIbMuqbqL/O8e8Fla//dNjGxYAgMz7HI1sHwSUHWeYzTKQ/gZtt+EA6H3GRI9lIP2zb7WVL8zbdA94EUh/x8OVqesFQqefkzpZBNLfW/pmCgGo+HtqScASkMEWlDcEAZmpPPcSkMFm3zezLAAg+zpedy0AGSjkphwA8Pt9jGQByD0VAgB4/zPy5XyQYTBzSxepS/D56dfJfJC7KwSgD/940GwQGpx5DxDYD68dzAYZKOT2hgUAdPYmVOeCDBVyu3fW9Ivx7stCzgUZnncXywJgfTH9TBAaWtZ9FALQydMuzgTZikIAend3SwNB7qUQAObNqZJ5IMMdr+9mWfB2GGeBDD3kzu+ne3MlzueBDA/cz0UAwLw7Ds4CuVLIfTVCX45e1hyQqz3978sBp3HNAbl6e9J9LQug4uNKJTNArhVyd43Q8cq4ZoBsTiEA7NvwYU6DxHmf70qhfNjHmgEyPHC7fNaI0MdgrsckyDB+30CdBeC6jzUJcv3Cty0oBKBT37imQBwK2YZGQJ+9NOoUyPX3G+EAzJ/ufxMgV+HiViwLAJ27YfAEiEMhm9EI8NYxrgmQa4VsiAPm8/J5HMTxjsrNWBYA+roY12KQLWkE/Nk+11GQjSsEoLzN0I+COL7cFgjoqxn/GQNxhItbqrMAAPxeP9oxkK17CABQ8VqVaQTEFb9vDgSUVfMfR+aiON4ccJPJJ0uF0++HsTeWKceL+SJNxVwpTN9etB/E9SqH6POzAoWTFy8Iud6AEHXm3zrxOrvrDdob1QcAP8h1hwqPCeJSyOZaw654NeI4tmUOH4jzxZkPCOL0kE27iAfEeXTTClkCsslWvRVnkV0vyX1IjTyeh7jL7M6/b1shThBXY/iIIG6FxHyXoYQ4QB5SIQ4QT2C/cV93gLjf2L51jmsQj0K2blnXIB4PeTiNuMPF7SvkCsStkMcD8Sjk8UB8HI8G4hgyBPAAle8QxNeD/6eAPABHv+juDtUjuHofxFdlPQTI/wHFcdIicwcncQAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example[\"label\"].resize((200, 200))" ] }, { "cell_type": "markdown", "metadata": { "id": "vp-fIQzvyTHL" }, "source": [ "Each of the pixels above can be associated to a particular category. Let's load all the categories that are associated with the dataset. Let's also create an `id2label` dictionary to decode them back to strings and see what they are. The inverse `label2id` will be useful too, when we load the model later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "c71c1a959c43465e9fa63a94e72e0681", "22d86a1d282448a190ff5f81d5086f16", "108d9a14e9bd46088a1b3526e9607c48", "f53b548093704a4483f8825f2c7eb3a6", "2be22416cdaf4ab1a48145dcdbafd57f", "31288c6294644f4c890374a08c7ec1ec", "abd78c14e61843ef89699b2ad68d970c", "3dbd784ed0724b7b8f3bf3b5d3d87919", "3f71c41ea8524623b129bfafddf293db", "90d261d919a54a91bacba98abaf6b35c", "d82c57f3e3484f149b5a8d63c97ac582" ] }, "id": "Op_AGsW0y5C3", "outputId": "59041978-a8fa-4aad-ef94-80013c58c288" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c71c1a959c43465e9fa63a94e72e0681", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/852 [00:00\n", " \n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "QXY3erc6g8LP" }, "source": [ "### Use the model from the Hub\n", "\n", "We'll first load the model from the Hub using `SegformerForSemanticSegmentation.from_pretrained()`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gfbdz9Hpg9mI" }, "outputs": [], "source": [ "from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation\n", "\n", "feature_extractor = SegformerFeatureExtractor.from_pretrained(model_checkpoint)\n", "hf_username = \"segments-tobias\"\n", "model = SegformerForSemanticSegmentation.from_pretrained(f\"{hf_username}/{hub_model_id}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "SQJkEqGxQwz6" }, "source": [ "Next, we'll load an image from our test dataset and its associated ground truth segmentation label." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R57X_iNkqv6H" }, "outputs": [], "source": [ "image = test_ds[0]['pixel_values']\n", "gt_seg = test_ds[0]['label']\n", "image" ] }, { "cell_type": "markdown", "metadata": { "id": "7m7IfMv6R3_5" }, "source": [ "To segment this test image, we first need to prepare the image using the feature extractor. Then we'll forward it through the model.\n", "\n", "We also need to remember to upscale the output logits to the original image size. In order to get the actual category predictions, we just have to apply an `argmax` on the logits." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8nNSSqEUBS2v" }, "outputs": [], "source": [ "from torch import nn\n", "\n", "inputs = feature_extractor(images=image, return_tensors=\"pt\")\n", "outputs = model(**inputs)\n", "logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)\n", "\n", "# First, rescale logits to original image size\n", "upsampled_logits = nn.functional.interpolate(\n", " logits,\n", " size=image.size[::-1], # (height, width)\n", " mode='bilinear',\n", " align_corners=False\n", ")\n", "\n", "# Second, apply argmax on the class dimension\n", "pred_seg = upsampled_logits.argmax(dim=1)[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "oyHddde_SOgv" }, "source": [ "Now it's time to display the result. The next cell defines the colors for each category, so that they match the \"category coloring\" on Segments.ai." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "Ky_8gHCRCJHj" }, "outputs": [], "source": [ "#@title `def sidewalk_palette()`\n", "\n", "def sidewalk_palette():\n", " \"\"\"Sidewalk palette that maps each class to RGB values.\"\"\"\n", " return [\n", " [0, 0, 0],\n", " [216, 82, 24],\n", " [255, 255, 0],\n", " [125, 46, 141],\n", " [118, 171, 47],\n", " [161, 19, 46],\n", " [255, 0, 0],\n", " [0, 128, 128],\n", " [190, 190, 0],\n", " [0, 255, 0],\n", " [0, 0, 255],\n", " [170, 0, 255],\n", " [84, 84, 0],\n", " [84, 170, 0],\n", " [84, 255, 0],\n", " [170, 84, 0],\n", " [170, 170, 0],\n", " [170, 255, 0],\n", " [255, 84, 0],\n", " [255, 170, 0],\n", " [255, 255, 0],\n", " [33, 138, 200],\n", " [0, 170, 127],\n", " [0, 255, 127],\n", " [84, 0, 127],\n", " [84, 84, 127],\n", " [84, 170, 127],\n", " [84, 255, 127],\n", " [170, 0, 127],\n", " [170, 84, 127],\n", " [170, 170, 127],\n", " [170, 255, 127],\n", " [255, 0, 127],\n", " [255, 84, 127],\n", " [255, 170, 127],\n", " ]" ] }, { "cell_type": "markdown", "metadata": { "id": "f4BzL0ISSePY" }, "source": [ "The next function overlays the output segmentation map on the original image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G3HqZXyQB7gJ" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "def get_seg_overlay(image, seg):\n", " color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3\n", " palette = np.array(sidewalk_palette())\n", " for label, color in enumerate(palette):\n", " color_seg[seg == label, :] = color\n", "\n", " # Show image + mask\n", " img = np.array(image) * 0.5 + color_seg * 0.5\n", " img = img.astype(np.uint8)\n", "\n", " return img" ] }, { "cell_type": "markdown", "metadata": { "id": "-yEXFytLSkht" }, "source": [ "We'll display the result next to the ground-truth mask." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vnSn2A2U0RMw" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "pred_img = get_seg_overlay(image, pred_seg)\n", "gt_img = get_seg_overlay(image, np.array(gt_seg))\n", "\n", "f, axs = plt.subplots(1, 2)\n", "f.set_figheight(30)\n", "f.set_figwidth(50)\n", "\n", "axs[0].set_title(\"Prediction\", {'fontsize': 40})\n", "axs[0].imshow(pred_img)\n", "axs[1].set_title(\"Ground truth\", {'fontsize': 40})\n", "axs[1].imshow(gt_img)" ] }, { "cell_type": "markdown", "metadata": { "id": "r3Chx4bXaCYa" }, "source": [ "What do you think? Would you send our pizza delivery robot on the road with this segmentation information?\n", "\n", "The result might not be perfect yet, but we can always expand our dataset to make the model more robust. We can now also go train a larger SegFormer model, and see how it stacks up. If you want to explore further beyond this notebook, here are some things you can try next:\n", "\n", "* Train the model for longer. \n", "* Try out the different segmentation-specific training augmentations from libraries like [`albumentations`](https://albumentations.ai/docs/getting_started/mask_augmentation/). \n", "* Try out a larger variant of the SegFormer model family or try an entirely new model family like MobileViT. " ] } ], "metadata": { "accelerator": "GPU", "colab": { "machine_shape": "hm", "provenance": [] }, "gpuClass": "premium", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 1 }