Page MenuHomec4science

DTW_allignment.py
No OneTemporary

File Metadata

Created
Sat, Jun 1, 19:03

DTW_allignment.py

from dtaidistance import dtw
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
data_beats_dir = "../data/beats/"
BEAT_NUM = 156 # RANDOM BEAT TO TEST THINGS TO
'''
Data structure:
File:
|
DICTIONARY:
|
--> Beat_annot:
|
--> lvl:
|
--> "t": []
|
--> "v": []
'''
def open_file(file_name):
file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name))
data = {}
with open(file_name_full,"rb") as f:
data = pickle.load(f)
return data
def split_beat(data, split_time):
'''
DATA:
lvl:
|
--> "t": []
|
--> "v": []
'''
data_splitted = {}
for lvl in data.keys():
split_idx = find_time_index_in_eb(data[lvl]['t'], split_time, prior = len(data[lvl]['t'])//3)
data_splitted[lvl] = {'t0':data[lvl]['t'][:split_idx+1],
't1':data[lvl]['t'][split_idx:],
'v0':data[lvl]['v'][:split_idx+1],
'v1':data[lvl]['v'][split_idx:]}
return data_splitted
def find_time_index_in_eb(data_time, time, prior = 0, mode = "before"):
if time <= data_time[0]:
return 0
elif time >= data_time[-1]:
return (len(data_time)-1)
idx = None
#print(time,prior,len(data_time)-1)
for this_idx in range(prior,len(data_time)-1):
if time >= data_time[this_idx] and time < data_time[this_idx+1]:
if time == data_time[this_idx] or mode == "before":
idx = this_idx
elif mode == "after":
idx = this_idx+1
break
if idx == None:
for this_idx in range(0,prior):
if time >= data_time[this_idx] and time < data_time[this_idx+1]:
if time == data_time[this_idx] or mode == "before":
idx = this_idx
elif mode == "after":
idx = this_idx+1
break
return idx
def DTW_Matching(v_uniform,v_eb):
path = dtw.warping_path(v_uniform, v_eb)
v_uniform_warped = []
v_eb_warped = []
for point in path:
v_uniform_warped.append(v_uniform[point[0]])
v_eb_warped.append(v_eb[point[1]])
return v_uniform_warped, v_eb_warped
def DTW_match_beat(data, beat_time):
data_uniform = data[0]
data_splitted = split_beat(data,beat_time)
data_uniform_splitted = data_splitted[0]
keys = list(filter(lambda x: True if x != 0 else False, data_splitted.keys()))
data_matched = {"uniform_original":{'t_uniform':data_uniform['t'],'v_uniform':data_uniform['v']}}
data_matched['warped'] = {}
data_matched["EB_original"] = {}
v_uniform_0 = data_uniform_splitted['v0']
v_uniform_1 = data_uniform_splitted['v1']
for k in keys:
v_eb_0 = data_splitted[k]['v0']
v_eb_1 = data_splitted[k]['v1']
v_uniform_0_warped, v_eb_0_warped = DTW_Matching(v_uniform_0,v_eb_0)
v_uniform_1_warped, v_eb_1_warped = DTW_Matching(v_uniform_1,v_eb_1)
'''
print("Uniform:",v_uniform_0,v_uniform_1)
print("Uniform warped: ",v_uniform_0_warped,v_uniform_1_warped)
print(k,": Original --> ",v_eb_0_warped,v_eb_1_warped)
print(k,": Warped --> ",v_eb_0_warped,v_eb_1_warped)
print("-"*70)
'''
v_uniform_warped = v_uniform_0_warped[:-1]
v_eb_warped = v_eb_0_warped[:-1]
v_uniform_warped.extend(v_uniform_1_warped)
v_eb_warped.extend(v_eb_1_warped)
data_matched['warped'][k] = {'v_uniform_warped': v_uniform_warped, 'v_eb_warped':v_eb_warped}
data_matched["EB_original"][k] = {'t_EB':data[k]['t'],'v_EB':data[k]['v']}
return data_matched
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)")
parser.add_argument("--beat", help="Force to analyze one specific beat instead of default one (first found)")
args = parser.parse_args()
files = os.listdir(data_beats_dir)
if args.file is not None:
analyzed = list(filter(lambda string: True if args.file in string else False, files))[0]
else:
analyzed = files[0]
if args.beat is not None:
beat_analyzed = int(args.beat)
else:
beat_analyzed = 0
data = open_file(analyzed)
k = list(data.keys())[beat_analyzed]
data_DTW = DTW_match_beat(data[k],k)
for k in data_DTW['warped'].keys():
if k <= 6:
# Creates two subplots and unpacks the output array immediately
f, (ax1, ax2, ax3) = plt.subplots(3, 1)
ax1.plot(data_DTW["uniform_original"]['t_uniform'],data_DTW["uniform_original"]['v_uniform'])
#ax1.plot(data_DTW["uniform_original"]['v_uniform'])
ax1.plot(data_DTW["EB_original"][k]['t_EB'],data_DTW["EB_original"][k]['v_EB'])
ax1.set_title('ORIGINAL and time-matched EB')
#ax2.plot(data_DTW["EB_original"][k]['t_EB'],data_DTW["EB_original"][k]['v_EB'])
ax2.plot(data_DTW["EB_original"][k]['v_EB'])
ax2.set_title("Event based")
ax3.plot(data_DTW['warped'][k]['v_uniform_warped'])
ax3.plot(data_DTW['warped'][k]['v_eb_warped'])
ax3.set_title(f"Warped against {k} bits sampled signal")
ax3.legend(["uniform","level crossing"])
plt.show()
'''
s1 = np.array([0., 0, 1, 2, 1, 0, 1, 0, 0, 2, 1, 0, 0])
s2 = np.array([0., 1, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 1, 0])
path = dtw.warping_path(s1, s2)
print(path)
w1 = []
w2 = []
print("---------")
for point in path:
print(point)
w1.append(s1[point[0]])
w2.append(s2[point[1]])
plt.figure()
plt.plot(w1)
plt.plot(w2)
plt.figure()
plt.plot(s1)
plt.plot(s2)
plt.show()
#RESULTS:
# [(0, 0), (1, 0), (2, 1), (3, 2), (3, 3), (4, 4), (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), (5, 10), (5, 11),
# (5, 12), (5, 13), (5, 14), (5, 15), (5, 16), (6, 17), (7, 18), (8, 18), (9, 19), (10, 20), (11, 21), (12, 21)]
'''

Event Timeline