Page MenuHomec4science

functions.py
No OneTemporary

File Metadata

Created
Thu, Mar 28, 12:16

functions.py

import mne as mne
import numpy as np
import h5py
from statsmodels.stats.multitest import multipletests
from scipy.stats import ttest_ind
import matplotlib.pyplot as plt
from itertools import combinations
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
def compute_connectivity_matrix(f_range, tmin_tmax_range, csd_param, method,
condition, subj, epochs_tmin, epochs_tmax,
results_base_dir, data_base_dir,
scores_range=None, do_rename_channels=True):
"""
Parameters
----------
f_range : frequency range, tuple (float, float), in Hz
tmin_tmax_range: time relative to the stimulus, tuple (float, float), in seconds
csd_param: CSD parameters, tuple (float, float), stiffness 2-5, lambda^2 0-10e-5
method: connectivity computation method, str
condition: proposer/responder
subj: subject id list, list of str
epochs_tmin: time range for epochs in seconds
epochs_tmax: time range for epochs in seconds
results_base_dir: path where to write the result
data_base_dir: data path
scores_range: behavioral scores range, which epochs to use for connectivity,
tuple (int, int)
do_rename_channels: restore biosemi electrode names, bool
Returns
con[:, :, 0]: connectivity matrix
num_epochs: number of epochs
b_scores: behavioral scores values
b_accept: accepted/rejected offers
"""
f_range_low = f_range[0]
f_range_high = f_range[1]
if tmin_tmax_range is None:
conn_tmin = -0.2
conn_tmax = 0.5
else:
conn_tmin = tmin_tmax_range[0]
conn_tmax = tmin_tmax_range[1]
if csd_param is None:
do_csd = False
else:
do_csd = True
stiffnes = csd_param[0]
lambd = csd_param[1]
if scores_range is None:
scores_range = (1, 9)
print("Starting connectivity computation: " + condition + " condition, " +
str(f_range_low) + "-" + str(f_range_high) + " Hz: " + method +
", tmin: " + str(conn_tmin) + "s, tmax: " + str(conn_tmax) + "s")
suffix_base = ""
suffix_base += "_" + str(f_range_low) + "-" + str(f_range_high) + "_" + \
method + "_T_" + str(conn_tmin) + "_" + str(conn_tmax) + \
"s_" + str(scores_range[0]) + "_" + str(scores_range[1]) + \
"chf"
if do_csd:
suffix_base += "_csd" + str(stiffnes) + "_" + str(lambd)
else:
suffix_base += "_nocsd"
if condition == "proposer":
suffix_base += "P"
event_name = 'Stimulus/S128'
event_num = 128
ev_offset = 192
elif condition == "responder":
suffix_base += "R"
event_name = 'Stimulus/S160'
event_num = 160
ev_offset = 224
if subj[0] == '0':
if condition == "responder":
file_name = subj + "_R1"
elif condition == "proposer":
file_name = subj + "_P1"
else:
raise ValueError("Unrecognised condition" + condition)
elif subj[0] == '1':
if condition == "responder":
file_name = subj + "_R1"
elif condition == "proposer":
file_name = subj + "_P1"
else:
raise ValueError("Unrecognised condition" + condition)
header_file = data_base_dir + file_name + ".vhdr"
marker_file = data_base_dir + file_name + ".vmrk"
print("Reading:\t" + header_file)
raw = mne.io.read_raw_brainvision(header_file, preload=True)
# rename channels to a 128biosemi layout names
picks = mne.pick_channels(raw.info['ch_names'], include=[])
if do_rename_channels:
names_dict = {}
ch_ind = 0
for prefix in ['A', 'B', 'C', 'D']:
for i in range(0, 32):
name = prefix + str(i + 1)
if name != raw.info['ch_names'][ch_ind]:
names_dict.update({raw.info['ch_names'][ch_ind]: name})
ch_ind += 1
raw.rename_channels(names_dict)
print("Updated " + str(len(names_dict)) + " electrodes names")
# Resetting channel types
types_dict = {}
for name in raw.info['ch_names']:
if name[0] != 'E':
types_dict.update({name: 'eeg'})
else:
types_dict.update({name: 'eog'})
raw.set_channel_types(types_dict)
my_montage = mne.channels.read_custom_montage("biosemi128.sfp")
raw.set_montage(my_montage)
raw.filter(l_freq=None, h_freq=40)
annot = mne.read_annotations(marker_file)
raw.set_annotations(annot)
events, event_ids = mne.events_from_annotations(raw)
events_start_only = mne.pick_events(events, include=event_num)
event_ids_start_only = {key: event_ids[key] for key in event_ids.keys() & {event_name}}
sfreq = raw.info['sfreq']
raw.set_eeg_reference('average', projection=True)
if do_csd:
raw_post = mne.preprocessing.compute_current_source_density(raw, stiffness=stiffnes, lambda2=lambd)
else:
raw_post = raw
epochs = mne.Epochs(raw_post, events_start_only, event_ids_start_only,
tmin=epochs_tmin, tmax=epochs_tmax, picks='data')
epochs.drop_bad()
b_scores = np.zeros(epochs.events.shape[0]).astype(int)
b_accept = np.zeros(epochs.events.shape[0]).astype(int)
epochs.drop_bad()
for nm_ev, ev in enumerate(epochs.events):
ev_ind = np.where(events[:, 0] == ev[0])[0][0]
b_scores[nm_ev] = events[ev_ind:, 2][np.where(
(ev_offset + 1 <= events[ev_ind:, 2]) & (
events[ev_ind:, 2] <= ev_offset + 9))[0][0]] - ev_offset
b_accept[nm_ev] = events[ev_ind:, 2][np.where(
(1 <= events[ev_ind:, 2]) & (events[ev_ind:, 2] <= 4))[0][0]] % 2
epochs_ind = np.where(
(b_scores >= scores_range[0]) & (b_scores <= scores_range[1]))[0]
num_epochs = len(epochs[epochs_ind])
(con, freqs,
times, n_epochs,
n_tapers) = mne.connectivity.spectral_connectivity(epochs[epochs_ind], method=method,
mode='multitaper', sfreq=sfreq,
fmin=f_range_low, fmax=f_range_high,
tmin=conn_tmin, tmax=conn_tmax,
faverage=True, mt_adaptive=False,
n_jobs=1)
con[:, :, 0] += con[:, :, 0].T
wf = h5py.File(results_base_dir + 'con' + file_name + suffix_base + '.h5', "w")
wf.create_dataset('con_matrix', data=con[:, :, 0])
wf.create_dataset('bsc_values', data=b_scores)
wf.create_dataset('bsc_acc', data=b_accept)
wf.close()
#np.save(results_base_dir + 'con' + file_name + suffix_base + '.npy', con[:, :, 0])
return con[:, :, 0], num_epochs, b_scores, b_accept
def compute_distances(el_subset = 'All', ref_data_file = "data/000_P1.vhdr", vis=False):
# needs checking the picks order
raw = mne.io.read_raw_brainvision(ref_data_file, preload=True)
if el_subset == 'All':
el_subset = ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', 'A11','A12', 'A13', 'A14', 'A15',
'A16', 'A17', 'A18', 'A19', 'A20','A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27', 'A28', 'A29',
'A30', 'A31', 'A32', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11', 'B12',
'B13', 'B14', 'B15', 'B16', 'B17', 'B18', 'B19', 'B20', 'B21', 'B22', 'B23', 'B24', 'B25', 'B26',
'B27', 'B28', 'B29', 'B30', 'B31', 'B32', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9',
'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23',
'C24', 'C25', 'C26', 'C27', 'C28', 'C29', 'C30', 'C31', 'C32', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6',
'D7', 'D8', 'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15', 'D16', 'D17', 'D18', 'D19', 'D20',
'D21', 'D22', 'D23', 'D24', 'D25', 'D26', 'D27', 'D28', 'D29', 'D30', 'D31', 'D32']
n_electrodes = len(el_subset)
names_dict = {}
ch_ind = 0
for prefix in ['A', 'B', 'C', 'D']:
for i in range(0, 32):
name = prefix + str(i + 1)
if name != raw.info['ch_names'][ch_ind]:
names_dict.update({raw.info['ch_names'][ch_ind]: name})
ch_ind += 1
raw.rename_channels(names_dict)
print("Updated " + str(len(names_dict)) + " electrodes names")
# Resetting channel types
types_dict = {}
for name in raw.info['ch_names']:
if name[0] != 'E':
types_dict.update({name: 'eeg'})
else:
types_dict.update({name: 'eog'})
raw.set_channel_types(types_dict)
raw.set_montage('biosemi128')
picks = mne.pick_channels(raw.info['ch_names'], include=el_subset)
#raw = raw.pick(picks)
dist_m = np.zeros(shape=(n_electrodes, n_electrodes))
for i in range(0, n_electrodes):
for j in range(i + 1, n_electrodes):
dist_m[j, i] = np.linalg.norm(raw.info['chs'][picks[i]]['loc'][0:3] - raw.info['chs'][picks[j]]['loc'][0:3])
if vis:
#raw = raw.pick(picks)
raw.info['bads'] = el_subset
raw.plot_sensors(ch_type='eeg', kind='3d', show_names=True)
plt.figure()
plt.imshow(dist_m, cmap='hot')
#plt.tick_params(labelsize=6)
step = max(1, int(np.log2(n_electrodes)/2))
plt.xticks(np.arange(n_electrodes,step=step), el_subset[::step], rotation=270)
plt.yticks(np.arange(n_electrodes,step=step), el_subset[::step])
plt.colorbar()
plt.show()
return dist_m
def select_el_subset_short(n_of_electrodes : int, connectivity_matrix, subjects, el_subset='All', reference_raw = None):
"""
Parameters
----------
n_of_electrodes: desired number of n best electrodes
connectivity_matrix: connectivity matrix: electrodes x electrodes x subjects
el_subset: which subset of electrodes to start from
n_controls: how many controls in the connectivity matrix
reference_raw: reference eeg file for electrode information
Returns
-------
"""
n_controls = 0
n_patients = 0
for subj in subjects:
if subj[0] == '0':
n_controls += 1
elif subj[0] == '1':
n_patients += 1
if el_subset == 'All':
el_subset = ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', 'A11','A12', 'A13', 'A14', 'A15',
'A16', 'A17', 'A18', 'A19', 'A20','A21', 'A22', 'A23', 'A24', 'A25', 'A26', 'A27', 'A28', 'A29',
'A30', 'A31', 'A32', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11', 'B12',
'B13', 'B14', 'B15', 'B16', 'B17', 'B18', 'B19', 'B20', 'B21', 'B22', 'B23', 'B24', 'B25', 'B26',
'B27', 'B28', 'B29', 'B30', 'B31', 'B32', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9',
'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23',
'C24', 'C25', 'C26', 'C27', 'C28', 'C29', 'C30', 'C31', 'C32', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6',
'D7', 'D8', 'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15', 'D16', 'D17', 'D18', 'D19', 'D20',
'D21', 'D22', 'D23', 'D24', 'D25', 'D26', 'D27', 'D28', 'D29', 'D30', 'D31', 'D32']
reference_raw_file = "data/000_P1.vhdr"
do_rename_channels = True
if reference_raw is None:
reference_raw = mne.io.read_raw_brainvision(reference_raw_file, preload=False)
if do_rename_channels:
names_dict = {}
ch_ind = 0
for prefix in ['A', 'B', 'C', 'D']:
for i in range(0, 32):
name = prefix + str(i + 1)
if name != reference_raw.info['ch_names'][ch_ind]:
names_dict.update({reference_raw.info['ch_names'][ch_ind]: name})
ch_ind += 1
reference_raw.rename_channels(names_dict)
types_dict = {}
for name in reference_raw.info['ch_names']:
if name[0] != 'E':
types_dict.update({name: 'eeg'})
else:
types_dict.update({name: 'eog'})
reference_raw.set_channel_types(types_dict)
picks = mne.pick_channels(reference_raw.info['ch_names'], include=el_subset)
cm_selected = connectivity_matrix[picks, :, :][:, picks, :]
n_channels = picks.shape[0]
tstats, pvalue = ttest_ind(cm_selected[:, :, :n_controls], cm_selected[:, :, n_controls:], axis=2)
flat_pvalues = np.empty(shape=(int(n_channels*(n_channels-1)/2)))
ind = 0
for i in range(0, n_channels):
for j in range(i + 1, n_channels):
flat_pvalues[ind] = pvalue[j, i]
ind += 1
h_true, corrected_pvalues_flat, als, alb = multipletests(flat_pvalues, method='fdr_bh')
corrected_pvalue = np.empty_like(pvalue)
corrected_pvalue[:] = np.nan
ind = 0
for i in range(0, n_channels):
for j in range(i + 1, n_channels):
corrected_pvalue[j, i] = corrected_pvalues_flat[ind]
ind += 1
tstats_corrected_sorted = np.sort(np.tril(tstats, -1).flatten())
i = 0
result_subset = []
while len(result_subset) < n_of_electrodes:
inds = np.where(tstats == tstats_corrected_sorted[-1 * (i+1)])
if el_subset[inds[0][0]] not in result_subset:
result_subset.append(el_subset[inds[0][0]])
if el_subset[inds[1][0]] not in result_subset:
result_subset.append(el_subset[inds[1][0]])
i += 1
result_subset = sorted(result_subset)
return result_subset
def plot_statistical_test_results(stats, pvalues, labels, writepath, title, show=False, left_title="Test results"):
step = max(1, int(np.log2(len(labels) + 1) / 2))
fig = plt.figure(title, figsize=[12, 6])
ax = plt.subplot(1, 2, 1)
plt.title(left_title)
im = plt.imshow(stats, cmap='hot')
plt.xticks(np.arange(len(labels), step=step), labels[::step], rotation=270)
plt.yticks(np.arange(len(labels), step=step), labels[::step])
plt.colorbar(im, fraction=0.046, pad=0.04)
ax = plt.subplot(1, 2, 2)
plt.title("p values, min = " + str('{0:.4f}'.format(np.nanmin(pvalues))))
im = plt.imshow(pvalues, cmap='hot')
plt.xticks(np.arange(len(labels), step=step), labels[::step], rotation=270)
plt.yticks(np.arange(len(labels), step=step), labels[::step])
plt.colorbar(im, fraction=0.046, pad=0.04)
plt.tight_layout(pad=1)
if show:
plt.show()
else:
plt.savefig(writepath, dpi=600)
plt.close(fig)
return
def get_biosemi_to_1020_mapping(num_of_electrodes):
electrode_map32 = {'C29': 'Fp1', 'C16': 'Fp2', 'C27': 'AF3', 'C14': 'AF4', 'D7': 'F7', 'D4': 'F3', 'C21': 'Fz',
'C4': 'F4', 'C7': 'F8', 'D10': 'FC5', 'D2': 'FC1', 'C2': 'FC2', 'B29': 'FC6', 'D23': 'T7',
'D19': 'C3', 'A1': 'Cz', 'B22': 'C4', 'B26': 'T8', 'D26': 'CP5', 'D16': 'CP1', 'B2': 'CP2',
'B16': 'CP6', 'D31': 'P7', 'A7': 'P3', 'A19': 'Pz', 'B4': 'P4', 'B11': 'P8', 'A17': 'PO3',
'A30': 'PO4', 'A15': 'O1', 'A23': 'Oz', 'A28': 'O2'}
electrode_map64 = {'C29': 'Fp1', 'C17': 'Fpz', 'C16': 'Fp2', 'C30': 'AF7', 'C27': 'AF3', 'C19': 'AFz', 'C14': 'AF4',
'C8': 'AF8', 'D7': 'F7', 'D5': 'F5', 'D4': 'F3', 'C25': 'F1', 'C21': 'Fz', 'C12': 'F2',
'C4': 'F4', 'C5': 'F6', 'C7': 'F8', 'D8': 'FT7', 'D10': 'FC5', 'D12': 'FC3', 'D2': 'FC1',
'C23': 'FCz', 'C2': 'FC2', 'B31': 'FC4', 'B29': 'FC6', 'B27': 'FT8', 'D23': 'T7', 'D21': 'C5',
'D19': 'C3', 'D14': 'C1', 'A1': 'Cz', 'B20': 'C2', 'B22': 'C4', 'B24': 'C6', 'B26': 'T8',
'D24': 'TP7', 'D26': 'CP5', 'D28': 'CP3', 'D16': 'CP1', 'A3': 'CPz', 'B2': 'CP2', 'B18': 'CP4',
'B16': 'CP6', 'B14': 'TP8', 'D32': 'P9', 'D31': 'P7', 'D29': 'P5', 'A7': 'P3', 'A5': 'P1',
'A19': 'Pz', 'A32': 'P2', 'B4': 'P4', 'B13': 'P6', 'B11': 'P8', 'B10': 'P10', 'A10': 'PO7',
'A16': 'PO3', 'A21': 'POz', 'A29': 'PO4', 'B7': 'PO8', 'A15': 'O1', 'A23': 'Oz', 'A28': 'O2',
'A25': 'Iz'}
if num_of_electrodes == 32:
return electrode_map32
elif num_of_electrodes == 64:
return electrode_map64
else:
raise ValueError("Wrong parameter: should be 32 or 64")
def get_biosemi128_names():
n = ['A1', 'A2', 'A3', 'A4', 'A5', 'A6', 'A7', 'A8', 'A9', 'A10', 'A11',
'A12', 'A13', 'A14', 'A15', 'A16', 'A17', 'A18', 'A19', 'A20', 'A21',
'A22', 'A23', 'A24', 'A25', 'A26', 'A27', 'A28', 'A29', 'A30', 'A31',
'A32', 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10',
'B11', 'B12', 'B13', 'B14', 'B15', 'B16', 'B17', 'B18', 'B19', 'B20',
'B21', 'B22', 'B23', 'B24', 'B25', 'B26', 'B27', 'B28', 'B29', 'B30',
'B31', 'B32', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9',
'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19',
'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26', 'C27', 'C28', 'C29',
'C30', 'C31', 'C32', 'D1', 'D2', 'D3', 'D4', 'D5', 'D6', 'D7', 'D8',
'D9', 'D10', 'D11', 'D12', 'D13', 'D14', 'D15', 'D16', 'D17', 'D18',
'D19', 'D20', 'D21', 'D22', 'D23', 'D24', 'D25', 'D26', 'D27', 'D28',
'D29', 'D30', 'D31', 'D32']
return n
def get_biosemi128_region_mapping():
map = {'A1': 'CP', 'A2': 'CP', 'A3': 'CP', 'A4': 'CP',
'A5': 'CP', 'A6': 'CP', 'A7': 'CP', 'A8': 'CP',
'A9': 'CP', 'A10': 'O', 'A11': 'O', 'A12': 'O',
'A13': 'O', 'A14': 'O', 'A15': 'O', 'A16': 'O',
'A17': 'O', 'A18': 'CP', 'A19': 'CP', 'A20': 'CP',
'A21': 'O', 'A22': 'O', 'A23': 'O', 'A24': 'O',
'A25': 'O', 'A26': 'O', 'A27': 'O', 'A28': 'O',
'A29': 'O', 'A30': 'O', 'A31': 'CP', 'A32': 'CP',
'B1': 'CP', 'B2': 'CP', 'B3': 'CP', 'B4': 'CP',
'B5': 'CP', 'B6': 'CP', 'B7': 'O', 'B8': 'O',
'B9': 'O', 'B10': 'RT', 'B11': 'RT', 'B12': 'RT',
'B13': 'CP', 'B14': 'RT', 'B15': 'RT', 'B16': 'RT',
'B17': 'CP', 'B18': 'CP', 'B19': 'CP', 'B20': 'CP',
'B21': 'CP', 'B22': 'CP', 'B23': 'RT', 'B24': 'RT',
'B25': 'RT', 'B26': 'RT', 'B27': 'RT', 'B28': 'RT',
'B29': 'RT', 'B30': 'CP', 'B31': 'CP', 'B32': 'CP',
'C1': 'CP', 'C2': 'CP', 'C3': 'CP', 'C4': 'F',
'C5': 'F', 'C6': 'RT', 'C7': 'RT', 'C8': 'F',
'C9': 'F', 'C10': 'F', 'C11': 'CP', 'C12': 'F',
'C13': 'F', 'C14': 'F', 'C15': 'F', 'C16': 'F',
'C17': 'F', 'C18': 'F', 'C19': 'F', 'C20': 'F',
'C21': 'F', 'C22': 'F', 'C23': 'CP', 'C24': 'CP',
'C25': 'F', 'C26': 'F', 'C27': 'F', 'C28': 'F',
'C29': 'F', 'C30': 'F', 'C31': 'F', 'C32': 'F',
'D1': 'CP', 'D2': 'CP', 'D3': 'F', 'D4': 'F',
'D5': 'F', 'D6': 'LT', 'D7': 'LT', 'D8': 'LT',
'D9': 'LT', 'D10': 'LT', 'D11': 'CP', 'D12': 'CP',
'D13': 'CP', 'D14': 'CP', 'D15': 'CP', 'D16': 'CP',
'D17': 'CP', 'D18': 'CP', 'D19': 'CP', 'D20': 'LT',
'D21': 'LT', 'D22': 'LT', 'D23': 'LT', 'D24': 'LT',
'D25': 'LT', 'D26': 'LT', 'D27': 'CP', 'D28': 'CP',
'D29': 'CP', 'D30': 'LT', 'D31': 'LT', 'D32': 'LT'}
return map
def xy_1020_coord(el=None):
xy_dict = {
'Fp1': (-90, 277),
'Fp2': (90, 277),
'AF3': (-109, 223),
'AF4': (109, 223),
'F7': (-234, 170),
'F3': (-127, 154),
'Fz': (0, 146),
'F4': (127, 154),
'F8': (234, 170),
'FC5': (-207, 80),
'FC1': (-68, 75),
'FC2': (68, 75),
'FC6': (207, 80),
'T7': (-287, 0),
'C3': (-143, 0),
'Cz': (0, 0),
'C4': (143, 0),
'T8': (287, 0),
'CP5': (-207, -80),
'CP1': (-68, -75),
'CP2': (68, -75),
'CP6': (207, -80),
'P7': (-234, -170),
'P3': (-127, -154),
'Pz': (0, -146),
'P4': (127, -154),
'P8': (234, -170),
'PO3': (-109, -223),
'PO4': (109, -223),
'O1': (-90, -277),
'Oz': (0, -287),
'O2': (90, -277)}
if el is None:
return xy_dict
else:
return xy_dict[el]

Event Timeline