# 多层感知机的简洁实现
:label:`sec_mlp_concise`

正如你所期待的,我们可以使用`DJL`更简洁地实现多层感知机。


In [None]:
%load ../utils/djl-imports
%load ../utils/plot-utils

In [None]:
import ai.djl.metric.*;
import ai.djl.basicdataset.cv.classification.*;
import org.apache.commons.lang3.ArrayUtils;

## 模型

与softmax回归的简洁实现(:numref:`sec_softmax_concise`)相比,唯一的区别是我们添加了2个全连接层(之前我们只添加了1个全连接层)。第一层是隐藏层,它包含256个隐藏单元,并使用了ReLU激活函数。第二层是输出层。


In [None]:
SequentialBlock net = new SequentialBlock();
net.add(Blocks.batchFlattenBlock(784));
net.add(Linear.builder().setUnits(256).build());
net.add(Activation::relu);
net.add(Linear.builder().setUnits(10).build());
net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);

训练过程的实现与我们实现softmax回归时完全相同,这种模块化设计使我们能够将与和模型架构有关的内容独立出来。


In [None]:
int batchSize = 256;
int numEpochs = Integer.getInteger("MAX_EPOCH", 10);
double[] trainLoss;
double[] testAccuracy;
double[] epochCount;
double[] trainAccuracy;

trainLoss = new double[numEpochs];
trainAccuracy = new double[numEpochs];
testAccuracy = new double[numEpochs];
epochCount = new double[numEpochs];

FashionMnist trainIter = FashionMnist.builder()
        .optUsage(Dataset.Usage.TRAIN)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

FashionMnist testIter = FashionMnist.builder()
        .optUsage(Dataset.Usage.TEST)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

trainIter.prepare();
testIter.prepare();

for(int i = 0; i < epochCount.length; i++) {
    epochCount[i] = (i + 1);
}

Map<String, double[]> evaluatorMetrics = new HashMap<>();

In [None]:
Tracker lrt = Tracker.fixed(0.5f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

Loss loss = Loss.softmaxCrossEntropyLoss();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
                .optOptimizer(sgd) // Optimizer (loss function)
                .optDevices(Engine.getInstance().getDevices(1)) // single GPU
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

    try (Model model = Model.newInstance("mlp")) {
        model.setBlock(net);

        try (Trainer trainer = model.newTrainer(config)) {

            trainer.initialize(new Shape(1, 784));
            trainer.setMetrics(new Metrics());

            EasyTrain.fit(trainer, numEpochs, trainIter, testIter);
            // collect results from evaluators
            Metrics metrics = trainer.getMetrics();

            trainer.getEvaluators().stream()
                    .forEach(evaluator -> {
                        evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                        evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
            });
    }
}

In [None]:
trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

Arrays.fill(lossLabel, 0, trainLoss.length, "test acc");
Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                trainLoss.length + testAccuracy.length + trainAccuracy.length, "train loss");

Table data = Table.create("Data").addColumns(
            DoubleColumn.create("epochCount", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
            DoubleColumn.create("loss", ArrayUtils.addAll(testAccuracy , ArrayUtils.addAll(trainAccuracy, trainLoss))),
            StringColumn.create("lossLabel", lossLabel)
);

render(LinePlot.create("", data, "epochCount", "loss", "lossLabel"),"text/html");

## 练习

1. 尝试添加不同数量的隐藏层(也可以修改学习率)。怎么样设置效果最好?
1. 尝试不同的激活函数。哪个效果最好?
1. 尝试不同的方案来初始化权重。什么方法效果最好?
