"""
.. _ex-spoc-cmc:

====================================
Continuous Target Decoding with SPoC
====================================

Source Power Comodulation (SPoC) :footcite:`DahneEtAl2014` allows to identify
the composition of
orthogonal spatial filters that maximally correlate with a continuous target.

SPoC can be seen as an extension of the CSP for continuous variables.

Here, SPoC is applied to decode the (continuous) fluctuation of an
electromyogram from MEG beta activity using data from
`Cortico-Muscular Coherence example of FieldTrip
<http://www.fieldtriptoolbox.org/tutorial/coherence>`_
"""

# Author: Alexandre Barachant <alexandre.barachant@gmail.com>
#         Jean-Rémi King <jeanremi.king@gmail.com>
#
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

# %%
import matplotlib.pyplot as plt
from sklearn.linear_model import Ridge
from sklearn.model_selection import KFold, cross_val_predict
from sklearn.pipeline import make_pipeline

import mne
from mne import Epochs
from mne.datasets.fieldtrip_cmc import data_path
from mne.decoding import SPoC, get_spatial_filter_from_estimator

# Define parameters
fname = data_path() / "SubjectCMC.ds"
raw = mne.io.read_raw_ctf(fname)
raw.crop(50.0, 200.0)  # crop for memory purposes

# Filter muscular activity to only keep high frequencies
emg = raw.copy().pick(["EMGlft"]).load_data()
emg.filter(20.0, None)

# Filter MEG data to focus on beta band
raw.pick(picks=["meg", "ref_meg"]).load_data()
raw.filter(15.0, 30.0)

# Build epochs as sliding windows over the continuous raw file
events = mne.make_fixed_length_events(raw, id=1, duration=0.75)

# Epoch length is 1.5 second
meg_epochs = Epochs(raw, events, tmin=0.0, tmax=1.5, baseline=None, detrend=1, decim=12)
emg_epochs = Epochs(emg, events, tmin=0.0, tmax=1.5, baseline=None)

# Prepare classification
X = meg_epochs.get_data()
y = emg_epochs.get_data().var(axis=2)[:, 0]  # target is EMG power

# Classification pipeline with SPoC spatial filtering and Ridge Regression
spoc = SPoC(n_components=2, log=True, reg="oas", rank="full")
clf = make_pipeline(spoc, Ridge())
# Define a two fold cross-validation
cv = KFold(n_splits=2, shuffle=False)

# Run cross validation
y_preds = cross_val_predict(clf, X, y, cv=cv)

# Plot the True EMG power and the EMG power predicted from MEG data
fig, ax = plt.subplots(1, 1, figsize=[10, 4], layout="constrained")
times = raw.times[meg_epochs.events[:, 0] - raw.first_samp]
ax.plot(times, y_preds, color="b", label="Predicted EMG")
ax.plot(times, y, color="r", label="True EMG")
ax.set_xlabel("Time (s)")
ax.set_ylabel("EMG Power")
ax.set_title("SPoC MEG Predictions")
plt.legend()
plt.show()

##############################################################################
# Plot the contributions to the detected components (i.e., the forward model)

spoc.fit(X, y)
spf = get_spatial_filter_from_estimator(spoc, info=meg_epochs.info)
spf.plot_scree()

# Plot patterns for the first three components
# with largest absolute generalized eigenvalues,
# as we can see on the scree plot
spf.plot_patterns(components=[0, 1, 2])


##############################################################################
# References
# ----------
# .. footbibliography::

# %%
