{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Recognizing hand-written digits\n", "\n", "## Introduction\n", "\n", "This notebook adapts the existing example of applying support vector classification from scikit-learn ([https://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html#sphx-glr-auto-examples-classification-plot-digits-classification-py](https://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html#sphx-glr-auto-examples-classification-plot-digits-classification-py)) to PyRCN to demonstrate, how PyRCN can be used to classify hand-written digits.\n", "\n", "The tutorial is based on numpy, scikit-learn and PyRCN. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import time\n", "from sklearn.base import clone\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.model_selection import (\n", " ParameterGrid, RandomizedSearchCV, cross_validate)\n", "from scipy.stats import uniform, loguniform\n", "from sklearn.metrics import make_scorer\n", "\n", "from pyrcn.model_selection import SequentialSearchCV\n", "from pyrcn.echo_state_network import ESNClassifier\n", "from pyrcn.metrics import accuracy_score\n", "from pyrcn.datasets import load_digits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the dataset\n", "\n", "The dataset is already part of scikit-learn and consists of 1797 8x8 images. \n", "\n", "We are using our dataloader that is derived from scikit-learns dataloader and returns arrays of 8x8 sequences and corresponding labels." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of digits: 1797\n", "Shape of digits (8, 8)\n" ] } ], "source": [ "X, y = load_digits(return_X_y=True, as_sequence=True)\n", "print(\"Number of digits: {0}\".format(len(X)))\n", "print(\"Shape of digits {0}\".format(X[0].shape))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split dataset in training and test\n", "\n", "Afterwards, we split the dataset into training and test sets. We train the ESN using 80% of the digits and test it using the remaining images. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of digits in training set: 1437\n", "Shape of digits in training set: (8, 8)\n", "Number of digits in test set: 360\n", "Shape of digits in test set: (8, 8)\n" ] } ], "source": [ "stratify = np.asarray([np.unique(yt) for yt in y]).flatten()\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.2, stratify=stratify, random_state=42)\n", "X_tr = np.copy(X_train)\n", "y_tr = np.copy(y_train)\n", "X_te = np.copy(X_test)\n", "y_te = np.copy(y_test)\n", "for k, _ in enumerate(y_tr):\n", " y_tr[k] = np.repeat(y_tr[k], 8, 0)\n", "for k, _ in enumerate(y_te):\n", " y_te[k] = np.repeat(y_te[k], 8, 0)\n", "\n", "print(\"Number of digits in training set: {0}\".format(len(X_train)))\n", "print(\"Shape of digits in training set: {0}\".format(X_train[0].shape))\n", "print(\"Number of digits in test set: {0}\".format(len(X_test)))\n", "print(\"Shape of digits in test set: {0}\".format(X_test[0].shape))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set up a ESN\n", "\n", "To develop an ESN model for digit recognition, we need to tune several hyper-parameters, e.g., input_scaling, spectral_radius, bias_scaling and leaky integration.\n", "\n", "We follow the way proposed in the introductory paper of PyRCN to optimize hyper-parameters sequentially.\n", "\n", "We define the search spaces for each step together with the type of search (a grid search in this context).\n", "\n", "At last, we initialize a SeqToLabelESNClassifier with the desired output strategy and with the initially fixed parameters." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "initially_fixed_params = {\n", " 'hidden_layer_size': 50, 'input_activation': 'identity', 'k_in': 5,\n", " 'bias_scaling': 0.0, 'reservoir_activation': 'tanh', 'leakage': 1.0,\n", " 'bidirectional': False, 'k_rec': 10, 'continuation': False, 'alpha': 1e-5,\n", " 'random_state': 42, 'decision_strategy': \"winner_takes_all\"}\n", "\n", "step1_esn_params = {'input_scaling': uniform(loc=1e-2, scale=1),\n", " 'spectral_radius': uniform(loc=0, scale=2)}\n", "\n", "step2_esn_params = {'leakage': loguniform(1e-5, 1e0)}\n", "step3_esn_params = {'bias_scaling': uniform(loc=0, scale=2)}\n", "step4_esn_params = {'alpha': loguniform(1e-5, 1e0)}\n", "\n", "kwargs_step1 = {'n_iter': 200, 'random_state': 42, 'verbose': 10, 'n_jobs': -1,\n", " 'scoring': make_scorer(accuracy_score)}\n", "kwargs_step2 = {'n_iter': 50, 'random_state': 42, 'verbose': 1, 'n_jobs': -1,\n", " 'scoring': make_scorer(accuracy_score)}\n", "kwargs_step3 = {'verbose': 1, 'n_jobs': -1,\n", " 'scoring': make_scorer(accuracy_score)}\n", "kwargs_step4 = {'n_iter': 50, 'random_state': 42, 'verbose': 1, 'n_jobs': -1,\n", " 'scoring': make_scorer(accuracy_score)}\n", "\n", "# The searches are defined similarly to the steps of a sklearn.pipeline.Pipeline:\n", "searches = [('step1', RandomizedSearchCV, step1_esn_params, kwargs_step1),\n", " ('step2', RandomizedSearchCV, step2_esn_params, kwargs_step2),\n", " ('step3', RandomizedSearchCV, step3_esn_params, kwargs_step3),\n", " ('step4', RandomizedSearchCV, step4_esn_params, kwargs_step4)]\n", "\n", "base_esn = ESNClassifier(**initially_fixed_params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optimization\n", "\n", "We provide a SequentialSearchCV that basically iterates through the list of searches that we have defined before. It can be combined with any model selection tool from scikit-learn." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 5 folds for each of 200 candidates, totalling 1000 fits\n", "Fitting 5 folds for each of 50 candidates, totalling 250 fits\n", "Fitting 5 folds for each of 10 candidates, totalling 50 fits\n", "Fitting 5 folds for each of 50 candidates, totalling 250 fits\n" ] } ], "source": [ "sequential_search = SequentialSearchCV(base_esn, searches=searches).fit(\n", " X_tr, y_tr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Use the ESN with final hyper-parameters\n", "\n", "After the optimization, we extract the ESN with final hyper-parameters as the result of the optimization." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "base_esn = sequential_search.best_estimator_" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'bias_scaling': 1.500499867548985,\n", " 'bias_shift': 0.0,\n", " 'hidden_layer_size': 50,\n", " 'input_activation': 'identity',\n", " 'input_scaling': 0.016952130531190705,\n", " 'input_shift': 0.0,\n", " 'k_in': 5,\n", " 'predefined_bias_weights': None,\n", " 'predefined_input_weights': None,\n", " 'random_state': 42,\n", " 'sparsity': 0.2,\n", " 'bidirectional': False,\n", " 'k_rec': 10,\n", " 'leakage': 0.026373339933815243,\n", " 'predefined_recurrent_weights': None,\n", " 'reservoir_activation': 'tanh',\n", " 'spectral_radius': 1.0214946051551315,\n", " 'alpha': 1.2674255898937214e-05}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "base_esn.get_params()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'mean_fit_time': array([2.6532074 , 2.78561602, 2.60994244, 2.60310826, 2.60697007,\n", " 2.64569302, 2.61672802, 2.65451388, 2.60384569, 2.65692558,\n", " 2.65675936, 2.82678428, 2.80295968, 2.59712319, 2.6227159 ,\n", " 2.84535923, 2.58692093, 2.6632936 , 2.57469478, 2.56074982,\n", " 2.59788303, 2.588311 , 2.58742948, 2.59293227, 2.65385418,\n", " 2.59408164, 2.53729944, 2.56438246, 2.56678057, 2.5659205 ,\n", " 2.57332883, 2.58767653, 2.5798398 , 2.57694831, 2.61877756,\n", " 2.71321197, 2.86083097, 2.80925045, 3.00275102, 2.93799133,\n", " 2.87340975, 2.7778234 , 2.76509657, 2.81323166, 2.90686755,\n", " 2.84235563, 2.84695301, 2.85544257, 2.98605399, 2.61845927]),\n", " 'std_fit_time': array([0.14858599, 0.12348577, 0.03306796, 0.08086142, 0.03734891,\n", " 0.0695286 , 0.08522908, 0.07412123, 0.03764825, 0.05583214,\n", " 0.10757 , 0.1773657 , 0.16576643, 0.06928259, 0.13255758,\n", " 0.19700425, 0.0754604 , 0.09600206, 0.07056154, 0.09768716,\n", " 0.08431972, 0.04275456, 0.04707553, 0.06725929, 0.05697506,\n", " 0.0577874 , 0.04623621, 0.03177068, 0.01874206, 0.05587149,\n", " 0.05580291, 0.04256553, 0.03625484, 0.04998939, 0.0781173 ,\n", " 0.11202149, 0.16487397, 0.10515313, 0.22132628, 0.14726933,\n", " 0.03476583, 0.0293329 , 0.03798615, 0.09745271, 0.07042488,\n", " 0.04373445, 0.07296759, 0.08742276, 0.10569645, 0.22623521]),\n", " 'mean_score_time': array([0.45047684, 0.45714602, 0.4183434 , 0.44898777, 0.45384693,\n", " 0.46315026, 0.43620358, 0.46164112, 0.4391418 , 0.4501318 ,\n", " 0.45846133, 0.43599329, 0.44054184, 0.46960073, 0.46084428,\n", " 0.49376464, 0.4295187 , 0.42602949, 0.43847399, 0.42765632,\n", " 0.43435025, 0.42474437, 0.44880104, 0.43577037, 0.43333535,\n", " 0.42483249, 0.43919163, 0.42167859, 0.42816215, 0.45055833,\n", " 0.41359997, 0.45228491, 0.44458857, 0.43079457, 0.44092717,\n", " 0.45220747, 0.47940674, 0.45597787, 0.47957048, 0.46685867,\n", " 0.44076443, 0.45997396, 0.46912503, 0.49164128, 0.4674027 ,\n", " 0.49251075, 0.45494628, 0.51233678, 0.45821381, 0.34683628]),\n", " 'std_score_time': array([0.01720297, 0.04384708, 0.02255728, 0.03739307, 0.01739511,\n", " 0.02808892, 0.01232231, 0.03649716, 0.02363613, 0.02076977,\n", " 0.042372 , 0.01075888, 0.00828188, 0.0283921 , 0.04988519,\n", " 0.03591959, 0.01403997, 0.00931752, 0.01194943, 0.01462739,\n", " 0.01634563, 0.01704705, 0.03384312, 0.03549023, 0.03455489,\n", " 0.0084412 , 0.01340478, 0.00975513, 0.01267895, 0.02809188,\n", " 0.0162326 , 0.03091601, 0.01952567, 0.00924466, 0.01567532,\n", " 0.01916811, 0.02552566, 0.00673686, 0.04554566, 0.03245472,\n", " 0.01283099, 0.03262466, 0.02411342, 0.02191526, 0.00963087,\n", " 0.03399785, 0.01723422, 0.08456728, 0.02803731, 0.0318799 ]),\n", " 'param_alpha': masked_array(data=[0.0007459343285726547, 0.566984951147885,\n", " 0.0457056309980145, 0.009846738873614562,\n", " 6.0268891286825045e-05, 6.025215736203858e-05,\n", " 1.951722464144947e-05, 0.21423021757741043,\n", " 0.010129197956845729, 0.034702669886504146,\n", " 1.2674255898937214e-05, 0.7072114131472232,\n", " 0.1452824663751603, 0.00011526449540315612,\n", " 8.111941985431919e-05, 8.260808399079601e-05,\n", " 0.00033205591037519585, 0.00420515645091387,\n", " 0.0014445251022763056, 0.0002858549394196192,\n", " 0.01146210740342503, 4.98275235707645e-05,\n", " 0.00028888383623653144, 0.0006789053271698484,\n", " 0.0019069966103000427, 0.08431013932082466,\n", " 9.962513222055098e-05, 0.0037253938395788847,\n", " 0.009163741808778781, 1.7070728830306648e-05,\n", " 0.010907475835157693, 7.122305833333864e-05,\n", " 2.1147447960615714e-05, 0.5551721685244719,\n", " 0.6732248920775336, 0.11015056790269633,\n", " 0.0003334792728637581, 3.0786517836196185e-05,\n", " 0.026373339933815243, 0.0015876781526923994,\n", " 4.075596440072869e-05, 0.0029914693021302154,\n", " 1.485739280627924e-05, 0.3520481045526039,\n", " 0.00019674328025306103, 0.020540519425388443,\n", " 0.00036187233309596217, 0.003984190594434687,\n", " 0.005414413211338522, 8.3998644459575e-05],\n", " mask=[False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False,\n", " False, False],\n", " fill_value='?',\n", " dtype=object),\n", " 'params': [{'alpha': 0.0007459343285726547},\n", " {'alpha': 0.566984951147885},\n", " {'alpha': 0.0457056309980145},\n", " {'alpha': 0.009846738873614562},\n", " {'alpha': 6.0268891286825045e-05},\n", " {'alpha': 6.025215736203858e-05},\n", " {'alpha': 1.951722464144947e-05},\n", " {'alpha': 0.21423021757741043},\n", " {'alpha': 0.010129197956845729},\n", " {'alpha': 0.034702669886504146},\n", " {'alpha': 1.2674255898937214e-05},\n", " {'alpha': 0.7072114131472232},\n", " {'alpha': 0.1452824663751603},\n", " {'alpha': 0.00011526449540315612},\n", " {'alpha': 8.111941985431919e-05},\n", " {'alpha': 8.260808399079601e-05},\n", " {'alpha': 0.00033205591037519585},\n", " {'alpha': 0.00420515645091387},\n", " {'alpha': 0.0014445251022763056},\n", " {'alpha': 0.0002858549394196192},\n", " {'alpha': 0.01146210740342503},\n", " {'alpha': 4.98275235707645e-05},\n", " {'alpha': 0.00028888383623653144},\n", " {'alpha': 0.0006789053271698484},\n", " {'alpha': 0.0019069966103000427},\n", " {'alpha': 0.08431013932082466},\n", " {'alpha': 9.962513222055098e-05},\n", " {'alpha': 0.0037253938395788847},\n", " {'alpha': 0.009163741808778781},\n", " {'alpha': 1.7070728830306648e-05},\n", " {'alpha': 0.010907475835157693},\n", " {'alpha': 7.122305833333864e-05},\n", " {'alpha': 2.1147447960615714e-05},\n", " {'alpha': 0.5551721685244719},\n", " {'alpha': 0.6732248920775336},\n", " {'alpha': 0.11015056790269633},\n", " {'alpha': 0.0003334792728637581},\n", " {'alpha': 3.0786517836196185e-05},\n", " {'alpha': 0.026373339933815243},\n", " {'alpha': 0.0015876781526923994},\n", " {'alpha': 4.075596440072869e-05},\n", " {'alpha': 0.0029914693021302154},\n", " {'alpha': 1.485739280627924e-05},\n", " {'alpha': 0.3520481045526039},\n", " {'alpha': 0.00019674328025306103},\n", " {'alpha': 0.020540519425388443},\n", " {'alpha': 0.00036187233309596217},\n", " {'alpha': 0.003984190594434687},\n", " {'alpha': 0.005414413211338522},\n", " {'alpha': 8.3998644459575e-05}],\n", " 'split0_test_score': array([0.64236111, 0.37890625, 0.42751736, 0.51128472, 0.66666667,\n", " 0.66666667, 0.67100694, 0.390625 , 0.50954861, 0.44010417,\n", " 0.67491319, 0.37934028, 0.39366319, 0.66189236, 0.66536458,\n", " 0.66493056, 0.6484375 , 0.55902778, 0.61588542, 0.64973958,\n", " 0.50173611, 0.66796875, 0.64930556, 0.64409722, 0.60590278,\n", " 0.40581597, 0.6640625 , 0.56857639, 0.51432292, 0.671875 ,\n", " 0.50347222, 0.66666667, 0.67057292, 0.37890625, 0.37847222,\n", " 0.40017361, 0.6484375 , 0.66883681, 0.453125 , 0.61241319,\n", " 0.66710069, 0.58029514, 0.67274306, 0.38107639, 0.65451389,\n", " 0.46397569, 0.64713542, 0.56467014, 0.55251736, 0.66493056]),\n", " 'split1_test_score': array([0.64322917, 0.41666667, 0.4609375 , 0.52690972, 0.66970486,\n", " 0.66970486, 0.67881944, 0.43098958, 0.52473958, 0.47178819,\n", " 0.68098958, 0.41102431, 0.43967014, 0.66536458, 0.66840278,\n", " 0.66883681, 0.65321181, 0.57942708, 0.63064236, 0.65625 ,\n", " 0.52039931, 0.671875 , 0.65538194, 0.64322917, 0.61848958,\n", " 0.44444444, 0.66710069, 0.58723958, 0.53211806, 0.67925347,\n", " 0.52300347, 0.66883681, 0.67708333, 0.41710069, 0.41189236,\n", " 0.44314236, 0.65321181, 0.67534722, 0.48263889, 0.62890625,\n", " 0.67491319, 0.59852431, 0.6796875 , 0.42361111, 0.66102431,\n", " 0.49435764, 0.65234375, 0.58246528, 0.56206597, 0.66883681]),\n", " 'split2_test_score': array([0.6141115 , 0.38022648, 0.43205575, 0.50827526, 0.65156794,\n", " 0.65156794, 0.65766551, 0.40374564, 0.50609756, 0.44337979,\n", " 0.66376307, 0.37412892, 0.41463415, 0.64329268, 0.64764808,\n", " 0.64764808, 0.62630662, 0.54703833, 0.5966899 , 0.62979094,\n", " 0.4956446 , 0.65243902, 0.62979094, 0.61716028, 0.58667247,\n", " 0.42116725, 0.64503484, 0.55662021, 0.51393728, 0.65897213,\n", " 0.50391986, 0.64982578, 0.65635889, 0.38066202, 0.37587108,\n", " 0.41811847, 0.62630662, 0.65374564, 0.45513937, 0.59320557,\n", " 0.65287456, 0.56445993, 0.66114983, 0.38850174, 0.6337108 ,\n", " 0.46689895, 0.625 , 0.55182927, 0.53571429, 0.64764808]),\n", " 'split3_test_score': array([0.6271777 , 0.43031359, 0.47343206, 0.52439024, 0.64634146,\n", " 0.64634146, 0.6533101 , 0.44642857, 0.52351916, 0.47778746,\n", " 0.65679443, 0.42595819, 0.45121951, 0.64764808, 0.64372822,\n", " 0.64416376, 0.63719512, 0.56794425, 0.61367596, 0.64155052,\n", " 0.5152439 , 0.646777 , 0.64155052, 0.63022648, 0.60191638,\n", " 0.46341463, 0.64634146, 0.57317073, 0.5283101 , 0.6533101 ,\n", " 0.51872822, 0.64503484, 0.6533101 , 0.43074913, 0.42595819,\n", " 0.45949477, 0.63719512, 0.64982578, 0.48562718, 0.60932056,\n", " 0.64939024, 0.58188153, 0.65418118, 0.44033101, 0.64416376,\n", " 0.49433798, 0.63632404, 0.57012195, 0.55444251, 0.64416376]),\n", " 'split4_test_score': array([0.60409408, 0.41419861, 0.45862369, 0.50566202, 0.63545296,\n", " 0.63545296, 0.64067944, 0.42857143, 0.5043554 , 0.46602787,\n", " 0.64067944, 0.4033101 , 0.43423345, 0.62848432, 0.6337108 ,\n", " 0.6337108 , 0.6184669 , 0.54268293, 0.58580139, 0.61890244,\n", " 0.50130662, 0.63675958, 0.61890244, 0.60670732, 0.58057491,\n", " 0.44163763, 0.62891986, 0.54834495, 0.50783972, 0.63937282,\n", " 0.50348432, 0.63414634, 0.6402439 , 0.41463415, 0.4033101 ,\n", " 0.43728223, 0.6184669 , 0.63850174, 0.47212544, 0.58362369,\n", " 0.6380662 , 0.55879791, 0.6402439 , 0.42203833, 0.62369338,\n", " 0.47996516, 0.61672474, 0.54442509, 0.53222997, 0.63327526]),\n", " 'mean_test_score': array([0.62619471, 0.40406232, 0.45051327, 0.51530439, 0.65394678,\n", " 0.65394678, 0.66029629, 0.42007205, 0.51365206, 0.4598175 ,\n", " 0.66342794, 0.39875236, 0.42668409, 0.64933641, 0.65177089,\n", " 0.651858 , 0.63672359, 0.55922407, 0.60853901, 0.6392467 ,\n", " 0.50686611, 0.65516387, 0.63898628, 0.62828409, 0.59871122,\n", " 0.43529599, 0.65029187, 0.56679037, 0.51930562, 0.6605567 ,\n", " 0.51052162, 0.65290209, 0.65951383, 0.40441045, 0.39910079,\n", " 0.43164229, 0.63672359, 0.65725144, 0.46973117, 0.60549385,\n", " 0.65646898, 0.57679176, 0.66160109, 0.41111172, 0.64342123,\n", " 0.47990708, 0.63550559, 0.56270234, 0.54739402, 0.65177089]),\n", " 'std_test_score': array([0.01540736, 0.02074472, 0.01771556, 0.00866923, 0.01277211,\n", " 0.01277211, 0.01340428, 0.02009865, 0.00872502, 0.01525516,\n", " 0.01415579, 0.01946625, 0.02031353, 0.01333222, 0.01318527,\n", " 0.01315612, 0.01304699, 0.01346361, 0.01566089, 0.01347457,\n", " 0.00934325, 0.01311266, 0.01318936, 0.01460465, 0.0136164 ,\n", " 0.01991613, 0.01394397, 0.01345909, 0.00927742, 0.01401038,\n", " 0.00855507, 0.01316445, 0.01304424, 0.02085063, 0.0193278 ,\n", " 0.0205572 , 0.01304699, 0.01327147, 0.01351821, 0.01574781,\n", " 0.01307603, 0.01404518, 0.01387552, 0.02255094, 0.01354214,\n", " 0.0129622 , 0.01328559, 0.01342966, 0.01146787, 0.01327687]),\n", " 'rank_test_score': array([24, 48, 41, 34, 9, 9, 4, 45, 35, 40, 1, 50, 44, 16, 13, 12, 20,\n", " 31, 25, 18, 37, 8, 19, 23, 27, 42, 15, 29, 33, 3, 36, 11, 5, 47,\n", " 49, 43, 20, 6, 39, 26, 7, 28, 2, 46, 17, 38, 22, 30, 32, 13])}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sequential_search.all_cv_results_[\"step4\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test the ESN\n", "\n", "Finally, we increase the reservoir size and compare the impact of uni- and bidirectional ESNs. Notice that the ESN strongly benefit from both, increasing the reservoir size and from the bi-directional working mode." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CV results\tFit time\tInference time\tAccuracy score\tSize[Bytes]\n", "{'fit_time': array([1.20268822, 1.31980658, 1.17576218, 1.22178006, 1.24531698]), 'score_time': array([0.20427084, 0.1787076 , 0.19799089, 0.21658278, 0.17672634]), 'test_score': array([0.91319444, 0.91666667, 0.89547038, 0.91289199, 0.87108014])}\t1.092219591140747\t0.16528797149658203\t0.925\t29892\n", "{'fit_time': array([1.34596086, 1.32152963, 1.32597208, 1.34012151, 1.38048863]), 'score_time': array([0.22114849, 0.2336669 , 0.17909884, 0.2270937 , 0.20272541]), 'test_score': array([0.94444444, 0.93402778, 0.93728223, 0.92334495, 0.91289199])}\t1.1797807216644287\t0.16232609748840332\t0.9416666666666667\t99092\n", "{'fit_time': array([1.71792817, 1.72016811, 1.70596719, 1.65824962, 1.68945742]), 'score_time': array([0.23715806, 0.23139477, 0.24706268, 0.26664233, 0.25166321]), 'test_score': array([0.95138889, 0.94791667, 0.94773519, 0.94425087, 0.91986063])}\t1.5785596370697021\t0.2425389289855957\t0.9611111111111111\t357492\n", "{'fit_time': array([4.3435142 , 4.2941072 , 4.47814965, 4.45681834, 4.38268828]), 'score_time': array([0.47901511, 0.49599099, 0.41987824, 0.42973256, 0.44941235]), 'test_score': array([0.94791667, 0.9375 , 0.94425087, 0.95470383, 0.92334495])}\t2.862438201904297\t0.22818946838378906\t0.9638888888888889\t1354292\n", "{'fit_time': array([5.95386696, 6.16443706, 6.10848475, 6.10720062, 6.09413409]), 'score_time': array([0.45308208, 0.40109324, 0.41994405, 0.4118824 , 0.42954731]), 'test_score': array([0.94791667, 0.94791667, 0.94773519, 0.95470383, 0.92334495])}\t3.914069175720215\t0.3445422649383545\t0.9611111111111111\t2092692\n", "{'fit_time': array([3.28861213, 3.14525843, 2.96652579, 3.16590023, 3.28161478]), 'score_time': array([0.70867753, 0.73422313, 0.88020873, 0.7299583 , 0.66681576]), 'test_score': array([0.94791667, 0.95138889, 0.94076655, 0.91986063, 0.8989547 ])}\t1.941004991531372\t0.3332488536834717\t0.9083333333333333\t98692\n", "{'fit_time': array([3.2477181 , 3.34564877, 3.25544882, 3.29266453, 3.40397859]), 'score_time': array([0.66511965, 0.68270802, 0.64036489, 0.65441465, 0.61603808]), 'test_score': array([0.97222222, 0.95833333, 0.95818815, 0.92682927, 0.93379791])}\t1.932464599609375\t0.3669440746307373\t0.9305555555555556\t356692\n", "{'fit_time': array([5.44010997, 5.43559408, 5.36100101, 5.35568452, 5.48375392]), 'score_time': array([0.6066854 , 0.59092546, 0.6019001 , 0.60729814, 0.56611323]), 'test_score': array([0.97569444, 0.96527778, 0.96864111, 0.95121951, 0.94425087])}\t3.334501028060913\t0.49420928955078125\t0.9583333333333334\t1352692\n", "{'fit_time': array([14.56774807, 14.57170606, 14.46315455, 14.58913779, 14.54430723]), 'score_time': array([0.93357658, 0.9349041 , 0.92975593, 0.94193864, 0.91288638]), 'test_score': array([0.97916667, 0.96180556, 0.97560976, 0.95818815, 0.94425087])}\t7.724725008010864\t0.6274893283843994\t0.9583333333333334\t5264692\n", "{'fit_time': array([20.9128716 , 20.93491912, 20.86150718, 21.00014853, 20.94348502]), 'score_time': array([0.81963921, 0.8159306 , 0.84687257, 0.80614424, 0.81322145]), 'test_score': array([0.97569444, 0.96527778, 0.97212544, 0.95121951, 0.95470383])}\t11.71635103225708\t0.8147547245025635\t0.9694444444444444\t8180692\n" ] } ], "source": [ "param_grid = {'hidden_layer_size': [50, 100, 200, 400, 500],\n", " 'bidirectional': [False, True]}\n", "\n", "print(\"CV results\\tFit time\\tInference time\\tAccuracy score\\tSize[Bytes]\")\n", "for params in ParameterGrid(param_grid):\n", " esn_cv = cross_validate(clone(base_esn).set_params(**params), X=X_train, y=y_train,\n", " scoring=make_scorer(accuracy_score), n_jobs=-1)\n", " t1 = time.time()\n", " esn = clone(base_esn).set_params(**params).fit(X_train, y_train)\n", " t_fit = time.time() - t1\n", " t1 = time.time()\n", " esn_par = clone(base_esn).set_params(**params).fit(X_train, y_train, n_jobs=-1)\n", " t_fit_par = time.time() - t1\n", " mem_size = esn.__sizeof__()\n", " t1 = time.time()\n", " acc_score = accuracy_score(y_test, esn.predict(X_test))\n", " t_inference = time.time() - t1\n", " print(f\"{esn_cv}\\t{t_fit}\\t{t_inference}\\t{acc_score}\\t{mem_size}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alternatively, we can also use a PyTorch implementation of the ESN model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from pyrcn.nn import ESN\n", "import torch" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "base_esn = ESN(input_size=8, hidden_size=50, num_layers=1, nonlinearity='tanh',\n", " bias=True, input_scaling=0.016952130531190705,\n", " spectral_radius=1.0214946051551315, bias_scaling=1.500499867548985,\n", " bias_shift=0., input_sparsity=0.1, recurrent_sparsity=0.1,\n", " bidirectional=False)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 50])\n" ] } ], "source": [ "for x, y in zip(X_train, y_train):\n", " x = torch.Tensor(x).float()\n", " print(base_esn(x)[0].shape)\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 4 }