{ "metadata": { "name": "TEST_theanoml_formula" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "code", "collapsed": false, "input": [ "import theanoml\n", "reload(theanoml)" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "pyout", "prompt_number": 10, "text": [ "" ] } ], "prompt_number": 10 }, { "cell_type": "code", "collapsed": false, "input": [ "## Load MNIST data\n", "reload(theanoml.formula)\n", "import cPickle\n", "X, y = cPickle.load(open('data/digits.pkl', 'rb'))\n", "print X.shape, y.shape\n", "from sklearn.cross_validation import train_test_split\n", "train_X, validation_X, train_y, validation_y = train_test_split(X, y, test_size = 0.2)\n", "v_train_X, v_validation_X = theanoml.formula.share_data(train_X), theanoml.formula.share_data(validation_X)\n", "v_train_y, v_validation_y = theanoml.formula.share_data(train_y, dtype = 'int32'), theanoml.formula.share_data(validation_y, dtype='int32')" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "(42000, 784) (42000,)\n" ] } ], "prompt_number": 27 }, { "cell_type": "code", "collapsed": false, "input": [ "## TEST theanoml LogisticRegression MLP\n", "reload(theanoml.formula)\n", "mlp = theanoml.formula.MLPClassifierFormula(n_in = 28 * 28, n_hidden = 500, n_out = 10)\n", "theanoml.formula.sgd(v_train_X, v_train_y, v_validation_X, v_validation_y, mlp)" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "epoch 1, minibatch 1680 / 1680, validation error 10.511905 %\n", "epoch 2, minibatch 1680 / 1680, validation error 10.238095 %" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "epoch 3, minibatch 1680 / 1680, validation error 9.928571 %" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "epoch 4, minibatch 1680 / 1680, validation error 9.261905 %" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "epoch 5, minibatch 1680 / 1680, validation error 9.511905 %" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "epoch 6, minibatch 1680 / 1680, validation error 9.226190 %" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "epoch 7, minibatch 1680 / 1680, validation error 9.202381 %" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "optimization complete with best valiation error 9.202381 %" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n" ] } ], "prompt_number": 34 }, { "cell_type": "code", "collapsed": false, "input": [], "language": "python", "metadata": {}, "outputs": [] } ], "metadata": {} } ] }