Page MenuHomec4science

naive_dtw_recon.py
No OneTemporary

File Metadata

Created
Thu, May 23, 23:18

naive_dtw_recon.py

import multiprocessing
import os
import pickle
import shutil
from multiprocessing import Pool
from time import time
import matplotlib.pyplot as plt
import numpy as np
from dtaidistance import dtw
from scipy.interpolate import interp1d
from build_template import build_template
data_beats_dir = "../data/beats/"
log_dir = "../data/beat_recon_logs_naive"
FREQ = 128
MIN_TO_AVG = 5
BEAT_TO_PLOT = 10
LEVEL = [1,2,3,4,5,6,7,8,9,10,11]
NUM_BEAT_ANALYZED = 5000
def RMSE(v1,v2):
v1_n = np.array(v1)
v2_n = np.array(v2)
return np.sqrt(np.mean((v1_n-v2_n)**2))
def RMSE_warped_resampled(data_orig,data_warped,data_resampled):
rmse = {}
for beat in data_orig.keys():
t_start = data_orig[beat]['t'][0]
t_stop = data_orig[beat]['t'][-1]
for lvl in data_warped[beat].keys():
if lvl not in rmse.keys():
rmse[lvl] = {'warp':[],'resamp':[]}
v_warped = []
v_resampled = []
for i,t in enumerate(data_warped[beat][lvl]['t']):
if t>=t_start and t<=t_stop:
v_warped.append(data_warped[beat][lvl]['v'][i])
v_resampled.append(data_resampled[beat][lvl]['v'][i])
elif t > t_stop:
break
rmse[lvl]['warp'].append(RMSE(data_orig[beat]['v'],v_warped))
rmse[lvl]['resamp'].append(RMSE(data_orig[beat]['v'],v_resampled))
return rmse
def open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = None):
'''
Data structure:
File:
|
DICTIONARY:
|
--> Beat_annot:
|
--> lvl:
|
--> "t": []
|
--> "v": []
'''
file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name))
data = {}
data_out = {}
with open(file_name_full,"rb") as f:
data = pickle.load(f)
for k in data.keys():
if k > FREQ*start_after:
data_out[k] = {}
if get_selected_level is not None:
data_out[k][0] = data[k][0]
for lvl in get_selected_level:
data_out[k][lvl] = data[k][lvl]
else:
data_out[k] = data[k]
return data_out, list(data_out[k].keys())
def min_max_normalization(vector, forced_avg = None):
mi_v = min(vector)
ma_v = max(vector)
norm = (np.array(vector)-mi_v)/(ma_v-mi_v)
if forced_avg is not None:
avg = forced_avg
else:
avg = np.average(norm)
norm -= avg
return list(norm),avg
def resamp_one_signal(t,v,resample_type = 'linear', min_t = None, max_t = None):
if resample_type == "linear":
f = interp1d(t,v)
elif resample_type == "flat":
f = interp1d(t,v, kind = 'previous')
if min_t is None:
min_t = t[0]
if max_t is None:
max_t = t[-1]
t_new = list(range(min_t,max_t+1))
v_new = f(t_new)
return t_new,v_new
def resample(data, resample_type = "linear", min_t = None, max_t = None):
resampled_data = {"t":None,"v":None}
t = data['t']
v = data['v']
t_r,v_r = resamp_one_signal(t,v,resample_type = resample_type, min_t = min_t, max_t = max_t)
resampled_data['t'] = t_r
resampled_data['v'] = v_r
return resampled_data
def warp(data,template):
alligned = {}
v_src = data['v']
path = dtw.warping_path(v_src, template)
v_alligned =[]
prev_idx_src = path[-1][0]+1
for idx_src, idx_temp in path:
if prev_idx_src == idx_src:
continue
prev_idx_src = idx_src
v_alligned.append(template[idx_temp])
alligned = {'t':data['t'],'v':v_alligned}
return alligned
def reconstruct_beats(file_name, level = None, normalize = True, resample_type = 'linear', num_beats = None, verbose = False, prev_pre_proc = None):
reconstructed = {}
resamp = {}
data_orig = {}
pre_proc = {}
if verbose:
print(f"Extracting {file_name}")
data,lvls = open_file(file_name, start_after = MIN_TO_AVG, get_selected_level = level)
lvls.remove(0)
if verbose:
print("Building template")
if prev_pre_proc is not None:
template = prev_pre_proc['template']
else:
template,_,_ = build_template(file_name, normalize = normalize)
if num_beats == None:
num_beats = len(list(data.keys()))
if verbose:
print("reconstructing")
for lvl in lvls:
i = 0
if verbose:
print(f"Analyzing level:{lvl}")
for beat in data.keys():
if i == num_beats:
break
i+=1
if (i%(num_beats/20)==0 or i == 1) and verbose:
print(f"File: {file_name}, Reconstructing beat {beat} ({i}/{num_beats}: {100*i/num_beats}%, LEVEL:{lvl})")
if beat not in reconstructed.keys():
reconstructed[beat] = {}
resamp[beat] = {}
data_orig[beat] = data[beat][0]
if normalize:
data_orig[beat]['v'],avg = min_max_normalization(data_orig[beat]['v'], forced_avg = None)
data[beat][lvl]['v'],_ = min_max_normalization(data[beat][lvl]['v'], forced_avg = avg)
data_resampled = resample(data[beat][lvl], resample_type = resample_type, min_t = data_orig[beat]['t'][0], max_t = data_orig[beat]['t'][-1])
reconstructed[beat][lvl] = warp(data_resampled,template)
resamp[beat][lvl] = data_resampled
pre_proc['template'] = template
return reconstructed,data_orig,resamp,pre_proc
def recontruct_and_compare(file):
pre_proc = None
for lvl in LEVEL:
recon,orig,resamp,pre_proc = reconstruct_beats(file, level = [lvl], resample_type = 'flat', num_beats = NUM_BEAT_ANALYZED, verbose = True, prev_pre_proc = pre_proc) #resample_type = flat vs linear
rmse = RMSE_warped_resampled(orig,recon,resamp)
avg_rmse_warp = np.average(rmse[lvl]['warp'])
std_rmse_warp = np.std(rmse[lvl]['warp'])
avg_rmse_resamp = np.average(rmse[lvl]['resamp'])
std_rmse_resamp = np.std(rmse[lvl]['resamp'])
file_name_to_save = "L_"+file.split(".")[0]+".log"
with open(os.path.join(log_dir,file_name_to_save),"a") as f:
f.write(f"Lvl: {lvl}\n")
f.write(f"\tWarp: {avg_rmse_warp}, +-{std_rmse_warp}\n")
f.write(f"\tInterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}\n")
f.write(f"\n\n")
print(f"File:{file_name_to_save}")
print(f"\tLvl: {lvl}")
print(f"\t\twarp: {avg_rmse_warp}, +-{std_rmse_warp}")
print(f"\t\tinterpolation: {avg_rmse_resamp}, +-{std_rmse_resamp}")
file_name_to_save_fig = os.path.join(log_dir,file.split(".")[0]+"_"+str(lvl)+".svg")
file_name_to_save_fig_template = os.path.join(log_dir,file.split(".")[0]+"_template.svg")
beat_to_plot = list(recon.keys())[BEAT_TO_PLOT]
t_o,v_o = orig[beat_to_plot]['t'],orig[beat_to_plot]['v']
t_a,v_a = recon[beat_to_plot][lvl]['t'],recon[beat_to_plot][lvl]['v']
t_r,v_r = resamp[beat_to_plot][lvl]['t'],resamp[beat_to_plot][lvl]['v']
plt.figure()
plt.plot(t_o,v_o)
plt.plot(t_a,v_a)
plt.plot(t_r,v_r)
plt.legend(['original','warped template','resampled'])
plt.title(f'File: {file}, Lvl: {lvl}, Beat time (samples):{beat_to_plot}')
plt.savefig(file_name_to_save_fig)
plt.close()
if not os.path.isfile(file_name_to_save_fig_template):
template = pre_proc['template']
plt.figure()
plt.plot(template)
plt.legend(['template'])
plt.title(f'File: {file}, Beat time (samples):{beat_to_plot}')
plt.savefig(file_name_to_save_fig_template)
plt.close()
def process(files, multi=True, cores=1):
# ------------ INIT ------------
if os.path.isdir(log_dir):
shutil.rmtree(log_dir)
os.mkdir(log_dir)
# ------------ Extract DATA & ANNOTATIONS ------------
with Pool(cores) as pool:
pool.map(recontruct_and_compare, files)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)")
parser.add_argument("--not_norm", help="Force to NOT normalize each beats", action="store_true")
parser.add_argument("--cores", help="Force used number of cores (default, half of the available ones")
args = parser.parse_args()
files = os.listdir(data_beats_dir)
if args.file is not None:
if args.file == 'all':
analyzed = files
else:
analyzed = list(filter(lambda string: True if args.file in string else False, files))
else:
analyzed = [files[0]]
if args.not_norm:
normalize = False
else:
normalize = True
if args.cores is not None:
used_cores = int(args.cores)
else:
used_cores = multiprocessing.cpu_count()//2
print(f"Analyzing files: {analyzed}")
print(f"Extracting data with {used_cores} cores...")
process(files = analyzed, multi=True, cores=used_cores)

Event Timeline