# Classification on Iris dataset with sklearn and DJL

In this notebook, you will try to use a pre-trained sklearn model to run on DJL for a general classification task. The model was trained with [Iris flower dataset](https://en.wikipedia.org/wiki/Iris_flower_data_set).

## Background 

### Iris Dataset

The dataset contains a set of 150 records under five attributes - sepal length, sepal width, petal length, petal width and species.

Iris setosa                |  Iris versicolor          | Iris virginica
:-------------------------:|:-------------------------:|:-------------------------:
![](https://upload.wikimedia.org/wikipedia/commons/5/56/Kosaciec_szczecinkowaty_Iris_setosa.jpg)  |  ![](https://upload.wikimedia.org/wikipedia/commons/4/41/Iris_versicolor_3.jpg) | ![](https://upload.wikimedia.org/wikipedia/commons/9/9f/Iris_virginica.jpg) 

The chart above shows three different kinds of the Iris flowers. 

We will use sepal length, sepal width, petal length, petal width as the feature and species as the label to train the model.

### Sklearn Model

You can find more information [here](http://onnx.ai/sklearn-onnx/). You can use the sklearn built-in iris dataset to load the data. Then we defined a [RandomForestClassifer](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html) to train the model. After that, we convert the model to onnx format for DJL to run inference. The following code is a sample classification setup using sklearn:

```python
# Train a model.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)
```


## Preparation

This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md).

These are dependencies we will use. To enhance the NDArray operation capability, we are importing ONNX Runtime and PyTorch Engine at the same time. Please find more information [here](https://github.com/deepjavalibrary/djl/blob/master/docs/hybrid_engine.md).

In [None]:
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.22.1
%maven ai.djl.onnxruntime:onnxruntime-engine:0.22.1
%maven org.slf4j:slf4j-simple:1.7.32

In [None]:
import ai.djl.inference.*;
import ai.djl.modality.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.types.*;
import ai.djl.repository.zoo.*;
import ai.djl.translate.*;
import java.util.*;

## Step 1 create a Translator

Inference in machine learning is the process of predicting the output for a given input based on a pre-defined model.
DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide
output. DJL also allows you to provide user-defined inputs. The workflow looks like the following:

![https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true](https://github.com/deepjavalibrary/djl/blob/master/examples/docs/img/workFlow.png?raw=true)

The [`Translator`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html) interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing
component converts the user-defined input objects into an NDList, so that the [`Predictor`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/inference/Predictor.html) in DJL can understand the
input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the
`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output
format.

In our use case, we use a class namely `IrisFlower` as our input class type. We will use [`Classifications`](https://javadoc.io/doc/ai.djl/api/0.22.1/ai/djl/modality/Classifications.html) as our output class type.

In [None]:
public static class IrisFlower {

    public float sepalLength;
    public float sepalWidth;
    public float petalLength;
    public float petalWidth;

    public IrisFlower(float sepalLength, float sepalWidth, float petalLength, float petalWidth) {
        this.sepalLength = sepalLength;
        this.sepalWidth = sepalWidth;
        this.petalLength = petalLength;
        this.petalWidth = petalWidth;
    }
}

Let's create a translator

In [None]:
public static class MyTranslator implements NoBatchifyTranslator<IrisFlower, Classifications> {

    private final List<String> synset;

    public MyTranslator() {
        // species name
        synset = Arrays.asList("setosa", "versicolor", "virginica");
    }

    @Override
    public NDList processInput(TranslatorContext ctx, IrisFlower input) {
        float[] data = {input.sepalLength, input.sepalWidth, input.petalLength, input.petalWidth};
        NDArray array = ctx.getNDManager().create(data, new Shape(1, 4));
        return new NDList(array);
    }

    @Override
    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        float[] data = list.get(1).toFloatArray();
        List<Double> probabilities = new ArrayList<>(data.length);
        for (float f : data) {
            probabilities.add((double) f);
        }
        return new Classifications(synset, probabilities);
    }
}

## Step 2 Prepare your model

We will load a pretrained sklearn model into DJL. We defined a [`ModelZoo`](https://javadoc.io/doc/ai.djl/api/0.22.1/ai/djl/repository/zoo/ModelZoo.html) concept to allow user load model from varity of locations, such as remote URL, local files or DJL pretrained model zoo. We need to define [`Criteria`](https://javadoc.io/doc/ai.djl/api/0.22.1/ai/djl/repository/zoo/Criteria.html) class to help the modelzoo locate the model and attach translator. In this example, we download a compressed ONNX model from S3.

In [None]:
String modelUrl = "https://mlrepo.djl.ai/model/tabular/softmax_regression/ai/djl/onnxruntime/iris_flowers/0.0.1/iris_flowers.zip";
Criteria<IrisFlower, Classifications> criteria = Criteria.builder()
        .setTypes(IrisFlower.class, Classifications.class)
        .optModelUrls(modelUrl)
        .optTranslator(new MyTranslator())
        .optEngine("OnnxRuntime") // use OnnxRuntime engine by default
        .build();
ZooModel<IrisFlower, Classifications> model = criteria.loadModel();

## Step 3 Run inference

User will just need to create a `Predictor` from model to run the inference.

In [None]:
Predictor<IrisFlower, Classifications> predictor = model.newPredictor();
IrisFlower info = new IrisFlower(1.0f, 2.0f, 3.0f, 4.0f);
predictor.predict(info);