Within Session P300

This example shows how to perform a within session analysis on three different P300 datasets.

We will compare two pipelines :

  • Riemannian geometry

  • XDAWN with Linear Discriminant Analysis

We will use the P300 paradigm, which uses the AUC as metric.

# Authors: Pedro Rodrigues <pedro.rodrigues01@gmail.com>
#
# License: BSD (3-clause)

import warnings

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pyriemann.estimation import Xdawn, XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline

import moabb
from moabb.datasets import BNCI2014009
from moabb.evaluations import WithinSessionEvaluation
from moabb.paradigms import P300

Out:

/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'rocket' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'rocket_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'mako' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'mako_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'icefire' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'icefire_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'vlag' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'vlag_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'flare' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'flare_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'crest' which already exists.
  mpl_cm.register_cmap(_name, _cmap)
/home/runner/work/moabb/moabb/.venv/lib/python3.9/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'crest_r' which already exists.
  mpl_cm.register_cmap(_name + "_r", _cmap_r)

getting rid of the warnings about the future

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

moabb.set_log_level("info")

This is an auxiliary transformer that allows one to vectorize data structures in a pipeline For instance, in the case of an X with dimensions Nt x Nc x Ns, one might be interested in a new data structure with dimensions Nt x (Nc.Ns)

class Vectorizer(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y):
        """fit."""
        return self

    def transform(self, X):
        """transform. """
        return np.reshape(X, (X.shape[0], -1))

Create Pipelines

Pipelines must be a dict of sklearn pipeline transformer.

pipelines = {}

We have to do this because the classes are called ‘Target’ and ‘NonTarget’ but the evaluation function uses a LabelEncoder, transforming them to 0 and 1

labels_dict = {"Target": 1, "NonTarget": 0}

pipelines["RG+LDA"] = make_pipeline(
    XdawnCovariances(
        nfilter=2, classes=[labels_dict["Target"]], estimator="lwf", xdawn_estimator="scm"
    ),
    TangentSpace(),
    LDA(solver="lsqr", shrinkage="auto"),
)

pipelines["Xdw+LDA"] = make_pipeline(
    Xdawn(nfilter=2, estimator="scm"), Vectorizer(), LDA(solver="lsqr", shrinkage="auto")
)

Evaluation

We define the paradigm (P300) and use all three datasets available for it. The evaluation will return a DataFrame containing a single AUC score for each subject / session of the dataset, and for each pipeline.

Results are saved into the database, so that if you add a new pipeline, it will not run again the evaluation unless a parameter has changed. Results can be overwritten if necessary.

paradigm = P300(resample=128)
dataset = BNCI2014009()
dataset.subject_list = dataset.subject_list[:2]
datasets = [dataset]
overwrite = True  # set to True if we want to overwrite cached results
evaluation = WithinSessionEvaluation(
    paradigm=paradigm, datasets=datasets, suffix="examples", overwrite=overwrite
)
results = evaluation.process(pipelines)

Out:

009-2014-WithinSession:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|                                              | 0.00/18.5M [00:00<?, ?B/s]

  0%|                                     | 1.02k/18.5M [00:00<39:34, 7.80kB/s]

  0%|                                     | 17.4k/18.5M [00:00<04:02, 76.4kB/s]

  0%|                                     | 33.8k/18.5M [00:00<03:08, 98.3kB/s]

  0%|                                      | 50.2k/18.5M [00:00<02:50, 109kB/s]

  0%|▏                                     | 82.9k/18.5M [00:00<01:55, 159kB/s]

  1%|▏                                     | 99.3k/18.5M [00:00<02:05, 147kB/s]

  1%|▎                                      | 132k/18.5M [00:00<01:41, 181kB/s]

  1%|▎                                      | 151k/18.5M [00:01<01:49, 168kB/s]

  1%|▍                                      | 181k/18.5M [00:01<01:37, 188kB/s]

  1%|▍                                      | 214k/18.5M [00:01<01:28, 207kB/s]

  1%|▍                                      | 236k/18.5M [00:01<01:34, 194kB/s]

  1%|▌                                      | 263k/18.5M [00:01<01:31, 199kB/s]

  2%|▌                                      | 296k/18.5M [00:01<01:25, 214kB/s]

  2%|▋                                      | 345k/18.5M [00:01<01:09, 262kB/s]

  2%|▊                                      | 378k/18.5M [00:01<01:10, 258kB/s]

  2%|▉                                      | 427k/18.5M [00:02<01:01, 293kB/s]

  2%|▉                                      | 460k/18.5M [00:02<01:04, 280kB/s]

  3%|█                                      | 525k/18.5M [00:02<00:52, 345kB/s]

  3%|█▏                                     | 574k/18.5M [00:02<00:50, 354kB/s]

  3%|█▎                                     | 624k/18.5M [00:02<00:49, 360kB/s]

  4%|█▍                                     | 689k/18.5M [00:02<00:44, 401kB/s]

  4%|█▌                                     | 755k/18.5M [00:02<00:41, 430kB/s]

  5%|█▊                                     | 837k/18.5M [00:03<00:36, 488kB/s]

  5%|█▉                                     | 902k/18.5M [00:03<00:35, 491kB/s]

  5%|██                                     | 984k/18.5M [00:03<00:33, 529kB/s]

  6%|██▏                                   | 1.05M/18.5M [00:03<00:33, 520kB/s]

  6%|██▎                                   | 1.13M/18.5M [00:03<00:31, 550kB/s]

  7%|██▌                                   | 1.23M/18.5M [00:03<00:28, 608kB/s]

  7%|██▋                                   | 1.33M/18.5M [00:03<00:26, 648kB/s]

  8%|██▉                                   | 1.44M/18.5M [00:03<00:24, 708kB/s]

  8%|███▏                                  | 1.56M/18.5M [00:04<00:22, 756kB/s]

  9%|███▍                                  | 1.69M/18.5M [00:04<00:20, 827kB/s]

 10%|███▋                                  | 1.82M/18.5M [00:04<00:19, 875kB/s]

 11%|████                                  | 1.97M/18.5M [00:04<00:17, 947kB/s]

 12%|████▎                                | 2.13M/18.5M [00:04<00:15, 1.03MB/s]

 12%|████▌                                | 2.26M/18.5M [00:04<00:15, 1.03MB/s]

 13%|████▉                                | 2.44M/18.5M [00:04<00:14, 1.13MB/s]

 14%|█████▎                               | 2.64M/18.5M [00:05<00:12, 1.24MB/s]

 15%|█████▋                               | 2.84M/18.5M [00:05<00:11, 1.32MB/s]

 16%|██████                               | 3.03M/18.5M [00:05<00:11, 1.37MB/s]

 18%|██████▌                              | 3.28M/18.5M [00:05<00:10, 1.52MB/s]

 19%|███████                              | 3.54M/18.5M [00:05<00:09, 1.61MB/s]

 21%|███████▋                             | 3.82M/18.5M [00:05<00:08, 1.76MB/s]

 22%|████████▏                            | 4.11M/18.5M [00:05<00:07, 1.90MB/s]

 24%|████████▊                            | 4.42M/18.5M [00:05<00:07, 2.01MB/s]

 26%|█████████▌                           | 4.77M/18.5M [00:06<00:06, 2.17MB/s]

 27%|██████████▏                          | 5.08M/18.5M [00:06<00:05, 2.25MB/s]

 29%|██████████▉                          | 5.46M/18.5M [00:06<00:05, 2.43MB/s]

 32%|███████████▋                         | 5.87M/18.5M [00:06<00:04, 2.60MB/s]

 34%|████████████▌                        | 6.29M/18.5M [00:06<00:04, 2.79MB/s]

 36%|█████████████▍                       | 6.73M/18.5M [00:06<00:03, 2.96MB/s]

 39%|██████████████▍                      | 7.23M/18.5M [00:06<00:03, 3.14MB/s]

 42%|███████████████▍                     | 7.75M/18.5M [00:07<00:03, 3.39MB/s]

 45%|████████████████▌                    | 8.29M/18.5M [00:07<00:02, 3.58MB/s]

 48%|█████████████████▋                   | 8.86M/18.5M [00:07<00:02, 3.81MB/s]

 51%|██████████████████▉                  | 9.50M/18.5M [00:07<00:02, 4.07MB/s]

 55%|████████████████████▎                | 10.2M/18.5M [00:07<00:01, 4.34MB/s]

 59%|█████████████████████▊               | 10.9M/18.5M [00:07<00:01, 4.71MB/s]

 63%|███████████████████████▎             | 11.7M/18.5M [00:07<00:01, 5.05MB/s]

 67%|████████████████████████▉            | 12.5M/18.5M [00:07<00:01, 5.43MB/s]

 72%|██████████████████████████▊          | 13.4M/18.5M [00:08<00:00, 5.79MB/s]

 78%|████████████████████████████▋        | 14.4M/18.5M [00:08<00:00, 6.30MB/s]

 83%|██████████████████████████████▋      | 15.4M/18.5M [00:08<00:00, 6.63MB/s]

 89%|████████████████████████████████▉    | 16.5M/18.5M [00:08<00:00, 7.03MB/s]

 95%|███████████████████████████████████▏ | 17.6M/18.5M [00:08<00:00, 7.49MB/s]

  0%|                                              | 0.00/18.5M [00:00<?, ?B/s]
100%|█████████████████████████████████████| 18.5M/18.5M [00:00<00:00, 19.9GB/s]

009-2014-WithinSession:  50%|#####     | 1/2 [00:21<00:21, 21.55s/it]

  0%|                                              | 0.00/18.5M [00:00<?, ?B/s]

  0%|                                     | 1.02k/18.5M [00:00<39:43, 7.77kB/s]

  0%|                                     | 17.4k/18.5M [00:00<04:02, 76.3kB/s]

  0%|                                     | 33.8k/18.5M [00:00<03:08, 98.2kB/s]

  0%|                                      | 50.2k/18.5M [00:00<02:50, 108kB/s]

  0%|▏                                     | 82.9k/18.5M [00:00<01:56, 159kB/s]

  1%|▏                                     | 99.3k/18.5M [00:00<02:05, 147kB/s]

  1%|▎                                      | 132k/18.5M [00:00<01:42, 180kB/s]

  1%|▎                                      | 151k/18.5M [00:01<01:49, 167kB/s]

  1%|▍                                      | 181k/18.5M [00:01<01:37, 188kB/s]

  1%|▍                                      | 214k/18.5M [00:01<01:28, 207kB/s]

  1%|▌                                      | 247k/18.5M [00:01<01:23, 220kB/s]

  1%|▌                                      | 269k/18.5M [00:01<01:29, 205kB/s]

  2%|▌                                      | 296k/18.5M [00:01<01:29, 204kB/s]

  2%|▋                                      | 345k/18.5M [00:01<01:11, 255kB/s]

  2%|▊                                      | 394k/18.5M [00:01<01:02, 291kB/s]

  2%|▉                                      | 443k/18.5M [00:02<00:57, 316kB/s]

  3%|█                                      | 493k/18.5M [00:02<00:54, 333kB/s]

  3%|█▏                                     | 558k/18.5M [00:02<00:47, 379kB/s]

  3%|█▎                                     | 624k/18.5M [00:02<00:43, 415kB/s]

  4%|█▍                                     | 689k/18.5M [00:02<00:40, 438kB/s]

  4%|█▌                                     | 755k/18.5M [00:02<00:38, 456kB/s]

  4%|█▋                                     | 820k/18.5M [00:02<00:37, 469kB/s]

  5%|█▉                                     | 919k/18.5M [00:03<00:31, 552kB/s]

  5%|██                                    | 1.02M/18.5M [00:03<00:28, 604kB/s]

  6%|██▎                                   | 1.12M/18.5M [00:03<00:26, 646kB/s]

  7%|██▌                                   | 1.23M/18.5M [00:03<00:24, 707kB/s]

  7%|██▊                                   | 1.34M/18.5M [00:03<00:22, 753kB/s]

  8%|██▉                                   | 1.44M/18.5M [00:03<00:22, 753kB/s]

  8%|███▏                                  | 1.57M/18.5M [00:03<00:20, 825kB/s]

  9%|███▌                                  | 1.72M/18.5M [00:03<00:18, 909kB/s]

 10%|███▊                                  | 1.87M/18.5M [00:04<00:17, 971kB/s]

 11%|████                                 | 2.03M/18.5M [00:04<00:15, 1.05MB/s]

 12%|████▍                                | 2.21M/18.5M [00:04<00:14, 1.14MB/s]

 13%|████▋                                | 2.36M/18.5M [00:04<00:14, 1.13MB/s]

 14%|█████                                | 2.56M/18.5M [00:04<00:12, 1.24MB/s]

 15%|█████▌                               | 2.77M/18.5M [00:04<00:11, 1.35MB/s]

 16%|█████▉                               | 3.00M/18.5M [00:04<00:10, 1.45MB/s]

 18%|██████▍                              | 3.25M/18.5M [00:05<00:09, 1.57MB/s]

 19%|███████                              | 3.51M/18.5M [00:05<00:08, 1.70MB/s]

 20%|███████▌                             | 3.79M/18.5M [00:05<00:08, 1.80MB/s]

 22%|████████▏                            | 4.08M/18.5M [00:05<00:07, 1.93MB/s]

 24%|████████▋                            | 4.36M/18.5M [00:05<00:07, 1.98MB/s]

 25%|█████████▎                           | 4.67M/18.5M [00:05<00:06, 2.10MB/s]

 27%|█████████▉                           | 5.00M/18.5M [00:05<00:06, 2.21MB/s]

 29%|██████████▋                          | 5.37M/18.5M [00:05<00:05, 2.38MB/s]

 31%|███████████▌                         | 5.77M/18.5M [00:06<00:04, 2.59MB/s]

 33%|████████████▎                        | 6.15M/18.5M [00:06<00:04, 2.67MB/s]

 36%|█████████████▏                       | 6.60M/18.5M [00:06<00:04, 2.92MB/s]

 38%|██████████████▏                      | 7.11M/18.5M [00:06<00:03, 3.19MB/s]

 41%|███████████████▏                     | 7.62M/18.5M [00:06<00:03, 3.39MB/s]

 44%|████████████████▍                    | 8.23M/18.5M [00:06<00:02, 3.75MB/s]

 48%|█████████████████▌                   | 8.82M/18.5M [00:06<00:02, 3.97MB/s]

 51%|██████████████████▉                  | 9.50M/18.5M [00:07<00:02, 4.29MB/s]

 55%|████████████████████▍                | 10.2M/18.5M [00:07<00:01, 4.64MB/s]

 59%|█████████████████████▉               | 11.0M/18.5M [00:07<00:01, 5.03MB/s]

 64%|███████████████████████▋             | 11.8M/18.5M [00:07<00:01, 5.42MB/s]

 69%|█████████████████████████▍           | 12.7M/18.5M [00:07<00:00, 5.80MB/s]

 74%|███████████████████████████▎         | 13.7M/18.5M [00:07<00:00, 6.26MB/s]

 79%|█████████████████████████████▎       | 14.7M/18.5M [00:07<00:00, 6.47MB/s]

 85%|███████████████████████████████▌     | 15.8M/18.5M [00:07<00:00, 7.06MB/s]

 91%|█████████████████████████████████▊   | 16.9M/18.5M [00:08<00:00, 7.51MB/s]

 98%|████████████████████████████████████▏| 18.1M/18.5M [00:08<00:00, 8.01MB/s]

  0%|                                              | 0.00/18.5M [00:00<?, ?B/s]
100%|█████████████████████████████████████| 18.5M/18.5M [00:00<00:00, 23.0GB/s]

009-2014-WithinSession: 100%|##########| 2/2 [00:42<00:00, 20.96s/it]
009-2014-WithinSession: 100%|##########| 2/2 [00:42<00:00, 21.05s/it]

Plot Results

Here we plot the results to compare the two pipelines

fig, ax = plt.subplots(facecolor="white", figsize=[8, 4])

sns.stripplot(
    data=results,
    y="score",
    x="pipeline",
    ax=ax,
    jitter=True,
    alpha=0.5,
    zorder=1,
    palette="Set1",
)
sns.pointplot(data=results, y="score", x="pipeline", ax=ax, zorder=1, palette="Set1")

ax.set_ylabel("ROC AUC")
ax.set_ylim(0.5, 1)

fig.show()
plot within session p300

Total running time of the script: ( 0 minutes 42.327 seconds)

Gallery generated by Sphinx-Gallery