Page MenuHomec4science

region_average.py
No OneTemporary

File Metadata

Created
Sun, May 22, 06:06

region_average.py

import os
import json
import hashlib
import csv
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import mne
from scipy.stats import mannwhitneyu, shapiro
import seaborn as sns
from statsmodels.stats.multitest import multipletests
from visbrain.objects import SourceObj, BrainObj, TopoObj
from visbrain.gui import Brain, Figure
from functions import get_biosemi_to_1020_mapping, compute_connectivity_matrix, get_biosemi128_region_mapping, xy_1020_coord
#Pairs of regions to average connectivity. 'F', 'O', 'LT', 'RT', 'CP'
regions = [['LT', 'RT'], ['F', 'O']]
param_file = "npm_8-13Hz_0.7s_R.json"
param_file = "npm_8-13Hz_0.7s_P.json"
#param_file = "npm_14-30Hz_0.7s_R.json"
#param_file = "npm_14-30Hz_0.7s_P.json"
with open(param_file) as json_file:
param_dict = json.load(json_file)
subjects = np.array(sorted(param_dict['subjects'])).astype(str)
do_rename_channels = True
results_base_dir = param_dict['result_dir']
data_base_dir = param_dict['datapath']
reference_raw_file = data_base_dir + "000_P1.vhdr"
conditions = param_dict['conditions']
methods = param_dict['methods']
franges = param_dict['frequency_band']
tmin_tmax = param_dict['tmin_tmax']
csd_params = param_dict['CSD_parameters']
add_label = param_dict['add_label']
min_edge_pvals = []
epochs_tmin = -2
epochs_tmax = 4
xyz = None
region_mapping = get_biosemi128_region_mapping()
#Load one raw file to get electrodes information, positions etc
reference_raw = mne.io.read_raw_brainvision(reference_raw_file, preload=False)
if do_rename_channels:
names_dict = {}
ch_ind = 0
for prefix in ['A', 'B', 'C', 'D']:
for i in range(0, 32):
name = prefix + str(i + 1)
if name != reference_raw.info['ch_names'][ch_ind]:
names_dict.update({reference_raw.info['ch_names'][ch_ind]: name})
ch_ind += 1
reference_raw.rename_channels(names_dict)
w = csv.writer(open(results_base_dir + "electrodes_dict.csv", "w"))
for key, val in names_dict.items():
w.writerow([key, val])
types_dict = {}
for name in reference_raw.info['ch_names']:
if name[0] != 'E':
types_dict.update({name: 'eeg'})
else:
types_dict.update({name: 'eog'})
reference_raw.set_channel_types(types_dict)
picks = mne.pick_types(reference_raw.info, eeg=True)
n_all_el = picks.shape[0]
my_montage = mne.channels.read_custom_montage("biosemi128.sfp")
reference_raw.set_montage(my_montage)
print("Reference raw loaded")
pvals_results = []
sw_results = []
for region_pair in regions:
region1 = region_pair[0]
region2 = region_pair[1]
for condition in conditions:
for f_range in franges:
f_range_low = f_range[0]
f_range_high = f_range[1]
for tmin_tmax_ins in tmin_tmax:
if tmin_tmax_ins is None:
conn_tmin = -0.2
conn_tmax = 0.5
else:
conn_tmin = tmin_tmax_ins[0]
conn_tmax = tmin_tmax_ins[1]
for csd_param in csd_params:
if csd_param is None:
do_csd = False
else:
do_csd = True
stiffnes = csd_param[0]
lambd = csd_param[1]
for method in methods:
connectivity_full = np.empty(shape=(n_all_el, n_all_el, len(subjects)))
labels_true = []
subj_pos = 0
n_controls = 0
n_patients = 0
suffix_base = ""
suffix_base += "_" + str(f_range_low) + "-" + str(f_range_high) + "_" + method + "_T_" + \
str(conn_tmin) + "_" + str(conn_tmax) + "s"
if do_csd:
suffix_base += "_csd" + str(stiffnes) + "_" + str(lambd)
else:
suffix_base += "_nocsd"
if condition == "proposer":
suffix_base += "P"
elif condition == "responder":
suffix_base += "R"
for subj in subjects:
if subj[0] == '0':
n_controls += 1
labels_true.append(1)
if condition == "responder":
file_name = subj + "_R1"
elif condition == "proposer":
file_name = subj + "_P1"
else:
raise ValueError("Unrecognised condition" + condition)
elif subj[0] == '1':
n_patients += 1
labels_true.append(2)
if condition == "responder":
file_name = subj + "_R1"
elif condition == "proposer":
file_name = subj + "_P1"
else:
raise ValueError("Unrecognised condition" + condition)
############################# Loading/computing connectivity for each subject #########################
if os.path.exists(results_base_dir + "conn_matrices/" + 'con' + file_name + suffix_base + '.npy'):
print(
"Reading precomputed connectivity matrix:\t" + results_base_dir + "conn_matrices/" + 'con' + file_name + suffix_base + '.npy')
c_matrix = np.load(results_base_dir + "conn_matrices/" + 'con' + file_name + suffix_base + '.npy')
else:
c_matrix = compute_connectivity_matrix(f_range, tmin_tmax_ins, csd_param, method,
condition, subj, epochs_tmin, epochs_tmax,
results_base_dir + "conn_matrices/", data_base_dir)
connectivity_full[:, :, subj_pos] = c_matrix
n_channels = c_matrix.shape[0]
subj_pos += 1
labels_true = np.array(labels_true).astype(int)
#Getting subset of electrodes
el_set_index = -1
for selected_channels_param in param_dict['electrode_subset']:
el_set_index += 1
if type(selected_channels_param) is list:
selected_channels = selected_channels_param
elif type(selected_channels_param) is str:
if selected_channels_param == 'All':
selected_channels = ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10',
'A11', 'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A20', 'A21',
'A22', 'A23', 'A24', 'A25', 'A26', 'A27', 'A28', 'A29', 'A30', 'A31',
'A32', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10',
'B11', 'B12', 'B13', 'B14', 'B15', 'B16', 'B17', 'B18', 'B19', 'B20',
'B21', 'B22', 'B23', 'B24', 'B25', 'B26', 'B27', 'B28', 'B29', 'B30',
'B31', 'B32', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9',
'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19',
'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26', 'C27', 'C28', 'C29',
'C30', 'C31', 'C32', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7', 'D8',
'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15', 'D16', 'D17', 'D18',
'D19', 'D20', 'D21', 'D22', 'D23', 'D24', 'D25', 'D26', 'D27', 'D28',
'D29', 'D30', 'D31', 'D32']
elif selected_channels_param == '32el10-20':
electrode_mapping = get_biosemi_to_1020_mapping(32)
selected_channels = list(electrode_mapping.keys())
elif selected_channels_param == '64el10-20':
electrode_mapping = get_biosemi_to_1020_mapping(64)
selected_channels = list(electrode_mapping.keys())
else:
raise ValueError("Wrong channel selection parameter: should be a list of "
"channel labels, '32el10-20', '64el10-20' or 'All'")
else:
raise ValueError("Wrong channel selection parameter: should be a list of "
"channel labels, '32el10-20', '64el10-20' or 'All'")
#Electrode names conversion from biosemi ABCD to 10-20 convention
if selected_channels_param == '32el10-20' or selected_channels_param == '64el10-20':
el_labels = [electrode_mapping[key] for key in selected_channels]
else:
el_labels = selected_channels
suffix = suffix_base + "_" + str(len(selected_channels)) + "el_"
text_el = "".join(selected_channels)
suffix += hashlib.sha256(text_el.encode()).hexdigest()[:8]
suffix += add_label
n_channels = len(selected_channels)
selected_ch_ind = [mne.pick_channels(reference_raw.info['ch_names'], include=[ch])[0]
for ch in selected_channels]
connectivity_subset = connectivity_full[selected_ch_ind, :, :][:, selected_ch_ind, :]
if xyz is None:
xyz = np.empty((n_channels, 3))
xy = np.empty((n_channels, 2))
for i in range(n_channels):
pick = mne.pick_channels(reference_raw.info['ch_names'], [selected_channels[i]])
xyz[i, :] = reference_raw.info['chs'][pick[0]]['loc'][0:3]
if selected_channels_param == '32el10-20':
xy[i, :] = xy_1020_coord(electrode_mapping[selected_channels[i]])
else:
xy[i, :] = xyz[i, 0:2]
#Sum connections between two regions of interest
avg_conn_abs = np.zeros(len(subjects))
avg_conn_sgn = np.zeros(len(subjects))
select = np.zeros((n_channels, n_channels)).astype(bool)
avg_conn_matrix_HC = np.zeros((n_channels, n_channels))
avg_conn_matrix_FEP = np.zeros((n_channels, n_channels))
n_connections = 0
for i, el1_name in enumerate(selected_channels):
for j, el2_name in enumerate(selected_channels[:i]):
avg_conn_matrix_HC[(i, j), (j, i)] = np.mean(connectivity_subset[i, j, :n_controls])
avg_conn_matrix_FEP[(i, j), (j, i)] = np.mean(connectivity_subset[i, j, n_controls:])
if (region_mapping[el1_name] == region1 and region_mapping[el2_name] == region2) or \
(region_mapping[el2_name] == region1 and region_mapping[el1_name] == region2):
avg_conn_abs += np.abs(connectivity_subset[i, j, :])
avg_conn_sgn += connectivity_subset[i, j, :]
n_connections += 1
select[(i, j), (j, i)] = True
#Average connection strength
avg_conn_sgn = avg_conn_sgn / n_connections
avg_conn_abs = avg_conn_abs / n_connections
#Between groups test
sw_stat_HC, sw_pval_HC = shapiro(avg_conn_sgn[:n_controls])
sw_stat_FEP, sw_pval_FEP = shapiro(avg_conn_sgn[n_controls:])
stat, pval = mannwhitneyu(avg_conn_sgn[:n_controls], avg_conn_sgn[n_controls:])
pvals_results.append([pval, str(f_range[0]) + "-" + str(f_range[1]),
str(conn_tmin) + "_" + str(conn_tmax),
condition, region1 + "-" + region2])
sw_results.append(
[sw_pval_FEP, str(f_range[0]) + "-" + str(f_range[1]), str(conn_tmin) + "_" + str(conn_tmax),
condition, region1 + "-" + region2, "FEP"])
sw_results.append([sw_pval_HC, str(f_range[0]) + "-" + str(f_range[1]),
str(conn_tmin) + "_" + str(conn_tmax), condition,
region1 + "-" + region2, "HC"])
if pval < 0.05:
ed_cmap = 'RdBu_r'
tp_cmap = "Greys"
tp_point = 0.3
cmin = np.minimum(np.min(avg_conn_matrix_FEP[select]), np.min(avg_conn_matrix_HC[select]))
cmax = np.maximum(np.max(avg_conn_matrix_FEP[select]), np.max(avg_conn_matrix_HC[select]))
cmin = float(format(-1 * np.maximum(np.abs(cmin), cmax), '.2f'))
cmax = float(format(np.maximum(np.abs(cmin), cmax), '.2f'))
kw_top = dict(margin=15 / 100, chan_offset=(0., -0.81, 0.), chan_size=15,
system='cartesian', cbtxtsz=60, txtsz=155.)
t_obj = TopoObj('topo', tp_point*np.ones(n_channels),
channels=[electrode_mapping[i] for i in selected_channels],
xyz=xy, cmap=tp_cmap, **kw_top, clim=(0, 1))
t_obj.connect(avg_conn_matrix_HC, select=select, cmap=ed_cmap,
antialias=True, line_width=4., clim=(cmin, cmax))
t_obj.screenshot(results_base_dir + suffix_base[1:] + "_HC_edge.png", autocrop=True, dpi=600.,
bgcolor='white')
t_obj = TopoObj('topo', tp_point*np.ones(n_channels),
channels=[electrode_mapping[i] for i in selected_channels],
xyz=xy, cmap=tp_cmap, **kw_top, clim=(0, 1))
t_obj.connect(avg_conn_matrix_FEP, select=select, cmap=ed_cmap, antialias=True,
line_width=4., clim=(cmin, cmax))
t_obj.screenshot(results_base_dir + suffix_base[1:] + "_FEP_edge.png", autocrop=True,
dpi=600., bgcolor='white')
fig_agg = Figure([results_base_dir + suffix_base[1:] + "_FEP_edge.png",
results_base_dir + suffix_base[1:] + "_HC_edge.png"],
titles=['Patients', 'Controls'], figtitle='Group average',
xlabels=[None, None], ylabels=[None, None], grid=[1, 2],
ax_bgcolor='white', fig_bgcolor='white',
subspace={'left': 0., 'right': 1., 'bottom': 0., 'top': .9,
'wspace': 0., 'hspace': 0.07}, figsize=(12, 6),
text_color='black', autocrop=True) #y=1.,
fig_agg.shared_colorbar((cmin, cmax), t_obj._connect._cmap,
fz_title=20, vmin=cmin, vmax=cmax,
under='olive', over='firebrick', position='right',
title='Average imcoh', fz_ticks=20,
pltmargin=.001, figmargin=.001)
fig_agg.save(results_base_dir + suffix_base[1:] + "_edge.png", dpi=300)
avg_conn = pd.DataFrame([avg_conn_abs, avg_conn_sgn,
["HC" if label == 1 else "FEP" for label in labels_true]],
index=["avg_abs", "avg_sgn", "Group"], columns=subjects).T
if pval < 0.05:
fig = plt.figure()
h = sns.histplot(avg_conn[avg_conn["Group"] == 'HC'], x="avg_sgn", element="step", bins=15)
h.set_title(
"Time window " + str(conn_tmin) + " - " + str(conn_tmax) + "s, frequency " + str(
f_range_low) + "-" + str(
f_range_high) + "Hz,\n" + region1 + "-" + region2 + " connectivity",
fontsize=12)
h.set_ylabel("Number of participants", fontsize=18)
h.set_xlabel("Average imcoh", fontsize=18)
axes = plt.gca()
axes.set_xlim([np.min(avg_conn['avg_sgn']), np.max(avg_conn['avg_sgn'])])
axes.set_ylim([0, 7.2])
plt.savefig(
results_base_dir + "histograms/" + "HC_avg_coh_sgn_hist_" + region1 + "-" +
region2 + "_" + suffix[1:] + ".png", dpi=400)
plt.close()
fig = plt.figure()
h = sns.histplot(avg_conn[avg_conn["Group"] == 'FEP'], x="avg_sgn", element="step",
bins=15)
print(suffix[1:] + "\nFEP:\t")
print(np.mean(avg_conn[avg_conn["Group"] == 'FEP']['avg_sgn']))
print(np.std(avg_conn[avg_conn["Group"] == 'FEP']['avg_sgn']))
print(suffix[1:] + "\nHC:\t")
print(np.mean(avg_conn[avg_conn["Group"] == 'HC']['avg_sgn']))
print(np.std(avg_conn[avg_conn["Group"] == 'HC']['avg_sgn']))
h.set_title(
"Time window " + str(conn_tmin) + " - " + str(conn_tmax) + "s, frequency " + str(
f_range_low) + "-" + str(
f_range_high) + "Hz,\n" + region1 + "-" + region2 + " connectivity",
fontsize=12)
h.set_ylabel("Number of participants", fontsize=18)
h.set_xlabel("Average imcoh", fontsize=18)
axes = plt.gca()
axes.set_xlim([np.min(avg_conn['avg_sgn']), np.max(avg_conn['avg_sgn'])])
axes.set_ylim([0, 7.2])
plt.savefig(
results_base_dir + "histograms/" + "FEP_avg_coh_sgn_hist_" + region1 + "-" +
region2 + "_" + suffix[1:] + ".png",
dpi=400)
plt.close()
if not os.path.exists(results_base_dir + "histograms/"):
os.mkdir(results_base_dir + "histograms/")
print("Results saved at: " + results_base_dir + "histograms/")
h = sns.histplot(avg_conn, x="avg_abs", hue="Group", element="step")
h.set_title("Time window " + str(conn_tmin) + " - " + str(conn_tmax) + "s, frequency "
+ str(f_range_low) + "-" + str(f_range_high) + "Hz, "
+ region1 + " - " + region2 + "connectivity", fontsize=12)
h.set_ylabel("Number of participants)", fontsize=10)
h.set_xlabel("Average |imcoh|", fontsize=10)
plt.savefig(results_base_dir + "histograms/" + "avg_coh_abs_" + region1 + "-" + region2 + "_" + suffix[1:] + ".png", dpi=400)
plt.close()
h = sns.histplot(avg_conn, x="avg_sgn", hue="Group", element="step")
h.set_title("Time window " + str(conn_tmin) + " - " + str(conn_tmax) + "s, frequency "
+ str(f_range_low) + "-" + str(f_range_high) + "Hz,\n"
+ region1 + "-" + region2 + " connectivity, pvalue = " + '%.3f' % pval, fontsize=12)
h.set_ylabel("Number of participants", fontsize=10)
h.set_xlabel("Average imcoh", fontsize=10)
#put most interesting plots on the top
if pval < 0.05:
mrk = "0"
else:
mrk = ""
plt.savefig(results_base_dir + "histograms/" + mrk + "avg_coh_sgn_" + region1 + "-" + region2 + "_" + suffix[1:] + ".png", dpi=400)
plt.close()
pvals_results = pd.DataFrame(pvals_results, columns=['Pval', 'Freq', 'Window', 'Condition', 'Region'])
pvals_results = pvals_results.assign(Pval_corr=multipletests(pvals_results['Pval'], method='fdr_bh')[1])
pvals_results.to_csv(results_base_dir + "histograms/" + "avg_coh_sgn_pvalues" + param_file[:-5] + ".csv")
sw_pval = pd.DataFrame(sw_results, columns=["pval", "freq", "twin", "condition", "regions", "Group"])
sw_pval.to_csv(results_base_dir + "histograms/" + "avg_coh_sgn_gaussianity_test" + param_file[:-5] + ".csv")
for condition in conditions:
for f_range in franges:
pvals_results_subset = pvals_results[(pvals_results['Condition'] == condition) &
(pvals_results['Freq'] == str(f_range[0]) + "-" + str(f_range[1]))]
pvals_results_subset['Pval'] = pvals_results_subset['Pval'].apply(lambda x: -1*np.log10(x))
pvals_results_subset['Pval_corr'] = pvals_results_subset['Pval_corr'].apply(lambda x: -1 * np.log10(x))
pvals_results_subset['Window'] = pvals_results_subset['Window'].apply(lambda x: 0.5 * float(x.split("_")[1]) +
0.5 * float(x.split("_")[0]))
sns.set(rc={'figure.figsize': (7, 5.5)})
sns.set_style('white')
b = sns.lineplot(x="Window", y="Pval", hue="Region", style="Region",
dashes=False, markers=True, data=pvals_results_subset)
b.set_title("Frequency " + str(f_range_low) + "-" + str(f_range_high) + "Hz, " +
condition + " condition", fontsize=12)
sns.despine(offset=-5, trim=False)
b.set_ylabel("-log (non-corrected pval)", fontsize=10)
b.set_xlabel("Time window" + r'$ (\pm $' + '%d' % (500*(conn_tmax-conn_tmin)) + "ms)", fontsize=10)
b.tick_params(axis='y', labelsize=10)
b.tick_params(axis='x', labelsize=10, rotation=30)
plt.axhline(-1 * np.log10(0.05), color="xkcd:light gray", ls='--')
plt.text(tmin_tmax[-1][0], -1 * np.log10(0.05) + 0.05, "p=0.05", fontsize=8)
plt.axhline(-1 * np.log10(0.01), color="xkcd:light gray", ls='--')
plt.text(tmin_tmax[1][0], -1 * np.log10(0.01) + 0.05, "p=0.01", fontsize=8)
plt.savefig(results_base_dir + "region_pval_for_windows_" + str(f_range[0]) + "-" + str(f_range[1]) + "Hz_" +
condition + ".png", dpi=400)
plt.close()
sns.set(rc={'figure.figsize': (7, 5.5)})
sns.set_style('white')
b = sns.lineplot(x="Window", y="Pval_corr", hue="Region", style="Region", dashes=False, markers=True,
data=pvals_results_subset)
b.set_title("Frequency " + str(f_range_low) + "-" + str(f_range_high) + "Hz, " + condition + " condition",
fontsize=12)
sns.despine(offset=-5, trim=False)
b.set(ylim=(0, 1.75))
b.set_ylabel("-log (corrected pval)", fontsize=10)
b.set_xlabel("Time window" + r'$ (\pm $' + '%d' % (500 * (conn_tmax - conn_tmin)) + "ms)", fontsize=10)
b.tick_params(axis='y', labelsize=10)
b.tick_params(axis='x', labelsize=10, rotation=30)
plt.axhline(-1 * np.log10(0.05), color="xkcd:light gray", ls='--')
plt.text(tmin_tmax[-1][0], -1 * np.log10(0.05) + 0.02, "p=0.05", fontsize=8)
plt.axvline(0, color="xkcd:light grey blue", ls='--')
plt.text(tmin_tmax[1][0], -1 * np.log10(0.01) + 0.05, "p=0.01", fontsize=8)
plt.savefig(results_base_dir + "region_corr_pval_for_windows_" + str(f_range[0]) + "-" + str(
f_range[1]) + "Hz_" + condition + ".png", dpi=400)
plt.close()
region_ind = {'LT': 1, 'RT': 2, 'CP': 3, 'F': 4, 'O': 5}
reg_labels = np.zeros(len(selected_channels))
for i, el_name in enumerate(selected_channels):
reg_labels[i] = region_ind[region_mapping[el_name]]
x_max = np.max(np.abs(xyz[:, 0]))
y_max = np.max(np.abs(xyz[:, 1]))
xyz[:, 0] = xy[:, 0]/np.max(xy[:, 0]) * x_max
xyz[:, 1] = xy[:, 1]/np.max(xy[:, 1]) * y_max - 0.017
xyz *= 1000
s_obj = SourceObj('Basic', np.array(xyz), radius_min=50,
text=[electrode_mapping[i] for i in selected_channels], text_size=10., text_color='black', text_bold=True)
s_obj.color_sources(data=reg_labels, cmap='plasma')
b_obj = BrainObj('B2')
vb = Brain(source_obj=s_obj, brain_obj=b_obj, bgcolor='white')
vb.brain_control(translucent=True)
vb.menuDispCbar.setChecked(True)
vb.rotate('axial_0')
#vb.show()
vb.screenshot(results_base_dir + "region_mapping.png", autocrop=True, print_size=(24, 24), unit='centimeter', dpi=300.)

Event Timeline