moabb.evaluations.WithinSubjectSplitter#
- class moabb.evaluations.WithinSubjectSplitter(n_folds: int = 5, shuffle: bool = True, random_state: int | None = None, cv_class: type[~sklearn.model_selection._split.BaseCrossValidator] = <class 'sklearn.model_selection._split.StratifiedKFold'>, **cv_kwargs)[source]#
Data splitter for within subject evaluation.
Within-subject evaluation uses k-fold cross-validation to determine train and test sets for each subject across all their sessions. This splitter assumes that all data from all subjects is already known and loaded.
Unlike WithinSessionSplitter which performs cross-validation within each session, this splitter performs cross-validation across all sessions within each subject.
The inner cross-validation strategy can be changed by passing the cv_class and cv_kwargs arguments. By default, it uses StratifiedKFold.
- Parameters:
n_folds (int, default=5) – Number of folds. Must be at least 2.
shuffle (bool, default=True) – Whether to shuffle each class’s samples before splitting into batches. Note that the samples within each split will not be shuffled.
random_state (int, RandomState instance or None, default=None) – Controls the randomness of splits. Only used when shuffle is True. Pass an int for reproducible output across multiple function calls.
cv_class (cross-validation class, default=StratifiedKFold) – Inner cross-validation strategy for splitting within each subject.
cv_kwargs (dict) – Additional arguments to pass to the inner cross-validation strategy.
- get_n_splits(metadata)[source]#
Return the number of splits for the cross-validation.
The number of splits is the number of subjects times the number of folds.
We try to keep the same behaviour as the sklearn cross-validation classes.
- Parameters:
metadata (pd.DataFrame) – The metadata containing the subject and session information.
- Returns:
n_splits – The number of splits for the cross-validation
- Return type:
- split(y, metadata)[source]#
Generate indices to split data into training and test set.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Training data, where n_samples is the number of samples and n_features is the number of features.
y (array-like of shape (n_samples,)) – The target variable for supervised learning problems.
groups (array-like of shape (n_samples,), default=None) – Group labels for the samples used while splitting the dataset into train/test set.
- Yields:
train (ndarray) – The training set indices for that split.
test (ndarray) – The testing set indices for that split.