{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 多层感知机的简洁实现\n",
    ":label:`sec_mlp_concise`\n",
    "\n",
    "正如你所期待的,我们可以使用`DJL`更简洁地实现多层感知机。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.metric.*;\n",
    "import ai.djl.basicdataset.cv.classification.*;\n",
    "import org.apache.commons.lang3.ArrayUtils;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型\n",
    "\n",
    "与softmax回归的简洁实现(:numref:`sec_softmax_concise`)相比,唯一的区别是我们添加了2个全连接层(之前我们只添加了1个全连接层)。第一层是隐藏层,它包含256个隐藏单元,并使用了ReLU激活函数。第二层是输出层。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SequentialBlock net = new SequentialBlock();\n",
    "net.add(Blocks.batchFlattenBlock(784));\n",
    "net.add(Linear.builder().setUnits(256).build());\n",
    "net.add(Activation::relu);\n",
    "net.add(Linear.builder().setUnits(10).build());\n",
    "net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练过程的实现与我们实现softmax回归时完全相同,这种模块化设计使我们能够将与和模型架构有关的内容独立出来。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int batchSize = 256;\n",
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 10);\n",
    "double[] trainLoss;\n",
    "double[] testAccuracy;\n",
    "double[] epochCount;\n",
    "double[] trainAccuracy;\n",
    "\n",
    "trainLoss = new double[numEpochs];\n",
    "trainAccuracy = new double[numEpochs];\n",
    "testAccuracy = new double[numEpochs];\n",
    "epochCount = new double[numEpochs];\n",
    "\n",
    "FashionMnist trainIter = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TRAIN)\n",
    "        .setSampling(batchSize, true)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "FashionMnist testIter = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TEST)\n",
    "        .setSampling(batchSize, true)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "trainIter.prepare();\n",
    "testIter.prepare();\n",
    "\n",
    "for(int i = 0; i < epochCount.length; i++) {\n",
    "    epochCount[i] = (i + 1);\n",
    "}\n",
    "\n",
    "Map<String, double[]> evaluatorMetrics = new HashMap<>();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Tracker lrt = Tracker.fixed(0.5f);\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "                .optOptimizer(sgd) // Optimizer (loss function)\n",
    "                .optDevices(Engine.getInstance().getDevices(1)) // single GPU\n",
    "                .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "    try (Model model = Model.newInstance(\"mlp\")) {\n",
    "        model.setBlock(net);\n",
    "\n",
    "        try (Trainer trainer = model.newTrainer(config)) {\n",
    "\n",
    "            trainer.initialize(new Shape(1, 784));\n",
    "            trainer.setMetrics(new Metrics());\n",
    "\n",
    "            EasyTrain.fit(trainer, numEpochs, trainIter, testIter);\n",
    "            // collect results from evaluators\n",
    "            Metrics metrics = trainer.getMetrics();\n",
    "\n",
    "            trainer.getEvaluators().stream()\n",
    "                    .forEach(evaluator -> {\n",
    "                        evaluatorMetrics.put(\"train_epoch_\" + evaluator.getName(), metrics.getMetric(\"train_epoch_\" + evaluator.getName()).stream()\n",
    "                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n",
    "                        evaluatorMetrics.put(\"validate_epoch_\" + evaluator.getName(), metrics.getMetric(\"validate_epoch_\" + evaluator.getName()).stream()\n",
    "                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n",
    "            });\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainLoss = evaluatorMetrics.get(\"train_epoch_SoftmaxCrossEntropyLoss\");\n",
    "trainAccuracy = evaluatorMetrics.get(\"train_epoch_Accuracy\");\n",
    "testAccuracy = evaluatorMetrics.get(\"validate_epoch_Accuracy\");\n",
    "\n",
    "String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n",
    "\n",
    "Arrays.fill(lossLabel, 0, trainLoss.length, \"test acc\");\n",
    "Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n",
    "Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n",
    "                trainLoss.length + testAccuracy.length + trainAccuracy.length, \"train loss\");\n",
    "\n",
    "Table data = Table.create(\"Data\").addColumns(\n",
    "            DoubleColumn.create(\"epochCount\", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),\n",
    "            DoubleColumn.create(\"loss\", ArrayUtils.addAll(testAccuracy , ArrayUtils.addAll(trainAccuracy, trainLoss))),\n",
    "            StringColumn.create(\"lossLabel\", lossLabel)\n",
    ");\n",
    "\n",
    "render(LinePlot.create(\"\", data, \"epochCount\", \"loss\", \"lossLabel\"),\"text/html\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 练习\n",
    "\n",
    "1. 尝试添加不同数量的隐藏层(也可以修改学习率)。怎么样设置效果最好?\n",
    "1. 尝试不同的激活函数。哪个效果最好?\n",
    "1. 尝试不同的方案来初始化权重。什么方法效果最好?\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}