{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# \"Converting CRAFT to TFLite: A Guide to PyTorch-TFLite Conversion\"\n", "> \"Learn how to convert PyTorch pretrained model to TFLite Format\"\n", "\n", "- toc: true\n", "- branch: master\n", "- badges: true\n", "- comments: true\n", "- hide: false\n", "- categories: [tflite, optimization, onnx, craft, text-detector]\n", "- image: images/flow_resized.png\n", "- author: Tulasi Ram\n", "- permalink : /craft-in-tflite" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an end-to-end tutorial on how to convert a PyTorch model to TensorFlow Lite (TFLite) using ONNX. Specifically, we will be using the CRAFT model (proposed in [this paper](https://arxiv.org/pdf/1904.01941)) which is essentially a text detector. Above is the overview of what’s covered in the tutorial -" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# hide\n", "# - [Brief Overview of Craft Model](#Brief-Overview-of-the-CRAFT-Model):\n", "# - [TFLite Conversion Flow](#Conversion-Flow):\n", "# - [PyTorch Model to ONNX Model](#PyTorch-Model-to-ONNX):\n", "# - [ONNX Model to TensorFlow SavedModel](#ONNX-Model-to-TensorFlow-SavedModel):\n", "# - [TensorFlow SavedModel to TFLite Model](#TensorFlow-SavedModel-to-TFLite):\n", "# - Dynamic Range Quantization\n", "# - Float16 Quantization\n", "# - [Running Inference with TFLite Model](#Running-inference-with-TFLite-Models):\n", "# - [Results](#Results)\n", "# - [Conclusion](#CONCLUSION)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Please open the notebooks included in this [repository](https://github.com/tulasiram58827/craft_tflite) and follow along with this blog post." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Important: You may also directly download already converted TFLite Models from this [Repository](https://github.com/tulasiram58827/craft_tflite). Also pre-converted now available in [Tensorflow Hub](https://tfhub.dev/tulasiram58827/lite-model/craft-text-detector/dr/1)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Brief Overview of the CRAFT Model\n", "\n", "Character Region Awareness for Text Detection in short **CRAFT** was proposed in [this paper](https://arxiv.org/pdf/1904.01941) and is known for its efficiency as well as precise performance\n", "The main principle of **CRAFT** is to localize the individual character regions and link the detected characters to text instances. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**CRAFT** produces two scores for each *character-character region score* and *affinity score*.\n", "\n", "- **Character Region Score is used to localize the individual character**\n", "- **Affinity Score is used to group each character into a single instance.**\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](https://lh5.googleusercontent.com/t3J001eXsOz1ZHcRw-csq2veMrsRK_SNJ8xOVmFFOomffcIBbqEZ00oVrGNcbId6Hg2PQRox1SCGCW5IA8T8L9IY6fTcx2ZqGyOt2xe4XSOItzgV5nIT-eNR1MCwUM9Wx4p8m26q)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we all know in most of the image detectors use **VGG16** as feature extractor **CRAFT** is not an exception for it and for decoding the architecture is similar to UNet.\n", "\n", "![](https://lh5.googleusercontent.com/eq-ksP-g33SDXlOCNXpgGPwF2iW-03-VAmM-v9iM13mCkBPt15uXSTqLbv_TsXPFFfHVg3jcNgPPxnmcK9G_TQPGFWMWTkBRryMZUaEyRYPe0PJ0rCEgUNBgrdsPFdEcMe7n3_dm)\n", "\n", "\n", "The above diagram is taken from the original [paper](https://arxiv.org/pdf/1904.01941.pdf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TFLite Conversion Flow\n", "\n", "![](https://paper-attachments.dropbox.com/s_26EA9AD010A0BB211BCF8D0337C8A342D7ED67947FA9F0B435D128AAF3F2C824_1606055863622_flow.png)\n", "\n", "\n", "> Note: Currently, the integer quantization is erroring out and it has been reported to the TensorFlow Lite team.\n", "\n", "*Update from TFLite team: Currently support for NCHW image format(like those converted from PyTorch) is quite limited at this moment, which caused this issue with full integer quantized model.*\n", "\n", "You can find the full reply from TensorFlow Lite team [here](https://github.com/tulasiram58827/craft_tflite/issues/1#issuecomment-734015155)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Clove AI](https://github.com/clovaai/CRAFT-pytorch) team already provided pre-trained weights we can use for making inference on images. But the framework(PyTorch) is not ideal for mobile applications and also for low latency devices like Raspberry Pi and Fully Integer Devices like Google Coral and MicroControllers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TensorFlow Lite is a framework that is well suited for running Deep Learning Models on edge devices and mobile devices. Now a days usage of edge devices become popular mainly due to 3 reasons\n", "\n", "- Lower Latency\n", "- No requirement of Internet\n", "- Privacy Protection\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is why we first convert these pre-trained weights to TFLite which would be more suitable for low latency devices and mobile applications." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PyTorch Model to ONNX Model\n", "\n", "\n", "Refer this [notebook](https://github.com/tulasiram58827/craft_tflite/blob/main/colabs/pytorch_to_onnx.ipynb) for complete code details mentioned in this section.\n", "\n", "\n", "[Open Neural Network Exchange](https://github.com/onnx/onnx) in short **ONNX** is an open format built to represent machine learning models.\n", "The best thing about ONNX is interoperability. You can develop in your preferred framework without worrying about downstream inference applications.\n", "Exporting the models to ONNX format requires some mandatory parameters:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. Pre-trained Model\n", "2. Sample Input\n", "3. Path to save the model\n", "4. Input and Output Node names" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "!pip install onnx\n", "!pip install onnxruntime\n", "!pip install pip install git+https://github.com/onnx/onnx-tensorflow.git\n", "\n", "\n", "import gdown\n", "\n", "import numpy as np\n", "\n", "from pathlib import Path\n", "from datetime import datetime\n", "\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.nn.init as init\n", "import torch.backends.cudnn as cudnn\n", "from torchvision import models\n", "from torchvision.models.vgg import model_urls\n", "from collections import namedtuple\n", "from collections import OrderedDict\n", "import onnx\n", "import onnxruntime\n", "from onnx_tf.backend import prepare" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "def copyStateDict(state_dict):\n", " if list(state_dict.keys())[0].startswith(\"module\"):\n", " start_idx = 1\n", " else:\n", " start_idx = 0\n", " new_state_dict = OrderedDict()\n", " for k, v in state_dict.items():\n", " name = \".\".join(k.split(\".\")[start_idx:])\n", " new_state_dict[name] = v\n", " return new_state_dict\n", "\n", "def init_weights(modules):\n", " for m in modules:\n", " if isinstance(m, nn.Conv2d):\n", " init.xavier_uniform_(m.weight.data)\n", " if m.bias is not None:\n", " m.bias.data.zero_()\n", " elif isinstance(m, nn.BatchNorm2d):\n", " m.weight.data.fill_(1)\n", " m.bias.data.zero_()\n", " elif isinstance(m, nn.Linear):\n", " m.weight.data.normal_(0, 0.01)\n", " m.bias.data.zero_()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "class vgg16_bn(torch.nn.Module):\n", " def __init__(self, pretrained=True, freeze=True):\n", " super(vgg16_bn, self).__init__()\n", " model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')\n", " vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features\n", " self.slice1 = torch.nn.Sequential()\n", " self.slice2 = torch.nn.Sequential()\n", " self.slice3 = torch.nn.Sequential()\n", " self.slice4 = torch.nn.Sequential()\n", " self.slice5 = torch.nn.Sequential()\n", " for x in range(12): # conv2_2\n", " self.slice1.add_module(str(x), vgg_pretrained_features[x])\n", " for x in range(12, 19): # conv3_3\n", " self.slice2.add_module(str(x), vgg_pretrained_features[x])\n", " for x in range(19, 29): # conv4_3\n", " self.slice3.add_module(str(x), vgg_pretrained_features[x])\n", " for x in range(29, 39): # conv5_3\n", " self.slice4.add_module(str(x), vgg_pretrained_features[x])\n", "\n", " # fc6, fc7 without atrous conv\n", " self.slice5 = torch.nn.Sequential(\n", " nn.MaxPool2d(kernel_size=3, stride=1, padding=1),\n", " nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),\n", " nn.Conv2d(1024, 1024, kernel_size=1)\n", " )\n", "\n", " if not pretrained:\n", " init_weights(self.slice1.modules())\n", " init_weights(self.slice2.modules())\n", " init_weights(self.slice3.modules())\n", " init_weights(self.slice4.modules())\n", "\n", " init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7\n", "\n", " if freeze:\n", " for param in self.slice1.parameters(): # only first conv\n", " param.requires_grad= False\n", "\n", " def forward(self, X):\n", " h = self.slice1(X)\n", " h_relu2_2 = h\n", " h = self.slice2(h)\n", " h_relu3_2 = h\n", " h = self.slice3(h)\n", " h_relu4_3 = h\n", " h = self.slice4(h)\n", " h_relu5_3 = h\n", " h = self.slice5(h)\n", " h_fc7 = h\n", " vgg_outputs = namedtuple(\"VggOutputs\", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])\n", " out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)\n", " return out\n", "\n", "class double_conv(nn.Module):\n", " def __init__(self, in_ch, mid_ch, out_ch):\n", " super(double_conv, self).__init__()\n", " self.conv = nn.Sequential(\n", " nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),\n", " nn.BatchNorm2d(mid_ch),\n", " nn.ReLU(inplace=True),\n", " nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),\n", " nn.BatchNorm2d(out_ch),\n", " nn.ReLU(inplace=True)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.conv(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "class CRAFT(nn.Module):\n", " def __init__(self, pretrained=False, freeze=False):\n", " super(CRAFT, self).__init__()\n", "\n", " \"\"\" Base network \"\"\"\n", " self.basenet = vgg16_bn(pretrained, freeze)\n", "\n", " \"\"\" U network \"\"\"\n", " self.upconv1 = double_conv(1024, 512, 256)\n", " self.upconv2 = double_conv(512, 256, 128)\n", " self.upconv3 = double_conv(256, 128, 64)\n", " self.upconv4 = double_conv(128, 64, 32)\n", "\n", " num_class = 2\n", " self.conv_cls = nn.Sequential(\n", " nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),\n", " nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),\n", " nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),\n", " nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),\n", " nn.Conv2d(16, num_class, kernel_size=1),\n", " )\n", "\n", " init_weights(self.upconv1.modules())\n", " init_weights(self.upconv2.modules())\n", " init_weights(self.upconv3.modules())\n", " init_weights(self.upconv4.modules())\n", " init_weights(self.conv_cls.modules())\n", " \n", " def forward(self, x):\n", " \"\"\" Base network \"\"\"\n", " sources = self.basenet(x)\n", "\n", " \"\"\" U network \"\"\"\n", " y = torch.cat([sources[0], sources[1]], dim=1)\n", " y = self.upconv1(y)\n", "\n", " y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)\n", " y = torch.cat([y, sources[2]], dim=1)\n", " y = self.upconv2(y)\n", "\n", " y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)\n", " y = torch.cat([y, sources[3]], dim=1)\n", " y = self.upconv3(y)\n", "\n", " y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)\n", " y = torch.cat([y, sources[4]], dim=1)\n", " feature = self.upconv4(y)\n", "\n", " y = self.conv_cls(feature)\n", "\n", " return y.permute(0,2,3,1), feature" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# hide\n", "# Link to the pretrained model.\n", "# https://drive.google.com/uc?export=download&id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "pytorch_model = CRAFT()\n", "pytorch_model.load_state_dict(copyStateDict(torch.load('/home/ram/Projects/OCR/craft_tflite/models/craft_mlt_25k.pth', map_location='cpu')))\n", "#net.load_state_dict(copyStateDict(torch.load('.EasyOCR/model/craft_mlt_25k.pth', map_location='cuda')))\n", "#net = torch.nn.DataParallel(net).to('cuda')\n", "#cudnn.benchmark = False\n", "pytorch_model.eval()\n", "print(\"Model loaded\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "batch_size = 1\n", "# Input to the model\n", "x = torch.randn(batch_size, 3, 224, 224, requires_grad=True)\n", "onnx_runtime_input = x.detach().numpy()\n", "t1 = datetime.now()\n", "torch_out = net(x)\n", "t2 = datetime.now()\n", "print(\"Time taken for Pytoch model\", str(t2-t1))\n", "pytorch_output = torch_out[0].detach().numpy()\n", "print(\"Output size\", torch_out[0].size())\n", "print(\"Model ran sucesfully\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "shape_dict ={'input' : {0 : 'batch_size',\n", " 2 : 'width',\n", " 3 : 'height'}, # variable lenght axes\n", " 'output' : {0 : 'batch_size',\n", " 1 : 'width',\n", " 2: 'height'}}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.onnx.export(pytorch_model, \n", " x, \n", " 'craft.onnx', \n", " opset_version=10, \n", " do_constant_folding=True, \n", " input_names=['input'], \n", " output_names=['output'], \n", " dynamic_axes= shape_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are some details about the above code snippet - \n", "\n", "- `export()` function executes the model and records a trace of operators that are used to compute output.\n", "\n", "- To execute the model we need to provide the *input*. This value can be random as long as type and dimensions are matched because the export function just runs the model to trace the operators that are being used to compute output.\n", "\n", "- Exported ONNX model will be of fixed dimension unless specified in the `dynamic_axes` parameter. In the above code we specified batch_size, width and height of the image are dynamic and the channels which are not specified in the `dynamic_axes` will be fixed according to input dimension.\n", "\n", "- To visualize the exported onnx model you can use this [tool](https://netron.app/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once the model is exported, load the model and verify the model structure and confirm whether the model has a valid schema or not. \n", "\n", "The below code snippet checks whether the exported onnx model has a valid schema." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "onnx_model = onnx.load('craft.onnx')\n", "onnx.checker.check_model(onnx_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Expected Output:\n", " \n", "Raises Runtime Error if model is not valid. If valid no output." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare ONNX output with Pytorch Model Output:\n", "\n", "\n", " To check whether the exported ONNX model was faulty or not follow these steps:\n", "\n", "- Create a Sample Input\n", "- Run pre-trained Pytorch Model and save output\n", "- Run exported ONNX model and save output\n", "- Compare both pytorch output and ONNX model output." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Below is the code snippet required to implement the above steps:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ort_session = onnxruntime.InferenceSession('craft.onnx')\n", "ort_inputs = {ort_session.get_inputs()[0].name:onnx_runtime_input}\n", "ort_outs = ort_session.run(None, ort_inputs)\n", "np.testing.assert_allclose(pytorch_out, ort_outs[0], rtol=1e-03, atol=1e-05)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above code snippet compares both pytorch model output and onnx model output and errors out if the outputs are not matched with the tolerance mentioned.\n", "\n", "It compares the difference between pytorch output and onnx output to \n", " `atol+rtol*abs(onnx output)`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can refer to this [documentation](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html) for more details about this function.\n", "If the ONNX conversion was faulty then the assertion statement would have errored out.\n", "\n", "\n", "Great! We converted to **ONNX.**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ONNX Model to TensorFlow SavedModel\n", "\n", "Refer to this [notebook](https://github.com/tulasiram58827/craft_tflite/blob/main/colabs/onnx_to_tflite.ipynb) for complete code details mentioned in this section.\n", "\n", "\n", "As mentioned earlier, the best feature of ONNX is interoperability. Once we have the access to the ONNX model we can convert it into any other existing popular frameworks very easily." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let’s see how to convert the ONNX model to the TensorFlow SavedModel." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import onnx\n", "from onnx_tf.backend import prepare\n", "onnx_model = onnx.load('craft.onnx')\n", "tf_rep = prepare(onnx_model)\n", "tf_rep.export_graph('craft_tf_graph')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After exporting to TensorFlow graphs we can inspect the graph using the same [tool](https://netron.app/)\n", "which we used to visualize the onnx model. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Warning: Please refer to the installation [instructions](https://github.com/tulasiram58827/craft_tflite/blob/main/README.md#installation) for validity of onnx and onnx_tf versions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A SavedModel contains all the information about the TensorFlow program, along with weights and computation. As we don’t require any extra code to build the model it is very easy to share or deploy TensorFlow saved models.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The file structure of SavedModel `craft_tf_graph` will be as follows:\n", "```\n", "craft_tf_graph\n", " |---- saved_model.pb\n", " |---- assets\n", " |---- variables\n", " |---- variables.data-00000-of-00001\n", " |---- variables.index\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `saved_model.pb` contains an actual model and set of named signatures each identifying a function that accepts input tensors and produces output tensors.\n", "\n", "The `variables` directory contains standard checkpoints and `assets` directory contains files used by tensorflow graph. `assets` directory is unused in this example as saved model has no requirement of extra files.\n", "\n", "To know more about TensorFlow SavedModel please refer to this [guide](https://www.tensorflow.org/guide/saved_model)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can load the saved model assuming it is Keras saved model. Below is the code\n", "snippet to load the saved model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = tf.keras.models.load_model('craft_tf_graph')\n", "# or model = tf.saved_model.load('craft_tf_graph')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can easily convert to the TFLite Model easily from the saved model. But inorder to change any input dimension you can set it by loading the concrete function from the saved model.\n", "\n", "Below is the code snippet to set the input shape required for the TFLite format." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n", "concrete_func.inputs[0].set_shape([None, 3, 800, 600])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### TensorFlow SavedModel to TFLite\n", "\n", "To convert a TensorFlow model into TensorFlow Lite model can be done from 3 ways:\n", "\n", "- From Saved Model\n", "- From Keras Model\n", "- From Concrete Function\n", "\n", "You can refer to this [blog](https://www.tensorflow.org/lite/performance/post_training_quantization) for various conversion techniques. We will convert to TFLite from concrete function.\n", "\n", "Below is the code snippet to load the concrete function into the TFLiteConverter." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "While converting to TFLite we can choose several quantization methods. Refer to this [guide](https://www.tensorflow.org/lite/performance/post_training_quantization) for various Post training Quantization techniques.\n", "\n", "- Dynamic Range Quantization\n", "- Float16 Quantization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Dynamic Range Quantization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Default Optimization is **Dynamic Range** Quantization." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "converter.optimizations = [tf.lite.Optimize.DEFAULT]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Float16 Quantization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For **Float16** all other things remain same we just need to add this line" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "converter.target_spec.supported_types = [tf.float16]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Convert and Store the Model:**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tf_lite_model = converter.convert()\n", "open('craft.tflite', 'wb').write(tf_lite_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "| **Quantization Type** | **Model Size** |\n", "| --------------------- | -------------- |\n", "| Dynamic Range | 20MB |\n", "| Float16 | 40MB |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Original PyTorch model size is around **80MB**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running inference with TFLite Models\n", "\n", "Refer to this [notebook](https://github.com/tulasiram58827/craft_tflite/blob/main/colabs/tflite_inference.ipynb) for complete code details mentioned in this section.\n", "\n", "Once the TFLite models are generated we need to make sure they are working as expected. So let’s do inference on the real image and check the output.\n", "\n", "Run the preprocessing steps mentioned in this [notebook](https://github.com/tulasiram58827/craft_tflite/blob/main/colabs/tflite_inference.ipynb) before feeding to the tflite model.\n", "\n", "Below is the code snippet to run the inference with TFLite model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "interpreter = tf.lite.Interpreter(model_path='craft.tflite')\n", "interpreter.allocate_tensors()\n", "input_details = interpreter.get_input_details()\n", "output_details = interpreter.get_output_details() \n", "input_shape = input_details[0]['shape']\n", "interpreter.set_tensor(input_details[0]['index'], input_data)\n", "interpreter.invoke()\n", "y = interpreter.get_tensor(output_details[0]['index']) \n", "feature = interpreter.get_tensor(output_details[1]['index'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After the post-processing steps mentioned in this [notebook](https://github.com/tulasiram58827/craft_tflite/blob/main/colabs/tflite_inference.ipynb) the output image (with dynamic range quantized model) would look like this alongside with the output of the original model output." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " **Output with Dynamic Range Quantized Model:** \n", "\n", "![](https://paper-attachments.dropbox.com/s_26EA9AD010A0BB211BCF8D0337C8A342D7ED67947FA9F0B435D128AAF3F2C824_1606194525784_CRAFT_TFLITE_CONVERSION.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Output with Float16 Quantized Model:**\n", "\n", "![](https://paper-attachments.dropbox.com/s_26EA9AD010A0BB211BCF8D0337C8A342D7ED67947FA9F0B435D128AAF3F2C824_1606194562401_CRAFT_TFLITE_CONVERSION+1.jpg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is clearly evident that the results of Float16 quantized model are better than results of Dynamic Range quantized model but at the cost of model size." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "In this post we have covered all the steps required to convert any PyTorch pre-trained model to TFLite format. If you want to use the same notebook for all of the mentioned steps you can use this [notebook](https://github.com/tulasiram58827/craft_tflite/blob/main/colabs/CRAFT_TFLITE.ipynb).\n", "\n", "\n", "Wondering about how the CRAFT model would perform in the mobile device? Refer to this [blog post](https://sayak.dev/optimizing-text-detectors) that compares the CRAFT model with the [EAST model](https://arxiv.org/abs/1704.03155) w.r.t. many useful metrics such as memory, inference latency, performance and so on." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Acknowledgments:**\n", "\n", "*Thanks to [Sayak Paul](https://twitter.com/RisingSayak) , [Le Viet Gia Khanh](https://twitter.com/khanhlvg)(from TFLite team) for their constant guidance.*" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }