diff --git a/README.md b/README.md index 588abc8..f2c9b0d 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,14 @@ This is the main repository for the smart resampler project. -The data for this repository can be fond here: [Dataset](ref) +The data for this repository can be fond here: [Dataset](https://physionet.org/content/nsrdb/1.0.0/) This project aim to replicate the results obtained in the original paper up to the SVM results To run this project: - Make this directory the home director for this project -- Download the [Dataset](ref) and put it inside a "./data" folder -- Execute "python3 ./preProc/pre_proc.py" +- Download the [Dataset](https://physionet.org/content/nsrdb/1.0.0/) and put it inside a "./data/dataRaw" folder +- Execute "python3 ./preProc/pre_process.py" +- To check DTW results on an ECG beat run "python3 ./src/DTW_alligment.py [--file file_name][--beat beat_to_analyze (sequential)"] used: - [DTIDistance](https://github.com/wannesm/dtaidistance) for DWT - [WFDB](https://pypi.org/project/wfdb/) - [eventBased](https://c4science.ch/source/eb_lib/) diff --git a/src/DTW_allignment.py b/src/DTW_allignment.py new file mode 100644 index 0000000..dd7208c --- /dev/null +++ b/src/DTW_allignment.py @@ -0,0 +1,178 @@ +from dtaidistance import dtw +import numpy as np +import matplotlib.pyplot as plt +import pickle +import os + +data_beats_dir = "../data/beats/" +BEAT_NUM = 156 # RANDOM BEAT TO TEST THINGS TO + +''' +Data structure: +File: + | + DICTIONARY: + | + --> Beat_index: + | + --> lvl: + | + --> "t": [] + | + --> "v": [] +''' + +def open_file(file_name): + file_name_full = os.path.join(data_beats_dir,os.path.basename(file_name)) + data = {} + with open(file_name_full,"rb") as f: + data = pickle.load(f) + return data + +def split_beat(data, split_time): + ''' + DATA: + lvl: + | + --> "t": [] + | + --> "v": [] + ''' + data_splitted = {} + for lvl in data.keys(): + split_idx = find_time_index_in_eb(data[lvl]['t'], split_time, prior = len(data[lvl]['t'])//3) + data_splitted[lvl] = {'t0':data[lvl]['t'][:split_idx+1], + 't1':data[lvl]['t'][split_idx:], + 'v0':data[lvl]['v'][:split_idx+1], + 'v1':data[lvl]['v'][split_idx:]} + return data_splitted + +def find_time_index_in_eb(data_time, time, prior = 0, mode = "before"): + if time <= data_time[0]: + return 0 + elif time >= data_time[-1]: + return (len(data_time)-1) + idx = None + #print(time,prior,len(data_time)-1) + for this_idx in range(prior,len(data_time)-1): + if time >= data_time[this_idx] and time < data_time[this_idx+1]: + if time == data_time[this_idx] or mode == "before": + idx = this_idx + elif mode == "after": + idx = this_idx+1 + + break + if idx == None: + for this_idx in range(0,prior): + if time >= data_time[this_idx] and time < data_time[this_idx+1]: + if time == data_time[this_idx] or mode == "before": + idx = this_idx + elif mode == "after": + idx = this_idx+1 + + break + return idx + +def DTW_Matching(v_uniform,v_eb): + path = dtw.warping_path(v_uniform, v_eb) + v_uniform_warped = [] + v_eb_warped = [] + for point in path: + v_uniform_warped.append(v_uniform[point[0]]) + v_eb_warped.append(v_eb[point[1]]) + return v_uniform_warped, v_eb_warped + +def DTW_match_beat(data, beat_time): + data_uniform = data[0] + data_splitted = split_beat(data,beat_time) + data_uniform_splitted = data_splitted[0] + keys = list(filter(lambda x: True if x != 0 else False, data_splitted.keys())) + + data_matched = {"uniform_original":{'t_uniform':data_uniform['t'],'v_uniform':data_uniform['v']}} + data_matched['warped'] = {} + v_uniform_0 = data_uniform_splitted['v0'] + v_uniform_1 = data_uniform_splitted['v1'] + for k in keys: + v_eb_0 = data_splitted[k]['v0'] + v_eb_1 = data_splitted[k]['v1'] + v_uniform_0_warped, v_eb_0_warped = DTW_Matching(v_uniform_0,v_eb_0) + v_uniform_1_warped, v_eb_1_warped = DTW_Matching(v_uniform_1,v_eb_1) + ''' + print("Uniform:",v_uniform_0,v_uniform_1) + print("Uniform warped: ",v_uniform_0_warped,v_uniform_1_warped) + print(k,": Original --> ",v_eb_0_warped,v_eb_1_warped) + print(k,": Warped --> ",v_eb_0_warped,v_eb_1_warped) + print("-"*70) + + ''' + v_uniform_warped = v_uniform_0_warped[:-1] + v_eb_warped = v_eb_0_warped[:-1] + v_uniform_warped.extend(v_uniform_1_warped) + v_eb_warped.extend(v_eb_1_warped) + data_matched['warped'][k] = {'v_uniform_warped': v_uniform_warped, 'v_eb_warped':v_eb_warped} + return data_matched + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--file", help="Force to analyze one specific file instead of default one (first found)") + parser.add_argument("--beat", help="Force to analyze one specific beat instead of default one (first found)") + args = parser.parse_args() + files = os.listdir(data_beats_dir) + if args.file is not None: + analyzed = list(filter(lambda string: True if args.file in string else False, files))[0] + else: + analyzed = files[0] + if args.beat is not None: + beat_analyzed = int(args.beat) + else: + beat_analyzed = 0 + data = open_file(analyzed) + k = list(data.keys())[beat_analyzed] + data_DTW = DTW_match_beat(data[k],k) + plt.figure() + plt.plot(data_DTW["uniform_original"]['t_uniform'],data_DTW["uniform_original"]['v_uniform']) + plt.title("ORIGINAL") + for k in data_DTW['warped'].keys(): + plt.figure() + plt.plot(data_DTW['warped'][k]['v_uniform_warped']) + plt.plot(data_DTW['warped'][k]['v_eb_warped']) + plt.title(f"Warped against {k} bits sampled signal") + plt.legend(["uniform","level crossing"]) + plt.show() + + + + + + + + +''' +s1 = np.array([0., 0, 1, 2, 1, 0, 1, 0, 0, 2, 1, 0, 0]) +s2 = np.array([0., 1, 2, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 1, 0]) +path = dtw.warping_path(s1, s2) +print(path) +w1 = [] +w2 = [] +print("---------") +for point in path: + print(point) + w1.append(s1[point[0]]) + w2.append(s2[point[1]]) + +plt.figure() +plt.plot(w1) +plt.plot(w2) + +plt.figure() +plt.plot(s1) +plt.plot(s2) + +plt.show() + +#RESULTS: +# [(0, 0), (1, 0), (2, 1), (3, 2), (3, 3), (4, 4), (5, 5), (5, 6), (5, 7), (5, 8), (5, 9), (5, 10), (5, 11), +# (5, 12), (5, 13), (5, 14), (5, 15), (5, 16), (6, 17), (7, 18), (8, 18), (9, 19), (10, 20), (11, 21), (12, 21)] + +''' \ No newline at end of file