diff --git a/src/precise_multiple_templates_prog.py b/src/precise_multiple_templates_prog.py index a57050a..06a147e 100644 --- a/src/precise_multiple_templates_prog.py +++ b/src/precise_multiple_templates_prog.py @@ -1,1003 +1,1009 @@ import multiprocessing import os import pickle import shutil import sys from multiprocessing import Pool from time import time import matplotlib import matplotlib.pyplot as plt from numpy.core.numeric import base_repr # configure backend here matplotlib.use('SVG') import numpy as np from dtaidistance import dtw from numpy.core.fromnumeric import argmin scripts = '../helper_scripts' if scripts not in sys.path: sys.path.insert(0,scripts) import warnings from copy import copy from dtw.dtw import dtw_std from scipy import stats from scipy.interpolate import CubicSpline, interp1d from scipy.signal.signaltools import fftconvolve from build_template import build_template, multi_template data_beats_dir = "../data/beats/" log_dir = "../results/beat_recon_logs_multi_prog_" # Signal freqeuncy variables FREQ = 128*4 MULTIPLIER_FREQ = 5 #Final freq: 512*5 256 KHz # Parameters fixed for result evaluation PERC_BINS = 15 # Variable fixed by arguments (eg. arguments) CLUSTER_PERCENTAGE = 3 NUM_BEAT_ANALYZED = 50 INTERPOLATION_TYPE = 'flat' FILES_SELECTED = ["17052.pickle"] PARALLELIZE_ALONG = 'files' LEVELS = [3,4,5,6,7,8,9,10,11] SEC_FOR_INITIAL_TEMPLATES = 3*60 #5*60 LEN_DISTANCE_VECTOR = 60 #80 LEN_DISTANCE_VECTOR_REF = 400 #500 SEC_FOR_NEW_TEMPLATES = 40 #2*60 +TIME_MODE = 'short' def z_score_filter(vector,threshold = 3): out = None z = stats.zscore(vector) out = [vector[idx] for idx in range(len(vector)) if z[idx] Beat_annot: | --> "t": [] | --> "v": [] ''' file_name_full = os.path.join(data_beats_dir, str(level), os.path.basename(file_name)) data = {} with open(file_name_full,"rb") as f: data = pickle.load(f) for k in data.keys(): if k <= FREQ*start_after: data.pop(k) return data def upsample_uniform_beat(beat,t_QRS,multiplier): out_beat = {} upsampled = [] v = beat['v'] new_QRS_pos = beat['t'].index(t_QRS)*multiplier last_new_t = (len(v)-1)*multiplier t = np.arange(0,last_new_t+1,multiplier,dtype = np.int64) new_base = np.arange(0,last_new_t+1,1,dtype = np.int64) f = interp1d(t, v) upsampled = f(new_base) out_beat = {'t':new_base,'v':upsampled,'QRS_pos':new_QRS_pos} return out_beat def upsample_uniform_beats_in_beats_dict(beats,multiplier): up_samp_beats = {} i = 0 for beat in beats: i += 1 up_samp_beats[beat] = upsample_uniform_beat(beats[beat], beat, multiplier) return up_samp_beats def get_beats_in_time_span(data, lvl = 0 ,t_start_seconds = 0, t_stop_seconds = SEC_FOR_INITIAL_TEMPLATES): beats = {} for annot in data.keys(): if annot >= int(t_start_seconds*FREQ) and annot <= int(t_stop_seconds*FREQ): beats[annot] = data[annot] elif annot > int(t_stop_seconds*FREQ): break return beats def min_max_normalization_one_beat(beat,param = None): params_out = {} if 'QRS_pos' in list(beat.keys()): normalized_beat = {'t':beat['t'],'v':[], 'QRS_pos': beat['QRS_pos']} else: normalized_beat = {'t':beat['t'],'v':[]} vector = beat['v'] if param is not None: mi_v = param['min'] ma_v = param['max'] avg = param['avg'] norm = (np.array(vector)-mi_v)/(ma_v-mi_v) norm -= avg else: mi_v = min(vector) ma_v = max(vector) norm = (np.array(vector)-mi_v)/(ma_v-mi_v) avg = np.average(norm) norm -= avg normalized_beat['v'] = norm.tolist() params_out['min'] = mi_v params_out['max'] = ma_v params_out['avg'] = avg return normalized_beat,params_out def min_max_normalization_beats_chunk(beats, params = None): params_out = {} normalized_beats = {} #print("\t",params) for beat in beats: normalized_beats[beat],params_out[beat] = min_max_normalization_one_beat(beats[beat],params) return normalized_beats,params_out def resamp_one_signal(t,v,resample_type = INTERPOLATION_TYPE, min_t = None, max_t = None): if min_t is None: min_t = t[0] if max_t is None: max_t = t[-1] t_extended = copy(t) v_extended = copy(v) if max_t not in t: t_extended.insert(len(t_extended), max_t) v_extended.insert(len(v_extended), 0) if min_t not in t: # This is not needed in this implementation as the first sample is always an events for each beat t_extended.insert(0, min_t) # Still, we write it for clearity and consistency v_extended.insert(0, 0) if resample_type == "linear": f = interp1d(t_extended,v_extended, bounds_error = False, fill_value = (v[0],v[-1])) elif resample_type == "flat": f = interp1d(t_extended,v_extended, kind = 'previous', bounds_error = False, fill_value = (v[0],v[-1])) elif resample_type == "spline": f = CubicSpline(t_extended,v_extended, bc_type="natural") t_new = list(range(min_t,max_t+1)) v_new = f(t_new) return t_new,v_new def percentile_idx(vector,perc): pcen=np.percentile(np.array(vector),perc,interpolation='nearest') i_near=abs(np.array(vector)-pcen).argmin() return i_near def correlate(v1,v2): v1_n = np.array(v1) v2_n = np.array(v2) cov = np.cov(v1_n,v2_n)[0,1] cor = cov/(np.std(v1_n)*np.std(v2_n)) return cor def avg_std_rmse(rmse_descriptor): dict_out = {'warp':{'avg':None,'std':None},'resamp':{'avg':None,'std':None}} wr = [] rs = [] dict_out['warp']['avg'] = np.average(rmse_descriptor['rmse_warp']) dict_out['warp']['std'] = np.std(rmse_descriptor['rmse_warp']) dict_out['resamp']['avg'] = np.average(rmse_descriptor['rmse_resamp']) dict_out['resamp']['std'] = np.std(rmse_descriptor['rmse_resamp']) return dict_out def percentile_rmse_beat_rel(rmse,perc): vec = [] for beat in rmse.keys(): vec.append(np.array(rmse['rmse_warp']) - np.array(rmse['rmse_resamp'])) idx = percentile_idx(vec,perc) return rmse['beats'][idx],rmse['ids'][idx] def percentile_rmse_beat_abs(rmse,perc): vec = [] for beat in rmse.keys(): vec.append(rmse['rmse_warp']) idx = percentile_idx(vec,perc) return rmse['beats'][idx],rmse['ids'][idx] def find_all_connceted_template(id_temp,templates_info): id_collected = [id_temp] for id_connected in id_collected: for id_to_add in templates_info[id_connected]["connected_template"]: if id_to_add not in id_collected: id_collected.append(id_to_add) return id_collected def plot_perc(orig,warp,resamp,template,title,out_file_name): fig, (ax1, ax2) = plt.subplots(2) fig.suptitle(title) ax1.plot(orig[0],orig[1]) ax1.plot(warp[0],warp[1]) ax1.plot(resamp[0],resamp[1]) ax1.legend(['original','warped template','resampled']) ax2.plot(template) ax2.legend(['template used']) fig.savefig(out_file_name) plt.close(fig) def plot_beat_rmse_percentile(orig, rmse_id_coupling, templates_collection, lvl, perc, log_dir_this_lvl, file, interpolation_type = INTERPOLATION_TYPE): # rmse_id_coupling = {"rmse_resamp":[],"rmse_right_resamp":[],"rmse_left_resamp":[], "rmse_warp": [],"rmse_warp_right": [],"rmse_warp_left": [], "ids":[], "beats": []} #Relative template_sel = {"template":None} beat_id, template_id = percentile_rmse_beat_rel(rmse_id_coupling, perc) file_name_to_save_fig_rel = os.path.join(log_dir_this_lvl,file.split(".")[0]+"_"+str(perc)+"_perc"+str(lvl)+"_relative.svg") template_sel['template'] = templates_collection[template_id] up_samp_beat = upsample_uniform_beat(orig[beat_id], beat_id, MULTIPLIER_FREQ) #TODO: keep track of QRS events = ADC(up_samp_beat,lvl) up_samp_beat,params = min_max_normalization_one_beat(up_samp_beat) events,_ = min_max_normalization_one_beat(events,params) resampled = resample(events, resample_type = interpolation_type, min_t = up_samp_beat['t'][0], max_t = up_samp_beat['t'][-1]) _, reconstructed, _, _ = warp(resampled,template_sel,events) orig_rel = up_samp_beat['t'],up_samp_beat['v'] recon_rel = reconstructed['t'],reconstructed['v'] resamp_rel = resampled['t'],resampled['v'] title_rel = f'File: {file}, Lvl: {lvl}, Beat time (samples): {beat_id}, {str(perc)} percentile, Relative' plot_perc(orig_rel,recon_rel,resamp_rel,template_sel['template']['point'],title_rel,file_name_to_save_fig_rel) #Absolute beat_id, template_id = percentile_rmse_beat_abs(rmse_id_coupling, perc) file_name_to_save_fig_abs = os.path.join(log_dir_this_lvl,file.split(".")[0]+"_"+str(perc)+"_perc"+str(lvl)+"_absolute.svg") template_sel['template'] = templates_collection[template_id] up_samp_beat = upsample_uniform_beat(orig[beat_id], beat_id, MULTIPLIER_FREQ) #TODO: keep track of QRS events = ADC(up_samp_beat,lvl) up_samp_beat,params = min_max_normalization_one_beat(up_samp_beat) events,_ = min_max_normalization_one_beat(events,params) resampled = resample(events, resample_type = interpolation_type, min_t = up_samp_beat['t'][0], max_t = up_samp_beat['t'][-1]) _, reconstructed, _, _ = warp(resampled,template_sel,events) orig_abs = up_samp_beat['t'],up_samp_beat['v'] recon_abs = reconstructed['t'],reconstructed['v'] resamp_abs = resampled['t'],resampled['v'] title_abs = f'File: {file}, Lvl: {lvl}, Beat time (samples): {beat_id}, {str(perc)} percentile, Absolute' plot_perc(orig_abs,recon_abs,resamp_abs,template_sel['template']['point'],title_abs,file_name_to_save_fig_abs) def stitch_segments(segments): stitched = [] for i,segment in enumerate(segments): if i == len(segments)-1: stitched.extend(segment) else: stitched.extend(segment[0:-1]) return stitched def change_sampling(vector, size_to_warp): resampled = [] if len(vector) < 3: #If the length of the vector is either 2 or 1 it will be mnorfed in a straight line and such line will be morfed equaly resampled = [vector[0]]*size_to_warp return resampled delta = abs(size_to_warp - len(vector)) # --> number of points to insert/delete #At least one more sample for each sample (except the last one)? --> insert k in each hole and recompute delta if delta >= len(vector)-1 and delta > 0: holes = len(vector)-1 k = delta // holes #Number of point for each hole time = range(0,len(vector)*(k+1),k+1) # [o1,i11,i12,...,i1k, o2,i21,i22,...,i2k, o3,..., o_fin] --> NOTE: last sampel is not taken # 0 k+1 2k+2 l*k +l --> Hence, we get to the exact end t_new = range(time[-1]+1) f = interp1d(time,vector, bounds_error = False, fill_value = "extrapolate") vector = list(f(t_new)) delta = abs(size_to_warp - len(vector)) if len(vector) == size_to_warp: return vector grad = np.gradient(vector) grad[-1] = max(grad) + 1 # --> we don't want to insert anything after the last sample grad_idxs = sorted(range(len(grad)), key = lambda idx: grad[idx]) idx_to_consider = grad_idxs[:delta] #print(delta, len(idx_to_consider), idx_to_consider) for i,sample in enumerate(vector): resampled.append(sample) if i in idx_to_consider: if size_to_warp < len(vector): resampled.pop() elif size_to_warp > len(vector): resampled.append((vector[i]+vector[i+1])/2) return resampled def segment_warp(segment): #print("--> ",len(segment['segment'])) #print("--> ",segment['length_to_warp']) resampled = np.array(change_sampling(segment['segment'], segment['length_to_warp'])) #print("--> ",len(resampled)) resampled += (segment['v_start'] - resampled[0]) m = (segment['v_stop'] - resampled[-1])/(segment['length_to_warp']-1) coef_to_sum = [m*x for x in range(segment['length_to_warp'])] resampled += coef_to_sum return list(resampled) def warp(resamp_event, templates, events): #print(events) warped = {} segments = [] # --> {"segment":v, "length_to_warp":length, "v_start":v_eb_0, "v_stop":v_eb_1} segment = {"segment":None, "length_to_warp":None, "v_start":None, "v_stop":None} segment_start = None event_start = None #chosen_template = [] #ad the first and last points of the resampled and truncated beat if not already in event if resamp_event['t'][0] not in events['t']: events['t'].insert(0, resamp_event['t'][0]) events['v'].insert(0, 0) if resamp_event['t'][-1] not in events['t']: events['t'].insert(len(events['t']),resamp_event['t'][-1]) events['v'].insert(len(events['v']),0) #print(events) #Apply DTW for matching resampled event to template v_src = resamp_event['v'] #This is how its actualy done on the library when calling 'warping_path' dist = float('inf') path = [] selected_template = [] disatances_vector = [] template_id = None for id in templates: t = templates[id]['point'] #dist_this_template, paths_this_template = dtw.warping_paths(v_src, t) with warnings.catch_warnings(): warnings.simplefilter("ignore") dist_this_template, _, path_this_template = dtw_std(v_src, t, dist_only=False) disatances_vector.append(dist_this_template) if dist_this_template < dist: dist = dist_this_template path = path_this_template selected_template = t template_id = id #path = dtw.best_path(paths) path = [(v1,v2) for v1,v2 in zip(path[0],path[1])] #Remove the "left" steps from the path and segment the template based on the events point + #TODO: Problem here in the warping (some beats does not include the ending of the template, especially with low levels), check why !!!! With only 10 beats the effect is not visible + # RESULTS: ALL THE TEMPLATES POINTS ARE IN THE LAST POINT OF THE WARP!! prev_idx_src = path[-1][0]+1 for idx_src, idx_temp in path: if prev_idx_src == idx_src: continue prev_idx_src = idx_src for idx, t in enumerate(events['t']): if resamp_event['t'][idx_src] == t: if segment_start == None: segment_start = idx_temp event_start = idx else: #print(f"SEGMENT {events['t'][event_start]} - {events['t'][idx]} ({abs(events['t'][event_start] - events['t'][idx])})") segment['segment'] = np.array(selected_template[segment_start:idx_temp+1], dtype=np.float64) - segment['length_to_warp'] = events['t'][idx] - events['t'][event_start] + 1 #TODO: Change this from timing information to length information + segment['length_to_warp'] = events['t'][idx] - events['t'][event_start] + 1 segment['v_start'] = events['v'][event_start] segment['v_stop'] = events['v'][idx] w = segment_warp(segment) - #print(len(w)) segments.append(w) segment_start = idx_temp event_start = idx break segment_stitched = stitch_segments(segments) - #print(len(segment_stitched), len(resamp_event['t'])) warped = {'t':resamp_event['t'],'v':segment_stitched} return dist, warped, disatances_vector, template_id def dtw_dist(v1,v2): with warnings.catch_warnings(): warnings.simplefilter("ignore") dist = dtw_std(v1, v2) return dist # TODO: /!\ TEST WHAT HAPPEN WITH LOWER TIME TO COLLECT BEATS FOR NEW TEMPLATS: This version is with shorter times def compute_new_templates(data, t_start, t_stop, old_templates, templates_info): beats_for_new_template = get_beats_in_time_span(data, t_start_seconds=t_start,t_stop_seconds=t_stop) #beats_for_new_template = upsample_uniform_beats_in_beats_dict(beats_for_new_template,MULTIPLIER_FREQ) beats_for_new_template,_ = min_max_normalization_beats_chunk(beats_for_new_template) _, new_templates_descriptor = multi_template(beats_for_new_template, percentage_each_cluster = CLUSTER_PERCENTAGE*2, freq= FREQ)#*MULTIPLIER_FREQ) lbls_new_templates = list(new_templates_descriptor.keys()) old_ids = list(old_templates.keys()) if old_ids == []: next_id = 0 else: next_id = max(old_ids) + 1 cluster_representatives = {} new_template_set = {} print(f"\nNew template built, number of new templates:{len(lbls_new_templates)}\n") old_templates_kept = 0 old_templates_substituted = 0 new_templates_kept = 0 new_templates_substituted = 0 # We search which of the old templates can be considered clustered with the newly founded ones if len(lbls_new_templates) > 0: for id_old_template in old_templates: t_o = old_templates[id_old_template]['point'] dists = np.zeros((len(lbls_new_templates))) for k,t_n in enumerate(lbls_new_templates): dists[k] = dtw_dist(t_o,new_templates_descriptor[t_n]['center']) min_dist_pos = np.argmin(dists) min_val = dists[min_dist_pos] this_lbls_new_templates = lbls_new_templates[min_dist_pos] distances_intra_cluster = new_templates_descriptor[this_lbls_new_templates]['dist_all_pt_cluster'] dist_threshold = np.average(distances_intra_cluster)+np.std(distances_intra_cluster) if min_val < dist_threshold: if this_lbls_new_templates not in cluster_representatives.keys(): cluster_representatives[this_lbls_new_templates] = {"ids": [], "dists": []} cluster_representatives[this_lbls_new_templates]["ids"].append(id_old_template) cluster_representatives[this_lbls_new_templates]["dists"].append(min_val) else: #This templates minimum distances to the newly founded templates are too big to be considered the same cluster new_template_set[id_old_template] = old_templates[id_old_template] old_templates_kept += 1 # Now we check if the old templates, clustered with the newly founded one, are representative of the cluster of if we should use the new templates for local_new_id in lbls_new_templates: if local_new_id in cluster_representatives.keys(): connected_ids = cluster_representatives[local_new_id]["ids"] arg_min_dist = np.argmin(cluster_representatives[local_new_id]["dists"]) old_template_id = cluster_representatives[local_new_id]["ids"][arg_min_dist] old_template_dist = cluster_representatives[local_new_id]["dists"][arg_min_dist] # The old template is nearer to the found cluster center than the beat found by the clustering alg if old_template_dist < new_templates_descriptor[local_new_id]["dist"]: new_template_set[old_template_id] = old_templates[old_template_id] connected_ids.remove(old_template_id) templates_info[old_template_id]["connected_template"].extend(connected_ids) new_templates_substituted += 1 # The beat found by the clustering alg is nearer to the found cluster center than the old template else: new_template_set[next_id] = new_templates_descriptor[local_new_id] templates_info[next_id] = {"used": 0, "deceased": False, "connected_template": connected_ids} next_id += 1 old_templates_substituted += len(cluster_representatives[local_new_id]["ids"]) # The templates in this cluster that are not the center are now deceased for id_removed in connected_ids: templates_info[id_removed]["deceased"] = True else: new_template_set[next_id] = new_templates_descriptor[local_new_id] templates_info[next_id] = {"used": 0, "deceased": False, "connected_template": []} next_id += 1 new_templates_kept += 1 print(f"Old templates kept untuched: {old_templates_kept}/{len(old_templates)}") print(f"Old templates kept, representative of new clusters (but already present): {new_templates_substituted}/{len(old_templates)}") print(f"Old templates removed for old clusters (with new defined params): {len(old_templates) - (old_templates_kept+old_templates_substituted+new_templates_substituted)}/{len(old_templates)}") print(f"Old templates removed for new clusters: {old_templates_substituted}/{len(old_templates)}") print(f"\nNew templates kept untuched: {new_templates_kept}/{len(lbls_new_templates)}") print(f"New templates kept, representative of old clusters (with new params): {len(lbls_new_templates) - (new_templates_kept+new_templates_substituted)}/{len(lbls_new_templates)}") print(f"New templates removed for old clusters: {new_templates_substituted}/{len(lbls_new_templates)}") print(f"\nFinal lenght of the new tempalte set: {len(new_template_set)}") else: new_template_set = old_templates print("No new template found, keeping the previously computed ones") return new_template_set def resample(data, resample_type = INTERPOLATION_TYPE, min_t = None, max_t = None): resampled_data = {"t":None,"v":None} t = data['t'] v = data['v'] t_r,v_r = resamp_one_signal(t,v,resample_type = resample_type, min_t = min_t, max_t = max_t) resampled_data['t'] = t_r resampled_data['v'] = v_r return resampled_data def ADC(beat, nBits, hist = 5, original_bits = 11): #ADC stats delta = 2**original_bits dV = (delta)/(2**nBits) hist = hist/100*dV min_val = -delta//2 events = {'t':[],'v':[], 'QRS_pos': beat['QRS_pos']} #init value, first sample (we assume we always sample the nearest level at start) and ADC status v_0 = beat['v'][0] lowTh = min_val+((v_0-min_val)//dV)*dV highTh = lowTh + dV events['t'].append(beat['t'][0]) events['v'].append(int(lowTh if v_0-lowTh < highTh - v_0 else highTh)) for val,time in zip(beat['v'],beat['t']): #print(f"Value: {val}, time: {time}, low_th = {lowTh - hist}, high_th = { highTh + hist}") if val > highTh + hist or val < lowTh - hist: direction = 1 if val > highTh else -1 lowTh = min_val+((val-min_val)//dV)*dV #Delta from the bottom: (val-min_val)//dV*dV then compute the actual level summin min_val highTh = lowTh + dV events['t'].append(time) events['v'].append(int(lowTh if direction == 1 else highTh)) return events def reconstruct_beats(data_orig, lvl_number, init_templates = None, start_after = SEC_FOR_INITIAL_TEMPLATES, resample_type = INTERPOLATION_TYPE, num_beats_analyzed = None, verbose = False): beat_seq_number = 0 distances = [] distances_ref = [] num_distances_out = 0 skip_until = start_after time_info = {'beats_low_res_num': 0, 'tot_beats_num': 0} # COMPUTE INITIAL TEMPLATES #Init templet info templates_info = {} #{$templet_id: {used: 0, deceased: False, "connected_template": []}} all_templates_collection = {} if init_templates is None: templates = compute_new_templates(data_orig, 0, start_after, {}, templates_info) else: templates = init_templates for id_new in templates: templates_info[id_new] = {"used": 0, "deceased": False, "connected_template": []} all_templates_collection = copy (templates) init_templates = copy(templates) rmse_id_coupling = {"rmse_resamp":[],"rmse_right_resamp":[],"rmse_left_resamp":[], "rmse_warp": [],"rmse_warp_right": [],"rmse_warp_left": [], "ids":[], "beats":[]} dist_vector = [0]*len(templates) for beat in data_orig.keys(): t_beat = beat/FREQ time_info['tot_beats_num'] += 1 if t_beat < skip_until: continue time_info['beats_low_res_num'] += 1 if beat_seq_number >= num_beats_analyzed: break beat_seq_number +=1 if (beat_seq_number %(num_beats_analyzed/20)==0 or beat_seq_number == 1) and verbose: print(f"Reconstructing beat {beat} ({beat_seq_number}/{num_beats_analyzed}: {100*beat_seq_number /num_beats_analyzed}%, LEVELS:{lvl_number}, resample type: {resample_type})") up_samp_beat = upsample_uniform_beat(data_orig[beat], beat, MULTIPLIER_FREQ) #TODO: keep track of QRS events = ADC(up_samp_beat,lvl_number) up_samp_beat,params = min_max_normalization_one_beat(up_samp_beat) events,_ = min_max_normalization_one_beat(events,params) resampled = resample(events, resample_type = resample_type, min_t = up_samp_beat['t'][0], max_t = up_samp_beat['t'][-1]) #WARP DRIVEEEEEEE dist, reconstructed, dist_all_template, local_template_id = warp(resampled,templates,events) dist_vector = [dist_vector[i]+dist_all_template[i] for i in range(len(dist_vector))] #Update templates usage info template_id = list(templates_info.keys())[-len(templates)+local_template_id] templates_info[template_id]["used"] += 1 #Compute RMSE #{'warp':None,'resamp':None,'warp_left':None,'resamp_left':None,'warp_right':None,'resamp_right':None} rmse_this_beat = RMSE_one_beat(up_samp_beat,resampled,reconstructed) rmse_id_coupling['ids'].append(local_template_id) rmse_id_coupling['beats'].append(beat) rmse_id_coupling['rmse_resamp'].append(rmse_this_beat['resamp']) rmse_id_coupling['rmse_right_resamp'].append(rmse_this_beat['resamp_right']) rmse_id_coupling['rmse_left_resamp'].append(rmse_this_beat['resamp_left']) rmse_id_coupling['rmse_warp'].append(rmse_this_beat['warp']) rmse_id_coupling['rmse_warp_right'].append(rmse_this_beat['warp_right']) rmse_id_coupling['rmse_warp_left'].append(rmse_this_beat['warp_left']) if len(distances_ref) < LEN_DISTANCE_VECTOR_REF: distances_ref.append(dist) if len(distances_ref) >= LEN_DISTANCE_VECTOR_REF: distances.append(dist) if len(distances) >= LEN_DISTANCE_VECTOR: with warnings.catch_warnings(): warnings.simplefilter("ignore") p = stats.anderson_ksamp([distances_ref,distances])[2] distances = [] if (p>= 0.05): num_distances_out = 0 else: # p value les than 0.05: null hypothesis (same distribution) rejected (please Kolmogorov forgive me) num_distances_out += 1 if num_distances_out >= 3: # The acquired vector of distances was out for 2 times max_accum_dist = max(dist_vector) print(f"\n################################################################") print(f"\nBeat num:{beat_seq_number}, New template needed ... ") print(f"\t p-value: {p}") for j in range(len(dist_vector)): print(f"\tTemplate {j}, dist: {dist_vector[j]}:\t","|"*int(20*dist_vector[j]/max_accum_dist)) templates = compute_new_templates(data_orig, t_beat, t_beat+SEC_FOR_NEW_TEMPLATES, templates, templates_info) for t in templates: all_templates_collection[t] = templates[t] dist_vector = [0]*len(templates) print(f"\n################################################################\n") distances_ref = [] skip_until = t_beat + SEC_FOR_NEW_TEMPLATES num_distances_out = 0 return all_templates_collection,rmse_id_coupling,init_templates,time_info def reconstruction_one_lvl_one_file_and_compare(data_orig,lvl,file,initial_templates,verbose = True): log_dir_this_file = os.path.join(log_dir,file.split(".")[0]) if verbose: print(f"Analyzing level:{lvl}") interpolation_type_list = None if INTERPOLATION_TYPE == 'all': interpolation_type_list = ['flat','spline','linear'] else: interpolation_type_list = [INTERPOLATION_TYPE] rmse_each_interpolation_type = {} for interpolation_type in interpolation_type_list: if verbose: print(f"Using interpolation: {interpolation_type}") all_templates_collection, rmse_id_coupling, initial_templates, time_info = reconstruct_beats(data_orig, lvl, init_templates = initial_templates, start_after = SEC_FOR_INITIAL_TEMPLATES, resample_type = interpolation_type, num_beats_analyzed = NUM_BEAT_ANALYZED, verbose = True) #resample_type = flat vs linear log_dir_this_lvl = os.path.join(log_dir_this_file,str(lvl),interpolation_type) os.makedirs(log_dir_this_lvl, exist_ok=True) stats = avg_std_rmse(rmse_id_coupling) avg_rmse_warp = stats['warp']['avg'] std_rmse_warp = stats['warp']['std'] avg_rmse_resamp = stats['resamp']['avg'] std_rmse_resamp = stats['resamp']['std'] file_name_to_save = "L_"+interpolation_type+"_"+file.split(".")[0]+".log" # Particular log for each lvl with open(os.path.join(log_dir_this_lvl,file_name_to_save),"a") as f: f.write(f"Lvl: {lvl}, using the {interpolation_type} interpolation:\n") f.write(f"\tWarp: {avg_rmse_warp}, +-{std_rmse_warp}\n") f.write(f"\tInterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") f.write(f"\tTime (percentage) passed in low-sampling mode: {time_info['beats_low_res_num']/time_info['tot_beats_num']*100}%\n") f.write(f"\n\n") # General log (the same but all toghether: more confusing but with all infos) with open(os.path.join(log_dir_this_file,file_name_to_save),"a") as f: f.write(f"Lvl: {lvl}, using the {interpolation_type} interpolation:\n") f.write(f"\tWarp: {avg_rmse_warp}, +-{std_rmse_warp}\n") f.write(f"\tInterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") ''' for id_temp in templates_info: usage = templates_info[id_temp]["used"] deceased = templates_info[id_temp]["deceased"] f.write(f"\t\tTemplate {id_temp} was used {usage} times out of {NUM_BEAT_ANALYZED}, \ deceased? {deceased}\n") ''' f.write(f"\tTime (percentage) passed in low-sampling mode: {time_info['beats_low_res_num']/time_info['tot_beats_num']*100}%\n") f.write(f"\n\n") print(f"File:{file_name_to_save}, using the {interpolation_type} interpolation:") print(f"\tLvl: {lvl}") print(f"\t\twarp: {avg_rmse_warp}, +-{std_rmse_warp}") print(f"\t\tinterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}") print(f"\t\tTime (percentage) passed in low-sampling mode: {time_info['beats_low_res_num']/time_info['tot_beats_num']*100}%") print("\n") # 01 percentile plot_beat_rmse_percentile(data_orig, rmse_id_coupling, all_templates_collection, lvl, 1, log_dir_this_lvl, file, interpolation_type = interpolation_type) # 25 percentile plot_beat_rmse_percentile(data_orig, rmse_id_coupling, all_templates_collection, lvl, 25, log_dir_this_lvl, file, interpolation_type = interpolation_type) # 50 percentile plot_beat_rmse_percentile(data_orig, rmse_id_coupling, all_templates_collection, lvl, 50, log_dir_this_lvl, file, interpolation_type = interpolation_type) # 75 percentile plot_beat_rmse_percentile(data_orig, rmse_id_coupling, all_templates_collection, lvl, 75, log_dir_this_lvl, file, interpolation_type = interpolation_type) # 99 percentile plot_beat_rmse_percentile(data_orig, rmse_id_coupling, all_templates_collection, lvl, 99, log_dir_this_lvl, file, interpolation_type = interpolation_type) #Filter and save back the RMSE values for better plots (the filtered data is still accounted for before) rmse_id_coupling['rmse_warp'] = z_score_filter(rmse_id_coupling['rmse_warp']) rmse_id_coupling['rmse_resamp'] = z_score_filter(rmse_id_coupling['rmse_resamp']) rmse_id_coupling['rmse_warp_left'] = z_score_filter(rmse_id_coupling['rmse_warp_left']) rmse_id_coupling['rmse_left_resamp'] = z_score_filter(rmse_id_coupling['rmse_left_resamp']) rmse_id_coupling['rmse_warp_right'] = z_score_filter(rmse_id_coupling['rmse_warp_right']) rmse_id_coupling['rmse_right_resamp'] = z_score_filter(rmse_id_coupling['rmse_right_resamp']) rmse_each_interpolation_type[interpolation_type] = rmse_id_coupling # Histograms file_name_to_save_fig_hist = os.path.join(log_dir_this_lvl,file.split(".")[0]+"_hist"+str(lvl)+".svg") file_name_to_save_fig_hist_left = os.path.join(log_dir_this_lvl,file.split(".")[0]+"_hist_left"+str(lvl)+".svg") file_name_to_save_fig_hist_right = os.path.join(log_dir_this_lvl,file.split(".")[0]+"_hist_right"+str(lvl)+".svg") n_bins = min(len(rmse_id_coupling['rmse_warp']),len(rmse_id_coupling['rmse_resamp']))*PERC_BINS//100 min_bin = min(min(rmse_id_coupling['rmse_warp']),min(rmse_id_coupling['rmse_resamp'])) max_bin = max(max(rmse_id_coupling['rmse_warp']),max(rmse_id_coupling['rmse_resamp'])) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.figure() plt.hist(rmse_id_coupling['rmse_warp'], bins = bins, alpha=0.5) plt.hist(rmse_id_coupling['rmse_resamp'], bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist) plt.close() n_bins = min(len(rmse_id_coupling['rmse_warp_left']),len(rmse_id_coupling['rmse_left_resamp']))*PERC_BINS//100 min_bin = min(min(rmse_id_coupling['rmse_warp_left']),min(rmse_id_coupling['rmse_left_resamp'])) max_bin = max(max(rmse_id_coupling['rmse_warp_left']),max(rmse_id_coupling['rmse_left_resamp'])) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.figure() plt.hist(rmse_id_coupling['rmse_warp_left'], bins = bins, alpha=0.5) plt.hist(rmse_id_coupling['rmse_left_resamp'], bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, left RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist_left) plt.close() n_bins = min(len(rmse_id_coupling['rmse_warp_right']),len(rmse_id_coupling['rmse_right_resamp']))*PERC_BINS//100 min_bin = min(min(rmse_id_coupling['rmse_warp_right']),min(rmse_id_coupling['rmse_right_resamp'])) max_bin = max(max(rmse_id_coupling['rmse_warp_right']),max(rmse_id_coupling['rmse_right_resamp'])) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) plt.figure() plt.hist(rmse_id_coupling['rmse_warp_right'], bins = bins, alpha=0.5) plt.hist(rmse_id_coupling['rmse_right_resamp'], bins = bins, alpha=0.5) plt.title(f'File: {file}, Lvl: {lvl}, right RMSE histogram') plt.legend(['RMSE warp','RMSE resampled']) plt.savefig(file_name_to_save_fig_hist_right) plt.close() if INTERPOLATION_TYPE == 'all': log_dir_this_lvl_combined = os.path.join(log_dir_this_file,str(lvl),"combined_results") os.makedirs(log_dir_this_lvl_combined, exist_ok=True) file_name_to_save = "L_"+file.split(".")[0]+".log" global_file_name_to_save = "L_all_interpolations_"+file.split(".")[0]+".log" with open(os.path.join(log_dir_this_lvl_combined,file_name_to_save),"a") as f: f.write(f"Lvl: {lvl}\n") with open(os.path.join(log_dir_this_file,global_file_name_to_save),"a") as f: f.write(f"Lvl: {lvl}\n") # Histograms plt.figure() file_name_to_save_fig_hist = os.path.join(log_dir_this_lvl_combined,file.split(".")[0]+"_hist"+str(lvl)+".svg") n_bins = np.inf min_bin = np.inf max_bin = -np.inf legend_str = [] for interp_type in rmse_each_interpolation_type: r_w = rmse_each_interpolation_type[interp_type]['rmse_warp'] r_r = rmse_each_interpolation_type[interp_type]['rmse_resamp'] n_bins = min(min(len(r_w),len(r_r))*PERC_BINS//100,n_bins) min_bin = min(min(r_w),min(r_r),min_bin) max_bin = max(max(r_w),max(r_r),max_bin) legend_str.extend([f'RMSE warp {interp_type}',f'RMSE resampled {interp_type}']) delta = (max_bin-min_bin)/n_bins bins = np.arange(min_bin,max_bin+delta,delta) for interp_type in rmse_each_interpolation_type: stats = avg_std_rmse(rmse_each_interpolation_type[interp_type]) avg_rmse_warp = stats['warp']['avg'] std_rmse_warp = stats['warp']['std'] avg_rmse_resamp = stats['resamp']['avg'] std_rmse_resamp = stats['resamp']['std'] # Particular log for each lvl with open(os.path.join(log_dir_this_lvl_combined,file_name_to_save),"a") as f: f.write(f"\tWarp using the {interp_type} interpolation: {avg_rmse_warp}, +-{std_rmse_warp}\n") f.write(f"\tInterpolation using the {interp_type} interpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") # General log (the same but all toghether: more confusing but with all infos) with open(os.path.join(log_dir_this_file,global_file_name_to_save),"a") as f: f.write(f"\tWarp using the {interp_type} interpolation: {avg_rmse_warp}, +-{std_rmse_warp}\n") f.write(f"\tInterpolation using the {interp_type} interpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n") plt.hist(rmse_each_interpolation_type[interp_type]['rmse_warp'], bins = bins, alpha=0.5) plt.hist(rmse_each_interpolation_type[interp_type]['rmse_resamp'], bins = bins, alpha=0.5) plt.legend(legend_str) plt.title(f'File: {file}, Lvl: {lvl}, RMSE histogram') plt.savefig(file_name_to_save_fig_hist) plt.close() with open(os.path.join(log_dir_this_lvl_combined,file_name_to_save),"a") as f: f.write(f"\n\n") with open(os.path.join(log_dir_this_file,global_file_name_to_save),"a") as f: f.write(f"\n\n") return initial_templates def reconstruct_and_compare_level_parallel(lvl): print(lvl) for file in FILES_SELECTED: verbose = True log_dir_this_file = os.path.join(log_dir,file.split(".")[0]) os.makedirs(log_dir_this_file,exist_ok=True) init_templates = None if verbose: print("Extracting original data") data_orig = open_file(file, level = "original", start_after = 0) init_templates = reconstruction_one_lvl_one_file_and_compare(data_orig,lvl,file,init_templates,verbose = True) def recontruct_and_compare_file_parallel(file): verbose = True log_dir_this_file = os.path.join(log_dir,file.split(".")[0]) os.mkdir(log_dir_this_file) init_templates = None if verbose: print("Extracting original data") data_orig = open_file(file, level = "original", start_after = 0) for lvl in LEVELS: init_templates = reconstruction_one_lvl_one_file_and_compare(data_orig,lvl,file,init_templates,verbose = True) def process(files, levels, parallelize_along = PARALLELIZE_ALONG, cores=1): # ------------ INIT ------------ global log_dir for i in range(1,10000): tmp_log_dir = log_dir+str(i) if not os.path.isdir(tmp_log_dir): log_dir = tmp_log_dir break os.makedirs(log_dir, exist_ok=True) with open(os.path.join(log_dir,"specs.txt"), "w") as f: f.write(f"Results generated by script: {sys.argv[0]}\n") f.write(f"Time: {time.ctime(time.time())}\n\n") f.write(f"Files: {files}\n") f.write(f"Levels: {levels}\n") f.write(f"Parallelize along: {parallelize_along}\n") f.write(f"Cores: {cores}\n") f.write(f"Beats: {NUM_BEAT_ANALYZED}\n") f.write(f"Cluster percentage: {CLUSTER_PERCENTAGE}\n") f.write(f"Interpolation type: {INTERPOLATION_TYPE}\n") + f.write(f"Timing mode: {TIME_MODE}\n") # ------------ Extract DATA & ANNOTATIONS ------------ if cores == 1: print("Single core") if parallelize_along == 'levels': for lvl in levels: reconstruct_and_compare_level_parallel(lvl) elif parallelize_along == 'files': for f in files: recontruct_and_compare_file_parallel(f) else: with Pool(cores) as pool: if parallelize_along == 'levels': print(f"parallelizing along levels: {levels}") pool.map(reconstruct_and_compare_level_parallel, levels) elif parallelize_along == 'files': print("parallelizing along files") pool.map(recontruct_and_compare_file_parallel, files) if __name__ == "__main__": import argparse import time seconds_start = time.time() local_time_start = time.ctime(seconds_start) print("\nStarted at:", local_time_start,"\n\n") #global NUM_BEAT_ANALYZED parser = argparse.ArgumentParser() parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") parser.add_argument("--levels", help="Decide how many bits to use in the ADC: options:\n\t->1: [3]\n\t->2: [3,4]\n\t->3: [3,4,5]\n\t->4: ...") parser.add_argument("--cores", help="Force used number of cores (default, half of the available ones") parser.add_argument("--parallelize_along", help="describe if to parallelize on the number of 'files'\n or on the number of 'levels'") parser.add_argument("--beats", help="Number of used beats, default: 5000") parser.add_argument("--cluster_opt", help="Percentage of points for a cluster to be considered") parser.add_argument("--interpolation_type", help="Chose between: spline, flat, linear, and all. Default: falt") parser.add_argument("--acquisition_time_mode", help="Menage the time length the algorithm acquire the data at \ full speed ant the time horizon used for template recomputation. \nChose between: short, normal, and long. Default: normal") args = parser.parse_args() files = os.listdir(os.path.join(data_beats_dir,"original")) if args.file is not None: if args.file == 'all': FILES_SELECTED = files else: FILES_SELECTED = list(filter(lambda string: True if args.file in string else False, files)) else: FILES_SELECTED = [files[0]] if args.cores is not None: used_cores = int(args.cores) else: used_cores = multiprocessing.cpu_count()//3 if args.beats is not None: NUM_BEAT_ANALYZED = int(args.beats) else: NUM_BEAT_ANALYZED = 5000 if args.cluster_opt is not None: CLUSTER_PERCENTAGE = int(args.cluster_opt) else: CLUSTER_PERCENTAGE = 3 if args.interpolation_type is not None: INTERPOLATION_TYPE = args.interpolation_type else: INTERPOLATION_TYPE = "flat" if args.parallelize_along is not None: PARALLELIZE_ALONG = args.parallelize_along else: PARALLELIZE_ALONG = "levels" if args.levels is not None: - LEVELS = LEVELS[:args.levels] + LEVELS = LEVELS[:int(args.levels)] else: LEVELS = [3,4,5,6,7,8] if args.acquisition_time_mode is not None: if args.acquisition_time_mode == "short": SEC_FOR_INITIAL_TEMPLATES = 3*60 #5*60 LEN_DISTANCE_VECTOR = 60 #80 LEN_DISTANCE_VECTOR_REF = 400 #500 SEC_FOR_NEW_TEMPLATES = 40 #2*60 + TIME_MODE = 'short' elif args.acquisition_time_mode == "long": SEC_FOR_INITIAL_TEMPLATES = 8*60 LEN_DISTANCE_VECTOR = 110 LEN_DISTANCE_VECTOR_REF = 650 SEC_FOR_NEW_TEMPLATES = 3*60 + TIME_MODE = 'long' else: SEC_FOR_INITIAL_TEMPLATES = 5*60 LEN_DISTANCE_VECTOR = 80 LEN_DISTANCE_VECTOR_REF = 500 SEC_FOR_NEW_TEMPLATES = 2*60 + TIME_MODE = 'medium' else: SEC_FOR_INITIAL_TEMPLATES = 5*60 LEN_DISTANCE_VECTOR = 80 LEN_DISTANCE_VECTOR_REF = 500 SEC_FOR_NEW_TEMPLATES = 2*60 + TIME_MODE = 'medium' print(f"Analyzing files: {FILES_SELECTED}") print(f"Extracting data with {used_cores} cores...") process(files = FILES_SELECTED, levels = LEVELS, parallelize_along = PARALLELIZE_ALONG, cores=used_cores) seconds_stop = time.time() local_time_stop = time.ctime(seconds_stop) elapsed = seconds_stop - seconds_start hours = elapsed//60//60 minutes = (elapsed - hours * 60 * 60) // 60 seconds = (elapsed - hours * 60 * 60 - minutes * 60) // 1 print("\n\n\n-----------------------------------------------------------------------------------------------------------------") print(f"Finished at: {local_time_stop}, elapsed: {elapsed} seconds ({hours} hours, {minutes} minutes, {seconds} seconds)")