{ "cells": [ { "cell_type": "markdown", "id": "a79eaad5-5e70-47ba-8527-f5df267d1de4", "metadata": {}, "source": [ "# Cross validation with `scikit-learn`\n", "\n", "This example shows how `pykoop` works with `scikit-learn` to allow cross-validation over all parameters in the Koopman pipeline." ] }, { "cell_type": "code", "execution_count": 1, "id": "53030cd1-edcc-4cea-b2bb-d24989ee2a2c", "metadata": {}, "outputs": [], "source": [ "# Imports\n", "import numpy as np\n", "import pandas\n", "import sklearn.model_selection\n", "import sklearn.preprocessing\n", "from matplotlib import pyplot as plt\n", "\n", "import pykoop\n", "import pykoop.dynamic_models\n", "\n", "# Set plot defaults\n", "plt.rc('lines', linewidth=2)\n", "plt.rc('axes', grid=True)\n", "plt.rc('grid', linestyle='--')" ] }, { "cell_type": "markdown", "id": "d8e334c3-9094-4a19-a961-f5a4f179d6c8", "metadata": {}, "source": [ "Load example data from the library. `eg` is a `dict` containing training data, validation data, and a few related parameters." ] }, { "cell_type": "code", "execution_count": 2, "id": "0d762dcd-4e7c-44bf-a5a4-06c6675fe2b4", "metadata": {}, "outputs": [], "source": [ "eg = pykoop.example_data_msd()" ] }, { "cell_type": "markdown", "id": "540f9278-1122-4959-8131-4f3031eca4b8", "metadata": {}, "source": [ "Create the Koopman pipeline. We don't need to set any lifting functions since they'll be set during the cross-validation." ] }, { "cell_type": "code", "execution_count": 3, "id": "a3a6f0e6-4047-4762-8b25-226bb4b30d83", "metadata": {}, "outputs": [], "source": [ "kp = pykoop.KoopmanPipeline(regressor=pykoop.Edmd())" ] }, { "cell_type": "markdown", "id": "b5d89a1c-2438-479a-b225-c48613b01995", "metadata": {}, "source": [ "Set up the cross-validation splitting. Here we use the episode feature to ensure that the split keeps individual experiments intact." ] }, { "cell_type": "code", "execution_count": 4, "id": "e2096c1d-3817-41a5-ba8b-3b0f2bb2340a", "metadata": {}, "outputs": [], "source": [ "episode_feature = eg['X_train'][:, 0]\n", "cv = sklearn.model_selection.GroupShuffleSplit(\n", " random_state=1234,\n", " n_splits=3,\n", ").split(eg['X_train'], groups=episode_feature)" ] }, { "cell_type": "markdown", "id": "80831684-dc4a-41a3-9753-a9c80427f459", "metadata": {}, "source": [ "Choose the cross-validation parameters. We can cross-validate over entire lifting functions, or just specific regressor or lifting function parameters using the `scikit-learn` parameter setting conventions." ] }, { "cell_type": "code", "execution_count": 5, "id": "802b9497-a1a5-426e-b089-07d778207149", "metadata": {}, "outputs": [], "source": [ "params = {\n", " # Lifting functions to try\n", " 'lifting_functions': [\n", " [(\n", " 'ss',\n", " pykoop.SkLearnLiftingFn(\n", " sklearn.preprocessing.StandardScaler()),\n", " )],\n", " [\n", " (\n", " 'ma',\n", " pykoop.SkLearnLiftingFn(\n", " sklearn.preprocessing.MaxAbsScaler()),\n", " ),\n", " (\n", " 'pl',\n", " pykoop.PolynomialLiftingFn(order=2),\n", " ),\n", " (\n", " 'ss',\n", " pykoop.SkLearnLiftingFn(\n", " sklearn.preprocessing.StandardScaler()),\n", " ),\n", " ],\n", " ],\n", " # Regressor parameters to try\n", " 'regressor__alpha': [0, 0.1, 1, 10],\n", "}" ] }, { "cell_type": "markdown", "id": "b8629847-03ab-4b85-91ea-d64b1813b53c", "metadata": {}, "source": [ "Set up the grid search cross-validation. We have multiple scoring metrics over different time windows, but we choose the \"ten step ahead\" prediction ranking and refitting." ] }, { "cell_type": "code", "execution_count": 6, "id": "21376df6-a3c4-42c4-ab60-8e94e052ba0b", "metadata": {}, "outputs": [], "source": [ "gs = sklearn.model_selection.GridSearchCV(\n", " kp,\n", " params,\n", " cv=cv,\n", " # Score using short and long prediction time frames\n", " scoring={\n", " f'full_episode': pykoop.KoopmanPipeline.make_scorer(),\n", " f'ten_steps': pykoop.KoopmanPipeline.make_scorer(n_steps=10),\n", " },\n", " # Rank according to short time frame\n", " refit='ten_steps',\n", ")" ] }, { "cell_type": "markdown", "id": "0667f03a-e4a8-4ee4-b74f-430a54e81eeb", "metadata": {}, "source": [ "Perform the cross-valdiation and pick the winner." ] }, { "cell_type": "code", "execution_count": 7, "id": "07bcb7bc-0881-49da-a99c-8408028c9f86", "metadata": {}, "outputs": [], "source": [ "gs.fit(\n", " eg['X_train'],\n", " n_inputs=eg['n_inputs'],\n", " episode_feature=eg['episode_feature'],\n", ")\n", "best_estimator = gs.best_estimator_" ] }, { "cell_type": "markdown", "id": "a6236025-6f40-4d87-96a9-a1d2f4285124", "metadata": {}, "source": [ "This is the matrix approximation of the Koopman operator. It needs to be transposed because `scikit-learn` puts time on the first axis and features on the second." ] }, { "cell_type": "code", "execution_count": 8, "id": "d506a09e-e5d4-4fdb-90e1-4d3699c51b5a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 9.91883440e-01, 8.94273868e-02, 1.79097934e-04,\n", " -3.13640303e-04, -2.95548469e-03, 9.70314627e-03,\n", " -9.98717276e-05, 4.42764949e-03, -1.60312780e-03],\n", " [-1.43815421e-01, 8.73130869e-01, -7.37399270e-04,\n", " -5.08492758e-03, -8.62861954e-03, 1.48763701e-01,\n", " 4.77555077e-03, 1.52077930e-02, -7.06918061e-03],\n", " [ 2.31449054e-04, 4.66948082e-06, 9.82608512e-01,\n", " 1.41282965e-01, 9.42406644e-03, 8.19992260e-05,\n", " 1.93299275e-02, 1.27826191e-03, -1.11644849e-03],\n", " [-1.77900783e-04, -1.75636014e-03, -1.76376390e-01,\n", " 8.48278785e-01, 9.60773702e-02, 1.96343500e-03,\n", " 1.43159804e-01, 3.21235551e-02, -2.91522952e-03],\n", " [ 4.36902662e-04, 1.68738238e-03, 2.18075712e-02,\n", " -1.93181000e-01, 7.61323258e-01, -2.19676979e-03,\n", " -3.54741406e-02, 2.35737972e-01, 1.35329508e-02]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_estimator.regressor_.coef_.T" ] }, { "cell_type": "markdown", "id": "c8567565-3e8b-45b2-9847-90fcb9ba39f4", "metadata": {}, "source": [ "These are the cross-validation results summarized in a big table. Scores are negated mean squared errors." ] }, { "cell_type": "code", "execution_count": 9, "id": "7dc8a2d6-4c51-4186-9992-811b53832648", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_fit_timestd_fit_timemean_score_timestd_score_timeparam_lifting_functionsparam_regressor__alphaparamssplit0_test_full_episodesplit1_test_full_episodesplit2_test_full_episodemean_test_full_episodestd_test_full_episoderank_test_full_episodesplit0_test_ten_stepssplit1_test_ten_stepssplit2_test_ten_stepsmean_test_ten_stepsstd_test_ten_stepsrank_test_ten_steps
00.0021290.0010620.0947160.012392[(ss, SkLearnLiftingFn(transformer=StandardSca...0{'lifting_functions': [('ss', SkLearnLiftingFn...-0.000396-0.000421-0.000683-0.0005000.0001304-0.000142-0.000153-0.000406-0.0002330.0001223
10.0015220.0003900.0962140.015442[(ss, SkLearnLiftingFn(transformer=StandardSca...0.1{'lifting_functions': [('ss', SkLearnLiftingFn...-0.000427-0.000373-0.000653-0.0004850.0001212-0.000138-0.000110-0.000361-0.0002030.0001122
20.0017140.0002550.0945660.005554[(ss, SkLearnLiftingFn(transformer=StandardSca...1{'lifting_functions': [('ss', SkLearnLiftingFn...-0.000745-0.000440-0.000645-0.0006100.0001275-0.000140-0.000507-0.000124-0.0002570.0001774
30.0016590.0005840.0891890.012520[(ss, SkLearnLiftingFn(transformer=StandardSca...10{'lifting_functions': [('ss', SkLearnLiftingFn...-0.001322-0.004338-0.003608-0.0030890.0012857-0.000290-0.003897-0.000205-0.0014640.0017217
40.0029780.0006350.1859430.011038[(ma, SkLearnLiftingFn(transformer=MaxAbsScale...0{'lifting_functions': [('ma', SkLearnLiftingFn...-0.000314-0.000411-0.000644-0.0004560.0001381-0.000230-0.000217-0.000369-0.0002720.0000695
50.0038030.0016850.2404370.058409[(ma, SkLearnLiftingFn(transformer=MaxAbsScale...0.1{'lifting_functions': [('ma', SkLearnLiftingFn...-0.000467-0.000384-0.000628-0.0004930.0001013-0.000199-0.000166-0.000120-0.0001620.0000321
60.0030410.0006200.2407460.083650[(ma, SkLearnLiftingFn(transformer=MaxAbsScale...1{'lifting_functions': [('ma', SkLearnLiftingFn...-0.002445-0.001038-0.002982-0.0021550.0008196-0.000429-0.003733-0.000212-0.0014580.0016116
70.0056390.0026750.2373060.016316[(ma, SkLearnLiftingFn(transformer=MaxAbsScale...10{'lifting_functions': [('ma', SkLearnLiftingFn...-0.004469-0.004848-0.003980-0.0044320.0003558-0.001622-0.007135-0.000664-0.0031400.0028528
\n", "
" ], "text/plain": [ " mean_fit_time std_fit_time mean_score_time std_score_time \\\n", "0 0.002129 0.001062 0.094716 0.012392 \n", "1 0.001522 0.000390 0.096214 0.015442 \n", "2 0.001714 0.000255 0.094566 0.005554 \n", "3 0.001659 0.000584 0.089189 0.012520 \n", "4 0.002978 0.000635 0.185943 0.011038 \n", "5 0.003803 0.001685 0.240437 0.058409 \n", "6 0.003041 0.000620 0.240746 0.083650 \n", "7 0.005639 0.002675 0.237306 0.016316 \n", "\n", " param_lifting_functions param_regressor__alpha \\\n", "0 [(ss, SkLearnLiftingFn(transformer=StandardSca... 0 \n", "1 [(ss, SkLearnLiftingFn(transformer=StandardSca... 0.1 \n", "2 [(ss, SkLearnLiftingFn(transformer=StandardSca... 1 \n", "3 [(ss, SkLearnLiftingFn(transformer=StandardSca... 10 \n", "4 [(ma, SkLearnLiftingFn(transformer=MaxAbsScale... 0 \n", "5 [(ma, SkLearnLiftingFn(transformer=MaxAbsScale... 0.1 \n", "6 [(ma, SkLearnLiftingFn(transformer=MaxAbsScale... 1 \n", "7 [(ma, SkLearnLiftingFn(transformer=MaxAbsScale... 10 \n", "\n", " params \\\n", "0 {'lifting_functions': [('ss', SkLearnLiftingFn... \n", "1 {'lifting_functions': [('ss', SkLearnLiftingFn... \n", "2 {'lifting_functions': [('ss', SkLearnLiftingFn... \n", "3 {'lifting_functions': [('ss', SkLearnLiftingFn... \n", "4 {'lifting_functions': [('ma', SkLearnLiftingFn... \n", "5 {'lifting_functions': [('ma', SkLearnLiftingFn... \n", "6 {'lifting_functions': [('ma', SkLearnLiftingFn... \n", "7 {'lifting_functions': [('ma', SkLearnLiftingFn... \n", "\n", " split0_test_full_episode split1_test_full_episode \\\n", "0 -0.000396 -0.000421 \n", "1 -0.000427 -0.000373 \n", "2 -0.000745 -0.000440 \n", "3 -0.001322 -0.004338 \n", "4 -0.000314 -0.000411 \n", "5 -0.000467 -0.000384 \n", "6 -0.002445 -0.001038 \n", "7 -0.004469 -0.004848 \n", "\n", " split2_test_full_episode mean_test_full_episode std_test_full_episode \\\n", "0 -0.000683 -0.000500 0.000130 \n", "1 -0.000653 -0.000485 0.000121 \n", "2 -0.000645 -0.000610 0.000127 \n", "3 -0.003608 -0.003089 0.001285 \n", "4 -0.000644 -0.000456 0.000138 \n", "5 -0.000628 -0.000493 0.000101 \n", "6 -0.002982 -0.002155 0.000819 \n", "7 -0.003980 -0.004432 0.000355 \n", "\n", " rank_test_full_episode split0_test_ten_steps split1_test_ten_steps \\\n", "0 4 -0.000142 -0.000153 \n", "1 2 -0.000138 -0.000110 \n", "2 5 -0.000140 -0.000507 \n", "3 7 -0.000290 -0.003897 \n", "4 1 -0.000230 -0.000217 \n", "5 3 -0.000199 -0.000166 \n", "6 6 -0.000429 -0.003733 \n", "7 8 -0.001622 -0.007135 \n", "\n", " split2_test_ten_steps mean_test_ten_steps std_test_ten_steps \\\n", "0 -0.000406 -0.000233 0.000122 \n", "1 -0.000361 -0.000203 0.000112 \n", "2 -0.000124 -0.000257 0.000177 \n", "3 -0.000205 -0.001464 0.001721 \n", "4 -0.000369 -0.000272 0.000069 \n", "5 -0.000120 -0.000162 0.000032 \n", "6 -0.000212 -0.001458 0.001611 \n", "7 -0.000664 -0.003140 0.002852 \n", "\n", " rank_test_ten_steps \n", "0 3 \n", "1 2 \n", "2 4 \n", "3 7 \n", "4 5 \n", "5 1 \n", "6 6 \n", "7 8 " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pandas.DataFrame(gs.cv_results_)" ] }, { "cell_type": "markdown", "id": "991822f5-4611-4f47-8881-9b604d16e13e", "metadata": {}, "source": [ "This was the winner" ] }, { "cell_type": "code", "execution_count": 10, "id": "0e802339-af81-4774-9bdd-8b7ce38b66d8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "KoopmanPipeline(lifting_functions=[('ma',\n", " SkLearnLiftingFn(transformer=MaxAbsScaler())),\n", " ('pl', PolynomialLiftingFn(order=2)),\n", " ('ss',\n", " SkLearnLiftingFn(transformer=StandardScaler()))],\n", " regressor=Edmd(alpha=0.1))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_estimator" ] }, { "cell_type": "markdown", "id": "5f7ad1de-9ac8-480a-bda9-2ac5ec527fd0", "metadata": {}, "source": [ "Predict a new trajectory using the best pipeline." ] }, { "cell_type": "code", "execution_count": 11, "id": "c8cda953-c792-4af9-a3c3-784bac40d208", "metadata": {}, "outputs": [], "source": [ "X_pred = best_estimator.predict_trajectory(eg['x0_valid'], eg['u_valid'])" ] }, { "cell_type": "markdown", "id": "e5112b9c-f7c6-43e0-9764-22b58a97b1d1", "metadata": {}, "source": [ "Score a new trajectory using the best pipeline." ] }, { "cell_type": "code", "execution_count": 12, "id": "a38d17a3-76cb-4dd7-b906-2212263664c1", "metadata": {}, "outputs": [], "source": [ "score = best_estimator.score(eg['X_valid'])" ] }, { "cell_type": "markdown", "id": "b8803df6-7afe-4dea-a3ae-d8ac7eafdf2c", "metadata": {}, "source": [ "Plot the prediction" ] }, { "cell_type": "code", "execution_count": 13, "id": "93cbb4c7-d3de-43a1-9d3b-a108c87b7aa6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(\n", " best_estimator.n_states_in_ + best_estimator.n_inputs_in_,\n", " 1,\n", " constrained_layout=True,\n", " sharex=True,\n", " figsize=(6, 6),\n", ")\n", "# Plot true trajectory\n", "ax[0].plot(eg['t'], eg['X_valid'][:, 1], label='True trajectory')\n", "ax[1].plot(eg['t'], eg['X_valid'][:, 2])\n", "ax[2].plot(eg['t'], eg['X_valid'][:, 3])\n", "# Plot predicted trajectory\n", "ax[0].plot(eg['t'], X_pred[:, 1], '--', label='Predicted trajectory')\n", "ax[1].plot(eg['t'], X_pred[:, 2], '--')\n", "# Add labels\n", "ax[-1].set_xlabel('$t$')\n", "ax[0].set_ylabel('$x(t)$')\n", "ax[1].set_ylabel(r'$\\dot{x}(t)$')\n", "ax[2].set_ylabel('$u$')\n", "ax[0].set_title(f'True and predicted states; MSE={-1 * score:.2e}')\n", "ax[0].legend(loc='upper right')" ] } ], "metadata": { "kernelspec": { "display_name": "pykoop", "language": "python", "name": "pykoop" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }