diff --git a/src/build_template.py b/src/build_template.py index 8be5af5..c3f21bf 100644 --- a/src/build_template.py +++ b/src/build_template.py @@ -1,210 +1,210 @@ import numpy as np import pickle import matplotlib.pyplot as plt import os from sklearn.cluster import MeanShift, estimate_bandwidth from scipy.signal import medfilt as median_filter data_beats_dir = "../data/beats/" FREQ = 128 MIN_TO_AVG = 5 def open_file(file_name): ''' Data structure: File: | DICTIONARY: | --> Beat_annot: | --> lvl: | --> "t": [] | --> "v": [] ''' file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) data = {} with open(file_name_full,"rb") as f: data = pickle.load(f) for k in data.keys(): data[k] = data[k][0] return data def min_max_normalization(vector): mi_v = min(vector) ma_v = max(vector) norm = (np.array(vector)-mi_v)/(ma_v-mi_v) norm -= np.average(norm) return list(norm) def get_significant_peak(signal): val = -1000 peak_max = None is_peak = lambda w,x,y,z: True if (wy) or (wz) else False for i in range(1,len(signal)-2): if is_peak(signal[i-1],signal[i],signal[i+1],signal[i+2]) and signal[i]> val: peak_max = i val = signal[i] return peak_max def get_beats(data, t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60): beats = {} annotations = list(data.keys()) for annot in annotations[1:]: if annot >= int(t_start_seconds*FREQ) and annot <= int(t_stop_seconds*FREQ): beats[annot] = data[annot] elif annot > int(t_stop_seconds*FREQ): break return beats def get_r_peaks_idx(beats): r_times_idx = [] time_to_look = int(0.15*FREQ) #150 ms for annot in beats.keys(): idx = 0 signal_r_peaks = [] for i,t in enumerate(beats[annot]['t']): if t == annot - time_to_look: idx = i if t >= annot - time_to_look and t<= annot + time_to_look: signal_r_peaks.append(beats[annot]['v'][i]) peak_idx = get_significant_peak(signal_r_peaks)+idx r_times_idx.append(peak_idx) return r_times_idx def allign_beats(data, allignment_idxs, normalize = True): len_bef = len_aft = 1000 data_alligned = {} for annot, allignment_idx in zip(data.keys(),allignment_idxs): time = data[annot]['t'] this_len_bef = len(time[:allignment_idx]) this_len_aft = len(time[allignment_idx+1:]) if this_len_bef < len_bef: len_bef = this_len_bef if this_len_aft < len_aft: len_aft = this_len_aft for annot, allignment_idx in zip(data.keys(),allignment_idxs): new_t = data[annot]['t'][allignment_idx-len_bef:allignment_idx+len_aft+1] new_v = data[annot]['v'][allignment_idx-len_bef:allignment_idx+len_aft+1] if normalize: new_v = min_max_normalization(new_v) data_alligned[annot] = {'t':new_t,'v':new_v} return data_alligned def RMSE(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) return np.sqrt(np.mean((v1_n-v2_n)**2)) def compute_template_lowest_distance(beats): dist = np.zeros((len(beats),len(beats))) for i in range(len(beats)): for j in range(i+1,len(beats)): dist[i][j] = RMSE(beats[i],beats[j]) dist[j][i] = dist[i][j] dist = np.sum(dist, axis=0) idx = np.argmin(dist) return beats[idx] def build_template(file_analyzed, normalize = True, mode = 'average', t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60): template = [] beats_alligned = [] data = open_file(file_analyzed) beats = get_beats(data, t_start_seconds = t_start_seconds, t_stop_seconds = t_stop_seconds) allignment_idxs = get_r_peaks_idx(beats) alligned_beats = allign_beats(beats, allignment_idxs, normalize=normalize) for annot in alligned_beats.keys(): beats_alligned.append(alligned_beats[annot]['v']) if mode == 'average': template = list(np.average(beats_alligned,axis=0)) elif mode == 'distance': template = compute_template_lowest_distance(beats_alligned) template_std = list(np.std(beats_alligned,axis=0)) return template, template_std, alligned_beats def prune_template(template): filtered = median_filter(template,3) noise = np.array(template) - filtered p_filtered = sum(filtered**2)/len(filtered) p_noise = sum(noise**2)/len(noise) SNR = p_filtered/p_noise if SNR > 50: return False else: return True -def multi_template(file_analyzed, normalize = True, t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60): +def multi_template(file_analyzed, normalize = True, t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60, percentage_each_cluster = 7): # Extract and allign beats beats_alligned = [] data = open_file(file_analyzed) beats = get_beats(data, t_start_seconds = t_start_seconds, t_stop_seconds = t_stop_seconds) allignment_idxs = get_r_peaks_idx(beats) alligned_beats = allign_beats(beats, allignment_idxs, normalize=normalize) for annot in alligned_beats.keys(): beats_alligned.append(alligned_beats[annot]['v']) # Find centroid clusters through mean shift bandwidth = estimate_bandwidth(beats_alligned, quantile=0.05) ms = MeanShift(bandwidth=bandwidth) #ms = MeanShift() ms.fit(beats_alligned) cluster_centers = ms.cluster_centers_ lbls = ms.labels_ unique, counts = np.unique(lbls, return_counts=True) lbl_counts = dict(zip(unique, counts)) #Filter the found labels and find the nearest point to centroids nearest = {} # {$lbl: "center": C, "point": P, "dist": D} for i,pt in enumerate(beats_alligned): #------------------------- lbl_this_pt = lbls[i] # The number samples in a class need to be bigger than x% of all samples #<-- UBER-IMPORTANT HERE | - if lbl_this_pt == -1 or lbl_counts[lbl_this_pt] < int(len(lbls)*7/100) or prune_template(pt): #------------------------- + if lbl_this_pt == -1 or lbl_counts[lbl_this_pt] < int(len(lbls)*percentage_each_cluster/100) or prune_template(pt): #------------------------- continue dist = np.linalg.norm(pt - cluster_centers[lbl_this_pt]) if lbl_this_pt not in nearest.keys(): nearest[lbl_this_pt] = {"center": None, "point": None, "dist": np.inf} if nearest[lbl_this_pt]['dist'] > dist: nearest[lbl_this_pt] = {"center": cluster_centers[lbl_this_pt], "point": pt, "dist": dist} rep_points = [] for lbl in nearest.keys(): rep_points.append(nearest[lbl]['point']) rep_points = np.array(rep_points) - print(f"Found labels and respective number of items (total items: {len(beats_alligned)})\n{lbl_counts}\nof which, with more than 7% associateb beats: {len(rep_points)}") + print(f"Found labels and respective number of items (total items: {len(beats_alligned)})\n{lbl_counts}\nof which, with more than {percentage_each_cluster}% associateb beats: {len(rep_points)}") return rep_points, alligned_beats if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true") args = parser.parse_args() files = os.listdir(data_beats_dir) if args.file is not None: analyzed = list(filter(lambda string: True if args.file in string else False, files))[0] else: analyzed = files[0] if args.not_norm: normalize = False else: normalize = True template_dist,std,alligned = build_template(analyzed, normalize = normalize, mode = 'distance') template_avg,_,_ = build_template(analyzed, normalize = normalize, mode = 'average') plt.plot(template_dist, color = 'C0') plt.plot(template_avg, color = 'C3') plt.plot(np.array(template_dist)+np.array(std), color = "C1") plt.plot(np.array(template_dist)-np.array(std), color = "C1") for beat in alligned.keys(): plt.plot(alligned[beat]['v'],color = 'C2', alpha = 0.01) plt.figure() plt.plot(template_dist, color = 'C0') plt.plot(template_avg, color = 'C3') plt.show() diff --git a/src/multiple_templates_beginning.py b/src/multiple_templates_beginning.py index 9ddfd0f..ec20217 100644 --- a/src/multiple_templates_beginning.py +++ b/src/multiple_templates_beginning.py @@ -1,567 +1,567 @@ import multiprocessing import os import pickle import shutil from multiprocessing import Pool from time import time import matplotlib.pyplot as plt import numpy as np from dtaidistance import dtw from scipy.interpolate import interp1d from copy import copy from build_template import build_template, multi_template data_beats_dir = "../data/beats/" log_dir = "../data/beat_recon_logs_multi_beginning_" FREQ = 128 MIN_TO_AVG = 5 BEAT_TO_PLOT = 2 LEVEL = [3,4,5,6,7,8] NUM_BEAT_ANALYZED = 50 LEN_DISTANCE_VECTOR = 50 TEMPLATE_TYPE = 'distance' PERC_BINS = 15 def RMSE(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) return np.sqrt(np.mean((v1_n-v2_n)**2)) def RMSE_warped_resampled(data_orig,data_warped,data_resampled): rmse = {} template = {'warp':None,'resamp':None,'warp_left':None,'resamp_left':None,'warp_right':None,'resamp_right':None} for beat in data_warped.keys(): for lvl in data_warped[beat].keys(): if lvl not in rmse.keys(): rmse[lvl] = {} rmse[lvl][beat] = copy(template) idx_beat = data_warped[beat][lvl]['t'].index(beat) l_resamp = data_resampled[beat][lvl]['v'][:idx_beat] r_resamp = data_resampled[beat][lvl]['v'][idx_beat+1:] l_warp = data_warped[beat][lvl]['v'][:idx_beat] r_warp = data_warped[beat][lvl]['v'][idx_beat+1:] l_orig = data_orig[beat]['v'][:idx_beat] r_orig = data_orig[beat]['v'][idx_beat+1:] rmse[lvl][beat]['warp'] = RMSE(data_orig[beat]['v'],data_warped[beat][lvl]['v']) rmse[lvl][beat]['resamp'] = RMSE(data_orig[beat]['v'],data_resampled[beat][lvl]['v']) rmse[lvl][beat]['warp_left'] = RMSE(l_orig,l_warp) rmse[lvl][beat]['resamp_left'] = RMSE(l_orig,l_resamp) rmse[lvl][beat]['warp_right'] = RMSE(r_orig,r_warp) rmse[lvl][beat]['resamp_right'] = RMSE(r_orig,r_resamp) return rmse def open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = None): ''' Data structure: File: | DICTIONARY: | --> Beat_annot: | --> lvl: | --> "t": [] | --> "v": [] ''' file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) data = {} data_out = {} with open(file_name_full,"rb") as f: data = pickle.load(f) for k in data.keys(): if k > FREQ*start_after: data_out[k] = {} if get_selected_level is not None: data_out[k][0] = data[k][0] for lvl in get_selected_level: data_out[k][lvl] = data[k][lvl] else: data_out[k] = data[k] return data_out, list(data_out[k].keys()) def min_max_normalization(vector, params = None): params_out = {} #print("\t",params) if params is not None: mi_v = params['min'] ma_v = params['max'] avg = params['avg'] norm = (np.array(vector)-mi_v)/(ma_v-mi_v) norm -= avg params_out = params else: mi_v = min(vector) ma_v = max(vector) norm = (np.array(vector)-mi_v)/(ma_v-mi_v) avg = np.average(norm) norm -= avg params_out['min'] = mi_v params_out['max'] = ma_v params_out['avg'] = avg return list(norm),params_out def resamp_one_signal(t,v,resample_type = 'linear', min_t = None, max_t = None): if resample_type == "linear": f = interp1d(t,v) elif resample_type == "flat": f = interp1d(t,v, kind = 'previous') if min_t is None: min_t = t[0] if max_t is None: max_t = t[-1] t_new = list(range(min_t,max_t+1)) v_new = f(t_new) return t_new,v_new def resample(data, resample_type = "linear", min_t = None, max_t = None): resampled_data = {"t":None,"v":None} t = data['t'] v = data['v'] t_r,v_r = resamp_one_signal(t,v,resample_type = resample_type, min_t = min_t, max_t = max_t) resampled_data['t'] = t_r resampled_data['v'] = v_r return resampled_data def change_sampling(vector, size_to_warp): resampled = [] if len(vector) < 3: #If the length of the vector is either 2 or 1 it will be mnorfed in a straight line and such line will be morfed equaly resampled = [vector[0]]*size_to_warp return resampled delta = abs(size_to_warp - len(vector)) # --> number of points to insert/delete #At least one more sample for each sample (except the last one)? --> insert k in each hole and recompute delta if delta >= len(vector)-1 and delta > 0: holes = len(vector)-1 k = delta // holes #Number of point for each hole time = range(0,len(vector)*(k+1),k+1) # [o1,i11,i12,...,i1k, o2,i21,i22,...,i2k, o3,..., o_fin] --> NOTE: last sampel is not taken # 0 k+1 2k+2 l*k +l --> Hence, we get to the exact end t_new = range(time[-1]+1) f = interp1d(time,vector) vector = list(f(t_new)) delta = abs(size_to_warp - len(vector)) if len(vector) == size_to_warp: return vector grad = np.gradient(vector) grad[-1] = max(grad) + 1 # --> we don't want to insert anything after the last sample grad_idxs = sorted(range(len(grad)), key = lambda idx: grad[idx]) idx_to_consider = grad_idxs[:delta] #print(delta, len(idx_to_consider), idx_to_consider) for i,sample in enumerate(vector): resampled.append(sample) if i in idx_to_consider: if size_to_warp < len(vector): resampled.pop() elif size_to_warp > len(vector): resampled.append((vector[i]+vector[i+1])/2) return resampled def segment_warp(segment): #print("--> ",len(segment['segment'])) #print("--> ",segment['length_to_warp']) resampled = np.array(change_sampling(segment['segment'], segment['length_to_warp'])) #print("--> ",len(resampled)) resampled += (segment['v_start'] - resampled[0]) m = (segment['v_stop'] - resampled[-1])/(segment['length_to_warp']-1) coef_to_sum = [m*x for x in range(segment['length_to_warp'])] resampled += coef_to_sum return list(resampled) def stitch_segments(segments): stitched = [] for i,segment in enumerate(segments): if i == len(segments)-1: stitched.extend(segment) else: stitched.extend(segment[0:-1]) return stitched def warp(data, templates, events): #print(events) warped = {} segments = [] # --> {"segment":v, "length_to_warp":length, "v_start":v_eb_0, "v_stop":v_eb_1} segment = {"segment":None, "length_to_warp":None, "v_start":None, "v_stop":None} segment_start = None event_start = None #chosen_template = [] #ad the first and last points of the resampled and truncated beat if not already in event if data['t'][0] not in events['t']: events['t'].insert(1, data['t'][0]) events['v'].insert(1, data['v'][0]) if data['t'][-1] not in events['t']: events['t'].insert(-1,data['t'][-1]) events['v'].insert(-1,data['v'][-1]) #print(events) #Apply DTW for matching resampled event to template v_src = data['v'] #This is how its actualy done on the library when calling 'warping_path' dist = float('inf') paths = [] selected_template = [] template_pos = None for i,t in enumerate(templates): dist_this_template, paths_this_template = dtw.warping_paths(v_src, t) if dist_this_template < dist: dist = dist_this_template paths = paths_this_template selected_template = t template_pos = i path = dtw.best_path(paths) #path = dtw.warping_path(v_src, template) #Remove the "left" steps from the path and segment the template based on the events point prev_idx_src = path[-1][0]+1 for idx_src, idx_temp in path: if prev_idx_src == idx_src: continue prev_idx_src = idx_src for idx, t in enumerate(events['t']): if data['t'][idx_src] == t: if segment_start == None: segment_start = idx_temp event_start = idx else: #print(f"SEGMENT {events['t'][event_start]} - {events['t'][idx]} ({abs(events['t'][event_start] - events['t'][idx])})") segment['segment'] = selected_template[segment_start:idx_temp+1] segment['length_to_warp'] = events['t'][idx] - events['t'][event_start] + 1 segment['v_start'] = events['v'][event_start] segment['v_stop'] = events['v'][idx] w = segment_warp(segment) #print(len(w)) segments.append(w) segment_start = idx_temp event_start = idx break segment_stitched = stitch_segments(segments) #print(len(segment_stitched), len(data['t'])) warped = {'t':data['t'],'v':segment_stitched} return dist, warped, template_pos def reconstruct_beats(file_name, level = None, normalize = True, resample_type = 'linear', num_beats = None, verbose = False): reconstructed = {} recon_cost = {} resamp = {} data_orig = {} templates = [] templates_orig = [] templates_usage = {} if verbose: print(f"Extracting {file_name}") data,lvls = open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = level) lvls.remove(0) if verbose: print("Building template") - templates_orig,_ = multi_template(file_name, normalize = normalize, t_start_seconds=0,t_stop_seconds=60*5) + templates_orig,_ = multi_template(file_name, normalize = normalize, t_start_seconds=0,t_stop_seconds=60*5, percentage_each_cluster = 4) if verbose: print(f"Initial number of templates: {len(templates_orig)}") if num_beats == None: num_beats = len(list(data.keys())) if verbose: print("reconstructing") for lvl in lvls: i = 0 recon_cost[lvl] = [] if verbose: print(f"Analyzing level:{lvl}") templates = copy(templates_orig) templates_usage_lvl = [0]*len(templates) for beat in data.keys(): if i == num_beats: break i+=1 if (i%(num_beats/20)==0 or i == 1) and verbose: print(f"File: {file_name}, Reconstructing beat {beat} ({i}/{num_beats}: {100*i/num_beats}%, LEVEL:{lvl})") if beat not in reconstructed.keys(): reconstructed[beat] = {} resamp[beat] = {} data_orig[beat] = copy(data[beat][0]) if normalize: data_orig[beat]['v'],params_prev = min_max_normalization(data_orig[beat]['v'], params = None) data[beat][lvl]['v'],_ = min_max_normalization(data[beat][lvl]['v'], params = params_prev) data_resampled = resample(data[beat][lvl], resample_type = resample_type, min_t = data_orig[beat]['t'][0], max_t = data_orig[beat]['t'][-1]) _, reconstructed[beat][lvl],template_pos = warp(data_resampled,templates,data[beat][lvl]) resamp[beat][lvl] = data_resampled templates_usage_lvl[template_pos] += 1 templates_usage[lvl] = templates_usage_lvl return reconstructed,data_orig,resamp,templates,templates_usage def percentile_idx(vector,perc): pcen=np.percentile(np.array(vector),perc,interpolation='nearest') i_near=abs(np.array(vector)-pcen).argmin() return i_near def correlate(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) cov = np.cov(v1_n,v2_n)[0,1] cor = cov/(np.std(v1_n)*np.std(v2_n)) return cor def avg_std_for_rmse_dict(dictionary): dict_out = {'warp':{'avg':None,'std':None},'resamp':{'avg':None,'std':None}} wr = [] rs = [] for beat in dictionary.keys(): wr.append(dictionary[beat]['warp']) rs.append(dictionary[beat]['resamp']) dict_out['warp']['avg'] = np.average(wr) dict_out['warp']['std'] = np.std(wr) dict_out['resamp']['avg'] = np.average(rs) dict_out['resamp']['std'] = np.std(rs) return dict_out def percentile_rmse_beat(rmse,perc): vec = [] for beat in rmse.keys(): vec.append(rmse[beat]['warp']) idx = percentile_idx(vec,perc) return list(rmse.keys())[idx] def recontruct_and_compare(file): recon,orig,resamp,templates, templates_usage = reconstruct_beats(file, level = LEVEL, resample_type = 'flat', num_beats = NUM_BEAT_ANALYZED, verbose = True) #resample_type = flat vs linear rmse = RMSE_warped_resampled(orig,recon,resamp) for lvl in LEVEL: ''' recon,orig,resamp,pre_proc = reconstruct_beats(file, level = [lvl], resample_type = 'linear', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear rmse = RMSE_warped_resampled(orig,recon,resamp) ''' stats = avg_std_for_rmse_dict(rmse[lvl]) avg_rmse_warp = stats['warp']['avg'] std_rmse_warp = stats['warp']['std'] avg_rmse_resamp = stats['resamp']['avg'] std_rmse_resamp = stats['resamp']['std'] #cov_rmse_warp_dist = correlate(recon_cost[lvl],rmse[lvl]) file_name_to_save = "L_"+file.split(".")[0]+".log" with open(os.path.join(log_dir,file_name_to_save),"a") as f: f.write(f"Lvl: {lvl}\n") f.write(f"\tWarp: {avg_rmse_warp}, +-{std_rmse_warp}\n") f.write(f"\tInterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") for i,templates_used in enumerate(templates_usage[lvl]): f.write(f"\tTemplate {i} was used {templates_used} times out of {sum(templates_usage[lvl])}\n") f.write(f"\n\n") print(f"File:{file_name_to_save}") print(f"\tLvl: {lvl}") print(f"\t\twarp: {avg_rmse_warp}, +-{std_rmse_warp}") print(f"\t\tinterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}") for i,templates_used in enumerate(templates_usage[lvl]): print(f"\tTemplate {i} was used {templates_used} times out of {sum(templates_usage[lvl])}") beat_to_plot_01 = percentile_rmse_beat(rmse[lvl],1) beat_to_plot_25 = percentile_rmse_beat(rmse[lvl],25) beat_to_plot_50 = percentile_rmse_beat(rmse[lvl],50) beat_to_plot_75 = percentile_rmse_beat(rmse[lvl],75) beat_to_plot_99 = percentile_rmse_beat(rmse[lvl],99) file_name_to_save_fig_01 = os.path.join(log_dir,file.split(".")[0]+"_01_perc"+str(lvl)+".svg") file_name_to_save_fig_25 = os.path.join(log_dir,file.split(".")[0]+"_25_perc"+str(lvl)+".svg") file_name_to_save_fig_50 = os.path.join(log_dir,file.split(".")[0]+"_50_perc"+str(lvl)+".svg") file_name_to_save_fig_75 = os.path.join(log_dir,file.split(".")[0]+"_75_perc"+str(lvl)+".svg") file_name_to_save_fig_99 = os.path.join(log_dir,file.split(".")[0]+"_99_perc"+str(lvl)+".svg") file_name_to_save_fig_hist = os.path.join(log_dir,file.split(".")[0]+"_hist"+str(lvl)+".svg") file_name_to_save_fig_hist_left = os.path.join(log_dir,file.split(".")[0]+"_hist_left"+str(lvl)+".svg") file_name_to_save_fig_hist_right = os.path.join(log_dir,file.split(".")[0]+"_hist_right"+str(lvl)+".svg") # 01 percentile t_o,v_o = orig[beat_to_plot_01]['t'],orig[beat_to_plot_01]['v'] t_a,v_a = recon[beat_to_plot_01][lvl]['t'],recon[beat_to_plot_01][lvl]['v'] t_r,v_r = resamp[beat_to_plot_01][lvl]['t'],resamp[beat_to_plot_01][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_01}, 01 percentile') plt.savefig(file_name_to_save_fig_01) plt.close() # 25 percentile t_o,v_o = orig[beat_to_plot_25]['t'],orig[beat_to_plot_25]['v'] t_a,v_a = recon[beat_to_plot_25][lvl]['t'],recon[beat_to_plot_25][lvl]['v'] t_r,v_r = resamp[beat_to_plot_25][lvl]['t'],resamp[beat_to_plot_25][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_25}, 25 percentile') plt.savefig(file_name_to_save_fig_25) plt.close() # 50 percentile t_o,v_o = orig[beat_to_plot_50]['t'],orig[beat_to_plot_50]['v'] t_a,v_a = recon[beat_to_plot_50][lvl]['t'],recon[beat_to_plot_50][lvl]['v'] t_r,v_r = resamp[beat_to_plot_50][lvl]['t'],resamp[beat_to_plot_50][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_50}, 50 percentile') plt.savefig(file_name_to_save_fig_50) plt.close() # 75 percentile t_o,v_o = orig[beat_to_plot_75]['t'],orig[beat_to_plot_75]['v'] t_a,v_a = recon[beat_to_plot_75][lvl]['t'],recon[beat_to_plot_75][lvl]['v'] t_r,v_r = resamp[beat_to_plot_75][lvl]['t'],resamp[beat_to_plot_75][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_75}, 75 percentile') plt.savefig(file_name_to_save_fig_75) plt.close() # 99 percentile t_o,v_o = orig[beat_to_plot_99]['t'],orig[beat_to_plot_99]['v'] t_a,v_a = recon[beat_to_plot_99][lvl]['t'],recon[beat_to_plot_99][lvl]['v'] t_r,v_r = resamp[beat_to_plot_99][lvl]['t'],resamp[beat_to_plot_99][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_99}, 99 percentile') plt.savefig(file_name_to_save_fig_99) plt.close() #bar plots rmse_warp = [] rmse_resamp = [] rmse_warp_left = [] rmse_resamp_left = [] rmse_warp_right = [] rmse_resamp_right = [] for beat in rmse[lvl].keys(): rmse_warp.append(rmse[lvl][beat]['warp']) rmse_resamp.append(rmse[lvl][beat]['resamp']) rmse_warp_left.append(rmse[lvl][beat]['warp_left']) rmse_resamp_left.append(rmse[lvl][beat]['resamp_left']) rmse_warp_right.append(rmse[lvl][beat]['warp_right']) rmse_resamp_right.append(rmse[lvl][beat]['resamp_right']) n_bins = len(rmse_warp)*PERC_BINS//100 min_bin = min(min(rmse_warp),min(rmse_resamp)) max_bin = max(max(rmse_warp),max(rmse_resamp)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp, bins = bins, alpha=0.5) plt.hist(rmse_resamp, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist) plt.close() n_bins = len(rmse_warp_left)*PERC_BINS//100 min_bin = min(min(rmse_warp_left),min(rmse_resamp_left)) max_bin = max(max(rmse_warp_left),max(rmse_resamp_left)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp_left, bins = bins, alpha=0.5) plt.hist(rmse_resamp_left, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, left RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist_left) plt.close() n_bins = len(rmse_warp_right)*PERC_BINS//100 min_bin = min(min(rmse_warp_right),min(rmse_resamp_right)) max_bin = max(max(rmse_warp_right),max(rmse_resamp_right)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp_right, bins = bins, alpha=0.5) plt.hist(rmse_resamp_right, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, right RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist_right) plt.close() for i, templ in enumerate(templates): file_name_to_save_fig_template = os.path.join(log_dir,file.split(".")[0]+f"_template_{str(i)}.svg") if not os.path.isfile(file_name_to_save_fig_template): plt.figure() plt.plot(templ) plt.legend(['template']) plt.title(f'File: {file}, template') plt.savefig(file_name_to_save_fig_template) plt.close() else: break def process(files, multi=True, cores=1): # ------------ INIT ------------ global log_dir for i in range(1,1000): tmp_log_dir = log_dir+str(i) if not os.path.isdir(tmp_log_dir): log_dir = tmp_log_dir break os.mkdir(log_dir) # ------------ Extract DATA & ANNOTATIONS ------------ with Pool(cores) as pool: pool.map(recontruct_and_compare, files) if __name__ == "__main__": import argparse #global NUM_BEAT_ANALYZED parser = argparse.ArgumentParser() parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true") parser.add_argument("--cores", help="Force used number of cores (default, half of the available ones") parser.add_argument("--beats", help="Number of used beats, default: 5000") parser.add_argument("--template_type", help="Type of template, default: distance") args = parser.parse_args() files = os.listdir(data_beats_dir) if args.file is not None: if args.file == 'all': analyzed = files else: analyzed = list(filter(lambda string: True if args.file in string else False, files)) else: analyzed = [files[0]] if args.not_norm: normalize = False else: normalize = True if args.cores is not None: used_cores = int(args.cores) else: used_cores = multiprocessing.cpu_count()//2 if args.beats is not None: NUM_BEAT_ANALYZED = int(args.beats) else: NUM_BEAT_ANALYZED = 5000 if args.template_type is not None: TEMPLATE_TYPE = 'average' else: TEMPLATE_TYPE = 'distance' print(f"Analyzing files: {analyzed}") print(f"Extracting data with {used_cores} cores...") process(files = analyzed, multi=True, cores=used_cores) \ No newline at end of file diff --git a/src/multiple_templates_prog.py b/src/multiple_templates_prog.py index dd80c0d..00fb1de 100644 --- a/src/multiple_templates_prog.py +++ b/src/multiple_templates_prog.py @@ -1,660 +1,660 @@ import multiprocessing import os import pickle import shutil from multiprocessing import Pool from time import time import matplotlib.pyplot as plt import numpy as np from dtaidistance import dtw from scipy.interpolate import interp1d from copy import copy from build_template import build_template, multi_template data_beats_dir = "../data/beats/" log_dir = "../data/beat_recon_logs_multi_prog_" FREQ = 128 MIN_TO_AVG = 5 BEAT_TO_PLOT = 2 LEVEL = [3,4,5,6,7,8] NUM_BEAT_ANALYZED = 50 LEN_DISTANCE_VECTOR = 50 TEMPLATE_TYPE = 'distance' PERC_BINS = 15 def RMSE(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) return np.sqrt(np.mean((v1_n-v2_n)**2)) def RMSE_warped_resampled(data_orig,data_warped,data_resampled): rmse = {} template = {'warp':None,'resamp':None,'warp_left':None,'resamp_left':None,'warp_right':None,'resamp_right':None} for beat in data_warped.keys(): for lvl in data_warped[beat].keys(): if lvl not in rmse.keys(): rmse[lvl] = {} rmse[lvl][beat] = copy(template) idx_beat = data_warped[beat][lvl]['t'].index(beat) l_resamp = data_resampled[beat][lvl]['v'][:idx_beat] r_resamp = data_resampled[beat][lvl]['v'][idx_beat+1:] l_warp = data_warped[beat][lvl]['v'][:idx_beat] r_warp = data_warped[beat][lvl]['v'][idx_beat+1:] l_orig = data_orig[beat]['v'][:idx_beat] r_orig = data_orig[beat]['v'][idx_beat+1:] rmse[lvl][beat]['warp'] = RMSE(data_orig[beat]['v'],data_warped[beat][lvl]['v']) rmse[lvl][beat]['resamp'] = RMSE(data_orig[beat]['v'],data_resampled[beat][lvl]['v']) rmse[lvl][beat]['warp_left'] = RMSE(l_orig,l_warp) rmse[lvl][beat]['resamp_left'] = RMSE(l_orig,l_resamp) rmse[lvl][beat]['warp_right'] = RMSE(r_orig,r_warp) rmse[lvl][beat]['resamp_right'] = RMSE(r_orig,r_resamp) return rmse def open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = None): ''' Data structure: File: | DICTIONARY: | --> Beat_annot: | --> lvl: | --> "t": [] | --> "v": [] ''' file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) data = {} data_out = {} with open(file_name_full,"rb") as f: data = pickle.load(f) for k in data.keys(): if k > FREQ*start_after: data_out[k] = {} if get_selected_level is not None: data_out[k][0] = data[k][0] for lvl in get_selected_level: data_out[k][lvl] = data[k][lvl] else: data_out[k] = data[k] return data_out, list(data_out[k].keys()) def min_max_normalization(vector, params = None): params_out = {} #print("\t",params) if params is not None: mi_v = params['min'] ma_v = params['max'] avg = params['avg'] norm = (np.array(vector)-mi_v)/(ma_v-mi_v) norm -= avg params_out = params else: mi_v = min(vector) ma_v = max(vector) norm = (np.array(vector)-mi_v)/(ma_v-mi_v) avg = np.average(norm) norm -= avg params_out['min'] = mi_v params_out['max'] = ma_v params_out['avg'] = avg return list(norm),params_out def resamp_one_signal(t,v,resample_type = 'linear', min_t = None, max_t = None): if resample_type == "linear": f = interp1d(t,v) elif resample_type == "flat": f = interp1d(t,v, kind = 'previous') if min_t is None: min_t = t[0] if max_t is None: max_t = t[-1] t_new = list(range(min_t,max_t+1)) v_new = f(t_new) return t_new,v_new def resample(data, resample_type = "linear", min_t = None, max_t = None): resampled_data = {"t":None,"v":None} t = data['t'] v = data['v'] t_r,v_r = resamp_one_signal(t,v,resample_type = resample_type, min_t = min_t, max_t = max_t) resampled_data['t'] = t_r resampled_data['v'] = v_r return resampled_data def change_sampling(vector, size_to_warp): resampled = [] if len(vector) < 3: #If the length of the vector is either 2 or 1 it will be mnorfed in a straight line and such line will be morfed equaly resampled = [vector[0]]*size_to_warp return resampled delta = abs(size_to_warp - len(vector)) # --> number of points to insert/delete #At least one more sample for each sample (except the last one)? --> insert k in each hole and recompute delta if delta >= len(vector)-1 and delta > 0: holes = len(vector)-1 k = delta // holes #Number of point for each hole time = range(0,len(vector)*(k+1),k+1) # [o1,i11,i12,...,i1k, o2,i21,i22,...,i2k, o3,..., o_fin] --> NOTE: last sampel is not taken # 0 k+1 2k+2 l*k +l --> Hence, we get to the exact end t_new = range(time[-1]+1) f = interp1d(time,vector) vector = list(f(t_new)) delta = abs(size_to_warp - len(vector)) if len(vector) == size_to_warp: return vector grad = np.gradient(vector) grad[-1] = max(grad) + 1 # --> we don't want to insert anything after the last sample grad_idxs = sorted(range(len(grad)), key = lambda idx: grad[idx]) idx_to_consider = grad_idxs[:delta] #print(delta, len(idx_to_consider), idx_to_consider) for i,sample in enumerate(vector): resampled.append(sample) if i in idx_to_consider: if size_to_warp < len(vector): resampled.pop() elif size_to_warp > len(vector): resampled.append((vector[i]+vector[i+1])/2) return resampled def segment_warp(segment): #print("--> ",len(segment['segment'])) #print("--> ",segment['length_to_warp']) resampled = np.array(change_sampling(segment['segment'], segment['length_to_warp'])) #print("--> ",len(resampled)) resampled += (segment['v_start'] - resampled[0]) m = (segment['v_stop'] - resampled[-1])/(segment['length_to_warp']-1) coef_to_sum = [m*x for x in range(segment['length_to_warp'])] resampled += coef_to_sum return list(resampled) def stitch_segments(segments): stitched = [] for i,segment in enumerate(segments): if i == len(segments)-1: stitched.extend(segment) else: stitched.extend(segment[0:-1]) return stitched def warp(data, templates, events): #print(events) warped = {} segments = [] # --> {"segment":v, "length_to_warp":length, "v_start":v_eb_0, "v_stop":v_eb_1} segment = {"segment":None, "length_to_warp":None, "v_start":None, "v_stop":None} segment_start = None event_start = None #chosen_template = [] #ad the first and last points of the resampled and truncated beat if not already in event if data['t'][0] not in events['t']: events['t'].insert(1, data['t'][0]) events['v'].insert(1, data['v'][0]) if data['t'][-1] not in events['t']: events['t'].insert(-1,data['t'][-1]) events['v'].insert(-1,data['v'][-1]) #print(events) #Apply DTW for matching resampled event to template v_src = data['v'] #This is how its actualy done on the library when calling 'warping_path' dist = float('inf') paths = [] selected_template = [] disatances_vector = [] for t in templates: dist_this_template, paths_this_template = dtw.warping_paths(v_src, t) disatances_vector.append(dist_this_template) if dist_this_template < dist: dist = dist_this_template paths = paths_this_template selected_template = t path = dtw.best_path(paths) #path = dtw.warping_path(v_src, template) #Remove the "left" steps from the path and segment the template based on the events point prev_idx_src = path[-1][0]+1 for idx_src, idx_temp in path: if prev_idx_src == idx_src: continue prev_idx_src = idx_src for idx, t in enumerate(events['t']): if data['t'][idx_src] == t: if segment_start == None: segment_start = idx_temp event_start = idx else: #print(f"SEGMENT {events['t'][event_start]} - {events['t'][idx]} ({abs(events['t'][event_start] - events['t'][idx])})") segment['segment'] = selected_template[segment_start:idx_temp+1] segment['length_to_warp'] = events['t'][idx] - events['t'][event_start] + 1 segment['v_start'] = events['v'][event_start] segment['v_stop'] = events['v'][idx] w = segment_warp(segment) #print(len(w)) segments.append(w) segment_start = idx_temp event_start = idx break segment_stitched = stitch_segments(segments) #print(len(segment_stitched), len(data['t'])) warped = {'t':data['t'],'v':segment_stitched} return dist, warped, disatances_vector, selected_template def reconstruct_beats(file_name, level = None, normalize = True, resample_type = 'linear', num_beats = None, verbose = False): reconstructed = {} resamp = {} data_orig = {} num_distances_out = 0 time_last_out = 0 skip_until = 0 templates = [] templates_orig = [] max_num_template = None templates_usage = {} if verbose: print(f"Extracting {file_name}") data,lvls = open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = level) lvls.remove(0) if verbose: print("Building template") - templates_orig,_ = multi_template(file_name, normalize = normalize, t_start_seconds=0,t_stop_seconds=60*5) + templates_orig,_ = multi_template(file_name, normalize = normalize, t_start_seconds=0,t_stop_seconds=60*5, percentage_each_cluster = 4) if verbose: print(f"Initial number of templates: {len(templates_orig)}") max_num_template = len(templates_orig)*3 if num_beats == None: num_beats = len(list(data.keys())) if verbose: print("reconstructing") for lvl in lvls: i = 0 if verbose: print(f"Analyzing level:{lvl}") distances = [] reference_median_dist = 0 reference_50_75_perc_distance_delta = 0 num_perc = 0 num_perc_delta_out = 0 num_median_out = 0 num_distances_out = 0 time_last_out = 0 skip_until = 0 templates = copy(templates_orig) templates_usage[lvl] = {} dist_vector = [0]*len(templates_orig) for beat in data.keys(): t_beat = beat/FREQ if t_beat < skip_until: continue if i >= num_beats: break i+=1 if (i%(num_beats/20)==0 or i == 1) and verbose: print(f"File: {file_name}, Reconstructing beat {beat} ({i}/{num_beats}: {100*i/num_beats}%, LEVEL:{lvl})") if beat not in reconstructed.keys(): reconstructed[beat] = {} resamp[beat] = {} data_orig[beat] = copy(data[beat][0]) if normalize: data_orig[beat]['v'],params_prev = min_max_normalization(data_orig[beat]['v'], params = None) data[beat][lvl]['v'],_ = min_max_normalization(data[beat][lvl]['v'], params = params_prev) data_resampled = resample(data[beat][lvl], resample_type = resample_type, min_t = data_orig[beat]['t'][0], max_t = data_orig[beat]['t'][-1]) #WARP DRIVEEEEEEE dist, reconstructed[beat][lvl], dist_all_template, selected_template = warp(data_resampled,templates,data[beat][lvl]) dist_vector = [dist_vector[i]+dist_all_template[i] for i in range(len(dist_vector))] resamp[beat][lvl] = data_resampled #Increment the used tempalte count key = tuple(selected_template) if key not in templates_usage[lvl].keys(): templates_usage[lvl][key] = 1 else: templates_usage[lvl][key] += 1 if len(distances) >= LEN_DISTANCE_VECTOR: median_dist = np.median(distances) perc_50_75_distance_delta = np.percentile(distances,75)-median_dist delta_percentage = perc_50_75_distance_delta/median_dist distances.pop(0) if (delta_percentage >= 1 or # The standard deviation of the previous beat is bigger than x% perc_50_75_distance_delta > 1.5 * reference_50_75_perc_distance_delta or # The standard deviations of the previous beats are bigger than the reference for this template median_dist > 1.5 * reference_median_dist): # The averages of the previous beats are bigger than the reference for this template if delta_percentage >= 1 : num_perc += 1 if perc_50_75_distance_delta > 1.5 * reference_50_75_perc_distance_delta : num_perc_delta_out += 1 if median_dist > 1.5 * reference_median_dist: num_median_out += 1 time_last_out = t_beat if t_beat - time_last_out > 3: #Delta time since last time num_distances_out = 0 num_perc = 0 num_perc_delta_out = 0 num_median_out = 0 time_last_out = t_beat num_distances_out += 1 if num_distances_out > 40: # number of beats in wich the warping distance was too big max_accum_dist = max(dist_vector) print(f"\nBeat num:{i}, New template needed ... ") print(f"\t Delta percentile percent too high: {num_perc}") print(f"\t Delta percentile too big wrt reference: {num_perc_delta_out}") print(f"\t Median too big wrt reference: {num_median_out}") for j in range(len(dist_vector)): print(f"\tTemplate {j}, dist: {dist_vector[j]}:\t","|"*int(10*dist_vector[j]/max_accum_dist)) - new_templates,_ = multi_template(file_name, normalize = normalize, t_start_seconds=t_beat,t_stop_seconds=t_beat+120) + new_templates,_ = multi_template(file_name, normalize = normalize, t_start_seconds=t_beat,t_stop_seconds=t_beat+120, percentage_each_cluster = 7) print(f"New template built, number of new templates:{len(new_templates)}\n") if len(new_templates)>max_num_template: templates = copy(new_templates) else: to_keep = np.argsort(dist_vector)[:max_num_template-len(new_templates)] new_full_templates = [] for idx in to_keep: new_full_templates.append(templates[idx]) print(f"\tKept template: {idx}, dist: {dist_vector[idx]}:\t","|"*int(10*dist_vector[idx]/max_accum_dist)) for t in new_templates: new_full_templates.append(t) templates = copy(new_full_templates) dist_vector = [0]*len(templates) print(f"New template list length: {len(templates)}\n") distances = [] skip_until = t_beat + 120 num_distances_out = 0 elif len(distances) == LEN_DISTANCE_VECTOR - 1: reference_median_dist = np.median(distances) reference_50_75_perc_distance_delta = np.percentile(distances,75)-reference_median_dist num_perc = 0 num_perc_delta_out = 0 num_median_out = 0 distances.append(dist) return reconstructed,data_orig,resamp,templates,templates_usage def percentile_idx(vector,perc): pcen=np.percentile(np.array(vector),perc,interpolation='nearest') i_near=abs(np.array(vector)-pcen).argmin() return i_near def correlate(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) cov = np.cov(v1_n,v2_n)[0,1] cor = cov/(np.std(v1_n)*np.std(v2_n)) return cor def avg_std_for_rmse_dict(dictionary): dict_out = {'warp':{'avg':None,'std':None},'resamp':{'avg':None,'std':None}} wr = [] rs = [] for beat in dictionary.keys(): wr.append(dictionary[beat]['warp']) rs.append(dictionary[beat]['resamp']) dict_out['warp']['avg'] = np.average(wr) dict_out['warp']['std'] = np.std(wr) dict_out['resamp']['avg'] = np.average(rs) dict_out['resamp']['std'] = np.std(rs) return dict_out def percentile_rmse_beat(rmse,perc): vec = [] for beat in rmse.keys(): vec.append(rmse[beat]['warp']) idx = percentile_idx(vec,perc) return list(rmse.keys())[idx] def recontruct_and_compare(file): recon,orig,resamp,templates,templates_usage = reconstruct_beats(file, level = LEVEL, resample_type = 'flat', num_beats = NUM_BEAT_ANALYZED, verbose = True) #resample_type = flat vs linear rmse = RMSE_warped_resampled(orig,recon,resamp) for lvl in LEVEL: ''' recon,orig,resamp,pre_proc = reconstruct_beats(file, level = [lvl], resample_type = 'linear', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear rmse = RMSE_warped_resampled(orig,recon,resamp) ''' stats = avg_std_for_rmse_dict(rmse[lvl]) avg_rmse_warp = stats['warp']['avg'] std_rmse_warp = stats['warp']['std'] avg_rmse_resamp = stats['resamp']['avg'] std_rmse_resamp = stats['resamp']['std'] #cov_rmse_warp_dist = correlate(recon_cost[lvl],rmse[lvl]) file_name_to_save = "L_"+file.split(".")[0]+".log" with open(os.path.join(log_dir,file_name_to_save),"a") as f: f.write(f"Lvl: {lvl}\n") f.write(f"\tWarp: {avg_rmse_warp}, +-{std_rmse_warp}\n") f.write(f"\tInterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") for i,templates_used in enumerate(templates_usage[lvl]): f.write(f"\tTemplate {i} was used {templates_usage[lvl][templates_used]} times out of {NUM_BEAT_ANALYZED}\n") f.write(f"\n\n") print(f"File:{file_name_to_save}") print(f"\tLvl: {lvl}") print(f"\t\twarp: {avg_rmse_warp}, +-{std_rmse_warp}") print(f"\t\tinterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}") for i,templates_used in enumerate(templates_usage[lvl]): print(f"\tTemplate {i} was used {templates_usage[lvl][templates_used]} times out of {NUM_BEAT_ANALYZED}\n") beat_to_plot_01 = percentile_rmse_beat(rmse[lvl],1) beat_to_plot_25 = percentile_rmse_beat(rmse[lvl],25) beat_to_plot_50 = percentile_rmse_beat(rmse[lvl],50) beat_to_plot_75 = percentile_rmse_beat(rmse[lvl],75) beat_to_plot_99 = percentile_rmse_beat(rmse[lvl],99) file_name_to_save_fig_01 = os.path.join(log_dir,file.split(".")[0]+"_01_perc"+str(lvl)+".svg") file_name_to_save_fig_25 = os.path.join(log_dir,file.split(".")[0]+"_25_perc"+str(lvl)+".svg") file_name_to_save_fig_50 = os.path.join(log_dir,file.split(".")[0]+"_50_perc"+str(lvl)+".svg") file_name_to_save_fig_75 = os.path.join(log_dir,file.split(".")[0]+"_75_perc"+str(lvl)+".svg") file_name_to_save_fig_99 = os.path.join(log_dir,file.split(".")[0]+"_99_perc"+str(lvl)+".svg") file_name_to_save_fig_hist = os.path.join(log_dir,file.split(".")[0]+"_hist"+str(lvl)+".svg") file_name_to_save_fig_hist_left = os.path.join(log_dir,file.split(".")[0]+"_hist_left"+str(lvl)+".svg") file_name_to_save_fig_hist_right = os.path.join(log_dir,file.split(".")[0]+"_hist_right"+str(lvl)+".svg") # 01 percentile t_o,v_o = orig[beat_to_plot_01]['t'],orig[beat_to_plot_01]['v'] t_a,v_a = recon[beat_to_plot_01][lvl]['t'],recon[beat_to_plot_01][lvl]['v'] t_r,v_r = resamp[beat_to_plot_01][lvl]['t'],resamp[beat_to_plot_01][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_01}, 01 percentile') plt.savefig(file_name_to_save_fig_01) plt.close() # 25 percentile t_o,v_o = orig[beat_to_plot_25]['t'],orig[beat_to_plot_25]['v'] t_a,v_a = recon[beat_to_plot_25][lvl]['t'],recon[beat_to_plot_25][lvl]['v'] t_r,v_r = resamp[beat_to_plot_25][lvl]['t'],resamp[beat_to_plot_25][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_25}, 25 percentile') plt.savefig(file_name_to_save_fig_25) plt.close() # 50 percentile t_o,v_o = orig[beat_to_plot_50]['t'],orig[beat_to_plot_50]['v'] t_a,v_a = recon[beat_to_plot_50][lvl]['t'],recon[beat_to_plot_50][lvl]['v'] t_r,v_r = resamp[beat_to_plot_50][lvl]['t'],resamp[beat_to_plot_50][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_50}, 50 percentile') plt.savefig(file_name_to_save_fig_50) plt.close() # 75 percentile t_o,v_o = orig[beat_to_plot_75]['t'],orig[beat_to_plot_75]['v'] t_a,v_a = recon[beat_to_plot_75][lvl]['t'],recon[beat_to_plot_75][lvl]['v'] t_r,v_r = resamp[beat_to_plot_75][lvl]['t'],resamp[beat_to_plot_75][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_75}, 75 percentile') plt.savefig(file_name_to_save_fig_75) plt.close() # 99 percentile t_o,v_o = orig[beat_to_plot_99]['t'],orig[beat_to_plot_99]['v'] t_a,v_a = recon[beat_to_plot_99][lvl]['t'],recon[beat_to_plot_99][lvl]['v'] t_r,v_r = resamp[beat_to_plot_99][lvl]['t'],resamp[beat_to_plot_99][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_99}, 99 percentile') plt.savefig(file_name_to_save_fig_99) plt.close() #bar plots rmse_warp = [] rmse_resamp = [] rmse_warp_left = [] rmse_resamp_left = [] rmse_warp_right = [] rmse_resamp_right = [] for beat in rmse[lvl].keys(): rmse_warp.append(rmse[lvl][beat]['warp']) rmse_resamp.append(rmse[lvl][beat]['resamp']) rmse_warp_left.append(rmse[lvl][beat]['warp_left']) rmse_resamp_left.append(rmse[lvl][beat]['resamp_left']) rmse_warp_right.append(rmse[lvl][beat]['warp_right']) rmse_resamp_right.append(rmse[lvl][beat]['resamp_right']) n_bins = len(rmse_warp)*PERC_BINS//100 min_bin = min(min(rmse_warp),min(rmse_resamp)) max_bin = max(max(rmse_warp),max(rmse_resamp)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp, bins = bins, alpha=0.5) plt.hist(rmse_resamp, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist) plt.close() n_bins = len(rmse_warp_left)*PERC_BINS//100 min_bin = min(min(rmse_warp_left),min(rmse_resamp_left)) max_bin = max(max(rmse_warp_left),max(rmse_resamp_left)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp_left, bins = bins, alpha=0.5) plt.hist(rmse_resamp_left, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, left RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist_left) plt.close() n_bins = len(rmse_warp_right)*PERC_BINS//100 min_bin = min(min(rmse_warp_right),min(rmse_resamp_right)) max_bin = max(max(rmse_warp_right),max(rmse_resamp_right)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp_right, bins = bins, alpha=0.5) plt.hist(rmse_resamp_right, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, right RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist_right) plt.close() for i, templ in enumerate(templates): file_name_to_save_fig_template = os.path.join(log_dir,file.split(".")[0]+f"_template_{str(i)}.svg") if not os.path.isfile(file_name_to_save_fig_template): plt.figure() plt.plot(templ) plt.legend(['template']) plt.title(f'File: {file}, template') plt.savefig(file_name_to_save_fig_template) plt.close() else: break def process(files, multi=True, cores=1): # ------------ INIT ------------ global log_dir for i in range(1,1000): tmp_log_dir = log_dir+str(i) if not os.path.isdir(tmp_log_dir): log_dir = tmp_log_dir break os.mkdir(log_dir) # ------------ Extract DATA & ANNOTATIONS ------------ if cores == 1: print("Single core") for f in files: recontruct_and_compare(f) else: with Pool(cores) as pool: pool.map(recontruct_and_compare, files) if __name__ == "__main__": import argparse #global NUM_BEAT_ANALYZED parser = argparse.ArgumentParser() parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true") parser.add_argument("--cores", help="Force used number of cores (default, half of the available ones") parser.add_argument("--beats", help="Number of used beats, default: 5000") parser.add_argument("--template_type", help="Type of template, default: distance") args = parser.parse_args() files = os.listdir(data_beats_dir) if args.file is not None: if args.file == 'all': analyzed = files else: analyzed = list(filter(lambda string: True if args.file in string else False, files)) else: analyzed = [files[0]] if args.not_norm: normalize = False else: normalize = True if args.cores is not None: used_cores = int(args.cores) else: used_cores = multiprocessing.cpu_count()//2 if args.beats is not None: NUM_BEAT_ANALYZED = int(args.beats) else: NUM_BEAT_ANALYZED = 5000 if args.template_type is not None: TEMPLATE_TYPE = 'average' else: TEMPLATE_TYPE = 'distance' print(f"Analyzing files: {analyzed}") print(f"Extracting data with {used_cores} cores...") process(files = analyzed, multi=True, cores=used_cores) \ No newline at end of file