diff --git a/connectivty_group.py b/connectivty_group.py new file mode 100644 index 0000000..9e2fd07 --- /dev/null +++ b/connectivty_group.py @@ -0,0 +1,398 @@ +import os +import numpy as np +import matplotlib.pyplot as plt +import json +import pandas as pd +import hashlib +import csv +import mne +import scipy.io + +from statsmodels.stats.multitest import multipletests +from scipy.stats import mannwhitneyu, shapiro, hmean, pearsonr, spearmanr + +from functions import compute_distances, plot_statistical_test_results, \ + get_biosemi_to_1020_mapping, compute_connectivity_matrix + +param_file = "npm_8-13Hz_0.7s_R.json" +#param_file = "npm_8-13Hz_0.7s_P.json" + +#param_file = "npm_14-30Hz_0.7s_R.json" +#param_file = "npm_14-30Hz_0.7s_P.json" + +with open(param_file) as json_file: + param_dict = json.load(json_file) + +subjects = np.array(sorted(param_dict['subjects'])).astype(str) + +n_of_plots_in_line = 3 +show_electrodes = 0 +do_rename_channels = True + +results_base_dir = param_dict['result_dir'] +data_base_dir = param_dict['datapath'] +reference_raw_file = data_base_dir + "000_P1.vhdr" +conditions = param_dict['conditions'] +methods = param_dict['methods'] +franges = param_dict['frequency_band'] +tmin_tmax = param_dict['tmin_tmax'] +csd_params = param_dict['CSD_parameters'] +add_label = param_dict['add_label'] + +min_edge_pvals = [] + +epochs_tmin = -2 +epochs_tmax = 4 + +reference_raw = mne.io.read_raw_brainvision(reference_raw_file, preload=False) +if do_rename_channels: + names_dict = {} + ch_ind = 0 + for prefix in ['A', 'B', 'C', 'D']: + for i in range(0, 32): + name = prefix + str(i + 1) + if name != reference_raw.info['ch_names'][ch_ind]: + names_dict.update({reference_raw.info['ch_names'][ch_ind]: name}) + ch_ind += 1 + reference_raw.rename_channels(names_dict) + + w = csv.writer(open(results_base_dir + "electrodes_dict.csv", "w")) + for key, val in names_dict.items(): + w.writerow([key, val]) + + types_dict = {} + for name in reference_raw.info['ch_names']: + if name[0] != 'E': + types_dict.update({name: 'eeg'}) + else: + types_dict.update({name: 'eog'}) + reference_raw.set_channel_types(types_dict) +picks = mne.pick_types(reference_raw.info, eeg=True) +n_all_el = picks.shape[0] +my_montage = mne.channels.read_custom_montage("biosemi128.sfp") +reference_raw.set_montage(my_montage) +print("Reference raw loaded") + +for condition in conditions: + for f_range in franges: + f_range_low = f_range[0] + f_range_high = f_range[1] + for tmin_tmax_range in tmin_tmax: + if tmin_tmax_range is None: + conn_tmin = -0.2 + conn_tmax = 0.5 + else: + conn_tmin = tmin_tmax_range[0] + conn_tmax = tmin_tmax_range[1] + for csd_param in csd_params: + if csd_param is None: + do_csd = False + else: + do_csd = True + stiffnes = csd_param[0] + lambd = csd_param[1] + for method in methods: + connectivity_full = np.empty(shape=(n_all_el, n_all_el, len(subjects))) + labels_true = [] + subj_pos = 0 + n_controls = 0 + n_patients = 0 + suffix_base = "" + suffix_base += "_" + str(f_range_low) + "-" + str(f_range_high) + "_" + method + "_T_" + \ + str(conn_tmin) + "_" + str(conn_tmax) + "s" + if do_csd: + suffix_base += "_csd" + str(stiffnes) + "_" + str(lambd) + else: + suffix_base += "_nocsd" + if condition == "proposer": + suffix_base += "P" + elif condition == "responder": + suffix_base += "R" + + for subj in subjects: + if subj[0] == '0': + n_controls += 1 + labels_true.append(1) + if condition == "responder": + file_name = subj + "_R1" + elif condition == "proposer": + file_name = subj + "_P1" + + else: + raise ValueError("Unrecognised condition" + condition) + + elif subj[0] == '1': + n_patients += 1 + labels_true.append(2) + if condition == "responder": + file_name = subj + "_R1" + elif condition == "proposer": + file_name = subj + "_P1" + else: + raise ValueError("Unrecognised condition" + condition) + + ############################# Loading/computing connectivity for each subject ######################### + if os.path.exists(results_base_dir + "conn_matrices/" + 'con' + file_name + suffix_base + '.npy'): + print( + "Reading precomputed connectivity matrix:\t" + results_base_dir + "conn_matrices/" + 'con' + file_name + suffix_base + '.npy') + c_matrix = np.load(results_base_dir + "conn_matrices/" + 'con' + file_name + suffix_base + '.npy') + + else: + c_matrix = compute_connectivity_matrix(f_range, tmin_tmax_range, csd_param, method, + condition, subj, epochs_tmin, epochs_tmax, + results_base_dir + "conn_matrices/", data_base_dir) + connectivity_full[:, :, subj_pos] = c_matrix + n_channels = c_matrix.shape[0] + subj_pos += 1 + + labels_true = np.array(labels_true).astype(int) + + el_set_index = -1 + for selected_channels_param in param_dict['electrode_subset']: + el_set_index += 1 + if type(selected_channels_param) is list: + selected_channels = selected_channels_param + elif type(selected_channels_param) is str: + if selected_channels_param == 'All': + selected_channels = ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', 'A11', + 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A20', 'A21', + 'A22', 'A23', 'A24', 'A25', 'A26', 'A27', 'A28', 'A29', 'A30', 'A31', + 'A32', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', + 'B11', 'B12', 'B13', 'B14', 'B15', 'B16', 'B17', 'B18', 'B19', 'B20', + 'B21', 'B22', 'B23', 'B24', 'B25', 'B26', 'B27', 'B28', 'B29', 'B30', + 'B31', 'B32', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', + 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', + 'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26', 'C27', 'C28', 'C29', + 'C30', 'C31', 'C32', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7', 'D8', + 'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15', 'D16', 'D17', 'D18', + 'D19', 'D20', 'D21', 'D22', 'D23', 'D24', 'D25', 'D26', 'D27', 'D28', + 'D29', 'D30', 'D31', 'D32'] + elif selected_channels_param == '32el10-20': + electrode_mapping = get_biosemi_to_1020_mapping(32) + selected_channels = list(electrode_mapping.keys()) + + elif selected_channels_param == '64el10-20': + electrode_mapping = get_biosemi_to_1020_mapping(64) + selected_channels = list(electrode_mapping.keys()) + else: + raise ValueError("Wrong channel selection parameter: should be a list of channel labels or 'All'") + else: + raise ValueError("Wrong channel selection parameter: should be a list of channel labels or 'All'") + + if selected_channels_param == '32el10-20' or selected_channels_param == '64el10-20': + el_labels = [electrode_mapping[key] for key in selected_channels] + else: + el_labels = selected_channels + + suffix = suffix_base + "_" + str(len(selected_channels)) + "el_" + text_el = "".join(selected_channels) + suffix += hashlib.sha256(text_el.encode()).hexdigest()[:8] + suffix += add_label + + electrode_dist_m = compute_distances(selected_channels, ref_data_file=reference_raw_file, + vis=show_electrodes) + + n_channels = len(selected_channels) + selected_ch_ind = [mne.pick_channels(reference_raw.info['ch_names'], include=[ch])[0] for ch in selected_channels] + connectivity_subset = connectivity_full[selected_ch_ind, :, :][:, selected_ch_ind, :] + + for i in range(connectivity_subset.shape[2]): + connectivity_subset[:, :, i][np.triu_indices_from(connectivity_subset[:, :, i], 1)] = 0 + + xyz = np.empty((n_channels, 3)) + ch_names = np.empty(n_channels) + for i, rawind in enumerate(selected_ch_ind): + xyz[i, :] = reference_raw.info['chs'][rawind]['loc'][0:3] + xyz[:, 1] += -0.017 + + connectivity_dict = {'conn': connectivity_subset, 'subjects': subjects, + 'xyz': xyz, 'el_labels': el_labels, + 'design_matrix': np.vstack([labels_true - 1, -1 * (labels_true - 2)])} + scipy.io.savemat(results_base_dir + 'con' + suffix_base + add_label + '.mat', connectivity_dict) + + #Non parametric test, gaussianity test + ustats = np.empty_like(connectivity_subset[:, :, 0]) + ustats[:] = np.nan + upvalue = np.empty_like(connectivity_subset[:, :, 0]) + upvalue[:] = np.nan + isgaussian_stats = np.empty_like(connectivity_subset[:, :, 0]) + isgaussian_stats[:] = np.nan + isgaussian_pvalue = np.empty_like(connectivity_subset[:, :, 0]) + isgaussian_pvalue[:] = np.nan + flat_mwupvalues = np.empty(shape=(int(n_channels * (n_channels - 1) / 2))) + + ind = 0 + for i in range(0, n_channels): + for j in range(i + 1, n_channels): + ustats[j, i], upvalue[j, i] = mannwhitneyu(connectivity_subset[j, i, :n_controls], connectivity_subset[j, i, n_controls:]) + isgaussian_stats[j, i], isgaussian_pvalue[j, i] = shapiro(connectivity_subset[j, i, :n_controls]) + flat_mwupvalues[ind] = upvalue[j, i] + ind += 1 + + #Multiple compatison correction can only be done on a vector + h_true, corrected_mwupvalues_flat, als, alb = multipletests(flat_mwupvalues, method='fdr_bh') + + corrected_mwupvalue = np.empty_like(upvalue) + corrected_mwupvalue[:] = np.nan + ind = 0 + for i in range(0, n_channels): + for j in range(i + 1, n_channels): + corrected_mwupvalue[j, i] = corrected_mwupvalues_flat[ind] + ind += 1 + + df_mwustats = pd.DataFrame(ustats, columns=el_labels, index=el_labels) + df_mwustats.to_csv(results_base_dir + "mwustats" + suffix + '.csv', index=False, header=True, float_format='%.5f') + df_mwupval = pd.DataFrame(upvalue, columns=el_labels, index=el_labels) + df_mwupval.to_csv(results_base_dir + "mwupval" + suffix + '.csv', index=False, header=True, float_format='%.5f') + df_corr_mwupval = pd.DataFrame(corrected_mwupvalue, columns=el_labels, index=el_labels) + df_corr_mwupval.to_csv(results_base_dir + "mwupval_corrected_" + suffix + '.csv', index=False, + header=True, float_format='%.5f') + + min_edge_pvals.append([len(selected_channels), suffix[-8:], method, f_range[0], f_range[1], + conn_tmin, conn_tmax, stiffnes, lambd, -1 * np.log10(np.nanmin(corrected_mwupvalue))]) + + mean_pval = np.zeros(upvalue.shape[0]) + symm_upvalue = np.copy(upvalue) + symm_upvalue = np.nan_to_num(symm_upvalue) + symm_upvalue += symm_upvalue.T + for i in range(corrected_mwupvalue.shape[0]): + if i == 0: + mean_pval[i] = hmean(symm_upvalue[i, i + 1:]) + elif 0 < i < symm_upvalue.shape[0] - 1: + mean_pval[i] = hmean(np.hstack([symm_upvalue[i, 0:i], symm_upvalue[i, i + 1:]])) + else: + mean_pval[i] = hmean(symm_upvalue[i, 0:i]) + + df_mean_mwupval = pd.DataFrame(mean_pval, index=selected_channels) + df_mean_mwupval.to_csv(results_base_dir + "mean_u_pvalues" + suffix + '.csv', + index=selected_channels, header=["mean_pval"], float_format='%.5f') + + plot_statistical_test_results(isgaussian_stats, isgaussian_pvalue, labels=el_labels, + writepath=results_base_dir + "gaussianity_test" + suffix + ".png", + title='Gaussianity test, controls, ' + method) + + plot_statistical_test_results(ustats, upvalue, labels=el_labels, + writepath=results_base_dir + "mannwhitney" + suffix + ".png", + title='Mann-Whitney rank test, ' + method, left_title="Mann-Whitney U statistic") + + plot_statistical_test_results(ustats, corrected_mwupvalue, labels=el_labels, + writepath=results_base_dir + "mannwhitney_corr" + suffix + ".png", + title='Mann-Whitney rank test, corrected, ' + method, left_title="Mann-Whitney U statistic") + + plot_statistical_test_results(ustats, -1*np.log10(corrected_mwupvalue), labels=el_labels, + writepath=results_base_dir + "mannwhitney_corr_log" + suffix + ".png", + title='-log Mann-Whitney rank test, corrected, ' + method, + left_title="Mann-Whitney U statistic") + + + results_dir = results_base_dir + suffix[1:] + "/" + if not os.path.exists(results_dir): + os.mkdir(results_dir) + print("Results saved at: " + results_dir) + + el_dist_flat = electrode_dist_m[np.tril_indices(connectivity_subset.shape[0], -1)] + + ############## plotting distance matrix ############################ + fig = plt.figure() + im = plt.imshow(electrode_dist_m, cmap='hot') + step = max(1, int(np.log2(len(selected_channels)+1) / 2)) + plt.xticks(np.arange(len(selected_channels), step=step), selected_channels[::step], rotation=270) + plt.yticks(np.arange(len(selected_channels), step=step), selected_channels[::step]) + plt.colorbar(im, fraction=0.046, pad=0.04) + plt.savefig(results_dir + "distance_matrix" + suffix + ".png", dpi=600) + plt.close(fig) + + result_flat = np.zeros(shape=(len(subjects), len(np.tril_indices(connectivity_subset.shape[0], -1)[0]))) + for i in range(0, len(subjects)): + result_flat[i, :] = connectivity_subset[:, :, i][ + np.tril_indices(connectivity_subset.shape[0], -1)] + + for i in range(connectivity_subset.shape[2]): + connectivity_subset[:, :, i][np.triu_indices_from(connectivity_subset[:, :, i], 1)] = np.nan + + for p_number in range(0, int(max(n_patients, n_controls) / n_of_plots_in_line)): + + fig = plt.figure('Connectivity ' + str(p_number) + ', ' + method, figsize=[12, 6]) + for i in range(0, n_of_plots_in_line): + ax = plt.subplot(2, n_of_plots_in_line, i + 1) + if p_number * n_of_plots_in_line + i < n_controls: + plt.title(subjects[p_number * n_of_plots_in_line + i]) + im = plt.imshow(connectivity_subset[:, :, p_number * n_of_plots_in_line + i], cmap='twilight_shifted') + plt.xticks(np.arange(len(el_labels), step=step), el_labels[::step], rotation=270) + plt.yticks(np.arange(len(el_labels), step=step), el_labels[::step]) + + for i in range(0, n_of_plots_in_line): + ax = plt.subplot(2, n_of_plots_in_line, i + 1 + n_of_plots_in_line) + if p_number * n_of_plots_in_line + i < n_patients: + plt.title(subjects[n_controls + p_number * n_of_plots_in_line + i]) + im = plt.imshow(connectivity_subset[:, :, n_controls + p_number * n_of_plots_in_line + i], cmap='twilight_shifted') + plt.xticks(np.arange(len(el_labels), step=step), el_labels[::step], rotation=270) + plt.yticks(np.arange(len(el_labels), step=step), el_labels[::step]) + plt.subplots_adjust(bottom=0.1, right=0.8, top=0.9) + cax = plt.axes([0.85, 0.12, 0.03, 0.75]) + cbar = plt.colorbar(cax=cax) + cbar.ax.set_ylabel('Connectivity strength') + plt.subplots_adjust(top=0.935, bottom=0.075, left=0.04, right=0.835, hspace=0.3, wspace=0.254) + plt.savefig(results_dir + "connectiviy" + str(p_number) + suffix + ".png", dpi=600) + plt.close(fig) + + fig = plt.figure( + 'Connectivity (y-axis) vs distance (x-axis) correlation ' + str(p_number) + ', ' + method, figsize=[12, 6]) + for i in range(0, n_of_plots_in_line): + ax = plt.subplot(2, n_of_plots_in_line, i + 1) + if p_number * n_of_plots_in_line + i < n_controls: + corr, pcorr = spearmanr(100*el_dist_flat, result_flat[p_number * n_of_plots_in_line + i, :]) + plt.title(subjects[p_number * n_of_plots_in_line + i] + "\ncorr=" + + format(corr, '.2f') + " p=" + format(pcorr, '.2f')) + plt.scatter(100*el_dist_flat, result_flat[p_number * n_of_plots_in_line + i, :]) + #plt.xlabel("Distance (cm)") + plt.ylabel("ImCoh") + axes = plt.gca() + + for i in range(0, n_of_plots_in_line): + ax = plt.subplot(2, n_of_plots_in_line, i + 1 + n_of_plots_in_line) + if p_number * n_of_plots_in_line + i < n_patients: + corr, pcorr = pearsonr(100 * el_dist_flat, + result_flat[p_number * n_of_plots_in_line + i, :]) + plt.title(subjects[n_controls+p_number * n_of_plots_in_line + i] + + "\ncorr=" + format(corr, '.2f') + ", p=" + format(pcorr, '.2f')) + plt.scatter(100*el_dist_flat, result_flat[n_controls + p_number * n_of_plots_in_line + i, :]) + plt.xlabel("Distance (cm)") + plt.ylabel("ImCoh") + axes = plt.gca() + plt.tight_layout(pad=1) + plt.savefig(results_dir + "conn_vs_dist" + str(p_number) + suffix + ".png", dpi=600) + plt.close(fig) + + for p_number in range(0, int(max(n_patients, n_controls) / n_of_plots_in_line)): + fig = plt.figure('Weight distribution, ' + str(p_number) + ', ' + method, figsize=[12, 6]) + + for i in range(0, n_of_plots_in_line): + ax = plt.subplot(2, n_of_plots_in_line, i + 1) + if p_number * n_of_plots_in_line + i < n_controls: + plt.title(subjects[p_number * n_of_plots_in_line + i]) + im = plt.hist(connectivity_subset[:, :, p_number * n_of_plots_in_line + i][np.where( + connectivity_subset[:, :, p_number * n_of_plots_in_line + i] != 0)].flatten()) + axes = plt.gca() + axes.set_ylim([0, 100]) + axes.set_xlim([np.nanmin(connectivity_subset), np.nanmax(connectivity_subset)]) + + for i in range(0, n_of_plots_in_line): + ax = plt.subplot(2, n_of_plots_in_line, i + 1 + n_of_plots_in_line) + if p_number * n_of_plots_in_line + i < n_patients: + plt.title(subjects[n_controls + p_number * n_of_plots_in_line + i]) + im = plt.hist(connectivity_subset[:, :, n_controls + p_number * n_of_plots_in_line + i][np.where( + connectivity_subset[:, :, n_controls + p_number * n_of_plots_in_line + i] != 0)].flatten()) + axes = plt.gca() + axes.set_ylim([0, 100]) + axes.set_xlim([np.nanmin(connectivity_subset), np.nanmax(connectivity_subset)]) + + plt.tight_layout(pad=1) + plt.savefig(results_dir + "weight_distr" + str(p_number) + suffix + ".png", dpi=600) + plt.close(fig) + +min_edge_pvals_df = pd.DataFrame(min_edge_pvals, columns=['n-el', 'el-hash', 'method', 'freq-min', 'freq-max', + "tmin", "tmax", "csd-st", "csd-lam", '-log10(min-pval)']) +np.savetxt(results_base_dir + param_file[:-5] + 'min_edge_pvals.csv', min_edge_pvals_df.values, + fmt=['%d', '%s', '%s', '%d', '%d', '%.2f', '%.2f', '%.2f', '%.2e', '%.2f'], delimiter=', ', + header=''+','.join(min_edge_pvals_df.columns), comments='') \ No newline at end of file diff --git a/functions.py b/functions.py new file mode 100644 index 0000000..23b1965 --- /dev/null +++ b/functions.py @@ -0,0 +1,453 @@ +import mne as mne +import numpy as np +from statsmodels.stats.multitest import multipletests +from scipy.stats import ttest_ind + +import matplotlib.pyplot as plt +from itertools import combinations +from scipy.sparse import csr_matrix +from scipy.sparse.csgraph import connected_components + + +def compute_connectivity_matrix(f_range, tmin_tmax_range, csd_param, method, condition, subj, epochs_tmin, epochs_tmax, + results_base_dir, data_base_dir, do_rename_channels=True): + f_range_low = f_range[0] + f_range_high = f_range[1] + if tmin_tmax_range is None: + conn_tmin = -0.2 + conn_tmax = 0.5 + else: + conn_tmin = tmin_tmax_range[0] + conn_tmax = tmin_tmax_range[1] + if csd_param is None: + do_csd = False + else: + do_csd = True + stiffnes = csd_param[0] + lambd = csd_param[1] + print("Starting connectivity computation: " + condition + " condition, " + str(f_range_low) + "-" + + str(f_range_high) + " Hz: " + method + ", tmin: " + str(conn_tmin) + "s, tmax: " + str(conn_tmax) + "s") + suffix_base = "" + suffix_base += "_" + str(f_range_low) + "-" + str(f_range_high) + "_" + method + "_T_" + \ + str(conn_tmin) + "_" + str(conn_tmax) + "s" + if do_csd: + suffix_base += "_csd" + str(stiffnes) + "_" + str(lambd) + else: + suffix_base += "_nocsd" + if condition == "proposer": + suffix_base += "P" + elif condition == "responder": + suffix_base += "R" + + if subj[0] == '0': + if condition == "responder": + file_name = subj + "_R1" + event_name = 'Stimulus/S160' + event_num = 160 + elif condition == "proposer": + file_name = subj + "_P1" + event_name = 'Stimulus/S128' + event_num = 128 + else: + raise ValueError("Unrecognised condition" + condition) + + elif subj[0] == '1': + if condition == "responder": + file_name = subj + "_R1" + event_name = 'Stimulus/S160' + event_num = 160 + elif condition == "proposer": + file_name = subj + "_P1" + event_name = 'Stimulus/S128' + event_num = 128 + else: + raise ValueError("Unrecognised condition" + condition) + + header_file = data_base_dir + file_name + ".vhdr" + marker_file = data_base_dir + file_name + ".vmrk" + + print("Reading:\t" + header_file) + raw = mne.io.read_raw_brainvision(header_file, preload=True) + + # rename channels to a 128biosemi layout names + picks = mne.pick_channels(raw.info['ch_names'], include=[]) + if do_rename_channels: + names_dict = {} + ch_ind = 0 + for prefix in ['A', 'B', 'C', 'D']: + for i in range(0, 32): + name = prefix + str(i + 1) + if name != raw.info['ch_names'][ch_ind]: + names_dict.update({raw.info['ch_names'][ch_ind]: name}) + ch_ind += 1 + raw.rename_channels(names_dict) + print("Updated " + str(len(names_dict)) + " electrodes names") + + # Resetting channel types + types_dict = {} + for name in raw.info['ch_names']: + if name[0] != 'E': + types_dict.update({name: 'eeg'}) + else: + types_dict.update({name: 'eog'}) + + raw.set_channel_types(types_dict) + + my_montage = mne.channels.read_custom_montage("biosemi128.sfp") + raw.set_montage(my_montage) + + annot = mne.read_annotations(marker_file) + raw.set_annotations(annot) + + events, event_ids = mne.events_from_annotations(raw) + events_start_only = mne.pick_events(events, include=event_num) + event_ids_start_only = {key: event_ids[key] for key in event_ids.keys() & {event_name}} + sfreq = raw.info['sfreq'] + + raw.set_eeg_reference('average', projection=True) + if do_csd: + raw_post = mne.preprocessing.compute_current_source_density(raw, stiffness=stiffnes, lambda2=lambd) + else: + raw_post = raw + + epochs = mne.Epochs(raw_post, events_start_only, event_ids_start_only, + tmin=epochs_tmin, tmax=epochs_tmax, picks='data') + + (con, freqs, times, n_epochs, n_tapers) = mne.connectivity.spectral_connectivity(epochs, method=method, + mode='multitaper', sfreq=sfreq, + fmin=f_range_low, + fmax=f_range_high, + tmin=conn_tmin, tmax=conn_tmax, + faverage=True, mt_adaptive=False, + n_jobs=1) + con[:, :, 0] += con[:, :, 0].T + np.save(results_base_dir + 'con' + file_name + suffix_base + '.npy', con[:, :, 0]) + return con[:, :, 0] + + +def compute_distances(el_subset = 'All', ref_data_file = "data/000_P1.vhdr", vis=False): + # needs checking the picks order + raw = mne.io.read_raw_brainvision(ref_data_file, preload=True) + if el_subset == 'All': + el_subset = ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', 'A11','A12', 'A13', 'A14', 'A15', + 'A16', 'A17', 'A18', 'A19', 'A20','A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27', 'A28', 'A29', + 'A30', 'A31', 'A32', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11', 'B12', + 'B13', 'B14', 'B15', 'B16', 'B17', 'B18', 'B19', 'B20', 'B21', 'B22', 'B23', 'B24', 'B25', 'B26', + 'B27', 'B28', 'B29', 'B30', 'B31', 'B32', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', + 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23', + 'C24', 'C25', 'C26', 'C27', 'C28', 'C29', 'C30', 'C31', 'C32', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6', + 'D7', 'D8', 'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15', 'D16', 'D17', 'D18', 'D19', 'D20', + 'D21', 'D22', 'D23', 'D24', 'D25', 'D26', 'D27', 'D28', 'D29', 'D30', 'D31', 'D32'] + n_electrodes = len(el_subset) + + + names_dict = {} + ch_ind = 0 + for prefix in ['A', 'B', 'C', 'D']: + for i in range(0, 32): + name = prefix + str(i + 1) + if name != raw.info['ch_names'][ch_ind]: + names_dict.update({raw.info['ch_names'][ch_ind]: name}) + ch_ind += 1 + raw.rename_channels(names_dict) + print("Updated " + str(len(names_dict)) + " electrodes names") + + # Resetting channel types + types_dict = {} + for name in raw.info['ch_names']: + if name[0] != 'E': + types_dict.update({name: 'eeg'}) + else: + types_dict.update({name: 'eog'}) + + raw.set_channel_types(types_dict) + + raw.set_montage('biosemi128') + picks = mne.pick_channels(raw.info['ch_names'], include=el_subset) + #raw = raw.pick(picks) + + dist_m = np.zeros(shape=(n_electrodes, n_electrodes)) + for i in range(0, n_electrodes): + for j in range(i + 1, n_electrodes): + dist_m[j, i] = np.linalg.norm(raw.info['chs'][picks[i]]['loc'][0:3] - raw.info['chs'][picks[j]]['loc'][0:3]) + + if vis: + #raw = raw.pick(picks) + raw.info['bads'] = el_subset + raw.plot_sensors(ch_type='eeg', kind='3d', show_names=True) + + plt.figure() + plt.imshow(dist_m, cmap='hot') + #plt.tick_params(labelsize=6) + + step = max(1, int(np.log2(n_electrodes)/2)) + plt.xticks(np.arange(n_electrodes,step=step), el_subset[::step], rotation=270) + plt.yticks(np.arange(n_electrodes,step=step), el_subset[::step]) + plt.colorbar() + plt.show() + return dist_m + +def select_el_subset_short(n_of_electrodes : int, connectivity_matrix, subjects, el_subset='All', reference_raw = None): + """ + + Parameters + ---------- + n_of_electrodes: desired number of n best electrodes + connectivity_matrix: connectivity matrix: electrodes x electrodes x subjects + el_subset: which subset of electrodes to start from + n_controls: how many controls in the connectivity matrix + reference_raw: reference eeg file for electrode information + + Returns + ------- + + """ + + n_controls = 0 + n_patients = 0 + for subj in subjects: + if subj[0] == '0': + n_controls += 1 + elif subj[0] == '1': + n_patients += 1 + + if el_subset == 'All': + el_subset = ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', 'A11','A12', 'A13', 'A14', 'A15', + 'A16', 'A17', 'A18', 'A19', 'A20','A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27', 'A28', 'A29', + 'A30', 'A31', 'A32', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11', 'B12', + 'B13', 'B14', 'B15', 'B16', 'B17', 'B18', 'B19', 'B20', 'B21', 'B22', 'B23', 'B24', 'B25', 'B26', + 'B27', 'B28', 'B29', 'B30', 'B31', 'B32', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', + 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23', + 'C24', 'C25', 'C26', 'C27', 'C28', 'C29', 'C30', 'C31', 'C32', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6', + 'D7', 'D8', 'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15', 'D16', 'D17', 'D18', 'D19', 'D20', + 'D21', 'D22', 'D23', 'D24', 'D25', 'D26', 'D27', 'D28', 'D29', 'D30', 'D31', 'D32'] + + reference_raw_file = "data/000_P1.vhdr" + do_rename_channels = True + + if reference_raw is None: + reference_raw = mne.io.read_raw_brainvision(reference_raw_file, preload=False) + if do_rename_channels: + names_dict = {} + ch_ind = 0 + for prefix in ['A', 'B', 'C', 'D']: + for i in range(0, 32): + name = prefix + str(i + 1) + if name != reference_raw.info['ch_names'][ch_ind]: + names_dict.update({reference_raw.info['ch_names'][ch_ind]: name}) + ch_ind += 1 + reference_raw.rename_channels(names_dict) + types_dict = {} + for name in reference_raw.info['ch_names']: + if name[0] != 'E': + types_dict.update({name: 'eeg'}) + else: + types_dict.update({name: 'eog'}) + reference_raw.set_channel_types(types_dict) + picks = mne.pick_channels(reference_raw.info['ch_names'], include=el_subset) + + cm_selected = connectivity_matrix[picks, :, :][:, picks, :] + + n_channels = picks.shape[0] + + tstats, pvalue = ttest_ind(cm_selected[:, :, :n_controls], cm_selected[:, :, n_controls:], axis=2) + flat_pvalues = np.empty(shape=(int(n_channels*(n_channels-1)/2))) + ind = 0 + for i in range(0, n_channels): + for j in range(i + 1, n_channels): + flat_pvalues[ind] = pvalue[j, i] + ind += 1 + + h_true, corrected_pvalues_flat, als, alb = multipletests(flat_pvalues, method='fdr_bh') + + corrected_pvalue = np.empty_like(pvalue) + corrected_pvalue[:] = np.nan + ind = 0 + for i in range(0, n_channels): + for j in range(i + 1, n_channels): + corrected_pvalue[j, i] = corrected_pvalues_flat[ind] + ind += 1 + tstats_corrected_sorted = np.sort(np.tril(tstats, -1).flatten()) + i = 0 + result_subset = [] + while len(result_subset) < n_of_electrodes: + inds = np.where(tstats == tstats_corrected_sorted[-1 * (i+1)]) + if el_subset[inds[0][0]] not in result_subset: + result_subset.append(el_subset[inds[0][0]]) + if el_subset[inds[1][0]] not in result_subset: + result_subset.append(el_subset[inds[1][0]]) + i += 1 + result_subset = sorted(result_subset) + + return result_subset + + +def plot_statistical_test_results(stats, pvalues, labels, writepath, title, show=False, left_title="Test results"): + step = max(1, int(np.log2(len(labels) + 1) / 2)) + fig = plt.figure(title, figsize=[12, 6]) + ax = plt.subplot(1, 2, 1) + plt.title(left_title) + im = plt.imshow(stats, cmap='hot') + plt.xticks(np.arange(len(labels), step=step), labels[::step], rotation=270) + plt.yticks(np.arange(len(labels), step=step), labels[::step]) + plt.colorbar(im, fraction=0.046, pad=0.04) + ax = plt.subplot(1, 2, 2) + plt.title("p values, min = " + str('{0:.4f}'.format(np.nanmin(pvalues)))) + im = plt.imshow(pvalues, cmap='hot') + plt.xticks(np.arange(len(labels), step=step), labels[::step], rotation=270) + plt.yticks(np.arange(len(labels), step=step), labels[::step]) + plt.colorbar(im, fraction=0.046, pad=0.04) + plt.tight_layout(pad=1) + if show: + plt.show() + else: + plt.savefig(writepath, dpi=600) + plt.close(fig) + return + +def get_biosemi_to_1020_mapping(num_of_electrodes): + electrode_map32 = {'C29': 'Fp1', 'C16': 'Fp2', 'C27': 'AF3', 'C14': 'AF4', 'D7': 'F7', 'D4': 'F3', 'C21': 'Fz', + 'C4': 'F4', 'C7': 'F8', 'D10': 'FC5', 'D2': 'FC1', 'C2': 'FC2', 'B29': 'FC6', 'D23': 'T7', + 'D19': 'C3', 'A1': 'Cz', 'B22': 'C4', 'B26': 'T8', 'D26': 'CP5', 'D16': 'CP1', 'B2': 'CP2', + 'B16': 'CP6', 'D31': 'P7', 'A7': 'P3', 'A19': 'Pz', 'B4': 'P4', 'B11': 'P8', 'A17': 'PO3', + 'A30': 'PO4', 'A15': 'O1', 'A23': 'Oz', 'A28': 'O2'} + electrode_map64 = {'C29': 'Fp1', 'C17': 'Fpz', 'C16': 'Fp2', 'C30': 'AF7', 'C27': 'AF3', 'C19': 'AFz', 'C14': 'AF4', + 'C8': 'AF8', 'D7': 'F7', 'D5': 'F5', 'D4': 'F3', 'C25': 'F1', 'C21': 'Fz', 'C12': 'F2', + 'C4': 'F4', 'C5': 'F6', 'C7': 'F8', 'D8': 'FT7', 'D10': 'FC5', 'D12': 'FC3', 'D2': 'FC1', + 'C23': 'FCz', 'C2': 'FC2', 'B31': 'FC4', 'B29': 'FC6', 'B27': 'FT8', 'D23': 'T7', 'D21': 'C5', + 'D19': 'C3', 'D14': 'C1', 'A1': 'Cz', 'B20': 'C2', 'B22': 'C4', 'B24': 'C6', 'B26': 'T8', + 'D24': 'TP7', 'D26': 'CP5', 'D28': 'CP3', 'D16': 'CP1', 'A3': 'CPz', 'B2': 'CP2', 'B18': 'CP4', + 'B16': 'CP6', 'B14': 'TP8', 'D32': 'P9', 'D31': 'P7', 'D29': 'P5', 'A7': 'P3', 'A5': 'P1', + 'A19': 'Pz', 'A32': 'P2', 'B4': 'P4', 'B13': 'P6', 'B11': 'P8', 'B10': 'P10', 'A10': 'PO7', + 'A16': 'PO3', 'A21': 'POz', 'A29': 'PO4', 'B7': 'PO8', 'A15': 'O1', 'A23': 'Oz', 'A28': 'O2', + 'A25': 'Iz'} + if num_of_electrodes == 32: + return electrode_map32 + elif num_of_electrodes == 64: + return electrode_map64 + else: + raise ValueError("Wrong parameter: should be 32 or 64") + +def get_biosemi128_names(): + n = ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', 'A11', + 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A20', 'A21', + 'A22', 'A23', 'A24', 'A25', 'A26', 'A27', 'A28', 'A29', 'A30', 'A31', + 'A32', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', + 'B11', 'B12', 'B13', 'B14', 'B15', 'B16', 'B17', 'B18', 'B19', 'B20', + 'B21', 'B22', 'B23', 'B24', 'B25', 'B26', 'B27', 'B28', 'B29', 'B30', + 'B31', 'B32', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', + 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', + 'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26', 'C27', 'C28', 'C29', + 'C30', 'C31', 'C32', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7', 'D8', + 'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15', 'D16', 'D17', 'D18', + 'D19', 'D20', 'D21', 'D22', 'D23', 'D24', 'D25', 'D26', 'D27', 'D28', + 'D29', 'D30', 'D31', 'D32'] + return n + +def get_biosemi128_region_mapping(): + map = {'A1': 'CP', 'A2': 'CP', 'A3': 'CP', 'A4': 'CP', + 'A5': 'CP', 'A6': 'CP', 'A7': 'CP', 'A8': 'CP', + 'A9': 'CP', 'A10': 'O', 'A11': 'O', 'A12': 'O', + 'A13': 'O', 'A14': 'O', 'A15': 'O', 'A16': 'O', + 'A17': 'O', 'A18': 'CP', 'A19': 'CP', 'A20': 'CP', + 'A21': 'O', 'A22': 'O', 'A23': 'O', 'A24': 'O', + 'A25': 'O', 'A26': 'O', 'A27': 'O', 'A28': 'O', + 'A29': 'O', 'A30': 'O', 'A31': 'CP', 'A32': 'CP', + 'B1': 'CP', 'B2': 'CP', 'B3': 'CP', 'B4': 'CP', + 'B5': 'CP', 'B6': 'CP', 'B7': 'O', 'B8': 'O', + 'B9': 'O', 'B10': 'RT', 'B11': 'RT', 'B12': 'RT', + 'B13': 'CP', 'B14': 'RT', 'B15': 'RT', 'B16': 'RT', + 'B17': 'CP', 'B18': 'CP', 'B19': 'CP', 'B20': 'CP', + 'B21': 'CP', 'B22': 'CP', 'B23': 'RT', 'B24': 'RT', + 'B25': 'RT', 'B26': 'RT', 'B27': 'RT', 'B28': 'RT', + 'B29': 'RT', 'B30': 'CP', 'B31': 'CP', 'B32': 'CP', + 'C1': 'CP', 'C2': 'CP', 'C3': 'CP', 'C4': 'F', + 'C5': 'F', 'C6': 'RT', 'C7': 'RT', 'C8': 'F', + 'C9': 'F', 'C10': 'F', 'C11': 'CP', 'C12': 'F', + 'C13': 'F', 'C14': 'F', 'C15': 'F', 'C16': 'F', + 'C17': 'F', 'C18': 'F', 'C19': 'F', 'C20': 'F', + 'C21': 'F', 'C22': 'F', 'C23': 'CP', 'C24': 'CP', + 'C25': 'F', 'C26': 'F', 'C27': 'F', 'C28': 'F', + 'C29': 'F', 'C30': 'F', 'C31': 'F', 'C32': 'F', + 'D1': 'CP', 'D2': 'CP', 'D3': 'F', 'D4': 'F', + 'D5': 'F', 'D6': 'LT', 'D7': 'LT', 'D8': 'LT', + 'D9': 'LT', 'D10': 'LT', 'D11': 'CP', 'D12': 'CP', + 'D13': 'CP', 'D14': 'CP', 'D15': 'CP', 'D16': 'CP', + 'D17': 'CP', 'D18': 'CP', 'D19': 'CP', 'D20': 'LT', + 'D21': 'LT', 'D22': 'LT', 'D23': 'LT', 'D24': 'LT', + 'D25': 'LT', 'D26': 'LT', 'D27': 'CP', 'D28': 'CP', + 'D29': 'CP', 'D30': 'LT', 'D31': 'LT', 'D32': 'LT'} + return map + + +def xy_1020_coord(el=None): + xy_dict = { + 'Fp1': (-90, 277), + 'Fp2': (90, 277), + 'AF3': (-109, 223), + 'AF4': (109, 223), + 'F7': (-234, 170), + 'F3': (-127, 154), + 'Fz': (0, 146), + 'F4': (127, 154), + 'F8': (234, 170), + 'FC5': (-207, 80), + 'FC1': (-68, 75), + 'FC2': (68, 75), + 'FC6': (207, 80), + 'T7': (-287, 0), + 'C3': (-143, 0), + 'Cz': (0, 0), + 'C4': (143, 0), + 'T8': (287, 0), + 'CP5': (-207, -80), + 'CP1': (-68, -75), + 'CP2': (68, -75), + 'CP6': (207, -80), + 'P7': (-234, -170), + 'P3': (-127, -154), + 'Pz': (0, -146), + 'P4': (127, -154), + 'P8': (234, -170), + 'PO3': (-109, -223), + 'PO4': (109, -223), + 'O1': (-90, -277), + 'Oz': (0, -287), + 'O2': (90, -277)} + if el is None: + return xy_dict + else: + return xy_dict[el] + + +def symm_matrix(input_matrix, binary=False): + if binary: + result_matrix = input_matrix.astype(int) + for (i, j) in combinations(range(len(input_matrix)), 2): + if input_matrix[i, j] or input_matrix[j, i]: + result_matrix[i, j] = 1 + result_matrix[j, i] = 1 + else: + result_matrix = (input_matrix + input_matrix.T) / 2.0 + + return result_matrix + + +def threshold_matrix(input_matrix, percentage_of_connections, is_symmetrical=True): + #if is_symmetrical: + # border_val = np.sort(input_matrix.flatten())[-1*int(percentage_of_connections*input_matrix.flatten().shape[0])] + #else: + # border_val = np.sort(input_matrix.flatten())[-1*int(percentage_of_connections * input_matrix.shape[0]*(input_matrix.shape[0]-1)/2)] + conn_values = np.sort(input_matrix[np.triu_indices_from(input_matrix, 1)]) + border_val = conn_values[-1*int(len(conn_values)*percentage_of_connections)] + output_matrix = input_matrix + output_matrix[np.where(input_matrix <= border_val)] = 0 + return output_matrix + +def is_connected(adj: np.ndarray): + graph = csr_matrix(adj) + n_components, labels = connected_components(csgraph=graph, directed=False, return_labels=True) + if n_components == 1: + return True + else: + return False \ No newline at end of file diff --git a/makeparam.py b/makeparam.py new file mode 100644 index 0000000..8259ac2 --- /dev/null +++ b/makeparam.py @@ -0,0 +1,58 @@ +import json +''' + 'methods': imcoh, pli, coh, plv, etc. Any method available + for mne.connectivity.spectral_connectivity() + 'electrode_subset': A list of electrodes: ['A1', 'A2'] or + a pre-defined set of '32el10-20' or '64el10-20' + 'select_x_best_electrodes': int or 'All'. If int is given, it will select the most statistically + significant set of the corresponding 'electrode_subset' (see above). + 'frequency_band': A bandwidth (pair of [fmin, fmax]) in which connectivity will be computed + 'tmin_tmax': time around the stimulus (pair of [tmin, tmax])in seconds. If None,, [tmin, tmax] = [-0.2, 0.5] + 'CSD_parameters': None or a pair of [stiffnes, lambda^2]: stiffness: 2-5, lambda^2: 0-10e-5 + 'conditions': 'responder' and/or 'proposer' + 'subjects': list of subjects + 'add_label': additional label + + methods, electrode_subset, select_x_best_electrodes, CSD_parameters are mutli-loop. + subjects and add_label are fixed for all loops (1 value) +''' +filename = "npm_8-13Hz_0.7s_P.json" +#filename = "npm_8-13Hz_0.7s_R.json" + +#filename = "npm_14-30Hz_0.7s_P.json" +#filename = "npm_14-30Hz_0.7s_R.json" + + +param_dict = { + 'methods': ['imcoh'], + 'conditions': ['proposer'], + 'frequency_band': [[8, 13]], + 'electrode_subset': ['32el10-20'], + 'select_x_best_electrodes': ["All"], #int or 'All'. If int is given, it will select the most statistically + # significant set of the corresponding 'electrode_subset' (see above). + 'tmin_tmax': [[-2.25, -1.55], [-1.55, -0.85], [-0.7, 0.0], [-0.35, 0.35], [0.0, 0.7], [0.85, 1.55], [1.55, 2.25]], + 'CSD_parameters': [[3, 0.001]], #stiffness 2-5, lambda^2 0-10e-5 + 'datapath': "/mnt/data/NoCSD/", + 'subjects': + ['000', '007', '008', '009', '010', '011', '012', '013', '014', '015', + '016', '017', '018', '020', '021', '023', '025', '027', '028', '029', + '030', '031', '032', '034', + '101', '102', '106', '109', '110', '111', '112', '113', '114', '115', + '116', '117', '121', '124', '125', '126', '127', '128', '199'], + 'add_label': "", + 'result_dir': "/mnt/data/conn_results/results/" +} +while len(param_dict['electrode_subset']) > len(param_dict['select_x_best_electrodes']): + param_dict['select_x_best_electrodes'].append('All') + +if param_dict['datapath'][-1] is not "/": + param_dict['datapath'] += "/" +if param_dict['result_dir'][-1] is not "/": + param_dict['result_dir'] += "/" + +json_dict = json.dumps(param_dict) +f = open(filename, "w") +f.write(json_dict) +f.close() + +print(param_dict) diff --git a/npm_14-30Hz_0.7s_P.json b/npm_14-30Hz_0.7s_P.json new file mode 100644 index 0000000..513c77c --- /dev/null +++ b/npm_14-30Hz_0.7s_P.json @@ -0,0 +1 @@ +{"methods": ["imcoh"], "conditions": ["proposer"], "frequency_band": [[14, 30]], "electrode_subset": ["32el10-20"], "select_x_best_electrodes": ["All"], "tmin_tmax": [[-2.25, -1.55], [-1.55, -0.85], [-0.7, 0.0], [-0.35, 0.35], [0.0, 0.7], [0.85, 1.55], [1.55, 2.25]], "CSD_parameters": [[3, 0.001]], "datapath": "/mnt/data/NoCSD/", "subjects": ["000", "007", "008", "009", "010", "011", "012", "013", "014", "015", "016", "017", "018", "020", "021", "023", "025", "027", "028", "029", "030", "031", "032", "034", "101", "102", "106", "109", "110", "111", "112", "113", "114", "115", "116", "117", "121", "124", "125", "126", "127", "128", "199"], "add_label": "", "result_dir": "/mnt/data/conn_results/results/"} \ No newline at end of file diff --git a/npm_14-30Hz_0.7s_R.json b/npm_14-30Hz_0.7s_R.json new file mode 100644 index 0000000..1645879 --- /dev/null +++ b/npm_14-30Hz_0.7s_R.json @@ -0,0 +1 @@ +{"methods": ["imcoh"], "electrode_subset": ["32el10-20"], "select_x_best_electrodes": ["All"], "frequency_band": [[14, 30]], "tmin_tmax": [[-2.25, -1.55], [-1.55, -0.85], [-0.7, 0.0], [-0.35, 0.35], [0.0, 0.7], [0.85, 1.55], [1.55, 2.25]], "CSD_parameters": [[3, 0.001]], "datapath": "/mnt/data/NoCSD/", "subjects": ["000", "007", "008", "009", "010", "011", "012", "013", "014", "015", "016", "017", "018", "020", "021", "023", "025", "027", "028", "029", "030", "031", "032", "034", "101", "102", "106", "109", "110", "111", "112", "113", "114", "115", "116", "117", "121", "124", "125", "126", "127", "128", "199"], "conditions": ["responder"], "add_label": "", "result_dir": "/mnt/data/conn_results/results/"} \ No newline at end of file diff --git a/npm_8-13Hz_0.7s_P.json b/npm_8-13Hz_0.7s_P.json new file mode 100644 index 0000000..ef5d7b4 --- /dev/null +++ b/npm_8-13Hz_0.7s_P.json @@ -0,0 +1 @@ +{"methods": ["imcoh"], "conditions": ["proposer"], "frequency_band": [[8, 13]], "electrode_subset": ["32el10-20"], "select_x_best_electrodes": ["All"], "tmin_tmax": [[-2.25, -1.55], [-1.55, -0.85], [-0.7, 0.0], [-0.35, 0.35], [0.0, 0.7], [0.85, 1.55], [1.55, 2.25]], "CSD_parameters": [[3, 0.001]], "datapath": "/mnt/data/NoCSD/", "subjects": ["000", "007", "008", "009", "010", "011", "012", "013", "014", "015", "016", "017", "018", "020", "021", "023", "025", "027", "028", "029", "030", "031", "032", "034", "101", "102", "106", "109", "110", "111", "112", "113", "114", "115", "116", "117", "121", "124", "125", "126", "127", "128", "199"], "add_label": "", "result_dir": "/mnt/data/conn_results/results/"} \ No newline at end of file diff --git a/npm_8-13Hz_0.7s_R.json b/npm_8-13Hz_0.7s_R.json new file mode 100644 index 0000000..7821847 --- /dev/null +++ b/npm_8-13Hz_0.7s_R.json @@ -0,0 +1 @@ +{"methods": ["imcoh"], "electrode_subset": ["32el10-20"], "select_x_best_electrodes": ["All"], "frequency_band": [[8, 13]], "tmin_tmax": [[-2.25, -1.55], [-1.55, -0.85], [-0.7, 0.0], [-0.35, 0.35], [0.0, 0.7], [0.85, 1.55], [1.55, 2.25]], "connectivity_threshold": [0.2], "CSD_parameters": [[3, 0.001]], "datapath": "/mnt/data/NoCSD/", "subjects": ["000", "007", "008", "009", "010", "011", "012", "013", "014", "015", "016", "017", "018", "020", "021", "023", "025", "026", "027", "028", "029", "030", "031", "032", "034", "101", "102", "105", "106", "109", "110", "111", "112", "113", "114", "115", "116", "117", "118", "121", "124", "125", "126", "127", "128", "199"], "conditions": ["responder"], "add_label": "", "result_dir": "/mnt/data/conn_results/results/"} \ No newline at end of file diff --git a/region_average.py b/region_average.py new file mode 100644 index 0000000..d40a5c0 --- /dev/null +++ b/region_average.py @@ -0,0 +1,456 @@ +import os +import json +import hashlib +import csv +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + +import mne +from scipy.stats import mannwhitneyu, shapiro +import seaborn as sns +from statsmodels.stats.multitest import multipletests +from visbrain.objects import SourceObj, BrainObj, TopoObj +from visbrain.gui import Brain, Figure + +from functions import get_biosemi_to_1020_mapping, compute_connectivity_matrix, get_biosemi128_region_mapping, xy_1020_coord + +#Pairs of regions to average connectivity. 'F', 'O', 'LT', 'RT', 'CP' +regions = [['LT', 'RT'], ['F', 'O']] + + +param_file = "npm_8-13Hz_0.7s_R.json" +param_file = "npm_8-13Hz_0.7s_P.json" + +#param_file = "npm_14-30Hz_0.7s_R.json" +#param_file = "npm_14-30Hz_0.7s_P.json" + + +with open(param_file) as json_file: + param_dict = json.load(json_file) + +subjects = np.array(sorted(param_dict['subjects'])).astype(str) + +do_rename_channels = True + +results_base_dir = param_dict['result_dir'] +data_base_dir = param_dict['datapath'] +reference_raw_file = data_base_dir + "000_P1.vhdr" +conditions = param_dict['conditions'] +methods = param_dict['methods'] +franges = param_dict['frequency_band'] +tmin_tmax = param_dict['tmin_tmax'] +csd_params = param_dict['CSD_parameters'] +add_label = param_dict['add_label'] + +min_edge_pvals = [] + +epochs_tmin = -2 +epochs_tmax = 4 +xyz = None + +region_mapping = get_biosemi128_region_mapping() + +#Load one raw file to get electrodes information, positions etc +reference_raw = mne.io.read_raw_brainvision(reference_raw_file, preload=False) +if do_rename_channels: + names_dict = {} + ch_ind = 0 + for prefix in ['A', 'B', 'C', 'D']: + for i in range(0, 32): + name = prefix + str(i + 1) + if name != reference_raw.info['ch_names'][ch_ind]: + names_dict.update({reference_raw.info['ch_names'][ch_ind]: name}) + ch_ind += 1 + reference_raw.rename_channels(names_dict) + + w = csv.writer(open(results_base_dir + "electrodes_dict.csv", "w")) + for key, val in names_dict.items(): + w.writerow([key, val]) + + types_dict = {} + for name in reference_raw.info['ch_names']: + if name[0] != 'E': + types_dict.update({name: 'eeg'}) + else: + types_dict.update({name: 'eog'}) + reference_raw.set_channel_types(types_dict) +picks = mne.pick_types(reference_raw.info, eeg=True) +n_all_el = picks.shape[0] +my_montage = mne.channels.read_custom_montage("biosemi128.sfp") +reference_raw.set_montage(my_montage) +print("Reference raw loaded") + +pvals_results = [] +sw_results = [] + +for region_pair in regions: + region1 = region_pair[0] + region2 = region_pair[1] + for condition in conditions: + for f_range in franges: + f_range_low = f_range[0] + f_range_high = f_range[1] + for tmin_tmax_ins in tmin_tmax: + if tmin_tmax_ins is None: + conn_tmin = -0.2 + conn_tmax = 0.5 + else: + conn_tmin = tmin_tmax_ins[0] + conn_tmax = tmin_tmax_ins[1] + for csd_param in csd_params: + if csd_param is None: + do_csd = False + else: + do_csd = True + stiffnes = csd_param[0] + lambd = csd_param[1] + for method in methods: + connectivity_full = np.empty(shape=(n_all_el, n_all_el, len(subjects))) + labels_true = [] + subj_pos = 0 + n_controls = 0 + n_patients = 0 + suffix_base = "" + suffix_base += "_" + str(f_range_low) + "-" + str(f_range_high) + "_" + method + "_T_" + \ + str(conn_tmin) + "_" + str(conn_tmax) + "s" + if do_csd: + suffix_base += "_csd" + str(stiffnes) + "_" + str(lambd) + else: + suffix_base += "_nocsd" + if condition == "proposer": + suffix_base += "P" + elif condition == "responder": + suffix_base += "R" + + for subj in subjects: + if subj[0] == '0': + n_controls += 1 + labels_true.append(1) + if condition == "responder": + file_name = subj + "_R1" + elif condition == "proposer": + file_name = subj + "_P1" + + else: + raise ValueError("Unrecognised condition" + condition) + + elif subj[0] == '1': + n_patients += 1 + labels_true.append(2) + if condition == "responder": + file_name = subj + "_R1" + elif condition == "proposer": + file_name = subj + "_P1" + else: + raise ValueError("Unrecognised condition" + condition) + + ############################# Loading/computing connectivity for each subject ######################### + if os.path.exists(results_base_dir + "conn_matrices/" + 'con' + file_name + suffix_base + '.npy'): + print( + "Reading precomputed connectivity matrix:\t" + results_base_dir + "conn_matrices/" + 'con' + file_name + suffix_base + '.npy') + c_matrix = np.load(results_base_dir + "conn_matrices/" + 'con' + file_name + suffix_base + '.npy') + + else: + c_matrix = compute_connectivity_matrix(f_range, tmin_tmax_ins, csd_param, method, + condition, subj, epochs_tmin, epochs_tmax, + results_base_dir + "conn_matrices/", data_base_dir) + connectivity_full[:, :, subj_pos] = c_matrix + n_channels = c_matrix.shape[0] + subj_pos += 1 + + labels_true = np.array(labels_true).astype(int) + + #Getting subset of electrodes + el_set_index = -1 + for selected_channels_param in param_dict['electrode_subset']: + el_set_index += 1 + if type(selected_channels_param) is list: + selected_channels = selected_channels_param + elif type(selected_channels_param) is str: + if selected_channels_param == 'All': + selected_channels = ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', + 'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A20', 'A21', + 'A22', 'A23', 'A24', 'A25', 'A26', 'A27', 'A28', 'A29', 'A30', 'A31', + 'A32', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', + 'B11', 'B12', 'B13', 'B14', 'B15', 'B16', 'B17', 'B18', 'B19', 'B20', + 'B21', 'B22', 'B23', 'B24', 'B25', 'B26', 'B27', 'B28', 'B29', 'B30', + 'B31', 'B32', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', + 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', + 'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26', 'C27', 'C28', 'C29', + 'C30', 'C31', 'C32', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7', 'D8', + 'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15', 'D16', 'D17', 'D18', + 'D19', 'D20', 'D21', 'D22', 'D23', 'D24', 'D25', 'D26', 'D27', 'D28', + 'D29', 'D30', 'D31', 'D32'] + elif selected_channels_param == '32el10-20': + electrode_mapping = get_biosemi_to_1020_mapping(32) + selected_channels = list(electrode_mapping.keys()) + + elif selected_channels_param == '64el10-20': + electrode_mapping = get_biosemi_to_1020_mapping(64) + selected_channels = list(electrode_mapping.keys()) + else: + raise ValueError("Wrong channel selection parameter: should be a list of " + "channel labels, '32el10-20', '64el10-20' or 'All'") + else: + raise ValueError("Wrong channel selection parameter: should be a list of " + "channel labels, '32el10-20', '64el10-20' or 'All'") + + #Electrode names conversion from biosemi ABCD to 10-20 convention + if selected_channels_param == '32el10-20' or selected_channels_param == '64el10-20': + el_labels = [electrode_mapping[key] for key in selected_channels] + else: + el_labels = selected_channels + + suffix = suffix_base + "_" + str(len(selected_channels)) + "el_" + text_el = "".join(selected_channels) + suffix += hashlib.sha256(text_el.encode()).hexdigest()[:8] + suffix += add_label + + n_channels = len(selected_channels) + selected_ch_ind = [mne.pick_channels(reference_raw.info['ch_names'], include=[ch])[0] + for ch in selected_channels] + connectivity_subset = connectivity_full[selected_ch_ind, :, :][:, selected_ch_ind, :] + + if xyz is None: + xyz = np.empty((n_channels, 3)) + xy = np.empty((n_channels, 2)) + for i in range(n_channels): + pick = mne.pick_channels(reference_raw.info['ch_names'], [selected_channels[i]]) + xyz[i, :] = reference_raw.info['chs'][pick[0]]['loc'][0:3] + if selected_channels_param == '32el10-20': + xy[i, :] = xy_1020_coord(electrode_mapping[selected_channels[i]]) + else: + xy[i, :] = xyz[i, 0:2] + + + #Sum connections between two regions of interest + avg_conn_abs = np.zeros(len(subjects)) + avg_conn_sgn = np.zeros(len(subjects)) + select = np.zeros((n_channels, n_channels)).astype(bool) + avg_conn_matrix_HC = np.zeros((n_channels, n_channels)) + avg_conn_matrix_FEP = np.zeros((n_channels, n_channels)) + n_connections = 0 + for i, el1_name in enumerate(selected_channels): + for j, el2_name in enumerate(selected_channels[:i]): + avg_conn_matrix_HC[(i, j), (j, i)] = np.mean(connectivity_subset[i, j, :n_controls]) + avg_conn_matrix_FEP[(i, j), (j, i)] = np.mean(connectivity_subset[i, j, n_controls:]) + if (region_mapping[el1_name] == region1 and region_mapping[el2_name] == region2) or \ + (region_mapping[el2_name] == region1 and region_mapping[el1_name] == region2): + avg_conn_abs += np.abs(connectivity_subset[i, j, :]) + avg_conn_sgn += connectivity_subset[i, j, :] + n_connections += 1 + select[(i, j), (j, i)] = True + + #Average connection strength + avg_conn_sgn = avg_conn_sgn / n_connections + avg_conn_abs = avg_conn_abs / n_connections + + #Between groups test + sw_stat_HC, sw_pval_HC = shapiro(avg_conn_sgn[:n_controls]) + sw_stat_FEP, sw_pval_FEP = shapiro(avg_conn_sgn[n_controls:]) + stat, pval = mannwhitneyu(avg_conn_sgn[:n_controls], avg_conn_sgn[n_controls:]) + pvals_results.append([pval, str(f_range[0]) + "-" + str(f_range[1]), + str(conn_tmin) + "_" + str(conn_tmax), + condition, region1 + "-" + region2]) + + sw_results.append( + [sw_pval_FEP, str(f_range[0]) + "-" + str(f_range[1]), str(conn_tmin) + "_" + str(conn_tmax), + condition, region1 + "-" + region2, "FEP"]) + + sw_results.append([sw_pval_HC, str(f_range[0]) + "-" + str(f_range[1]), + str(conn_tmin) + "_" + str(conn_tmax), condition, + region1 + "-" + region2, "HC"]) + + if pval < 0.05: + ed_cmap = 'RdBu_r' + tp_cmap = "Greys" + tp_point = 0.3 + cmin = np.minimum(np.min(avg_conn_matrix_FEP[select]), np.min(avg_conn_matrix_HC[select])) + cmax = np.maximum(np.max(avg_conn_matrix_FEP[select]), np.max(avg_conn_matrix_HC[select])) + cmin = float(format(-1 * np.maximum(np.abs(cmin), cmax), '.2f')) + cmax = float(format(np.maximum(np.abs(cmin), cmax), '.2f')) + + kw_top = dict(margin=15 / 100, chan_offset=(0., -0.81, 0.), chan_size=15, + system='cartesian', cbtxtsz=60, txtsz=155.) + t_obj = TopoObj('topo', tp_point*np.ones(n_channels), + channels=[electrode_mapping[i] for i in selected_channels], + xyz=xy, cmap=tp_cmap, **kw_top, clim=(0, 1)) + t_obj.connect(avg_conn_matrix_HC, select=select, cmap=ed_cmap, + antialias=True, line_width=4., clim=(cmin, cmax)) + t_obj.screenshot(results_base_dir + suffix_base[1:] + "_HC_edge.png", autocrop=True, dpi=600., + bgcolor='white') + + t_obj = TopoObj('topo', tp_point*np.ones(n_channels), + channels=[electrode_mapping[i] for i in selected_channels], + xyz=xy, cmap=tp_cmap, **kw_top, clim=(0, 1)) + t_obj.connect(avg_conn_matrix_FEP, select=select, cmap=ed_cmap, antialias=True, + line_width=4., clim=(cmin, cmax)) + t_obj.screenshot(results_base_dir + suffix_base[1:] + "_FEP_edge.png", autocrop=True, + dpi=600., bgcolor='white') + + fig_agg = Figure([results_base_dir + suffix_base[1:] + "_FEP_edge.png", + results_base_dir + suffix_base[1:] + "_HC_edge.png"], + titles=['Patients', 'Controls'], figtitle='Group average', + xlabels=[None, None], ylabels=[None, None], grid=[1, 2], + ax_bgcolor='white', fig_bgcolor='white', + subspace={'left': 0., 'right': 1., 'bottom': 0., 'top': .9, + 'wspace': 0., 'hspace': 0.07}, figsize=(12, 6), + text_color='black', autocrop=True) #y=1., + + fig_agg.shared_colorbar((cmin, cmax), t_obj._connect._cmap, + fz_title=20, vmin=cmin, vmax=cmax, + under='olive', over='firebrick', position='right', + title='Average imcoh', fz_ticks=20, + pltmargin=.001, figmargin=.001) + + fig_agg.save(results_base_dir + suffix_base[1:] + "_edge.png", dpi=300) + + avg_conn = pd.DataFrame([avg_conn_abs, avg_conn_sgn, + ["HC" if label == 1 else "FEP" for label in labels_true]], + index=["avg_abs", "avg_sgn", "Group"], columns=subjects).T + + if pval < 0.05: + fig = plt.figure() + h = sns.histplot(avg_conn[avg_conn["Group"] == 'HC'], x="avg_sgn", element="step", bins=15) + h.set_title( + "Time window " + str(conn_tmin) + " - " + str(conn_tmax) + "s, frequency " + str( + f_range_low) + "-" + str( + f_range_high) + "Hz,\n" + region1 + "-" + region2 + " connectivity", + fontsize=12) + h.set_ylabel("Number of participants", fontsize=18) + h.set_xlabel("Average imcoh", fontsize=18) + axes = plt.gca() + axes.set_xlim([np.min(avg_conn['avg_sgn']), np.max(avg_conn['avg_sgn'])]) + axes.set_ylim([0, 7.2]) + + plt.savefig( + results_base_dir + "histograms/" + "HC_avg_coh_sgn_hist_" + region1 + "-" + + region2 + "_" + suffix[1:] + ".png", dpi=400) + plt.close() + fig = plt.figure() + h = sns.histplot(avg_conn[avg_conn["Group"] == 'FEP'], x="avg_sgn", element="step", + bins=15) + print(suffix[1:] + "\nFEP:\t") + print(np.mean(avg_conn[avg_conn["Group"] == 'FEP']['avg_sgn'])) + print(np.std(avg_conn[avg_conn["Group"] == 'FEP']['avg_sgn'])) + print(suffix[1:] + "\nHC:\t") + print(np.mean(avg_conn[avg_conn["Group"] == 'HC']['avg_sgn'])) + print(np.std(avg_conn[avg_conn["Group"] == 'HC']['avg_sgn'])) + h.set_title( + "Time window " + str(conn_tmin) + " - " + str(conn_tmax) + "s, frequency " + str( + f_range_low) + "-" + str( + f_range_high) + "Hz,\n" + region1 + "-" + region2 + " connectivity", + fontsize=12) + h.set_ylabel("Number of participants", fontsize=18) + h.set_xlabel("Average imcoh", fontsize=18) + axes = plt.gca() + axes.set_xlim([np.min(avg_conn['avg_sgn']), np.max(avg_conn['avg_sgn'])]) + axes.set_ylim([0, 7.2]) + plt.savefig( + results_base_dir + "histograms/" + "FEP_avg_coh_sgn_hist_" + region1 + "-" + + region2 + "_" + suffix[1:] + ".png", + dpi=400) + plt.close() + + if not os.path.exists(results_base_dir + "histograms/"): + os.mkdir(results_base_dir + "histograms/") + print("Results saved at: " + results_base_dir + "histograms/") + + h = sns.histplot(avg_conn, x="avg_abs", hue="Group", element="step") + h.set_title("Time window " + str(conn_tmin) + " - " + str(conn_tmax) + "s, frequency " + + str(f_range_low) + "-" + str(f_range_high) + "Hz, " + + region1 + " - " + region2 + "connectivity", fontsize=12) + h.set_ylabel("Number of participants)", fontsize=10) + h.set_xlabel("Average |imcoh|", fontsize=10) + plt.savefig(results_base_dir + "histograms/" + "avg_coh_abs_" + region1 + "-" + region2 + "_" + suffix[1:] + ".png", dpi=400) + plt.close() + + h = sns.histplot(avg_conn, x="avg_sgn", hue="Group", element="step") + h.set_title("Time window " + str(conn_tmin) + " - " + str(conn_tmax) + "s, frequency " + + str(f_range_low) + "-" + str(f_range_high) + "Hz,\n" + + region1 + "-" + region2 + " connectivity, pvalue = " + '%.3f' % pval, fontsize=12) + h.set_ylabel("Number of participants", fontsize=10) + h.set_xlabel("Average imcoh", fontsize=10) + #put most interesting plots on the top + if pval < 0.05: + mrk = "0" + else: + mrk = "" + plt.savefig(results_base_dir + "histograms/" + mrk + "avg_coh_sgn_" + region1 + "-" + region2 + "_" + suffix[1:] + ".png", dpi=400) + plt.close() + +pvals_results = pd.DataFrame(pvals_results, columns=['Pval', 'Freq', 'Window', 'Condition', 'Region']) +pvals_results = pvals_results.assign(Pval_corr=multipletests(pvals_results['Pval'], method='fdr_bh')[1]) +pvals_results.to_csv(results_base_dir + "histograms/" + "avg_coh_sgn_pvalues" + param_file[:-5] + ".csv") + +sw_pval = pd.DataFrame(sw_results, columns=["pval", "freq", "twin", "condition", "regions", "Group"]) +sw_pval.to_csv(results_base_dir + "histograms/" + "avg_coh_sgn_gaussianity_test" + param_file[:-5] + ".csv") + +for condition in conditions: + for f_range in franges: + pvals_results_subset = pvals_results[(pvals_results['Condition'] == condition) & + (pvals_results['Freq'] == str(f_range[0]) + "-" + str(f_range[1]))] + pvals_results_subset['Pval'] = pvals_results_subset['Pval'].apply(lambda x: -1*np.log10(x)) + pvals_results_subset['Pval_corr'] = pvals_results_subset['Pval_corr'].apply(lambda x: -1 * np.log10(x)) + pvals_results_subset['Window'] = pvals_results_subset['Window'].apply(lambda x: 0.5 * float(x.split("_")[1]) + + 0.5 * float(x.split("_")[0])) + sns.set(rc={'figure.figsize': (7, 5.5)}) + sns.set_style('white') + b = sns.lineplot(x="Window", y="Pval", hue="Region", style="Region", + dashes=False, markers=True, data=pvals_results_subset) + b.set_title("Frequency " + str(f_range_low) + "-" + str(f_range_high) + "Hz, " + + condition + " condition", fontsize=12) + sns.despine(offset=-5, trim=False) + b.set_ylabel("-log (non-corrected pval)", fontsize=10) + b.set_xlabel("Time window" + r'$ (\pm $' + '%d' % (500*(conn_tmax-conn_tmin)) + "ms)", fontsize=10) + b.tick_params(axis='y', labelsize=10) + b.tick_params(axis='x', labelsize=10, rotation=30) + plt.axhline(-1 * np.log10(0.05), color="xkcd:light gray", ls='--') + plt.text(tmin_tmax[-1][0], -1 * np.log10(0.05) + 0.05, "p=0.05", fontsize=8) + plt.axhline(-1 * np.log10(0.01), color="xkcd:light gray", ls='--') + plt.text(tmin_tmax[1][0], -1 * np.log10(0.01) + 0.05, "p=0.01", fontsize=8) + plt.savefig(results_base_dir + "region_pval_for_windows_" + str(f_range[0]) + "-" + str(f_range[1]) + "Hz_" + + condition + ".png", dpi=400) + plt.close() + + sns.set(rc={'figure.figsize': (7, 5.5)}) + sns.set_style('white') + b = sns.lineplot(x="Window", y="Pval_corr", hue="Region", style="Region", dashes=False, markers=True, + data=pvals_results_subset) + b.set_title("Frequency " + str(f_range_low) + "-" + str(f_range_high) + "Hz, " + condition + " condition", + fontsize=12) + sns.despine(offset=-5, trim=False) + b.set(ylim=(0, 1.75)) + b.set_ylabel("-log (corrected pval)", fontsize=10) + b.set_xlabel("Time window" + r'$ (\pm $' + '%d' % (500 * (conn_tmax - conn_tmin)) + "ms)", fontsize=10) + b.tick_params(axis='y', labelsize=10) + b.tick_params(axis='x', labelsize=10, rotation=30) + plt.axhline(-1 * np.log10(0.05), color="xkcd:light gray", ls='--') + plt.text(tmin_tmax[-1][0], -1 * np.log10(0.05) + 0.02, "p=0.05", fontsize=8) + plt.axvline(0, color="xkcd:light grey blue", ls='--') + plt.text(tmin_tmax[1][0], -1 * np.log10(0.01) + 0.05, "p=0.01", fontsize=8) + plt.savefig(results_base_dir + "region_corr_pval_for_windows_" + str(f_range[0]) + "-" + str( + f_range[1]) + "Hz_" + condition + ".png", dpi=400) + plt.close() + +region_ind = {'LT': 1, 'RT': 2, 'CP': 3, 'F': 4, 'O': 5} +reg_labels = np.zeros(len(selected_channels)) +for i, el_name in enumerate(selected_channels): + reg_labels[i] = region_ind[region_mapping[el_name]] + +x_max = np.max(np.abs(xyz[:, 0])) +y_max = np.max(np.abs(xyz[:, 1])) +xyz[:, 0] = xy[:, 0]/np.max(xy[:, 0]) * x_max +xyz[:, 1] = xy[:, 1]/np.max(xy[:, 1]) * y_max - 0.017 +xyz *= 1000 +s_obj = SourceObj('Basic', np.array(xyz), radius_min=50, + text=[electrode_mapping[i] for i in selected_channels], text_size=10., text_color='black', text_bold=True) +s_obj.color_sources(data=reg_labels, cmap='plasma') +b_obj = BrainObj('B2') +vb = Brain(source_obj=s_obj, brain_obj=b_obj, bgcolor='white') +vb.brain_control(translucent=True) +vb.menuDispCbar.setChecked(True) +vb.rotate('axial_0') +#vb.show() +vb.screenshot(results_base_dir + "region_mapping.png", autocrop=True, print_size=(24, 24), unit='centimeter', dpi=300.)