{ "cells": [ { "cell_type": "markdown", "id": "a79eaad5-5e70-47ba-8527-f5df267d1de4", "metadata": {}, "source": [ "# Cross validation with `scikit-learn`\n", "\n", "This example shows how `pykoop` works with `scikit-learn` to allow cross-validation over all parameters in the Koopman pipeline." ] }, { "cell_type": "code", "execution_count": 1, "id": "53030cd1-edcc-4cea-b2bb-d24989ee2a2c", "metadata": {}, "outputs": [], "source": [ "# Imports\n", "import numpy as np\n", "import pandas\n", "import sklearn.model_selection\n", "import sklearn.preprocessing\n", "from matplotlib import pyplot as plt\n", "\n", "import pykoop\n", "import pykoop.dynamic_models\n", "\n", "# Set plot defaults\n", "plt.rc('lines', linewidth=2)\n", "plt.rc('axes', grid=True)\n", "plt.rc('grid', linestyle='--')" ] }, { "cell_type": "markdown", "id": "d8e334c3-9094-4a19-a961-f5a4f179d6c8", "metadata": {}, "source": [ "Load example data from the library. `eg` is a `dict` containing training data, validation data, and a few related parameters." ] }, { "cell_type": "code", "execution_count": 2, "id": "0d762dcd-4e7c-44bf-a5a4-06c6675fe2b4", "metadata": {}, "outputs": [], "source": [ "eg = pykoop.example_data_msd()" ] }, { "cell_type": "markdown", "id": "540f9278-1122-4959-8131-4f3031eca4b8", "metadata": {}, "source": [ "Create the Koopman pipeline. We don't need to set any lifting functions since they'll be set during the cross-validation." ] }, { "cell_type": "code", "execution_count": 3, "id": "a3a6f0e6-4047-4762-8b25-226bb4b30d83", "metadata": {}, "outputs": [], "source": [ "kp = pykoop.KoopmanPipeline(regressor=pykoop.Edmd())" ] }, { "cell_type": "markdown", "id": "b5d89a1c-2438-479a-b225-c48613b01995", "metadata": {}, "source": [ "Set up the cross-validation splitting. Here we use the episode feature to ensure that the split keeps individual experiments intact." ] }, { "cell_type": "code", "execution_count": 4, "id": "e2096c1d-3817-41a5-ba8b-3b0f2bb2340a", "metadata": {}, "outputs": [], "source": [ "episode_feature = eg['X_train'][:, 0]\n", "cv = sklearn.model_selection.GroupShuffleSplit(\n", " random_state=1234,\n", " n_splits=3,\n", ").split(eg['X_train'], groups=episode_feature)" ] }, { "cell_type": "markdown", "id": "80831684-dc4a-41a3-9753-a9c80427f459", "metadata": {}, "source": [ "Choose the cross-validation parameters. We can cross-validate over entire lifting functions, or just specific regressor or lifting function parameters using the `scikit-learn` parameter setting conventions." ] }, { "cell_type": "code", "execution_count": 5, "id": "802b9497-a1a5-426e-b089-07d778207149", "metadata": {}, "outputs": [], "source": [ "params = {\n", " # Lifting functions to try\n", " 'lifting_functions': [\n", " [(\n", " 'ss',\n", " pykoop.SkLearnLiftingFn(\n", " sklearn.preprocessing.StandardScaler()),\n", " )],\n", " [\n", " (\n", " 'ma',\n", " pykoop.SkLearnLiftingFn(\n", " sklearn.preprocessing.MaxAbsScaler()),\n", " ),\n", " (\n", " 'pl',\n", " pykoop.PolynomialLiftingFn(order=2),\n", " ),\n", " (\n", " 'ss',\n", " pykoop.SkLearnLiftingFn(\n", " sklearn.preprocessing.StandardScaler()),\n", " ),\n", " ],\n", " ],\n", " # Regressor parameters to try\n", " 'regressor__alpha': [0, 0.1, 1, 10],\n", "}" ] }, { "cell_type": "markdown", "id": "b8629847-03ab-4b85-91ea-d64b1813b53c", "metadata": {}, "source": [ "Set up the grid search cross-validation. We have multiple scoring metrics over different time windows, but we choose the \"ten step ahead\" prediction ranking and refitting." ] }, { "cell_type": "code", "execution_count": 6, "id": "21376df6-a3c4-42c4-ab60-8e94e052ba0b", "metadata": {}, "outputs": [], "source": [ "gs = sklearn.model_selection.GridSearchCV(\n", " kp,\n", " params,\n", " cv=cv,\n", " # Score using short and long prediction time frames\n", " scoring={\n", " f'full_episode': pykoop.KoopmanPipeline.make_scorer(),\n", " f'ten_steps': pykoop.KoopmanPipeline.make_scorer(n_steps=10),\n", " },\n", " # Rank according to short time frame\n", " refit='ten_steps',\n", ")" ] }, { "cell_type": "markdown", "id": "0667f03a-e4a8-4ee4-b74f-430a54e81eeb", "metadata": {}, "source": [ "Perform the cross-valdiation and pick the winner." ] }, { "cell_type": "code", "execution_count": 7, "id": "07bcb7bc-0881-49da-a99c-8408028c9f86", "metadata": {}, "outputs": [], "source": [ "gs.fit(\n", " eg['X_train'],\n", " n_inputs=eg['n_inputs'],\n", " episode_feature=eg['episode_feature'],\n", ")\n", "best_estimator = gs.best_estimator_" ] }, { "cell_type": "markdown", "id": "a6236025-6f40-4d87-96a9-a1d2f4285124", "metadata": {}, "source": [ "This is the matrix approximation of the Koopman operator. It needs to be transposed because `scikit-learn` puts time on the first axis and features on the second." ] }, { "cell_type": "code", "execution_count": 8, "id": "d506a09e-e5d4-4fdb-90e1-4d3699c51b5a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 9.91883440e-01, 8.94273868e-02, 1.79097934e-04,\n", " -3.13640303e-04, -2.95548469e-03, 9.70314627e-03,\n", " -9.98717276e-05, 4.42764949e-03, -1.60312780e-03],\n", " [-1.43815421e-01, 8.73130869e-01, -7.37399270e-04,\n", " -5.08492758e-03, -8.62861954e-03, 1.48763701e-01,\n", " 4.77555077e-03, 1.52077930e-02, -7.06918061e-03],\n", " [ 2.31449054e-04, 4.66948082e-06, 9.82608512e-01,\n", " 1.41282965e-01, 9.42406644e-03, 8.19992260e-05,\n", " 1.93299275e-02, 1.27826191e-03, -1.11644849e-03],\n", " [-1.77900783e-04, -1.75636014e-03, -1.76376390e-01,\n", " 8.48278785e-01, 9.60773702e-02, 1.96343500e-03,\n", " 1.43159804e-01, 3.21235551e-02, -2.91522952e-03],\n", " [ 4.36902662e-04, 1.68738238e-03, 2.18075712e-02,\n", " -1.93181000e-01, 7.61323258e-01, -2.19676979e-03,\n", " -3.54741406e-02, 2.35737972e-01, 1.35329508e-02]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_estimator.regressor_.coef_.T" ] }, { "cell_type": "markdown", "id": "c8567565-3e8b-45b2-9847-90fcb9ba39f4", "metadata": {}, "source": [ "These are the cross-validation results summarized in a big table. Scores are negated mean squared errors." ] }, { "cell_type": "code", "execution_count": 9, "id": "7dc8a2d6-4c51-4186-9992-811b53832648", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | mean_fit_time | \n", "std_fit_time | \n", "mean_score_time | \n", "std_score_time | \n", "param_lifting_functions | \n", "param_regressor__alpha | \n", "params | \n", "split0_test_full_episode | \n", "split1_test_full_episode | \n", "split2_test_full_episode | \n", "mean_test_full_episode | \n", "std_test_full_episode | \n", "rank_test_full_episode | \n", "split0_test_ten_steps | \n", "split1_test_ten_steps | \n", "split2_test_ten_steps | \n", "mean_test_ten_steps | \n", "std_test_ten_steps | \n", "rank_test_ten_steps | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "0.002129 | \n", "0.001062 | \n", "0.094716 | \n", "0.012392 | \n", "[(ss, SkLearnLiftingFn(transformer=StandardSca... | \n", "0 | \n", "{'lifting_functions': [('ss', SkLearnLiftingFn... | \n", "-0.000396 | \n", "-0.000421 | \n", "-0.000683 | \n", "-0.000500 | \n", "0.000130 | \n", "4 | \n", "-0.000142 | \n", "-0.000153 | \n", "-0.000406 | \n", "-0.000233 | \n", "0.000122 | \n", "3 | \n", "
| 1 | \n", "0.001522 | \n", "0.000390 | \n", "0.096214 | \n", "0.015442 | \n", "[(ss, SkLearnLiftingFn(transformer=StandardSca... | \n", "0.1 | \n", "{'lifting_functions': [('ss', SkLearnLiftingFn... | \n", "-0.000427 | \n", "-0.000373 | \n", "-0.000653 | \n", "-0.000485 | \n", "0.000121 | \n", "2 | \n", "-0.000138 | \n", "-0.000110 | \n", "-0.000361 | \n", "-0.000203 | \n", "0.000112 | \n", "2 | \n", "
| 2 | \n", "0.001714 | \n", "0.000255 | \n", "0.094566 | \n", "0.005554 | \n", "[(ss, SkLearnLiftingFn(transformer=StandardSca... | \n", "1 | \n", "{'lifting_functions': [('ss', SkLearnLiftingFn... | \n", "-0.000745 | \n", "-0.000440 | \n", "-0.000645 | \n", "-0.000610 | \n", "0.000127 | \n", "5 | \n", "-0.000140 | \n", "-0.000507 | \n", "-0.000124 | \n", "-0.000257 | \n", "0.000177 | \n", "4 | \n", "
| 3 | \n", "0.001659 | \n", "0.000584 | \n", "0.089189 | \n", "0.012520 | \n", "[(ss, SkLearnLiftingFn(transformer=StandardSca... | \n", "10 | \n", "{'lifting_functions': [('ss', SkLearnLiftingFn... | \n", "-0.001322 | \n", "-0.004338 | \n", "-0.003608 | \n", "-0.003089 | \n", "0.001285 | \n", "7 | \n", "-0.000290 | \n", "-0.003897 | \n", "-0.000205 | \n", "-0.001464 | \n", "0.001721 | \n", "7 | \n", "
| 4 | \n", "0.002978 | \n", "0.000635 | \n", "0.185943 | \n", "0.011038 | \n", "[(ma, SkLearnLiftingFn(transformer=MaxAbsScale... | \n", "0 | \n", "{'lifting_functions': [('ma', SkLearnLiftingFn... | \n", "-0.000314 | \n", "-0.000411 | \n", "-0.000644 | \n", "-0.000456 | \n", "0.000138 | \n", "1 | \n", "-0.000230 | \n", "-0.000217 | \n", "-0.000369 | \n", "-0.000272 | \n", "0.000069 | \n", "5 | \n", "
| 5 | \n", "0.003803 | \n", "0.001685 | \n", "0.240437 | \n", "0.058409 | \n", "[(ma, SkLearnLiftingFn(transformer=MaxAbsScale... | \n", "0.1 | \n", "{'lifting_functions': [('ma', SkLearnLiftingFn... | \n", "-0.000467 | \n", "-0.000384 | \n", "-0.000628 | \n", "-0.000493 | \n", "0.000101 | \n", "3 | \n", "-0.000199 | \n", "-0.000166 | \n", "-0.000120 | \n", "-0.000162 | \n", "0.000032 | \n", "1 | \n", "
| 6 | \n", "0.003041 | \n", "0.000620 | \n", "0.240746 | \n", "0.083650 | \n", "[(ma, SkLearnLiftingFn(transformer=MaxAbsScale... | \n", "1 | \n", "{'lifting_functions': [('ma', SkLearnLiftingFn... | \n", "-0.002445 | \n", "-0.001038 | \n", "-0.002982 | \n", "-0.002155 | \n", "0.000819 | \n", "6 | \n", "-0.000429 | \n", "-0.003733 | \n", "-0.000212 | \n", "-0.001458 | \n", "0.001611 | \n", "6 | \n", "
| 7 | \n", "0.005639 | \n", "0.002675 | \n", "0.237306 | \n", "0.016316 | \n", "[(ma, SkLearnLiftingFn(transformer=MaxAbsScale... | \n", "10 | \n", "{'lifting_functions': [('ma', SkLearnLiftingFn... | \n", "-0.004469 | \n", "-0.004848 | \n", "-0.003980 | \n", "-0.004432 | \n", "0.000355 | \n", "8 | \n", "-0.001622 | \n", "-0.007135 | \n", "-0.000664 | \n", "-0.003140 | \n", "0.002852 | \n", "8 | \n", "