{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Load MXNet model\n", "\n", "In this tutorial, you learn how to load an existing MXNet model and use it to run a prediction task.\n", "\n", "\n", "## Preparation\n", "\n", "This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", "\n", "%maven ai.djl:api:0.15.0\n", "%maven ai.djl:model-zoo:0.15.0\n", "%maven ai.djl.mxnet:mxnet-engine:0.15.0\n", "%maven ai.djl.mxnet:mxnet-model-zoo:0.15.0\n", "%maven org.slf4j:slf4j-simple:1.7.32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import java.awt.image.*;\n", "import java.nio.file.*;\n", "import ai.djl.*;\n", "import ai.djl.inference.*;\n", "import ai.djl.ndarray.*;\n", "import ai.djl.modality.*;\n", "import ai.djl.modality.cv.*;\n", "import ai.djl.modality.cv.util.*;\n", "import ai.djl.modality.cv.transform.*;\n", "import ai.djl.modality.cv.translator.*;\n", "import ai.djl.translate.*;\n", "import ai.djl.training.util.*;\n", "import ai.djl.util.*;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Prepare your MXNet model\n", "\n", "This tutorial assumes that you have a MXNet model trained using Python. A MXNet symbolic model usually contains the following files:\n", "* Symbol file: {MODEL_NAME}-symbol.json - a json file that contains network information about the model\n", "* Parameters file: {MODEL_NAME}-{EPOCH}.params - a binary file that stores the parameter weight and bias\n", "* Synset file: synset.txt - an optional text file that stores classification classes labels\n", "\n", "This tutorial uses a pre-trained MXNet `resnet18_v1` model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use `DownloadUtils` for downloading files from internet." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-symbol.json\", \"build/resnet/resnet18_v1-symbol.json\", new ProgressBar());\n", "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-0000.params.gz\", \"build/resnet/resnet18_v1-0000.params\", new ProgressBar());\n", "DownloadUtils.download(\"https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset.txt\", \"build/resnet/synset.txt\", new ProgressBar());\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2: Load your model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Path modelDir = Paths.get(\"build/resnet\");\n", "Model model = Model.newInstance(\"resnet\");\n", "model.load(modelDir, \"resnet18_v1\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3: Create a `Translator`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Pipeline pipeline = new Pipeline();\n", "pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());\n", "Translator translator = ImageClassificationTranslator.builder()\n", " .setPipeline(pipeline)\n", " .optSynsetArtifactName(\"synset.txt\")\n", " .optApplySoftmax(true)\n", " .build();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4: Load image for classification" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/kitten.jpg\");\n", "img.getWrappedImage()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5: Run inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Predictor predictor = model.newPredictor(translator);\n", "Classifications classifications = predictor.predict(img);\n", "\n", "classifications" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "Now, you can load any MXNet symbolic model and run inference.\n", "\n", "You might also want to check out [load_pytorch_model.ipynb](https://github.com/deepjavalibrary/djl/blob/master/jupyter/load_pytorch_model.ipynb) which demonstrates loading a local model using the ModelZoo API." ] } ], "metadata": { "kernelspec": { "display_name": "Java", "language": "java", "name": "java" }, "language_info": { "codemirror_mode": "java", "file_extension": ".jshell", "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", "version": "14.0.2+12" } }, "nbformat": 4, "nbformat_minor": 2 }