{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Spectro-temporal receptive field (STRF) estimation on continuous data\n\nThis demonstrates how an encoding model can be fit with multiple continuous\ninputs. In this case, we simulate the model behind a spectro-temporal receptive\nfield (or STRF). First, we create a linear filter that maps patterns in\nspectro-temporal space onto an output, representing neural activity. We fit\na receptive field model that attempts to recover the original linear filter\nthat was used to create this data.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: Chris Holdgraf \n# Eric Larson \n#\n# License: BSD-3-Clause\n# Copyright the MNE-Python contributors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\nimport numpy as np\nfrom scipy.io import loadmat\nfrom scipy.stats import multivariate_normal\nfrom sklearn.preprocessing import scale\n\nimport mne\nfrom mne.decoding import ReceptiveField, TimeDelayingRidge\n\nrng = np.random.RandomState(1337) # To make this example reproducible" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load audio data\n\nWe'll read in the audio data from :footcite:`CrosseEtAl2016` in order to\nsimulate a response.\n\nIn addition, we'll downsample the data along the time dimension in order to\nspeed up computation. Note that depending on the input values, this may\nnot be desired. For example if your input stimulus varies more quickly than\n1/2 the sampling rate to which we are downsampling.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Read in audio that's been recorded in epochs.\npath_audio = mne.datasets.mtrf.data_path()\ndata = loadmat(str(path_audio / \"speech_data.mat\"))\naudio = data[\"spectrogram\"].T\nsfreq = float(data[\"Fs\"][0, 0])\nn_decim = 2\naudio = mne.filter.resample(audio, down=n_decim, npad=\"auto\")\nsfreq /= n_decim" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a receptive field\n\nWe'll simulate a linear receptive field for a theoretical neural signal. This\ndefines how the signal will respond to power in this receptive field space.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_freqs = 20\ntmin, tmax = -0.1, 0.4\n\n# To simulate the data we'll create explicit delays here\ndelays_samp = np.arange(np.round(tmin * sfreq), np.round(tmax * sfreq) + 1).astype(int)\ndelays_sec = delays_samp / sfreq\nfreqs = np.linspace(50, 5000, n_freqs)\ngrid = np.array(np.meshgrid(delays_sec, freqs))\n\n# We need data to be shaped as n_epochs, n_features, n_times, so swap axes here\ngrid = grid.swapaxes(0, -1).swapaxes(0, 1)\n\n# Simulate a temporal receptive field with a Gabor filter\nmeans_high = [0.1, 500]\nmeans_low = [0.2, 2500]\ncov = [[0.001, 0], [0, 500000]]\ngauss_high = multivariate_normal.pdf(grid, means_high, cov)\ngauss_low = -1 * multivariate_normal.pdf(grid, means_low, cov)\nweights = gauss_high + gauss_low # Combine to create the \"true\" STRF\nkwargs = dict(\n vmax=np.abs(weights).max(),\n vmin=-np.abs(weights).max(),\n cmap=\"RdBu_r\",\n shading=\"gouraud\",\n)\n\nfig, ax = plt.subplots(layout=\"constrained\")\nax.pcolormesh(delays_sec, freqs, weights, **kwargs)\nax.set(title=\"Simulated STRF\", xlabel=\"Time Lags (s)\", ylabel=\"Frequency (Hz)\")\nplt.setp(ax.get_xticklabels(), rotation=45)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simulate a neural response\n\nUsing this receptive field, we'll create an artificial neural response to\na stimulus.\n\nTo do this, we'll create a time-delayed version of the receptive field, and\nthen calculate the dot product between this and the stimulus. Note that this\nis effectively doing a convolution between the stimulus and the receptive\nfield. See [here](https://en.wikipedia.org/wiki/Convolution) for more\ninformation.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Reshape audio to split into epochs, then make epochs the first dimension.\nn_epochs, n_seconds = 16, 5\naudio = audio[:, : int(n_seconds * sfreq * n_epochs)]\nX = audio.reshape([n_freqs, n_epochs, -1]).swapaxes(0, 1)\nn_times = X.shape[-1]\n\n# Delay the spectrogram according to delays so it can be combined w/ the STRF\n# Lags will now be in axis 1, then we reshape to vectorize\ndelays = np.arange(np.round(tmin * sfreq), np.round(tmax * sfreq) + 1).astype(int)\n\n# Iterate through indices and append\nX_del = np.zeros((len(delays),) + X.shape)\nfor ii, ix_delay in enumerate(delays):\n # These arrays will take/put particular indices in the data\n take = [slice(None)] * X.ndim\n put = [slice(None)] * X.ndim\n if ix_delay > 0:\n take[-1] = slice(None, -ix_delay)\n put[-1] = slice(ix_delay, None)\n elif ix_delay < 0:\n take[-1] = slice(-ix_delay, None)\n put[-1] = slice(None, ix_delay)\n X_del[ii][tuple(put)] = X[tuple(take)]\n\n# Now set the delayed axis to the 2nd dimension\nX_del = np.rollaxis(X_del, 0, 3)\nX_del = X_del.reshape([n_epochs, -1, n_times])\nn_features = X_del.shape[1]\nweights_sim = weights.ravel()\n\n# Simulate a neural response to the sound, given this STRF\ny = np.zeros((n_epochs, n_times))\nfor ii, iep in enumerate(X_del):\n # Simulate this epoch and add random noise\n noise_amp = 0.002\n y[ii] = np.dot(weights_sim, iep) + noise_amp * rng.randn(n_times)\n\n# Plot the first 2 trials of audio and the simulated electrode activity\nX_plt = scale(np.hstack(X[:2]).T).T\ny_plt = scale(np.hstack(y[:2]))\ntime = np.arange(X_plt.shape[-1]) / sfreq\n_, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 6), sharex=True, layout=\"constrained\")\nax1.pcolormesh(time, freqs, X_plt, vmin=0, vmax=4, cmap=\"Reds\", shading=\"gouraud\")\nax1.set_title(\"Input auditory features\")\nax1.set(ylim=[freqs.min(), freqs.max()], ylabel=\"Frequency (Hz)\")\nax2.plot(time, y_plt)\nax2.set(\n xlim=[time.min(), time.max()],\n title=\"Simulated response\",\n xlabel=\"Time (s)\",\n ylabel=\"Activity (a.u.)\",\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fit a model to recover this receptive field\n\nFinally, we'll use the :class:`mne.decoding.ReceptiveField` class to recover\nthe linear receptive field of this signal. Note that properties of the\nreceptive field (e.g. smoothness) will depend on the autocorrelation in the\ninputs and outputs.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Create training and testing data\ntrain, test = np.arange(n_epochs - 1), n_epochs - 1\nX_train, X_test, y_train, y_test = X[train], X[test], y[train], y[test]\nX_train, X_test, y_train, y_test = (\n np.rollaxis(ii, -1, 0) for ii in (X_train, X_test, y_train, y_test)\n)\n# Model the simulated data as a function of the spectrogram input\nalphas = np.logspace(-3, 3, 7)\nscores = np.zeros_like(alphas)\nmodels = []\nfor ii, alpha in enumerate(alphas):\n rf = ReceptiveField(tmin, tmax, sfreq, freqs, estimator=alpha)\n rf.fit(X_train, y_train)\n\n # Now make predictions about the model output, given input stimuli.\n scores[ii] = rf.score(X_test, y_test).item()\n models.append(rf)\n\ntimes = rf.delays_ / float(rf.sfreq)\n\n# Choose the model that performed best on the held out data\nix_best_alpha = np.argmax(scores)\nbest_mod = models[ix_best_alpha]\ncoefs = best_mod.coef_[0]\nbest_pred = best_mod.predict(X_test)[:, 0]\n\n# Plot the original STRF, and the one that we recovered with modeling.\n_, (ax1, ax2) = plt.subplots(\n 1,\n 2,\n figsize=(6, 3),\n sharey=True,\n sharex=True,\n layout=\"constrained\",\n)\nax1.pcolormesh(delays_sec, freqs, weights, **kwargs)\nax2.pcolormesh(times, rf.feature_names, coefs, **kwargs)\nax1.set_title(\"Original STRF\")\nax2.set_title(\"Best Reconstructed STRF\")\nplt.setp([iax.get_xticklabels() for iax in [ax1, ax2]], rotation=45)\n\n# Plot the actual response and the predicted response on a held out stimulus\ntime_pred = np.arange(best_pred.shape[0]) / sfreq\nfig, ax = plt.subplots()\nax.plot(time_pred, y_test, color=\"k\", alpha=0.2, lw=4)\nax.plot(time_pred, best_pred, color=\"r\", lw=1)\nax.set(title=\"Original and predicted activity\", xlabel=\"Time (s)\")\nax.legend([\"Original\", \"Predicted\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize the effects of regularization\n\nAbove we fit a :class:`mne.decoding.ReceptiveField` model for one of many\nvalues for the ridge regularization parameter. Here we will plot the model\nscore as well as the model coefficients for each value, in order to\nvisualize how coefficients change with different levels of regularization.\nThese issues as well as the STRF pipeline are described in detail\nin :footcite:`TheunissenEtAl2001,WillmoreSmyth2003,HoldgrafEtAl2016`.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Plot model score for each ridge parameter\nfig = plt.figure(figsize=(10, 4), layout=\"constrained\")\nax = plt.subplot2grid([2, len(alphas)], [1, 0], 1, len(alphas))\nax.plot(np.arange(len(alphas)), scores, marker=\"o\", color=\"r\")\nax.annotate(\n \"Best parameter\",\n (ix_best_alpha, scores[ix_best_alpha]),\n (ix_best_alpha, scores[ix_best_alpha] - 0.1),\n arrowprops={\"arrowstyle\": \"->\"},\n)\nplt.xticks(np.arange(len(alphas)), [f\"{ii:.0e}\" for ii in alphas])\nax.set(\n xlabel=\"Ridge regularization value\",\n ylabel=\"Score ($R^2$)\",\n xlim=[-0.4, len(alphas) - 0.6],\n)\n\n# Plot the STRF of each ridge parameter\nfor ii, (rf, i_alpha) in enumerate(zip(models, alphas)):\n ax = plt.subplot2grid([2, len(alphas)], [0, ii], 1, 1)\n ax.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs)\n plt.xticks([], [])\n plt.yticks([], [])\nfig.suptitle(\"Model coefficients / scores for many ridge parameters\", y=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using different regularization types\nIn addition to the standard ridge regularization, the\n:class:`mne.decoding.TimeDelayingRidge` class also exposes\n[Laplacian](https://en.wikipedia.org/wiki/Laplacian_matrix) regularization\nterm as:\n\n\\begin{align}\\left[\\begin{matrix}\n 1 & -1 & & & & \\\\\n -1 & 2 & -1 & & & \\\\\n & -1 & 2 & -1 & & \\\\\n & & \\ddots & \\ddots & \\ddots & \\\\\n & & & -1 & 2 & -1 \\\\\n & & & & -1 & 1\\end{matrix}\\right]\\end{align}\n\nThis imposes a smoothness constraint of nearby time samples and/or features.\nQuoting :footcite:`CrosseEtAl2016` :\n\n Tikhonov [identity] regularization (Equation 5) reduces overfitting by\n smoothing the TRF estimate in a way that is insensitive to\n the amplitude of the signal of interest. However, the Laplacian\n approach (Equation 6) reduces off-sample error whilst preserving\n signal amplitude (Lalor et al., 2006). As a result, this approach\n usually leads to an improved estimate of the system\u2019s response (as\n indexed by MSE) compared to Tikhonov regularization.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "scores_lap = np.zeros_like(alphas)\nmodels_lap = []\nfor ii, alpha in enumerate(alphas):\n estimator = TimeDelayingRidge(tmin, tmax, sfreq, reg_type=\"laplacian\", alpha=alpha)\n rf = ReceptiveField(tmin, tmax, sfreq, freqs, estimator=estimator)\n rf.fit(X_train, y_train)\n\n # Now make predictions about the model output, given input stimuli.\n scores_lap[ii] = rf.score(X_test, y_test).item()\n models_lap.append(rf)\n\nix_best_alpha_lap = np.argmax(scores_lap)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compare model performance\nBelow we visualize the model performance of each regularization method\n(ridge vs. Laplacian) for different levels of alpha. As you can see, the\nLaplacian method performs better in general, because it imposes a smoothness\nconstraint along the time and feature dimensions of the coefficients.\nThis matches the \"true\" receptive field structure and results in a better\nmodel fit.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = plt.figure(figsize=(10, 6), layout=\"constrained\")\nax = plt.subplot2grid([3, len(alphas)], [2, 0], 1, len(alphas))\nax.plot(np.arange(len(alphas)), scores_lap, marker=\"o\", color=\"r\")\nax.plot(np.arange(len(alphas)), scores, marker=\"o\", color=\"0.5\", ls=\":\")\nax.annotate(\n \"Best Laplacian\",\n (ix_best_alpha_lap, scores_lap[ix_best_alpha_lap]),\n (ix_best_alpha_lap, scores_lap[ix_best_alpha_lap] - 0.1),\n arrowprops={\"arrowstyle\": \"->\"},\n)\nax.annotate(\n \"Best Ridge\",\n (ix_best_alpha, scores[ix_best_alpha]),\n (ix_best_alpha, scores[ix_best_alpha] - 0.1),\n arrowprops={\"arrowstyle\": \"->\"},\n)\nplt.xticks(np.arange(len(alphas)), [f\"{ii:.0e}\" for ii in alphas])\nax.set(\n xlabel=\"Laplacian regularization value\",\n ylabel=\"Score ($R^2$)\",\n xlim=[-0.4, len(alphas) - 0.6],\n)\n\n# Plot the STRF of each ridge parameter\nxlim = times[[0, -1]]\nfor ii, (rf_lap, rf, i_alpha) in enumerate(zip(models_lap, models, alphas)):\n ax = plt.subplot2grid([3, len(alphas)], [0, ii], 1, 1)\n ax.pcolormesh(times, rf_lap.feature_names, rf_lap.coef_[0], **kwargs)\n ax.set(xticks=[], yticks=[], xlim=xlim)\n if ii == 0:\n ax.set(ylabel=\"Laplacian\")\n ax = plt.subplot2grid([3, len(alphas)], [1, ii], 1, 1)\n ax.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs)\n ax.set(xticks=[], yticks=[], xlim=xlim)\n if ii == 0:\n ax.set(ylabel=\"Ridge\")\nfig.suptitle(\"Model coefficients / scores for laplacian regularization\", y=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot the original STRF, and the one that we recovered with modeling.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "rf = models[ix_best_alpha]\nrf_lap = models_lap[ix_best_alpha_lap]\n_, (ax1, ax2, ax3) = plt.subplots(\n 1,\n 3,\n figsize=(9, 3),\n sharey=True,\n sharex=True,\n layout=\"constrained\",\n)\nax1.pcolormesh(delays_sec, freqs, weights, **kwargs)\nax2.pcolormesh(times, rf.feature_names, rf.coef_[0], **kwargs)\nax3.pcolormesh(times, rf_lap.feature_names, rf_lap.coef_[0], **kwargs)\nax1.set_title(\"Original STRF\")\nax2.set_title(\"Best Ridge STRF\")\nax3.set_title(\"Best Laplacian STRF\")\nplt.setp([iax.get_xticklabels() for iax in [ax1, ax2, ax3]], rotation=45)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### References\n.. footbibliography::\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }