/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference;
import ai.djl.Application;
import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
/**
* An example of inference using BertQA.
*
*
See:
*
*
* - the jupyter demo with more
* information about BERT.
*
- the docs
* for information about running this example.
*
*/
public final class BertQaInference {
private static final Logger logger = LoggerFactory.getLogger(BertQaInference.class);
private BertQaInference() {}
public static void main(String[] args) throws IOException, TranslateException, ModelException {
String answer = BertQaInference.predict();
logger.info("Answer: {}", answer);
}
public static String predict() throws IOException, TranslateException, ModelException {
String question = "When did BBC Japan start broadcasting?";
String paragraph =
"BBC Japan was a general entertainment Channel. "
+ "Which operated between December 2004 and April 2006. "
+ "It ceased operations after its Japanese distributor folded.";
QAInput input = new QAInput(question, paragraph);
logger.info("Paragraph: {}", input.getParagraph());
logger.info("Question: {}", input.getQuestion());
Criteria criteria =
Criteria.builder()
.optApplication(Application.NLP.QUESTION_ANSWER)
.setTypes(QAInput.class, String.class)
.optFilter("backbone", "bert")
.optEngine("PyTorch")
.optDevice(Device.cpu())
.optProgress(new ProgressBar())
.build();
try (ZooModel model = criteria.loadModel();
Predictor predictor = model.newPredictor()) {
return predictor.predict(input);
}
}
}