{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DJL BERT Inference Demo\n", "\n", "## Introduction\n", "\n", "In this tutorial, you walk through running inference using DJL on a [BERT](https://towardsdatascience.com/bert-explained-state-of-the-art-language-model-for-nlp-f8b21a9b6270) QA model trained with MXNet and PyTorch. \n", "You can provide a question and a paragraph containing the answer to the model. The model is then able to find the best answer from the answer paragraph.\n", "\n", "Example:\n", "```text\n", "Q: When did BBC Japan start broadcasting?\n", "```\n", "\n", "Answer paragraph:\n", "```text\n", "BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006.\n", "It ceased operations after its Japanese distributor folded.\n", "```\n", "And it picked the right answer:\n", "```text\n", "A: December 2004\n", "```\n", "\n", "One of the most powerful features of DJL is that it's engine agnostic. Because of this, you can run different backend engines seamlessly. We showcase BERT QA first with an MXNet pre-trained model, then with a PyTorch model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparation\n", "\n", "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/awslabs/djl/blob/master/jupyter/README.md)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n", "\n", "%maven ai.djl:api:0.7.0\n", "%maven ai.djl.mxnet:mxnet-engine:0.7.0\n", "%maven ai.djl.mxnet:mxnet-model-zoo:0.7.0\n", "%maven ai.djl.pytorch:pytorch-engine:0.7.0\n", "%maven ai.djl.pytorch:pytorch-model-zoo:0.7.0\n", "%maven org.slf4j:slf4j-api:1.7.26\n", "%maven org.slf4j:slf4j-simple:1.7.26\n", "%maven net.java.dev.jna:jna:5.3.0\n", "\n", "// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md\n", "// and See https://github.com/awslabs/djl/blob/master/pytorch/pytorch-engine/README.md\n", "// for more engine library selection options\n", "%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport\n", "%maven ai.djl.pytorch:pytorch-native-auto:1.6.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import java packages by running the following:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import ai.djl.*;\n", "import ai.djl.engine.*;\n", "import ai.djl.modality.nlp.qa.*;\n", "import ai.djl.repository.zoo.*;\n", "import ai.djl.training.util.*;\n", "import ai.djl.inference.*;\n", "import ai.djl.repository.zoo.*;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that all of the prerequisites are complete, start writing code to run inference with this example.\n", "\n", "\n", "## Load the model and input\n", "\n", "**First, load the input**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var question = \"When did BBC Japan start broadcasting?\";\n", "var resourceDocument = \"BBC Japan was a general entertainment Channel.\\n\" +\n", " \"Which operated between December 2004 and April 2006.\\n\" +\n", " \"It ceased operations after its Japanese distributor folded.\";\n", "\n", "QAInput input = new QAInput(question, resourceDocument);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then load the model and vocabulary. Create a variable `model` by using the `ModelZoo` as shown in the following code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Criteria criteria = Criteria.builder()\n", " .optApplication(Application.NLP.QUESTION_ANSWER)\n", " .setTypes(QAInput.class, String.class)\n", " .optFilter(\"backbone\", \"bert\")\n", " .optEngine(\"MXNet\") // For DJL to use MXNet engine\n", " .optProgress(new ProgressBar()).build();\n", "ZooModel model = ModelZoo.loadModel(criteria);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run inference\n", "Once the model is loaded, you can call `Predictor` and run inference as follows" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Predictor predictor = model.newPredictor();\n", "String answer = predictor.predict(input);\n", "answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Running inference on DJL is that easy. Now, let's try the PyTorch engine by specifying PyTorch engine in Criteria.optEngine(\"PyTorch\"). Let's rerun the inference code." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "var question = \"When did BBC Japan start broadcasting?\";\n", "var resourceDocument = \"BBC Japan was a general entertainment Channel.\\n\" +\n", " \"Which operated between December 2004 and April 2006.\\n\" +\n", " \"It ceased operations after its Japanese distributor folded.\";\n", "\n", "QAInput input = new QAInput(question, resourceDocument);\n", "\n", "Criteria criteria = Criteria.builder()\n", " .optApplication(Application.NLP.QUESTION_ANSWER)\n", " .setTypes(QAInput.class, String.class)\n", " .optFilter(\"backbone\", \"bert\")\n", " .optEngine(\"PyTorch\") // Use PyTorch engine\n", " .optProgress(new ProgressBar()).build();\n", "ZooModel model = ModelZoo.loadModel(criteria);\n", "Predictor predictor = model.newPredictor();\n", "String answer = predictor.predict(input);\n", "answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "Suprisingly, there are no differences between the PyTorch code snippet and MXNet code snippet. \n", "This is power of DJL. We define a unified API where you can switch to different backend engines on the fly.\n", "Next chapter: Inference with your own BERT: [MXNet](mxnet/load_your_own_mxnet_bert.ipynb) [PyTorch](pytorch/load_your_own_pytorch_bert.ipynb)." ] } ], "metadata": { "kernelspec": { "display_name": "Java", "language": "java", "name": "java" }, "language_info": { "codemirror_mode": "java", "file_extension": ".jshell", "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", "version": "12.0.2+10" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 2 }