{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import numpy as np\n", "from sklearn.preprocessing import StandardScaler\n", "from keras.datasets import mnist\n", "from keras.models import Sequential\n", "from keras.layers import Dense\n", "from keras.utils import np_utils" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'2.0.5'" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import keras\n", "keras.__version__" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "np.random.seed(42)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz\n", "10903552/11490434 [===========================>..] - ETA: 0s \b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b" ] } ], "source": [ "(X_train, y_train), (X_test, y_test) = mnist.load_data()" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Картинки 24х24, развернем каждую в вектор длины 784." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "data": { "text/plain": [ "((60000, 28, 28), (10000, 28, 28))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.shape, X_test.shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "X_train = X_train.reshape(X_train.shape[0], 784).astype('float32')\n", "X_test = X_test.reshape(X_test.shape[0], 784).astype('float32')" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Масштабируем данные." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "scaler = StandardScaler()\n", "X_train_scaled = scaler.fit_transform(X_train)\n", "X_test_scaled = scaler.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "y_train = np_utils.to_categorical(y_train, 10)\n", "y_test = np_utils.to_categorical(y_test, 10)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "data": { "text/plain": [ "array([[ 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n", " [ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train[:3, :]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "model = Sequential()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "model.add(Dense(800, input_dim=784, kernel_initializer='normal', activation='relu'))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "model.add(Dense(10, kernel_initializer='normal', activation='softmax'))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "model.compile(loss='categorical_crossentropy', optimizer='SGD', metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "dense_1 (Dense) (None, 800) 628000 \n", "_________________________________________________________________\n", "dense_2 (Dense) (None, 10) 8010 \n", "_________________________________________________________________\n", "dense_3 (Dense) (None, 800) 8800 \n", "_________________________________________________________________\n", "dense_4 (Dense) (None, 10) 8010 \n", "=================================================================\n", "Total params: 652,820\n", "Trainable params: 652,820\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "print(model.summary())" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Визуализация модели (нужно выполнить pip install pydot-ng)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "G\n", "\n", "\n", "139891033145040\n", "\n", "dense_1_input: InputLayer\n", "\n", "input:\n", "\n", "output:\n", "\n", "(None, 784)\n", "\n", "(None, 784)\n", "\n", "\n", "139891033144464\n", "\n", "dense_1: Dense\n", "\n", "input:\n", "\n", "output:\n", "\n", "(None, 784)\n", "\n", "(None, 800)\n", "\n", "\n", "139891033145040->139891033144464\n", "\n", "\n", "\n", "\n", "139890994950288\n", "\n", "dense_2: Dense\n", "\n", "input:\n", "\n", "output:\n", "\n", "(None, 800)\n", "\n", "(None, 10)\n", "\n", "\n", "139891033144464->139890994950288\n", "\n", "\n", "\n", "\n", "139890995173520\n", "\n", "dense_3: Dense\n", "\n", "input:\n", "\n", "output:\n", "\n", "(None, 10)\n", "\n", "(None, 800)\n", "\n", "\n", "139890994950288->139890995173520\n", "\n", "\n", "\n", "\n", "139890995174288\n", "\n", "dense_4: Dense\n", "\n", "input:\n", "\n", "output:\n", "\n", "(None, 800)\n", "\n", "(None, 10)\n", "\n", "\n", "139890995173520->139890995174288\n", "\n", "\n", "\n", "\n", "" ], "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import SVG\n", "from keras.utils.vis_utils import model_to_dot\n", "\n", "SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Callbacks\n", "### Ранняя остановка\n", "Нужно также сказать несколько слов о такой важной особенности Keras, как колбеки. Через них реализовано много полезной функциональности. Например, если вы тренируете сеть в течение очень долгого времени, вам нужно понять, когда пора остановиться, если ошибка на вашем датасете перестала уменьшаться. По-английски описываемая функциональность называется \"early stopping\" (\"ранняя остановка\")." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "from keras.callbacks import EarlyStopping \n", "early_stopping = EarlyStopping(monitor='loss')" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Tensorboard\n", "Еще в качестве колбека можно использовать сохранение логов в формате, удобном для Tensorboard — это специальная утилита для обработки и визуализации информации из логов Tensorflow.\n", "После того, как обучение закончится (или даже в процессе!), вы можете запустить Tensorboard, указав абсолютный путь к директории с логами:\n", "tensorboard --logdir=/path/to/logs" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "from keras.callbacks import TensorBoard \n", "tensorboard = TensorBoard(log_dir='../logs/', write_graph=True)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 48000 samples, validate on 12000 samples\n", "Epoch 1/100\n", "4s - loss: 0.5569 - acc: 0.8703 - val_loss: 0.6286 - val_acc: 0.8392\n", "Epoch 2/100\n", "4s - loss: 0.5331 - acc: 0.8825 - val_loss: 0.6521 - val_acc: 0.8056\n", "Epoch 3/100\n", "4s - loss: 0.5124 - acc: 0.8879 - val_loss: 0.5859 - val_acc: 0.8623\n", "Epoch 4/100\n", "4s - loss: 0.4892 - acc: 0.8970 - val_loss: 0.5624 - val_acc: 0.8705\n", "Epoch 5/100\n", "4s - loss: 0.4695 - acc: 0.9018 - val_loss: 0.5454 - val_acc: 0.8783\n", "Epoch 6/100\n", "4s - loss: 0.4479 - acc: 0.9068 - val_loss: 0.5229 - val_acc: 0.8814\n", "Epoch 7/100\n", "4s - loss: 0.4141 - acc: 0.9184 - val_loss: 0.5296 - val_acc: 0.8748\n", "Epoch 8/100\n", "4s - loss: 0.3972 - acc: 0.9211 - val_loss: 0.5017 - val_acc: 0.8826\n", "Epoch 9/100\n", "4s - loss: 0.3762 - acc: 0.9268 - val_loss: 0.4880 - val_acc: 0.8828\n", "Epoch 10/100\n", "4s - loss: 0.3564 - acc: 0.9335 - val_loss: 0.4799 - val_acc: 0.8811\n", "Epoch 11/100\n", "4s - loss: 0.3405 - acc: 0.9365 - val_loss: 0.4492 - val_acc: 0.8964\n", "Epoch 12/100\n", "4s - loss: 0.3232 - acc: 0.9414 - val_loss: 0.4373 - val_acc: 0.8974\n", "Epoch 13/100\n", "4s - loss: 0.3100 - acc: 0.9435 - val_loss: 0.4296 - val_acc: 0.8984\n", "Epoch 14/100\n", "4s - loss: 0.3022 - acc: 0.9439 - val_loss: 0.4260 - val_acc: 0.8979\n", "Epoch 15/100\n", "4s - loss: 0.2862 - acc: 0.9485 - val_loss: 0.4138 - val_acc: 0.9018\n", "Epoch 16/100\n", "4s - loss: 0.2721 - acc: 0.9518 - val_loss: 0.4070 - val_acc: 0.9023\n", "Epoch 17/100\n", "4s - loss: 0.2571 - acc: 0.9555 - val_loss: 0.4056 - val_acc: 0.9024\n", "Epoch 18/100\n", "4s - loss: 0.2493 - acc: 0.9563 - val_loss: 0.4482 - val_acc: 0.8796\n", "Epoch 19/100\n", "4s - loss: 0.2418 - acc: 0.9570 - val_loss: 0.3796 - val_acc: 0.9093\n", "Epoch 20/100\n", "5s - loss: 0.2295 - acc: 0.9616 - val_loss: 0.3772 - val_acc: 0.9117\n", "Epoch 21/100\n", "4s - loss: 0.2222 - acc: 0.9615 - val_loss: 0.4037 - val_acc: 0.8967\n", "Epoch 22/100\n", "4s - loss: 0.2116 - acc: 0.9644 - val_loss: 0.3631 - val_acc: 0.9149\n", "Epoch 23/100\n", "4s - loss: 0.2025 - acc: 0.9651 - val_loss: 0.3645 - val_acc: 0.9127\n", "Epoch 24/100\n", "4s - loss: 0.1950 - acc: 0.9668 - val_loss: 0.3583 - val_acc: 0.9146\n", "Epoch 25/100\n", "4s - loss: 0.1936 - acc: 0.9666 - val_loss: 0.3562 - val_acc: 0.9158\n", "Epoch 26/100\n", "4s - loss: 0.1830 - acc: 0.9693 - val_loss: 0.3564 - val_acc: 0.9127\n", "Epoch 27/100\n", "4s - loss: 0.1747 - acc: 0.9707 - val_loss: 0.3471 - val_acc: 0.9167\n", "Epoch 28/100\n", "4s - loss: 0.1705 - acc: 0.9712 - val_loss: 0.3482 - val_acc: 0.9179\n", "Epoch 29/100\n", "4s - loss: 0.1626 - acc: 0.9732 - val_loss: 0.3422 - val_acc: 0.9197\n", "Epoch 30/100\n", "4s - loss: 0.1581 - acc: 0.9746 - val_loss: 0.3423 - val_acc: 0.9192\n", "Epoch 31/100\n", "4s - loss: 0.1523 - acc: 0.9753 - val_loss: 0.3464 - val_acc: 0.9158\n", "Epoch 32/100\n", "4s - loss: 0.1467 - acc: 0.9762 - val_loss: 0.3396 - val_acc: 0.9194\n", "Epoch 33/100\n", "4s - loss: 0.1430 - acc: 0.9763 - val_loss: 0.3377 - val_acc: 0.9190\n", "Epoch 34/100\n", "4s - loss: 0.1371 - acc: 0.9783 - val_loss: 0.3325 - val_acc: 0.9211\n", "Epoch 35/100\n", "4s - loss: 0.1348 - acc: 0.9787 - val_loss: 0.3368 - val_acc: 0.9192\n", "Epoch 36/100\n", "4s - loss: 0.1287 - acc: 0.9800 - val_loss: 0.3310 - val_acc: 0.9213\n", "Epoch 37/100\n", "5s - loss: 0.1279 - acc: 0.9793 - val_loss: 0.3345 - val_acc: 0.9198\n", "Epoch 38/100\n", "5s - loss: 0.1243 - acc: 0.9800 - val_loss: 0.3408 - val_acc: 0.9179\n", "Epoch 39/100\n", "4s - loss: 0.1193 - acc: 0.9814 - val_loss: 0.3307 - val_acc: 0.9202\n", "Epoch 40/100\n", "4s - loss: 0.1152 - acc: 0.9823 - val_loss: 0.3295 - val_acc: 0.9209\n", "Epoch 41/100\n", "5s - loss: 0.1143 - acc: 0.9823 - val_loss: 0.3320 - val_acc: 0.9186\n", "Epoch 42/100\n", "5s - loss: 0.1097 - acc: 0.9831 - val_loss: 0.3315 - val_acc: 0.9213\n", "Epoch 43/100\n", "5s - loss: 0.1085 - acc: 0.9833 - val_loss: 0.3339 - val_acc: 0.9207\n", "Epoch 44/100\n", "5s - loss: 0.1060 - acc: 0.9836 - val_loss: 0.3398 - val_acc: 0.9178\n", "Epoch 45/100\n", "4s - loss: 0.1031 - acc: 0.9846 - val_loss: 0.3299 - val_acc: 0.9214\n", "Epoch 46/100\n", "4s - loss: 0.1003 - acc: 0.9850 - val_loss: 0.3347 - val_acc: 0.9179\n", "Epoch 47/100\n", "5s - loss: 0.0979 - acc: 0.9856 - val_loss: 0.3342 - val_acc: 0.9204\n", "Epoch 48/100\n", "5s - loss: 0.0970 - acc: 0.9858 - val_loss: 0.3282 - val_acc: 0.9215\n", "Epoch 49/100\n", "5s - loss: 0.0946 - acc: 0.9860 - val_loss: 0.3316 - val_acc: 0.9200\n", "Epoch 50/100\n", "4s - loss: 0.0924 - acc: 0.9861 - val_loss: 0.3309 - val_acc: 0.9210\n", "Epoch 51/100\n", "4s - loss: 0.0898 - acc: 0.9875 - val_loss: 0.3294 - val_acc: 0.9222\n", "Epoch 52/100\n", "5s - loss: 0.0878 - acc: 0.9874 - val_loss: 0.3385 - val_acc: 0.9185\n", "Epoch 53/100\n", "4s - loss: 0.0859 - acc: 0.9879 - val_loss: 0.3258 - val_acc: 0.9207\n", "Epoch 54/100\n", "4s - loss: 0.0853 - acc: 0.9878 - val_loss: 0.3314 - val_acc: 0.9211\n", "Epoch 55/100\n", "5s - loss: 0.0824 - acc: 0.9885 - val_loss: 0.3322 - val_acc: 0.9214\n", "Epoch 56/100\n", "5s - loss: 0.0826 - acc: 0.9883 - val_loss: 0.3485 - val_acc: 0.9147\n", "CPU times: user 23min 40s, sys: 10min 57s, total: 34min 38s\n", "Wall time: 4min 38s\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "model.fit(X_train_scaled, y_train, batch_size=200, epochs=100, \n", " validation_split=0.2, callbacks=[early_stopping, tensorboard], verbose=2);" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "C помощью Tensorboard удобно отслеживать процесс обучения нейронной сети. " ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Оцениваем качество обучения сети на тестовых данных" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Доля верных ответов на тестовых данных: 91.04%\n" ] } ], "source": [ "scores = model.evaluate(X_test_scaled, y_test, verbose=0)\n", "print(\"Доля верных ответов на тестовых данных: %.2f%%\" % (scores[1]*100))" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Сохраняем сеть\n", "**Архитектуру – в JSON-файл**" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "model_json = model.to_json()\n", "with open(\"mnist_model.json\", \"w\") as json_file:\n", " json_file.write(model_json)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "import pprint\n", "import json" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{u'backend': u'tensorflow',\n", " u'class_name': u'Sequential',\n", " u'config': [{u'class_name': u'Dense',\n", " u'config': {u'activation': u'relu',\n", " u'activity_regularizer': None,\n", " u'batch_input_shape': [None, 784],\n", " u'bias_constraint': None,\n", " u'bias_initializer': {u'class_name': u'Zeros',\n", " u'config': {}},\n", " u'bias_regularizer': None,\n", " u'dtype': u'float32',\n", " u'kernel_constraint': None,\n", " u'kernel_initializer': {u'class_name': u'RandomNormal',\n", " u'config': {u'mean': 0.0,\n", " u'seed': None,\n", " u'stddev': 0.05}},\n", " u'kernel_regularizer': None,\n", " u'name': u'dense_1',\n", " u'trainable': True,\n", " u'units': 800,\n", " u'use_bias': True}},\n", " {u'class_name': u'Dense',\n", " u'config': {u'activation': u'softmax',\n", " u'activity_regularizer': None,\n", " u'bias_constraint': None,\n", " u'bias_initializer': {u'class_name': u'Zeros',\n", " u'config': {}},\n", " u'bias_regularizer': None,\n", " u'kernel_constraint': None,\n", " u'kernel_initializer': {u'class_name': u'RandomNormal',\n", " u'config': {u'mean': 0.0,\n", " u'seed': None,\n", " u'stddev': 0.05}},\n", " u'kernel_regularizer': None,\n", " u'name': u'dense_2',\n", " u'trainable': True,\n", " u'units': 10,\n", " u'use_bias': True}},\n", " {u'class_name': u'Dense',\n", " u'config': {u'activation': u'relu',\n", " u'activity_regularizer': None,\n", " u'batch_input_shape': [None, 784],\n", " u'bias_constraint': None,\n", " u'bias_initializer': {u'class_name': u'Zeros',\n", " u'config': {}},\n", " u'bias_regularizer': None,\n", " u'dtype': u'float32',\n", " u'kernel_constraint': None,\n", " u'kernel_initializer': {u'class_name': u'RandomNormal',\n", " u'config': {u'mean': 0.0,\n", " u'seed': None,\n", " u'stddev': 0.05}},\n", " u'kernel_regularizer': None,\n", " u'name': u'dense_3',\n", " u'trainable': True,\n", " u'units': 800,\n", " u'use_bias': True}},\n", " {u'class_name': u'Dense',\n", " u'config': {u'activation': u'softmax',\n", " u'activity_regularizer': None,\n", " u'bias_constraint': None,\n", " u'bias_initializer': {u'class_name': u'Zeros',\n", " u'config': {}},\n", " u'bias_regularizer': None,\n", " u'kernel_constraint': None,\n", " u'kernel_initializer': {u'class_name': u'RandomNormal',\n", " u'config': {u'mean': 0.0,\n", " u'seed': None,\n", " u'stddev': 0.05}},\n", " u'kernel_regularizer': None,\n", " u'name': u'dense_4',\n", " u'trainable': True,\n", " u'units': 10,\n", " u'use_bias': True}}],\n", " u'keras_version': u'2.0.5'}\n" ] } ], "source": [ "with open(\"mnist_model.json\", \"r\") as json_file:\n", " pprint.pprint(json.loads(json_file.read()))" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "**Веса сохраняем в бинарный hd5-файл**" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "model.save_weights(\"mnist_model.h5\")" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "**Теперь сеть можно восстановить и использовать**" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "from keras.models import model_from_json\n", "\n", "with open(\"mnist_model.json\") as json_file:\n", " loaded_model_json = json_file.read()\n", "\n", "# Создаем модель на основе загруженных данных\n", "loaded_model = model_from_json(loaded_model_json)\n", "# Загружаем веса в модель\n", "loaded_model.load_weights(\"mnist_model.h5\")" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Перед использованием модели ее обязательно нужно скомпилировать. " ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "model.compile(loss='categorical_crossentropy', optimizer='SGD', metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Доля верных ответов на тестовых данных: 91.04%\n" ] } ], "source": [ "scores = model.evaluate(X_test_scaled, y_test, verbose=0)\n", "print(\"Доля верных ответов на тестовых данных: %.2f%%\" % (scores[1]*100))" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda root]", "language": "python", "name": "conda-root-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.13" } }, "nbformat": 4, "nbformat_minor": 2 }