{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Inference with your model\n", "\n", "In a [previous tutorial](train_your_first_model.ipynb), you successfully trained your model.\n", "Now, you learn how to use your model to classify a handwritten digit image. You will learn how to implement a `Translator` interface to convert between POJO and `NDArray`.\n", "\n", "\n", "## Preparation\n", "\n", "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/awslabs/djl/blob/master/jupyter/README.md)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", "\n", "%maven ai.djl:api:0.2.1\n", "%maven ai.djl:repository:0.2.1\n", "%maven ai.djl:model-zoo:0.2.1\n", "%maven ai.djl.mxnet:mxnet-engine:0.2.1\n", "%maven ai.djl.mxnet:mxnet-model-zoo:0.2.1\n", "%maven org.slf4j:slf4j-api:1.7.26\n", "%maven org.slf4j:slf4j-simple:1.7.26\n", "%maven net.java.dev.jna:jna:5.3.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Include MXNet engine dependency\n", "\n", "This tutorial uses MXNet engine as its backend. MXNet has different [build flavor](https://mxnet.apache.org/get_started?version=v1.5.1&platform=linux&language=python&environ=pip&processor=cpu) and it is platform specific.\n", "Please read [here](https://github.com/awslabs/djl/blob/master/examples/README.md#engine-selection) for how to select MXNet engine flavor." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "String osName = System.getProperty(\"os.name\");\n", "String classifier = osName.startsWith(\"Mac\") ? \"osx-x86_64\" : osName.startsWith(\"Win\") ? \"win-x86_64\" : \"linux-x86_64\";\n", "\n", "%maven ai.djl.mxnet:mxnet-native-mkl:jar:${classifier}:1.6.0-b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import java.awt.image.*;\n", "import java.nio.file.*;\n", "import java.util.*;\n", "import java.util.stream.*;\n", "import ai.djl.*;\n", "import ai.djl.inference.*;\n", "import ai.djl.ndarray.*;\n", "import ai.djl.ndarray.index.*;\n", "import ai.djl.modality.*;\n", "import ai.djl.modality.cv.*;\n", "import ai.djl.modality.cv.util.*;\n", "import ai.djl.modality.cv.util.NDImageUtils.Flag;\n", "import ai.djl.mxnet.zoo.*;\n", "import ai.djl.translate.*;\n", "import ai.djl.util.*;\n", "import ai.djl.zoo.cv.classification.*;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Load your handwritten digit image" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var img = BufferedImageUtils.fromUrl(\"https://djl-ai.s3.amazonaws.com/resources/images/0.png\");\n", "img" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2: Load your model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Path modelDir = Paths.get(\"build/mlp\");\n", "Model model = Model.newInstance();\n", "model.setBlock(new Mlp(28, 28));\n", "model.load(modelDir);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3: Create a `Translator`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Translator translator = new Translator() {\n", "\n", " @Override\n", " public NDList processInput(TranslatorContext ctx, BufferedImage input) {\n", " // Convert BufferedImage to NDArray\n", " NDArray array = BufferedImageUtils.toNDArray(ctx.getNDManager(), input, NDImageUtils.Flag.GRAYSCALE);\n", " return new NDList(NDImageUtils.toTensor(array));\n", " }\n", "\n", " @Override\n", " public Classifications processOutput(TranslatorContext ctx, NDList list) {\n", " NDArray probabilities = list.singletonOrThrow().softmax(0);\n", " List indices = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());\n", " return new Classifications(indices, probabilities);\n", " }\n", "};\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3: Run inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var predictor = model.newPredictor(translator);\n", "var classifications = predictor.predict(img);\n", "\n", "classifications" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "\n", "You can find the full source in the [examples project](https://github.com/awslabs/djl/blob/master/examples/docs/image_classification.md)." ] } ], "metadata": { "kernelspec": { "display_name": "Java", "language": "java", "name": "java" }, "language_info": { "codemirror_mode": "java", "file_extension": ".jshell", "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", "version": "12.0.2+10" } }, "nbformat": 4, "nbformat_minor": 2 }