{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Demo of ROCKET transform\n", "\n", "## Overview\n", "\n", "ROCKET [1] transforms time series using random convolutional kernels (random length, weights, bias, dilation, and padding). ROCKET computes two features from the resulting feature maps: the max, and the proportion of positive values (or ppv). The transformed features are used to train a linear classifier.\n", "\n", "[1] Dempster A, Petitjean F, Webb GI (2019) ROCKET: Exceptionally fast and accurate time series classification using random convolutional kernels. [arXiv:1910.13051](https://arxiv.org/abs/1910.13051)\n", "\n", "***\n", "\n", "## Contents\n", "\n", "1. Imports\n", "2. Univariate Time Series\n", "3. Multivariate Time Series\n", "4. Pipeline Example\n", "\n", "***\n", "\n", "## 1 Imports\n", "\n", "Import example data, ROCKET, and a classifier (`RidgeClassifierCV` from scikit-learn), as well as NumPy and `make_pipeline` from scikit-learn.\n", "\n", "**Note**: ROCKET compiles (via Numba) on import, which may take a few seconds." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2020-12-19T14:32:46.441933Z", "iopub.status.busy": "2020-12-19T14:32:46.441213Z", "iopub.status.idle": "2020-12-19T14:32:46.443225Z", "shell.execute_reply": "2020-12-19T14:32:46.444014Z" } }, "outputs": [], "source": [ "# !pip install --upgrade numba" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2020-12-19T14:32:46.448396Z", "iopub.status.busy": "2020-12-19T14:32:46.447602Z", "iopub.status.idle": "2020-12-19T14:32:51.904418Z", "shell.execute_reply": "2020-12-19T14:32:51.905034Z" } }, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.linear_model import RidgeClassifierCV\n", "from sklearn.pipeline import make_pipeline\n", "\n", "from aeon.datasets import load_arrow_head # univariate dataset\n", "from aeon.datasets import load_basic_motions # multivariate dataset\n", "from aeon.transformations.collection.convolution_based import Rocket" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2 Univariate Time Series\n", "\n", "We can transform the data using ROCKET and separately fit a classifier, or we can use ROCKET together with a classifier in a pipeline (section 4, below).\n", "\n", "### 2.1 Load the Training Data\n", "For more details on the data set, see the [univariate time series classification\n", "notebook](https://github.com/aeon-toolkit/aeon/tree/main/examples/classification.ipynb)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2020-12-19T14:32:51.908710Z", "iopub.status.busy": "2020-12-19T14:32:51.908101Z", "iopub.status.idle": "2020-12-19T14:32:51.918987Z", "shell.execute_reply": "2020-12-19T14:32:51.919508Z" } }, "outputs": [], "source": [ "X_train, y_train = load_arrow_head(split=\"train\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 Initialise ROCKET and Transform the Training Data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2020-12-19T14:32:51.923023Z", "iopub.status.busy": "2020-12-19T14:32:51.922451Z", "iopub.status.idle": "2020-12-19T14:32:52.164365Z", "shell.execute_reply": "2020-12-19T14:32:52.164864Z" } }, "outputs": [], "source": [ "rocket = Rocket() # by default, ROCKET uses 10,000 kernels\n", "rocket.fit(X_train)\n", "X_train_transform = rocket.transform(X_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3 Fit a Classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We recommend using `RidgeClassifierCV` from scikit-learn for smaller datasets (fewer than approx. 20K training examples), and using logistic regression trained using stochastic gradient descent for larger datasets." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2020-12-19T14:32:52.168847Z", "iopub.status.busy": "2020-12-19T14:32:52.168155Z", "iopub.status.idle": "2020-12-19T14:32:52.284816Z", "shell.execute_reply": "2020-12-19T14:32:52.285506Z" } }, "outputs": [ { "data": { "text/plain": "RidgeClassifierCV(alphas=array([1.00000000e-03, 4.64158883e-03, 2.15443469e-02, 1.00000000e-01,\n 4.64158883e-01, 2.15443469e+00, 1.00000000e+01, 4.64158883e+01,\n 2.15443469e+02, 1.00000000e+03]))", "text/html": "
RidgeClassifierCV(alphas=array([1.00000000e-03, 4.64158883e-03, 2.15443469e-02, 1.00000000e-01,\n 4.64158883e-01, 2.15443469e+00, 1.00000000e+01, 4.64158883e+01,\n 2.15443469e+02, 1.00000000e+03]))In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
RidgeClassifierCV(alphas=array([1.00000000e-03, 4.64158883e-03, 2.15443469e-02, 1.00000000e-01,\n 4.64158883e-01, 2.15443469e+00, 1.00000000e+01, 4.64158883e+01,\n 2.15443469e+02, 1.00000000e+03]))
RidgeClassifierCV(alphas=array([1.00000000e-03, 4.64158883e-03, 2.15443469e-02, 1.00000000e-01,\n 4.64158883e-01, 2.15443469e+00, 1.00000000e+01, 4.64158883e+01,\n 2.15443469e+02, 1.00000000e+03]))In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
RidgeClassifierCV(alphas=array([1.00000000e-03, 4.64158883e-03, 2.15443469e-02, 1.00000000e-01,\n 4.64158883e-01, 2.15443469e+00, 1.00000000e+01, 4.64158883e+01,\n 2.15443469e+02, 1.00000000e+03]))
Pipeline(steps=[('rocket', Rocket()),\n ('ridgeclassifiercv',\n RidgeClassifierCV(alphas=array([1.00000000e-03, 4.64158883e-03, 2.15443469e-02, 1.00000000e-01,\n 4.64158883e-01, 2.15443469e+00, 1.00000000e+01, 4.64158883e+01,\n 2.15443469e+02, 1.00000000e+03])))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. Pipeline(steps=[('rocket', Rocket()),\n ('ridgeclassifiercv',\n RidgeClassifierCV(alphas=array([1.00000000e-03, 4.64158883e-03, 2.15443469e-02, 1.00000000e-01,\n 4.64158883e-01, 2.15443469e+00, 1.00000000e+01, 4.64158883e+01,\n 2.15443469e+02, 1.00000000e+03])))])Rocket()
RidgeClassifierCV(alphas=array([1.00000000e-03, 4.64158883e-03, 2.15443469e-02, 1.00000000e-01,\n 4.64158883e-01, 2.15443469e+00, 1.00000000e+01, 4.64158883e+01,\n 2.15443469e+02, 1.00000000e+03]))