From 3a047e04601a37ffb700c8aef631cc24f267e2b2 Mon Sep 17 00:00:00 2001 From: Frida Heskebeck <frida.heskebeck@control.lth.se> Date: Thu, 12 Sep 2024 11:55:57 +0200 Subject: [PATCH] Add data_generation.py --- data_generation/data_generation.py | 1560 ++++++++++++++++++++++++++++ 1 file changed, 1560 insertions(+) create mode 100644 data_generation/data_generation.py diff --git a/data_generation/data_generation.py b/data_generation/data_generation.py new file mode 100644 index 0000000..7cd58da --- /dev/null +++ b/data_generation/data_generation.py @@ -0,0 +1,1560 @@ +#!python + +# ----------------- +# IMPORTS +# ----------------- + +# General Libraries +import os +import sys +import time +import math +import random +import datetime +import subprocess +import pickle +import copy +import warnings +import shutil + + +# Numpy Library +import numpy as np +rng = np.random.default_rng(seed=42) +# Making numpy single threaded, so I can multithread this program myself: +os.environ["OMP_NUM_THREADS"] = "1" # export OMP_NUM_THREADS=4 +os.environ["OPENBLAS_NUM_THREADS"] = "1" # export OPENBLAS_NUM_THREADS=4 +os.environ["MKL_NUM_THREADS"] = "1" # export MKL_NUM_THREADS=6 +os.environ["VECLIB_MAXIMUM_THREADS"] = "1" # export VECLIB_MAXIMUM_THREADS=4 +os.environ["NUMEXPR_NUM_THREADS"] = "1" # export NUMEXPR_NUM_THREADS=6 + +# Pandas Library +import pandas as pd + +# Pyrieman library +from pyriemann.estimation import Covariances +from pyriemann.classification import MDM, TSclassifier, SVC, KNearestNeighbor,MeanField +from pyriemann.utils.distance import distance +from pyriemann.utils.mean import mean_covariance +from pyriemann.embedding import SpectralEmbedding, LocallyLinearEmbedding +from pyriemann.datasets.simulated import make_classification_transfer +from pyriemann.transfer import decode_domains,encode_domains,TLCenter,TLStretch,TLRotate,TLClassifier,TLDummy +from pyriemann.utils.viz import plot_embedding +from pyriemann.tangentspace import TangentSpace + +# Scikit-learn library +from sklearn.pipeline import make_pipeline +from sklearn.metrics import accuracy_score,confusion_matrix +from sklearn.model_selection import LeaveOneGroupOut, KFold, StratifiedKFold, StratifiedShuffleSplit +from sklearn import svm +from sklearn.neural_network import MLPClassifier +from sklearn import decomposition +from sklearn.linear_model import LogisticRegression +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis + +# Matplotlib Library +import matplotlib as mpl +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches + + +# Multiprocessing +from multiprocessing import Pool + +# MNE library +from mne import Epochs, pick_types, events_from_annotations +from mne.io import concatenate_raws +from mne.io.edf import read_raw_edf +from mne.datasets import eegbci +from mne import set_log_level + + + +# ----------------- +# GLOBAL PARAMETERS +# ----------------- + +NBR_THREADS = 8 # Nbr of threads for parallellisation + +NBR_FOLDS = 5 # Number of folds for cross validation. [BEWARE! The source selection script assumes this is 5.] + + +# Name for datafolder. +# DATA_FOLDER = 'all_users_left_right' +# DATA_FOLDER = 'all_users_hands_feet' +DATA_FOLDER = time.strftime("data_%Y_%m_%d-%H_%M_%S") + + +# Physionet: 46 samples in 2 classes +DATA_SET = 'PhysioNet' + + +nbr_samples_per_class = 17 # How many samples to use from each class as traing data from target user. +ALL_TRAINING_SIZE = [2*nbr_samples_per_class] # Physionet has two classes. + +# Which subjects to use. +# SUBJECTS_TO_USE = np.arange(1,110) # all +# SUBJECTS_TO_USE = [1,109] # First and last +SUBJECTS_TO_USE = [1,2, 4, 7, 8, 15, 20,109] # Some. + + +# Which classes to use. +# CLASSES_PHYSIONET = 'left_right' +CLASSES_PHYSIONET = 'hands_feet' + + +# Automatic plotting of subresults. No plotting if more than 40 source users. +if len(SUBJECTS_TO_USE) > 40: + PLOTTING_SUBRESULTS = False +else: + PLOTTING_SUBRESULTS = True + + +random_state = 42 + +# Classifiers used for MI task classification after transfer learning. +ALL_CLFS = [ + ('MDM',make_pipeline(MDM(metric=dict(mean='riemann', distance='riemann')))), + ] + + +print(__doc__) + +class Logger(object): + """For saving everything printed to file as well.""" + def __init__(self, filename="console_log.txt"): + """Initialize the class.""" + self.terminal = sys.stdout + self.log = open(filename, "a") + def write(self, message): + """Write both to stdout and to file.""" + self.terminal.write(message) + self.log.write(message) + def flush(self): + """We don't do flush.""" + pass + + +def print_git_version(): + """We want to know what version of the source code we're dealing with.""" + # Ideally, commit to git _every time_ before you run the code. + try: + git_output = subprocess.check_output("git log -n 1", shell=True, stderr=subprocess.STDOUT) + print("\n\n### git log -n 1\n%s\n\n" % git_output.decode("utf-8")) + git_output2 = subprocess.check_output("git status", shell=True, stderr=subprocess.STDOUT) + print("### git status\n%s\n\n" % git_output2.decode("utf-8")) + git_output3 = subprocess.check_output("git diff", shell=True, stderr=subprocess.STDOUT) + print("### git diff\n%s\n\n\n" % git_output3.decode("utf-8")) + except: + print("\n\n### WARNING: The code you're running is NOT under GIT version control. You can do better. Behave.\n\n") + return + +def print_source_code(): + """ + Printing source code to terminal + """ + print("\n\n###\n### Listing the python source code used, from the file '%s':\n###\n\n" % __file__) + with open(__file__, 'r') as f: + print(f.read()) + print("\n\n###\n### End of python source code from the file '%s'.\n###\n\n" % __file__) + return + +def save_source_code(run_logdir): + """ + Saving source code to run_logdir folder. + """ + print("\n\n###\n### Saving the python source code used, from the file '%s':\n###" % __file__) + shutil.copy2(__file__, f'{run_logdir}/_script.py') + return + +# ----------------------------------------------- + + + +# ----------------------------------------------- +# PhysioNet : Loading data +# ----------------------------------------------- +def get_subject_dataset_physio_net(subject,classes='hands_feet'): + """ + Loads and preprocesses EEG data for a specific subject from the PhysioNet EEG Motor Movement/Imagery Dataset. + + This function downloads the EEG data for a given subject, filters it, selects relevant channels, + extracts motor imagery events, and computes covariance matrices based on the selected epochs. + The function supports two types of motor imagery tasks: 'hands vs feet' and 'left hand vs right hand'. + + Parameters + ---------- + subject : int + The subject ID from the PhysioNet EEG Motor Movement/Imagery dataset. + + classes : str, optional + The motor imagery task to extract. Either 'hands_feet' (default) or 'left_right'. + - 'hands_feet' extracts events related to imagining hand and foot movements (runs 6, 10, 14). + - 'left_right' extracts events related to imagining left and right hand movements (runs 4, 8, 12). + + Returns + ------- + covs : ndarray of shape (n_epochs, n_channels, n_channels) + Covariance matrices computed from the EEG data for each epoch. + + labels : ndarray of shape (n_epochs,) + Encoded labels corresponding to each epoch. For 'hands_feet', 0 corresponds to 'hands' and 1 to 'feet'. + For 'left_right', 0 corresponds to 'left hand' and 1 to 'right hand'. + + str_labels : ndarray of shape (n_epochs,) + String labels for each epoch, corresponding to the classes provided ('hands', 'feet', etc.). + + Notes + ----- + The data is preprocessed by selecting EEG channels, applying a bandpass filter (7-35 Hz), and + extracting epochs from 1s to 2s after cue onset. The covariance matrices are computed on scaled data. + The dataset used is from PhysioNet: https://www.physionet.org/content/eegmmidb/1.0.0/ + """ + + # Dataset description: https://www.physionet.org/content/eegmmidb/1.0.0/ + # Processing template: https://pyriemann.readthedocs.io/en/latest/auto_examples/transfer/plot_bci_example.html#sphx-glr-auto-examples-transfer-plot-bci-example-py + + # Consider epochs that start 1s after cue onset. + tmin, tmax = 1., 2. + if classes == 'hands_feet': + event_id = dict(hands=2, feet=3) + str_classes = ['hands','feet'] + runs = [6, 10, 14] # motor imagery: hands vs feet + elif classes == 'left_right': + event_id = dict(left_hand=2, right_hand=3) + runs = [4, 8, 12] # Motor imagery: left vs right hand + str_classes = ['left_hand','right_hand'] + else: + bug + + # Download data with MNE + raw_files = [ + read_raw_edf(f, preload=True) for f in eegbci.load_data(subject, runs) + ] + raw = concatenate_raws(raw_files) + + # Select only EEG channels + picks = pick_types( + raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads') + # select only nine electrodes: F3, Fz, F4, C3, Cz, C4, P3, Pz, P4 + picks = picks[[31, 33, 35, 8, 10, 12, 48, 50, 52]] + + # Apply band-pass filter + raw.filter(7., 35., method='iir', picks=picks) + + # Check the events + events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3)) + + # Define the epochs + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + proj=True, + picks=picks, + baseline=None, + preload=True, + verbose=False) + + # Extract the labels for each event + labels = epochs.events[:, -1] - 2 + str_labels = np.array([str_classes[i] for i in labels]) + + # Compute covariance matrices on scaled data + covs = Covariances().fit_transform(1e6 * epochs.get_data()) + del(raw) + + + return covs, labels,str_labels + + + +def load_data_physio_net(subjects,classes='hands_feet'): + """ + Loads and preprocesses EEG data for multiple subjects from the PhysioNet EEG Motor Movement/Imagery Dataset. + + For each subject, this function downloads, processes, and extracts the covariance matrices and labels. + The function returns a list of clusters, where each cluster contains the covariance matrices, labels, + and metadata for a given subject. + + Parameters + ---------- + subjects : list of int + A list of subject IDs from the PhysioNet EEG Motor Movement/Imagery dataset. + + classes : str, optional + The motor imagery task to extract. Either 'hands_feet' (default) or 'left_right'. + - 'hands_feet' extracts events related to imagining hand and foot movements. + - 'left_right' extracts events related to imagining left and right hand movements. + + Returns + ------- + clusters : list of dict + A list where each dictionary contains the following keys: + - 'X_cov': ndarray of shape (n_epochs, n_channels, n_channels) + Covariance matrices for each subject. + - 'y': ndarray of shape (n_epochs,) + String labels for each epoch, corresponding to the classes provided ('hands', 'feet', etc.). + - 'y_enc': ndarray of shape (n_epochs,) + Encoded labels with subject information. For example, 'subj_X/hand'. + - 'subject': int + The subject ID. + - 'cluster_name': str + The cluster name, in the format 'subj_X', where X is the subject ID. + + Notes + ----- + This function calls `get_subject_dataset_physio_net` internally to process the data for each subject. + """ + clusters = [] + + for subj in subjects: + covs, labels,str_labels = get_subject_dataset_physio_net(subj,classes) + + cluster_data = { + 'X_cov' : covs, + 'y': str_labels, + 'y_enc' : np.array(['subj_{}/'.format(subj) + str(item) for item in str_labels]), + 'subject': subj, + 'cluster_name': 'subj_{}'.format(subj) + } + + clusters.append(cluster_data) + + return clusters + + + +# ----------------------------------------------- +# Innermost function. Calculate n-fold accuracy of source and target data. +# ----------------------------------------------- +def performance_on_one_cluster_vs_one_cluster(source_cluster_data, target_cluster_data, clf,nbr_folds=5,training_size=8,run_logdir='.',plotting=False,df_data=None): + """ + Evaluates the performance of a classifier using transfer learning between a source and target data cluster. + + The classifier is trained on the source cluster and tested on the target cluster using transfer learning techniques + (re-centering and rotation). It performs a StratifiedShuffleSplit cross-validation on the target data, optionally + plotting the embeddings and confusion matrices for each fold. + + Parameters + ---------- + source_cluster_data : dict + A dictionary containing data from the source cluster. + Expected to have 'X_cov' (covariance matrices), 'y' (labels), and 'y_enc' (encoded labels). + + target_cluster_data : dict + A dictionary containing data from the target cluster. + Expected to have 'X_cov' (covariance matrices), 'y' (labels), 'y_enc' (encoded labels), and 'cluster_name' (metadata). + + clf : classifier object + A classifier object implementing the 'fit' and 'predict' methods. + + nbr_folds : int, optional + The number of folds for cross-validation. Default is 5. + + training_size : int, optional + The size of the training set for each fold. Default is 8. + + run_logdir : str, optional + Directory to save the generated plots. Default is the current directory ('.'). + + plotting : bool, optional + Whether to plot embeddings and confusion matrices during the transfer learning process. Default is False. + + df_data : pandas.DataFrame, optional + A dataframe to append the results. If None, a new dataframe will be created. Default is None. + + Returns + ------- + df_data : pandas.DataFrame + A DataFrame containing performance metrics and distances computed between source and target clusters for + each fold before rotation in the transfer learning algorithm. + + benefit_of_transfer_learning : ndarray of shape (nbr_folds,) + The benefit of transfer learning for each fold, calculated as the difference in accuracy between + after transfer learning and without transfer learning. + + results_test_data : list of float + A list of accuracy scores for the classifier on the target test data after applying transfer learning, for each fold. + + results_no_transfer_learning_target_data : list of float + A list of accuracy scores for the classifier on the target test data without transfer learning, for each fold. + + Notes + ----- + The function applies two transfer learning techniques: TLCenter (re-centering) and TLRotate (rotation using Riemannian metrics). + If plotting is enabled, embeddings of the source and target clusters before and after transfer learning are displayed. + It also generates confusion matrices for each fold and saves them to the specified log directory. + """ + + # Extract data to variables. + X_source = source_cluster_data['X_cov'] + y_source = source_cluster_data['y'] + y_source_enc = source_cluster_data['y_enc'] + X_target = target_cluster_data['X_cov'] + y_target = target_cluster_data['y'] + y_target_enc = target_cluster_data['y_enc'] + name_target_domain = target_cluster_data['cluster_name'] + + # Placeholders for results + results_test_data = [] + results_no_transfer_learning_target_data = [] + distances_all_folds = [] # [target:center-class1,target:center-class2,target:class1-class2,source:center-class1,source:center-class2,source:class1-class2] + + + + if plotting: + # Create figure for plotting of data before and after transfer learning. + ncols = 5 + nrows = 3 + fig, axs = plt.subplots(figsize=(35,21), ncols=ncols, nrows=nrows, sharey=True,sharex=True,layout='constrained') + classes = np.unique(y_target) + confusion_matrices = [] + + # Splitter function for test and training split. + splitter = StratifiedShuffleSplit(n_splits=nbr_folds, test_size=training_size, random_state=42) # Shuffels data. Test size = training size in our case. + + # Loop over folds for test and training splits. + for i, (test_index, train_index) in enumerate(splitter.split(X_target, y_target)): + # Get data + X_train_source = copy.deepcopy(X_source) + y_train_source = copy.deepcopy(y_source) + y_enc_source = copy.deepcopy(y_source_enc) + + # Split target data into training and test subsets + X_train_target = copy.deepcopy(X_target[train_index]) + y_train_target = copy.deepcopy(y_target[train_index]) + y_enc_target_train = copy.deepcopy(y_target_enc[train_index]) + + X_test_target = copy.deepcopy(X_target) + y_test_target = copy.deepcopy(y_target) + y_enc_target_test = copy.deepcopy(y_target_enc) + + + # Plot data before transfer learning + if plotting: + ax = axs[0,i] + X_training_list = np.concatenate([X_train_source,X_train_target]) + y_enc_training_list = np.concatenate([y_enc_source, y_enc_target_train]) + plot_embeddings_source_target_train_test(X_training_list,X_test_target,y_enc_training_list,y_enc_target_test,ax,target_cluster_name=name_target_domain) + + + + # objects for transfer learning + tl_recenter = TLCenter(target_domain=name_target_domain) + tl_rotate = TLRotate(target_domain=name_target_domain, metric='riemann') + + + # ================================== + # ==== TRANSFER LEARNING ====== + # ============================ + # Format data in correct way for transfer learning + X_train = np.concatenate((X_train_source,X_train_target)) + y_enc = np.concatenate((y_enc_source,y_enc_target_train)) + y_train = np.concatenate((y_train_source,y_train_target)) + + + + # === RECENTER === + # Train the classifier after transfer lerning and evaluate on the target test data + X_train = tl_recenter.fit_transform(X_train, y_enc) + X_test_target = tl_recenter.transform(X_test_target) + if plotting: + # Plot data after recentering + plot_embeddings_source_target_train_test(X_train,X_test_target,y_enc,y_enc_target_test,axs[1,i],target_cluster_name=name_target_domain) + + # INTRA SUBJECT TRAINING - TARGET + nbr_training_data = len(y_train_target) + clf.fit(X_train[-nbr_training_data:],y_train[-nbr_training_data:]) + y_pred = clf.predict(X_test_target) + score = accuracy_score(y_test_target, y_pred) + results_no_transfer_learning_target_data.append(score) + + + + # Finding means and then distances between means. + classes = np.unique(y_train_source) + metric = 'riemann' + means_target = {'center':mean_covariance(X_train[-nbr_training_data:], metric=metric), + 'class1':mean_covariance(X_train[-nbr_training_data:][y_train_target==classes[0]], metric=metric), + 'class2':mean_covariance(X_train[-nbr_training_data:][y_train_target==classes[1]], metric=metric)} + + means_source = {'center':mean_covariance(X_train[:-nbr_training_data], metric=metric), + 'class1':mean_covariance(X_train[:-nbr_training_data][y_train_source==classes[0]], metric=metric), + 'class2':mean_covariance(X_train[:-nbr_training_data][y_train_source==classes[1]], metric=metric)} + + columns_data = ['Target:Class1-Target:Class2', + 'Target:Center-Target:Class1', + 'Target:Center-Target:Class2', + 'Source:Class1-Source:Class2', + 'Source:Center-Source:Class1', + 'Source:Center-Source:Class2', + 'Target:Center-Source:Center', + 'Target:Center-Source:Class1', + 'Target:Center-Source:Class2', + 'Target:Class1-Source:Center', + 'Target:Class1-Source:Class1', + 'Target:Class1-Source:Class2', + 'Target:Class2-Source:Center', + 'Target:Class2-Source:Class1', + 'Target:Class2-Source:Class2', + 'Source:Dispersion:all', + 'Source:Dispersion:Class1', + 'Source:Dispersion:Class2', + 'Target:Dispersion:all', + 'Target:Dispersion:Class1', + 'Target:Dispersion:Class2', + ] + + + + metric = 'riemann' + distances = [distance(means_target['class1'], means_target['class2'], metric=metric), + distance(means_target['center'], means_target['class1'], metric=metric), + distance(means_target['center'], means_target['class2'], metric=metric), + distance(means_source['class1'], means_source['class2'], metric=metric), + distance(means_source['center'], means_source['class1'], metric=metric), + distance(means_source['center'], means_source['class2'], metric=metric), + distance(means_target['center'], means_source['center'], metric=metric), + distance(means_target['center'], means_source['class1'], metric=metric), + distance(means_target['center'], means_source['class2'], metric=metric), + distance(means_target['class1'], means_source['center'], metric=metric), + distance(means_target['class1'], means_source['class1'], metric=metric), + distance(means_target['class1'], means_source['class2'], metric=metric), + distance(means_target['class2'], means_source['center'], metric=metric), + distance(means_target['class2'], means_source['class1'], metric=metric), + distance(means_target['class2'], means_source['class2'], metric=metric), + np.sum([distance(cov, means_source['center'], metric=metric,squared=True) for cov in X_train[:-nbr_training_data]]) / len(X_train[:-nbr_training_data]), + np.sum([distance(cov, means_source['class1'], metric=metric,squared=True) for cov in X_train[:-nbr_training_data][y_train_source==classes[0]]]) / len(X_train[:-nbr_training_data][y_train_source==classes[0]]), + np.sum([distance(cov, means_source['class2'], metric=metric,squared=True) for cov in X_train[:-nbr_training_data][y_train_source==classes[1]]]) / len(X_train[:-nbr_training_data][y_train_source==classes[1]]), + np.sum([distance(cov, means_target['center'], metric=metric,squared=True) for cov in X_train[-nbr_training_data:]]) / len(X_train[-nbr_training_data:]), + np.sum([distance(cov, means_target['class1'], metric=metric,squared=True) for cov in X_train[-nbr_training_data:][y_train_target==classes[0]]]) / len(X_train[-nbr_training_data:][y_train_target==classes[0]]), + np.sum([distance(cov, means_target['class2'], metric=metric,squared=True) for cov in X_train[-nbr_training_data:][y_train_target==classes[1]]]) / len(X_train[-nbr_training_data:][y_train_target==classes[1]]) + ] + + distances_all_folds.append(distances) + + # === ROTATE === + X_train = tl_rotate.fit_transform(X_train, y_enc) + X_test_target = tl_rotate.transform(X_test_target) + if plotting: + plot_embeddings_source_target_train_test(X_train,X_test_target,y_enc,y_enc_target_test,axs[2,i],target_cluster_name=name_target_domain) + + # Fit classifier on data after transfer leanring and evaluate performance + # clf.fit(X_train,y_train) # With data from target + clf.fit(X_train[:-nbr_training_data],y_train[:-nbr_training_data]) # Without data from target. + y_pred = clf.predict(X_test_target) + score = accuracy_score(y_test_target, y_pred) + results_test_data.append(score) + + if plotting: + # Print text about transfer learning performance in figure + axs[0,i].text(0.25,0.05,f'Accuracy no transfer learning: {results_no_transfer_learning_target_data[-1]:.4f}',transform=axs[0,i].transAxes) + axs[2,i].text(0.25,0.08,f'Accuracy transfer learning: {results_test_data[-1]:.4f}',transform=axs[2,i].transAxes) + axs[2,i].text(0.25,0.05,f'Benefit transfer learning: {results_test_data[-1]-results_no_transfer_learning_target_data[-1]:.4f}',transform=axs[2,i].transAxes,color='b' if results_test_data[-1]-results_no_transfer_learning_target_data[-1]>=0 else 'r') + # Calculate confusion matrix + conf_matrix = confusion_matrix(y_test_target, y_pred,labels=classes) + confusion_matrices.append(conf_matrix) + + if plotting: + # save figure + plt.figure(fig) + axs[0,0].set_ylabel('Before transfer learning') + axs[1,0].set_ylabel('After Recenter') + axs[2,0].set_ylabel('After After recenter and then rotate') + plt.suptitle('Embeddings for the different folds during transfer learning.') + print('--> Saving transfer learning plot') + save_string = '{}_data_analysis_clusters_target_{}_and_source_{}_training_size_{}'.format(target_cluster_data['cluster_name'],target_cluster_data['cluster_name'],source_cluster_data['cluster_name'],training_size) # Filename + plt.savefig('{}/png_{}.png'.format(run_logdir,save_string)) # Save the figure in the specified directory + plt.close(fig) + print('--> Saving complete') + + + # Plot confusion matrix + conf_matrix = np.mean(confusion_matrices,axis=0) + fig, ax = plt.subplots(figsize=(7,7),layout='constrained') + + plot_confusion_matrix(conf_matrix, classes,ax,fig) + + print('--> Saving confusion matrix') + save_string = '{}_confusion_matrix_target_{}_and_source_{}_training_size_{}'.format(target_cluster_data['cluster_name'],target_cluster_data['cluster_name'],source_cluster_data['cluster_name'],training_size) # Filename + plt.savefig('{}/png_{}.png'.format(run_logdir,save_string)) # Save the figure in the specified directory + plt.close(fig) + print('--> Saving complete') + + + # calculate benefit of transfer learning. + benefit_of_transfer_learning = np.array(results_test_data) - np.array(results_no_transfer_learning_target_data) + + # add data to dataframe + df_here = pd.DataFrame(distances_all_folds,columns=columns_data) + df_here.insert(len(df_here.columns),'Source_cluster',np.repeat(source_cluster_data['cluster_name'],nbr_folds)) + df_here.insert(len(df_here.columns),'Target_cluster',np.repeat(target_cluster_data['cluster_name'],nbr_folds)) + df_here.insert(len(df_here.columns),'Target_accuracy',np.array(results_no_transfer_learning_target_data)) + df_here.insert(len(df_here.columns),'Benefit_transfer_learning',benefit_of_transfer_learning) + df_here.insert(len(df_here.columns),'Accuracy_after_transfer_learning',results_test_data) + df_data = pd.concat([df_data,df_here],ignore_index=True) + + return df_data, benefit_of_transfer_learning, results_test_data, results_no_transfer_learning_target_data + + + + +def plot_embeddings_source_target_train_test(X_training_list, X_test, y_enc_training_list, y_enc_test, ax, target_cluster_name): + """ + Plots 2D spectral embeddings of source and target data clusters before and after transfer learning. + + This function combines the training and test data, performs spectral embedding using the Riemannian + metric, and visualizes the embeddings. The target cluster is highlighted, and points are color-coded + based on the classes and clusters. + + Parameters + ---------- + X_training_list : ndarray of shape (n_samples_train, n_channels, n_channels) + Covariance matrices from the training data (source and target). + + X_test : ndarray of shape (n_samples_test, n_channels, n_channels) + Covariance matrices from the test data (target). + + y_enc_training_list : ndarray of shape (n_samples_train,) + Encoded labels for the training data, including source and target data. + + y_enc_test : ndarray of shape (n_samples_test,) + Encoded labels for the test data (target). + + ax : matplotlib.axes.Axes + The axes object where the embeddings will be plotted. + + target_cluster_name : str + The name of the target cluster (domain) used to distinguish it from the source cluster. + + Returns + ------- + ax : matplotlib.axes.Axes + The axes object with the plotted embeddings. + + Notes + ----- + The function uses spectral embedding to reduce the dimensionality of the covariance matrices to 2D. + The embeddings for different classes and clusters are plotted using different colors and markers. + The identity matrix (used as a reference) is also plotted with a special marker. + """ + + # Extract data to variables + X = np.concatenate((X_training_list,X_test)) + y_enc_test = ['test_'+item for item in y_enc_test] + y_enc = np.concatenate([y_enc_training_list,y_enc_test]) + + X, y, domain = decode_domains(X, y_enc) + # instantiate object for doing spectral embeddings + emb = SpectralEmbedding(n_components=2, metric='riemann') + # emb = LocallyLinearEmbedding(n_components=2,metric='riemann',n_neighbors=10) + + # embed the original source and target datasets + points = np.concatenate([X, np.eye(X.shape[1])[None, :, :]]) # stack the identity + embedded_points = emb.fit_transform(points) + embedded_points = embedded_points - embedded_points[-1] + identity = embedded_points[-1] + embedded_points = embedded_points[:-1] + + markers_list = ['o', 's', '^', 'D', 'v', 'p', '*', 'H', 'x', '.', ',', '<', '>', '1', '2', '3', '4', 'h', '+', 'd', '|', '_'] + colors_list = ['b', 'g', 'r', 'c', 'm', 'y', 'k', '#FFA07A', '#FF5733', '#33FFCE', '#8E44AD', '#3498DB', '#2ECC71', '#F1C40F', '#E74C3C', '#34495E', '#1ABC9C', '#9B59B6', '#34495E', '#95A5A6', '#7F8C8D'] + + unique_labels = np.unique(y) + unique_clusters = np.unique(domain) + # Move target index cluster last. + index_of_target = np.where(unique_clusters == target_cluster_name)[0] + unique_clusters = np.delete(unique_clusters,index_of_target) + unique_clusters = np.append(unique_clusters,target_cluster_name) + + # Plotting + for i in range(len(unique_clusters)): + for j in range(len(unique_labels)): + data_filter = (y==unique_labels[j]) & (domain==unique_clusters[i]) + ax.scatter( + embedded_points[data_filter][:, 0], + embedded_points[data_filter][:, 1], + c=colors_list[j], s=50, alpha=0.50,marker=markers_list[i]) + ax.scatter([],[],c=colors_list[j],marker=markers_list[i], label='{} - {}'.format(unique_clusters[i],unique_labels[j])) # Add label + ax.scatter(identity[ 0], identity[1], c='k', s=80, marker="*") + ax.legend() + + return ax + + +def plot_confusion_matrix(conf_mat, classes, ax, fig): + """ + Plots a confusion matrix using matplotlib, with color-coded values and class labels. + + Parameters + ---------- + conf_mat : ndarray of shape (n_classes, n_classes) + The confusion matrix to be plotted. + + classes : list of str + A list of class names corresponding to the rows and columns of the confusion matrix. + + ax : matplotlib.axes.Axes + The axes object on which to draw the confusion matrix. + + fig : matplotlib.figure.Figure + The figure object for adding the color bar. + + Returns + ------- + ax : matplotlib.axes.Axes + The axes object with the plotted confusion matrix. + + Notes + ----- + The confusion matrix is displayed with different colors to indicate higher or lower values, and the + color of the text within the cells is chosen based on a threshold to maintain readability. The matrix + has labeled x and y ticks for predicted and true labels, respectively, and a color bar to indicate the + intensity of the values. + """ + # Plot matrix + cax = ax.matshow(conf_mat, cmap='Greys') + + # Set labels + plt.xticks(np.arange(len(classes)), classes, rotation=45) + plt.yticks(np.arange(len(classes)), classes) + + threshold = np.max(conf_mat)/2 + np.min(conf_mat)/2 # treshold for change of color for text. + # Loop over data dimensions and create text annotations. + for i in range(len(classes)): + for j in range(len(classes)): + color = "white" if conf_mat[i, j] > threshold else "black" + ax.text(j, i, conf_mat[i, j], ha="center", va="center", color=color) + + ax.grid(False) + plt.xlabel('Predicted labels') + plt.ylabel('True labels') + plt.title('Confusion Matrix') + fig.colorbar(cax) + return ax + + + +# ----------------------------------------------- +# For one cluster, test performance vs all the other clusters. +# ----------------------------------------------- +def performance_for_one_cluster_vs_all_cluster(clusters, clf, target_cluster_idx, nbr_folds=5, run_logdir='.', training_size=8, plotting=False, data_folder='data'): + """ + Evaluates and plots the performance of a classifier for one target cluster against all other source clusters using transfer learning. + + This function iterates over all clusters, using each one as a source cluster and the target cluster for testing. + It computes the classifier's performance for each source-target combination and, if specified, generates a bar plot + showing the mean accuracy and standard deviation for transfer learning performance. + + Parameters + ---------- + clusters : list of dict + A list of dictionaries, where each dictionary represents a data cluster with 'X_cov', 'y', 'y_enc', 'subject', + and 'cluster_name' keys. + + clf : classifier object + A classifier object implementing 'fit' and 'predict' methods. + + target_cluster_idx : int + The index of the cluster to be used as the target cluster for testing. + + nbr_folds : int, optional + The number of cross-validation folds. Default is 5. + + run_logdir : str, optional + The directory where plots and results will be saved. Default is the current directory ('.'). + + training_size : int, optional + The size of the training data used from the target cluster in each fold. Default is 8. + + plotting : bool, optional + Whether to plot the results (bar plots and other visuals). Default is False. + + data_folder : str, optional + The folder where the results will be saved as pickle files. Default is 'data'. + + Returns + ------- + None + This function does not return any values, but it saves the computed results and optionally generated plots to disk. + + Notes + ----- + - This function uses transfer learning by training the classifier on each source cluster and testing it on the target cluster. + - If plotting is enabled, a bar plot is generated to visualize the classifier's performance with and without transfer learning + for each source-target combination. + - Results are saved in pickle files, including transfer learning benefits, performance data, and processed DataFrames. + """ + + + + i = target_cluster_idx + + # Placeholders for results + all_benefit_of_transfer_learning = [] # Store results for this target cluster against all other clusters + all_results_test_data = [] + all_results_no_transfer_learning_target_data = [] + + cluster_target = clusters[i] # Target cluster for testing + + # If run has to be restarted, see if data is already analyzed. + save_folder_pickle = './{}/{}'.format(data_folder,cluster_target['cluster_name']) + if os.path.exists(save_folder_pickle): + # The files are already saved to disc + print('!!! {} already analyzed'.format(cluster_target['cluster_name'])) + return + elif not os.path.exists('./{}/'.format(data_folder)): + print('creating folder {}'.format(data_folder)) + os.makedirs('./{}/'.format(data_folder)) + + # If things should be plotted, save in specified folder. + if plotting: + save_folder = '{}/{}_training_size_{}'.format(run_logdir,cluster_target['cluster_name'],training_size) + os.mkdir(save_folder) + run_logdir = save_folder + x_ticks = [] # To store labels for the x-axis in the plot + + target_subject = cluster_target['subject'] # Extract the subject of the target cluster + df_data = pd.DataFrame() + + # Iterate over all clusters/subjects to use them as source clusters + for j in range(len(clusters)): + print('Target cluster: {}, analyzing cluster {} of {}'.format(cluster_target['cluster_name'],j+1,len(clusters))) + cluster_source = clusters[j] # Source cluster for training + + # Evaluate the performance of one cluster against another + df_data, benefit_of_transfer_learning, results_test_data, results_no_transfer_learning_target_data = performance_on_one_cluster_vs_one_cluster(source_cluster_data=cluster_source, target_cluster_data=cluster_target, clf=clf,df_data=df_data,nbr_folds=nbr_folds,training_size=training_size,run_logdir=run_logdir,plotting=plotting) + + # Store data. + all_benefit_of_transfer_learning.append(benefit_of_transfer_learning) # Store results for this target cluster against all other clusters + all_results_test_data.append(results_test_data) + all_results_no_transfer_learning_target_data.append(results_no_transfer_learning_target_data) + if plotting: + x_ticks.append(cluster_source['cluster_name']) # Store the label for the plot + + + if plotting: # Barplot of performance with source subject. + # Calculate the mean and standard deviation for the performance results + means_test = np.mean(all_results_test_data, axis=1) + stds_test = np.std(all_results_test_data, axis=1) + + means_no_transfer_target = np.mean(all_results_no_transfer_learning_target_data, axis=1) + stds_no_transfer_target = np.std(all_results_no_transfer_learning_target_data, axis=1) + + # Plotting the results + fig, ax = plt.subplots(figsize=(28,7),layout='constrained') + x = np.arange(len(means_test)) # X-axis values + width = 0.25 + # Create a bar plot with error bars showing standard deviation + # Test data + bars = ax.bar(x-1*width, means_test, yerr=stds_test, capsize=5,width=width,color='dimgray') + bars[i].set_hatch('x') + + bars = ax.bar(x+0*width, means_no_transfer_target, yerr=stds_no_transfer_target, capsize=5,width=width,color='silver',hatch='.') + bars[i].set_hatch('x.') + + # Pretty plotting + ax.set_xlabel('Source cluster') + ax.set_ylabel('Mean Accuracy') + ax.set_title('Performance of Each Cluster, target cluster = {}'.format(cluster_target['cluster_name'])) + ax.set_xticks(x) + ax.set_ylim([0,1]) + ax.set_xticklabels(x_ticks, rotation=90) # Set the labels for the x-axis + label1 = mpatches.Patch( facecolor='dimgray',label='Source data for training') + label2 = mpatches.Patch(facecolor='silver',label='Target training data for training') + label3 = mpatches.Patch(edgecolor='black',hatch='..',label='No transfer learning') + label4 = mpatches.Patch(edgecolor='black',hatch='xxx',label='Same subject - same session') + ax.legend(handles = [label1,label2,label3,label4],loc=0) + plt.grid(False,axis='x') + + # Save the plot to a file + print('--> Now Plotting barplot') + clf_name = list(clf.named_steps.keys())[0] + save_string = '{}_oracle_performance_target_cluster_training_size_{}'.format(cluster_target['cluster_name'],training_size) # Filename + plt.savefig('{}/png_{}.png'.format(run_logdir,save_string)) # Save the figure in the specified directory + plt.close(fig) + print('--> Plotting complete') + + # Save data to disk. + os.mkdir(save_folder_pickle) + all_benefit_of_transfer_learning + with open('{}/all_benefit_of_transfer_learning.pickle'.format(save_folder_pickle), 'wb') as file: + pickle.dump(all_benefit_of_transfer_learning, file, protocol=pickle.HIGHEST_PROTOCOL) + # all_results_test_data + with open('{}/all_results_test_data.pickle'.format(save_folder_pickle), 'wb') as file: + pickle.dump(all_results_test_data, file, protocol=pickle.HIGHEST_PROTOCOL) + # all_results_no_transfer_learning_target_data + with open('{}/all_results_no_transfer_learning_target_data.pickle'.format(save_folder_pickle), 'wb') as file: + pickle.dump(all_results_no_transfer_learning_target_data, file, protocol=pickle.HIGHEST_PROTOCOL) + all_results_no_transfer_learning_target_data + with open('{}/df_data.pickle'.format(save_folder_pickle), 'wb') as file: + pickle.dump(df_data, file, protocol=pickle.HIGHEST_PROTOCOL) + + # Remove variables to free space. + del all_benefit_of_transfer_learning,all_results_test_data, all_results_no_transfer_learning_target_data,df_data + return + + + + +# ----------------------------------------------- +# Multiprocessing helper function: +# ----------------------------------------------- +def worker_function(target_cluster_idx, clusters, clf, nbr_folds=5, run_logdir='.', training_size=8, plotting=False, data_folder='data'): + """ + A wrapper function to evaluate the performance of a classifier for one target cluster against all other clusters, + designed to be used with multiprocessing. + + This function is intended to be passed to a multiprocessing Pool, allowing for parallel evaluation of transfer learning + performance across multiple target clusters. It calls `performance_for_one_cluster_vs_all_cluster` to perform the actual + evaluation and result generation. + + Parameters + ---------- + target_cluster_idx : int + The index of the cluster to be used as the target cluster for testing. + + clusters : list of dict + A list of dictionaries, where each dictionary represents a data cluster with 'X_cov', 'y', 'y_enc', 'subject', and 'cluster_name' keys. + + clf : classifier object + A classifier object implementing 'fit' and 'predict' methods. + + nbr_folds : int, optional + The number of cross-validation folds. Default is 5. + + run_logdir : str, optional + The directory where plots and results will be saved. Default is the current directory ('.'). + + training_size : int, optional + The size of the training data used from the target cluster in each fold. Default is 8. + + plotting : bool, optional + Whether to plot the results (bar plots and other visuals). Default is False. + + data_folder : str, optional + The folder where the results will be saved as pickle files. Default is 'data'. + + Returns + ------- + None + This function does not return any values, but it triggers `performance_for_one_cluster_vs_all_cluster` for the + specified target cluster index which stores data to disk. + + Notes + ----- + This function is designed to be used in a multiprocessing context to allow parallel evaluation of multiple target clusters. + """ + print("### worker_function for target cluster index = {}".format(target_cluster_idx)) + return performance_for_one_cluster_vs_all_cluster(clusters, clf,target_cluster_idx,nbr_folds,run_logdir,training_size,plotting,data_folder) + + + +# ----------------------------------------------- +# Test all clusters vs all clusters +# ----------------------------------------------- +def run_all_clusters(clusters, clf, nbr_folds=5, run_logdir='.', training_size=8, plotting_subresults=False, data_folder='data', clf_name=None): + """ + Runs transfer learning evaluations for all clusters and generates performance reports and plots. + + This function iterates through all clusters, evaluating the classifier's performance for each one as the target cluster + using the rest as source clusters. It can run in single-threaded or multi-threaded mode depending on the number of threads (NBR_THREADS). + Results for each evaluation are saved to disk, and overall performance statistics are generated. + + Parameters + ---------- + clusters : list of dict + A list of dictionaries, where each dictionary represents a data cluster with 'X_cov', 'y', 'y_enc', 'subject', and 'cluster_name' keys. + + clf : classifier object + A classifier object implementing 'fit' and 'predict' methods. + + nbr_folds : int, optional + The number of cross-validation folds. Default is 5. + + run_logdir : str, optional + The directory where plots and results will be saved. Default is the current directory ('.'). + + training_size : int, optional + The size of the training data used from the target cluster in each fold. Default is 8. + + plotting_subresults : bool, optional + Whether to generate plots for the results of each cluster evaluation. Default is False. + + data_folder : str, optional + The folder where the results will be saved as pickle files. Default is 'data'. + + clf_name : str, optional + The name of the classifier, which can be used for labeling and saving plots. Default is None. + + Returns + ------- + average_intra_subject_performance : float + The average performance of the classifier when trained and tested on data from the same subject (intra-subject performance). + + Notes + ----- + - This function can either run in single-threaded or multi-threaded mode. If multi-threading is enabled, it uses the `worker_function` + and multiprocessing Pool for parallel processing. + - Results for each target cluster, including transfer learning benefits and performance data, are saved as pickle files in the + specified data folder from the subcalled functions. + - It generates several plots visualizing the transfer learning performance across all clusters, both in subject order and sorted by + intra-subject performance. + - This function also saves metadata such as the number of clusters and cross-validation folds via the subcalled functions. + """ + nbr_of_clusters = len(clusters) + list_of_cluster_names = [cluster['cluster_name'] for cluster in clusters] + + # Create folder for data. + if not os.path.exists('./{}/'.format(data_folder)): + print('creating data folder {}'.format(data_folder)) + os.makedirs('./{}/'.format(data_folder)) + + # If multithreading or not. + if NBR_THREADS == 1: + print('\n-------------------------------') + print( '--> Single thread started < ---') + print( '-------------------------------\n') + for target_cluster_idx in range(len(clusters)): + performance_for_one_cluster_vs_all_cluster(clusters, clf,target_cluster_idx,nbr_folds,run_logdir,training_size,plotting_subresults,data_folder) + else: + # Multithreaded version: + print('\n---------------------------------') + print( '--> Multiprocessing started < ---') + print( '---------------------------------\n') + + pool = Pool(processes=NBR_THREADS) + results = [pool.apply_async(worker_function, args=(target_cluster_idx, clusters, clf, nbr_folds,run_logdir,training_size,plotting_subresults,data_folder)) for target_cluster_idx in range(nbr_of_clusters)] + # for i, aresult in enumerate(results): + # all_results[i, :],base_accuracy[i] = aresult.get() + print('---> Closing pools\n') + pool.close() + pool.join() + print('---> Pools closed \n') + + + # Placeholders + all_results_no_transfer_learning_target_data = np.zeros((nbr_of_clusters,nbr_of_clusters)) + all_benefit_of_transfer_learning = np.zeros((nbr_of_clusters,nbr_of_clusters)) + all_results_test_data = np.zeros((nbr_of_clusters,nbr_of_clusters)) + + # Load stored data from disk. + data_folder_path = './{}/'.format(data_folder) + + for i, subdir in enumerate(list_of_cluster_names): + print(subdir) + subdir_path = os.path.join(data_folder_path, subdir) + + # Check if the path is indeed a directory + if not os.path.isdir(subdir_path): + print('Not a folder: {}'.format(subdir)) + continue # There might be a .DS_store. + + + # == all_results_no_transfer_learning_target_data + pickle_file_path = os.path.join(subdir_path, 'all_results_no_transfer_learning_target_data.pickle') + + # Check if the pickle file exists + if not os.path.exists(pickle_file_path): + import pdb; pdb.set_trace() + bug # There is something wrong! The file is missing. + + with open(pickle_file_path, 'rb') as file: + data_here = pickle.load(file) + data_here = np.array(data_here) + data_here = np.mean(data_here,axis=1) + all_results_no_transfer_learning_target_data[i] = data_here + + + # == all_results_test_data + pickle_file_path = os.path.join(subdir_path, 'all_results_test_data.pickle') + + # Check if the pickle file exists + if not os.path.exists(pickle_file_path): + bug # There is something wrong! The file is missing. + + with open(pickle_file_path, 'rb') as file: + data_here = pickle.load(file) + data_here = np.array(data_here) + + data_here = np.mean(data_here,axis=1) + all_results_test_data[i] = data_here + + + # == all_benefit_of_transfer_learning + pickle_file_path = os.path.join(subdir_path, 'all_benefit_of_transfer_learning.pickle') + + # Check if the pickle file exists + if not os.path.exists(pickle_file_path): + bug # There is something wrong! The file is missing. + + with open(pickle_file_path, 'rb') as file: + data_here = pickle.load(file) + data_here = np.array(data_here) + data_here = np.mean(data_here,axis=1) + all_benefit_of_transfer_learning[i] = data_here + + # Load some general data for the run. + with open('{}/nbr_of_clusters.pickle'.format(data_folder_path), 'wb') as file: + pickle.dump(nbr_of_clusters, file, protocol=pickle.HIGHEST_PROTOCOL) + with open('{}/nbr_folds.pickle'.format(data_folder_path), 'wb') as file: + pickle.dump(nbr_folds, file, protocol=pickle.HIGHEST_PROTOCOL) + with open('{}/list_of_cluster_names.pickle'.format(data_folder_path), 'wb') as file: + pickle.dump(list_of_cluster_names, file, protocol=pickle.HIGHEST_PROTOCOL) + + # Verify that the data has been processed correctly. + print(all_results_test_data-all_results_no_transfer_learning_target_data-all_benefit_of_transfer_learning) + print('The above matrix should be all zero') + + # Plot transfer learning data. Setup figures. + print('--> Saving plot_of_all_data') + figsize = (nbr_of_clusters*1.5,nbr_of_clusters*1.5) + if nbr_of_clusters > 20: + plt.rcParams.update({'font.size': 20}) + + # Plot matrix of transfer learning performance in subject order. + fig, ax = plt.subplots(figsize=figsize,layout='constrained') + plot_all_accuracy_matrix(all_results_test_data, list_of_cluster_names,ax,fig,np.mean(all_results_no_transfer_learning_target_data,axis=1)) + save_string = 'all_accuracy_training_size_{}'.format(training_size) # Filename + plt.savefig(f'{run_logdir}/png_{save_string}.png') # Save the figure in the specified directory + plt.savefig(f'{run_logdir}/pdf_{save_string}.pdf') # Save the figure in the specified directory + plt.close(fig) + + # Plot matrix of transfer learning performance in sorted in intra subject performance order. + fig, ax = plt.subplots(figsize=figsize,layout='constrained') + plot_all_accuracy_matrix_sorted(all_results_test_data, list_of_cluster_names,ax,fig,np.mean(all_results_no_transfer_learning_target_data,axis=1)) + save_string = 'all_accuracy_sorted_training_size_{}'.format(training_size) # Filename + plt.savefig(f'{run_logdir}/png_{save_string}.png') # Save the figure in the specified directory + plt.savefig(f'{run_logdir}/pdf_{save_string}.pdf') # Save the figure in the specified directory + plt.close(fig) + + print('--> Saving complete') + + average_intra_subject_performance = np.mean(all_results_no_transfer_learning_target_data) + return average_intra_subject_performance + + +# def plot_all_results_matrix_sorted(data_matrix, labels, ax, fig, base_accuracy=None): +# """ +# Plots a sorted matrix of transfer learning results, with annotations and color-coded performance differences. + +# This function creates a matrix plot that shows the performance (or benefits) of transfer learning for different source-target +# cluster pairs. The rows and columns are sorted based on the base accuracy (intra-cluster performance) in descending order. +# Cells in the matrix are color-coded using a diverging colormap, with performance values annotated within each cell. + +# Parameters +# ---------- +# data_matrix : ndarray of shape (n_clusters, n_clusters) +# A matrix containing transfer learning performance values for each source-target cluster pair. + +# labels : list of str +# A list of labels for the clusters, used for the x and y axes. + +# ax : matplotlib.axes.Axes +# The axes object where the matrix will be plotted. + +# fig : matplotlib.figure.Figure +# The figure object to which the color bar will be added. + +# base_accuracy : ndarray of shape (n_clusters,), optional +# The intra-cluster performance (e.g., accuracy) for each cluster, used for sorting the clusters. +# If None, the matrix will not be sorted and all values will be treated as zero for the diagonal. +# Default is None. + +# Returns +# ------- +# ax : matplotlib.axes.Axes +# The axes object with the plotted matrix. + +# Notes +# ----- +# - The clusters are sorted by their intra-cluster performance (base_accuracy) in descending order. +# - A diverging colormap is used to highlight positive and negative transfer learning benefits, with color intensity proportional +# to the magnitude of the values. +# - The diagonal represents intra-cluster performance, highlighted with a different color and thicker borders. +# """ +# if base_accuracy is None: +# base_accuracy = np.zeros(len(labels)) + +# sorted_indices = np.argsort(-base_accuracy) # '-' for descending order +# base_accuracy = base_accuracy[sorted_indices] +# data_matrix = data_matrix[sorted_indices, :][:, sorted_indices] +# labels = np.array(labels)[sorted_indices] +# vmin = np.min(data_matrix) +# vmax = np.max(data_matrix) +# limit = max(abs(vmin),abs(vmax)) +# cax = ax.matshow(data_matrix, cmap='RdBu',vmin=-limit,vmax=limit) + +# plt.xticks(np.arange(len(labels)), labels, rotation=45) +# plt.yticks(np.arange(len(labels)), labels) + + +# threshold = limit*2/3 +# # Loop over data dimensions and create text annotations. +# for i in range(len(labels)): +# for j in range(len(labels)): +# if i==j: +# text = '{:.2f}'.format(base_accuracy[i]) +# highlight_cell(i,j,edgecolor="black", linewidth=3,ax=ax,fill=True,facecolor='silver') +# else: +# text = '{:.2f}'.format(data_matrix[i, j]) +# color = "white" if (data_matrix[i, j] > threshold) or (data_matrix[i, j] < -threshold) else "black" +# ax.text(j, i, text, ha="center", va="center", color=color) + +# ax.grid(False) +# plt.xlabel('Source cluster') +# plt.ylabel('Target cluster') +# plt.title('Transfer learning benefits') +# fig.colorbar(cax) + +# return ax + + + +# def plot_all_results_matrix(data_matrix, labels, ax, fig, base_accuracy=None): +# """ +# Plots a matrix of transfer learning results, with annotations and color-coded performance differences. + +# This function creates a matrix plot that shows the performance (or benefits) of transfer learning for different source-target +# cluster pairs. Cells in the matrix are color-coded using a diverging colormap, with performance values annotated within each cell. +# The diagonal represents intra-cluster performance and is highlighted with a different color and thicker borders. + +# Parameters +# ---------- +# data_matrix : ndarray of shape (n_clusters, n_clusters) +# A matrix containing transfer learning performance values for each source-target cluster pair. + +# labels : list of str +# A list of labels for the clusters, used for the x and y axes. + +# ax : matplotlib.axes.Axes +# The axes object where the matrix will be plotted. + +# fig : matplotlib.figure.Figure +# The figure object to which the color bar will be added. + +# base_accuracy : ndarray of shape (n_clusters,), optional +# The intra-cluster performance (e.g., accuracy) for each cluster, used for annotating the diagonal of the matrix. +# If None, the diagonal values will be treated as zero. Default is None. + +# Returns +# ------- +# ax : matplotlib.axes.Axes +# The axes object with the plotted matrix. + +# Notes +# ----- +# - A diverging colormap is used to highlight positive and negative transfer learning benefits, with color intensity +# proportional to the magnitude of the values. +# - The diagonal represents intra-cluster performance and is highlighted with a different color and thicker borders. +# - The matrix cells are annotated with the performance values, and the text color is adjusted for readability based on +# the cell's value. +# """ +# if base_accuracy is None: +# base_accuracy = np.zeros(len(labels)) +# vmin = np.min(data_matrix) +# vmax = np.max(data_matrix) +# limit = max(abs(vmin),abs(vmax)) +# cax = ax.matshow(data_matrix, cmap='RdBu',vmin=-limit,vmax=limit) + +# plt.xticks(np.arange(len(labels)), labels, rotation=45) +# plt.yticks(np.arange(len(labels)), labels) + + +# threshold = limit*2/3 +# # Loop over data dimensions and create text annotations. +# for i in range(len(labels)): +# for j in range(len(labels)): +# if i==j: +# text = '{:.2f}'.format(base_accuracy[i]) +# highlight_cell(i,j,edgecolor="black", linewidth=3,ax=ax,fill=True,facecolor='silver') +# else: +# text = '{:.2f}'.format(data_matrix[i, j]) +# color = "white" if (data_matrix[i, j] > threshold) or (data_matrix[i, j] < -threshold) else "black" +# ax.text(j, i, text, ha="center", va="center", color=color) + +# ax.grid(False) +# plt.xlabel('Source cluster') +# plt.ylabel('Target cluster') +# plt.title('Transfer learning benefits') +# fig.colorbar(cax) + +# return ax + + +def plot_all_accuracy_matrix(data_matrix, labels, ax, fig, diagonal_values=None): + """ + Plots a matrix of transfer learning accuracy, with annotations and color-coded values. + + This function creates a matrix plot that shows the accuracy of transfer learning for different source-target cluster pairs. + The diagonal values represent intra-cluster accuracy and are highlighted with a different border. Each cell in the matrix is + color-coded based on accuracy, and the values are annotated within the cells for better clarity. + + Parameters + ---------- + data_matrix : ndarray of shape (n_clusters, n_clusters) + A matrix containing accuracy values for each source-target cluster pair. + + labels : list of str + A list of labels for the clusters, used for the x and y axes. + + ax : matplotlib.axes.Axes + The axes object where the matrix will be plotted. + + fig : matplotlib.figure.Figure + The figure object to which the color bar will be added. + + diagonal_values : ndarray of shape (n_clusters,), optional + The values to display on the diagonal of the matrix, representing intra-cluster accuracy. If None, the diagonal values + will be extracted from the data matrix. Default is None. + + Returns + ------- + ax : matplotlib.axes.Axes + The axes object with the plotted matrix. + + Notes + ----- + - The diagonal values (intra-cluster accuracy) are highlighted with a thicker border for emphasis. + - A diverging colormap is used to visualize accuracy values, with color intensity proportional to the magnitude of the accuracy. + - The matrix cells are annotated with the accuracy values, and the text color is adjusted for readability based on the cell's value. + """ + if diagonal_values is None: + diagonal_values = np.diag(data_matrix) + vmin = 0.5 + vmax = 1 + + np.fill_diagonal(data_matrix,diagonal_values) + + cax = ax.matshow(data_matrix, cmap='RdBu',vmin=vmin,vmax=vmax) + + plt.xticks(np.arange(len(labels)), labels, rotation=45) + plt.yticks(np.arange(len(labels)), labels) + + threshold = 0.25*2/3 + # Loop over data dimensions and create text annotations. + for i in range(len(labels)): + for j in range(len(labels)): + if i==j: + highlight_cell(i,j,edgecolor="black", linewidth=4,ax=ax,fill=False) + text = '{:.2f}'.format(data_matrix[i, j]) + color = "white" if (data_matrix[i, j] > 0.75+threshold) or (data_matrix[i, j] < 0.75-threshold) else "black" + ax.text(j, i, text, ha="center", va="center", color=color) + + ax.grid(False) + plt.xlabel('Source cluster') + plt.ylabel('Target cluster') + plt.title('Transfer learning accuracies') + fig.colorbar(cax) + + return ax + + + +def plot_all_accuracy_matrix_sorted(data_matrix, labels, ax, fig, diagonal_values=None): + """ + Plots a sorted matrix of transfer learning accuracy, with annotations and color-coded values. + + This function creates a matrix plot showing the accuracy of transfer learning for different source-target cluster pairs. + The rows and columns are sorted based on the diagonal values (intra-cluster accuracy) in descending order. Each cell in the matrix + is color-coded based on accuracy, and the values are annotated within the cells for clarity. + + Parameters + ---------- + data_matrix : ndarray of shape (n_clusters, n_clusters) + A matrix containing accuracy values for each source-target cluster pair. + + labels : list of str + A list of labels for the clusters, used for the x and y axes. + + ax : matplotlib.axes.Axes + The axes object where the matrix will be plotted. + + fig : matplotlib.figure.Figure + The figure object to which the color bar will be added. + + diagonal_values : ndarray of shape (n_clusters,), optional + The values to display on the diagonal of the matrix, representing intra-cluster accuracy. If None, the diagonal values + will be extracted from the data matrix. The matrix will be sorted based on these values. Default is None. + + Returns + ------- + ax : matplotlib.axes.Axes + The axes object with the plotted matrix. + + Notes + ----- + - The diagonal values (intra-cluster accuracy) are used to sort the matrix in descending order. + - A diverging colormap is used to visualize accuracy values, with color intensity proportional to the magnitude of the accuracy. + - The matrix cells are annotated with the accuracy values, and the text color is adjusted for readability based on the cell's value. + - The diagonal values are highlighted with a thicker border for emphasis. + """ + if diagonal_values is None: + diagonal_values = np.diag(data_matrix) + + np.fill_diagonal(data_matrix,diagonal_values) + + # Sort matrix and data. + sorted_indices = np.argsort(-diagonal_values) # '-' for descending order + diagonal_values = diagonal_values[sorted_indices] + data_matrix = data_matrix[sorted_indices, :][:, sorted_indices] + labels = np.array(labels)[sorted_indices] + + # Limits for plotting colors. + vmin = 0.5 + vmax = 1 + + # Plot matrix. + cax = ax.matshow(data_matrix, cmap='RdBu',vmin=vmin,vmax=vmax) + + # Legends + plt.xticks(np.arange(len(labels)), labels, rotation=45) + plt.yticks(np.arange(len(labels)), labels) + + threshold = 0.25*1/3 # For when to change color for text. + # Loop over data dimensions and create text annotations. + for i in range(len(labels)): + for j in range(len(labels)): + if i==j: + highlight_cell(i,j,edgecolor="black", linewidth=4,ax=ax,fill=False) + text = '{:.2f}'.format(data_matrix[i, j]) + color = "white" if (data_matrix[i, j] > 0.75+threshold) or (data_matrix[i, j] < 0.75-threshold) else "black" + ax.text(j, i, text, ha="center", va="center", color=color) + + ax.grid(False) + plt.xlabel('Source cluster') + plt.ylabel('Target cluster') + plt.title('Transfer learning accuracies') + fig.colorbar(cax) + + return ax + + +def highlight_cell(row, col, ax=None, fill=False, **kwargs): + """ + Highlights a specific cell in a matrix plot by drawing a rectangle around it. + + This function adds a rectangular patch to highlight a specific cell in a plot (typically a heatmap or matrix plot). + The cell is defined by its row and column indices, and the rectangle can optionally be filled with a color. + + Parameters + ---------- + row : int + The row index of the cell to highlight. + + col : int + The column index of the cell to highlight. + + ax : matplotlib.axes.Axes, optional + The axes object on which to draw the rectangle. If None, the current axes will be used. Default is None. + + fill : bool, optional + Whether to fill the rectangle with color. Default is False. + + **kwargs : dict, optional + Additional keyword arguments to pass to `matplotlib.patches.Rectangle` (e.g., edgecolor, facecolor, linewidth). + + Returns + ------- + rect : matplotlib.patches.Rectangle + The rectangle patch added to the plot. + + Notes + ----- + This function is typically used to highlight specific cells in heatmaps or matrix plots, where each cell represents + a data point. The rectangle can be customized using the `**kwargs` to control the appearance of the highlight. + """ + y=row + x=col + rect = plt.Rectangle((x-.5, y-.5), 1,1, fill=fill, **kwargs) + ax = ax or plt.gca() + ax.add_patch(rect) + return rect + + +# ----------------------------------------------- +# Main +# ----------------------------------------------- +# execute only if run as a script +if __name__ == "__main__": + begin_time = datetime.datetime.now() + log_folder = "my_logs" + if not os.path.exists(log_folder): + os.makedirs(log_folder) + # Create run directory to store models and logs in: + root_logdir = os.path.join(os.curdir, log_folder) + run_id = "%s_on_%s" % (time.strftime("run_%Y_%m_%d-%H_%M_%S"), os.uname()[1]) + run_logdir = os.path.join(root_logdir, run_id) + output_folder = run_logdir + os.mkdir(run_logdir) + sys.stderr = Logger("{}/console_err.txt".format(run_logdir)) + sys.stdout = Logger("{}/console_log.txt".format(run_logdir)) + print( + "\n### main.py was started on %s in %s with logs in %s" + % (os.uname()[1], os.getcwd(), run_logdir) + ) + print(sys.argv) + print_git_version() + # print_source_code() + save_source_code(run_logdir) + now = begin_time + now_string = str(now).replace(" ", "_").replace(":","_") + + plt.rcParams.update({'font.size': 15}) + # save a file summarizing the used parameters in the script. + with open('{}/description.txt'.format(run_logdir), 'a') as f: + # Write some text to the file + f.write("This is a description file.\n") + + f.write('NBR_THREADS: {}\n'.format(NBR_THREADS)) + f.write('\nNBR_FOLDS: {}\n'.format(NBR_FOLDS)) + f.write('\nDATA_FOLDER: {}\n'.format(DATA_FOLDER)) + f.write('\nDATA_SET: {}\n'.format(DATA_SET)) + if DATA_SET == 'PhysioNet': + f.write('CLASSES_PHYSIONET: {}\n'.format(CLASSES_PHYSIONET)) + f.write('Nbr of classes: 2\n') + + + f.write('nbr_samples_per_class: {}\n'.format(nbr_samples_per_class)) + f.write('\nSUBJECTS_TO_USE: ') + for i in SUBJECTS_TO_USE: + f.write(' {},'.format(i)) + f.write('\n\nClassifiers:') + for name,_ in ALL_CLFS: + f.write('{}\n'.format(name)) + f.write('\n\nPLOTTING_SUBRESULTS: {}\n'.format(PLOTTING_SUBRESULTS)) + + + # ------------------------------------------------------------------------------- + + # ----------------------------------------------- + # Load data + # ----------------------------------------------- + + clusters = load_data_physio_net(SUBJECTS_TO_USE,CLASSES_PHYSIONET) + + + + # ----------------------------------------------- + # Run the code. + # ----------------------------------------------- + all_results = [] # Placeholder + for training_size in ALL_TRAINING_SIZE: + for name,clf in ALL_CLFS: + print('=============================') + print('===== Nbr of training_size: {} ====='.format(training_size)) + print('classifier: {}'.format(name)) + print('=============================') + + data_folder = f'{DATA_FOLDER}' + results = run_all_clusters(clusters,clf,NBR_FOLDS,run_logdir,training_size,PLOTTING_SUBRESULTS,data_folder,name) + all_results.append('Average intra subjct performance {:4.4f} clf: {}'.format(results,name)) + + for res in all_results: + print(res) + + plt.close('all') + print("\n...done!") + print("### Total run time: %s" % (datetime.datetime.now() - begin_time)) -- GitLab