{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# How to use this notebook\n", "\n", "This Jupyter notebook and the other files in this directory present the code related to the Medium blog post titled \"What is Giotto and Why James Bond Should Use It to Extract Secret Messages\". The idea is to use topological data analysis (TDA) to predict the regime change from a chaotic regime to a non-chaotic regime in time series with different levels of noise. For further information please refer to the blog post.\n", "\n", "\n", "As the feature creation takes a long time (around 45 minutes on a MacBook Pro with 16 cores), a precomputed dataset is loaded. In the 'Plot Features' section the features are directly calculated for a smaller time series for presentation purposes. In case you are interested in creating the features for all the time series yourself, run the bash script 'create_features.sh'." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Library Imports and Some Utility Functions" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Imports from Scikit-learn and XGBoost respectively\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.ensemble import GradientBoostingClassifier\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.metrics import balanced_accuracy_score, make_scorer\n", "from sklearn.linear_model import SGDClassifier\n", "from xgboost import XGBClassifier\n", "\n", "# For bootstrap confidence intervals\n", "from numpy.random import seed\n", "from numpy.random import rand\n", "from numpy.random import randint\n", "\n", "# Others\n", "import pandas as pd\n", "import numpy as np\n", "from datetime import datetime\n", "from itertools import product\n", "import os\n", "from pandarallel import pandarallel\n", "from joblib import Parallel, delayed\n", "from functools import reduce\n", "from scipy.fftpack import rfft\n", "import openml\n", "from openml.datasets.functions import get_dataset\n", "\n", "# Giotto\n", "import giotto as gt\n", "import giotto.diagrams as diag\n", "import giotto.homology as hl\n", "\n", "# Plotting functions\n", "from plotting import plot_diagram, plot_landscapes\n", "from plotting import plot_betti_surfaces, plot_betti_curves\n", "from plotting import plot_point_cloud\n", "import matplotlib.pyplot as plt\n", "import plotly.express as px\n", "\n", "# Our own feature creation and plotting functions\n", "from chaos_detection import *" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Make balanced accuracy scorer\n", "bal_acc_score = make_scorer(balanced_accuracy_score)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Plot Features\n", "Here we create the features for a small time series in order to present the TDA features." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "New pandarallel memory created - Size: 2000 MB\n", "Pandarallel will run on 16 workers\n", "Optimal embedding time delay based on mutual information: 5\n", "Optimal embedding dimension based on false nearest neighbors: 14\n" ] }, { "data": { "text/html": [ "
\n", " | time | \n", "y | \n", "x | \n", "x_dot | \n", "max_10 | \n", "max_20 | \n", "max_50 | \n", "mean_10 | \n", "mean_20 | \n", "mean_50 | \n", "... | \n", "fourier_w_1 | \n", "fourier_w_2 | \n", "num_holes | \n", "avg_lifetime | \n", "betti_0 | \n", "betti_1 | \n", "betti_2 | \n", "betti_argmax_1 | \n", "betti_argmax_2 | \n", "amplitude | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
133 | \n", "13.30133 | \n", "0 | \n", "0.919056 | \n", "1.520114 | \n", "0.919056 | \n", "0.919056 | \n", "0.919056 | \n", "0.895947 | \n", "0.871955 | \n", "0.801546 | \n", "... | \n", "-0.086702 | \n", "1.198089 | \n", "100 | \n", "0.011424 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.080962 | \n", "
134 | \n", "13.40134 | \n", "0 | \n", "0.924506 | \n", "1.521954 | \n", "0.924506 | \n", "0.924506 | \n", "0.924506 | \n", "0.900792 | \n", "0.877009 | \n", "0.806352 | \n", "... | \n", "-0.085955 | \n", "1.199410 | \n", "100 | \n", "0.011424 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.080962 | \n", "
135 | \n", "13.50135 | \n", "0 | \n", "0.922732 | \n", "1.517234 | \n", "0.924506 | \n", "0.924506 | \n", "0.924506 | \n", "0.904971 | \n", "0.881477 | \n", "0.810933 | \n", "... | \n", "-0.091852 | \n", "1.199812 | \n", "100 | \n", "0.011424 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.080962 | \n", "
136 | \n", "13.60136 | \n", "0 | \n", "0.933383 | \n", "1.522403 | \n", "0.933383 | \n", "0.933383 | \n", "0.933383 | \n", "0.910297 | \n", "0.886202 | \n", "0.815624 | \n", "... | \n", "-0.086803 | \n", "1.201020 | \n", "100 | \n", "0.011424 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.080962 | \n", "
137 | \n", "13.70137 | \n", "0 | \n", "0.938219 | \n", "1.523874 | \n", "0.938219 | \n", "0.938219 | \n", "0.938219 | \n", "0.915080 | \n", "0.891244 | \n", "0.820432 | \n", "... | \n", "-0.086749 | \n", "1.202251 | \n", "100 | \n", "0.011424 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.080962 | \n", "
5 rows × 23 columns
\n", "\n", " | index | \n", "time | \n", "y | \n", "x | \n", "coord_1 | \n", "max_10 | \n", "max_20 | \n", "max_50 | \n", "mean_10 | \n", "mean_20 | \n", "... | \n", "num_holes | \n", "avg_lifetime | \n", "betti_0 | \n", "betti_1 | \n", "betti_2 | \n", "betti_argmax_1 | \n", "betti_argmax_2 | \n", "amplitude | \n", "noise_level | \n", "test_set | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1246600 | \n", "150 | \n", "15.00150 | \n", "0 | \n", "0.994617 | \n", "1.524023 | \n", "0.995729 | \n", "0.995729 | \n", "0.995729 | \n", "0.976636 | \n", "0.953035 | \n", "... | \n", "100 | \n", "0.012137 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.085996 | \n", "0.0 | \n", "1 | \n", "
1246601 | \n", "151 | \n", "15.10151 | \n", "0 | \n", "1.005455 | \n", "1.529092 | \n", "1.005455 | \n", "1.005455 | \n", "1.005455 | \n", "0.982045 | \n", "0.957831 | \n", "... | \n", "100 | \n", "0.012137 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.085996 | \n", "0.0 | \n", "1 | \n", "
1246602 | \n", "152 | \n", "15.20152 | \n", "0 | \n", "1.010750 | \n", "1.528463 | \n", "1.010750 | \n", "1.010750 | \n", "1.010750 | \n", "0.986904 | \n", "0.962936 | \n", "... | \n", "100 | \n", "0.012137 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.085996 | \n", "0.0 | \n", "1 | \n", "
1246603 | \n", "153 | \n", "15.30153 | \n", "0 | \n", "1.008770 | \n", "1.523554 | \n", "1.010750 | \n", "1.010750 | \n", "1.010750 | \n", "0.991085 | \n", "0.967422 | \n", "... | \n", "100 | \n", "0.012137 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.085996 | \n", "0.0 | \n", "1 | \n", "
1246604 | \n", "154 | \n", "15.40154 | \n", "0 | \n", "1.019839 | \n", "1.529516 | \n", "1.019839 | \n", "1.019839 | \n", "1.019839 | \n", "0.996478 | \n", "0.972188 | \n", "... | \n", "100 | \n", "0.012137 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0 | \n", "0 | \n", "0.085996 | \n", "0.0 | \n", "1 | \n", "
5 rows × 26 columns
\n", "