{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Implementing a simple deep neural network for MNIST classifier\n", "\n", "At firs, we need to load mxnet gem." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "true" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "require 'mxnet'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting up the configuration\n", "\n", "Initialize the variables that are used below." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "#" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@data_dir = File.expand_path(\"../data/mnist\")\n", "@data_ctx = MXNet.cpu\n", "@model_ctx = MXNet.cpu\n", "#@model_ctx = MXNet.gpu" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that if you have CUDA available GPU, `@model_ctx = MXNet.gpu` enables us to use GPU for all the computation below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data loaders\n", "\n", "Setting up data loaders for both training and validation." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "num_inputs = 784\n", "num_outputs = 10\n", "batch_size = 64\n", "num_examples = 60000\n", "\n", "train_iter = MXNet::IO::MNISTIter.new(\n", " image: File.join(@data_dir, 'train-images-idx3-ubyte'),\n", " label: File.join(@data_dir, 'train-labels-idx1-ubyte'),\n", " batch_size: batch_size,\n", " shuffle: true)\n", "val_iter = MXNet::IO::MNISTIter.new(\n", " image: File.join(@data_dir, 't10k-images-idx3-ubyte'),\n", " label: File.join(@data_dir, 't10k-labels-idx1-ubyte'),\n", " batch_size: batch_size,\n", " shuffle: false)\n", "nil" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The parameters of the model\n", "\n", "Initialize the weights and biases of the neural network." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#######################\n", "# Set some constants so it's easy to modify the network later\n", "#######################\n", "num_hidden = 256\n", "weight_scale = 0.01\n", "\n", "#######################\n", "# Allocate parameters for the first hidden layer\n", "#######################\n", "@w1 = MXNet::NDArray.random_normal(shape: [num_inputs, num_hidden], scale: weight_scale, ctx: @model_ctx)\n", "@b1 = MXNet::NDArray.random_normal(shape: [num_hidden], scale: weight_scale, ctx: @model_ctx)\n", "\n", "#######################\n", "# Allocate parameters for the second hidden layer\n", "#######################\n", "@w2 = MXNet::NDArray.random_normal(shape: [num_hidden, num_hidden], scale: weight_scale, ctx: @model_ctx)\n", "@b2 = MXNet::NDArray.random_normal(shape: [num_hidden], scale: weight_scale, ctx: @model_ctx)\n", "\n", "#######################\n", "# Allocate parameters for the output layer\n", "#######################\n", "@w3 = MXNet::NDArray.random_normal(shape: [num_hidden, num_outputs], scale: weight_scale, ctx: @model_ctx)\n", "@b3 = MXNet::NDArray.random_normal(shape: [num_outputs], scale: weight_scale, ctx: @model_ctx)\n", "\n", "nil" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Mark all the parameters to be calculated their gradients automatically." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "@all_params = [@w1, @b1, @w2, @b2, @w3, @b3]\n", "@all_params.each(&:attach_grad)\n", "nil" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Activation and loss functions\n", "\n", "Define ReLU activation function." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":relu" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def relu(x)\n", " MXNet::NDArray.maximum(x, MXNet::NDArray.zeros_like(x))\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define softmax cross entropy function that is used to calculate for computing predictionlosses." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":softmax_cross_entropy" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def softmax_cross_entropy(y_hat_linear, y)\n", " return -MXNet::NDArray.nansum(y * MXNet::NDArray.log_softmax(y_hat_linear), axis: 0, exclude: true)\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model definition\n", "\n", "The definition of the neural network." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":net" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def net(x)\n", " # first hidden layer\n", " h1_linear = MXNet::NDArray.dot(x, @w1) + @b1\n", " h1 = relu(h1_linear)\n", "\n", " # second hidden layer\n", " h2_linear = MXNet::NDArray.dot(h1, @w2) + @b2\n", " h2 = relu(h2_linear)\n", "\n", " # output layer\n", " y_hat_linear = MXNet::NDArray.dot(h2, @w3) + @b3\n", " return y_hat_linear\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parameter optimizer\n", "\n", "And parameter optimizer. In this notebook, stochastic gradient descent is used." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":sgd" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def sgd(params, lr)\n", " params.each do |param|\n", " param[0..-1] = param - lr * param.grad\n", " end\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluator\n", "\n", "The next function is calculate the prediction accuracy for the given data set." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":evaluate_accuracy" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def evaluate_accuracy(data_iter)\n", " numerator = 0.0\n", " denominator = 0.0\n", " data_iter.each_with_index do |batch, i|\n", " data = batch.data[0].as_in_context(@model_ctx).reshape([-1, 784])\n", " label = batch.label[0].as_in_context(@model_ctx)\n", " output = net(data)\n", " predictions = MXNet::NDArray.argmax(output, axis: 1)\n", " numerator += MXNet::NDArray.sum(predictions == label)\n", " denominator += data.shape[0]\n", " end\n", " return (numerator / denominator).as_scalar\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training loop\n", "\n", "Execute the training loop for 10 epochs." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0. Loss: 1.2580198553880055, Train_acc 0.8733324408531189, Val_acc 0.8747996687889099 (3.4784769 sec)\n", "Epoch 1. Loss: 0.3402405142625173, Train_acc 0.925293505191803, Val_acc 0.9240785241127014 (2.9854487 sec)\n", "Epoch 2. Loss: 0.2273845267256101, Train_acc 0.9493563175201416, Val_acc 0.9460136294364929 (3.1201378 sec)\n", "Epoch 3. Loss: 0.16553548452655475, Train_acc 0.9613627195358276, Val_acc 0.9580328464508057 (3.3423264 sec)\n", "Epoch 4. Loss: 0.1294232620060444, Train_acc 0.969266951084137, Val_acc 0.9645432829856873 (3.1851936 sec)\n", "Epoch 5. Loss: 0.10531740031341712, Train_acc 0.9744197130203247, Val_acc 0.9686498641967773 (3.008946 sec)\n", "Epoch 6. Loss: 0.08776766262004773, Train_acc 0.9786719679832458, Val_acc 0.9706530570983887 (2.9635756 sec)\n", "Epoch 7. Loss: 0.07455756265173356, Train_acc 0.9818903207778931, Val_acc 0.97265625 (2.9956695 sec)\n", "Epoch 8. Loss: 0.06412682893921931, Train_acc 0.9842416048049927, Val_acc 0.973557710647583 (2.9654095 sec)\n", "Epoch 9. Loss: 0.05563920784989993, Train_acc 0.9862593412399292, Val_acc 0.9748597741127014 (2.9865925 sec)\n" ] }, { "data": { "text/plain": [ "10" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "epochs = 10\n", "learning_rate = 0.001\n", "smoothing_constant = 0.01\n", "\n", "epochs.times do |e|\n", " start = Time.now\n", " cumulative_loss = 0.0\n", " train_iter.each_with_index do |batch, i|\n", " data = batch.data[0].as_in_context(@model_ctx).reshape([-1, 784])\n", " label = batch.label[0].as_in_context(@model_ctx)\n", " label_one_hot = MXNet::NDArray.one_hot(label, depth: 10)\n", " loss = MXNet::Autograd.record do\n", " output = net(data)\n", " softmax_cross_entropy(output, label_one_hot)\n", " end\n", " loss.backward()\n", " sgd(@all_params, learning_rate)\n", " cumulative_loss += MXNet::NDArray.sum(loss).as_scalar\n", " end\n", " \n", " val_accuracy = evaluate_accuracy(val_iter)\n", " train_accuracy = evaluate_accuracy(train_iter)\n", " duration = Time.now - start\n", " puts \"Epoch #{e}. Loss: #{cumulative_loss/num_examples}, Train_acc #{train_accuracy}, Val_acc #{val_accuracy} (#{duration} sec)\"\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we have the model with 0.97 validation accuracy." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prediction with the trained model\n", "\n", "Let's use the trained model for prediction.\n", "\n", "We need the following helper function to display the input image." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":imshow" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "require 'chunky_png'\n", "require 'base64'\n", "\n", "def imshow(ary)\n", " height, width = ary.shape\n", " fig = ChunkyPNG::Image.new(width, height, ChunkyPNG::Color::TRANSPARENT)\n", " ary = ((ary - ary.min) / ary.max) * 255\n", " 0.upto(height - 1) do |i|\n", " 0.upto(width - 1) do |j|\n", " v = ary[i, j].round\n", " fig[j, i] = ChunkyPNG::Color.rgba(v, v, v, 255)\n", " end\n", " end\n", "\n", " src = 'data:image/png;base64,' + Base64.strict_encode64(fig.to_blob)\n", " IRuby.display \"\", mime: 'text/html'\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the function for prediction." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ ":predict" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def predict(data)\n", " output = net(data)\n", " MXNet::NDArray.argmax(output, axis: 1)\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a new data iterator for prediction.\n", "And generate predictions for the first 10 samples." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "\"\"" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "model predictions are: \n", "[0, 2, 7, 8, 7, 1, 5, 4, 3, 9]\n", "\n", "\n", "true labels: \n", "[0, 2, 3, 8, 7, 1, 5, 8, 3, 9]\n", "\n" ] } ], "source": [ "sample_size = 10\n", "sample_iter = test_iter = MXNet::IO::MNISTIter.new(\n", " image: File.join(@data_dir, 't10k-images-idx3-ubyte'),\n", " label: File.join(@data_dir, 't10k-labels-idx1-ubyte'),\n", " batch_size: sample_size,\n", " shuffle: true,\n", " seed: rand(100))\n", "sample_iter.each do |batch|\n", " data = batch.data[0].as_in_context(@model_ctx)\n", " label = batch.label[0]\n", "\n", " im = data.transpose(axes: [1, 0, 2, 3]).reshape([10*28, 28, 1])\n", " imshow(im[0..-1, 0..-1, 0].to_narray)\n", "\n", " pred = predict(data.reshape([-1, 784]))\n", " puts \"model predictions are: #{pred.inspect}\"\n", " puts\n", " puts \"true labels: #{label.inspect}\"\n", " break\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "In this notebook, the simple neural network model for MNIST classifier is implemented by MXNet NDArray API.\n", "\n", "More complex example is available here: https://github.com/mrkn/mxnet.rb/blob/taiwan2018/example/scratch/resnet/wrn.rb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Ruby 2.5.1", "language": "ruby", "name": "ruby" }, "language_info": { "file_extension": ".rb", "mimetype": "application/x-ruby", "name": "ruby", "version": "2.5.1" } }, "nbformat": 4, "nbformat_minor": 2 }