{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Inference with your model\n", "\n", "This is the third and final tutorial of our [beginner tutorial series](https://github.com/deepjavalibrary/djl/tree/master/jupyter/tutorial) that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to execute your image classification model for a production system.\n", "\n", "In the [previous tutorial](02_train_your_first_model.ipynb), you successfully trained your model. Now, we will learn how to implement a `Translator` to convert between POJO and `NDArray` as well as a `Predictor` to run inference.\n", "\n", "\n", "## Preparation\n", "\n", "This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the [Jupyter README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// Add the snapshot repository to get the DJL snapshot artifacts\n", "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", "\n", "// Add the maven dependencies\n", "%maven ai.djl:api:0.16.0\n", "%maven ai.djl:model-zoo:0.16.0\n", "%maven ai.djl.mxnet:mxnet-engine:0.16.0\n", "%maven ai.djl.mxnet:mxnet-model-zoo:0.16.0\n", "%maven org.slf4j:slf4j-simple:1.7.32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import java.awt.image.*;\n", "import java.nio.file.*;\n", "import java.util.*;\n", "import java.util.stream.*;\n", "import ai.djl.*;\n", "import ai.djl.basicmodelzoo.basic.*;\n", "import ai.djl.ndarray.*;\n", "import ai.djl.modality.*;\n", "import ai.djl.modality.cv.*;\n", "import ai.djl.modality.cv.util.NDImageUtils;\n", "import ai.djl.translate.*;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Load your handwritten digit image\n", "\n", "We will start by loading the image that we want to run our model to classify." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var img = ImageFactory.getInstance().fromUrl(\"https://resources.djl.ai/images/0.png\");\n", "img.getWrappedImage();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2: Load your model\n", "\n", "Next, we need to load the model to run inference with. This model should have been saved to the `build/mlp` directory when running the [previous tutorial](02_train_your_first_model.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Path modelDir = Paths.get(\"build/mlp\");\n", "Model model = Model.newInstance(\"mlp\");\n", "model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));\n", "model.load(modelDir);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to loading a local model, you can also find pretrained models within our [model zoo](http://docs.djl.ai/docs/model-zoo.html). See more options in our [model loading documentation](http://docs.djl.ai/docs/load_model.html).\n", "\n", "## Step 3: Create a `Translator`\n", "\n", "The [`Translator`](https://javadoc.io/static/ai.djl/api/0.16.0/index.html?ai/djl/translate/Translator.html) is used to encapsulate the pre-processing and post-processing functionality of your application. The input to the processInput and processOutput should be single data items, not batches." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Translator translator = new Translator() {\n", "\n", " @Override\n", " public NDList processInput(TranslatorContext ctx, Image input) {\n", " // Convert Image to NDArray\n", " NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);\n", " return new NDList(NDImageUtils.toTensor(array));\n", " }\n", "\n", " @Override\n", " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", " // Create a Classifications with the output probabilities\n", " NDArray probabilities = list.singletonOrThrow().softmax(0);\n", " List classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());\n", " return new Classifications(classNames, probabilities);\n", " }\n", " \n", " @Override\n", " public Batchifier getBatchifier() {\n", " // The Batchifier describes how to combine a batch together\n", " // Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array\n", " return Batchifier.STACK;\n", " }\n", "};" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4: Create Predictor\n", "\n", "Using the translator, we will create a new [`Predictor`](https://javadoc.io/static/ai.djl/api/0.16.0/index.html?ai/djl/inference/Predictor.html). The predictor is the main class to orchestrate the inference process. During inference, a trained model is used to predict values, often for production use cases. The predictor is NOT thread-safe, so if you want to do prediction in parallel, you should call newPredictor multiple times to create a predictor object for each thread." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var predictor = model.newPredictor(translator);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5: Run inference\n", "\n", "With our predictor, we can simply call the [predict](https://javadoc.io/static/ai.djl/api/0.16.0/ai/djl/inference/Predictor.html#predict-I-) method to run inference. For better performance, you can also call [batchPredict](https://javadoc.io/static/ai.djl/api/0.16.0/ai/djl/inference/Predictor.html#batchPredict-java.util.List-) with a list of input items. Afterwards, the same predictor should be used for further inference calls. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var classifications = predictor.predict(img);\n", "\n", "classifications" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "Now, you've successfully built a model, trained it, and run inference. Congratulations on finishing the [beginner tutorial series](https://github.com/deepjavalibrary/djl/tree/master/jupyter/tutorial). After this, you should read our other [examples](https://github.com/deepjavalibrary/djl/tree/master/examples) and [jupyter notebooks](https://github.com/deepjavalibrary/djl/tree/master/jupyter) to learn more about DJL.\n", "\n", "You can find the complete source code for this tutorial in the [examples project](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/ImageClassification.java)." ] } ], "metadata": { "kernelspec": { "display_name": "Java", "language": "java", "name": "java" }, "language_info": { "codemirror_mode": "java", "file_extension": ".jshell", "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", "version": "14.0.2+12" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 2 }