Cross-Session on Multiple Datasets#

This example shows how to perform a cross-session analysis on two MI datasets using a CSP+LDA pipeline

The cross session evaluation context will evaluate performance using a leave one session out cross-validation. For each session in the dataset, a model is trained on every other session and performance are evaluated on the current session.

# Authors: Sylvain Chevallier <sylvain.chevallier@uvsq.fr>
#
# License: BSD (3-clause)

import warnings

import matplotlib.pyplot as plt
import seaborn as sns
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline

import moabb
from moabb.datasets import BNCI2014_001, Zhou2016
from moabb.evaluations import CrossSessionEvaluation
from moabb.paradigms import LeftRightImagery


warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)
moabb.set_log_level("info")

Loading Dataset#

Load 2 subjects of BNCI 2014-004 and Zhou2016 datasets, with 2 session each

Choose Paradigm#

We select the paradigm MI, applying a bandpass filter (8-35 Hz) on the data and we will keep only left- and right-hand motor imagery

paradigm = LeftRightImagery(fmin=8, fmax=35)

Create Pipelines#

Use the Common Spatial Patterns with 8 components and a Linear Discriminant Analysis classifier.

pipeline = {}
pipeline["CSP+LDA"] = make_pipeline(CSP(n_components=8), LDA())

Get Data (optional)#

To get access to the EEG signals downloaded from the dataset, you could use dataset.get_data(subjects=[subject_id]) to obtain the EEG under an MNE format, stored in a dictionary of sessions and runs. Otherwise, paradigm.get_data(dataset=dataset, subjects=[subject_id]) allows to obtain the EEG data in sklearn format, the labels and the meta information. The data are preprocessed according to the paradigm requirements.

# X_all, labels_all, meta_all = [], [], []
# for d in datasets:
#     # sessions = d.get_data(subjects=[2])
#     X, labels, meta = paradigm.get_data(dataset=d, subjects=[2])
#     X_all.append(X)
#     labels_all.append(labels)
#     meta_all.append(meta)

Evaluation#

The evaluation will return a DataFrame containing a single AUC score for each subject / session of the dataset, and for each pipeline.

overwrite = True  # set to True if we want to overwrite cached results

evaluation = CrossSessionEvaluation(
    paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite
)
results = evaluation.process(pipeline)

print(results.head())
Zhou2016-CrossSession:   0%|          | 0/2 [00:00<?, ?it/s]/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 60 events (all good), 0 – 5 s (baseline off), ~8.0 MB, data loaded,
 'left_hand': 30
 'right_hand': 30>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 59 events (all good), 0 – 5 s (baseline off), ~7.9 MB, data loaded,
 'left_hand': 30
 'right_hand': 29>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 50 events (all good), 0 – 5 s (baseline off), ~6.7 MB, data loaded,
 'left_hand': 25
 'right_hand': 25>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 50 events (all good), 0 – 5 s (baseline off), ~6.7 MB, data loaded,
 'left_hand': 25
 'right_hand': 25>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 50 events (all good), 0 – 5 s (baseline off), ~6.7 MB, data loaded,
 'left_hand': 25
 'right_hand': 25>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 50 events (all good), 0 – 5 s (baseline off), ~6.7 MB, data loaded,
 'left_hand': 25
 'right_hand': 25>
  warn(f"warnEpochs {epochs}")

Zhou2016-CrossSession:  50%|#####     | 1/2 [00:04<00:04,  4.14s/it]/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 50 events (all good), 0 – 5 s (baseline off), ~6.7 MB, data loaded,
 'left_hand': 25
 'right_hand': 25>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 50 events (all good), 0 – 5 s (baseline off), ~6.7 MB, data loaded,
 'left_hand': 25
 'right_hand': 25>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 50 events (all good), 0 – 5 s (baseline off), ~6.7 MB, data loaded,
 'left_hand': 25
 'right_hand': 25>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 40 events (all good), 0 – 5 s (baseline off), ~5.4 MB, data loaded,
 'left_hand': 20
 'right_hand': 20>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 50 events (all good), 0 – 5 s (baseline off), ~6.7 MB, data loaded,
 'left_hand': 25
 'right_hand': 25>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 50 events (all good), 0 – 5 s (baseline off), ~6.7 MB, data loaded,
 'left_hand': 25
 'right_hand': 25>
  warn(f"warnEpochs {epochs}")

Zhou2016-CrossSession: 100%|##########| 2/2 [00:07<00:00,  3.61s/it]
Zhou2016-CrossSession: 100%|##########| 2/2 [00:07<00:00,  3.69s/it]

BNCI2014-001-CrossSession:   0%|          | 0/2 [00:00<?, ?it/s]/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")

BNCI2014-001-CrossSession:  50%|#####     | 1/2 [00:03<00:03,  3.14s/it]/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")
/home/runner/work/moabb/moabb/moabb/datasets/preprocessing.py:279: UserWarning: warnEpochs <Epochs | 24 events (all good), 2 – 6 s (baseline off), ~4.1 MB, data loaded,
 'left_hand': 12
 'right_hand': 12>
  warn(f"warnEpochs {epochs}")

BNCI2014-001-CrossSession: 100%|##########| 2/2 [00:06<00:00,  3.11s/it]
BNCI2014-001-CrossSession: 100%|##########| 2/2 [00:06<00:00,  3.11s/it]
      score      time  samples subject  ... channels  n_sessions   dataset pipeline
0  0.851412  0.338031    200.0       1  ...       14           3  Zhou2016  CSP+LDA
1  0.935200  0.259145    219.0       1  ...       14           3  Zhou2016  CSP+LDA
2  0.939200  0.322441    219.0       1  ...       14           3  Zhou2016  CSP+LDA
3  0.908000  0.228958    190.0       2  ...       14           3  Zhou2016  CSP+LDA
4  0.770370  0.251072    200.0       2  ...       14           3  Zhou2016  CSP+LDA

[5 rows x 9 columns]

Plot Results#

Here we plot the results, indicating the score for each session and subject

sns.catplot(
    data=results,
    x="session",
    y="score",
    hue="subject",
    col="dataset",
    kind="bar",
    palette="viridis",
)
plt.show()
dataset = Zhou2016, dataset = BNCI2014-001

Total running time of the script: ( 0 minutes 20.481 seconds)

Estimated memory usage: 205 MB

Gallery generated by Sphinx-Gallery