From 3a047e04601a37ffb700c8aef631cc24f267e2b2 Mon Sep 17 00:00:00 2001
From: Frida Heskebeck <frida.heskebeck@control.lth.se>
Date: Thu, 12 Sep 2024 11:55:57 +0200
Subject: [PATCH] Add data_generation.py

---
 data_generation/data_generation.py | 1560 ++++++++++++++++++++++++++++
 1 file changed, 1560 insertions(+)
 create mode 100644 data_generation/data_generation.py

diff --git a/data_generation/data_generation.py b/data_generation/data_generation.py
new file mode 100644
index 0000000..7cd58da
--- /dev/null
+++ b/data_generation/data_generation.py
@@ -0,0 +1,1560 @@
+#!python
+
+# -----------------
+# IMPORTS
+# -----------------
+
+# General Libraries
+import os
+import sys
+import time
+import math
+import random
+import datetime
+import subprocess
+import pickle
+import copy
+import warnings
+import shutil
+
+
+# Numpy Library
+import numpy as np
+rng = np.random.default_rng(seed=42)
+# Making numpy single threaded, so I can multithread this program myself:
+os.environ["OMP_NUM_THREADS"] = "1"  # export OMP_NUM_THREADS=4
+os.environ["OPENBLAS_NUM_THREADS"] = "1"  # export OPENBLAS_NUM_THREADS=4
+os.environ["MKL_NUM_THREADS"] = "1"  # export MKL_NUM_THREADS=6
+os.environ["VECLIB_MAXIMUM_THREADS"] = "1"  # export VECLIB_MAXIMUM_THREADS=4
+os.environ["NUMEXPR_NUM_THREADS"] = "1"  # export NUMEXPR_NUM_THREADS=6
+
+# Pandas Library
+import pandas as pd
+
+# Pyrieman library
+from pyriemann.estimation import Covariances
+from pyriemann.classification import MDM, TSclassifier, SVC, KNearestNeighbor,MeanField
+from pyriemann.utils.distance import distance
+from pyriemann.utils.mean import mean_covariance
+from pyriemann.embedding import SpectralEmbedding, LocallyLinearEmbedding
+from pyriemann.datasets.simulated import make_classification_transfer
+from pyriemann.transfer import decode_domains,encode_domains,TLCenter,TLStretch,TLRotate,TLClassifier,TLDummy
+from pyriemann.utils.viz import plot_embedding
+from pyriemann.tangentspace import TangentSpace
+
+# Scikit-learn library
+from sklearn.pipeline import make_pipeline
+from sklearn.metrics import accuracy_score,confusion_matrix
+from sklearn.model_selection import LeaveOneGroupOut, KFold, StratifiedKFold, StratifiedShuffleSplit
+from sklearn import svm
+from sklearn.neural_network import MLPClassifier
+from sklearn import decomposition
+from sklearn.linear_model import LogisticRegression
+from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
+
+# Matplotlib Library
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
+
+
+# Multiprocessing
+from multiprocessing import Pool
+
+# MNE library
+from mne import Epochs, pick_types, events_from_annotations
+from mne.io import concatenate_raws
+from mne.io.edf import read_raw_edf
+from mne.datasets import eegbci
+from mne import set_log_level
+
+
+
+# -----------------
+# GLOBAL PARAMETERS
+# -----------------
+
+NBR_THREADS = 8 # Nbr of threads for parallellisation
+
+NBR_FOLDS = 5 # Number of folds for cross validation. [BEWARE! The source selection script assumes this is 5.]
+
+
+# Name for datafolder. 
+# DATA_FOLDER = 'all_users_left_right'
+# DATA_FOLDER = 'all_users_hands_feet'
+DATA_FOLDER = time.strftime("data_%Y_%m_%d-%H_%M_%S")
+
+
+# Physionet: 46 samples in 2 classes 
+DATA_SET = 'PhysioNet'
+
+
+nbr_samples_per_class = 17 # How many samples to use from each class as traing data from target user. 
+ALL_TRAINING_SIZE = [2*nbr_samples_per_class] # Physionet has two classes. 
+
+# Which subjects to use. 
+# SUBJECTS_TO_USE = np.arange(1,110) # all
+# SUBJECTS_TO_USE = [1,109] # First and last
+SUBJECTS_TO_USE = [1,2, 4, 7, 8, 15, 20,109] # Some. 
+
+
+# Which classes to use. 
+# CLASSES_PHYSIONET = 'left_right'
+CLASSES_PHYSIONET = 'hands_feet'
+
+
+# Automatic plotting of subresults. No plotting if more than 40 source users. 
+if len(SUBJECTS_TO_USE) > 40:
+    PLOTTING_SUBRESULTS = False
+else:
+    PLOTTING_SUBRESULTS = True
+
+
+random_state = 42
+
+# Classifiers used for MI task classification after transfer learning.
+ALL_CLFS = [
+            ('MDM',make_pipeline(MDM(metric=dict(mean='riemann', distance='riemann')))),
+            ]
+
+
+print(__doc__)
+
+class Logger(object):
+    """For saving everything printed to file as well."""
+    def __init__(self, filename="console_log.txt"):
+        """Initialize the class."""
+        self.terminal = sys.stdout
+        self.log = open(filename, "a")
+    def write(self, message):
+        """Write both to stdout and to file."""
+        self.terminal.write(message)
+        self.log.write(message)
+    def flush(self):
+        """We don't do flush."""
+        pass
+
+
+def print_git_version():
+    """We want to know what version of the source code we're dealing with."""
+    # Ideally, commit to git _every time_ before you run the code.
+    try:
+        git_output = subprocess.check_output("git log -n 1", shell=True, stderr=subprocess.STDOUT)
+        print("\n\n### git log -n 1\n%s\n\n" % git_output.decode("utf-8"))
+        git_output2 = subprocess.check_output("git status", shell=True, stderr=subprocess.STDOUT)
+        print("### git status\n%s\n\n" % git_output2.decode("utf-8"))
+        git_output3 = subprocess.check_output("git diff", shell=True, stderr=subprocess.STDOUT)
+        print("### git diff\n%s\n\n\n" % git_output3.decode("utf-8"))
+    except:
+        print("\n\n### WARNING: The code you're running is NOT under GIT version control. You can do better. Behave.\n\n")
+    return
+
+def print_source_code():
+    """
+    Printing source code to terminal
+    """
+    print("\n\n###\n### Listing the python source code used, from the file '%s':\n###\n\n" % __file__)
+    with open(__file__, 'r') as f:
+        print(f.read())
+    print("\n\n###\n### End of python source code from the file '%s'.\n###\n\n" % __file__)
+    return
+
+def save_source_code(run_logdir):
+    """
+    Saving source code to run_logdir folder. 
+    """
+    print("\n\n###\n### Saving the python source code used, from the file '%s':\n###" % __file__)
+    shutil.copy2(__file__, f'{run_logdir}/_script.py')
+    return
+
+# -----------------------------------------------
+
+
+
+# -----------------------------------------------
+# PhysioNet : Loading data
+# -----------------------------------------------
+def get_subject_dataset_physio_net(subject,classes='hands_feet'):
+    """
+    Loads and preprocesses EEG data for a specific subject from the PhysioNet EEG Motor Movement/Imagery Dataset.
+
+    This function downloads the EEG data for a given subject, filters it, selects relevant channels,
+    extracts motor imagery events, and computes covariance matrices based on the selected epochs.
+    The function supports two types of motor imagery tasks: 'hands vs feet' and 'left hand vs right hand'.
+
+    Parameters
+    ----------
+    subject : int
+        The subject ID from the PhysioNet EEG Motor Movement/Imagery dataset.
+
+    classes : str, optional
+        The motor imagery task to extract. Either 'hands_feet' (default) or 'left_right'.
+        - 'hands_feet' extracts events related to imagining hand and foot movements (runs 6, 10, 14).
+        - 'left_right' extracts events related to imagining left and right hand movements (runs 4, 8, 12).
+
+    Returns
+    -------
+    covs : ndarray of shape (n_epochs, n_channels, n_channels)
+        Covariance matrices computed from the EEG data for each epoch.
+
+    labels : ndarray of shape (n_epochs,)
+        Encoded labels corresponding to each epoch. For 'hands_feet', 0 corresponds to 'hands' and 1 to 'feet'. 
+        For 'left_right', 0 corresponds to 'left hand' and 1 to 'right hand'.
+
+    str_labels : ndarray of shape (n_epochs,)
+        String labels for each epoch, corresponding to the classes provided ('hands', 'feet', etc.).
+
+    Notes
+    -----
+    The data is preprocessed by selecting EEG channels, applying a bandpass filter (7-35 Hz), and 
+    extracting epochs from 1s to 2s after cue onset. The covariance matrices are computed on scaled data.
+    The dataset used is from PhysioNet: https://www.physionet.org/content/eegmmidb/1.0.0/
+    """
+
+    # Dataset description: https://www.physionet.org/content/eegmmidb/1.0.0/
+    # Processing template: https://pyriemann.readthedocs.io/en/latest/auto_examples/transfer/plot_bci_example.html#sphx-glr-auto-examples-transfer-plot-bci-example-py
+
+    # Consider epochs that start 1s after cue onset.
+    tmin, tmax = 1., 2.
+    if classes == 'hands_feet':
+        event_id = dict(hands=2, feet=3)
+        str_classes = ['hands','feet']
+        runs = [6, 10, 14]  # motor imagery: hands vs feet
+    elif classes == 'left_right':
+        event_id = dict(left_hand=2, right_hand=3)
+        runs = [4, 8, 12]   # Motor imagery: left vs right hand
+        str_classes = ['left_hand','right_hand']
+    else:
+        bug
+
+    # Download data with MNE
+    raw_files = [
+        read_raw_edf(f, preload=True) for f in eegbci.load_data(subject, runs)
+    ]
+    raw = concatenate_raws(raw_files)
+
+    # Select only EEG channels
+    picks = pick_types(
+        raw.info, meg=False, eeg=True, stim=False, eog=False, exclude='bads')
+    # select only nine electrodes: F3, Fz, F4, C3, Cz, C4, P3, Pz, P4
+    picks = picks[[31, 33, 35, 8, 10, 12, 48, 50, 52]]
+
+    # Apply band-pass filter
+    raw.filter(7., 35., method='iir', picks=picks)
+
+    # Check the events
+    events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3))
+
+    # Define the epochs
+    epochs = Epochs(
+        raw,
+        events,
+        event_id,
+        tmin,
+        tmax,
+        proj=True,
+        picks=picks,
+        baseline=None,
+        preload=True,
+        verbose=False)
+
+    # Extract the labels for each event
+    labels = epochs.events[:, -1] - 2
+    str_labels = np.array([str_classes[i] for i in labels])
+
+    # Compute covariance matrices on scaled data
+    covs = Covariances().fit_transform(1e6 * epochs.get_data())
+    del(raw)
+
+    
+    return covs, labels,str_labels
+
+
+
+def load_data_physio_net(subjects,classes='hands_feet'):
+    """
+    Loads and preprocesses EEG data for multiple subjects from the PhysioNet EEG Motor Movement/Imagery Dataset.
+
+    For each subject, this function downloads, processes, and extracts the covariance matrices and labels.
+    The function returns a list of clusters, where each cluster contains the covariance matrices, labels, 
+    and metadata for a given subject.
+
+    Parameters
+    ----------
+    subjects : list of int
+        A list of subject IDs from the PhysioNet EEG Motor Movement/Imagery dataset.
+
+    classes : str, optional
+        The motor imagery task to extract. Either 'hands_feet' (default) or 'left_right'.
+        - 'hands_feet' extracts events related to imagining hand and foot movements.
+        - 'left_right' extracts events related to imagining left and right hand movements.
+
+    Returns
+    -------
+    clusters : list of dict
+        A list where each dictionary contains the following keys:
+        - 'X_cov': ndarray of shape (n_epochs, n_channels, n_channels)
+          Covariance matrices for each subject.
+        - 'y': ndarray of shape (n_epochs,)
+          String labels for each epoch, corresponding to the classes provided ('hands', 'feet', etc.).
+        - 'y_enc': ndarray of shape (n_epochs,)
+          Encoded labels with subject information. For example, 'subj_X/hand'.
+        - 'subject': int
+          The subject ID.
+        - 'cluster_name': str
+          The cluster name, in the format 'subj_X', where X is the subject ID.
+
+    Notes
+    -----
+    This function calls `get_subject_dataset_physio_net` internally to process the data for each subject.
+    """
+    clusters = []
+
+    for subj in subjects:
+        covs, labels,str_labels = get_subject_dataset_physio_net(subj,classes)
+
+        cluster_data = {
+            'X_cov' : covs,
+            'y': str_labels,
+            'y_enc' : np.array(['subj_{}/'.format(subj) + str(item) for item in str_labels]),
+            'subject': subj,
+            'cluster_name': 'subj_{}'.format(subj)
+        }
+
+        clusters.append(cluster_data)
+
+    return clusters
+
+
+
+# -----------------------------------------------
+# Innermost function. Calculate n-fold accuracy of source and target data. 
+# -----------------------------------------------
+def performance_on_one_cluster_vs_one_cluster(source_cluster_data, target_cluster_data, clf,nbr_folds=5,training_size=8,run_logdir='.',plotting=False,df_data=None):
+    """
+    Evaluates the performance of a classifier using transfer learning between a source and target data cluster. 
+
+    The classifier is trained on the source cluster and tested on the target cluster using transfer learning techniques 
+    (re-centering and rotation). It performs a StratifiedShuffleSplit cross-validation on the target data, optionally 
+    plotting the embeddings and confusion matrices for each fold.
+
+    Parameters
+    ----------
+    source_cluster_data : dict
+        A dictionary containing data from the source cluster. 
+        Expected to have 'X_cov' (covariance matrices), 'y' (labels), and 'y_enc' (encoded labels).
+
+    target_cluster_data : dict
+        A dictionary containing data from the target cluster. 
+        Expected to have 'X_cov' (covariance matrices), 'y' (labels), 'y_enc' (encoded labels), and 'cluster_name' (metadata).
+
+    clf : classifier object
+        A classifier object implementing the 'fit' and 'predict' methods.
+
+    nbr_folds : int, optional
+        The number of folds for cross-validation. Default is 5.
+
+    training_size : int, optional
+        The size of the training set for each fold. Default is 8.
+
+    run_logdir : str, optional
+        Directory to save the generated plots. Default is the current directory ('.').
+
+    plotting : bool, optional
+        Whether to plot embeddings and confusion matrices during the transfer learning process. Default is False.
+
+    df_data : pandas.DataFrame, optional
+        A dataframe to append the results. If None, a new dataframe will be created. Default is None.
+
+    Returns
+    -------
+    df_data : pandas.DataFrame
+        A DataFrame containing performance metrics and distances computed between source and target clusters for 
+        each fold before rotation in the transfer learning algorithm.
+
+    benefit_of_transfer_learning : ndarray of shape (nbr_folds,)
+        The benefit of transfer learning for each fold, calculated as the difference in accuracy between 
+        after transfer learning and without transfer learning.
+
+    results_test_data : list of float
+        A list of accuracy scores for the classifier on the target test data after applying transfer learning, for each fold.
+
+    results_no_transfer_learning_target_data : list of float
+        A list of accuracy scores for the classifier on the target test data without transfer learning, for each fold.
+
+    Notes
+    -----
+    The function applies two transfer learning techniques: TLCenter (re-centering) and TLRotate (rotation using Riemannian metrics). 
+    If plotting is enabled, embeddings of the source and target clusters before and after transfer learning are displayed. 
+    It also generates confusion matrices for each fold and saves them to the specified log directory.
+    """
+
+    # Extract data to variables. 
+    X_source = source_cluster_data['X_cov']
+    y_source = source_cluster_data['y']
+    y_source_enc = source_cluster_data['y_enc']
+    X_target = target_cluster_data['X_cov']
+    y_target = target_cluster_data['y']
+    y_target_enc = target_cluster_data['y_enc']
+    name_target_domain = target_cluster_data['cluster_name']
+
+    # Placeholders for results
+    results_test_data = []
+    results_no_transfer_learning_target_data = []
+    distances_all_folds = [] # [target:center-class1,target:center-class2,target:class1-class2,source:center-class1,source:center-class2,source:class1-class2]
+
+
+
+    if plotting:
+        # Create figure for plotting of data before and after transfer learning. 
+        ncols = 5
+        nrows = 3
+        fig, axs = plt.subplots(figsize=(35,21), ncols=ncols, nrows=nrows, sharey=True,sharex=True,layout='constrained')
+        classes = np.unique(y_target)
+        confusion_matrices = []
+
+    # Splitter function for test and training split. 
+    splitter = StratifiedShuffleSplit(n_splits=nbr_folds, test_size=training_size, random_state=42) # Shuffels data. Test size = training size in our case. 
+    
+    # Loop over folds for test and training splits. 
+    for i, (test_index, train_index) in enumerate(splitter.split(X_target, y_target)):
+        # Get data
+        X_train_source = copy.deepcopy(X_source)
+        y_train_source = copy.deepcopy(y_source)
+        y_enc_source = copy.deepcopy(y_source_enc)
+
+        # Split target data into training and test subsets
+        X_train_target = copy.deepcopy(X_target[train_index])
+        y_train_target = copy.deepcopy(y_target[train_index])
+        y_enc_target_train = copy.deepcopy(y_target_enc[train_index])
+
+        X_test_target = copy.deepcopy(X_target)
+        y_test_target = copy.deepcopy(y_target)
+        y_enc_target_test = copy.deepcopy(y_target_enc)
+
+
+        # Plot data before transfer learning
+        if plotting:
+            ax = axs[0,i]
+            X_training_list = np.concatenate([X_train_source,X_train_target])
+            y_enc_training_list = np.concatenate([y_enc_source, y_enc_target_train])
+            plot_embeddings_source_target_train_test(X_training_list,X_test_target,y_enc_training_list,y_enc_target_test,ax,target_cluster_name=name_target_domain)
+
+
+
+        # objects for transfer learning
+        tl_recenter = TLCenter(target_domain=name_target_domain)
+        tl_rotate = TLRotate(target_domain=name_target_domain, metric='riemann')
+
+
+        # ==================================
+        # ==== TRANSFER LEARNING ======
+        # ============================
+        # Format data in correct way for transfer learning
+        X_train = np.concatenate((X_train_source,X_train_target))
+        y_enc = np.concatenate((y_enc_source,y_enc_target_train))
+        y_train = np.concatenate((y_train_source,y_train_target))
+
+
+        
+        # === RECENTER ===
+        # Train the classifier after transfer lerning and evaluate on the target test data
+        X_train = tl_recenter.fit_transform(X_train, y_enc)
+        X_test_target = tl_recenter.transform(X_test_target)
+        if plotting:
+            # Plot data after recentering
+            plot_embeddings_source_target_train_test(X_train,X_test_target,y_enc,y_enc_target_test,axs[1,i],target_cluster_name=name_target_domain)
+
+        # INTRA SUBJECT TRAINING - TARGET
+        nbr_training_data = len(y_train_target)
+        clf.fit(X_train[-nbr_training_data:],y_train[-nbr_training_data:])
+        y_pred = clf.predict(X_test_target)
+        score = accuracy_score(y_test_target, y_pred)
+        results_no_transfer_learning_target_data.append(score)
+
+
+
+        # Finding means and then distances between means.
+        classes = np.unique(y_train_source)
+        metric = 'riemann'
+        means_target = {'center':mean_covariance(X_train[-nbr_training_data:], metric=metric),
+                        'class1':mean_covariance(X_train[-nbr_training_data:][y_train_target==classes[0]], metric=metric),
+                        'class2':mean_covariance(X_train[-nbr_training_data:][y_train_target==classes[1]], metric=metric)}
+
+        means_source = {'center':mean_covariance(X_train[:-nbr_training_data], metric=metric),
+                        'class1':mean_covariance(X_train[:-nbr_training_data][y_train_source==classes[0]], metric=metric),
+                        'class2':mean_covariance(X_train[:-nbr_training_data][y_train_source==classes[1]], metric=metric)}
+
+        columns_data = ['Target:Class1-Target:Class2',
+                        'Target:Center-Target:Class1',
+                        'Target:Center-Target:Class2',
+                        'Source:Class1-Source:Class2',
+                        'Source:Center-Source:Class1',
+                        'Source:Center-Source:Class2',
+                        'Target:Center-Source:Center',
+                        'Target:Center-Source:Class1',
+                        'Target:Center-Source:Class2',
+                        'Target:Class1-Source:Center',
+                        'Target:Class1-Source:Class1',
+                        'Target:Class1-Source:Class2',
+                        'Target:Class2-Source:Center',
+                        'Target:Class2-Source:Class1',
+                        'Target:Class2-Source:Class2',
+                        'Source:Dispersion:all',
+                        'Source:Dispersion:Class1',
+                        'Source:Dispersion:Class2',
+                        'Target:Dispersion:all',
+                        'Target:Dispersion:Class1',
+                        'Target:Dispersion:Class2',
+                        ]
+
+
+
+        metric = 'riemann'
+        distances = [distance(means_target['class1'], means_target['class2'], metric=metric),
+                    distance(means_target['center'], means_target['class1'], metric=metric),
+                    distance(means_target['center'], means_target['class2'], metric=metric),
+                    distance(means_source['class1'], means_source['class2'], metric=metric),
+                    distance(means_source['center'], means_source['class1'], metric=metric),
+                    distance(means_source['center'], means_source['class2'], metric=metric),
+                    distance(means_target['center'], means_source['center'], metric=metric),
+                    distance(means_target['center'], means_source['class1'], metric=metric),
+                    distance(means_target['center'], means_source['class2'], metric=metric),
+                    distance(means_target['class1'], means_source['center'], metric=metric),
+                    distance(means_target['class1'], means_source['class1'], metric=metric),
+                    distance(means_target['class1'], means_source['class2'], metric=metric),
+                    distance(means_target['class2'], means_source['center'], metric=metric),
+                    distance(means_target['class2'], means_source['class1'], metric=metric),
+                    distance(means_target['class2'], means_source['class2'], metric=metric),
+                    np.sum([distance(cov, means_source['center'], metric=metric,squared=True) for cov in X_train[:-nbr_training_data]]) / len(X_train[:-nbr_training_data]),
+                    np.sum([distance(cov, means_source['class1'], metric=metric,squared=True) for cov in X_train[:-nbr_training_data][y_train_source==classes[0]]]) / len(X_train[:-nbr_training_data][y_train_source==classes[0]]),
+                    np.sum([distance(cov, means_source['class2'], metric=metric,squared=True) for cov in X_train[:-nbr_training_data][y_train_source==classes[1]]]) / len(X_train[:-nbr_training_data][y_train_source==classes[1]]),
+                    np.sum([distance(cov, means_target['center'], metric=metric,squared=True) for cov in X_train[-nbr_training_data:]]) / len(X_train[-nbr_training_data:]),
+                    np.sum([distance(cov, means_target['class1'], metric=metric,squared=True) for cov in X_train[-nbr_training_data:][y_train_target==classes[0]]]) / len(X_train[-nbr_training_data:][y_train_target==classes[0]]),
+                    np.sum([distance(cov, means_target['class2'], metric=metric,squared=True) for cov in X_train[-nbr_training_data:][y_train_target==classes[1]]]) / len(X_train[-nbr_training_data:][y_train_target==classes[1]])
+                    ]
+
+        distances_all_folds.append(distances)
+
+        # === ROTATE ===
+        X_train = tl_rotate.fit_transform(X_train, y_enc)
+        X_test_target = tl_rotate.transform(X_test_target)
+        if plotting:
+            plot_embeddings_source_target_train_test(X_train,X_test_target,y_enc,y_enc_target_test,axs[2,i],target_cluster_name=name_target_domain)
+
+        # Fit classifier on data after transfer leanring and evaluate performance
+        # clf.fit(X_train,y_train) # With data from target
+        clf.fit(X_train[:-nbr_training_data],y_train[:-nbr_training_data]) # Without data from target.
+        y_pred = clf.predict(X_test_target)
+        score = accuracy_score(y_test_target, y_pred)
+        results_test_data.append(score)
+
+        if plotting:
+            # Print text about transfer learning performance in figure
+            axs[0,i].text(0.25,0.05,f'Accuracy no transfer learning: {results_no_transfer_learning_target_data[-1]:.4f}',transform=axs[0,i].transAxes)
+            axs[2,i].text(0.25,0.08,f'Accuracy transfer learning: {results_test_data[-1]:.4f}',transform=axs[2,i].transAxes)
+            axs[2,i].text(0.25,0.05,f'Benefit transfer learning: {results_test_data[-1]-results_no_transfer_learning_target_data[-1]:.4f}',transform=axs[2,i].transAxes,color='b' if results_test_data[-1]-results_no_transfer_learning_target_data[-1]>=0 else 'r')
+            # Calculate confusion matrix 
+            conf_matrix = confusion_matrix(y_test_target, y_pred,labels=classes)
+            confusion_matrices.append(conf_matrix)
+
+    if plotting:
+        # save figure
+        plt.figure(fig)
+        axs[0,0].set_ylabel('Before transfer learning')
+        axs[1,0].set_ylabel('After Recenter')
+        axs[2,0].set_ylabel('After After recenter and then rotate')
+        plt.suptitle('Embeddings for the different folds during transfer learning.')
+        print('--> Saving transfer learning plot')
+        save_string = '{}_data_analysis_clusters_target_{}_and_source_{}_training_size_{}'.format(target_cluster_data['cluster_name'],target_cluster_data['cluster_name'],source_cluster_data['cluster_name'],training_size)  # Filename
+        plt.savefig('{}/png_{}.png'.format(run_logdir,save_string))  # Save the figure in the specified directory
+        plt.close(fig)
+        print('--> Saving complete')
+
+
+        # Plot confusion matrix
+        conf_matrix = np.mean(confusion_matrices,axis=0)
+        fig, ax = plt.subplots(figsize=(7,7),layout='constrained')
+
+        plot_confusion_matrix(conf_matrix, classes,ax,fig)
+
+        print('--> Saving confusion matrix')
+        save_string = '{}_confusion_matrix_target_{}_and_source_{}_training_size_{}'.format(target_cluster_data['cluster_name'],target_cluster_data['cluster_name'],source_cluster_data['cluster_name'],training_size)  # Filename
+        plt.savefig('{}/png_{}.png'.format(run_logdir,save_string))  # Save the figure in the specified directory
+        plt.close(fig)
+        print('--> Saving complete')
+
+
+    # calculate benefit of transfer learning.
+    benefit_of_transfer_learning = np.array(results_test_data) - np.array(results_no_transfer_learning_target_data) 
+
+    # add data to dataframe
+    df_here = pd.DataFrame(distances_all_folds,columns=columns_data)
+    df_here.insert(len(df_here.columns),'Source_cluster',np.repeat(source_cluster_data['cluster_name'],nbr_folds))
+    df_here.insert(len(df_here.columns),'Target_cluster',np.repeat(target_cluster_data['cluster_name'],nbr_folds))
+    df_here.insert(len(df_here.columns),'Target_accuracy',np.array(results_no_transfer_learning_target_data))
+    df_here.insert(len(df_here.columns),'Benefit_transfer_learning',benefit_of_transfer_learning)
+    df_here.insert(len(df_here.columns),'Accuracy_after_transfer_learning',results_test_data)
+    df_data = pd.concat([df_data,df_here],ignore_index=True)
+
+    return df_data, benefit_of_transfer_learning, results_test_data, results_no_transfer_learning_target_data
+
+    
+
+
+def plot_embeddings_source_target_train_test(X_training_list, X_test, y_enc_training_list, y_enc_test, ax, target_cluster_name):
+    """
+    Plots 2D spectral embeddings of source and target data clusters before and after transfer learning.
+
+    This function combines the training and test data, performs spectral embedding using the Riemannian 
+    metric, and visualizes the embeddings. The target cluster is highlighted, and points are color-coded 
+    based on the classes and clusters.
+
+    Parameters
+    ----------
+    X_training_list : ndarray of shape (n_samples_train, n_channels, n_channels)
+        Covariance matrices from the training data (source and target).
+
+    X_test : ndarray of shape (n_samples_test, n_channels, n_channels)
+        Covariance matrices from the test data (target).
+
+    y_enc_training_list : ndarray of shape (n_samples_train,)
+        Encoded labels for the training data, including source and target data.
+
+    y_enc_test : ndarray of shape (n_samples_test,)
+        Encoded labels for the test data (target).
+
+    ax : matplotlib.axes.Axes
+        The axes object where the embeddings will be plotted.
+
+    target_cluster_name : str
+        The name of the target cluster (domain) used to distinguish it from the source cluster.
+
+    Returns
+    -------
+    ax : matplotlib.axes.Axes
+        The axes object with the plotted embeddings.
+
+    Notes
+    -----
+    The function uses spectral embedding to reduce the dimensionality of the covariance matrices to 2D. 
+    The embeddings for different classes and clusters are plotted using different colors and markers. 
+    The identity matrix (used as a reference) is also plotted with a special marker.
+    """
+    
+    # Extract data to variables
+    X = np.concatenate((X_training_list,X_test))
+    y_enc_test = ['test_'+item for item in y_enc_test]
+    y_enc = np.concatenate([y_enc_training_list,y_enc_test])
+
+    X, y, domain = decode_domains(X, y_enc)
+    # instantiate object for doing spectral embeddings
+    emb = SpectralEmbedding(n_components=2, metric='riemann')
+    # emb = LocallyLinearEmbedding(n_components=2,metric='riemann',n_neighbors=10)
+
+    # embed the original source and target datasets
+    points = np.concatenate([X, np.eye(X.shape[1])[None, :, :]])  # stack the identity
+    embedded_points = emb.fit_transform(points)
+    embedded_points = embedded_points - embedded_points[-1]
+    identity = embedded_points[-1]
+    embedded_points = embedded_points[:-1]
+
+    markers_list = ['o', 's', '^', 'D', 'v', 'p', '*', 'H', 'x', '.', ',', '<', '>', '1', '2', '3', '4', 'h', '+', 'd', '|', '_']
+    colors_list = ['b', 'g', 'r', 'c', 'm', 'y', 'k', '#FFA07A', '#FF5733', '#33FFCE', '#8E44AD', '#3498DB', '#2ECC71', '#F1C40F', '#E74C3C', '#34495E', '#1ABC9C', '#9B59B6', '#34495E', '#95A5A6', '#7F8C8D']
+
+    unique_labels = np.unique(y)
+    unique_clusters = np.unique(domain)
+    # Move target index cluster last.
+    index_of_target = np.where(unique_clusters == target_cluster_name)[0]
+    unique_clusters = np.delete(unique_clusters,index_of_target)
+    unique_clusters = np.append(unique_clusters,target_cluster_name)
+
+    # Plotting
+    for i in range(len(unique_clusters)):
+        for j in range(len(unique_labels)):
+            data_filter = (y==unique_labels[j]) & (domain==unique_clusters[i])
+            ax.scatter(
+                embedded_points[data_filter][:, 0],
+                embedded_points[data_filter][:, 1],
+                c=colors_list[j], s=50, alpha=0.50,marker=markers_list[i])
+            ax.scatter([],[],c=colors_list[j],marker=markers_list[i], label='{} - {}'.format(unique_clusters[i],unique_labels[j])) # Add label
+    ax.scatter(identity[ 0], identity[1], c='k', s=80, marker="*")
+    ax.legend()
+    
+    return ax
+
+
+def plot_confusion_matrix(conf_mat, classes, ax, fig):
+    """
+    Plots a confusion matrix using matplotlib, with color-coded values and class labels.
+
+    Parameters
+    ----------
+    conf_mat : ndarray of shape (n_classes, n_classes)
+        The confusion matrix to be plotted.
+
+    classes : list of str
+        A list of class names corresponding to the rows and columns of the confusion matrix.
+
+    ax : matplotlib.axes.Axes
+        The axes object on which to draw the confusion matrix.
+
+    fig : matplotlib.figure.Figure
+        The figure object for adding the color bar.
+
+    Returns
+    -------
+    ax : matplotlib.axes.Axes
+        The axes object with the plotted confusion matrix.
+
+    Notes
+    -----
+    The confusion matrix is displayed with different colors to indicate higher or lower values, and the 
+    color of the text within the cells is chosen based on a threshold to maintain readability. The matrix 
+    has labeled x and y ticks for predicted and true labels, respectively, and a color bar to indicate the 
+    intensity of the values.
+    """
+    # Plot matrix
+    cax = ax.matshow(conf_mat, cmap='Greys')
+
+    # Set labels
+    plt.xticks(np.arange(len(classes)), classes, rotation=45)
+    plt.yticks(np.arange(len(classes)), classes)
+
+    threshold = np.max(conf_mat)/2 + np.min(conf_mat)/2 # treshold for change of color for text. 
+    # Loop over data dimensions and create text annotations.
+    for i in range(len(classes)):
+        for j in range(len(classes)):
+            color = "white" if conf_mat[i, j] > threshold else "black"
+            ax.text(j, i, conf_mat[i, j], ha="center", va="center", color=color)
+
+    ax.grid(False)
+    plt.xlabel('Predicted labels')
+    plt.ylabel('True labels')
+    plt.title('Confusion Matrix')
+    fig.colorbar(cax)
+    return ax
+
+
+
+# -----------------------------------------------
+# For one cluster, test performance vs all the other clusters. 
+# -----------------------------------------------
+def performance_for_one_cluster_vs_all_cluster(clusters, clf, target_cluster_idx, nbr_folds=5, run_logdir='.', training_size=8, plotting=False, data_folder='data'):
+    """
+    Evaluates and plots the performance of a classifier for one target cluster against all other source clusters using transfer learning.
+
+    This function iterates over all clusters, using each one as a source cluster and the target cluster for testing. 
+    It computes the classifier's performance for each source-target combination and, if specified, generates a bar plot 
+    showing the mean accuracy and standard deviation for transfer learning performance.
+
+    Parameters
+    ----------
+    clusters : list of dict
+        A list of dictionaries, where each dictionary represents a data cluster with 'X_cov', 'y', 'y_enc', 'subject', 
+        and 'cluster_name' keys.
+
+    clf : classifier object
+        A classifier object implementing 'fit' and 'predict' methods.
+
+    target_cluster_idx : int
+        The index of the cluster to be used as the target cluster for testing.
+
+    nbr_folds : int, optional
+        The number of cross-validation folds. Default is 5.
+
+    run_logdir : str, optional
+        The directory where plots and results will be saved. Default is the current directory ('.').
+
+    training_size : int, optional
+        The size of the training data used from the target cluster in each fold. Default is 8.
+
+    plotting : bool, optional
+        Whether to plot the results (bar plots and other visuals). Default is False.
+
+    data_folder : str, optional
+        The folder where the results will be saved as pickle files. Default is 'data'.
+
+    Returns
+    -------
+    None
+        This function does not return any values, but it saves the computed results and optionally generated plots to disk.
+
+    Notes
+    -----
+    - This function uses transfer learning by training the classifier on each source cluster and testing it on the target cluster.
+    - If plotting is enabled, a bar plot is generated to visualize the classifier's performance with and without transfer learning 
+      for each source-target combination.
+    - Results are saved in pickle files, including transfer learning benefits, performance data, and processed DataFrames.
+    """
+
+
+
+    i = target_cluster_idx
+    
+    # Placeholders for results
+    all_benefit_of_transfer_learning = []  # Store results for this target cluster against all other clusters
+    all_results_test_data = []
+    all_results_no_transfer_learning_target_data = []
+
+    cluster_target = clusters[i]       # Target cluster for testing
+
+    # If run has to be restarted, see if data is already analyzed.
+    save_folder_pickle = './{}/{}'.format(data_folder,cluster_target['cluster_name'])
+    if os.path.exists(save_folder_pickle):
+        # The files are already saved to disc
+        print('!!! {} already analyzed'.format(cluster_target['cluster_name']))
+        return 
+    elif not os.path.exists('./{}/'.format(data_folder)):
+        print('creating folder {}'.format(data_folder))
+        os.makedirs('./{}/'.format(data_folder))
+    
+    # If things should be plotted, save in specified folder.
+    if plotting:
+        save_folder = '{}/{}_training_size_{}'.format(run_logdir,cluster_target['cluster_name'],training_size)
+        os.mkdir(save_folder)
+        run_logdir = save_folder
+        x_ticks = []  # To store labels for the x-axis in the plot
+
+    target_subject = cluster_target['subject']  # Extract the subject of the target cluster
+    df_data = pd.DataFrame()
+
+    # Iterate over all clusters/subjects to use them as source clusters
+    for j in range(len(clusters)):
+        print('Target cluster: {}, analyzing cluster {} of {}'.format(cluster_target['cluster_name'],j+1,len(clusters)))
+        cluster_source = clusters[j]  # Source cluster for training
+        
+        # Evaluate the performance of one cluster against another
+        df_data, benefit_of_transfer_learning, results_test_data, results_no_transfer_learning_target_data = performance_on_one_cluster_vs_one_cluster(source_cluster_data=cluster_source, target_cluster_data=cluster_target, clf=clf,df_data=df_data,nbr_folds=nbr_folds,training_size=training_size,run_logdir=run_logdir,plotting=plotting)
+        
+        # Store data. 
+        all_benefit_of_transfer_learning.append(benefit_of_transfer_learning)  # Store results for this target cluster against all other clusters
+        all_results_test_data.append(results_test_data)
+        all_results_no_transfer_learning_target_data.append(results_no_transfer_learning_target_data)
+        if plotting:
+            x_ticks.append(cluster_source['cluster_name'])  # Store the label for the plot
+
+
+    if plotting: # Barplot of performance with source subject. 
+        # Calculate the mean and standard deviation for the performance results
+        means_test = np.mean(all_results_test_data, axis=1)
+        stds_test = np.std(all_results_test_data, axis=1)   
+
+        means_no_transfer_target = np.mean(all_results_no_transfer_learning_target_data, axis=1)
+        stds_no_transfer_target = np.std(all_results_no_transfer_learning_target_data, axis=1)
+
+        # Plotting the results
+        fig, ax = plt.subplots(figsize=(28,7),layout='constrained')
+        x = np.arange(len(means_test))  # X-axis values
+        width = 0.25
+        # Create a bar plot with error bars showing standard deviation
+        # Test data
+        bars = ax.bar(x-1*width, means_test, yerr=stds_test, capsize=5,width=width,color='dimgray')
+        bars[i].set_hatch('x')
+
+        bars = ax.bar(x+0*width, means_no_transfer_target, yerr=stds_no_transfer_target, capsize=5,width=width,color='silver',hatch='.')
+        bars[i].set_hatch('x.')
+
+        # Pretty plotting
+        ax.set_xlabel('Source cluster')
+        ax.set_ylabel('Mean Accuracy')
+        ax.set_title('Performance of Each Cluster, target cluster = {}'.format(cluster_target['cluster_name']))
+        ax.set_xticks(x)
+        ax.set_ylim([0,1])
+        ax.set_xticklabels(x_ticks, rotation=90)  # Set the labels for the x-axis
+        label1 = mpatches.Patch( facecolor='dimgray',label='Source data for training')
+        label2 = mpatches.Patch(facecolor='silver',label='Target training data for training')
+        label3 = mpatches.Patch(edgecolor='black',hatch='..',label='No transfer learning')
+        label4 = mpatches.Patch(edgecolor='black',hatch='xxx',label='Same subject - same session')
+        ax.legend(handles = [label1,label2,label3,label4],loc=0)
+        plt.grid(False,axis='x')
+
+        # Save the plot to a file
+        print('--> Now Plotting barplot')
+        clf_name = list(clf.named_steps.keys())[0]
+        save_string = '{}_oracle_performance_target_cluster_training_size_{}'.format(cluster_target['cluster_name'],training_size)  # Filename
+        plt.savefig('{}/png_{}.png'.format(run_logdir,save_string))  # Save the figure in the specified directory
+        plt.close(fig)
+        print('--> Plotting complete')
+
+    # Save data to disk. 
+    os.mkdir(save_folder_pickle)
+    all_benefit_of_transfer_learning
+    with open('{}/all_benefit_of_transfer_learning.pickle'.format(save_folder_pickle), 'wb') as file:
+        pickle.dump(all_benefit_of_transfer_learning, file, protocol=pickle.HIGHEST_PROTOCOL)
+    # all_results_test_data
+    with open('{}/all_results_test_data.pickle'.format(save_folder_pickle), 'wb') as file:
+        pickle.dump(all_results_test_data, file, protocol=pickle.HIGHEST_PROTOCOL)
+    # all_results_no_transfer_learning_target_data
+    with open('{}/all_results_no_transfer_learning_target_data.pickle'.format(save_folder_pickle), 'wb') as file:
+        pickle.dump(all_results_no_transfer_learning_target_data, file, protocol=pickle.HIGHEST_PROTOCOL) 
+    all_results_no_transfer_learning_target_data
+    with open('{}/df_data.pickle'.format(save_folder_pickle), 'wb') as file:
+        pickle.dump(df_data, file, protocol=pickle.HIGHEST_PROTOCOL)     
+
+    # Remove variables to free space. 
+    del all_benefit_of_transfer_learning,all_results_test_data, all_results_no_transfer_learning_target_data,df_data
+    return 
+
+
+
+
+# -----------------------------------------------
+# Multiprocessing helper function:
+# -----------------------------------------------
+def worker_function(target_cluster_idx, clusters, clf, nbr_folds=5, run_logdir='.', training_size=8, plotting=False, data_folder='data'):
+    """
+    A wrapper function to evaluate the performance of a classifier for one target cluster against all other clusters, 
+    designed to be used with multiprocessing.
+
+    This function is intended to be passed to a multiprocessing Pool, allowing for parallel evaluation of transfer learning 
+    performance across multiple target clusters. It calls `performance_for_one_cluster_vs_all_cluster` to perform the actual 
+    evaluation and result generation.
+
+    Parameters
+    ----------
+    target_cluster_idx : int
+        The index of the cluster to be used as the target cluster for testing.
+
+    clusters : list of dict
+        A list of dictionaries, where each dictionary represents a data cluster with 'X_cov', 'y', 'y_enc', 'subject', and 'cluster_name' keys.
+
+    clf : classifier object
+        A classifier object implementing 'fit' and 'predict' methods.
+
+    nbr_folds : int, optional
+        The number of cross-validation folds. Default is 5.
+
+    run_logdir : str, optional
+        The directory where plots and results will be saved. Default is the current directory ('.').
+
+    training_size : int, optional
+        The size of the training data used from the target cluster in each fold. Default is 8.
+
+    plotting : bool, optional
+        Whether to plot the results (bar plots and other visuals). Default is False.
+
+    data_folder : str, optional
+        The folder where the results will be saved as pickle files. Default is 'data'.
+
+    Returns
+    -------
+    None
+        This function does not return any values, but it triggers `performance_for_one_cluster_vs_all_cluster` for the 
+        specified target cluster index which stores data to disk.
+
+    Notes
+    -----
+    This function is designed to be used in a multiprocessing context to allow parallel evaluation of multiple target clusters.
+    """
+    print("### worker_function for target cluster index = {}".format(target_cluster_idx))
+    return performance_for_one_cluster_vs_all_cluster(clusters, clf,target_cluster_idx,nbr_folds,run_logdir,training_size,plotting,data_folder)
+
+
+
+# -----------------------------------------------
+# Test all clusters vs all clusters
+# -----------------------------------------------
+def run_all_clusters(clusters, clf, nbr_folds=5, run_logdir='.', training_size=8, plotting_subresults=False, data_folder='data', clf_name=None):
+    """
+    Runs transfer learning evaluations for all clusters and generates performance reports and plots.
+
+    This function iterates through all clusters, evaluating the classifier's performance for each one as the target cluster 
+    using the rest as source clusters. It can run in single-threaded or multi-threaded mode depending on the number of threads (NBR_THREADS).
+    Results for each evaluation are saved to disk, and overall performance statistics are generated.
+
+    Parameters
+    ----------
+    clusters : list of dict
+        A list of dictionaries, where each dictionary represents a data cluster with 'X_cov', 'y', 'y_enc', 'subject', and 'cluster_name' keys.
+
+    clf : classifier object
+        A classifier object implementing 'fit' and 'predict' methods.
+
+    nbr_folds : int, optional
+        The number of cross-validation folds. Default is 5.
+
+    run_logdir : str, optional
+        The directory where plots and results will be saved. Default is the current directory ('.').
+
+    training_size : int, optional
+        The size of the training data used from the target cluster in each fold. Default is 8.
+
+    plotting_subresults : bool, optional
+        Whether to generate plots for the results of each cluster evaluation. Default is False.
+
+    data_folder : str, optional
+        The folder where the results will be saved as pickle files. Default is 'data'.
+
+    clf_name : str, optional
+        The name of the classifier, which can be used for labeling and saving plots. Default is None.
+
+    Returns
+    -------
+    average_intra_subject_performance : float
+        The average performance of the classifier when trained and tested on data from the same subject (intra-subject performance).
+
+    Notes
+    -----
+    - This function can either run in single-threaded or multi-threaded mode. If multi-threading is enabled, it uses the `worker_function` 
+      and multiprocessing Pool for parallel processing.
+    - Results for each target cluster, including transfer learning benefits and performance data, are saved as pickle files in the 
+      specified data folder from the subcalled functions.
+    - It generates several plots visualizing the transfer learning performance across all clusters, both in subject order and sorted by 
+      intra-subject performance.
+    - This function also saves metadata such as the number of clusters and cross-validation folds via the subcalled functions.
+    """
+    nbr_of_clusters = len(clusters)
+    list_of_cluster_names = [cluster['cluster_name'] for cluster in clusters]
+
+    # Create folder for data. 
+    if not os.path.exists('./{}/'.format(data_folder)):
+        print('creating data folder {}'.format(data_folder))
+        os.makedirs('./{}/'.format(data_folder))
+
+    # If multithreading or not. 
+    if NBR_THREADS == 1:
+        print('\n-------------------------------')
+        print(  '--> Single thread started < ---')
+        print(  '-------------------------------\n')
+        for target_cluster_idx in range(len(clusters)):
+            performance_for_one_cluster_vs_all_cluster(clusters, clf,target_cluster_idx,nbr_folds,run_logdir,training_size,plotting_subresults,data_folder)
+    else:
+        # Multithreaded version:
+        print('\n---------------------------------')
+        print(  '--> Multiprocessing started < ---')
+        print(  '---------------------------------\n')
+        
+        pool = Pool(processes=NBR_THREADS)
+        results = [pool.apply_async(worker_function, args=(target_cluster_idx, clusters, clf, nbr_folds,run_logdir,training_size,plotting_subresults,data_folder)) for target_cluster_idx in range(nbr_of_clusters)]
+        # for i, aresult in enumerate(results):
+        #     all_results[i, :],base_accuracy[i] = aresult.get()
+        print('---> Closing pools\n')
+        pool.close()
+        pool.join()
+        print('---> Pools closed \n')
+
+
+    # Placeholders
+    all_results_no_transfer_learning_target_data = np.zeros((nbr_of_clusters,nbr_of_clusters))
+    all_benefit_of_transfer_learning = np.zeros((nbr_of_clusters,nbr_of_clusters))
+    all_results_test_data = np.zeros((nbr_of_clusters,nbr_of_clusters))
+
+    # Load stored data from disk. 
+    data_folder_path = './{}/'.format(data_folder)
+
+    for i, subdir in enumerate(list_of_cluster_names):
+        print(subdir)
+        subdir_path = os.path.join(data_folder_path, subdir)
+        
+        # Check if the path is indeed a directory
+        if not os.path.isdir(subdir_path):
+            print('Not a folder: {}'.format(subdir))
+            continue # There might be a .DS_store. 
+        
+        
+        # == all_results_no_transfer_learning_target_data
+        pickle_file_path = os.path.join(subdir_path, 'all_results_no_transfer_learning_target_data.pickle')
+        
+        # Check if the pickle file exists
+        if not os.path.exists(pickle_file_path):
+            import pdb; pdb.set_trace()
+            bug # There is something wrong! The file is missing.
+        
+        with open(pickle_file_path, 'rb') as file:
+            data_here = pickle.load(file)
+        data_here = np.array(data_here)
+        data_here = np.mean(data_here,axis=1)
+        all_results_no_transfer_learning_target_data[i] = data_here
+
+
+        # == all_results_test_data
+        pickle_file_path = os.path.join(subdir_path, 'all_results_test_data.pickle')
+        
+        # Check if the pickle file exists
+        if not os.path.exists(pickle_file_path):
+            bug # There is something wrong! The file is missing.
+        
+        with open(pickle_file_path, 'rb') as file:
+            data_here = pickle.load(file)
+        data_here = np.array(data_here)  
+
+        data_here = np.mean(data_here,axis=1)
+        all_results_test_data[i] = data_here
+
+
+        # == all_benefit_of_transfer_learning
+        pickle_file_path = os.path.join(subdir_path, 'all_benefit_of_transfer_learning.pickle')
+        
+        # Check if the pickle file exists
+        if not os.path.exists(pickle_file_path):
+            bug # There is something wrong! The file is missing.
+        
+        with open(pickle_file_path, 'rb') as file:
+            data_here = pickle.load(file)
+        data_here = np.array(data_here)
+        data_here = np.mean(data_here,axis=1)
+        all_benefit_of_transfer_learning[i] = data_here
+
+    # Load some general data for the run. 
+    with open('{}/nbr_of_clusters.pickle'.format(data_folder_path), 'wb') as file:
+        pickle.dump(nbr_of_clusters, file, protocol=pickle.HIGHEST_PROTOCOL)   
+    with open('{}/nbr_folds.pickle'.format(data_folder_path), 'wb') as file:
+        pickle.dump(nbr_folds, file, protocol=pickle.HIGHEST_PROTOCOL)
+    with open('{}/list_of_cluster_names.pickle'.format(data_folder_path), 'wb') as file:
+        pickle.dump(list_of_cluster_names, file, protocol=pickle.HIGHEST_PROTOCOL) 
+
+    # Verify that the data has been processed correctly.
+    print(all_results_test_data-all_results_no_transfer_learning_target_data-all_benefit_of_transfer_learning)
+    print('The above matrix should be all zero')
+
+    # Plot transfer learning data. Setup figures.
+    print('--> Saving plot_of_all_data')
+    figsize = (nbr_of_clusters*1.5,nbr_of_clusters*1.5)
+    if nbr_of_clusters > 20:
+        plt.rcParams.update({'font.size': 20})
+
+    # Plot matrix of transfer learning performance in subject order.
+    fig, ax = plt.subplots(figsize=figsize,layout='constrained')
+    plot_all_accuracy_matrix(all_results_test_data, list_of_cluster_names,ax,fig,np.mean(all_results_no_transfer_learning_target_data,axis=1))
+    save_string = 'all_accuracy_training_size_{}'.format(training_size) # Filename
+    plt.savefig(f'{run_logdir}/png_{save_string}.png')  # Save the figure in the specified directory
+    plt.savefig(f'{run_logdir}/pdf_{save_string}.pdf')  # Save the figure in the specified directory
+    plt.close(fig)
+
+    # Plot matrix of transfer learning performance in sorted in intra subject performance order. 
+    fig, ax = plt.subplots(figsize=figsize,layout='constrained')
+    plot_all_accuracy_matrix_sorted(all_results_test_data, list_of_cluster_names,ax,fig,np.mean(all_results_no_transfer_learning_target_data,axis=1))
+    save_string = 'all_accuracy_sorted_training_size_{}'.format(training_size) # Filename
+    plt.savefig(f'{run_logdir}/png_{save_string}.png')  # Save the figure in the specified directory
+    plt.savefig(f'{run_logdir}/pdf_{save_string}.pdf')  # Save the figure in the specified directory
+    plt.close(fig)
+
+    print('--> Saving complete')
+
+    average_intra_subject_performance = np.mean(all_results_no_transfer_learning_target_data)
+    return average_intra_subject_performance
+
+
+# def plot_all_results_matrix_sorted(data_matrix, labels, ax, fig, base_accuracy=None):
+#     """
+#     Plots a sorted matrix of transfer learning results, with annotations and color-coded performance differences.
+
+#     This function creates a matrix plot that shows the performance (or benefits) of transfer learning for different source-target 
+#     cluster pairs. The rows and columns are sorted based on the base accuracy (intra-cluster performance) in descending order. 
+#     Cells in the matrix are color-coded using a diverging colormap, with performance values annotated within each cell.
+
+#     Parameters
+#     ----------
+#     data_matrix : ndarray of shape (n_clusters, n_clusters)
+#         A matrix containing transfer learning performance values for each source-target cluster pair.
+
+#     labels : list of str
+#         A list of labels for the clusters, used for the x and y axes.
+
+#     ax : matplotlib.axes.Axes
+#         The axes object where the matrix will be plotted.
+
+#     fig : matplotlib.figure.Figure
+#         The figure object to which the color bar will be added.
+
+#     base_accuracy : ndarray of shape (n_clusters,), optional
+#         The intra-cluster performance (e.g., accuracy) for each cluster, used for sorting the clusters. 
+#         If None, the matrix will not be sorted and all values will be treated as zero for the diagonal. 
+#         Default is None.
+
+#     Returns
+#     -------
+#     ax : matplotlib.axes.Axes
+#         The axes object with the plotted matrix.
+
+#     Notes
+#     -----
+#     - The clusters are sorted by their intra-cluster performance (base_accuracy) in descending order.
+#     - A diverging colormap is used to highlight positive and negative transfer learning benefits, with color intensity proportional 
+#       to the magnitude of the values.
+#     - The diagonal represents intra-cluster performance, highlighted with a different color and thicker borders.
+#     """
+#     if base_accuracy is None:
+#         base_accuracy = np.zeros(len(labels))
+
+#     sorted_indices = np.argsort(-base_accuracy)  # '-' for descending order
+#     base_accuracy = base_accuracy[sorted_indices]
+#     data_matrix = data_matrix[sorted_indices, :][:, sorted_indices]
+#     labels = np.array(labels)[sorted_indices]
+#     vmin = np.min(data_matrix)
+#     vmax = np.max(data_matrix)
+#     limit = max(abs(vmin),abs(vmax))
+#     cax = ax.matshow(data_matrix, cmap='RdBu',vmin=-limit,vmax=limit)
+
+#     plt.xticks(np.arange(len(labels)), labels, rotation=45)
+#     plt.yticks(np.arange(len(labels)), labels)
+
+
+#     threshold = limit*2/3
+#     # Loop over data dimensions and create text annotations.
+#     for i in range(len(labels)):
+#         for j in range(len(labels)):
+#             if i==j:
+#                 text = '{:.2f}'.format(base_accuracy[i])
+#                 highlight_cell(i,j,edgecolor="black", linewidth=3,ax=ax,fill=True,facecolor='silver')
+#             else:
+#                 text = '{:.2f}'.format(data_matrix[i, j])
+#             color = "white" if (data_matrix[i, j] > threshold) or (data_matrix[i, j] < -threshold) else "black"
+#             ax.text(j, i, text, ha="center", va="center", color=color)
+
+#     ax.grid(False)
+#     plt.xlabel('Source cluster')
+#     plt.ylabel('Target cluster')
+#     plt.title('Transfer learning benefits')
+#     fig.colorbar(cax)
+
+#     return ax
+
+
+
+# def plot_all_results_matrix(data_matrix, labels, ax, fig, base_accuracy=None):
+#     """
+#     Plots a matrix of transfer learning results, with annotations and color-coded performance differences.
+
+#     This function creates a matrix plot that shows the performance (or benefits) of transfer learning for different source-target 
+#     cluster pairs. Cells in the matrix are color-coded using a diverging colormap, with performance values annotated within each cell.
+#     The diagonal represents intra-cluster performance and is highlighted with a different color and thicker borders.
+
+#     Parameters
+#     ----------
+#     data_matrix : ndarray of shape (n_clusters, n_clusters)
+#         A matrix containing transfer learning performance values for each source-target cluster pair.
+
+#     labels : list of str
+#         A list of labels for the clusters, used for the x and y axes.
+
+#     ax : matplotlib.axes.Axes
+#         The axes object where the matrix will be plotted.
+
+#     fig : matplotlib.figure.Figure
+#         The figure object to which the color bar will be added.
+
+#     base_accuracy : ndarray of shape (n_clusters,), optional
+#         The intra-cluster performance (e.g., accuracy) for each cluster, used for annotating the diagonal of the matrix.
+#         If None, the diagonal values will be treated as zero. Default is None.
+
+#     Returns
+#     -------
+#     ax : matplotlib.axes.Axes
+#         The axes object with the plotted matrix.
+
+#     Notes
+#     -----
+#     - A diverging colormap is used to highlight positive and negative transfer learning benefits, with color intensity 
+#       proportional to the magnitude of the values.
+#     - The diagonal represents intra-cluster performance and is highlighted with a different color and thicker borders.
+#     - The matrix cells are annotated with the performance values, and the text color is adjusted for readability based on 
+#       the cell's value.
+#     """
+#     if base_accuracy is None:
+#         base_accuracy = np.zeros(len(labels))
+#     vmin = np.min(data_matrix)
+#     vmax = np.max(data_matrix)
+#     limit = max(abs(vmin),abs(vmax))
+#     cax = ax.matshow(data_matrix, cmap='RdBu',vmin=-limit,vmax=limit)
+
+#     plt.xticks(np.arange(len(labels)), labels, rotation=45)
+#     plt.yticks(np.arange(len(labels)), labels)
+
+
+#     threshold = limit*2/3
+#     # Loop over data dimensions and create text annotations.
+#     for i in range(len(labels)):
+#         for j in range(len(labels)):
+#             if i==j:
+#                 text = '{:.2f}'.format(base_accuracy[i])
+#                 highlight_cell(i,j,edgecolor="black", linewidth=3,ax=ax,fill=True,facecolor='silver')
+#             else:
+#                 text = '{:.2f}'.format(data_matrix[i, j])
+#             color = "white" if (data_matrix[i, j] > threshold) or (data_matrix[i, j] < -threshold) else "black"
+#             ax.text(j, i, text, ha="center", va="center", color=color)
+
+#     ax.grid(False)
+#     plt.xlabel('Source cluster')
+#     plt.ylabel('Target cluster')
+#     plt.title('Transfer learning benefits')
+#     fig.colorbar(cax)
+
+#     return ax
+
+
+def plot_all_accuracy_matrix(data_matrix, labels, ax, fig, diagonal_values=None):
+    """
+    Plots a matrix of transfer learning accuracy, with annotations and color-coded values.
+
+    This function creates a matrix plot that shows the accuracy of transfer learning for different source-target cluster pairs.
+    The diagonal values represent intra-cluster accuracy and are highlighted with a different border. Each cell in the matrix is 
+    color-coded based on accuracy, and the values are annotated within the cells for better clarity.
+
+    Parameters
+    ----------
+    data_matrix : ndarray of shape (n_clusters, n_clusters)
+        A matrix containing accuracy values for each source-target cluster pair.
+
+    labels : list of str
+        A list of labels for the clusters, used for the x and y axes.
+
+    ax : matplotlib.axes.Axes
+        The axes object where the matrix will be plotted.
+
+    fig : matplotlib.figure.Figure
+        The figure object to which the color bar will be added.
+
+    diagonal_values : ndarray of shape (n_clusters,), optional
+        The values to display on the diagonal of the matrix, representing intra-cluster accuracy. If None, the diagonal values 
+        will be extracted from the data matrix. Default is None.
+
+    Returns
+    -------
+    ax : matplotlib.axes.Axes
+        The axes object with the plotted matrix.
+
+    Notes
+    -----
+    - The diagonal values (intra-cluster accuracy) are highlighted with a thicker border for emphasis.
+    - A diverging colormap is used to visualize accuracy values, with color intensity proportional to the magnitude of the accuracy.
+    - The matrix cells are annotated with the accuracy values, and the text color is adjusted for readability based on the cell's value.
+    """
+    if diagonal_values is None:
+        diagonal_values = np.diag(data_matrix)
+    vmin = 0.5
+    vmax = 1
+
+    np.fill_diagonal(data_matrix,diagonal_values)
+    
+    cax = ax.matshow(data_matrix, cmap='RdBu',vmin=vmin,vmax=vmax)
+
+    plt.xticks(np.arange(len(labels)), labels, rotation=45)
+    plt.yticks(np.arange(len(labels)), labels)
+
+    threshold = 0.25*2/3
+    # Loop over data dimensions and create text annotations.
+    for i in range(len(labels)):
+        for j in range(len(labels)):
+            if i==j:
+                highlight_cell(i,j,edgecolor="black", linewidth=4,ax=ax,fill=False)
+            text = '{:.2f}'.format(data_matrix[i, j])
+            color = "white" if (data_matrix[i, j] > 0.75+threshold) or (data_matrix[i, j] < 0.75-threshold) else "black"
+            ax.text(j, i, text, ha="center", va="center", color=color)
+
+    ax.grid(False)
+    plt.xlabel('Source cluster')
+    plt.ylabel('Target cluster')
+    plt.title('Transfer learning accuracies')
+    fig.colorbar(cax)
+
+    return ax
+
+
+
+def plot_all_accuracy_matrix_sorted(data_matrix, labels, ax, fig, diagonal_values=None):
+    """
+    Plots a sorted matrix of transfer learning accuracy, with annotations and color-coded values.
+
+    This function creates a matrix plot showing the accuracy of transfer learning for different source-target cluster pairs.
+    The rows and columns are sorted based on the diagonal values (intra-cluster accuracy) in descending order. Each cell in the matrix 
+    is color-coded based on accuracy, and the values are annotated within the cells for clarity.
+
+    Parameters
+    ----------
+    data_matrix : ndarray of shape (n_clusters, n_clusters)
+        A matrix containing accuracy values for each source-target cluster pair.
+
+    labels : list of str
+        A list of labels for the clusters, used for the x and y axes.
+
+    ax : matplotlib.axes.Axes
+        The axes object where the matrix will be plotted.
+
+    fig : matplotlib.figure.Figure
+        The figure object to which the color bar will be added.
+
+    diagonal_values : ndarray of shape (n_clusters,), optional
+        The values to display on the diagonal of the matrix, representing intra-cluster accuracy. If None, the diagonal values 
+        will be extracted from the data matrix. The matrix will be sorted based on these values. Default is None.
+
+    Returns
+    -------
+    ax : matplotlib.axes.Axes
+        The axes object with the plotted matrix.
+
+    Notes
+    -----
+    - The diagonal values (intra-cluster accuracy) are used to sort the matrix in descending order.
+    - A diverging colormap is used to visualize accuracy values, with color intensity proportional to the magnitude of the accuracy.
+    - The matrix cells are annotated with the accuracy values, and the text color is adjusted for readability based on the cell's value.
+    - The diagonal values are highlighted with a thicker border for emphasis.
+    """
+    if diagonal_values is None:
+        diagonal_values = np.diag(data_matrix)
+
+    np.fill_diagonal(data_matrix,diagonal_values)
+
+    # Sort matrix and data.
+    sorted_indices = np.argsort(-diagonal_values)  # '-' for descending order
+    diagonal_values = diagonal_values[sorted_indices]
+    data_matrix = data_matrix[sorted_indices, :][:, sorted_indices]
+    labels = np.array(labels)[sorted_indices]
+
+    # Limits for plotting colors. 
+    vmin = 0.5
+    vmax = 1
+    
+    # Plot matrix.
+    cax = ax.matshow(data_matrix, cmap='RdBu',vmin=vmin,vmax=vmax)
+
+    # Legends
+    plt.xticks(np.arange(len(labels)), labels, rotation=45)
+    plt.yticks(np.arange(len(labels)), labels)
+
+    threshold = 0.25*1/3 # For when to change color for text. 
+    # Loop over data dimensions and create text annotations.
+    for i in range(len(labels)):
+        for j in range(len(labels)):
+            if i==j:
+                highlight_cell(i,j,edgecolor="black", linewidth=4,ax=ax,fill=False)
+            text = '{:.2f}'.format(data_matrix[i, j])
+            color = "white" if (data_matrix[i, j] > 0.75+threshold) or (data_matrix[i, j] < 0.75-threshold) else "black"
+            ax.text(j, i, text, ha="center", va="center", color=color)
+
+    ax.grid(False)
+    plt.xlabel('Source cluster')
+    plt.ylabel('Target cluster')
+    plt.title('Transfer learning accuracies')
+    fig.colorbar(cax)
+
+    return ax
+
+
+def highlight_cell(row, col, ax=None, fill=False, **kwargs):
+    """
+    Highlights a specific cell in a matrix plot by drawing a rectangle around it.
+
+    This function adds a rectangular patch to highlight a specific cell in a plot (typically a heatmap or matrix plot).
+    The cell is defined by its row and column indices, and the rectangle can optionally be filled with a color.
+
+    Parameters
+    ----------
+    row : int
+        The row index of the cell to highlight.
+
+    col : int
+        The column index of the cell to highlight.
+
+    ax : matplotlib.axes.Axes, optional
+        The axes object on which to draw the rectangle. If None, the current axes will be used. Default is None.
+
+    fill : bool, optional
+        Whether to fill the rectangle with color. Default is False.
+
+    **kwargs : dict, optional
+        Additional keyword arguments to pass to `matplotlib.patches.Rectangle` (e.g., edgecolor, facecolor, linewidth).
+
+    Returns
+    -------
+    rect : matplotlib.patches.Rectangle
+        The rectangle patch added to the plot.
+
+    Notes
+    -----
+    This function is typically used to highlight specific cells in heatmaps or matrix plots, where each cell represents 
+    a data point. The rectangle can be customized using the `**kwargs` to control the appearance of the highlight.
+    """
+    y=row
+    x=col
+    rect = plt.Rectangle((x-.5, y-.5), 1,1, fill=fill, **kwargs)
+    ax = ax or plt.gca()
+    ax.add_patch(rect)
+    return rect
+
+
+# -----------------------------------------------
+# Main 
+# -----------------------------------------------
+# execute only if run as a script
+if __name__ == "__main__":
+    begin_time = datetime.datetime.now()
+    log_folder = "my_logs"
+    if not os.path.exists(log_folder):
+        os.makedirs(log_folder)
+    # Create run directory to store models and logs in:
+    root_logdir = os.path.join(os.curdir, log_folder)
+    run_id = "%s_on_%s" % (time.strftime("run_%Y_%m_%d-%H_%M_%S"), os.uname()[1])
+    run_logdir = os.path.join(root_logdir, run_id)
+    output_folder = run_logdir
+    os.mkdir(run_logdir)
+    sys.stderr = Logger("{}/console_err.txt".format(run_logdir))
+    sys.stdout = Logger("{}/console_log.txt".format(run_logdir))
+    print(
+        "\n### main.py was started on %s in %s with logs in %s"
+        % (os.uname()[1], os.getcwd(), run_logdir)
+    )
+    print(sys.argv)
+    print_git_version()
+    # print_source_code()
+    save_source_code(run_logdir)
+    now = begin_time
+    now_string = str(now).replace(" ", "_").replace(":","_")
+
+    plt.rcParams.update({'font.size': 15})
+    # save a file summarizing the used parameters in the script. 
+    with open('{}/description.txt'.format(run_logdir), 'a') as f:
+        # Write some text to the file
+        f.write("This is a description file.\n")
+
+        f.write('NBR_THREADS: {}\n'.format(NBR_THREADS))
+        f.write('\nNBR_FOLDS: {}\n'.format(NBR_FOLDS))
+        f.write('\nDATA_FOLDER: {}\n'.format(DATA_FOLDER))
+        f.write('\nDATA_SET: {}\n'.format(DATA_SET))
+        if DATA_SET == 'PhysioNet':
+            f.write('CLASSES_PHYSIONET: {}\n'.format(CLASSES_PHYSIONET))
+            f.write('Nbr of classes: 2\n')
+
+
+        f.write('nbr_samples_per_class: {}\n'.format(nbr_samples_per_class))
+        f.write('\nSUBJECTS_TO_USE: ')
+        for i in SUBJECTS_TO_USE:
+            f.write(' {},'.format(i))
+        f.write('\n\nClassifiers:')
+        for name,_ in ALL_CLFS:
+            f.write('{}\n'.format(name))
+        f.write('\n\nPLOTTING_SUBRESULTS: {}\n'.format(PLOTTING_SUBRESULTS))
+
+
+    # -------------------------------------------------------------------------------
+
+    # -----------------------------------------------
+    # Load data
+    # -----------------------------------------------
+
+    clusters = load_data_physio_net(SUBJECTS_TO_USE,CLASSES_PHYSIONET)
+
+
+    
+    # -----------------------------------------------
+    # Run the code. 
+    # -----------------------------------------------
+    all_results = [] # Placeholder
+    for training_size in ALL_TRAINING_SIZE:
+        for name,clf in ALL_CLFS:
+            print('=============================')
+            print('===== Nbr of training_size: {} ====='.format(training_size))
+            print('classifier: {}'.format(name))
+            print('=============================')
+
+            data_folder = f'{DATA_FOLDER}'
+            results = run_all_clusters(clusters,clf,NBR_FOLDS,run_logdir,training_size,PLOTTING_SUBRESULTS,data_folder,name)
+            all_results.append('Average intra subjct performance {:4.4f} clf: {}'.format(results,name))
+
+    for res in all_results:
+        print(res)
+
+    plt.close('all')
+    print("\n...done!")
+    print("### Total run time: %s" % (datetime.datetime.now() - begin_time))
-- 
GitLab