{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 多层感知机的简洁实现\n", ":label:`sec_mlp_concise`\n", "\n", "正如你所期待的,我们可以使用`DJL`更简洁地实现多层感知机。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load ../utils/djl-imports\n", "%load ../utils/plot-utils" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import ai.djl.metric.*;\n", "import ai.djl.basicdataset.cv.classification.*;\n", "import org.apache.commons.lang3.ArrayUtils;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 模型\n", "\n", "与softmax回归的简洁实现(:numref:`sec_softmax_concise`)相比,唯一的区别是我们添加了2个全连接层(之前我们只添加了1个全连接层)。第一层是隐藏层,它包含256个隐藏单元,并使用了ReLU激活函数。第二层是输出层。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "SequentialBlock net = new SequentialBlock();\n", "net.add(Blocks.batchFlattenBlock(784));\n", "net.add(Linear.builder().setUnits(256).build());\n", "net.add(Activation::relu);\n", "net.add(Linear.builder().setUnits(10).build());\n", "net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "训练过程的实现与我们实现softmax回归时完全相同,这种模块化设计使我们能够将与和模型架构有关的内容独立出来。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "int batchSize = 256;\n", "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 10);\n", "double[] trainLoss;\n", "double[] testAccuracy;\n", "double[] epochCount;\n", "double[] trainAccuracy;\n", "\n", "trainLoss = new double[numEpochs];\n", "trainAccuracy = new double[numEpochs];\n", "testAccuracy = new double[numEpochs];\n", "epochCount = new double[numEpochs];\n", "\n", "FashionMnist trainIter = FashionMnist.builder()\n", " .optUsage(Dataset.Usage.TRAIN)\n", " .setSampling(batchSize, true)\n", " .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n", " .build();\n", "\n", "FashionMnist testIter = FashionMnist.builder()\n", " .optUsage(Dataset.Usage.TEST)\n", " .setSampling(batchSize, true)\n", " .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n", " .build();\n", "\n", "trainIter.prepare();\n", "testIter.prepare();\n", "\n", "for(int i = 0; i < epochCount.length; i++) {\n", " epochCount[i] = (i + 1);\n", "}\n", "\n", "Map<String, double[]> evaluatorMetrics = new HashMap<>();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Tracker lrt = Tracker.fixed(0.5f);\n", "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n", "\n", "Loss loss = Loss.softmaxCrossEntropyLoss();\n", "\n", "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n", " .optOptimizer(sgd) // Optimizer (loss function)\n", " .optDevices(Engine.getInstance().getDevices(1)) // single GPU\n", " .addEvaluator(new Accuracy()) // Model Accuracy\n", " .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n", "\n", " try (Model model = Model.newInstance(\"mlp\")) {\n", " model.setBlock(net);\n", "\n", " try (Trainer trainer = model.newTrainer(config)) {\n", "\n", " trainer.initialize(new Shape(1, 784));\n", " trainer.setMetrics(new Metrics());\n", "\n", " EasyTrain.fit(trainer, numEpochs, trainIter, testIter);\n", " // collect results from evaluators\n", " Metrics metrics = trainer.getMetrics();\n", "\n", " trainer.getEvaluators().stream()\n", " .forEach(evaluator -> {\n", " evaluatorMetrics.put(\"train_epoch_\" + evaluator.getName(), metrics.getMetric(\"train_epoch_\" + evaluator.getName()).stream()\n", " .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n", " evaluatorMetrics.put(\"validate_epoch_\" + evaluator.getName(), metrics.getMetric(\"validate_epoch_\" + evaluator.getName()).stream()\n", " .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n", " });\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainLoss = evaluatorMetrics.get(\"train_epoch_SoftmaxCrossEntropyLoss\");\n", "trainAccuracy = evaluatorMetrics.get(\"train_epoch_Accuracy\");\n", "testAccuracy = evaluatorMetrics.get(\"validate_epoch_Accuracy\");\n", "\n", "String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n", "\n", "Arrays.fill(lossLabel, 0, trainLoss.length, \"test acc\");\n", "Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n", "Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n", " trainLoss.length + testAccuracy.length + trainAccuracy.length, \"train loss\");\n", "\n", "Table data = Table.create(\"Data\").addColumns(\n", " DoubleColumn.create(\"epochCount\", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),\n", " DoubleColumn.create(\"loss\", ArrayUtils.addAll(testAccuracy , ArrayUtils.addAll(trainAccuracy, trainLoss))),\n", " StringColumn.create(\"lossLabel\", lossLabel)\n", ");\n", "\n", "render(LinePlot.create(\"\", data, \"epochCount\", \"loss\", \"lossLabel\"),\"text/html\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 练习\n", "\n", "1. 尝试添加不同数量的隐藏层(也可以修改学习率)。怎么样设置效果最好?\n", "1. 尝试不同的激活函数。哪个效果最好?\n", "1. 尝试不同的方案来初始化权重。什么方法效果最好?\n" ] } ], "metadata": { "kernelspec": { "display_name": "Java", "language": "java", "name": "java" }, "language_info": { "codemirror_mode": "java", "file_extension": ".jshell", "mimetype": "text/x-java-source", "name": "Java", "pygments_lexer": "java", "version": "14.0.2+12" } }, "nbformat": 4, "nbformat_minor": 4 }