{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Learning on Tangent Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lead author: Nicolas Guigui.\n", "\n", "In this notebook, we demonstrate how any standard machine learning algorithm can be used on data that live on a manifold yet respecting its geometry. In the previous notebooks we saw that linear operations (mean, linear weighting) don't work on manifold. However, to each point on a manifold, is associated a tangent space, which is a vector space, where all our off-the-shelf ML operations are well defined! \n", "\n", "We will use the [logarithm map](02_from_vector_spaces_to_manifolds.ipynb#From-substraction-to-logarithm-map) to go from points of the manifolds to vectors in the tangent space at a reference point. This will enable to use a simple logistic regression to classify our data." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We import the backend that will be used for geomstats computations and set a seed for reproducibility of the results." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO: Using numpy backend\n" ] } ], "source": [ "import geomstats.backend as gs\n", "\n", "gs.random.seed(2020)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We import the visualization tools." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use data from the [MSLP 2014 Schizophrenia Challenge](https://www.kaggle.com/c/mlsp-2014-mri/data). The dataset correponds to the Functional Connectivity Networks (FCN) extracted from resting-state fMRIs of 86 patients at 28 Regions Of Interest (ROIs). Roughly, an FCN corresponds to a correlation matrix and can be seen as a point on the manifold of Symmetric Positive-Definite (SPD) matrices. Patients are separated in two classes: schizophrenic and control. The goal will be to classify them.\n", "\n", "First we load the data (reshaped as matrices):" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import geomstats.datasets.utils as data_utils\n", "\n", "data, patient_ids, labels = data_utils.load_connectomes()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We plot the first two connectomes from the MSLP dataset with their corresponding labels." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [ "nbsphinx-thumbnail" ] }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "labels_str = [\"Healthy\", \"Schizophrenic\"]\n", "\n", "fig = plt.figure(figsize=(8, 4))\n", "\n", "ax = fig.add_subplot(121)\n", "imgplot = ax.imshow(data[0])\n", "ax.set_title(labels_str[labels[0]])\n", "\n", "ax = fig.add_subplot(122)\n", "imgplot = ax.imshow(data[1])\n", "ax.set_title(labels_str[labels[1]])\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to compare with a standard Euclidean method, we also flatten the data:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(86, 378)\n" ] } ], "source": [ "flat_data, _, _ = data_utils.load_connectomes(as_vectors=True)\n", "print(flat_data.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Manifold" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As mentionned above, correlation matrices are SPD matrices. Because multiple metrics could be used on SPD matrices, we also import two of the most commonly used ones: the Log-Euclidean metric and the Affine-Invariant metric [[PFA2006]](#References). We can use the SPD module from `geomstats` to handle all the geometry, and check that our data indeed belongs to the manifold of SPD matrices:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "from geomstats.geometry.spd_matrices import (\n", " SPDMatrices,\n", " SPDAffineMetric,\n", " SPDLogEuclideanMetric,\n", ")\n", "\n", "manifold = SPDMatrices(28, equip=False)\n", "\n", "spd_ai = SPDMatrices(28, equip=False)\n", "spd_ai.equip_with_metric(SPDAffineMetric)\n", "\n", "spd_le = SPDMatrices(28, equip=False)\n", "spd_le.equip_with_metric(SPDLogEuclideanMetric)\n", "\n", "print(gs.all(manifold.belongs(data)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Transformer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Great! Now, although the sum of two SPD matrices is an SPD matrix, their difference or their linear combination with non-positive weights are not necessarily! Therefore we need to work in a tangent space to perform simple machine learning. But worry not, all the geometry is handled by geomstats, thanks to the preprocessing module." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from geomstats.learning.preprocessing import ToTangentSpace" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What `ToTangentSpace` does is simple: it computes the Frechet Mean of the data set (covered in the previous tutorial), then takes the log of each data point from the mean. This results in a set of tangent vectors, and in the case of the SPD manifold, these are simply symmetric matrices. It then squeezes them to a 1d-vector of size `dim = 28 * (28 + 1) / 2`, and thus outputs an array of shape `[n_patients, dim]`, which can be fed to your favorite scikit-learn algorithm.\n", "\n", "Because the mean of the input data is computed, `ToTangentSpace` should be used in a pipeline (as e.g. scikit-learn's `StandardScaler`) not to leak information from the test set at train time." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import Pipeline\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import cross_validate\n", "\n", "pipeline = Pipeline(\n", " steps=[\n", " (\"feature_ext\", ToTangentSpace(space=spd_ai)),\n", " (\"classifier\", LogisticRegression(C=2)),\n", " ]\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now have all the material to classify connectomes, and we evaluate the model with cross validation. With the affine-invariant metric we obtain:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.7098039215686274\n" ] } ], "source": [ "result = cross_validate(pipeline, data, labels)\n", "print(result[\"test_score\"].mean())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And with the log-Euclidean metric:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.6862745098039216\n" ] } ], "source": [ "pipeline = Pipeline(\n", " steps=[\n", " (\"feature_ext\", ToTangentSpace(space=spd_le)),\n", " (\"classifier\", LogisticRegression(C=2)),\n", " ]\n", ")\n", "\n", "result = cross_validate(pipeline, data, labels)\n", "print(result[\"test_score\"].mean())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But wait, why do the results depend on the metric used? You may remember from the previous notebooks that the Riemannian metric defines the notion of geodesics and distance on the manifold. Both notions are used to compute the Frechet Mean and the logarithms, so changing the metric changes the results, and some metrics may be more suitable than others for different applications.\n", "\n", "We can finally compare to a standard Euclidean logistic regression on the flattened data:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.7333333333333334\n" ] } ], "source": [ "flat_result = cross_validate(LogisticRegression(), flat_data, labels)\n", "print(flat_result[\"test_score\"].mean())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example using Riemannian geometry does not make a big difference compared to applying logistic regression in the ambiant Euclidean space, but there are published results that show how useful geometry can be with this type of data (e.g [[NDV2014]](#References), [[WAZ2918]](#References)). We saw how to use the representation of points on the manifold as tangent vectors at a reference point to fit any machine learning algorithm, and compared the effect of different metrics on the space of symmetric positive-definite matrices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ".. [PFA2006] Pennec, X., Fillard, P. & Ayache, N. A Riemannian Framework for Tensor Computing. Int J Comput Vision 66, 41–66 (2006). https://doi.org/10.1007/s11263-005-3222-z\n", "\n", ".. [NDV2014] Bernard Ng, Martin Dressler, Gaël Varoquaux, Jean-Baptiste Poline, Michael Greicius, et al.. Transport on Riemannian Manifold for Functional Connectivity-based Classification. MICCAI - 17th International Conference on Medical Image Computing and Computer Assisted Intervention, Polina Golland, Sep 2014, Boston, United States. hal-01058521\n", "\n", ".. [WAZ2918] Wong E., Anderson J.S., Zielinski B.A., Fletcher P.T. (2018) Riemannian Regression and Classification Models of Brain Networks Applied to Autism. In: Wu G., Rekik I., Schirmer M., Chung A., Munsell B. (eds) Connectomics in NeuroImaging. CNI 2018. Lecture Notes in Computer Science, vol 11083. Springer, Cham" ] } ], "metadata": { "backends": [ "numpy", "autograd", "pytorch" ], "celltoolbar": "Tags", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 4 }