Source code for bciflow.modules.core.kfold

'''
Description
-----------

This module contains the method kfold, which is used to perform a stratified k-fold cross-validation. 
The method is designed to work with eegdata dictionary.

'''
from sklearn.model_selection import StratifiedKFold
import numpy as np
import pandas as pd
import inspect
from ..core.util import util
from ..sf.ea import ea

[docs] def kfold(target, start_window=0, start_test_window=None, window_size=2, pre_folding={}, pos_folding={}): ''' This method is used to perform a stratified k-fold cross-validation. The method is designed to work with eegdata dictionary. Parameters ---------- target : dict Input EEG data. start_window : int start_test_window : int pre_folding : dict A dictionary containing the preprocessing functions to be applied to the data before the cross-validation. The keys are the names of the preprocessing functions, and the values are tuples containing the function and its parameters. pos_folding : dict A dictionary containing the postprocessing functions to be applied to the data before the cross-validation. The keys are the names of the postprocessing functions, and the values are the functions. window_size : float The size of the window to be used in the crop method of eegdata. source : list List of Eeg data from anothers subjects to be used as a source for the Transfer Learning modules Returns ------- results : pandas.DataFrame A pandas dataframe containing the results of the cross-validation. The columns are 'fold', 'tmin', 'true_label', and the labels of the events in the target object. ''' if type(start_window) is float: start_window = [start_window] if start_test_window is None: start_test_window = start_window target_dict = {} for tmin_ in start_test_window: target_dict[tmin_] = util.crop(data=target, tmin=tmin_, window_size=window_size, inplace=False) for tmin_ in start_test_window: for name, pre_func in pre_folding.items(): if inspect.isfunction(pre_func[0]): target_dict[tmin_] = util.apply_to_trials(data=target_dict[tmin_], func=pre_func[0], func_param=pre_func[1], inplace=False) else: target_dict[tmin_] = util.apply_to_trials(data=target_dict[tmin_], func=pre_func[0].transform, func_param=pre_func[1], inplace=False) skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) fold_id = 0 results = [] for train_index, test_index in skf.split(target["y"], target["y"]): fold_id += 1 target_train = [] for tmin_ in start_window: target_train.append(util.get_trial(data=target_dict[tmin_], ids=train_index)) target_train = util.concatenate(target_train) target_test = {} for tmin_ in start_test_window: target_test[tmin_] = util.get_trial(data=target_dict[tmin_], ids=test_index) for name, pos_func in pos_folding.items(): if name != 'clf': if inspect.isfunction(pos_func[0]): target_train = pos_func[0](target_train, **pos_func[1]) else: target_train = pos_func[0].fit_transform(target_train, **pos_func[1]) for tmin_ in start_test_window: if inspect.isfunction(pos_func[0]): target_test[tmin_] = pos_func[0](target_test[tmin_], **pos_func[1]) else: target_test[tmin_] = pos_func[0].transform(target_test[tmin_]) clf, clf_param = pos_folding['clf'] if not inspect.isfunction(clf): clf = clf.fit(target_train['X'], target_train['y'], **clf_param) for tmin_ in start_test_window: try: y_pred = clf.predict_proba(target_test[tmin_]['X']) except: y_pred = np.zeros((len(target_test[tmin_]['y']), len(target['y_dict']))) y_pred = np.round(y_pred, 4) for trial_ in range(len(y_pred)): results.append([fold_id, tmin_, find_key_with_value(target['y_dict'], target_test[tmin_]['y'][trial_]), *y_pred[trial_]]) results = np.array(results) results = pd.DataFrame(results, columns=['fold', 'tmin', 'true_label', *target['y_dict'].keys()]) return results
[docs] def find_key_with_value(dictionary, i): ''' This function returns the key of a dictionary given a value. Parameters ---------- dictionary : dict The dictionary to be searched. i : any The value to be searched for. Returns ------- key : any The key of the dictionary that contains the value i. If the value is not found, returns None. ''' for key, value in dictionary.items(): if value == i: return key return None