{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Tensorflow.js\n", "\n", "In this notebooks, we build and train [a deep CNN model](https://developers.google.com/machine-learning/glossary/#convolutional_neural_network) using TensorFlow.js and visualize the predictions from the trained model on MNIST datasets.\n", "You can build and train deep neural network machine learning mode with tslab and Tensorflow.js without using Python.\n", "\n", "### Tensorflow.js references\n", "\n", "- [TensorFlow.js in Node](https://www.tensorflow.org/js/guide/nodejs)\n", "- [Get Started | Tensorflow.js](https://www.tensorflow.org/js/tutorials)\n", "\n", "### Disclaimer\n", "\n", "Don't run this notebook on mybinder.org.\n", "The training of the CNN model in this notebook is very heavy and it will not finish on mybinder.org.\n", "Please try this notebook in your local environment with enough CPU power." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "/**\n", " * Copyright 2018 Google LLC. All Rights Reserved.\n", " * Licensed under the Apache License, Version 2.0 (the \"License\");\n", " * you may not use this file except in compliance with the License.\n", " * You may obtain a copy of the License at\n", " *\n", " * http://www.apache.org/licenses/LICENSE-2.0\n", " *\n", " * Unless required by applicable law or agreed to in writing, software\n", " * distributed under the License is distributed on an \"AS IS\" BASIS,\n", " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", " * See the License for the specific language governing permissions and\n", " * limitations under the License.\n", " * =============================================================================\n", " *\n", " * This code was branched from\n", " * https://github.com/tensorflow/tfjs-examples/blob/master/mnist-node/\n", " * to demostrate Tensorflow in tslab.\n", " */\n", "\n", "import * as tf from '@tensorflow/tfjs-node'\n", "import Jimp from 'jimp';\n", "import {promisify} from 'util';\n", "import {dataset as mnist} from '../lib/mnist';\n", "import {display} from 'tslab';\n", "import * as tslab from 'tslab';" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "

TensorFlow.js versions

" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \u001b[32m'tfjs-core'\u001b[39m: \u001b[32m'1.3.1'\u001b[39m,\n", " \u001b[32m'tfjs-data'\u001b[39m: \u001b[32m'1.3.1'\u001b[39m,\n", " \u001b[32m'tfjs-layers'\u001b[39m: \u001b[32m'1.3.1'\u001b[39m,\n", " \u001b[32m'tfjs-converter'\u001b[39m: \u001b[32m'1.3.1'\u001b[39m,\n", " tfjs: \u001b[32m'1.3.1'\u001b[39m,\n", " \u001b[32m'tfjs-node'\u001b[39m: \u001b[32m'1.3.1'\u001b[39m\n", "}\n" ] }, { "data": { "text/html": [ "

tslab versions

" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{ tslab: \u001b[32m'1.0.7'\u001b[39m, typescript: \u001b[32m'3.7.3'\u001b[39m, node: \u001b[32m'v12.13.0'\u001b[39m }\n" ] } ], "source": [ "display.html('

TensorFlow.js versions

')\n", "console.log(tf.version)\n", "display.html('

tslab versions

')\n", "console.log(tslab.versions);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "await mnist.loadData();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "async function toPng(images: tf.Tensor4D, start: number, size: number): Promise {\n", " // Note: mnist.getTrainData().images.slice([index], [1]) is slow.\n", " let arry = images.slice([start], [size]).flatten().arraySync();\n", " let ret: Buffer[] = [];\n", " for (let i = 0; i < size; i++) {\n", " let raw = [];\n", " for (const v of arry.slice(i * 28 * 28, (i+1)*28*28)) {\n", " raw.push(...[v*255, v*255, v*255, 255])\n", " }\n", " let img = await promisify((cb: (err, v: Jimp)=>any) => {\n", " new Jimp({ data: Buffer.from(raw), width: 28, height: 28 }, cb);\n", " })();\n", " ret.push(await img.getBufferAsync(Jimp.MIME_PNG));\n", " }\n", " return ret;\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "const model = tf.sequential();\n", "model.add(tf.layers.conv2d({\n", " inputShape: [28, 28, 1],\n", " filters: 32,\n", " kernelSize: 3,\n", " activation: 'relu',\n", "}));\n", "model.add(tf.layers.conv2d({\n", " filters: 32,\n", " kernelSize: 3,\n", " activation: 'relu',\n", "}));\n", "model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));\n", "model.add(tf.layers.conv2d({\n", " filters: 64,\n", " kernelSize: 3,\n", " activation: 'relu',\n", "}));\n", "model.add(tf.layers.conv2d({\n", " filters: 64,\n", " kernelSize: 3,\n", " activation: 'relu',\n", "}));\n", "model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));\n", "model.add(tf.layers.flatten());\n", "model.add(tf.layers.dropout({rate: 0.25}));\n", "model.add(tf.layers.dense({units: 512, activation: 'relu'}));\n", "model.add(tf.layers.dropout({rate: 0.5}));\n", "model.add(tf.layers.dense({units: 10, activation: 'softmax'}));\n", "\n", "const optimizer = 'rmsprop';\n", "model.compile({\n", " optimizer: optimizer,\n", " loss: 'categoricalCrossentropy',\n", " metrics: ['accuracy'],\n", "});" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "async function train(epochs, batchSize, modelSavePath) {\n", " const {images: trainImages, labels: trainLabels} = mnist.getTrainData();\n", " model.summary();\n", "\n", " let epochBeginTime;\n", " let millisPerStep;\n", " const validationSplit = 0.15;\n", " const numTrainExamplesPerEpoch =\n", " trainImages.shape[0] * (1 - validationSplit);\n", " const numTrainBatchesPerEpoch =\n", " Math.ceil(numTrainExamplesPerEpoch / batchSize);\n", " await model.fit(trainImages, trainLabels, {\n", " epochs,\n", " batchSize,\n", " validationSplit\n", " });\n", "\n", " const {images: testImages, labels: testLabels} = mnist.getTestData();\n", " const evalOutput = model.evaluate(testImages, testLabels);\n", "\n", " console.log(\n", " `\\nEvaluation result:\\n` +\n", " ` Loss = ${evalOutput[0].dataSync()[0].toFixed(3)}; `+\n", " `Accuracy = ${evalOutput[1].dataSync()[0].toFixed(3)}`);\n", "\n", " if (modelSavePath != null) {\n", " await model.save(`file://${modelSavePath}`);\n", " console.log(`Saved model to path: ${modelSavePath}`);\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output shape Param # \n", "=================================================================\n", "conv2d_Conv2D1 (Conv2D) [null,26,26,32] 320 \n", "_________________________________________________________________\n", "conv2d_Conv2D2 (Conv2D) [null,24,24,32] 9248 \n", "_________________________________________________________________\n", "max_pooling2d_MaxPooling2D1 [null,12,12,32] 0 \n", "_________________________________________________________________\n", "conv2d_Conv2D3 (Conv2D) [null,10,10,64] 18496 \n", "_________________________________________________________________\n", "conv2d_Conv2D4 (Conv2D) [null,8,8,64] 36928 \n", "_________________________________________________________________\n", "max_pooling2d_MaxPooling2D2 [null,4,4,64] 0 \n", "_________________________________________________________________\n", "flatten_Flatten1 (Flatten) [null,1024] 0 \n", "_________________________________________________________________\n", "dropout_Dropout1 (Dropout) [null,1024] 0 \n", "_________________________________________________________________\n", "dense_Dense1 (Dense) [null,512] 524800 \n", "_________________________________________________________________\n", "dropout_Dropout2 (Dropout) [null,512] 0 \n", "_________________________________________________________________\n", "dense_Dense2 (Dense) [null,10] 5130 \n", "=================================================================\n", "Total params: 594922\n", "Trainable params: 594922\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "Epoch 1 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "40717ms 798us/step - acc=0.920 loss=0.245 val_acc=0.979 val_loss=0.0735 \n", "Epoch 2 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "39984ms 784us/step - acc=0.980 loss=0.0674 val_acc=0.990 val_loss=0.0360 \n", "Epoch 3 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "40109ms 786us/step - acc=0.985 loss=0.0491 val_acc=0.990 val_loss=0.0371 \n", "Epoch 4 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "42172ms 827us/step - acc=0.988 loss=0.0379 val_acc=0.992 val_loss=0.0294 \n", "Epoch 5 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "42451ms 832us/step - acc=0.990 loss=0.0320 val_acc=0.992 val_loss=0.0285 \n", "Epoch 6 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "42674ms 837us/step - acc=0.991 loss=0.0283 val_acc=0.987 val_loss=0.0481 \n", "Epoch 7 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "42504ms 833us/step - acc=0.993 loss=0.0234 val_acc=0.992 val_loss=0.0263 \n", "Epoch 8 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "43120ms 845us/step - acc=0.993 loss=0.0218 val_acc=0.993 val_loss=0.0263 \n", "Epoch 9 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "42818ms 840us/step - acc=0.994 loss=0.0191 val_acc=0.993 val_loss=0.0274 \n", "Epoch 10 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "43198ms 847us/step - acc=0.994 loss=0.0177 val_acc=0.994 val_loss=0.0213 \n", "Epoch 11 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "43481ms 853us/step - acc=0.995 loss=0.0150 val_acc=0.994 val_loss=0.0253 \n", "Epoch 12 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "43164ms 846us/step - acc=0.995 loss=0.0154 val_acc=0.994 val_loss=0.0263 \n", "Epoch 13 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "42980ms 843us/step - acc=0.995 loss=0.0135 val_acc=0.994 val_loss=0.0251 \n", "Epoch 14 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "43289ms 849us/step - acc=0.996 loss=0.0126 val_acc=0.994 val_loss=0.0255 \n", "Epoch 15 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "43104ms 845us/step - acc=0.996 loss=0.0113 val_acc=0.992 val_loss=0.0333 \n", "Epoch 16 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "43385ms 851us/step - acc=0.997 loss=0.0102 val_acc=0.993 val_loss=0.0320 \n", "Epoch 17 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "43223ms 848us/step - acc=0.996 loss=0.0106 val_acc=0.993 val_loss=0.0308 \n", "Epoch 18 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "43164ms 846us/step - acc=0.997 loss=9.44e-3 val_acc=0.994 val_loss=0.0329 \n", "Epoch 19 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "43324ms 849us/step - acc=0.997 loss=8.45e-3 val_acc=0.994 val_loss=0.0319 \n", "Epoch 20 / 20\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "42775ms 839us/step - acc=0.997 loss=8.62e-3 val_acc=0.994 val_loss=0.0270 \n", "\n", "Evaluation result:\n", " Loss = 0.021; Accuracy = 0.994\n", "Saved model to path: mnist\n", "\u001b[90mundefined\u001b[39m\n" ] } ], "source": [ "// Hack to suppress the progress bar\n", "process.stderr.isTTY = false;\n", "\n", "const epochs = 20;\n", "const batchSize = 128;\n", "const modelSavePath = 'mnist'\n", "await train(epochs, batchSize, modelSavePath);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "const predicted =\n", " tf.argMax(model.predict(mnist.getTestData().images) as tf.Tensor, 1).arraySync() as number[];" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "\n", "
6
\n", "
\n", "
\n", "\n", "
0
\n", "
\n", "
\n", "\n", "
5
\n", "
\n", "
\n", "\n", "
4
\n", "
\n", "
\n", "\n", "
9
\n", "
\n", "
\n", "\n", "
9
\n", "
\n", "
\n", "\n", "
2
\n", "
\n", "
\n", "\n", "
1
\n", "
\n", "
\n", "\n", "
9
\n", "
\n", "
\n", "\n", "
4
\n", "
\n", "
\n", "\n", "
8
\n", "
\n", "
\n", "\n", "
7
\n", "
\n", "
\n", "\n", "
3
\n", "
\n", "
\n", "\n", "
9
\n", "
\n", "
\n", "\n", "
7
\n", "
\n", "
\n", "\n", "
4
\n", "
\n", "
\n", "\n", "
4
\n", "
\n", "
\n", "\n", "
4
\n", "
\n", "
\n", "\n", "
9
\n", "
\n", "
\n", "\n", "
2
\n", "
\n", "
\n", "\n", "
5
\n", "
\n", "
\n", "\n", "
4
\n", "
\n", "
\n", "\n", "
7
\n", "
\n", "
\n", "\n", "
6
\n", "
\n", "
\n", "\n", "
7
\n", "
\n", "
\n", "\n", "
9
\n", "
\n", "
\n", "\n", "
0
\n", "
\n", "
\n", "\n", "
5
\n", "
\n", "
\n", "\n", "
8
\n", "
\n", "
\n", "\n", "
5
\n", "
\n", "
\n", "\n", "
6
\n", "
\n", "
\n", "\n", "
6
\n", "
\n", "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "{\n", " let start = 100;\n", " let size = 32;\n", " const html: string[] = [];\n", " const pngs = await toPng(mnist.getTestData().images, start, size);\n", " html.push('
')\n", " for (let i = 0; i < size; i++) {\n", " const pred = predicted[i + start];\n", " html.push('
');\n", " html.push(``);\n", " html.push(`
${pred}
`)\n", " html.push('
');\n", " }\n", " html.push('
')\n", " display.html(html.join('\\n'));\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "\n", "
Label: 4, Prediction: 2
\n", "
\n", "
\n", "\n", "
Label: 6, Prediction: 0
\n", "
\n", "
\n", "\n", "
Label: 3, Prediction: 5
\n", "
\n", "
\n", "\n", "
Label: 8, Prediction: 2
\n", "
\n", "
\n", "\n", "
Label: 2, Prediction: 1
\n", "
\n", "
\n", "\n", "
Label: 8, Prediction: 9
\n", "
\n", "
\n", "\n", "
Label: 6, Prediction: 5
\n", "
\n", "
\n", "\n", "
Label: 7, Prediction: 1
\n", "
\n", "
\n", "\n", "
Label: 7, Prediction: 2
\n", "
\n", "
\n", "\n", "
Label: 9, Prediction: 4
\n", "
\n", "
\n", "\n", "
Label: 9, Prediction: 5
\n", "
\n", "
\n", "\n", "
Label: 5, Prediction: 3
\n", "
\n", "
\n", "\n", "
Label: 5, Prediction: 3
\n", "
\n", "
\n", "\n", "
Label: 0, Prediction: 7
\n", "
\n", "
\n", "\n", "
Label: 2, Prediction: 7
\n", "
\n", "
\n", "\n", "
Label: 9, Prediction: 4
\n", "
\n", "
\n", "\n", "
Label: 1, Prediction: 2
\n", "
\n", "
\n", "\n", "
Label: 5, Prediction: 3
\n", "
\n", "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "// Sow examples the model failed to predict correct labels.\n", "{\n", " let start = 100;\n", " let size = 2000;\n", " const html: string[] = [];\n", " const pngs = await toPng(mnist.getTestData().images, start, size);\n", " const labels = tf.argMax(mnist.getTestData().labels, 1).arraySync() as number[];\n", " html.push('
')\n", " for (let i = 0; i < size; i++) {\n", " const pred = predicted[i + start];\n", " const label = labels[i + start];\n", " if (pred === label) {\n", " continue;\n", " }\n", " html.push('
');\n", " html.push(``);\n", " html.push(`
Label: ${label}, Prediction: ${pred}
`)\n", " html.push('
');\n", " }\n", " html.push('
')\n", " display.html(html.join('\\n'));\n", "}" ] } ], "metadata": { "kernelspec": { "display_name": "TypeScript", "language": "typescript", "name": "tslab" }, "language_info": { "codemirror_mode": { "mode": "typescript", "name": "javascript", "typescript": true }, "file_extension": ".ts", "mimetype": "text/typescript", "name": "typescript", "version": "3.7.2" } }, "nbformat": 4, "nbformat_minor": 4 }