diff --git a/src/build_template.py b/src/build_template.py index 8635c35..b981e35 100644 --- a/src/build_template.py +++ b/src/build_template.py @@ -1,156 +1,171 @@ import numpy as np import pickle import matplotlib.pyplot as plt import os +from sklearn.cluster import MeanShift, estimate_bandwidth data_beats_dir = "../data/beats/" FREQ = 128 MIN_TO_AVG = 5 def open_file(file_name): ''' Data structure: File: | DICTIONARY: | --> Beat_annot: | --> lvl: | --> "t": [] | --> "v": [] ''' file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) data = {} with open(file_name_full,"rb") as f: data = pickle.load(f) for k in data.keys(): data[k] = data[k][0] return data def min_max_normalization(vector): mi_v = min(vector) ma_v = max(vector) norm = (np.array(vector)-mi_v)/(ma_v-mi_v) norm -= np.average(norm) return list(norm) def get_significant_peak(signal): val = -1000 peak_max = None is_peak = lambda w,x,y,z: True if (wy) or (wz) else False for i in range(1,len(signal)-2): if is_peak(signal[i-1],signal[i],signal[i+1],signal[i+2]) and signal[i]> val: peak_max = i val = signal[i] return peak_max -def get_5_min_beats(data): +def get_beats(data, t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60): beats = {} annotations = list(data.keys()) for annot in annotations[1:]: - if annot > int(MIN_TO_AVG*60*FREQ): + if annot >= int(t_start_seconds*FREQ) and annot <= int(t_stop_seconds*FREQ): + beats[annot] = data[annot] + elif annot > int(t_stop_seconds*FREQ): break - beats[annot] = data[annot] return beats def get_r_peaks_idx(beats): r_times_idx = [] time_to_look = int(0.15*FREQ) #150 ms for annot in beats.keys(): idx = 0 signal_r_peaks = [] for i,t in enumerate(beats[annot]['t']): if t == annot - time_to_look: idx = i if t >= annot - time_to_look and t<= annot + time_to_look: signal_r_peaks.append(beats[annot]['v'][i]) peak_idx = get_significant_peak(signal_r_peaks)+idx r_times_idx.append(peak_idx) return r_times_idx def allign_beats(data, allignment_idxs, normalize = True): len_bef = len_aft = 1000 data_alligned = {} for annot, allignment_idx in zip(data.keys(),allignment_idxs): time = data[annot]['t'] this_len_bef = len(time[:allignment_idx]) this_len_aft = len(time[allignment_idx+1:]) if this_len_bef < len_bef: len_bef = this_len_bef if this_len_aft < len_aft: len_aft = this_len_aft for annot, allignment_idx in zip(data.keys(),allignment_idxs): new_t = data[annot]['t'][allignment_idx-len_bef:allignment_idx+len_aft+1] new_v = data[annot]['v'][allignment_idx-len_bef:allignment_idx+len_aft+1] if normalize: new_v = min_max_normalization(new_v) data_alligned[annot] = {'t':new_t,'v':new_v} return data_alligned def RMSE(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) return np.sqrt(np.mean((v1_n-v2_n)**2)) def compute_template_lowest_distance(beats): dist = np.zeros((len(beats),len(beats))) for i in range(len(beats)): for j in range(i+1,len(beats)): dist[i][j] = RMSE(beats[i],beats[j]) dist[j][i] = dist[i][j] dist = np.sum(dist, axis=0) idx = np.argmin(dist) return beats[idx] -def build_template(file_analyzed, normalize = True, mode = 'average'): +def build_template(file_analyzed, normalize = True, mode = 'average', t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60): template = [] beats_alligned = [] data = open_file(file_analyzed) - beats = get_5_min_beats(data) + beats = get_beats(data, t_start_seconds = t_start_seconds, t_stop_seconds = t_stop_seconds) allignment_idxs = get_r_peaks_idx(beats) alligned_beats = allign_beats(beats, allignment_idxs, normalize=normalize) for annot in alligned_beats.keys(): beats_alligned.append(alligned_beats[annot]['v']) if mode == 'average': template = list(np.average(beats_alligned,axis=0)) elif mode == 'distance': template = compute_template_lowest_distance(beats_alligned) template_std = list(np.std(beats_alligned,axis=0)) return template, template_std, alligned_beats +def multi_template(file_analyzed, normalize = True, t_start_seconds = 0, t_stop_seconds = MIN_TO_AVG*60): + beats_alligned = [] + data = open_file(file_analyzed) + beats = get_beats(data, t_start_seconds = t_start_seconds, t_stop_seconds = t_stop_seconds) + allignment_idxs = get_r_peaks_idx(beats) + alligned_beats = allign_beats(beats, allignment_idxs, normalize=normalize) + for annot in alligned_beats.keys(): + beats_alligned.append(alligned_beats[annot]['v']) + ms = MeanShift() + ms.fit(beats_alligned) + cluster_centers = ms.cluster_centers_.tolist() + return cluster_centers, alligned_beats + if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true") args = parser.parse_args() files = os.listdir(data_beats_dir) if args.file is not None: analyzed = list(filter(lambda string: True if args.file in string else False, files))[0] else: analyzed = files[0] if args.not_norm: normalize = False else: normalize = True template_dist,std,alligned = build_template(analyzed, normalize = normalize, mode = 'distance') template_avg,_,_ = build_template(analyzed, normalize = normalize, mode = 'average') plt.plot(template_dist, color = 'C0') plt.plot(template_avg, color = 'C3') plt.plot(np.array(template_dist)+np.array(std), color = "C1") plt.plot(np.array(template_dist)-np.array(std), color = "C1") for beat in alligned.keys(): plt.plot(alligned[beat]['v'],color = 'C2', alpha = 0.01) plt.figure() plt.plot(template_dist, color = 'C0') plt.plot(template_avg, color = 'C3') plt.show() diff --git a/src/multiple_template_recon.py b/src/multiple_template_recon.py new file mode 100644 index 0000000..4686f3c --- /dev/null +++ b/src/multiple_template_recon.py @@ -0,0 +1,610 @@ +import multiprocessing +import os +import pickle +import shutil +from multiprocessing import Pool +from time import time + +import matplotlib.pyplot as plt +import numpy as np +from dtaidistance import dtw +from scipy.interpolate import interp1d +from copy import copy + +from build_template import build_template, multi_template + +data_beats_dir = "../data/beats/" +log_dir = "../data/beat_recon_logs_multi_" + +FREQ = 128 +MIN_TO_AVG = 5 +BEAT_TO_PLOT = 2 +LEVEL = [1,2,3,4,5,6,7,8,9,10,11] +NUM_BEAT_ANALYZED = 50 +LEN_DISTANCE_VECTOR = 50 +TEMPLATE_TYPE = 'distance' +PERC_BINS = 15 + +def RMSE(v1,v2): + v1_n = np.array(v1) + v2_n = np.array(v2) + return np.sqrt(np.mean((v1_n-v2_n)**2)) + +def RMSE_warped_resampled(data_orig,data_warped,data_resampled): + rmse = {} + template = {'warp':None,'resamp':None,'warp_left':None,'resamp_left':None,'warp_right':None,'resamp_right':None} + for beat in data_warped.keys(): + for lvl in data_warped[beat].keys(): + if lvl not in rmse.keys(): + rmse[lvl] = {} + rmse[lvl][beat] = copy(template) + idx_beat = data_warped[beat][lvl]['t'].index(beat) + l_resamp = data_resampled[beat][lvl]['v'][:idx_beat] + r_resamp = data_resampled[beat][lvl]['v'][idx_beat+1:] + l_warp = data_warped[beat][lvl]['v'][:idx_beat] + r_warp = data_warped[beat][lvl]['v'][idx_beat+1:] + l_orig = data_orig[beat]['v'][:idx_beat] + r_orig = data_orig[beat]['v'][idx_beat+1:] + + rmse[lvl][beat]['warp'] = RMSE(data_orig[beat]['v'],data_warped[beat][lvl]['v']) + rmse[lvl][beat]['resamp'] = RMSE(data_orig[beat]['v'],data_resampled[beat][lvl]['v']) + + rmse[lvl][beat]['warp_left'] = RMSE(l_orig,l_warp) + rmse[lvl][beat]['resamp_left'] = RMSE(l_orig,l_resamp) + + rmse[lvl][beat]['warp_right'] = RMSE(r_orig,r_warp) + rmse[lvl][beat]['resamp_right'] = RMSE(r_orig,r_resamp) + return rmse + +def open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = None): + ''' + Data structure: + File: + | + DICTIONARY: + | + --> Beat_annot: + | + --> lvl: + | + --> "t": [] + | + --> "v": [] + ''' + file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) + data = {} + data_out = {} + with open(file_name_full,"rb") as f: + data = pickle.load(f) + for k in data.keys(): + if k > FREQ*start_after: + data_out[k] = {} + if get_selected_level is not None: + data_out[k][0] = data[k][0] + for lvl in get_selected_level: + data_out[k][lvl] = data[k][lvl] + else: + data_out[k] = data[k] + return data_out, list(data_out[k].keys()) + +def min_max_normalization(vector, params = None): + params_out = {} + #print("\t",params) + if params is not None: + mi_v = params['min'] + ma_v = params['max'] + avg = params['avg'] + norm = (np.array(vector)-mi_v)/(ma_v-mi_v) + norm -= avg + params_out = params + else: + mi_v = min(vector) + ma_v = max(vector) + norm = (np.array(vector)-mi_v)/(ma_v-mi_v) + avg = np.average(norm) + norm -= avg + params_out['min'] = mi_v + params_out['max'] = ma_v + params_out['avg'] = avg + return list(norm),params_out + +def resamp_one_signal(t,v,resample_type = 'linear', min_t = None, max_t = None): + if resample_type == "linear": + f = interp1d(t,v) + elif resample_type == "flat": + f = interp1d(t,v, kind = 'previous') + if min_t is None: + min_t = t[0] + if max_t is None: + max_t = t[-1] + t_new = list(range(min_t,max_t+1)) + v_new = f(t_new) + return t_new,v_new + +def resample(data, resample_type = "linear", min_t = None, max_t = None): + resampled_data = {"t":None,"v":None} + t = data['t'] + v = data['v'] + t_r,v_r = resamp_one_signal(t,v,resample_type = resample_type, min_t = min_t, max_t = max_t) + resampled_data['t'] = t_r + resampled_data['v'] = v_r + + return resampled_data + +def change_sampling(vector, size_to_warp): + resampled = [] + if len(vector) < 3: #If the length of the vector is either 2 or 1 it will be mnorfed in a straight line and such line will be morfed equaly + resampled = [vector[0]]*size_to_warp + return resampled + delta = abs(size_to_warp - len(vector)) # --> number of points to insert/delete + #At least one more sample for each sample (except the last one)? --> insert k in each hole and recompute delta + if delta >= len(vector)-1 and delta > 0: + holes = len(vector)-1 + k = delta // holes #Number of point for each hole + time = range(0,len(vector)*(k+1),k+1) # [o1,i11,i12,...,i1k, o2,i21,i22,...,i2k, o3,..., o_fin] --> NOTE: last sampel is not taken + # 0 k+1 2k+2 l*k +l --> Hence, we get to the exact end + t_new = range(time[-1]+1) + f = interp1d(time,vector) + vector = list(f(t_new)) + delta = abs(size_to_warp - len(vector)) + if len(vector) == size_to_warp: + return vector + + + grad = np.gradient(vector) + grad[-1] = max(grad) + 1 # --> we don't want to insert anything after the last sample + grad_idxs = sorted(range(len(grad)), key = lambda idx: grad[idx]) + + idx_to_consider = grad_idxs[:delta] + #print(delta, len(idx_to_consider), idx_to_consider) + for i,sample in enumerate(vector): + resampled.append(sample) + if i in idx_to_consider: + if size_to_warp < len(vector): + resampled.pop() + elif size_to_warp > len(vector): + resampled.append((vector[i]+vector[i+1])/2) + return resampled + +def segment_warp(segment): + #print("--> ",len(segment['segment'])) + #print("--> ",segment['length_to_warp']) + resampled = np.array(change_sampling(segment['segment'], segment['length_to_warp'])) + #print("--> ",len(resampled)) + resampled += (segment['v_start'] - resampled[0]) + m = (segment['v_stop'] - resampled[-1])/(segment['length_to_warp']-1) + coef_to_sum = [m*x for x in range(segment['length_to_warp'])] + resampled += coef_to_sum + return list(resampled) + +def stitch_segments(segments): + stitched = [] + for i,segment in enumerate(segments): + if i == len(segments)-1: + stitched.extend(segment) + else: + stitched.extend(segment[0:-1]) + return stitched + +def warp(data, templates, events): + #print(events) + warped = {} + segments = [] # --> {"segment":v, "length_to_warp":length, "v_start":v_eb_0, "v_stop":v_eb_1} + segment = {"segment":None, "length_to_warp":None, "v_start":None, "v_stop":None} + segment_start = None + event_start = None + #chosen_template = [] + + #ad the first and last points of the resampled and truncated beat if not already in event + if data['t'][0] not in events['t']: + events['t'].insert(1, data['t'][0]) + events['v'].insert(1, data['v'][0]) + if data['t'][-1] not in events['t']: + events['t'].insert(-1,data['t'][-1]) + events['v'].insert(-1,data['v'][-1]) + + #print(events) + #Apply DTW for matching resampled event to template + v_src = data['v'] + #This is how its actualy done on the library when calling 'warping_path' + dist = float('inf') + paths = [] + selected_template = [] + for t in templates: + dist_this_template, paths_this_template = dtw.warping_paths(v_src, t) + if dist_this_template < dist: + dist = dist_this_template + paths = paths_this_template + selected_template = t + path = dtw.best_path(paths) + #path = dtw.warping_path(v_src, template) + + #Remove the "left" steps from the path and segment the template based on the events point + prev_idx_src = path[-1][0]+1 + for idx_src, idx_temp in path: + if prev_idx_src == idx_src: + continue + prev_idx_src = idx_src + for idx, t in enumerate(events['t']): + if data['t'][idx_src] == t: + if segment_start == None: + segment_start = idx_temp + event_start = idx + else: + #print(f"SEGMENT {events['t'][event_start]} - {events['t'][idx]} ({abs(events['t'][event_start] - events['t'][idx])})") + + segment['segment'] = selected_template[segment_start:idx_temp+1] + segment['length_to_warp'] = events['t'][idx] - events['t'][event_start] + 1 + segment['v_start'] = events['v'][event_start] + segment['v_stop'] = events['v'][idx] + w = segment_warp(segment) + #print(len(w)) + segments.append(w) + segment_start = idx_temp + event_start = idx + break + segment_stitched = stitch_segments(segments) + #print(len(segment_stitched), len(data['t'])) + warped = {'t':data['t'],'v':segment_stitched} + return dist, warped, selected_template + +def reconstruct_beats(file_name, level = None, normalize = True, resample_type = 'linear', num_beats = None, verbose = False, prev_pre_proc = None): + reconstructed = {} + recon_cost = {} + resamp = {} + data_orig = {} + pre_proc = {} + + num_distances_out = 0 + time_last_out_std = 0 + skip_until = 0 + + templates = [] + templates_orig = [] + + if verbose: + print(f"Extracting {file_name}") + data,lvls = open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = level) + lvls.remove(0) + + if verbose: + print("Building template") + if prev_pre_proc is not None: + templates_orig = prev_pre_proc['template'] + else: + templates_orig,_ = multi_template(file_name, normalize = normalize, t_start_seconds=0,t_stop_seconds=60*5) + print(f"Initial number of templates: {len(templates_orig)}") + + if num_beats == None: + num_beats = len(list(data.keys())) + if verbose: + print("reconstructing") + + for lvl in lvls: + i = 0 + recon_cost[lvl] = [] + if verbose: + print(f"Analyzing level:{lvl}") + distances = [] + reference_average_diatance = 0 + reference_std_distance = 0 + num_distances_out = 0 + time_last_out_std = 0 + skip_until = 0 + templates = copy(templates_orig) + for beat in data.keys(): + t_beat = beat/FREQ + if t_beat < skip_until: + continue + if i == num_beats: + break + i+=1 + if (i%(num_beats/20)==0 or i == 1) and verbose: + print(f"File: {file_name}, Reconstructing beat {beat} ({i}/{num_beats}: {100*i/num_beats}%, LEVEL:{lvl})") + if beat not in reconstructed.keys(): + reconstructed[beat] = {} + resamp[beat] = {} + data_orig[beat] = copy(data[beat][0]) + if normalize: + data_orig[beat]['v'],params_prev = min_max_normalization(data_orig[beat]['v'], params = None) + data[beat][lvl]['v'],_ = min_max_normalization(data[beat][lvl]['v'], params = params_prev) + data_resampled = resample(data[beat][lvl], resample_type = resample_type, min_t = data_orig[beat]['t'][0], max_t = data_orig[beat]['t'][-1]) + dist, reconstructed[beat][lvl],_ = warp(data_resampled,templates,data[beat][lvl]) + resamp[beat][lvl] = data_resampled + recon_cost[lvl].append(dist) + + if len(distances) >= LEN_DISTANCE_VECTOR: + avg_dist = np.average(distances) + std_dist = np.std(distances) + perc_std = std_dist/avg_dist + distances.pop(0) + if (abs(dist-avg_dist)>std_dist or # The acquired beat is outside statistic + perc_std >= 1 or # The standard deviation of the previous beat is bigger than x% + std_dist > 1.5 * reference_std_distance or # The standard deviations of the previous beats are bigger than the reference for this template + avg_dist > 1.5 * reference_average_diatance): # The averages of the previous beats are bigger than the reference for this template + + time_last_out_std = t_beat + if t_beat - time_last_out_std > 2: #Delta time since last time the distance was outside one std + num_distances_out = 0 + time_last_out_std = t_beat + num_distances_out += 1 + if num_distances_out > 40: # number of beats in wich the warping distance was too big + print(f"\nBeat num:{i}, New template needed ... ") + print(f"\t Beat outside std? {abs(dist-avg_dist)>std_dist}") + print(f"\t Std percent too high? {perc_std >= 1}") + print(f"\t Std too big wrt reference? {std_dist > 1.5 * reference_std_distance}") + print(f"\t Average too big wrt reference? {avg_dist > 1.5 * reference_average_diatance}") + template,_,_ = build_template(file_name, normalize = normalize, mode = TEMPLATE_TYPE,t_start_seconds=t_beat,t_stop_seconds=t_beat+40) + print(f"Template built\n") + templates.append(template) + distances = [] + skip_until = t_beat + 40 + num_distances_out = 0 + elif len(distances) == LEN_DISTANCE_VECTOR - 1: + reference_average_diatance = np.average(distances) + reference_std_distance = np.std(distances) + distances.append(dist) + pre_proc['template'] = templates_orig + + return reconstructed,data_orig,resamp,pre_proc,templates + + +def percentile_idx(vector,perc): + pcen=np.percentile(np.array(vector),perc,interpolation='nearest') + i_near=abs(np.array(vector)-pcen).argmin() + return i_near + +def correlate(v1,v2): + v1_n = np.array(v1) + v2_n = np.array(v2) + cov = np.cov(v1_n,v2_n)[0,1] + cor = cov/(np.std(v1_n)*np.std(v2_n)) + return cor + +def avg_std_for_rmse_dict(dictionary): + dict_out = {'warp':{'avg':None,'std':None},'resamp':{'avg':None,'std':None}} + wr = [] + rs = [] + for beat in dictionary.keys(): + wr.append(dictionary[beat]['warp']) + rs.append(dictionary[beat]['resamp']) + dict_out['warp']['avg'] = np.average(wr) + dict_out['warp']['std'] = np.std(wr) + dict_out['resamp']['avg'] = np.average(rs) + dict_out['resamp']['std'] = np.std(rs) + return dict_out + +def percentile_rmse_beat(rmse,perc): + vec = [] + for beat in rmse.keys(): + vec.append(rmse[beat]['warp']) + idx = percentile_idx(vec,perc) + return list(rmse.keys())[idx] + +def recontruct_and_compare(file): + pre_proc = None + recon,orig,resamp,pre_proc,templates = reconstruct_beats(file, level = LEVEL, resample_type = 'flat', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear + rmse = RMSE_warped_resampled(orig,recon,resamp) + for lvl in LEVEL: + ''' + recon,orig,resamp,pre_proc = reconstruct_beats(file, level = [lvl], resample_type = 'linear', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear + rmse = RMSE_warped_resampled(orig,recon,resamp) + ''' + stats = avg_std_for_rmse_dict(rmse[lvl]) + avg_rmse_warp = stats['warp']['avg'] + std_rmse_warp = stats['warp']['std'] + avg_rmse_resamp = stats['resamp']['avg'] + std_rmse_resamp = stats['resamp']['std'] + #cov_rmse_warp_dist = correlate(recon_cost[lvl],rmse[lvl]) + file_name_to_save = "L_"+file.split(".")[0]+".log" + with open(os.path.join(log_dir,file_name_to_save),"a") as f: + f.write(f"Lvl: {lvl}\n") + f.write(f"\tWarp: {avg_rmse_warp}, +-{std_rmse_warp}\n") + f.write(f"\tInterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") + #f.write(f"\tCorrelation between warping cost and rmse: {cov_rmse_warp_dist}\n") + f.write(f"\n\n") + f.write(f"\n\n") + print(f"File:{file_name_to_save}") + print(f"\tLvl: {lvl}") + print(f"\t\twarp: {avg_rmse_warp}, +-{std_rmse_warp}") + print(f"\t\tinterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}") + + beat_to_plot_01 = percentile_rmse_beat(rmse[lvl],1) + beat_to_plot_25 = percentile_rmse_beat(rmse[lvl],25) + beat_to_plot_50 = percentile_rmse_beat(rmse[lvl],50) + beat_to_plot_75 = percentile_rmse_beat(rmse[lvl],75) + beat_to_plot_99 = percentile_rmse_beat(rmse[lvl],99) + + file_name_to_save_fig_01 = os.path.join(log_dir,file.split(".")[0]+"_01_perc"+str(lvl)+".svg") + file_name_to_save_fig_25 = os.path.join(log_dir,file.split(".")[0]+"_25_perc"+str(lvl)+".svg") + file_name_to_save_fig_50 = os.path.join(log_dir,file.split(".")[0]+"_50_perc"+str(lvl)+".svg") + file_name_to_save_fig_75 = os.path.join(log_dir,file.split(".")[0]+"_75_perc"+str(lvl)+".svg") + file_name_to_save_fig_99 = os.path.join(log_dir,file.split(".")[0]+"_99_perc"+str(lvl)+".svg") + + file_name_to_save_fig_hist = os.path.join(log_dir,file.split(".")[0]+"_hist"+str(lvl)+".svg") + file_name_to_save_fig_hist_left = os.path.join(log_dir,file.split(".")[0]+"_hist_left"+str(lvl)+".svg") + file_name_to_save_fig_hist_right = os.path.join(log_dir,file.split(".")[0]+"_hist_right"+str(lvl)+".svg") + + + # 01 percentile + t_o,v_o = orig[beat_to_plot_01]['t'],orig[beat_to_plot_01]['v'] + t_a,v_a = recon[beat_to_plot_01][lvl]['t'],recon[beat_to_plot_01][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_01][lvl]['t'],resamp[beat_to_plot_01][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_01}, 01 percentile') + plt.savefig(file_name_to_save_fig_01) + plt.close() + + # 25 percentile + t_o,v_o = orig[beat_to_plot_25]['t'],orig[beat_to_plot_25]['v'] + t_a,v_a = recon[beat_to_plot_25][lvl]['t'],recon[beat_to_plot_25][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_25][lvl]['t'],resamp[beat_to_plot_25][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_25}, 25 percentile') + plt.savefig(file_name_to_save_fig_25) + plt.close() + + # 50 percentile + t_o,v_o = orig[beat_to_plot_50]['t'],orig[beat_to_plot_50]['v'] + t_a,v_a = recon[beat_to_plot_50][lvl]['t'],recon[beat_to_plot_50][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_50][lvl]['t'],resamp[beat_to_plot_50][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_50}, 50 percentile') + plt.savefig(file_name_to_save_fig_50) + plt.close() + + # 75 percentile + t_o,v_o = orig[beat_to_plot_75]['t'],orig[beat_to_plot_75]['v'] + t_a,v_a = recon[beat_to_plot_75][lvl]['t'],recon[beat_to_plot_75][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_75][lvl]['t'],resamp[beat_to_plot_75][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_75}, 75 percentile') + plt.savefig(file_name_to_save_fig_75) + plt.close() + + # 99 percentile + t_o,v_o = orig[beat_to_plot_99]['t'],orig[beat_to_plot_99]['v'] + t_a,v_a = recon[beat_to_plot_99][lvl]['t'],recon[beat_to_plot_99][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_99][lvl]['t'],resamp[beat_to_plot_99][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_99}, 99 percentile') + plt.savefig(file_name_to_save_fig_99) + plt.close() + + #bar plots + rmse_warp = [] + rmse_resamp = [] + rmse_warp_left = [] + rmse_resamp_left = [] + rmse_warp_right = [] + rmse_resamp_right = [] + for beat in rmse[lvl].keys(): + rmse_warp.append(rmse[lvl][beat]['warp']) + rmse_resamp.append(rmse[lvl][beat]['resamp']) + rmse_warp_left.append(rmse[lvl][beat]['warp_left']) + rmse_resamp_left.append(rmse[lvl][beat]['resamp_left']) + rmse_warp_right.append(rmse[lvl][beat]['warp_right']) + rmse_resamp_right.append(rmse[lvl][beat]['resamp_right']) + + n_bins = len(rmse_warp)*PERC_BINS//100 + min_bin = min(min(rmse_warp),min(rmse_resamp)) + max_bin = max(max(rmse_warp),max(rmse_resamp)) + delta = (max_bin-min_bin)/n_bins + bins = np.arange(min_bin,max_bin+delta,delta) + plt.hist(rmse_warp, bins = bins, alpha=0.5) + plt.hist(rmse_resamp, bins = bins, alpha=0.5) + plt.title(f'File: {file}, Lvl: {lvl}, RMSE histogram') + plt.legend(['RMSE warp','RMSE resampled']) + plt.savefig(file_name_to_save_fig_hist) + plt.close() + + n_bins = len(rmse_warp_left)*PERC_BINS//100 + min_bin = min(min(rmse_warp_left),min(rmse_resamp_left)) + max_bin = max(max(rmse_warp_left),max(rmse_resamp_left)) + delta = (max_bin-min_bin)/n_bins + bins = np.arange(min_bin,max_bin+delta,delta) + plt.hist(rmse_warp_left, bins = bins, alpha=0.5) + plt.hist(rmse_resamp_left, bins = bins, alpha=0.5) + plt.title(f'File: {file}, Lvl: {lvl}, left RMSE histogram') + plt.legend(['RMSE warp','RMSE resampled']) + plt.savefig(file_name_to_save_fig_hist_left) + plt.close() + + n_bins = len(rmse_warp_right)*PERC_BINS//100 + min_bin = min(min(rmse_warp_right),min(rmse_resamp_right)) + max_bin = max(max(rmse_warp_right),max(rmse_resamp_right)) + delta = (max_bin-min_bin)/n_bins + bins = np.arange(min_bin,max_bin+delta,delta) + plt.hist(rmse_warp_right, bins = bins, alpha=0.5) + plt.hist(rmse_resamp_right, bins = bins, alpha=0.5) + plt.title(f'File: {file}, Lvl: {lvl}, right RMSE histogram') + plt.legend(['RMSE warp','RMSE resampled']) + plt.savefig(file_name_to_save_fig_hist_right) + plt.close() + + templates = pre_proc['template'] + for i, templ in enumerate(templates): + file_name_to_save_fig_template = os.path.join(log_dir,file.split(".")[0]+f"_template_{str(i)}.svg") + if not os.path.isfile(file_name_to_save_fig_template): + plt.figure() + plt.plot(templ) + plt.legend(['template']) + plt.title(f'File: {file}, template') + plt.savefig(file_name_to_save_fig_template) + plt.close() + else: + break + +def process(files, multi=True, cores=1): + # ------------ INIT ------------ + global log_dir + for i in range(1,1000): + tmp_log_dir = log_dir+str(i) + if not os.path.isdir(tmp_log_dir): + log_dir = tmp_log_dir + break + os.mkdir(log_dir) + + # ------------ Extract DATA & ANNOTATIONS ------------ + with Pool(cores) as pool: + pool.map(recontruct_and_compare, files) + + +if __name__ == "__main__": + import argparse + #global NUM_BEAT_ANALYZED + parser = argparse.ArgumentParser() + parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") + parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true") + parser.add_argument("--cores", help="Force used number of cores (default, half of the available ones") + parser.add_argument("--beats", help="Number of used beats, default: 5000") + parser.add_argument("--template_type", help="Type of template, default: distance") + args = parser.parse_args() + files = os.listdir(data_beats_dir) + if args.file is not None: + if args.file == 'all': + analyzed = files + else: + analyzed = list(filter(lambda string: True if args.file in string else False, files)) + else: + analyzed = [files[0]] + if args.not_norm: + normalize = False + else: + normalize = True + if args.cores is not None: + used_cores = int(args.cores) + else: + used_cores = multiprocessing.cpu_count()//2 + if args.beats is not None: + NUM_BEAT_ANALYZED = int(args.beats) + else: + NUM_BEAT_ANALYZED = 5000 + if args.template_type is not None: + TEMPLATE_TYPE = 'average' + else: + TEMPLATE_TYPE = 'distance' + + print(f"Analyzing files: {analyzed}") + print(f"Extracting data with {used_cores} cores...") + process(files = analyzed, multi=True, cores=used_cores) \ No newline at end of file diff --git a/src/beat_reconstruction_template_naive.py b/src/naive_dtw_recon.py similarity index 100% rename from src/beat_reconstruction_template_naive.py rename to src/naive_dtw_recon.py diff --git a/src/progressive_recomp_recon.py b/src/progressive_recomp_recon.py new file mode 100644 index 0000000..f2ceadc --- /dev/null +++ b/src/progressive_recomp_recon.py @@ -0,0 +1,592 @@ +import multiprocessing +import os +import pickle +import shutil +from multiprocessing import Pool +from time import time + +import matplotlib.pyplot as plt +import numpy as np +from dtaidistance import dtw +from scipy.interpolate import interp1d +from copy import copy + +from build_template import build_template + +data_beats_dir = "../data/beats/" +log_dir = "../data/beat_recon_logs_prog_" + +FREQ = 128 +MIN_TO_AVG = 5 +BEAT_TO_PLOT = 2 +LEVEL = [1,2,3,4,5,6,7,8,9,10,11] +NUM_BEAT_ANALYZED = 50 +LEN_DISTANCE_VECTOR = 50 +TEMPLATE_TYPE = 'distance' +PERC_BINS = 15 + +def RMSE(v1,v2): + v1_n = np.array(v1) + v2_n = np.array(v2) + return np.sqrt(np.mean((v1_n-v2_n)**2)) + +def RMSE_warped_resampled(data_orig,data_warped,data_resampled): + rmse = {} + template = {'warp':None,'resamp':None,'warp_left':None,'resamp_left':None,'warp_right':None,'resamp_right':None} + for beat in data_warped.keys(): + for lvl in data_warped[beat].keys(): + if lvl not in rmse.keys(): + rmse[lvl] = {} + rmse[lvl][beat] = copy(template) + idx_beat = data_warped[beat][lvl]['t'].index(beat) + l_resamp = data_resampled[beat][lvl]['v'][:idx_beat] + r_resamp = data_resampled[beat][lvl]['v'][idx_beat+1:] + l_warp = data_warped[beat][lvl]['v'][:idx_beat] + r_warp = data_warped[beat][lvl]['v'][idx_beat+1:] + l_orig = data_orig[beat]['v'][:idx_beat] + r_orig = data_orig[beat]['v'][idx_beat+1:] + + rmse[lvl][beat]['warp'] = RMSE(data_orig[beat]['v'],data_warped[beat][lvl]['v']) + rmse[lvl][beat]['resamp'] = RMSE(data_orig[beat]['v'],data_resampled[beat][lvl]['v']) + + rmse[lvl][beat]['warp_left'] = RMSE(l_orig,l_warp) + rmse[lvl][beat]['resamp_left'] = RMSE(l_orig,l_resamp) + + rmse[lvl][beat]['warp_right'] = RMSE(r_orig,r_warp) + rmse[lvl][beat]['resamp_right'] = RMSE(r_orig,r_resamp) + return rmse + +def open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = None): + ''' + Data structure: + File: + | + DICTIONARY: + | + --> Beat_annot: + | + --> lvl: + | + --> "t": [] + | + --> "v": [] + ''' + file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) + data = {} + data_out = {} + with open(file_name_full,"rb") as f: + data = pickle.load(f) + for k in data.keys(): + if k > FREQ*start_after: + data_out[k] = {} + if get_selected_level is not None: + data_out[k][0] = data[k][0] + for lvl in get_selected_level: + data_out[k][lvl] = data[k][lvl] + else: + data_out[k] = data[k] + return data_out, list(data_out[k].keys()) + +def min_max_normalization(vector, params = None): + params_out = {} + #print("\t",params) + if params is not None: + mi_v = params['min'] + ma_v = params['max'] + avg = params['avg'] + norm = (np.array(vector)-mi_v)/(ma_v-mi_v) + norm -= avg + params_out = params + else: + mi_v = min(vector) + ma_v = max(vector) + norm = (np.array(vector)-mi_v)/(ma_v-mi_v) + avg = np.average(norm) + norm -= avg + params_out['min'] = mi_v + params_out['max'] = ma_v + params_out['avg'] = avg + return list(norm),params_out + +def resamp_one_signal(t,v,resample_type = 'linear', min_t = None, max_t = None): + if resample_type == "linear": + f = interp1d(t,v) + elif resample_type == "flat": + f = interp1d(t,v, kind = 'previous') + if min_t is None: + min_t = t[0] + if max_t is None: + max_t = t[-1] + t_new = list(range(min_t,max_t+1)) + v_new = f(t_new) + return t_new,v_new + +def resample(data, resample_type = "linear", min_t = None, max_t = None): + resampled_data = {"t":None,"v":None} + t = data['t'] + v = data['v'] + t_r,v_r = resamp_one_signal(t,v,resample_type = resample_type, min_t = min_t, max_t = max_t) + resampled_data['t'] = t_r + resampled_data['v'] = v_r + + return resampled_data + +def change_sampling(vector, size_to_warp): + resampled = [] + if len(vector) < 3: #If the length of the vector is either 2 or 1 it will be mnorfed in a straight line and such line will be morfed equaly + resampled = [vector[0]]*size_to_warp + return resampled + delta = abs(size_to_warp - len(vector)) # --> number of points to insert/delete + #At least one more sample for each sample (except the last one)? --> insert k in each hole and recompute delta + if delta >= len(vector)-1 and delta > 0: + holes = len(vector)-1 + k = delta // holes #Number of point for each hole + time = range(0,len(vector)*(k+1),k+1) # [o1,i11,i12,...,i1k, o2,i21,i22,...,i2k, o3,..., o_fin] --> NOTE: last sampel is not taken + # 0 k+1 2k+2 l*k +l --> Hence, we get to the exact end + t_new = range(time[-1]+1) + f = interp1d(time,vector) + vector = list(f(t_new)) + delta = abs(size_to_warp - len(vector)) + if len(vector) == size_to_warp: + return vector + + + grad = np.gradient(vector) + grad[-1] = max(grad) + 1 # --> we don't want to insert anything after the last sample + grad_idxs = sorted(range(len(grad)), key = lambda idx: grad[idx]) + + idx_to_consider = grad_idxs[:delta] + #print(delta, len(idx_to_consider), idx_to_consider) + for i,sample in enumerate(vector): + resampled.append(sample) + if i in idx_to_consider: + if size_to_warp < len(vector): + resampled.pop() + elif size_to_warp > len(vector): + resampled.append((vector[i]+vector[i+1])/2) + return resampled + +def segment_warp(segment): + #print("--> ",len(segment['segment'])) + #print("--> ",segment['length_to_warp']) + resampled = np.array(change_sampling(segment['segment'], segment['length_to_warp'])) + #print("--> ",len(resampled)) + resampled += (segment['v_start'] - resampled[0]) + m = (segment['v_stop'] - resampled[-1])/(segment['length_to_warp']-1) + coef_to_sum = [m*x for x in range(segment['length_to_warp'])] + resampled += coef_to_sum + return list(resampled) + +def stitch_segments(segments): + stitched = [] + for i,segment in enumerate(segments): + if i == len(segments)-1: + stitched.extend(segment) + else: + stitched.extend(segment[0:-1]) + return stitched + +def warp(data, template, events): + #print(events) + warped = {} + segments = [] # --> {"segment":v, "length_to_warp":length, "v_start":v_eb_0, "v_stop":v_eb_1} + segment = {"segment":None, "length_to_warp":None, "v_start":None, "v_stop":None} + segment_start = None + event_start = None + + #ad the first and last points of the resampled and truncated beat if not already in event + if data['t'][0] not in events['t']: + events['t'].insert(1, data['t'][0]) + events['v'].insert(1, data['v'][0]) + if data['t'][-1] not in events['t']: + events['t'].insert(-1,data['t'][-1]) + events['v'].insert(-1,data['v'][-1]) + + #print(events) + #Apply DTW for matching resampled event to template + v_src = data['v'] + #This is how its actualy done on the library when calling 'warping_path' + dist,paths = dtw.warping_paths(v_src, template) + path = dtw.best_path(paths) + #path = dtw.warping_path(v_src, template) + + #Remove the "left" steps from the path and segment the template based on the events point + prev_idx_src = path[-1][0]+1 + for idx_src, idx_temp in path: + if prev_idx_src == idx_src: + continue + prev_idx_src = idx_src + for idx, t in enumerate(events['t']): + if data['t'][idx_src] == t: + if segment_start == None: + segment_start = idx_temp + event_start = idx + else: + #print(f"SEGMENT {events['t'][event_start]} - {events['t'][idx]} ({abs(events['t'][event_start] - events['t'][idx])})") + + segment['segment'] = template[segment_start:idx_temp+1] + segment['length_to_warp'] = events['t'][idx] - events['t'][event_start] + 1 + segment['v_start'] = events['v'][event_start] + segment['v_stop'] = events['v'][idx] + w = segment_warp(segment) + #print(len(w)) + segments.append(w) + segment_start = idx_temp + event_start = idx + break + segment_stitched = stitch_segments(segments) + #print(len(segment_stitched), len(data['t'])) + warped = {'t':data['t'],'v':segment_stitched} + return dist,warped + +def reconstruct_beats(file_name, level = None, normalize = True, resample_type = 'linear', num_beats = None, verbose = False, prev_pre_proc = None): + reconstructed = {} + recon_cost = {} + resamp = {} + data_orig = {} + pre_proc = {} + + num_distances_out = 0 + time_last_out_std = 0 + skip_until = 0 + + if verbose: + print(f"Extracting {file_name}") + data,lvls = open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = level) + lvls.remove(0) + + if verbose: + print("Building template") + if prev_pre_proc is not None: + template_orig = prev_pre_proc['template'] + else: + template_orig,_,_ = build_template(file_name, normalize = normalize, mode = TEMPLATE_TYPE,t_start_seconds=0,t_stop_seconds=60*5) + + if num_beats == None: + num_beats = len(list(data.keys())) + if verbose: + print("reconstructing") + + for lvl in lvls: + i = 0 + recon_cost[lvl] = [] + if verbose: + print(f"Analyzing level:{lvl}") + template = copy(template_orig) + distances = [] + reference_average_diatance = 0 + reference_std_distance = 0 + num_distances_out = 0 + time_last_out_std = 0 + skip_until = 0 + for beat in data.keys(): + t_beat = beat/FREQ + if t_beat < skip_until: + continue + if i == num_beats: + break + i+=1 + if (i%(num_beats/20)==0 or i == 1) and verbose: + print(f"File: {file_name}, Reconstructing beat {beat} ({i}/{num_beats}: {100*i/num_beats}%, LEVEL:{lvl})") + if beat not in reconstructed.keys(): + reconstructed[beat] = {} + resamp[beat] = {} + data_orig[beat] = copy(data[beat][0]) + if normalize: + data_orig[beat]['v'],params_prev = min_max_normalization(data_orig[beat]['v'], params = None) + data[beat][lvl]['v'],_ = min_max_normalization(data[beat][lvl]['v'], params = params_prev) + data_resampled = resample(data[beat][lvl], resample_type = resample_type, min_t = data_orig[beat]['t'][0], max_t = data_orig[beat]['t'][-1]) + dist, reconstructed[beat][lvl] = warp(data_resampled,template,data[beat][lvl]) + resamp[beat][lvl] = data_resampled + recon_cost[lvl].append(dist) + + if len(distances) >= LEN_DISTANCE_VECTOR: + avg_dist = np.average(distances) + std_dist = np.std(distances) + perc_std = std_dist/avg_dist + distances.pop(0) + if (abs(dist-avg_dist)>std_dist or # The acquired beat is outside statistic + perc_std >= 1 or # The standard deviation of the previous beat is bigger than x% + std_dist > 1.5 * reference_std_distance or # The standard deviations of the previous beats are bigger than the reference for this template + avg_dist > 1.5 * reference_average_diatance): # The averages of the previous beats are bigger than the reference for this template + + time_last_out_std = t_beat + if t_beat - time_last_out_std > 2: #Delta time since last time the distance was outside one std + num_distances_out = 0 + time_last_out_std = t_beat + num_distances_out += 1 + if num_distances_out > 40: # number of beats in wich the warping distance was too big + print(f"\nBeat num:{i}, Rebuilding template...(old: {template[:3]})") + print(f"\t Beat outside std? {abs(dist-avg_dist)>std_dist}") + print(f"\t Std percent too high? {perc_std >= 1}") + print(f"\t Std too big wrt reference? {std_dist > 1.5 * reference_std_distance}") + print(f"\t Average too big wrt reference? {avg_dist > 1.5 * reference_average_diatance}") + template,_,_ = build_template(file_name, normalize = normalize, mode = TEMPLATE_TYPE,t_start_seconds=t_beat,t_stop_seconds=t_beat+40) + print(f"Template built: {template[:3]}\n") + distances = [] + skip_until = t_beat + 40 + num_distances_out = 0 + elif len(distances) == LEN_DISTANCE_VECTOR - 1: + reference_average_diatance = np.average(distances) + reference_std_distance = np.std(distances) + distances.append(dist) + pre_proc['template'] = template + + return reconstructed,data_orig,resamp,pre_proc,recon_cost + + +def percentile_idx(vector,perc): + pcen=np.percentile(np.array(vector),perc,interpolation='nearest') + i_near=abs(np.array(vector)-pcen).argmin() + return i_near + +def correlate(v1,v2): + v1_n = np.array(v1) + v2_n = np.array(v2) + cov = np.cov(v1_n,v2_n)[0,1] + cor = cov/(np.std(v1_n)*np.std(v2_n)) + return cor + +def avg_std_for_rmse_dict(dictionary): + dict_out = {'warp':{'avg':None,'std':None},'resamp':{'avg':None,'std':None}} + wr = [] + rs = [] + for beat in dictionary.keys(): + wr.append(dictionary[beat]['warp']) + rs.append(dictionary[beat]['resamp']) + dict_out['warp']['avg'] = np.average(wr) + dict_out['warp']['std'] = np.std(wr) + dict_out['resamp']['avg'] = np.average(rs) + dict_out['resamp']['std'] = np.std(rs) + return dict_out + +def percentile_rmse_beat(rmse,perc): + vec = [] + for beat in rmse.keys(): + vec.append(rmse[beat]['warp']) + idx = percentile_idx(vec,perc) + return list(rmse.keys())[idx] + +def recontruct_and_compare(file): + pre_proc = None + recon,orig,resamp,pre_proc,_ = reconstruct_beats(file, level = LEVEL, resample_type = 'flat', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear + rmse = RMSE_warped_resampled(orig,recon,resamp) + for lvl in LEVEL: + ''' + recon,orig,resamp,pre_proc = reconstruct_beats(file, level = [lvl], resample_type = 'linear', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear + rmse = RMSE_warped_resampled(orig,recon,resamp) + ''' + stats = avg_std_for_rmse_dict(rmse[lvl]) + avg_rmse_warp = stats['warp']['avg'] + std_rmse_warp = stats['warp']['std'] + avg_rmse_resamp = stats['resamp']['avg'] + std_rmse_resamp = stats['resamp']['std'] + #cov_rmse_warp_dist = correlate(recon_cost[lvl],rmse[lvl]) + file_name_to_save = "L_"+file.split(".")[0]+".log" + with open(os.path.join(log_dir,file_name_to_save),"a") as f: + f.write(f"Lvl: {lvl}\n") + f.write(f"\tWarp: {avg_rmse_warp}, +-{std_rmse_warp}\n") + f.write(f"\tInterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") + #f.write(f"\tCorrelation between warping cost and rmse: {cov_rmse_warp_dist}\n") + f.write(f"\n\n") + f.write(f"\n\n") + print(f"File:{file_name_to_save}") + print(f"\tLvl: {lvl}") + print(f"\t\twarp: {avg_rmse_warp}, +-{std_rmse_warp}") + print(f"\t\tinterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}") + + beat_to_plot_01 = percentile_rmse_beat(rmse[lvl],1) + beat_to_plot_25 = percentile_rmse_beat(rmse[lvl],25) + beat_to_plot_50 = percentile_rmse_beat(rmse[lvl],50) + beat_to_plot_75 = percentile_rmse_beat(rmse[lvl],75) + beat_to_plot_99 = percentile_rmse_beat(rmse[lvl],99) + + file_name_to_save_fig_01 = os.path.join(log_dir,file.split(".")[0]+"_01_perc"+str(lvl)+".svg") + file_name_to_save_fig_25 = os.path.join(log_dir,file.split(".")[0]+"_25_perc"+str(lvl)+".svg") + file_name_to_save_fig_50 = os.path.join(log_dir,file.split(".")[0]+"_50_perc"+str(lvl)+".svg") + file_name_to_save_fig_75 = os.path.join(log_dir,file.split(".")[0]+"_75_perc"+str(lvl)+".svg") + file_name_to_save_fig_99 = os.path.join(log_dir,file.split(".")[0]+"_99_perc"+str(lvl)+".svg") + + file_name_to_save_fig_hist = os.path.join(log_dir,file.split(".")[0]+"_hist"+str(lvl)+".svg") + file_name_to_save_fig_hist_left = os.path.join(log_dir,file.split(".")[0]+"_hist_left"+str(lvl)+".svg") + file_name_to_save_fig_hist_right = os.path.join(log_dir,file.split(".")[0]+"_hist_right"+str(lvl)+".svg") + file_name_to_save_fig_template = os.path.join(log_dir,file.split(".")[0]+"_template.svg") + + # 01 percentile + t_o,v_o = orig[beat_to_plot_01]['t'],orig[beat_to_plot_01]['v'] + t_a,v_a = recon[beat_to_plot_01][lvl]['t'],recon[beat_to_plot_01][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_01][lvl]['t'],resamp[beat_to_plot_01][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_01}, 01 percentile') + plt.savefig(file_name_to_save_fig_01) + plt.close() + + # 25 percentile + t_o,v_o = orig[beat_to_plot_25]['t'],orig[beat_to_plot_25]['v'] + t_a,v_a = recon[beat_to_plot_25][lvl]['t'],recon[beat_to_plot_25][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_25][lvl]['t'],resamp[beat_to_plot_25][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_25}, 25 percentile') + plt.savefig(file_name_to_save_fig_25) + plt.close() + + # 50 percentile + t_o,v_o = orig[beat_to_plot_50]['t'],orig[beat_to_plot_50]['v'] + t_a,v_a = recon[beat_to_plot_50][lvl]['t'],recon[beat_to_plot_50][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_50][lvl]['t'],resamp[beat_to_plot_50][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_50}, 50 percentile') + plt.savefig(file_name_to_save_fig_50) + plt.close() + + # 75 percentile + t_o,v_o = orig[beat_to_plot_75]['t'],orig[beat_to_plot_75]['v'] + t_a,v_a = recon[beat_to_plot_75][lvl]['t'],recon[beat_to_plot_75][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_75][lvl]['t'],resamp[beat_to_plot_75][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_75}, 75 percentile') + plt.savefig(file_name_to_save_fig_75) + plt.close() + + # 99 percentile + t_o,v_o = orig[beat_to_plot_99]['t'],orig[beat_to_plot_99]['v'] + t_a,v_a = recon[beat_to_plot_99][lvl]['t'],recon[beat_to_plot_99][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_99][lvl]['t'],resamp[beat_to_plot_99][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_99}, 99 percentile') + plt.savefig(file_name_to_save_fig_99) + plt.close() + + #bar plots + rmse_warp = [] + rmse_resamp = [] + rmse_warp_left = [] + rmse_resamp_left = [] + rmse_warp_right = [] + rmse_resamp_right = [] + for beat in rmse[lvl].keys(): + rmse_warp.append(rmse[lvl][beat]['warp']) + rmse_resamp.append(rmse[lvl][beat]['resamp']) + rmse_warp_left.append(rmse[lvl][beat]['warp_left']) + rmse_resamp_left.append(rmse[lvl][beat]['resamp_left']) + rmse_warp_right.append(rmse[lvl][beat]['warp_right']) + rmse_resamp_right.append(rmse[lvl][beat]['resamp_right']) + + n_bins = len(rmse_warp)*PERC_BINS//100 + min_bin = min(min(rmse_warp),min(rmse_resamp)) + max_bin = max(max(rmse_warp),max(rmse_resamp)) + delta = (max_bin-min_bin)/n_bins + bins = np.arange(min_bin,max_bin+delta,delta) + plt.hist(rmse_warp, bins = bins, alpha=0.5) + plt.hist(rmse_resamp, bins = bins, alpha=0.5) + plt.title(f'File: {file}, Lvl: {lvl}, RMSE histogram') + plt.legend(['RMSE warp','RMSE resampled']) + plt.savefig(file_name_to_save_fig_hist) + plt.close() + + n_bins = len(rmse_warp_left)*PERC_BINS//100 + min_bin = min(min(rmse_warp_left),min(rmse_resamp_left)) + max_bin = max(max(rmse_warp_left),max(rmse_resamp_left)) + delta = (max_bin-min_bin)/n_bins + bins = np.arange(min_bin,max_bin+delta,delta) + plt.hist(rmse_warp_left, bins = bins, alpha=0.5) + plt.hist(rmse_resamp_left, bins = bins, alpha=0.5) + plt.title(f'File: {file}, Lvl: {lvl}, left RMSE histogram') + plt.legend(['RMSE warp','RMSE resampled']) + plt.savefig(file_name_to_save_fig_hist_left) + plt.close() + + n_bins = len(rmse_warp_right)*PERC_BINS//100 + min_bin = min(min(rmse_warp_right),min(rmse_resamp_right)) + max_bin = max(max(rmse_warp_right),max(rmse_resamp_right)) + delta = (max_bin-min_bin)/n_bins + bins = np.arange(min_bin,max_bin+delta,delta) + plt.hist(rmse_warp_right, bins = bins, alpha=0.5) + plt.hist(rmse_resamp_right, bins = bins, alpha=0.5) + plt.title(f'File: {file}, Lvl: {lvl}, right RMSE histogram') + plt.legend(['RMSE warp','RMSE resampled']) + plt.savefig(file_name_to_save_fig_hist_right) + plt.close() + + if not os.path.isfile(file_name_to_save_fig_template): + template = pre_proc['template'] + plt.figure() + plt.plot(template) + plt.legend(['template']) + plt.title(f'File: {file}, template') + plt.savefig(file_name_to_save_fig_template) + plt.close() + +def process(files, multi=True, cores=1): + # ------------ INIT ------------ + global log_dir + for i in range(1,1000): + tmp_log_dir = log_dir+str(i) + if not os.path.isdir(tmp_log_dir): + log_dir = tmp_log_dir + break + os.mkdir(log_dir) + + # ------------ Extract DATA & ANNOTATIONS ------------ + with Pool(cores) as pool: + pool.map(recontruct_and_compare, files) + + +if __name__ == "__main__": + import argparse + #global NUM_BEAT_ANALYZED + parser = argparse.ArgumentParser() + parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") + parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true") + parser.add_argument("--cores", help="Force used number of cores (default, half of the available ones") + parser.add_argument("--beats", help="Number of used beats, default: 5000") + parser.add_argument("--template_type", help="Type of template, default: distance") + args = parser.parse_args() + files = os.listdir(data_beats_dir) + if args.file is not None: + if args.file == 'all': + analyzed = files + else: + analyzed = list(filter(lambda string: True if args.file in string else False, files)) + else: + analyzed = [files[0]] + if args.not_norm: + normalize = False + else: + normalize = True + if args.cores is not None: + used_cores = int(args.cores) + else: + used_cores = multiprocessing.cpu_count()//2 + if args.beats is not None: + NUM_BEAT_ANALYZED = int(args.beats) + else: + NUM_BEAT_ANALYZED = 5000 + if args.template_type is not None: + TEMPLATE_TYPE = 'average' + else: + TEMPLATE_TYPE = 'distance' + + print(f"Analyzing files: {analyzed}") + print(f"Extracting data with {used_cores} cores...") + process(files = analyzed, multi=True, cores=used_cores) \ No newline at end of file diff --git a/src/beat_reconstruction_template_segmentation.py b/src/segment_recon.py similarity index 55% rename from src/beat_reconstruction_template_segmentation.py rename to src/segment_recon.py index 677edc0..7b711a2 100644 --- a/src/beat_reconstruction_template_segmentation.py +++ b/src/segment_recon.py @@ -1,371 +1,521 @@ import multiprocessing import os import pickle import shutil from multiprocessing import Pool from time import time import matplotlib.pyplot as plt import numpy as np from dtaidistance import dtw from scipy.interpolate import interp1d from copy import copy from build_template import build_template +#plt.style.use('seaborn-deep') + data_beats_dir = "../data/beats/" log_dir = "../data/beat_recon_logs_" FREQ = 128 MIN_TO_AVG = 5 -BEAT_TO_PLOT = 2 LEVEL = [1,2,3,4,5,6,7,8,9,10,11] -NUM_BEAT_ANALYZED = 50 +NUM_BEAT_ANALYZED = 50000 TEMPLATE_TYPE = 'distance' +PERC_BINS = 15 def RMSE(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) return np.sqrt(np.mean((v1_n-v2_n)**2)) def RMSE_warped_resampled(data_orig,data_warped,data_resampled): rmse = {} - for beat in data_orig.keys(): - t_start = data_orig[beat]['t'][0] - t_stop = data_orig[beat]['t'][-1] + for beat in data_warped.keys(): for lvl in data_warped[beat].keys(): if lvl not in rmse.keys(): - rmse[lvl] = {'warp':[],'resamp':[]} - v_warped = [] - v_resampled = [] - for i,t in enumerate(data_warped[beat][lvl]['t']): - if t>=t_start and t<=t_stop: - v_warped.append(data_warped[beat][lvl]['v'][i]) - v_resampled.append(data_resampled[beat][lvl]['v'][i]) - elif t > t_stop: - break - rmse[lvl]['warp'].append(RMSE(data_orig[beat]['v'],v_warped)) - rmse[lvl]['resamp'].append(RMSE(data_orig[beat]['v'],v_resampled)) + rmse[lvl] = {'warp':[],'resamp':[],'warp_left':[],'resamp_left':[],'warp_right':[],'resamp_right':[]} + idx_beat = data_warped[beat][lvl]['t'].index(beat) + l_resamp = data_resampled[beat][lvl]['v'][:idx_beat] + r_resamp = data_resampled[beat][lvl]['v'][idx_beat+1:] + l_warp = data_warped[beat][lvl]['v'][:idx_beat] + r_warp = data_warped[beat][lvl]['v'][idx_beat+1:] + l_orig = data_orig[beat]['v'][:idx_beat] + r_orig = data_orig[beat]['v'][idx_beat+1:] + + rmse[lvl]['warp'].append(RMSE(data_orig[beat]['v'],data_warped[beat][lvl]['v'])) + rmse[lvl]['resamp'].append(RMSE(data_orig[beat]['v'],data_resampled[beat][lvl]['v'])) + + rmse[lvl]['warp_left'].append(RMSE(l_orig,l_warp)) + rmse[lvl]['resamp_left'].append(RMSE(l_orig,l_resamp)) + + rmse[lvl]['warp_right'].append(RMSE(r_orig,r_warp)) + rmse[lvl]['resamp_right'].append(RMSE(r_orig,r_resamp)) return rmse - def open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = None): ''' Data structure: File: | DICTIONARY: | --> Beat_annot: | --> lvl: | --> "t": [] | --> "v": [] ''' file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) data = {} data_out = {} with open(file_name_full,"rb") as f: data = pickle.load(f) for k in data.keys(): if k > FREQ*start_after: data_out[k] = {} if get_selected_level is not None: data_out[k][0] = data[k][0] for lvl in get_selected_level: data_out[k][lvl] = data[k][lvl] else: data_out[k] = data[k] return data_out, list(data_out[k].keys()) -def min_max_normalization(vector, forced_avg = None): - mi_v = min(vector) - ma_v = max(vector) - norm = (np.array(vector)-mi_v)/(ma_v-mi_v) - if forced_avg is not None: - avg = forced_avg +def min_max_normalization(vector, params = None): + params_out = {} + #print("\t",params) + if params is not None: + mi_v = params['min'] + ma_v = params['max'] + avg = params['avg'] + norm = (np.array(vector)-mi_v)/(ma_v-mi_v) + norm -= avg + params_out = params else: + mi_v = min(vector) + ma_v = max(vector) + norm = (np.array(vector)-mi_v)/(ma_v-mi_v) avg = np.average(norm) - norm -= avg - return list(norm),avg + norm -= avg + params_out['min'] = mi_v + params_out['max'] = ma_v + params_out['avg'] = avg + return list(norm),params_out def resamp_one_signal(t,v,resample_type = 'linear', min_t = None, max_t = None): if resample_type == "linear": f = interp1d(t,v) elif resample_type == "flat": f = interp1d(t,v, kind = 'previous') if min_t is None: min_t = t[0] if max_t is None: max_t = t[-1] t_new = list(range(min_t,max_t+1)) v_new = f(t_new) return t_new,v_new def resample(data, resample_type = "linear", min_t = None, max_t = None): resampled_data = {"t":None,"v":None} t = data['t'] v = data['v'] t_r,v_r = resamp_one_signal(t,v,resample_type = resample_type, min_t = min_t, max_t = max_t) resampled_data['t'] = t_r resampled_data['v'] = v_r return resampled_data def change_sampling(vector, size_to_warp): resampled = [] if len(vector) < 3: #If the length of the vector is either 2 or 1 it will be mnorfed in a straight line and such line will be morfed equaly resampled = [vector[0]]*size_to_warp return resampled delta = abs(size_to_warp - len(vector)) # --> number of points to insert/delete #At least one more sample for each sample (except the last one)? --> insert k in each hole and recompute delta if delta >= len(vector)-1 and delta > 0: holes = len(vector)-1 k = delta // holes #Number of point for each hole time = range(0,len(vector)*(k+1),k+1) # [o1,i11,i12,...,i1k, o2,i21,i22,...,i2k, o3,..., o_fin] --> NOTE: last sampel is not taken # 0 k+1 2k+2 l*k +l --> Hence, we get to the exact end t_new = range(time[-1]+1) f = interp1d(time,vector) vector = list(f(t_new)) delta = abs(size_to_warp - len(vector)) if len(vector) == size_to_warp: return vector grad = np.gradient(vector) grad[-1] = max(grad) + 1 # --> we don't want to insert anything after the last sample grad_idxs = sorted(range(len(grad)), key = lambda idx: grad[idx]) idx_to_consider = grad_idxs[:delta] #print(delta, len(idx_to_consider), idx_to_consider) for i,sample in enumerate(vector): resampled.append(sample) if i in idx_to_consider: if size_to_warp < len(vector): resampled.pop() elif size_to_warp > len(vector): resampled.append((vector[i]+vector[i+1])/2) return resampled def segment_warp(segment): #print("--> ",len(segment['segment'])) #print("--> ",segment['length_to_warp']) resampled = np.array(change_sampling(segment['segment'], segment['length_to_warp'])) #print("--> ",len(resampled)) resampled += (segment['v_start'] - resampled[0]) m = (segment['v_stop'] - resampled[-1])/(segment['length_to_warp']-1) coef_to_sum = [m*x for x in range(segment['length_to_warp'])] resampled += coef_to_sum return list(resampled) def stitch_segments(segments): stitched = [] for i,segment in enumerate(segments): if i == len(segments)-1: stitched.extend(segment) else: stitched.extend(segment[0:-1]) return stitched def warp(data, template, events): #print(events) warped = {} segments = [] # --> {"segment":v, "length_to_warp":length, "v_start":v_eb_0, "v_stop":v_eb_1} segment = {"segment":None, "length_to_warp":None, "v_start":None, "v_stop":None} segment_start = None event_start = None #ad the first and last points of the resampled and truncated beat if not already in event if data['t'][0] not in events['t']: events['t'].insert(1, data['t'][0]) events['v'].insert(1, data['v'][0]) if data['t'][-1] not in events['t']: events['t'].insert(-1,data['t'][-1]) events['v'].insert(-1,data['v'][-1]) #print(events) #Apply DTW for matching resampled event to template v_src = data['v'] - path = dtw.warping_path(v_src, template) + #This is how its actualy done on the library when calling 'warping_path', we do it like this top obtain also the distance + dist,paths = dtw.warping_paths(v_src, template) + path = dtw.best_path(paths) #Remove the "left" steps from the path and segment the template based on the events point prev_idx_src = path[-1][0]+1 for idx_src, idx_temp in path: if prev_idx_src == idx_src: continue prev_idx_src = idx_src for idx, t in enumerate(events['t']): if data['t'][idx_src] == t: if segment_start == None: segment_start = idx_temp event_start = idx else: #print(f"SEGMENT {events['t'][event_start]} - {events['t'][idx]} ({abs(events['t'][event_start] - events['t'][idx])})") segment['segment'] = template[segment_start:idx_temp+1] segment['length_to_warp'] = events['t'][idx] - events['t'][event_start] + 1 segment['v_start'] = events['v'][event_start] segment['v_stop'] = events['v'][idx] w = segment_warp(segment) #print(len(w)) segments.append(w) segment_start = idx_temp event_start = idx break segment_stitched = stitch_segments(segments) #print(len(segment_stitched), len(data['t'])) warped = {'t':data['t'],'v':segment_stitched} - return warped + return warped,dist def reconstruct_beats(file_name, level = None, normalize = True, resample_type = 'linear', num_beats = None, verbose = False, prev_pre_proc = None): reconstructed = {} + recon_cost = {} resamp = {} data_orig = {} pre_proc = {} if verbose: print(f"Extracting {file_name}") data,lvls = open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = level) lvls.remove(0) if verbose: print("Building template") if prev_pre_proc is not None: template = prev_pre_proc['template'] else: template,_,_ = build_template(file_name, normalize = normalize, mode = TEMPLATE_TYPE) if num_beats == None: num_beats = len(list(data.keys())) if verbose: print("reconstructing") - + t0 = time() for lvl in lvls: i = 0 + recon_cost[lvl] = [] if verbose: print(f"Analyzing level:{lvl}") for beat in data.keys(): if i == num_beats: break i+=1 if (i%(num_beats/20)==0 or i == 1) and verbose: print(f"File: {file_name}, Reconstructing beat {beat} ({i}/{num_beats}: {100*i/num_beats}%, LEVEL:{lvl})") if beat not in reconstructed.keys(): reconstructed[beat] = {} resamp[beat] = {} - data_orig[beat] = data[beat][0] + data_orig[beat] = copy(data[beat][0]) if normalize: - data_orig[beat]['v'],avg = min_max_normalization(data_orig[beat]['v'], forced_avg = None) - data[beat][lvl]['v'],_ = min_max_normalization(data[beat][lvl]['v'], forced_avg = avg) + #print(max(data[beat][lvl]['v'])-min(data[beat][lvl]['v']),max(data_orig[beat]['v'])-min(data_orig[beat]['v'])) + data_orig[beat]['v'],params_prev = min_max_normalization(data_orig[beat]['v'], params = None) + data[beat][lvl]['v'],_ = min_max_normalization(data[beat][lvl]['v'], params = params_prev) + #print(max(data[beat][lvl]['v'])-min(data[beat][lvl]['v']),max(data_orig[beat]['v'])-min(data_orig[beat]['v'])) + #print("-"*40) - + # HERE in the resampling all the lenght are adjusted to the original beat length: + # 'resample' compute the resampled beat and allign its length wrt the original beat + # 'warp' compute the warped template, the result has the same lenght by construction data_resampled = resample(data[beat][lvl], resample_type = resample_type, min_t = data_orig[beat]['t'][0], max_t = data_orig[beat]['t'][-1]) - reconstructed[beat][lvl] = warp(data_resampled,template,data[beat][lvl]) + reconstructed[beat][lvl], warp_cost = warp(data_resampled,template,data[beat][lvl]) resamp[beat][lvl] = data_resampled + recon_cost[lvl].append(warp_cost) + t_e = time() + print("Time elapsed: ",t_e-t0) pre_proc['template'] = template - return reconstructed,data_orig,resamp,pre_proc + return reconstructed,data_orig,resamp,pre_proc,recon_cost + +def percentile_idx(vector,perc): + pcen=np.percentile(np.array(vector),perc,interpolation='nearest') + i_near=abs(np.array(vector)-pcen).argmin() + return i_near + +def correlate(v1,v2): + v1_n = np.array(v1) + v2_n = np.array(v2) + cov = np.cov(v1_n,v2_n)[0,1] + cor = cov/(np.std(v1_n)*np.std(v2_n)) + return cor def recontruct_and_compare(file): pre_proc = None - recon,orig,resamp,pre_proc = reconstruct_beats(file, level = LEVEL, resample_type = 'flat', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear + recon,orig,resamp,pre_proc,recon_cost = reconstruct_beats(file, level = LEVEL, resample_type = 'flat', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear rmse = RMSE_warped_resampled(orig,recon,resamp) for lvl in LEVEL: ''' recon,orig,resamp,pre_proc = reconstruct_beats(file, level = [lvl], resample_type = 'linear', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear rmse = RMSE_warped_resampled(orig,recon,resamp) ''' avg_rmse_warp = np.average(rmse[lvl]['warp']) std_rmse_warp = np.std(rmse[lvl]['warp']) avg_rmse_resamp = np.average(rmse[lvl]['resamp']) std_rmse_resamp = np.std(rmse[lvl]['resamp']) + cov_rmse_warp_dist = correlate(recon_cost[lvl],rmse[lvl]['warp']) file_name_to_save = "L_"+file.split(".")[0]+".log" with open(os.path.join(log_dir,file_name_to_save),"a") as f: f.write(f"Lvl: {lvl}\n") f.write(f"\tWarp: {avg_rmse_warp}, +-{std_rmse_warp}\n") f.write(f"\tInterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") + f.write(f"\tCorrelation between warping cost and rmse: {cov_rmse_warp_dist}\n") f.write(f"\n\n") print(f"File:{file_name_to_save}") print(f"\tLvl: {lvl}") print(f"\t\twarp: {avg_rmse_warp}, +-{std_rmse_warp}") print(f"\t\tinterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}") - file_name_to_save_fig = os.path.join(log_dir,file.split(".")[0]+"_"+str(lvl)+".svg") + pos_01_perc = percentile_idx(rmse[lvl]['warp'],1) + pos_25_perc = percentile_idx(rmse[lvl]['warp'],25) + pos_50_perc = percentile_idx(rmse[lvl]['warp'],50) + pos_75_perc = percentile_idx(rmse[lvl]['warp'],75) + pos_99_perc = percentile_idx(rmse[lvl]['warp'],99) + + file_name_to_save_fig_25 = os.path.join(log_dir,file.split(".")[0]+"_25_perc"+str(lvl)+".svg") + file_name_to_save_fig_50 = os.path.join(log_dir,file.split(".")[0]+"_50_perc"+str(lvl)+".svg") + file_name_to_save_fig_75 = os.path.join(log_dir,file.split(".")[0]+"_75_perc"+str(lvl)+".svg") + file_name_to_save_fig_01 = os.path.join(log_dir,file.split(".")[0]+"_01_perc"+str(lvl)+".svg") + file_name_to_save_fig_99 = os.path.join(log_dir,file.split(".")[0]+"_99_perc"+str(lvl)+".svg") + file_name_to_save_fig_hist = os.path.join(log_dir,file.split(".")[0]+"_hist"+str(lvl)+".svg") + file_name_to_save_fig_hist_left = os.path.join(log_dir,file.split(".")[0]+"_hist_left"+str(lvl)+".svg") + file_name_to_save_fig_hist_right = os.path.join(log_dir,file.split(".")[0]+"_hist_right"+str(lvl)+".svg") file_name_to_save_fig_template = os.path.join(log_dir,file.split(".")[0]+"_template.svg") - beat_to_plot = list(recon.keys())[BEAT_TO_PLOT] - t_o,v_o = orig[beat_to_plot]['t'],orig[beat_to_plot]['v'] - t_a,v_a = recon[beat_to_plot][lvl]['t'],recon[beat_to_plot][lvl]['v'] - t_r,v_r = resamp[beat_to_plot][lvl]['t'],resamp[beat_to_plot][lvl]['v'] + beat_to_plot_25 = list(recon.keys())[pos_25_perc] + beat_to_plot_50 = list(recon.keys())[pos_50_perc] + beat_to_plot_75 = list(recon.keys())[pos_75_perc] + beat_to_plot_01 = list(recon.keys())[pos_01_perc] + beat_to_plot_99 = list(recon.keys())[pos_99_perc] + + # 01 percentile + t_o,v_o = orig[beat_to_plot_01]['t'],orig[beat_to_plot_01]['v'] + t_a,v_a = recon[beat_to_plot_01][lvl]['t'],recon[beat_to_plot_01][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_01][lvl]['t'],resamp[beat_to_plot_01][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_01}, 01 percentile') + plt.savefig(file_name_to_save_fig_01) + plt.close() + + # 25 percentile + t_o,v_o = orig[beat_to_plot_25]['t'],orig[beat_to_plot_25]['v'] + t_a,v_a = recon[beat_to_plot_25][lvl]['t'],recon[beat_to_plot_25][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_25][lvl]['t'],resamp[beat_to_plot_25][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_25}, 25 percentile') + plt.savefig(file_name_to_save_fig_25) + plt.close() + + # 50 percentile + t_o,v_o = orig[beat_to_plot_50]['t'],orig[beat_to_plot_50]['v'] + t_a,v_a = recon[beat_to_plot_50][lvl]['t'],recon[beat_to_plot_50][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_50][lvl]['t'],resamp[beat_to_plot_50][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) - plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot}') - plt.savefig(file_name_to_save_fig) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_50}, 50 percentile') + plt.savefig(file_name_to_save_fig_50) plt.close() + + # 75 percentile + t_o,v_o = orig[beat_to_plot_75]['t'],orig[beat_to_plot_75]['v'] + t_a,v_a = recon[beat_to_plot_75][lvl]['t'],recon[beat_to_plot_75][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_75][lvl]['t'],resamp[beat_to_plot_75][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_75}, 75 percentile') + plt.savefig(file_name_to_save_fig_75) + plt.close() + + # 99 percentile + t_o,v_o = orig[beat_to_plot_99]['t'],orig[beat_to_plot_99]['v'] + t_a,v_a = recon[beat_to_plot_99][lvl]['t'],recon[beat_to_plot_99][lvl]['v'] + t_r,v_r = resamp[beat_to_plot_99][lvl]['t'],resamp[beat_to_plot_99][lvl]['v'] + plt.figure() + plt.plot(t_o,v_o) + plt.plot(t_a,v_a) + plt.plot(t_r,v_r) + plt.legend(['original','warped template','resampled']) + plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_99}, 99 percentile') + plt.savefig(file_name_to_save_fig_99) + plt.close() + + #bar plots + n_bins = len(rmse[lvl]['warp'])*PERC_BINS//100 + min_bin = min(min(rmse[lvl]['warp']),min(rmse[lvl]['resamp'])) + max_bin = max(max(rmse[lvl]['warp']),max(rmse[lvl]['resamp'])) + delta = (max_bin-min_bin)/n_bins + bins = np.arange(min_bin,max_bin+delta,delta) + plt.hist(rmse[lvl]['warp'], bins = bins, alpha=0.5) + plt.hist(rmse[lvl]['resamp'], bins = bins, alpha=0.5) + plt.title(f'File: {file}, Lvl: {lvl}, RMSE histogram') + plt.legend(['RMSE warp','RMSE resampled']) + plt.savefig(file_name_to_save_fig_hist) + plt.close() + + n_bins = len(rmse[lvl]['warp'])*PERC_BINS//100 + min_bin = min(min(rmse[lvl]['warp_left']),min(rmse[lvl]['resamp_left'])) + max_bin = max(max(rmse[lvl]['warp_left']),max(rmse[lvl]['resamp_left'])) + delta = (max_bin-min_bin)/n_bins + bins = np.arange(min_bin,max_bin+delta,delta) + plt.hist(rmse[lvl]['warp_left'], bins = bins, alpha=0.5) + plt.hist(rmse[lvl]['resamp_left'], bins = bins, alpha=0.5) + plt.title(f'File: {file}, Lvl: {lvl}, left RMSE histogram') + plt.legend(['RMSE warp','RMSE resampled']) + plt.savefig(file_name_to_save_fig_hist_left) + plt.close() + + n_bins = len(rmse[lvl]['warp'])*PERC_BINS//100 + min_bin = min(min(rmse[lvl]['warp_right']),min(rmse[lvl]['resamp_right'])) + max_bin = max(max(rmse[lvl]['warp_right']),max(rmse[lvl]['resamp_right'])) + delta = (max_bin-min_bin)/n_bins + bins = np.arange(min_bin,max_bin+delta,delta) + plt.hist(rmse[lvl]['warp_right'], bins = bins, alpha=0.5) + plt.hist(rmse[lvl]['resamp_right'], bins = bins, alpha=0.5) + plt.title(f'File: {file}, Lvl: {lvl}, right RMSE histogram') + plt.legend(['RMSE warp','RMSE resampled']) + plt.savefig(file_name_to_save_fig_hist_right) + plt.close() + if not os.path.isfile(file_name_to_save_fig_template): template = pre_proc['template'] plt.figure() plt.plot(template) plt.legend(['template']) - plt.title(f'File: {file}, Beat time (samples):{beat_to_plot}') + plt.title(f'File: {file}, template') plt.savefig(file_name_to_save_fig_template) plt.close() def process(files, multi=True, cores=1): # ------------ INIT ------------ global log_dir for i in range(1,1000): tmp_log_dir = log_dir+str(i) if not os.path.isdir(tmp_log_dir): log_dir = tmp_log_dir break os.mkdir(log_dir) # ------------ Extract DATA & ANNOTATIONS ------------ with Pool(cores) as pool: pool.map(recontruct_and_compare, files) if __name__ == "__main__": import argparse #global NUM_BEAT_ANALYZED parser = argparse.ArgumentParser() parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true") parser.add_argument("--cores", help="Force used number of cores (default, half of the available ones") parser.add_argument("--beats", help="Number of used beats, default: 5000") parser.add_argument("--template_type", help="Type of template, default: distance") args = parser.parse_args() files = os.listdir(data_beats_dir) if args.file is not None: if args.file == 'all': analyzed = files else: analyzed = list(filter(lambda string: True if args.file in string else False, files)) else: analyzed = [files[0]] if args.not_norm: normalize = False else: normalize = True if args.cores is not None: used_cores = int(args.cores) else: used_cores = multiprocessing.cpu_count()//2 if args.beats is not None: NUM_BEAT_ANALYZED = int(args.beats) else: NUM_BEAT_ANALYZED = 5000 if args.template_type is not None: - TEMPLATE_TYPE = 'avreage' + TEMPLATE_TYPE = 'average' else: TEMPLATE_TYPE = 'distance' print(f"Analyzing files: {analyzed}") print(f"Extracting data with {used_cores} cores...") process(files = analyzed, multi=True, cores=used_cores) \ No newline at end of file