{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DJL BERT Inference Demo\n", "\n", "## Introduction\n", "\n", "In this tutorial, you walk through running inference using DJL on a [BERT](https://towardsdatascience.com/bert-explained-state-of-the-art-language-model-for-nlp-f8b21a9b6270) QA model trained with MXNet. \n", "You can provide a question and a paragraph containing the answer to the model. The model is then able to find the best answer from the answer paragraph.\n", "\n", "Example:\n", "```text\n", "Q: When did BBC Japan start broadcasting?\n", "```\n", "\n", "Answer paragraph:\n", "```text\n", "BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006.\n", "It ceased operations after its Japanese distributor folded.\n", "```\n", "And it picked the right answer:\n", "```text\n", "A: December 2004\n", "```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparation\n", "\n", "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/awslabs/djl/blob/v0.3.0/jupyter/README.md)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%maven ai.djl:api:0.3.0\n", "%maven ai.djl.mxnet:mxnet-engine:0.3.0\n", "%maven ai.djl:repository:0.3.0\n", "%maven ai.djl.mxnet:mxnet-model-zoo:0.3.0\n", "%maven org.slf4j:slf4j-api:1.7.26\n", "%maven org.slf4j:slf4j-simple:1.7.26\n", "%maven net.java.dev.jna:jna:5.3.0\n", " \n", "// See https://github.com/awslabs/djl/blob/v0.3.0/mxnet/mxnet-engine/README.md\n", "// for more MXNet library selection options\n", "%maven ai.djl.mxnet:mxnet-native-auto:1.6.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import java packages by running the following:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import java.io.*;\n", "import java.nio.charset.*;\n", "import java.nio.file.*;\n", "import java.util.*;\n", "import com.google.gson.*;\n", "import com.google.gson.annotations.*;\n", "import ai.djl.*;\n", "import ai.djl.inference.*;\n", "import ai.djl.metric.*;\n", "import ai.djl.mxnet.zoo.*;\n", "import ai.djl.mxnet.zoo.nlp.qa.*;\n", "import ai.djl.repository.zoo.*;\n", "import ai.djl.ndarray.*;\n", "import ai.djl.ndarray.types.*;\n", "import ai.djl.training.util.*;\n", "import ai.djl.translate.*;\n", "import ai.djl.util.*;\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that all of the prerequisites are complete, start writing code to run inference with this example.\n", "\n", "## Load the model and input\n", "\n", "The model requires three inputs:\n", "\n", "- word indices: The index of each word in a sentence\n", "- word types: The type index of the word. All Questions will be labelled with 0 and all Answers will be labelled with 1.\n", "- sequence length: You need to limit the length of the input. In this case, the length is 384\n", "- valid length: The actual length of the question and answer tokens\n", "\n", "**First, load the input**\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var question = \"When did BBC Japan start broadcasting?\";\n", "var resourceDocument = \"BBC Japan was a general entertainment Channel.\\n\" +\n", " \"Which operated between December 2004 and April 2006.\\n\" +\n", " \"It ceased operations after its Japanese distributor folded.\";\n", "\n", "QAInput input = new QAInput(question, resourceDocument, 384);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then load the model and vocabulary. Create a variable `model` by using the `ModelZoo` as shown in the following code. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Criteria criteria = Criteria.builder()\n", " .optApplication(Application.NLP.QUESTION_ANSWER)\n", " .setTypes(QAInput.class, String.class)\n", " .optFilter(\"backbone\", \"bert\")\n", " .optFilter(\"dataset\", \"book_corpus_wiki_en_uncased\")\n", " .optProgress(new ProgressBar()).build();\n", "ZooModel model = ModelZoo.loadModel(criteria);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run inference\n", "Once the model is loaded, you can call `Predictor` and run inference as follows" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Predictor predictor = model.newPredictor();\n", "String answer = predictor.predict(input);\n", "answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Running inference on DJL is that easy. In the example, you use a model from the `ModelZoo`. However, you can also load the model on your own and use custom classes as the input and output. The process for that is illustrated in greater detail later in this tutorial. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dive deep into Translator\n", "\n", "Inference in deep learning is the process of predicting the output for a given input based on a pre-defined model. \n", "DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide \n", "output. DJL also allows you to provide user-defined inputs. The workflow looks like the following:\n", "\n", "![image](https://github.com/awslabs/djl/blob/v0.3.0/examples/docs/img/workFlow.png?raw=true)\n", "\n", "The red block (\"Images\") in the workflow is the input that DJL expects from you. The green block (\"Images \n", "bounding box\") is the output that you expect. Because DJL does not know which input to expect and which output format that you prefer, DJL provides the `Translator` interface so you can define your own \n", "input and output. \n", "\n", "The `Translator` interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing \n", "component converts the user-defined input objects into an NDList, so that the `Predictor` in DJL can understand the \n", "input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the \n", "`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output \n", "format. \n", "\n", "### Pre-processing\n", "\n", "Now, you need to convert the sentences into tokens. You can use `BertDataParser.tokenizer` to convert questions and answers into tokens. Then, use `BertDataParser.formTokens` to create Bert-Formatted tokens. Once you have properly formatted tokens, use `parser.token2idx` to create the indices. \n", "\n", "The following code block converts the question and answer defined earlier into bert-formatted tokens and creates word types for the tokens. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// Create token lists for question and answer\n", "List tokenQ = BertDataParser.tokenizer(question.toLowerCase());\n", "List tokenA = BertDataParser.tokenizer(resourceDocument.toLowerCase());\n", "int validLength = tokenQ.size() + tokenA.size();\n", "System.out.println(\"Question Token: \" + tokenQ);\n", "System.out.println(\"Answer Token: \" + tokenA);\n", "System.out.println(\"Valid length: \" + validLength);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Normally, words/sentences are represented as indices instead of Strings for training. They typically work like a vector in a n-dimensional space. In this case, you need to map them into indices. The form tokens also pad the sentence to the required length." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// Create Bert-formatted tokens\n", "List tokens = BertDataParser.formTokens(tokenQ, tokenA, 384);\n", "// Convert tokens into indices in the vocabulary\n", "BertDataParser parser = model.getArtifact(\"vocab.json\", BertDataParser::parse);\n", "List indices = parser.token2idx(tokens);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, the model needs to understand which part is the Question and which part is the Answer. Mask the tokens as follows:\n", "```\n", "[Question tokens...AnswerTokens...padding tokens] => [000000...11111....0000]\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// Get token types\n", "List tokenTypes = BertDataParser.getTokenTypes(tokenQ, tokenA, 384);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To properly convert them into `float[]` for `NDArray` creation, here is the helper function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "/**\n", " * Convert a List of Number to float array.\n", " *\n", " * @param list the list to be converted\n", " * @return float array\n", " */\n", "public static float[] toFloatArray(List list) {\n", " float[] ret = new float[list.size()];\n", " int idx = 0;\n", " for (Number n : list) {\n", " ret[idx++] = n.floatValue();\n", " }\n", " return ret;\n", "}\n", "\n", "float[] indicesFloat = toFloatArray(indices);\n", "float[] types = toFloatArray(tokenTypes);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that you have everything you need, you can create an NDList and populate all of the inputs you formatted earlier. You're done with pre-processing! \n", "\n", "#### Construct `Translator`\n", "\n", "You need to do this processing within an implementation of the `Translator` interface. `Translator` is designed to do pre-processing and post-processing. You must define the input and output objects. It contains the following two override classes:\n", "- `public NDList processInput(TranslatorContext ctx, I)`\n", "- `public String processOutput(TranslatorContext ctx, O)`\n", "\n", "Every translator takes in input and returns output in the form of generic objects. In this case, the translator takes input in the form of `QAInput` (I) and returns output as a `String` (O). `QAInput` is just an object that holds questions and answer; We have prepared the Input class for you." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Armed with the needed knowledge, you can write an implementation of the `Translator` interface. `BertTranslator` uses the code snippets explained previously to implement the `processInput`method. For more information, see [`NDManager`](https://javadoc.djl.ai/api/0.3.0/ai/djl/ndarray/NDManager.html).\n", "\n", "```\n", "manager.create(Number[] data, Shape)\n", "manager.create(Number[] data)\n", "```\n", "\n", "The `Shape` for `data0` and `data1` is (num_of_batches, sequence_length). For `data2` is just 1." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "public class BertTranslator implements Translator {\n", " private BertDataParser parser;\n", " private List tokens;\n", " private int seqLength;\n", "\n", " BertTranslator(BertDataParser parser) {\n", " this.parser = parser;\n", " this.seqLength = 384;\n", " }\n", " \n", " @Override\n", " public Batchifier getBatchifier() {\n", " return null;\n", " }\n", "\n", " @Override\n", " public NDList processInput(TranslatorContext ctx, QAInput input) throws IOException {\n", " BertDataParser parser = ctx.getModel().getArtifact(\"vocab.json\", BertDataParser::parse);\n", " // Pre-processing - tokenize sentence\n", " // Create token lists for question and answer\n", " List tokenQ = BertDataParser.tokenizer(question.toLowerCase());\n", " List tokenA = BertDataParser.tokenizer(resourceDocument.toLowerCase());\n", " \n", " // Calculate valid length (length(Question tokens) + length(resourceDocument tokens))\n", " var validLength = tokenQ.size() + tokenA.size();\n", " \n", " // Create Bert-formatted tokens\n", " tokens = BertDataParser.formTokens(tokenQ, tokenA, 384);\n", " \n", " if (tokens == null) {\n", " throw new IllegalStateException(\"tokens is not defined\");\n", " }\n", " \n", " // Convert tokens into indices in the vocabulary\n", " List indices = parser.token2idx(tokens);\n", " // Get token types\n", " List tokenTypes = BertDataParser.getTokenTypes(tokenQ, tokenA, 384);\n", "\n", " NDManager manager = ctx.getNDManager();\n", " \n", " // Using the manager created, create NDArrays for the indices, types, and valid length.\n", " // in that order. The type of the NDArray should all be float\n", " NDArray indicesNd = manager.create(toFloatArray(indices), new Shape(1, 384));\n", " indicesNd.setName(\"data0\");\n", " NDArray typesNd = manager.create(toFloatArray(tokenTypes), new Shape(1, 384));\n", " typesNd.setName(\"data1\");\n", " NDArray validLengthNd = manager.create(new float[]{validLength});\n", " validLengthNd.setName(\"data2\");\n", "\n", " NDList list = new NDList(3);\n", " list.add(indicesNd);\n", " list.add(typesNd);\n", " list.add(validLengthNd);\n", " \n", " return list;\n", " }\n", "\n", " @Override\n", " public String processOutput(TranslatorContext ctx, NDList list) {\n", " NDArray array = list.singletonOrThrow();\n", " NDList output = array.split(2, 2);\n", " // Get the formatted logits result\n", " NDArray startLogits = output.get(0).reshape(new Shape(1, -1));\n", " NDArray endLogits = output.get(1).reshape(new Shape(1, -1));\n", " // Get Probability distribution\n", " NDArray startProb = startLogits.softmax(-1);\n", " NDArray endProb = endLogits.softmax(-1);\n", " int startIdx = (int) startProb.argMax(1).getLong();\n", " int endIdx = (int) endProb.argMax(1).getLong();\n", " return tokens.subList(startIdx, endIdx + 1).toString();\n", " }\n", " }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Congrats! You have created your first Translator! We have pre-filled the `processOutput()` that will process the `NDList` and return it in a desired format. `processInput()` and `processOutput()` offer the flexibility to get the predictions from the model in any format you desire. \n", "\n", "\n", "With the Translator implemented, you need to bring up the predictor that uses your `Translator` to start making predictions. You can find the usage for `Predictor` in the [Predictor Javadoc](https://javadoc.djl.ai/api/0.3.0/ai/djl/inference/Predictor.html). Create a translator and use the `question` and `resourceDocument` provided previously." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "String predictResult = null;\n", "\n", "QAInput input = new QAInput(question, resourceDocument, 384);\n", "BertTranslator translator = new BertTranslator(parser);\n", "\n", "// Create a Predictor and use it to predict the output\n", "try (Predictor predictor = model.newPredictor(translator)) {\n", " predictResult = predictor.predict(input);\n", "}\n", "\n", "System.out.println(question);\n", "System.out.println(predictResult);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on the input, the following result will be shown:\n", "```\n", "[december, 2004]\n", "```\n", "That's it! \n", "\n", "You can try with more questions and answers. Here are the samples:\n", "\n", "**Answer Material**\n", "\n", "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.\n", "\n", "\n", "**Question**\n", "\n", "Q: When were the Normans in Normandy?\n", "A: 10th and 11th centuries\n", "\n", "Q: In what country is Normandy located?\n", "A: france" ] } ], "metadata": { "kernelspec": { "display_name": "Java", "language": "java", "name": "java" }, "language_info": { "codemirror_mode": "java", "file_extension": ".jshell", "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", "version": "12.0.2+10" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 2 }