Page MenuHomec4science

build_template.py
No OneTemporary

File Metadata

Created
Wed, Aug 14, 17:18

build_template.py

import os
import pickle
import sys
import warnings
from copy import copy
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import medfilt as median_filter
from sklearn.cluster import AffinityPropagation, MeanShift, estimate_bandwidth
scripts = '../helper_scripts'
if scripts not in sys.path:
sys.path.insert(0,scripts)
from dtw import dtw_std, ddtw_tv
data_beats_dir = "../data/beats/"
FREQ = 128
MULTIPLICATOR_FREQ = 4
MIN_TO_AVG = 5
def open_file(file_name, t_start_seconds = None, t_stop_seconds = None):
'''
Data structure:
File:
|
DICTIONARY:
|
--> Beat_annot:
|
--> lvl:
|
--> "t": []
|
--> "v": []
'''
file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name))
data = {}
data_out = {}
with open(file_name_full,"rb") as f:
data = pickle.load(f)
if t_start_seconds is None:
t_start_seconds = list(data.keys())[0]*FREQ
t_stop_seconds = list(data.keys())[-1]*FREQ
for k in data.keys():
if k >= int(t_start_seconds*FREQ) and k <= int(t_stop_seconds*FREQ):
data_out[k] = data[k][0]
elif k > int(t_stop_seconds*FREQ):
break
return data_out
def dtw_dist(v1,v2):
dist = dtw_std(v1, v2)
return dist
def ddtw_tv_dist(t1,v1,t2,v2, use_diff_dtw, time_weight):
dist = ddtw_tv(t1, v1, t2, v2, use_diff = use_diff_dtw, time_weight = time_weight)
return dist
def compute_affinity_matr(vectors, use_diff_dtw, time_weight):
A = np.zeros((len(vectors),len(vectors)))
for i in range(len(vectors)):
for j in range(i,len(vectors)):
t1 = list(range(len(vectors[i])))
v1 = vectors[i]
t2 = list(range(len(vectors[j])))
v2 = vectors[j]
dist = -ddtw_tv_dist(t1, v1, t2 , v2, use_diff_dtw, time_weight)**2
A[i,j] = dist
A[j,i] = dist
return A
def min_max_normalization(vector):
mi_v = min(vector)
ma_v = max(vector)
norm = (np.array(vector)-mi_v)/(ma_v-mi_v)
norm -= np.average(norm)
return list(norm)
def get_significant_peak(signal):
val = -1000
peak_max = None
is_peak = lambda w,x,y,z: True if (w<x and x>y) or (w<x and x==y and y>z) else False
for i in range(1,len(signal)-2):
if is_peak(signal[i-1],signal[i],signal[i+1],signal[i+2]) and signal[i]> val:
peak_max = i
val = signal[i]
return peak_max
def get_beats(data, t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60):
beats = {}
annotations = list(data.keys())
for annot in annotations[1:]:
if annot >= int(t_start_seconds*FREQ) and annot <= int(t_stop_seconds*FREQ):
beats[annot] = data[annot]
elif annot > int(t_stop_seconds*FREQ):
break
return beats
def get_r_peaks_idx(beats):
r_times_idx = []
time_to_look = int(0.15*FREQ) #150 ms
for annot in beats.keys():
idx = 0
signal_r_peaks = []
for i,t in enumerate(beats[annot]['t']):
if t == annot - time_to_look:
idx = i
if t >= annot - time_to_look and t<= annot + time_to_look:
signal_r_peaks.append(beats[annot]['v'][i])
peak_idx = get_significant_peak(signal_r_peaks)+idx
r_times_idx.append(peak_idx)
return r_times_idx
def allign_beats(data, allignment_idxs, normalize = True):
len_bef = len_aft = 1000
data_alligned = {}
for annot, allignment_idx in zip(data.keys(),allignment_idxs):
time = data[annot]['t']
this_len_bef = len(time[:allignment_idx])
this_len_aft = len(time[allignment_idx+1:])
if this_len_bef < len_bef:
len_bef = this_len_bef
if this_len_aft < len_aft:
len_aft = this_len_aft
for annot, allignment_idx in zip(data.keys(),allignment_idxs):
new_t = data[annot]['t'][allignment_idx-len_bef:allignment_idx+len_aft+1]
new_v = data[annot]['v'][allignment_idx-len_bef:allignment_idx+len_aft+1]
if normalize:
new_v = min_max_normalization(new_v)
data_alligned[annot] = {'t':new_t,'v':new_v}
return data_alligned
def RMSE(v1,v2):
v1_n = np.array(v1)
v2_n = np.array(v2)
return np.sqrt(np.mean((v1_n-v2_n)**2))
def compute_template_lowest_distance(beats):
dist = np.zeros((len(beats),len(beats)))
for i in range(len(beats)):
for j in range(i+1,len(beats)):
dist[i][j] = RMSE(beats[i],beats[j])
dist[j][i] = dist[i][j]
dist = np.sum(dist, axis=0)
idx = np.argmin(dist)
return beats[idx]
def build_template(file_analyzed, normalize = True, mode = 'average', t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60):
template = []
beats_alligned = []
data = open_file(file_analyzed)
beats = get_beats(data, t_start_seconds = t_start_seconds, t_stop_seconds = t_stop_seconds)
allignment_idxs = get_r_peaks_idx(beats)
alligned_beats = allign_beats(beats, allignment_idxs, normalize=normalize)
for annot in alligned_beats.keys():
beats_alligned.append(alligned_beats[annot]['v'])
if mode == 'average':
template = list(np.average(beats_alligned,axis=0))
elif mode == 'distance':
template = compute_template_lowest_distance(beats_alligned)
template_std = list(np.std(beats_alligned,axis=0))
return template, template_std, alligned_beats
def is_odd(num):
is_odd_num = True
if num%2 == 0:
is_odd_num = False
return is_odd_num
def is_even(num):
return not is_odd(num)
def euristic_smoothing_routine(signal,freq, SNR_obj = 50):
ms = 24e-3
K = int(ms * freq)
if is_even(K):
if K > 4:
K -= 1
else:
K += 1
filtered = median_filter(signal,K)
noise = np.array(signal) - filtered
p_filtered = sum(filtered**2)/len(filtered)
p_noise = sum(noise**2)/len(noise)
SNR = p_filtered/p_noise
smooth = False
if SNR > SNR_obj:
smooth = True
return smooth
def signal_processing_smoothing_routine(signal,freq):
'''
For future development: TODO
'''
return euristic_smoothing_routine(signal,freq)
def is_smooth(signal, freq, mode = 2, SNR_obj = 50):
'''
This methods can be implemented in two major ways:
1: Complex (fft, signal procesing)
2: Using an Euristic, assuming low snr and that the Noise can be isolated with median filtering:
- Kernel size (left/right) = K, with a time lenght of K/(frequency*multiplicator factor)
-> K = time_lenght * frequency * multiplicator factor
'''
smooth = False
if mode == 2:
smooth = euristic_smoothing_routine(signal,freq, SNR_obj = SNR_obj)
if mode == 1:
smooth = signal_processing_smoothing_routine(signal,freq)
return smooth
def multi_template(data, percentage_each_cluster = 7, freq = FREQ * MULTIPLICATOR_FREQ, use_diff_dtw = False, time_weight = 1):
'''
Output:
raw_templates: list
descriptor:{$lbl:
{"center": C,
"point": P,
"dist": D,
"dist_all_pt_cluster":[]}}
'''
descriptor = {} # {$lbl: {"center": C, "point": P, "dist": D, "dist_all_pt_cluster":[]}}
raw_templates = []
beats_v = []
#print(f"Computing templates: num vectors: {len(data)}")
for annot in data.keys():
beats_v.append(data[annot]['v'])
#print(f"Computing templates: num vectors: {len(beats_v)}")
affinity_matr = compute_affinity_matr(beats_v, use_diff_dtw = use_diff_dtw, time_weight = time_weight)
pref = np.percentile(affinity_matr, 5)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
af = AffinityPropagation(damping = 0.7, preference = pref, affinity = 'precomputed').fit(affinity_matr)
cluster_centers_indices = af.cluster_centers_indices_
labels = af.labels_
unique, counts = np.unique(labels, return_counts=True)
lbl_counts = dict(zip(unique, counts))
for i,pt in enumerate(beats_v):
lbl_this_pt = labels[i] # The number samples in a class need to be bigger than x% of all samples
if lbl_this_pt == -1 or lbl_counts[lbl_this_pt] < int(len(labels)*percentage_each_cluster/100) or not is_smooth(pt,freq, SNR_obj = 40):
continue
cluster_center = beats_v[cluster_centers_indices[lbl_this_pt]]
dist = ddtw_tv_dist(list(range(len(pt))), pt, list(range(len(cluster_center))), cluster_center, use_diff_dtw = use_diff_dtw, time_weight = time_weight)
if lbl_this_pt not in descriptor.keys():
descriptor[lbl_this_pt] = {"center": cluster_center, "point": None, "dist": np.inf, "dist_all_pt_cluster": []}
if dist < descriptor[lbl_this_pt]['dist']:
descriptor[lbl_this_pt]['point'] = pt
descriptor[lbl_this_pt]['dist'] = dist
descriptor[lbl_this_pt]["dist_all_pt_cluster"].append(dist)
#If no template is produced,
#Try with lower SNR
if len(descriptor.keys()) == 0:
for i,pt in enumerate(beats_v):
lbl_this_pt = labels[i] # The number samples in a class need to be bigger than x% of all samples
if lbl_this_pt == -1 or lbl_counts[lbl_this_pt] < int(len(labels)*percentage_each_cluster/100) or not is_smooth(pt,freq, SNR_obj = 20):
continue
cluster_center = beats_v[cluster_centers_indices[lbl_this_pt]]
dist = ddtw_tv_dist(list(range(len(pt))), pt, list(range(len(cluster_center))), cluster_center, use_diff_dtw = use_diff_dtw, time_weight = time_weight)
if lbl_this_pt not in descriptor.keys():
descriptor[lbl_this_pt] = {"center": cluster_center, "point": None, "dist": np.inf, "dist_all_pt_cluster": []}
if dist < descriptor[lbl_this_pt]['dist']:
descriptor[lbl_this_pt]['point'] = pt
descriptor[lbl_this_pt]['dist'] = dist
descriptor[lbl_this_pt]["dist_all_pt_cluster"].append(dist)
#take the centroid of the biggest, labeld, group
if len(descriptor.keys()) == 0:
max_len = 0
desc_sel = None
lbl_sel = None
for lbl in lbl_counts.keys():
center_beat = beats_v[cluster_centers_indices[lbl]]
if lbl_counts[lbl]>max_len and lbl != -1:
desc_sel = {"center": center_beat, "point": center_beat, "dist": 0, "dist_all_pt_cluster": []}
lbl_sel = lbl
descriptor[lbl_sel] = desc_sel
center = desc_sel['center']
for i,pt in enumerate(beats_v):
lbl_this_pt = labels[i]
if lbl_this_pt == lbl_sel:
dist = ddtw_tv_dist(list(range(len(pt))), pt, list(range(len(center))), center, use_diff_dtw = use_diff_dtw, time_weight = time_weight)
descriptor[lbl_sel]["dist_all_pt_cluster"].append(dist)
lbls_kept = list(descriptor.keys())
descriptor_seq = {}
for i,lbl in enumerate(lbls_kept):
raw_templates.append(descriptor[lbl]["point"])
descriptor_seq[i] = descriptor[lbl] #make all the ids of the tamplates sequential
print(f"\tFound labels and respective number of items (total items: {len(data)})\n\t{lbl_counts}\n\tof which, with more than {percentage_each_cluster}% associateb beats: {len(descriptor)}")
return raw_templates,descriptor_seq
def single_template(data, freq = FREQ * MULTIPLICATOR_FREQ, use_diff_dtw = False, time_weight = 1):
'''
Output:
raw_templates: list
descriptor:{$lbl:
{"center": C,
"point": P,
"dist": D,
"dist_all_pt_cluster":[]}}
'''
descriptor = {} # {$lbl: {"center": C, "point": P, "dist": D, "dist_all_pt_cluster":[]}}
raw_templates = []
beats_v = []
#print(f"Computing templates: num vectors: {len(data)}")
for annot in data.keys():
beats_v.append(data[annot]['v'])
#print(f"Computing templates: num vectors: {len(beats_v)}")
affinity_matr = compute_affinity_matr(beats_v, use_diff_dtw = use_diff_dtw, time_weight = time_weight)
pref = np.percentile(affinity_matr, 5)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
af = AffinityPropagation(damping = 0.7, preference = pref, affinity = 'precomputed').fit(affinity_matr)
cluster_centers_indices = af.cluster_centers_indices_
labels = af.labels_
unique, counts = np.unique(labels, return_counts=True)
for i,pt in enumerate(beats_v):
lbl_this_pt = labels[i] # The number samples in a class need to be bigger than x% of all samples
if lbl_this_pt == -1 or not is_smooth(pt,freq):
continue
cluster_center = beats_v[cluster_centers_indices[lbl_this_pt]]
dist = ddtw_tv_dist(list(range(len(pt))), pt, list(range(len(cluster_center))), cluster_center, use_diff_dtw = use_diff_dtw, time_weight = time_weight)
if lbl_this_pt not in descriptor.keys():
descriptor[lbl_this_pt] = {"center": cluster_center, "point": None, "dist": np.inf, "dist_all_pt_cluster": []}
if dist < descriptor[lbl_this_pt]['dist']:
descriptor[lbl_this_pt]['point'] = pt
descriptor[lbl_this_pt]['dist'] = dist
descriptor[lbl_this_pt]["dist_all_pt_cluster"].append(dist)
lbl_counts = dict(zip(unique, counts))
for _,l in sorted(zip(counts, unique), reverse=True):
if l in descriptor.keys():
descriptor_seq = {0:descriptor[l]}
raw_templates = descriptor[l]['point']
break
print(f"\tFound labels and respective number of items (total items: {len(data)})\n\t{lbl_counts}\n\tKept:{l}")
return raw_templates,descriptor_seq
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)")
parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true")
args = parser.parse_args()
files = os.listdir(data_beats_dir)
if args.file is not None:
analyzed = list(filter(lambda string: True if args.file in string else False, files))[0]
else:
analyzed = files[0]
if args.not_norm:
normalize = False
else:
normalize = True
template_dist,std,alligned = build_template(analyzed, normalize = normalize, mode = 'distance')
template_avg,_,_ = build_template(analyzed, normalize = normalize, mode = 'average')
plt.plot(template_dist, color = 'C0')
plt.plot(template_avg, color = 'C3')
plt.plot(np.array(template_dist)+np.array(std), color = "C1")
plt.plot(np.array(template_dist)-np.array(std), color = "C1")
for beat in alligned.keys():
plt.plot(alligned[beat]['v'],color = 'C2', alpha = 0.01)
plt.figure()
plt.plot(template_dist, color = 'C0')
plt.plot(template_avg, color = 'C3')
plt.show()

Event Timeline