diff --git a/src/beat_reconstruction_template_segmentation.py b/src/beat_reconstruction_template_segmentation.py index 8dc1371..677edc0 100644 --- a/src/beat_reconstruction_template_segmentation.py +++ b/src/beat_reconstruction_template_segmentation.py @@ -1,362 +1,371 @@ import multiprocessing import os import pickle import shutil from multiprocessing import Pool from time import time import matplotlib.pyplot as plt import numpy as np from dtaidistance import dtw from scipy.interpolate import interp1d from copy import copy from build_template import build_template data_beats_dir = "../data/beats/" log_dir = "../data/beat_recon_logs_" FREQ = 128 MIN_TO_AVG = 5 BEAT_TO_PLOT = 2 LEVEL = [1,2,3,4,5,6,7,8,9,10,11] -NUM_BEAT_ANALYZED = 5000 +NUM_BEAT_ANALYZED = 50 +TEMPLATE_TYPE = 'distance' def RMSE(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) return np.sqrt(np.mean((v1_n-v2_n)**2)) def RMSE_warped_resampled(data_orig,data_warped,data_resampled): rmse = {} for beat in data_orig.keys(): t_start = data_orig[beat]['t'][0] t_stop = data_orig[beat]['t'][-1] for lvl in data_warped[beat].keys(): if lvl not in rmse.keys(): rmse[lvl] = {'warp':[],'resamp':[]} v_warped = [] v_resampled = [] for i,t in enumerate(data_warped[beat][lvl]['t']): if t>=t_start and t<=t_stop: v_warped.append(data_warped[beat][lvl]['v'][i]) v_resampled.append(data_resampled[beat][lvl]['v'][i]) elif t > t_stop: break rmse[lvl]['warp'].append(RMSE(data_orig[beat]['v'],v_warped)) rmse[lvl]['resamp'].append(RMSE(data_orig[beat]['v'],v_resampled)) return rmse def open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = None): ''' Data structure: File: | DICTIONARY: | --> Beat_annot: | --> lvl: | --> "t": [] | --> "v": [] ''' file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) data = {} data_out = {} with open(file_name_full,"rb") as f: data = pickle.load(f) for k in data.keys(): if k > FREQ*start_after: data_out[k] = {} if get_selected_level is not None: data_out[k][0] = data[k][0] for lvl in get_selected_level: data_out[k][lvl] = data[k][lvl] else: data_out[k] = data[k] return data_out, list(data_out[k].keys()) def min_max_normalization(vector, forced_avg = None): mi_v = min(vector) ma_v = max(vector) norm = (np.array(vector)-mi_v)/(ma_v-mi_v) if forced_avg is not None: avg = forced_avg else: avg = np.average(norm) norm -= avg return list(norm),avg def resamp_one_signal(t,v,resample_type = 'linear', min_t = None, max_t = None): if resample_type == "linear": f = interp1d(t,v) elif resample_type == "flat": f = interp1d(t,v, kind = 'previous') if min_t is None: min_t = t[0] if max_t is None: max_t = t[-1] t_new = list(range(min_t,max_t+1)) v_new = f(t_new) return t_new,v_new def resample(data, resample_type = "linear", min_t = None, max_t = None): resampled_data = {"t":None,"v":None} t = data['t'] v = data['v'] t_r,v_r = resamp_one_signal(t,v,resample_type = resample_type, min_t = min_t, max_t = max_t) resampled_data['t'] = t_r resampled_data['v'] = v_r return resampled_data def change_sampling(vector, size_to_warp): resampled = [] if len(vector) < 3: #If the length of the vector is either 2 or 1 it will be mnorfed in a straight line and such line will be morfed equaly resampled = [vector[0]]*size_to_warp return resampled delta = abs(size_to_warp - len(vector)) # --> number of points to insert/delete #At least one more sample for each sample (except the last one)? --> insert k in each hole and recompute delta if delta >= len(vector)-1 and delta > 0: holes = len(vector)-1 k = delta // holes #Number of point for each hole time = range(0,len(vector)*(k+1),k+1) # [o1,i11,i12,...,i1k, o2,i21,i22,...,i2k, o3,..., o_fin] --> NOTE: last sampel is not taken # 0 k+1 2k+2 l*k +l --> Hence, we get to the exact end t_new = range(time[-1]+1) f = interp1d(time,vector) vector = list(f(t_new)) delta = abs(size_to_warp - len(vector)) if len(vector) == size_to_warp: return vector grad = np.gradient(vector) grad[-1] = max(grad) + 1 # --> we don't want to insert anything after the last sample grad_idxs = sorted(range(len(grad)), key = lambda idx: grad[idx]) idx_to_consider = grad_idxs[:delta] #print(delta, len(idx_to_consider), idx_to_consider) for i,sample in enumerate(vector): resampled.append(sample) if i in idx_to_consider: if size_to_warp < len(vector): resampled.pop() elif size_to_warp > len(vector): - resampled.append((vector[grad_idxs[i]]+vector[grad_idxs[i]+1])/2) + resampled.append((vector[i]+vector[i+1])/2) return resampled def segment_warp(segment): #print("--> ",len(segment['segment'])) #print("--> ",segment['length_to_warp']) resampled = np.array(change_sampling(segment['segment'], segment['length_to_warp'])) #print("--> ",len(resampled)) resampled += (segment['v_start'] - resampled[0]) m = (segment['v_stop'] - resampled[-1])/(segment['length_to_warp']-1) coef_to_sum = [m*x for x in range(segment['length_to_warp'])] resampled += coef_to_sum return list(resampled) def stitch_segments(segments): stitched = [] for i,segment in enumerate(segments): if i == len(segments)-1: stitched.extend(segment) else: stitched.extend(segment[0:-1]) return stitched def warp(data, template, events): #print(events) warped = {} segments = [] # --> {"segment":v, "length_to_warp":length, "v_start":v_eb_0, "v_stop":v_eb_1} segment = {"segment":None, "length_to_warp":None, "v_start":None, "v_stop":None} segment_start = None event_start = None #ad the first and last points of the resampled and truncated beat if not already in event if data['t'][0] not in events['t']: events['t'].insert(1, data['t'][0]) events['v'].insert(1, data['v'][0]) if data['t'][-1] not in events['t']: events['t'].insert(-1,data['t'][-1]) events['v'].insert(-1,data['v'][-1]) #print(events) #Apply DTW for matching resampled event to template v_src = data['v'] path = dtw.warping_path(v_src, template) #Remove the "left" steps from the path and segment the template based on the events point prev_idx_src = path[-1][0]+1 for idx_src, idx_temp in path: if prev_idx_src == idx_src: continue prev_idx_src = idx_src for idx, t in enumerate(events['t']): if data['t'][idx_src] == t: if segment_start == None: segment_start = idx_temp event_start = idx else: #print(f"SEGMENT {events['t'][event_start]} - {events['t'][idx]} ({abs(events['t'][event_start] - events['t'][idx])})") segment['segment'] = template[segment_start:idx_temp+1] segment['length_to_warp'] = events['t'][idx] - events['t'][event_start] + 1 segment['v_start'] = events['v'][event_start] segment['v_stop'] = events['v'][idx] w = segment_warp(segment) #print(len(w)) segments.append(w) segment_start = idx_temp event_start = idx break segment_stitched = stitch_segments(segments) #print(len(segment_stitched), len(data['t'])) warped = {'t':data['t'],'v':segment_stitched} return warped def reconstruct_beats(file_name, level = None, normalize = True, resample_type = 'linear', num_beats = None, verbose = False, prev_pre_proc = None): reconstructed = {} resamp = {} data_orig = {} pre_proc = {} if verbose: print(f"Extracting {file_name}") data,lvls = open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = level) lvls.remove(0) if verbose: print("Building template") if prev_pre_proc is not None: template = prev_pre_proc['template'] else: - template,_,_ = build_template(file_name, normalize = normalize) + template,_,_ = build_template(file_name, normalize = normalize, mode = TEMPLATE_TYPE) if num_beats == None: num_beats = len(list(data.keys())) if verbose: print("reconstructing") for lvl in lvls: i = 0 if verbose: print(f"Analyzing level:{lvl}") for beat in data.keys(): if i == num_beats: break i+=1 if (i%(num_beats/20)==0 or i == 1) and verbose: print(f"File: {file_name}, Reconstructing beat {beat} ({i}/{num_beats}: {100*i/num_beats}%, LEVEL:{lvl})") if beat not in reconstructed.keys(): reconstructed[beat] = {} resamp[beat] = {} data_orig[beat] = data[beat][0] if normalize: data_orig[beat]['v'],avg = min_max_normalization(data_orig[beat]['v'], forced_avg = None) data[beat][lvl]['v'],_ = min_max_normalization(data[beat][lvl]['v'], forced_avg = avg) data_resampled = resample(data[beat][lvl], resample_type = resample_type, min_t = data_orig[beat]['t'][0], max_t = data_orig[beat]['t'][-1]) reconstructed[beat][lvl] = warp(data_resampled,template,data[beat][lvl]) resamp[beat][lvl] = data_resampled pre_proc['template'] = template return reconstructed,data_orig,resamp,pre_proc def recontruct_and_compare(file): pre_proc = None + recon,orig,resamp,pre_proc = reconstruct_beats(file, level = LEVEL, resample_type = 'flat', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear + rmse = RMSE_warped_resampled(orig,recon,resamp) for lvl in LEVEL: + ''' recon,orig,resamp,pre_proc = reconstruct_beats(file, level = [lvl], resample_type = 'linear', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear rmse = RMSE_warped_resampled(orig,recon,resamp) - + ''' avg_rmse_warp = np.average(rmse[lvl]['warp']) std_rmse_warp = np.std(rmse[lvl]['warp']) avg_rmse_resamp = np.average(rmse[lvl]['resamp']) std_rmse_resamp = np.std(rmse[lvl]['resamp']) file_name_to_save = "L_"+file.split(".")[0]+".log" with open(os.path.join(log_dir,file_name_to_save),"a") as f: f.write(f"Lvl: {lvl}\n") f.write(f"\tWarp: {avg_rmse_warp}, +-{std_rmse_warp}\n") f.write(f"\tInterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") f.write(f"\n\n") print(f"File:{file_name_to_save}") print(f"\tLvl: {lvl}") print(f"\t\twarp: {avg_rmse_warp}, +-{std_rmse_warp}") print(f"\t\tinterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}") file_name_to_save_fig = os.path.join(log_dir,file.split(".")[0]+"_"+str(lvl)+".svg") file_name_to_save_fig_template = os.path.join(log_dir,file.split(".")[0]+"_template.svg") beat_to_plot = list(recon.keys())[BEAT_TO_PLOT] t_o,v_o = orig[beat_to_plot]['t'],orig[beat_to_plot]['v'] t_a,v_a = recon[beat_to_plot][lvl]['t'],recon[beat_to_plot][lvl]['v'] t_r,v_r = resamp[beat_to_plot][lvl]['t'],resamp[beat_to_plot][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot}') plt.savefig(file_name_to_save_fig) plt.close() if not os.path.isfile(file_name_to_save_fig_template): template = pre_proc['template'] plt.figure() plt.plot(template) plt.legend(['template']) plt.title(f'File: {file}, Beat time (samples):{beat_to_plot}') plt.savefig(file_name_to_save_fig_template) plt.close() def process(files, multi=True, cores=1): # ------------ INIT ------------ global log_dir for i in range(1,1000): tmp_log_dir = log_dir+str(i) if not os.path.isdir(tmp_log_dir): log_dir = tmp_log_dir break os.mkdir(log_dir) # ------------ Extract DATA & ANNOTATIONS ------------ with Pool(cores) as pool: pool.map(recontruct_and_compare, files) if __name__ == "__main__": import argparse #global NUM_BEAT_ANALYZED parser = argparse.ArgumentParser() parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true") parser.add_argument("--cores", help="Force used number of cores (default, half of the available ones") parser.add_argument("--beats", help="Number of used beats, default: 5000") + parser.add_argument("--template_type", help="Type of template, default: distance") args = parser.parse_args() files = os.listdir(data_beats_dir) if args.file is not None: if args.file == 'all': analyzed = files else: analyzed = list(filter(lambda string: True if args.file in string else False, files)) else: analyzed = [files[0]] if args.not_norm: normalize = False else: normalize = True if args.cores is not None: used_cores = int(args.cores) else: used_cores = multiprocessing.cpu_count()//2 if args.beats is not None: NUM_BEAT_ANALYZED = int(args.beats) else: NUM_BEAT_ANALYZED = 5000 + if args.template_type is not None: + TEMPLATE_TYPE = 'avreage' + else: + TEMPLATE_TYPE = 'distance' print(f"Analyzing files: {analyzed}") print(f"Extracting data with {used_cores} cores...") process(files = analyzed, multi=True, cores=used_cores) \ No newline at end of file diff --git a/src/build_template.py b/src/build_template.py index 2d37a72..8635c35 100644 --- a/src/build_template.py +++ b/src/build_template.py @@ -1,133 +1,156 @@ import numpy as np import pickle import matplotlib.pyplot as plt import os data_beats_dir = "../data/beats/" FREQ = 128 MIN_TO_AVG = 5 def open_file(file_name): ''' Data structure: File: | DICTIONARY: | --> Beat_annot: | --> lvl: | --> "t": [] | --> "v": [] ''' file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) data = {} with open(file_name_full,"rb") as f: data = pickle.load(f) for k in data.keys(): data[k] = data[k][0] return data def min_max_normalization(vector): mi_v = min(vector) ma_v = max(vector) norm = (np.array(vector)-mi_v)/(ma_v-mi_v) norm -= np.average(norm) return list(norm) def get_significant_peak(signal): val = -1000 peak_max = None is_peak = lambda w,x,y,z: True if (wy) or (wz) else False for i in range(1,len(signal)-2): if is_peak(signal[i-1],signal[i],signal[i+1],signal[i+2]) and signal[i]> val: peak_max = i val = signal[i] return peak_max def get_5_min_beats(data): beats = {} annotations = list(data.keys()) for annot in annotations[1:]: if annot > int(MIN_TO_AVG*60*FREQ): break beats[annot] = data[annot] return beats def get_r_peaks_idx(beats): r_times_idx = [] time_to_look = int(0.15*FREQ) #150 ms for annot in beats.keys(): idx = 0 signal_r_peaks = [] for i,t in enumerate(beats[annot]['t']): if t == annot - time_to_look: idx = i if t >= annot - time_to_look and t<= annot + time_to_look: signal_r_peaks.append(beats[annot]['v'][i]) peak_idx = get_significant_peak(signal_r_peaks)+idx r_times_idx.append(peak_idx) return r_times_idx def allign_beats(data, allignment_idxs, normalize = True): len_bef = len_aft = 1000 data_alligned = {} for annot, allignment_idx in zip(data.keys(),allignment_idxs): time = data[annot]['t'] this_len_bef = len(time[:allignment_idx]) this_len_aft = len(time[allignment_idx+1:]) if this_len_bef < len_bef: len_bef = this_len_bef if this_len_aft < len_aft: len_aft = this_len_aft for annot, allignment_idx in zip(data.keys(),allignment_idxs): new_t = data[annot]['t'][allignment_idx-len_bef:allignment_idx+len_aft+1] new_v = data[annot]['v'][allignment_idx-len_bef:allignment_idx+len_aft+1] if normalize: new_v = min_max_normalization(new_v) data_alligned[annot] = {'t':new_t,'v':new_v} return data_alligned -def build_template(file_analyzed, normalize = True): +def RMSE(v1,v2): + v1_n = np.array(v1) + v2_n = np.array(v2) + return np.sqrt(np.mean((v1_n-v2_n)**2)) + +def compute_template_lowest_distance(beats): + dist = np.zeros((len(beats),len(beats))) + for i in range(len(beats)): + for j in range(i+1,len(beats)): + dist[i][j] = RMSE(beats[i],beats[j]) + dist[j][i] = dist[i][j] + dist = np.sum(dist, axis=0) + idx = np.argmin(dist) + return beats[idx] + +def build_template(file_analyzed, normalize = True, mode = 'average'): template = [] beats_alligned = [] data = open_file(file_analyzed) beats = get_5_min_beats(data) allignment_idxs = get_r_peaks_idx(beats) alligned_beats = allign_beats(beats, allignment_idxs, normalize=normalize) for annot in alligned_beats.keys(): beats_alligned.append(alligned_beats[annot]['v']) - template = list(np.average(beats_alligned,axis=0)) + if mode == 'average': + template = list(np.average(beats_alligned,axis=0)) + elif mode == 'distance': + template = compute_template_lowest_distance(beats_alligned) template_std = list(np.std(beats_alligned,axis=0)) return template, template_std, alligned_beats if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true") args = parser.parse_args() files = os.listdir(data_beats_dir) if args.file is not None: analyzed = list(filter(lambda string: True if args.file in string else False, files))[0] else: analyzed = files[0] if args.not_norm: normalize = False else: normalize = True - template,std,alligned = build_template(analyzed, normalize = normalize) + template_dist,std,alligned = build_template(analyzed, normalize = normalize, mode = 'distance') + template_avg,_,_ = build_template(analyzed, normalize = normalize, mode = 'average') - plt.plot(template, color = 'C0') - plt.plot(np.array(template)+np.array(std), color = "C1") - plt.plot(np.array(template)-np.array(std), color = "C1") + plt.plot(template_dist, color = 'C0') + plt.plot(template_avg, color = 'C3') + plt.plot(np.array(template_dist)+np.array(std), color = "C1") + plt.plot(np.array(template_dist)-np.array(std), color = "C1") for beat in alligned.keys(): plt.plot(alligned[beat]['v'],color = 'C2', alpha = 0.01) + plt.figure() + plt.plot(template_dist, color = 'C0') + plt.plot(template_avg, color = 'C3') plt.show()