Page MenuHomec4science

build_template.py
No OneTemporary

File Metadata

Created
Sat, May 18, 12:31

build_template.py

import numpy as np
import pickle
import matplotlib.pyplot as plt
import os
from sklearn.cluster import MeanShift, estimate_bandwidth
data_beats_dir = "../data/beats/"
FREQ = 128
MIN_TO_AVG = 5
def open_file(file_name):
'''
Data structure:
File:
|
DICTIONARY:
|
--> Beat_annot:
|
--> lvl:
|
--> "t": []
|
--> "v": []
'''
file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name))
data = {}
with open(file_name_full,"rb") as f:
data = pickle.load(f)
for k in data.keys():
data[k] = data[k][0]
return data
def min_max_normalization(vector):
mi_v = min(vector)
ma_v = max(vector)
norm = (np.array(vector)-mi_v)/(ma_v-mi_v)
norm -= np.average(norm)
return list(norm)
def get_significant_peak(signal):
val = -1000
peak_max = None
is_peak = lambda w,x,y,z: True if (w<x and x>y) or (w<x and x==y and y>z) else False
for i in range(1,len(signal)-2):
if is_peak(signal[i-1],signal[i],signal[i+1],signal[i+2]) and signal[i]> val:
peak_max = i
val = signal[i]
return peak_max
def get_beats(data, t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60):
beats = {}
annotations = list(data.keys())
for annot in annotations[1:]:
if annot >= int(t_start_seconds*FREQ) and annot <= int(t_stop_seconds*FREQ):
beats[annot] = data[annot]
elif annot > int(t_stop_seconds*FREQ):
break
return beats
def get_r_peaks_idx(beats):
r_times_idx = []
time_to_look = int(0.15*FREQ) #150 ms
for annot in beats.keys():
idx = 0
signal_r_peaks = []
for i,t in enumerate(beats[annot]['t']):
if t == annot - time_to_look:
idx = i
if t >= annot - time_to_look and t<= annot + time_to_look:
signal_r_peaks.append(beats[annot]['v'][i])
peak_idx = get_significant_peak(signal_r_peaks)+idx
r_times_idx.append(peak_idx)
return r_times_idx
def allign_beats(data, allignment_idxs, normalize = True):
len_bef = len_aft = 1000
data_alligned = {}
for annot, allignment_idx in zip(data.keys(),allignment_idxs):
time = data[annot]['t']
this_len_bef = len(time[:allignment_idx])
this_len_aft = len(time[allignment_idx+1:])
if this_len_bef < len_bef:
len_bef = this_len_bef
if this_len_aft < len_aft:
len_aft = this_len_aft
for annot, allignment_idx in zip(data.keys(),allignment_idxs):
new_t = data[annot]['t'][allignment_idx-len_bef:allignment_idx+len_aft+1]
new_v = data[annot]['v'][allignment_idx-len_bef:allignment_idx+len_aft+1]
if normalize:
new_v = min_max_normalization(new_v)
data_alligned[annot] = {'t':new_t,'v':new_v}
return data_alligned
def RMSE(v1,v2):
v1_n = np.array(v1)
v2_n = np.array(v2)
return np.sqrt(np.mean((v1_n-v2_n)**2))
def compute_template_lowest_distance(beats):
dist = np.zeros((len(beats),len(beats)))
for i in range(len(beats)):
for j in range(i+1,len(beats)):
dist[i][j] = RMSE(beats[i],beats[j])
dist[j][i] = dist[i][j]
dist = np.sum(dist, axis=0)
idx = np.argmin(dist)
return beats[idx]
def build_template(file_analyzed, normalize = True, mode = 'average', t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60):
template = []
beats_alligned = []
data = open_file(file_analyzed)
beats = get_beats(data, t_start_seconds = t_start_seconds, t_stop_seconds = t_stop_seconds)
allignment_idxs = get_r_peaks_idx(beats)
alligned_beats = allign_beats(beats, allignment_idxs, normalize=normalize)
for annot in alligned_beats.keys():
beats_alligned.append(alligned_beats[annot]['v'])
if mode == 'average':
template = list(np.average(beats_alligned,axis=0))
elif mode == 'distance':
template = compute_template_lowest_distance(beats_alligned)
template_std = list(np.std(beats_alligned,axis=0))
return template, template_std, alligned_beats
def multi_template(file_analyzed, normalize = True, t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60):
beats_alligned = []
data = open_file(file_analyzed)
beats = get_beats(data, t_start_seconds = t_start_seconds, t_stop_seconds = t_stop_seconds)
allignment_idxs = get_r_peaks_idx(beats)
alligned_beats = allign_beats(beats, allignment_idxs, normalize=normalize)
for annot in alligned_beats.keys():
beats_alligned.append(alligned_beats[annot]['v'])
ms = MeanShift()
ms.fit(beats_alligned)
cluster_centers = ms.cluster_centers_.tolist()
return cluster_centers, alligned_beats
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)")
parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true")
args = parser.parse_args()
files = os.listdir(data_beats_dir)
if args.file is not None:
analyzed = list(filter(lambda string: True if args.file in string else False, files))[0]
else:
analyzed = files[0]
if args.not_norm:
normalize = False
else:
normalize = True
template_dist,std,alligned = build_template(analyzed, normalize = normalize, mode = 'distance')
template_avg,_,_ = build_template(analyzed, normalize = normalize, mode = 'average')
plt.plot(template_dist, color = 'C0')
plt.plot(template_avg, color = 'C3')
plt.plot(np.array(template_dist)+np.array(std), color = "C1")
plt.plot(np.array(template_dist)-np.array(std), color = "C1")
for beat in alligned.keys():
plt.plot(alligned[beat]['v'],color = 'C2', alpha = 0.01)
plt.figure()
plt.plot(template_dist, color = 'C0')
plt.plot(template_avg, color = 'C3')
plt.show()

Event Timeline