Note
Click here to download the full example code
Within Session P300¶
This example shows how to perform a within session analysis on three different P300 datasets.
We will compare two pipelines :
Riemannian geometry
XDAWN with Linear Discriminant Analysis
We will use the P300 paradigm, which uses the AUC as metric.
# Authors: Pedro Rodrigues <pedro.rodrigues01@gmail.com>
#
# License: BSD (3-clause)
import warnings
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pyriemann.estimation import Xdawn, XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline
import moabb
from moabb.datasets import BNCI2014009
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import P300
Out:
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'rocket' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'rocket_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'mako' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'mako_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'icefire' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'icefire_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'vlag' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'vlag_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'flare' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'flare_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'crest' which already exists.
mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'crest_r' which already exists.
mpl_cm.register_cmap(_name + "_r", _cmap_r)
getting rid of the warnings about the future
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)
moabb.set_log_level("info")
This is an auxiliary transformer that allows one to vectorize data structures in a pipeline For instance, in the case of an X with dimensions Nt x Nc x Ns, one might be interested in a new data structure with dimensions Nt x (Nc.Ns)
class Vectorizer(BaseEstimator, TransformerMixin):
def __init__(self):
pass
def fit(self, X, y):
"""fit."""
return self
def transform(self, X):
"""transform. """
return np.reshape(X, (X.shape[0], -1))
Create Pipelines¶
Pipelines must be a dict of sklearn pipeline transformer.
pipelines = {}
We have to do this because the classes are called ‘Target’ and ‘NonTarget’ but the evaluation function uses a LabelEncoder, transforming them to 0 and 1
labels_dict = {"Target": 1, "NonTarget": 0}
pipelines["RG+LDA"] = make_pipeline(
XdawnCovariances(
nfilter=2, classes=[labels_dict["Target"]], estimator="lwf", xdawn_estimator="scm"
),
TangentSpace(),
LDA(solver="lsqr", shrinkage="auto"),
)
pipelines["Xdw+LDA"] = make_pipeline(
Xdawn(nfilter=2, estimator="scm"), Vectorizer(), LDA(solver="lsqr", shrinkage="auto")
)
Evaluation¶
We define the paradigm (P300) and use all three datasets available for it. The evaluation will return a DataFrame containing a single AUC score for each subject / session of the dataset, and for each pipeline.
Results are saved into the database, so that if you add a new pipeline, it will not run again the evaluation unless a parameter has changed. Results can be overwritten if necessary.
paradigm = P300(resample=128)
dataset = BNCI2014009()
dataset.subject_list = dataset.subject_list[:2]
datasets = [dataset]
overwrite = True # set to True if we want to overwrite cached results
evaluation = WithinSessionEvaluation(
paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite
)
results = evaluation.process(pipelines)
Out:
009-2014-WithinSession: 0%| | 0/2 [00:00<?, ?it/s]
0%| | 0.00/18.5M [00:00<?, ?B/s][A
0%| | 1.02k/18.5M [00:00<39:34, 7.80kB/s][A
0%| | 17.4k/18.5M [00:00<04:02, 76.4kB/s][A
0%| | 33.8k/18.5M [00:00<03:08, 98.3kB/s][A
0%| | 50.2k/18.5M [00:00<02:50, 109kB/s][A
0%|▏ | 82.9k/18.5M [00:00<01:55, 159kB/s][A
1%|▏ | 99.3k/18.5M [00:00<02:05, 147kB/s][A
1%|▎ | 132k/18.5M [00:00<01:41, 181kB/s][A
1%|▎ | 151k/18.5M [00:01<01:49, 168kB/s][A
1%|▍ | 181k/18.5M [00:01<01:37, 188kB/s][A
1%|▍ | 214k/18.5M [00:01<01:28, 207kB/s][A
1%|▍ | 236k/18.5M [00:01<01:34, 194kB/s][A
1%|▌ | 263k/18.5M [00:01<01:31, 199kB/s][A
2%|▌ | 296k/18.5M [00:01<01:25, 214kB/s][A
2%|▋ | 345k/18.5M [00:01<01:09, 262kB/s][A
2%|▊ | 378k/18.5M [00:01<01:10, 258kB/s][A
2%|▉ | 427k/18.5M [00:02<01:01, 293kB/s][A
2%|▉ | 460k/18.5M [00:02<01:04, 280kB/s][A
3%|█ | 525k/18.5M [00:02<00:52, 345kB/s][A
3%|█▏ | 574k/18.5M [00:02<00:50, 354kB/s][A
3%|█▎ | 624k/18.5M [00:02<00:49, 360kB/s][A
4%|█▍ | 689k/18.5M [00:02<00:44, 401kB/s][A
4%|█▌ | 755k/18.5M [00:02<00:41, 430kB/s][A
5%|█▊ | 837k/18.5M [00:03<00:36, 488kB/s][A
5%|█▉ | 902k/18.5M [00:03<00:35, 491kB/s][A
5%|██ | 984k/18.5M [00:03<00:33, 529kB/s][A
6%|██▏ | 1.05M/18.5M [00:03<00:33, 520kB/s][A
6%|██▎ | 1.13M/18.5M [00:03<00:31, 550kB/s][A
7%|██▌ | 1.23M/18.5M [00:03<00:28, 608kB/s][A
7%|██▋ | 1.33M/18.5M [00:03<00:26, 648kB/s][A
8%|██▉ | 1.44M/18.5M [00:03<00:24, 708kB/s][A
8%|███▏ | 1.56M/18.5M [00:04<00:22, 756kB/s][A
9%|███▍ | 1.69M/18.5M [00:04<00:20, 827kB/s][A
10%|███▋ | 1.82M/18.5M [00:04<00:19, 875kB/s][A
11%|████ | 1.97M/18.5M [00:04<00:17, 947kB/s][A
12%|████▎ | 2.13M/18.5M [00:04<00:15, 1.03MB/s][A
12%|████▌ | 2.26M/18.5M [00:04<00:15, 1.03MB/s][A
13%|████▉ | 2.44M/18.5M [00:04<00:14, 1.13MB/s][A
14%|█████▎ | 2.64M/18.5M [00:05<00:12, 1.24MB/s][A
15%|█████▋ | 2.84M/18.5M [00:05<00:11, 1.32MB/s][A
16%|██████ | 3.03M/18.5M [00:05<00:11, 1.37MB/s][A
18%|██████▌ | 3.28M/18.5M [00:05<00:10, 1.52MB/s][A
19%|███████ | 3.54M/18.5M [00:05<00:09, 1.61MB/s][A
21%|███████▋ | 3.82M/18.5M [00:05<00:08, 1.76MB/s][A
22%|████████▏ | 4.11M/18.5M [00:05<00:07, 1.90MB/s][A
24%|████████▊ | 4.42M/18.5M [00:05<00:07, 2.01MB/s][A
26%|█████████▌ | 4.77M/18.5M [00:06<00:06, 2.17MB/s][A
27%|██████████▏ | 5.08M/18.5M [00:06<00:05, 2.25MB/s][A
29%|██████████▉ | 5.46M/18.5M [00:06<00:05, 2.43MB/s][A
32%|███████████▋ | 5.87M/18.5M [00:06<00:04, 2.60MB/s][A
34%|████████████▌ | 6.29M/18.5M [00:06<00:04, 2.79MB/s][A
36%|█████████████▍ | 6.73M/18.5M [00:06<00:03, 2.96MB/s][A
39%|██████████████▍ | 7.23M/18.5M [00:06<00:03, 3.14MB/s][A
42%|███████████████▍ | 7.75M/18.5M [00:07<00:03, 3.39MB/s][A
45%|████████████████▌ | 8.29M/18.5M [00:07<00:02, 3.58MB/s][A
48%|█████████████████▋ | 8.86M/18.5M [00:07<00:02, 3.81MB/s][A
51%|██████████████████▉ | 9.50M/18.5M [00:07<00:02, 4.07MB/s][A
55%|████████████████████▎ | 10.2M/18.5M [00:07<00:01, 4.34MB/s][A
59%|█████████████████████▊ | 10.9M/18.5M [00:07<00:01, 4.71MB/s][A
63%|███████████████████████▎ | 11.7M/18.5M [00:07<00:01, 5.05MB/s][A
67%|████████████████████████▉ | 12.5M/18.5M [00:07<00:01, 5.43MB/s][A
72%|██████████████████████████▊ | 13.4M/18.5M [00:08<00:00, 5.79MB/s][A
78%|████████████████████████████▋ | 14.4M/18.5M [00:08<00:00, 6.30MB/s][A
83%|██████████████████████████████▋ | 15.4M/18.5M [00:08<00:00, 6.63MB/s][A
89%|████████████████████████████████▉ | 16.5M/18.5M [00:08<00:00, 7.03MB/s][A
95%|███████████████████████████████████▏ | 17.6M/18.5M [00:08<00:00, 7.49MB/s][A
0%| | 0.00/18.5M [00:00<?, ?B/s][A
100%|█████████████████████████████████████| 18.5M/18.5M [00:00<00:00, 19.9GB/s]
009-2014-WithinSession: 50%|##### | 1/2 [00:21<00:21, 21.55s/it]
0%| | 0.00/18.5M [00:00<?, ?B/s][A
0%| | 1.02k/18.5M [00:00<39:43, 7.77kB/s][A
0%| | 17.4k/18.5M [00:00<04:02, 76.3kB/s][A
0%| | 33.8k/18.5M [00:00<03:08, 98.2kB/s][A
0%| | 50.2k/18.5M [00:00<02:50, 108kB/s][A
0%|▏ | 82.9k/18.5M [00:00<01:56, 159kB/s][A
1%|▏ | 99.3k/18.5M [00:00<02:05, 147kB/s][A
1%|▎ | 132k/18.5M [00:00<01:42, 180kB/s][A
1%|▎ | 151k/18.5M [00:01<01:49, 167kB/s][A
1%|▍ | 181k/18.5M [00:01<01:37, 188kB/s][A
1%|▍ | 214k/18.5M [00:01<01:28, 207kB/s][A
1%|▌ | 247k/18.5M [00:01<01:23, 220kB/s][A
1%|▌ | 269k/18.5M [00:01<01:29, 205kB/s][A
2%|▌ | 296k/18.5M [00:01<01:29, 204kB/s][A
2%|▋ | 345k/18.5M [00:01<01:11, 255kB/s][A
2%|▊ | 394k/18.5M [00:01<01:02, 291kB/s][A
2%|▉ | 443k/18.5M [00:02<00:57, 316kB/s][A
3%|█ | 493k/18.5M [00:02<00:54, 333kB/s][A
3%|█▏ | 558k/18.5M [00:02<00:47, 379kB/s][A
3%|█▎ | 624k/18.5M [00:02<00:43, 415kB/s][A
4%|█▍ | 689k/18.5M [00:02<00:40, 438kB/s][A
4%|█▌ | 755k/18.5M [00:02<00:38, 456kB/s][A
4%|█▋ | 820k/18.5M [00:02<00:37, 469kB/s][A
5%|█▉ | 919k/18.5M [00:03<00:31, 552kB/s][A
5%|██ | 1.02M/18.5M [00:03<00:28, 604kB/s][A
6%|██▎ | 1.12M/18.5M [00:03<00:26, 646kB/s][A
7%|██▌ | 1.23M/18.5M [00:03<00:24, 707kB/s][A
7%|██▊ | 1.34M/18.5M [00:03<00:22, 753kB/s][A
8%|██▉ | 1.44M/18.5M [00:03<00:22, 753kB/s][A
8%|███▏ | 1.57M/18.5M [00:03<00:20, 825kB/s][A
9%|███▌ | 1.72M/18.5M [00:03<00:18, 909kB/s][A
10%|███▊ | 1.87M/18.5M [00:04<00:17, 971kB/s][A
11%|████ | 2.03M/18.5M [00:04<00:15, 1.05MB/s][A
12%|████▍ | 2.21M/18.5M [00:04<00:14, 1.14MB/s][A
13%|████▋ | 2.36M/18.5M [00:04<00:14, 1.13MB/s][A
14%|█████ | 2.56M/18.5M [00:04<00:12, 1.24MB/s][A
15%|█████▌ | 2.77M/18.5M [00:04<00:11, 1.35MB/s][A
16%|█████▉ | 3.00M/18.5M [00:04<00:10, 1.45MB/s][A
18%|██████▍ | 3.25M/18.5M [00:05<00:09, 1.57MB/s][A
19%|███████ | 3.51M/18.5M [00:05<00:08, 1.70MB/s][A
20%|███████▌ | 3.79M/18.5M [00:05<00:08, 1.80MB/s][A
22%|████████▏ | 4.08M/18.5M [00:05<00:07, 1.93MB/s][A
24%|████████▋ | 4.36M/18.5M [00:05<00:07, 1.98MB/s][A
25%|█████████▎ | 4.67M/18.5M [00:05<00:06, 2.10MB/s][A
27%|█████████▉ | 5.00M/18.5M [00:05<00:06, 2.21MB/s][A
29%|██████████▋ | 5.37M/18.5M [00:05<00:05, 2.38MB/s][A
31%|███████████▌ | 5.77M/18.5M [00:06<00:04, 2.59MB/s][A
33%|████████████▎ | 6.15M/18.5M [00:06<00:04, 2.67MB/s][A
36%|█████████████▏ | 6.60M/18.5M [00:06<00:04, 2.92MB/s][A
38%|██████████████▏ | 7.11M/18.5M [00:06<00:03, 3.19MB/s][A
41%|███████████████▏ | 7.62M/18.5M [00:06<00:03, 3.39MB/s][A
44%|████████████████▍ | 8.23M/18.5M [00:06<00:02, 3.75MB/s][A
48%|█████████████████▌ | 8.82M/18.5M [00:06<00:02, 3.97MB/s][A
51%|██████████████████▉ | 9.50M/18.5M [00:07<00:02, 4.29MB/s][A
55%|████████████████████▍ | 10.2M/18.5M [00:07<00:01, 4.64MB/s][A
59%|█████████████████████▉ | 11.0M/18.5M [00:07<00:01, 5.03MB/s][A
64%|███████████████████████▋ | 11.8M/18.5M [00:07<00:01, 5.42MB/s][A
69%|█████████████████████████▍ | 12.7M/18.5M [00:07<00:00, 5.80MB/s][A
74%|███████████████████████████▎ | 13.7M/18.5M [00:07<00:00, 6.26MB/s][A
79%|█████████████████████████████▎ | 14.7M/18.5M [00:07<00:00, 6.47MB/s][A
85%|███████████████████████████████▌ | 15.8M/18.5M [00:07<00:00, 7.06MB/s][A
91%|█████████████████████████████████▊ | 16.9M/18.5M [00:08<00:00, 7.51MB/s][A
98%|████████████████████████████████████▏| 18.1M/18.5M [00:08<00:00, 8.01MB/s][A
0%| | 0.00/18.5M [00:00<?, ?B/s][A
100%|█████████████████████████████████████| 18.5M/18.5M [00:00<00:00, 23.0GB/s]
009-2014-WithinSession: 100%|##########| 2/2 [00:42<00:00, 20.96s/it]
009-2014-WithinSession: 100%|##########| 2/2 [00:42<00:00, 21.05s/it]
Plot Results¶
Here we plot the results to compare the two pipelines
fig, ax = plt.subplots(facecolor="white", figsize=[8, 4])
sns.stripplot(
data=results,
y="score",
x="pipeline",
ax=ax,
jitter=True,
alpha=0.5,
zorder=1,
palette="Set1",
)
sns.pointplot(data=results, y="score", x="pipeline", ax=ax, zorder=1, palette="Set1")
ax.set_ylabel("ROC AUC")
ax.set_ylim(0.5, 1)
fig.show()

Total running time of the script: ( 0 minutes 42.327 seconds)