diff --git a/src/multiple_templates.py b/src/multiple_templates.py index fa16d65..b4fb604 100644 --- a/src/multiple_templates.py +++ b/src/multiple_templates.py @@ -1,563 +1,563 @@ import multiprocessing import os import pickle import shutil from multiprocessing import Pool from time import time import matplotlib.pyplot as plt import numpy as np from dtaidistance import dtw from scipy.interpolate import interp1d from copy import copy from build_template import build_template, multi_template data_beats_dir = "../data/beats/" -log_dir = "../data/beat_recon_logs_multi_" +log_dir = "../data/beat_recon_logs_multi_beginning_" FREQ = 128 MIN_TO_AVG = 5 BEAT_TO_PLOT = 2 LEVEL = [1,2,3,4,5,6,7,8,9,10,11] NUM_BEAT_ANALYZED = 50 LEN_DISTANCE_VECTOR = 50 TEMPLATE_TYPE = 'distance' PERC_BINS = 15 def RMSE(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) return np.sqrt(np.mean((v1_n-v2_n)**2)) def RMSE_warped_resampled(data_orig,data_warped,data_resampled): rmse = {} template = {'warp':None,'resamp':None,'warp_left':None,'resamp_left':None,'warp_right':None,'resamp_right':None} for beat in data_warped.keys(): for lvl in data_warped[beat].keys(): if lvl not in rmse.keys(): rmse[lvl] = {} rmse[lvl][beat] = copy(template) idx_beat = data_warped[beat][lvl]['t'].index(beat) l_resamp = data_resampled[beat][lvl]['v'][:idx_beat] r_resamp = data_resampled[beat][lvl]['v'][idx_beat+1:] l_warp = data_warped[beat][lvl]['v'][:idx_beat] r_warp = data_warped[beat][lvl]['v'][idx_beat+1:] l_orig = data_orig[beat]['v'][:idx_beat] r_orig = data_orig[beat]['v'][idx_beat+1:] rmse[lvl][beat]['warp'] = RMSE(data_orig[beat]['v'],data_warped[beat][lvl]['v']) rmse[lvl][beat]['resamp'] = RMSE(data_orig[beat]['v'],data_resampled[beat][lvl]['v']) rmse[lvl][beat]['warp_left'] = RMSE(l_orig,l_warp) rmse[lvl][beat]['resamp_left'] = RMSE(l_orig,l_resamp) rmse[lvl][beat]['warp_right'] = RMSE(r_orig,r_warp) rmse[lvl][beat]['resamp_right'] = RMSE(r_orig,r_resamp) return rmse def open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = None): ''' Data structure: File: | DICTIONARY: | --> Beat_annot: | --> lvl: | --> "t": [] | --> "v": [] ''' file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) data = {} data_out = {} with open(file_name_full,"rb") as f: data = pickle.load(f) for k in data.keys(): if k > FREQ*start_after: data_out[k] = {} if get_selected_level is not None: data_out[k][0] = data[k][0] for lvl in get_selected_level: data_out[k][lvl] = data[k][lvl] else: data_out[k] = data[k] return data_out, list(data_out[k].keys()) def min_max_normalization(vector, params = None): params_out = {} #print("\t",params) if params is not None: mi_v = params['min'] ma_v = params['max'] avg = params['avg'] norm = (np.array(vector)-mi_v)/(ma_v-mi_v) norm -= avg params_out = params else: mi_v = min(vector) ma_v = max(vector) norm = (np.array(vector)-mi_v)/(ma_v-mi_v) avg = np.average(norm) norm -= avg params_out['min'] = mi_v params_out['max'] = ma_v params_out['avg'] = avg return list(norm),params_out def resamp_one_signal(t,v,resample_type = 'linear', min_t = None, max_t = None): if resample_type == "linear": f = interp1d(t,v) elif resample_type == "flat": f = interp1d(t,v, kind = 'previous') if min_t is None: min_t = t[0] if max_t is None: max_t = t[-1] t_new = list(range(min_t,max_t+1)) v_new = f(t_new) return t_new,v_new def resample(data, resample_type = "linear", min_t = None, max_t = None): resampled_data = {"t":None,"v":None} t = data['t'] v = data['v'] t_r,v_r = resamp_one_signal(t,v,resample_type = resample_type, min_t = min_t, max_t = max_t) resampled_data['t'] = t_r resampled_data['v'] = v_r return resampled_data def change_sampling(vector, size_to_warp): resampled = [] if len(vector) < 3: #If the length of the vector is either 2 or 1 it will be mnorfed in a straight line and such line will be morfed equaly resampled = [vector[0]]*size_to_warp return resampled delta = abs(size_to_warp - len(vector)) # --> number of points to insert/delete #At least one more sample for each sample (except the last one)? --> insert k in each hole and recompute delta if delta >= len(vector)-1 and delta > 0: holes = len(vector)-1 k = delta // holes #Number of point for each hole time = range(0,len(vector)*(k+1),k+1) # [o1,i11,i12,...,i1k, o2,i21,i22,...,i2k, o3,..., o_fin] --> NOTE: last sampel is not taken # 0 k+1 2k+2 l*k +l --> Hence, we get to the exact end t_new = range(time[-1]+1) f = interp1d(time,vector) vector = list(f(t_new)) delta = abs(size_to_warp - len(vector)) if len(vector) == size_to_warp: return vector grad = np.gradient(vector) grad[-1] = max(grad) + 1 # --> we don't want to insert anything after the last sample grad_idxs = sorted(range(len(grad)), key = lambda idx: grad[idx]) idx_to_consider = grad_idxs[:delta] #print(delta, len(idx_to_consider), idx_to_consider) for i,sample in enumerate(vector): resampled.append(sample) if i in idx_to_consider: if size_to_warp < len(vector): resampled.pop() elif size_to_warp > len(vector): resampled.append((vector[i]+vector[i+1])/2) return resampled def segment_warp(segment): #print("--> ",len(segment['segment'])) #print("--> ",segment['length_to_warp']) resampled = np.array(change_sampling(segment['segment'], segment['length_to_warp'])) #print("--> ",len(resampled)) resampled += (segment['v_start'] - resampled[0]) m = (segment['v_stop'] - resampled[-1])/(segment['length_to_warp']-1) coef_to_sum = [m*x for x in range(segment['length_to_warp'])] resampled += coef_to_sum return list(resampled) def stitch_segments(segments): stitched = [] for i,segment in enumerate(segments): if i == len(segments)-1: stitched.extend(segment) else: stitched.extend(segment[0:-1]) return stitched def warp(data, templates, events): #print(events) warped = {} segments = [] # --> {"segment":v, "length_to_warp":length, "v_start":v_eb_0, "v_stop":v_eb_1} segment = {"segment":None, "length_to_warp":None, "v_start":None, "v_stop":None} segment_start = None event_start = None #chosen_template = [] #ad the first and last points of the resampled and truncated beat if not already in event if data['t'][0] not in events['t']: events['t'].insert(1, data['t'][0]) events['v'].insert(1, data['v'][0]) if data['t'][-1] not in events['t']: events['t'].insert(-1,data['t'][-1]) events['v'].insert(-1,data['v'][-1]) #print(events) #Apply DTW for matching resampled event to template v_src = data['v'] #This is how its actualy done on the library when calling 'warping_path' dist = float('inf') paths = [] selected_template = [] for t in templates: dist_this_template, paths_this_template = dtw.warping_paths(v_src, t) if dist_this_template < dist: dist = dist_this_template paths = paths_this_template selected_template = t path = dtw.best_path(paths) #path = dtw.warping_path(v_src, template) #Remove the "left" steps from the path and segment the template based on the events point prev_idx_src = path[-1][0]+1 for idx_src, idx_temp in path: if prev_idx_src == idx_src: continue prev_idx_src = idx_src for idx, t in enumerate(events['t']): if data['t'][idx_src] == t: if segment_start == None: segment_start = idx_temp event_start = idx else: #print(f"SEGMENT {events['t'][event_start]} - {events['t'][idx]} ({abs(events['t'][event_start] - events['t'][idx])})") segment['segment'] = selected_template[segment_start:idx_temp+1] segment['length_to_warp'] = events['t'][idx] - events['t'][event_start] + 1 segment['v_start'] = events['v'][event_start] segment['v_stop'] = events['v'][idx] w = segment_warp(segment) #print(len(w)) segments.append(w) segment_start = idx_temp event_start = idx break segment_stitched = stitch_segments(segments) #print(len(segment_stitched), len(data['t'])) warped = {'t':data['t'],'v':segment_stitched} return dist, warped, selected_template def reconstruct_beats(file_name, level = None, normalize = True, resample_type = 'linear', num_beats = None, verbose = False, prev_pre_proc = None): reconstructed = {} recon_cost = {} resamp = {} data_orig = {} pre_proc = {} templates = [] templates_orig = [] if verbose: print(f"Extracting {file_name}") data,lvls = open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = level) lvls.remove(0) if verbose: print("Building template") if prev_pre_proc is not None: templates_orig = prev_pre_proc['template'] else: templates_orig,_ = multi_template(file_name, normalize = normalize, t_start_seconds=0,t_stop_seconds=60*5) print(f"Initial number of templates: {len(templates_orig)}") if num_beats == None: num_beats = len(list(data.keys())) if verbose: print("reconstructing") for lvl in lvls: i = 0 recon_cost[lvl] = [] if verbose: print(f"Analyzing level:{lvl}") templates = copy(templates_orig) for beat in data.keys(): if i == num_beats: break i+=1 if (i%(num_beats/20)==0 or i == 1) and verbose: print(f"File: {file_name}, Reconstructing beat {beat} ({i}/{num_beats}: {100*i/num_beats}%, LEVEL:{lvl})") if beat not in reconstructed.keys(): reconstructed[beat] = {} resamp[beat] = {} data_orig[beat] = copy(data[beat][0]) if normalize: data_orig[beat]['v'],params_prev = min_max_normalization(data_orig[beat]['v'], params = None) data[beat][lvl]['v'],_ = min_max_normalization(data[beat][lvl]['v'], params = params_prev) data_resampled = resample(data[beat][lvl], resample_type = resample_type, min_t = data_orig[beat]['t'][0], max_t = data_orig[beat]['t'][-1]) _, reconstructed[beat][lvl],_ = warp(data_resampled,templates,data[beat][lvl]) resamp[beat][lvl] = data_resampled pre_proc['template'] = templates_orig return reconstructed,data_orig,resamp,pre_proc,templates def percentile_idx(vector,perc): pcen=np.percentile(np.array(vector),perc,interpolation='nearest') i_near=abs(np.array(vector)-pcen).argmin() return i_near def correlate(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) cov = np.cov(v1_n,v2_n)[0,1] cor = cov/(np.std(v1_n)*np.std(v2_n)) return cor def avg_std_for_rmse_dict(dictionary): dict_out = {'warp':{'avg':None,'std':None},'resamp':{'avg':None,'std':None}} wr = [] rs = [] for beat in dictionary.keys(): wr.append(dictionary[beat]['warp']) rs.append(dictionary[beat]['resamp']) dict_out['warp']['avg'] = np.average(wr) dict_out['warp']['std'] = np.std(wr) dict_out['resamp']['avg'] = np.average(rs) dict_out['resamp']['std'] = np.std(rs) return dict_out def percentile_rmse_beat(rmse,perc): vec = [] for beat in rmse.keys(): vec.append(rmse[beat]['warp']) idx = percentile_idx(vec,perc) return list(rmse.keys())[idx] def recontruct_and_compare(file): pre_proc = None recon,orig,resamp,pre_proc,templates = reconstruct_beats(file, level = LEVEL, resample_type = 'flat', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear rmse = RMSE_warped_resampled(orig,recon,resamp) for lvl in LEVEL: ''' recon,orig,resamp,pre_proc = reconstruct_beats(file, level = [lvl], resample_type = 'linear', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear rmse = RMSE_warped_resampled(orig,recon,resamp) ''' stats = avg_std_for_rmse_dict(rmse[lvl]) avg_rmse_warp = stats['warp']['avg'] std_rmse_warp = stats['warp']['std'] avg_rmse_resamp = stats['resamp']['avg'] std_rmse_resamp = stats['resamp']['std'] #cov_rmse_warp_dist = correlate(recon_cost[lvl],rmse[lvl]) file_name_to_save = "L_"+file.split(".")[0]+".log" with open(os.path.join(log_dir,file_name_to_save),"a") as f: f.write(f"Lvl: {lvl}\n") f.write(f"\tWarp: {avg_rmse_warp}, +-{std_rmse_warp}\n") f.write(f"\tInterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") #f.write(f"\tCorrelation between warping cost and rmse: {cov_rmse_warp_dist}\n") f.write(f"\n\n") f.write(f"\n\n") print(f"File:{file_name_to_save}") print(f"\tLvl: {lvl}") print(f"\t\twarp: {avg_rmse_warp}, +-{std_rmse_warp}") print(f"\t\tinterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}") beat_to_plot_01 = percentile_rmse_beat(rmse[lvl],1) beat_to_plot_25 = percentile_rmse_beat(rmse[lvl],25) beat_to_plot_50 = percentile_rmse_beat(rmse[lvl],50) beat_to_plot_75 = percentile_rmse_beat(rmse[lvl],75) beat_to_plot_99 = percentile_rmse_beat(rmse[lvl],99) file_name_to_save_fig_01 = os.path.join(log_dir,file.split(".")[0]+"_01_perc"+str(lvl)+".svg") file_name_to_save_fig_25 = os.path.join(log_dir,file.split(".")[0]+"_25_perc"+str(lvl)+".svg") file_name_to_save_fig_50 = os.path.join(log_dir,file.split(".")[0]+"_50_perc"+str(lvl)+".svg") file_name_to_save_fig_75 = os.path.join(log_dir,file.split(".")[0]+"_75_perc"+str(lvl)+".svg") file_name_to_save_fig_99 = os.path.join(log_dir,file.split(".")[0]+"_99_perc"+str(lvl)+".svg") file_name_to_save_fig_hist = os.path.join(log_dir,file.split(".")[0]+"_hist"+str(lvl)+".svg") file_name_to_save_fig_hist_left = os.path.join(log_dir,file.split(".")[0]+"_hist_left"+str(lvl)+".svg") file_name_to_save_fig_hist_right = os.path.join(log_dir,file.split(".")[0]+"_hist_right"+str(lvl)+".svg") # 01 percentile t_o,v_o = orig[beat_to_plot_01]['t'],orig[beat_to_plot_01]['v'] t_a,v_a = recon[beat_to_plot_01][lvl]['t'],recon[beat_to_plot_01][lvl]['v'] t_r,v_r = resamp[beat_to_plot_01][lvl]['t'],resamp[beat_to_plot_01][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_01}, 01 percentile') plt.savefig(file_name_to_save_fig_01) plt.close() # 25 percentile t_o,v_o = orig[beat_to_plot_25]['t'],orig[beat_to_plot_25]['v'] t_a,v_a = recon[beat_to_plot_25][lvl]['t'],recon[beat_to_plot_25][lvl]['v'] t_r,v_r = resamp[beat_to_plot_25][lvl]['t'],resamp[beat_to_plot_25][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_25}, 25 percentile') plt.savefig(file_name_to_save_fig_25) plt.close() # 50 percentile t_o,v_o = orig[beat_to_plot_50]['t'],orig[beat_to_plot_50]['v'] t_a,v_a = recon[beat_to_plot_50][lvl]['t'],recon[beat_to_plot_50][lvl]['v'] t_r,v_r = resamp[beat_to_plot_50][lvl]['t'],resamp[beat_to_plot_50][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_50}, 50 percentile') plt.savefig(file_name_to_save_fig_50) plt.close() # 75 percentile t_o,v_o = orig[beat_to_plot_75]['t'],orig[beat_to_plot_75]['v'] t_a,v_a = recon[beat_to_plot_75][lvl]['t'],recon[beat_to_plot_75][lvl]['v'] t_r,v_r = resamp[beat_to_plot_75][lvl]['t'],resamp[beat_to_plot_75][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_75}, 75 percentile') plt.savefig(file_name_to_save_fig_75) plt.close() # 99 percentile t_o,v_o = orig[beat_to_plot_99]['t'],orig[beat_to_plot_99]['v'] t_a,v_a = recon[beat_to_plot_99][lvl]['t'],recon[beat_to_plot_99][lvl]['v'] t_r,v_r = resamp[beat_to_plot_99][lvl]['t'],resamp[beat_to_plot_99][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_99}, 99 percentile') plt.savefig(file_name_to_save_fig_99) plt.close() #bar plots rmse_warp = [] rmse_resamp = [] rmse_warp_left = [] rmse_resamp_left = [] rmse_warp_right = [] rmse_resamp_right = [] for beat in rmse[lvl].keys(): rmse_warp.append(rmse[lvl][beat]['warp']) rmse_resamp.append(rmse[lvl][beat]['resamp']) rmse_warp_left.append(rmse[lvl][beat]['warp_left']) rmse_resamp_left.append(rmse[lvl][beat]['resamp_left']) rmse_warp_right.append(rmse[lvl][beat]['warp_right']) rmse_resamp_right.append(rmse[lvl][beat]['resamp_right']) n_bins = len(rmse_warp)*PERC_BINS//100 min_bin = min(min(rmse_warp),min(rmse_resamp)) max_bin = max(max(rmse_warp),max(rmse_resamp)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp, bins = bins, alpha=0.5) plt.hist(rmse_resamp, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist) plt.close() n_bins = len(rmse_warp_left)*PERC_BINS//100 min_bin = min(min(rmse_warp_left),min(rmse_resamp_left)) max_bin = max(max(rmse_warp_left),max(rmse_resamp_left)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp_left, bins = bins, alpha=0.5) plt.hist(rmse_resamp_left, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, left RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist_left) plt.close() n_bins = len(rmse_warp_right)*PERC_BINS//100 min_bin = min(min(rmse_warp_right),min(rmse_resamp_right)) max_bin = max(max(rmse_warp_right),max(rmse_resamp_right)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp_right, bins = bins, alpha=0.5) plt.hist(rmse_resamp_right, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, right RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist_right) plt.close() templates = pre_proc['template'] for i, templ in enumerate(templates): file_name_to_save_fig_template = os.path.join(log_dir,file.split(".")[0]+f"_template_{str(i)}.svg") if not os.path.isfile(file_name_to_save_fig_template): plt.figure() plt.plot(templ) plt.legend(['template']) plt.title(f'File: {file}, template') plt.savefig(file_name_to_save_fig_template) plt.close() else: break def process(files, multi=True, cores=1): # ------------ INIT ------------ global log_dir for i in range(1,1000): tmp_log_dir = log_dir+str(i) if not os.path.isdir(tmp_log_dir): log_dir = tmp_log_dir break os.mkdir(log_dir) # ------------ Extract DATA & ANNOTATIONS ------------ with Pool(cores) as pool: pool.map(recontruct_and_compare, files) if __name__ == "__main__": import argparse #global NUM_BEAT_ANALYZED parser = argparse.ArgumentParser() parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true") parser.add_argument("--cores", help="Force used number of cores (default, half of the available ones") parser.add_argument("--beats", help="Number of used beats, default: 5000") parser.add_argument("--template_type", help="Type of template, default: distance") args = parser.parse_args() files = os.listdir(data_beats_dir) if args.file is not None: if args.file == 'all': analyzed = files else: analyzed = list(filter(lambda string: True if args.file in string else False, files)) else: analyzed = [files[0]] if args.not_norm: normalize = False else: normalize = True if args.cores is not None: used_cores = int(args.cores) else: used_cores = multiprocessing.cpu_count()//2 if args.beats is not None: NUM_BEAT_ANALYZED = int(args.beats) else: NUM_BEAT_ANALYZED = 5000 if args.template_type is not None: TEMPLATE_TYPE = 'average' else: TEMPLATE_TYPE = 'distance' print(f"Analyzing files: {analyzed}") print(f"Extracting data with {used_cores} cores...") process(files = analyzed, multi=True, cores=used_cores) \ No newline at end of file diff --git a/src/multiple_templates_prog.py b/src/multiple_templates_prog.py index a49eabb..9c7a028 100644 --- a/src/multiple_templates_prog.py +++ b/src/multiple_templates_prog.py @@ -1,634 +1,654 @@ import multiprocessing import os import pickle import shutil from multiprocessing import Pool from time import time import matplotlib.pyplot as plt import numpy as np from dtaidistance import dtw from scipy.interpolate import interp1d from copy import copy from build_template import build_template, multi_template data_beats_dir = "../data/beats/" log_dir = "../data/beat_recon_logs_multi_prog_" FREQ = 128 MIN_TO_AVG = 5 BEAT_TO_PLOT = 2 LEVEL = [1,2,3,4,5,6,7,8,9,10,11] NUM_BEAT_ANALYZED = 50 LEN_DISTANCE_VECTOR = 50 TEMPLATE_TYPE = 'distance' PERC_BINS = 15 def RMSE(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) return np.sqrt(np.mean((v1_n-v2_n)**2)) def RMSE_warped_resampled(data_orig,data_warped,data_resampled): rmse = {} template = {'warp':None,'resamp':None,'warp_left':None,'resamp_left':None,'warp_right':None,'resamp_right':None} for beat in data_warped.keys(): for lvl in data_warped[beat].keys(): if lvl not in rmse.keys(): rmse[lvl] = {} rmse[lvl][beat] = copy(template) idx_beat = data_warped[beat][lvl]['t'].index(beat) l_resamp = data_resampled[beat][lvl]['v'][:idx_beat] r_resamp = data_resampled[beat][lvl]['v'][idx_beat+1:] l_warp = data_warped[beat][lvl]['v'][:idx_beat] r_warp = data_warped[beat][lvl]['v'][idx_beat+1:] l_orig = data_orig[beat]['v'][:idx_beat] r_orig = data_orig[beat]['v'][idx_beat+1:] rmse[lvl][beat]['warp'] = RMSE(data_orig[beat]['v'],data_warped[beat][lvl]['v']) rmse[lvl][beat]['resamp'] = RMSE(data_orig[beat]['v'],data_resampled[beat][lvl]['v']) rmse[lvl][beat]['warp_left'] = RMSE(l_orig,l_warp) rmse[lvl][beat]['resamp_left'] = RMSE(l_orig,l_resamp) rmse[lvl][beat]['warp_right'] = RMSE(r_orig,r_warp) rmse[lvl][beat]['resamp_right'] = RMSE(r_orig,r_resamp) return rmse def open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = None): ''' Data structure: File: | DICTIONARY: | --> Beat_annot: | --> lvl: | --> "t": [] | --> "v": [] ''' file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) data = {} data_out = {} with open(file_name_full,"rb") as f: data = pickle.load(f) for k in data.keys(): if k > FREQ*start_after: data_out[k] = {} if get_selected_level is not None: data_out[k][0] = data[k][0] for lvl in get_selected_level: data_out[k][lvl] = data[k][lvl] else: data_out[k] = data[k] return data_out, list(data_out[k].keys()) def min_max_normalization(vector, params = None): params_out = {} #print("\t",params) if params is not None: mi_v = params['min'] ma_v = params['max'] avg = params['avg'] norm = (np.array(vector)-mi_v)/(ma_v-mi_v) norm -= avg params_out = params else: mi_v = min(vector) ma_v = max(vector) norm = (np.array(vector)-mi_v)/(ma_v-mi_v) avg = np.average(norm) norm -= avg params_out['min'] = mi_v params_out['max'] = ma_v params_out['avg'] = avg return list(norm),params_out def resamp_one_signal(t,v,resample_type = 'linear', min_t = None, max_t = None): if resample_type == "linear": f = interp1d(t,v) elif resample_type == "flat": f = interp1d(t,v, kind = 'previous') if min_t is None: min_t = t[0] if max_t is None: max_t = t[-1] t_new = list(range(min_t,max_t+1)) v_new = f(t_new) return t_new,v_new def resample(data, resample_type = "linear", min_t = None, max_t = None): resampled_data = {"t":None,"v":None} t = data['t'] v = data['v'] t_r,v_r = resamp_one_signal(t,v,resample_type = resample_type, min_t = min_t, max_t = max_t) resampled_data['t'] = t_r resampled_data['v'] = v_r return resampled_data def change_sampling(vector, size_to_warp): resampled = [] if len(vector) < 3: #If the length of the vector is either 2 or 1 it will be mnorfed in a straight line and such line will be morfed equaly resampled = [vector[0]]*size_to_warp return resampled delta = abs(size_to_warp - len(vector)) # --> number of points to insert/delete #At least one more sample for each sample (except the last one)? --> insert k in each hole and recompute delta if delta >= len(vector)-1 and delta > 0: holes = len(vector)-1 k = delta // holes #Number of point for each hole time = range(0,len(vector)*(k+1),k+1) # [o1,i11,i12,...,i1k, o2,i21,i22,...,i2k, o3,..., o_fin] --> NOTE: last sampel is not taken # 0 k+1 2k+2 l*k +l --> Hence, we get to the exact end t_new = range(time[-1]+1) f = interp1d(time,vector) vector = list(f(t_new)) delta = abs(size_to_warp - len(vector)) if len(vector) == size_to_warp: return vector grad = np.gradient(vector) grad[-1] = max(grad) + 1 # --> we don't want to insert anything after the last sample grad_idxs = sorted(range(len(grad)), key = lambda idx: grad[idx]) idx_to_consider = grad_idxs[:delta] #print(delta, len(idx_to_consider), idx_to_consider) for i,sample in enumerate(vector): resampled.append(sample) if i in idx_to_consider: if size_to_warp < len(vector): resampled.pop() elif size_to_warp > len(vector): resampled.append((vector[i]+vector[i+1])/2) return resampled def segment_warp(segment): #print("--> ",len(segment['segment'])) #print("--> ",segment['length_to_warp']) resampled = np.array(change_sampling(segment['segment'], segment['length_to_warp'])) #print("--> ",len(resampled)) resampled += (segment['v_start'] - resampled[0]) m = (segment['v_stop'] - resampled[-1])/(segment['length_to_warp']-1) coef_to_sum = [m*x for x in range(segment['length_to_warp'])] resampled += coef_to_sum return list(resampled) def stitch_segments(segments): stitched = [] for i,segment in enumerate(segments): if i == len(segments)-1: stitched.extend(segment) else: stitched.extend(segment[0:-1]) return stitched def warp(data, templates, events): #print(events) warped = {} segments = [] # --> {"segment":v, "length_to_warp":length, "v_start":v_eb_0, "v_stop":v_eb_1} segment = {"segment":None, "length_to_warp":None, "v_start":None, "v_stop":None} segment_start = None event_start = None #chosen_template = [] #ad the first and last points of the resampled and truncated beat if not already in event if data['t'][0] not in events['t']: events['t'].insert(1, data['t'][0]) events['v'].insert(1, data['v'][0]) if data['t'][-1] not in events['t']: events['t'].insert(-1,data['t'][-1]) events['v'].insert(-1,data['v'][-1]) #print(events) #Apply DTW for matching resampled event to template v_src = data['v'] #This is how its actualy done on the library when calling 'warping_path' dist = float('inf') paths = [] selected_template = [] disatances_vector = [] for t in templates: dist_this_template, paths_this_template = dtw.warping_paths(v_src, t) disatances_vector.append(dist_this_template) if dist_this_template < dist: dist = dist_this_template paths = paths_this_template selected_template = t path = dtw.best_path(paths) #path = dtw.warping_path(v_src, template) #Remove the "left" steps from the path and segment the template based on the events point prev_idx_src = path[-1][0]+1 for idx_src, idx_temp in path: if prev_idx_src == idx_src: continue prev_idx_src = idx_src for idx, t in enumerate(events['t']): if data['t'][idx_src] == t: if segment_start == None: segment_start = idx_temp event_start = idx else: #print(f"SEGMENT {events['t'][event_start]} - {events['t'][idx]} ({abs(events['t'][event_start] - events['t'][idx])})") segment['segment'] = selected_template[segment_start:idx_temp+1] segment['length_to_warp'] = events['t'][idx] - events['t'][event_start] + 1 segment['v_start'] = events['v'][event_start] segment['v_stop'] = events['v'][idx] w = segment_warp(segment) #print(len(w)) segments.append(w) segment_start = idx_temp event_start = idx break segment_stitched = stitch_segments(segments) #print(len(segment_stitched), len(data['t'])) warped = {'t':data['t'],'v':segment_stitched} return dist, warped, disatances_vector def reconstruct_beats(file_name, level = None, normalize = True, resample_type = 'linear', num_beats = None, verbose = False, prev_pre_proc = None): reconstructed = {} resamp = {} data_orig = {} pre_proc = {} num_distances_out = 0 - time_last_out_std = 0 + time_last_out = 0 skip_until = 0 templates = [] templates_orig = [] max_num_template = None if verbose: print(f"Extracting {file_name}") data,lvls = open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = level) lvls.remove(0) if verbose: print("Building template") if prev_pre_proc is not None: templates_orig = prev_pre_proc['template'] else: templates_orig,_ = multi_template(file_name, normalize = normalize, t_start_seconds=0,t_stop_seconds=60*5) print(f"Initial number of templates: {len(templates_orig)}") - max_num_template = len(templates_orig) + max_num_template = len(templates_orig)//3 if num_beats == None: num_beats = len(list(data.keys())) if verbose: print("reconstructing") for lvl in lvls: i = 0 if verbose: print(f"Analyzing level:{lvl}") distances = [] - reference_average_diatance = 0 - reference_std_distance = 0 + reference_median_dist = 0 + reference_50_75_perc_distance_delta = 0 + + num_perc = 0 + num_perc_delta_out = 0 + num_median_out = 0 + num_distances_out = 0 - time_last_out_std = 0 + time_last_out = 0 skip_until = 0 templates = copy(templates_orig) dist_vector = [0]*len(templates_orig) for beat in data.keys(): t_beat = beat/FREQ if t_beat < skip_until: continue if i >= num_beats: break i+=1 if (i%(num_beats/20)==0 or i == 1) and verbose: print(f"File: {file_name}, Reconstructing beat {beat} ({i}/{num_beats}: {100*i/num_beats}%, LEVEL:{lvl})") if beat not in reconstructed.keys(): reconstructed[beat] = {} resamp[beat] = {} data_orig[beat] = copy(data[beat][0]) if normalize: data_orig[beat]['v'],params_prev = min_max_normalization(data_orig[beat]['v'], params = None) data[beat][lvl]['v'],_ = min_max_normalization(data[beat][lvl]['v'], params = params_prev) data_resampled = resample(data[beat][lvl], resample_type = resample_type, min_t = data_orig[beat]['t'][0], max_t = data_orig[beat]['t'][-1]) dist, reconstructed[beat][lvl], dist_all_template = warp(data_resampled,templates,data[beat][lvl]) dist_vector = [dist_vector[i]+dist_all_template[i] for i in range(len(dist_vector))] resamp[beat][lvl] = data_resampled if len(distances) >= LEN_DISTANCE_VECTOR: - avg_dist = np.average(distances) - std_dist = np.std(distances) - perc_std = std_dist/avg_dist + median_dist = np.median(distances) + perc_50_75_distance_delta = np.percentile(distances,75)-median_dist + delta_percentage = perc_50_75_distance_delta/median_dist distances.pop(0) - if (abs(dist-avg_dist)>std_dist or # The acquired beat is outside statistic - perc_std >= 1 or # The standard deviation of the previous beat is bigger than x% - std_dist > 1.5 * reference_std_distance or # The standard deviations of the previous beats are bigger than the reference for this template - avg_dist > 1.5 * reference_average_diatance): # The averages of the previous beats are bigger than the reference for this template - - time_last_out_std = t_beat - if t_beat - time_last_out_std > 2: #Delta time since last time the distance was outside one std + if (delta_percentage >= 1 or # The standard deviation of the previous beat is bigger than x% + perc_50_75_distance_delta > 1.5 * reference_50_75_perc_distance_delta or # The standard deviations of the previous beats are bigger than the reference for this template + median_dist > 1.5 * reference_median_dist): # The averages of the previous beats are bigger than the reference for this template + + if delta_percentage >= 1 : + num_perc += 1 + if perc_50_75_distance_delta > 1.5 * reference_50_75_perc_distance_delta : + num_perc_delta_out += 1 + if median_dist > 1.5 * reference_median_dist: + num_median_out += 1 + + time_last_out = t_beat + if t_beat - time_last_out > 3: #Delta time since last time num_distances_out = 0 - time_last_out_std = t_beat + num_perc = 0 + num_perc_delta_out = 0 + num_median_out = 0 + + time_last_out = t_beat num_distances_out += 1 if num_distances_out > 40: # number of beats in wich the warping distance was too big max_accum_dist = max(dist_vector) print(f"\nBeat num:{i}, New template needed ... ") - print(f"\t Beat outside std? {abs(dist-avg_dist)>std_dist}") - print(f"\t Std percent too high? {perc_std >= 1}") - print(f"\t Std too big wrt reference? {std_dist > 1.5 * reference_std_distance}") - print(f"\t Average too big wrt reference? {avg_dist > 1.5 * reference_average_diatance}") + print(f"\t Delta percentile percent too high: {num_perc}") + print(f"\t Delta percentile too big wrt reference: {num_perc_delta_out}") + print(f"\t Median too big wrt reference: {num_median_out}") for j in range(len(dist_vector)): print(f"\tTemplate {j}, dist: {dist_vector[j]}:\t","|"*int(10*dist_vector[j]/max_accum_dist)) - new_templates,_ = multi_template(file_name, normalize = normalize, t_start_seconds=t_beat,t_stop_seconds=t_beat+60) + new_templates,_ = multi_template(file_name, normalize = normalize, t_start_seconds=t_beat,t_stop_seconds=t_beat+40) print(f"New template built, number of new templates:{len(new_templates)}\n") if len(new_templates)>max_num_template: templates = copy(new_templates) else: to_keep = np.argsort(dist_vector)[:max_num_template-len(new_templates)] new_full_templates = [] for idx in to_keep: new_full_templates.append(templates[idx]) + print(f"\tKept template: {idx}, dist: {dist_vector[idx]}:\t","|"*int(10*dist_vector[idx]/max_accum_dist)) for t in new_templates: new_full_templates.append(t) templates = copy(new_full_templates) dist_vector = [0]*len(templates) print(f"New template list length: {len(templates)}\n") distances = [] - skip_until = t_beat + 60 + skip_until = t_beat + 40 num_distances_out = 0 elif len(distances) == LEN_DISTANCE_VECTOR - 1: - reference_average_diatance = np.average(distances) - reference_std_distance = np.std(distances) + reference_median_dist = np.median(distances) + reference_50_75_perc_distance_delta = np.percentile(distances,75)-reference_median_dist + + num_perc = 0 + num_perc_delta_out = 0 + num_median_out = 0 + distances.append(dist) pre_proc['template'] = templates_orig return reconstructed,data_orig,resamp,pre_proc,templates def percentile_idx(vector,perc): pcen=np.percentile(np.array(vector),perc,interpolation='nearest') i_near=abs(np.array(vector)-pcen).argmin() return i_near def correlate(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) cov = np.cov(v1_n,v2_n)[0,1] cor = cov/(np.std(v1_n)*np.std(v2_n)) return cor def avg_std_for_rmse_dict(dictionary): dict_out = {'warp':{'avg':None,'std':None},'resamp':{'avg':None,'std':None}} wr = [] rs = [] for beat in dictionary.keys(): wr.append(dictionary[beat]['warp']) rs.append(dictionary[beat]['resamp']) dict_out['warp']['avg'] = np.average(wr) dict_out['warp']['std'] = np.std(wr) dict_out['resamp']['avg'] = np.average(rs) dict_out['resamp']['std'] = np.std(rs) return dict_out def percentile_rmse_beat(rmse,perc): vec = [] for beat in rmse.keys(): vec.append(rmse[beat]['warp']) idx = percentile_idx(vec,perc) return list(rmse.keys())[idx] def recontruct_and_compare(file): pre_proc = None recon,orig,resamp,pre_proc,templates = reconstruct_beats(file, level = LEVEL, resample_type = 'flat', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear rmse = RMSE_warped_resampled(orig,recon,resamp) for lvl in LEVEL: ''' recon,orig,resamp,pre_proc = reconstruct_beats(file, level = [lvl], resample_type = 'linear', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear rmse = RMSE_warped_resampled(orig,recon,resamp) ''' stats = avg_std_for_rmse_dict(rmse[lvl]) avg_rmse_warp = stats['warp']['avg'] std_rmse_warp = stats['warp']['std'] avg_rmse_resamp = stats['resamp']['avg'] std_rmse_resamp = stats['resamp']['std'] #cov_rmse_warp_dist = correlate(recon_cost[lvl],rmse[lvl]) file_name_to_save = "L_"+file.split(".")[0]+".log" with open(os.path.join(log_dir,file_name_to_save),"a") as f: f.write(f"Lvl: {lvl}\n") f.write(f"\tWarp: {avg_rmse_warp}, +-{std_rmse_warp}\n") f.write(f"\tInterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") #f.write(f"\tCorrelation between warping cost and rmse: {cov_rmse_warp_dist}\n") f.write(f"\n\n") f.write(f"\n\n") print(f"File:{file_name_to_save}") print(f"\tLvl: {lvl}") print(f"\t\twarp: {avg_rmse_warp}, +-{std_rmse_warp}") print(f"\t\tinterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}") beat_to_plot_01 = percentile_rmse_beat(rmse[lvl],1) beat_to_plot_25 = percentile_rmse_beat(rmse[lvl],25) beat_to_plot_50 = percentile_rmse_beat(rmse[lvl],50) beat_to_plot_75 = percentile_rmse_beat(rmse[lvl],75) beat_to_plot_99 = percentile_rmse_beat(rmse[lvl],99) file_name_to_save_fig_01 = os.path.join(log_dir,file.split(".")[0]+"_01_perc"+str(lvl)+".svg") file_name_to_save_fig_25 = os.path.join(log_dir,file.split(".")[0]+"_25_perc"+str(lvl)+".svg") file_name_to_save_fig_50 = os.path.join(log_dir,file.split(".")[0]+"_50_perc"+str(lvl)+".svg") file_name_to_save_fig_75 = os.path.join(log_dir,file.split(".")[0]+"_75_perc"+str(lvl)+".svg") file_name_to_save_fig_99 = os.path.join(log_dir,file.split(".")[0]+"_99_perc"+str(lvl)+".svg") file_name_to_save_fig_hist = os.path.join(log_dir,file.split(".")[0]+"_hist"+str(lvl)+".svg") file_name_to_save_fig_hist_left = os.path.join(log_dir,file.split(".")[0]+"_hist_left"+str(lvl)+".svg") file_name_to_save_fig_hist_right = os.path.join(log_dir,file.split(".")[0]+"_hist_right"+str(lvl)+".svg") # 01 percentile t_o,v_o = orig[beat_to_plot_01]['t'],orig[beat_to_plot_01]['v'] t_a,v_a = recon[beat_to_plot_01][lvl]['t'],recon[beat_to_plot_01][lvl]['v'] t_r,v_r = resamp[beat_to_plot_01][lvl]['t'],resamp[beat_to_plot_01][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_01}, 01 percentile') plt.savefig(file_name_to_save_fig_01) plt.close() # 25 percentile t_o,v_o = orig[beat_to_plot_25]['t'],orig[beat_to_plot_25]['v'] t_a,v_a = recon[beat_to_plot_25][lvl]['t'],recon[beat_to_plot_25][lvl]['v'] t_r,v_r = resamp[beat_to_plot_25][lvl]['t'],resamp[beat_to_plot_25][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_25}, 25 percentile') plt.savefig(file_name_to_save_fig_25) plt.close() # 50 percentile t_o,v_o = orig[beat_to_plot_50]['t'],orig[beat_to_plot_50]['v'] t_a,v_a = recon[beat_to_plot_50][lvl]['t'],recon[beat_to_plot_50][lvl]['v'] t_r,v_r = resamp[beat_to_plot_50][lvl]['t'],resamp[beat_to_plot_50][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_50}, 50 percentile') plt.savefig(file_name_to_save_fig_50) plt.close() # 75 percentile t_o,v_o = orig[beat_to_plot_75]['t'],orig[beat_to_plot_75]['v'] t_a,v_a = recon[beat_to_plot_75][lvl]['t'],recon[beat_to_plot_75][lvl]['v'] t_r,v_r = resamp[beat_to_plot_75][lvl]['t'],resamp[beat_to_plot_75][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_75}, 75 percentile') plt.savefig(file_name_to_save_fig_75) plt.close() # 99 percentile t_o,v_o = orig[beat_to_plot_99]['t'],orig[beat_to_plot_99]['v'] t_a,v_a = recon[beat_to_plot_99][lvl]['t'],recon[beat_to_plot_99][lvl]['v'] t_r,v_r = resamp[beat_to_plot_99][lvl]['t'],resamp[beat_to_plot_99][lvl]['v'] plt.figure() plt.plot(t_o,v_o) plt.plot(t_a,v_a) plt.plot(t_r,v_r) plt.legend(['original','warped template','resampled']) plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot_99}, 99 percentile') plt.savefig(file_name_to_save_fig_99) plt.close() #bar plots rmse_warp = [] rmse_resamp = [] rmse_warp_left = [] rmse_resamp_left = [] rmse_warp_right = [] rmse_resamp_right = [] for beat in rmse[lvl].keys(): rmse_warp.append(rmse[lvl][beat]['warp']) rmse_resamp.append(rmse[lvl][beat]['resamp']) rmse_warp_left.append(rmse[lvl][beat]['warp_left']) rmse_resamp_left.append(rmse[lvl][beat]['resamp_left']) rmse_warp_right.append(rmse[lvl][beat]['warp_right']) rmse_resamp_right.append(rmse[lvl][beat]['resamp_right']) n_bins = len(rmse_warp)*PERC_BINS//100 min_bin = min(min(rmse_warp),min(rmse_resamp)) max_bin = max(max(rmse_warp),max(rmse_resamp)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp, bins = bins, alpha=0.5) plt.hist(rmse_resamp, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist) plt.close() n_bins = len(rmse_warp_left)*PERC_BINS//100 min_bin = min(min(rmse_warp_left),min(rmse_resamp_left)) max_bin = max(max(rmse_warp_left),max(rmse_resamp_left)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp_left, bins = bins, alpha=0.5) plt.hist(rmse_resamp_left, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, left RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist_left) plt.close() n_bins = len(rmse_warp_right)*PERC_BINS//100 min_bin = min(min(rmse_warp_right),min(rmse_resamp_right)) max_bin = max(max(rmse_warp_right),max(rmse_resamp_right)) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.hist(rmse_warp_right, bins = bins, alpha=0.5) plt.hist(rmse_resamp_right, bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, right RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist_right) plt.close() templates = pre_proc['template'] for i, templ in enumerate(templates): file_name_to_save_fig_template = os.path.join(log_dir,file.split(".")[0]+f"_template_{str(i)}.svg") if not os.path.isfile(file_name_to_save_fig_template): plt.figure() plt.plot(templ) plt.legend(['template']) plt.title(f'File: {file}, template') plt.savefig(file_name_to_save_fig_template) plt.close() else: break def process(files, multi=True, cores=1): # ------------ INIT ------------ global log_dir for i in range(1,1000): tmp_log_dir = log_dir+str(i) if not os.path.isdir(tmp_log_dir): log_dir = tmp_log_dir break os.mkdir(log_dir) # ------------ Extract DATA & ANNOTATIONS ------------ if cores == 1: print("Single core") for f in files: recontruct_and_compare(f) else: with Pool(cores) as pool: pool.map(recontruct_and_compare, files) if __name__ == "__main__": import argparse #global NUM_BEAT_ANALYZED parser = argparse.ArgumentParser() parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true") parser.add_argument("--cores", help="Force used number of cores (default, half of the available ones") parser.add_argument("--beats", help="Number of used beats, default: 5000") parser.add_argument("--template_type", help="Type of template, default: distance") args = parser.parse_args() files = os.listdir(data_beats_dir) if args.file is not None: if args.file == 'all': analyzed = files else: analyzed = list(filter(lambda string: True if args.file in string else False, files)) else: analyzed = [files[0]] if args.not_norm: normalize = False else: normalize = True if args.cores is not None: used_cores = int(args.cores) else: used_cores = multiprocessing.cpu_count()//2 if args.beats is not None: NUM_BEAT_ANALYZED = int(args.beats) else: NUM_BEAT_ANALYZED = 5000 if args.template_type is not None: TEMPLATE_TYPE = 'average' else: TEMPLATE_TYPE = 'distance' print(f"Analyzing files: {analyzed}") print(f"Extracting data with {used_cores} cores...") process(files = analyzed, multi=True, cores=used_cores) \ No newline at end of file