diff --git a/src/dataset_generator.py b/src/dataset_generator.py index efc653b..86e4919 100644 --- a/src/dataset_generator.py +++ b/src/dataset_generator.py @@ -1,118 +1,142 @@ import os from typing import List import pandas as pd import numpy as np import jax.numpy as jnp DATA_PATH = "../ecg_syn/ecgsyn.dat" SAVE_PATH = "../dataset/beats.npy" OPT_FILE = "../ecg_syn/ecgsyn.opt" HR = 60 T_SPAN_RANDOM_SIGNAL_SECONDS = 2 +MEM_FOR_ECGSYN = 32e9 +BYTES_ECGSYN_BOINT = 1001 def run_ECGSYN(data_path,freq,num_samples): dt = 1/freq if os.path.isfile(OPT_FILE): os.remove(OPT_FILE) command = f'cd ../ecg_syn/ ; ./ecgsyn -n {num_samples+2} -s {freq} -S {freq} -h {HR} %%' #num_samples+2 as the first and last heartbeat might be vexed os.system(command) data = pd.read_csv(data_path,delimiter=" ",header=None) return data def separate_beats(vs: np.ndarray, ms: List) -> List[np.ndarray]: out: List[np.ndarray] = [] min_value_idx: int = 0 min_value_idx_old: int = 0 min_value: float = np.inf in_t_p: bool = False for i,(v,m) in enumerate(zip(vs,ms)): if m == 5: in_t_p = True if m == 1: in_t_p = False out.append(vs[min_value_idx_old:min_value_idx]) min_value_idx_old = min_value_idx min_value = np.inf if in_t_p: if v List[np.ndarray]: out: List[np.ndarray] = [] small_len = min([len(w) for w in windows]) for w in windows: len_diff = len(w)-small_len out.append(w[len_diff:]-min(w[len_diff:])) return out def load_signal(num_pts,freq): tot_num_pts = 0 freq_this_file = 0 dataset = None if os.path.isfile(OPT_FILE): with open(OPT_FILE) as f: for l in f.readlines(): if "-s" in l: freq_this_file = int(l[3:12]) if "-n" in l: tot_num_pts = int(l[3:12]) if tot_num_pts>=num_pts and freq == freq_this_file: dataset = np.load(SAVE_PATH)[:num_pts] + if tot_num_pts>num_pts: + print(f"Loaded {tot_num_pts} points, the dataset contains {tot_num_pts} points") else: print(f"Present dataset do not respect given parameter (f: {freq_this_file}, pts: {tot_num_pts})") else: print("No signal to load/ Missing config file") return dataset def create_random_signal(coefs, ws, dt): t = np.linspace(0,T_SPAN_RANDOM_SIGNAL_SECONDS,int(T_SPAN_RANDOM_SIGNAL_SECONDS/dt)) x = np.zeros(len(t)) for w,c in zip(ws, coefs): x += c*np.sin(w*t) return x def create_positive_random_dataset(num_pts,freq): out = [] rng = np.random.default_rng(31415926514) ws = rng.choice(int(freq*10), size=100, replace=False)/100 #We do FREQ*10/100 so to have an big enough integer search space for rng.choich, and we divide by 10 so the maximum freq. is 1/10 of the sampling freq dt = 1/freq coefs = rng.choice(3000, size=100) for _ in range(num_pts): x = create_random_signal(coefs, ws, dt) x -= min(x)+0.01 out.append(x) out = np.array(out) return out def create_ECG_emulated_dataset(num_pts,freq): - data = run_ECGSYN(data_path=DATA_PATH,freq=freq, num_samples=num_pts) - v = data[1].to_numpy() - marks = data[2].to_list() - windows = separate_beats(v,marks)[:num_pts] + windows = [] + max_num_beat_this_freq = int(MEM_FOR_ECGSYN/(freq*BYTES_ECGSYN_BOINT)*60/HR) + print(f"Maximum number of beats at this freq (per ECGSYN run): {max_num_beat_this_freq}") + print(f"Beats desired: {num_pts}") + print(f"Running ECGSYN {num_pts//max_num_beat_this_freq} times") + for i in range(num_pts//max_num_beat_this_freq): + print("\n###########################") + print(f"#Generating {i+1}/{num_pts//max_num_beat_this_freq} datassets") + print("###########################\n") + data = run_ECGSYN(data_path=DATA_PATH,freq=freq, num_samples=max_num_beat_this_freq) + v = data[1].to_numpy() + marks = data[2].to_list() + windows.extend(separate_beats(v,marks)[:max_num_beat_this_freq]) + #Tail + if num_pts%max_num_beat_this_freq != 0: + print("\n###########################") + print(f"#Running ECGSYN. Tail beats number: {num_pts%max_num_beat_this_freq}") + print("###########################\n") + num_beats_remaining = num_pts%max_num_beat_this_freq + data = run_ECGSYN(data_path=DATA_PATH,freq=freq, num_samples=num_beats_remaining) + v = data[1].to_numpy() + marks = data[2].to_list() + windows.extend(separate_beats(v,marks)[:num_beats_remaining]) + windows_length_norm = normalize_length(windows) dataset_np = np.array(windows_length_norm) if not os.path.isdir("../dataset"): os.mkdir("../dataset") np.save(SAVE_PATH,dataset_np) return dataset_np def get_signal(type = 'load', num_pts = 1000, freq = 256): if type == 'random': dataset = create_positive_random_dataset(num_pts,freq) elif type == 'load': dataset = load_signal(num_pts,freq) elif type == 'create': dataset = create_ECG_emulated_dataset(num_pts,freq) else: print("Dataset type not recognized in 'get_signal()'") return jnp.array(dataset) def main() -> None: pass if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/gaussain_collpsing.py b/src/gaussain_collpsing.py index 5d96743..bb72a62 100644 --- a/src/gaussain_collpsing.py +++ b/src/gaussain_collpsing.py @@ -1,265 +1,264 @@ from functools import partial import os import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np from jax import grad, jit, pmap, value_and_grad, vmap from jax.scipy.special import logsumexp from torch.utils.data import DataLoader import optax import dataset_generator as dataLoader import processing as proc -FREQ = 150_000 -NUM_SAMPLES = 250 -BATCH_SIZE = 50 +FREQ = 1_000_000 +NUM_SAMPLES = 50 +BATCH_SIZE = 10 USE_ENERGY_MAX_FREQ_EVAL = True -STOPPING_PERC = 3_000 +STOPPING_PERC = 4_000 INTERPOLATOR = 'sinc' SIGNAL_TYPE = 'create' #'create': create a new emyulated ECG signal, 'load': load previous ECG signal, 'random': random sinusoidal mix signal EPOCHS = 20 SIGMA_EPOCH = 50 -SIGMA_INIT_DIVISOR = 5 -SIGMA_EPOCH_DIVISOR = 1.2 +SIGMA_EPOCH_DIVISOR = 1.1 INIT_FREQ_DIVISOR = 4 -NUM_LVLS = 8 -LR = 0.01 +NUM_LVLS = 16 +LR = 0.005 DEV_MODE = False imag_res_folder = "../img_res" def init_pipeline(gaussian_numbers,range_min,range_max,freq): - mu_s = jnp.linspace(range_min+(range_max-range_min)/3,range_max-(range_max-range_min)/3,gaussian_numbers) + mu_s = jnp.linspace(range_min+(range_max-range_min)/NUM_LVLS,range_max-(range_max-range_min)/NUM_LVLS,gaussian_numbers) params = {'mus':mu_s,'freq':freq} return params def proc_pipeline(params,x,sigma,static_params): signal = x signal = proc.mix_gaussian_lvl_crossing(signal,params['mus'],sigma) signal = proc.normalize(signal) signal = proc.sinc_interpolation_freq_parametrized(signal,params['freq'],FREQ,static_params['time_base']) signal = proc.normalize(signal) return signal batched_proc_pipeline = pmap(proc_pipeline,in_axes=(None,0,None,None),static_broadcasted_argnums=[3]) batched_RMSE = pmap(proc.RMSE) def batched_loss_fn(params, data, nyquist_sampled_data, sigma, static_params): proc_results = batched_proc_pipeline(params, data,sigma, static_params) return jnp.mean(batched_RMSE(proc_results,nyquist_sampled_data)) def compute_loss(dataset,params,sigma,static_params): loss = 0 i = 0 for (batch,objective_batch) in dataset: i+=1 loss += batched_loss_fn(params, batch, objective_batch,sigma, static_params) return loss/i def train(train_dataset,params,sigma,static_params, lr=LR): print(f"Loss: {compute_loss(train_dataset,params,sigma,static_params)}") optimizer = optax.adam(lr) opt_state = optimizer.init(params) for e in range(EPOCHS): print(f"Epoch {e+1}") for (train_batch,objective_batch) in train_dataset: grads = grad(batched_loss_fn)(params, train_batch, objective_batch, sigma, static_params) updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) print(f"Loss: {compute_loss(train_dataset,params,sigma,static_params)}") return params #Final loss def compute_loss_each_samples(dataset,params,sigma,static_params): #NO loss = [] for (batch,objective_batch) in dataset: proc_results = batched_proc_pipeline(params, batch,sigma, static_params) loss.extend(batched_RMSE(proc_results, objective_batch)) return loss def compute_transform_each_samples(train_dataset,params,sigma,static_params, lr=LR): proc_results = [] for (train_batch,_) in train_dataset: proc_results.extend(batched_proc_pipeline(params, train_batch,sigma, static_params)) return proc_results #Final loss def custom_collate_fn(batch): transposed_data = list(zip(*batch)) data = np.array(transposed_data[0]) obj = np.array(transposed_data[1]) return data, obj class hashabledict(dict): def __hash__(self): return hash(tuple(sorted(self.items()))) class hashable_np_array(np.ndarray): def __hash__(self): return int(self.mean()*1_000_000_000) def main(): if DEV_MODE: freq_desired = int(input("What base frequency do you need your data: ")) mus_num = int(input("How many levels: ")) sigma_div = int(input("Sigma fraction of the signal range (1/X): ")) nyq_freq_div = int(input("Sync freq in fraction of nyq. freq: ")) data = dataLoader.get_signal(SIGNAL_TYPE,num_pts=1,freq=freq_desired) print("Computing Nyquist frequency...") nyq_freq = proc.get_nyquist_freq_dataset(data,freq_desired,USE_ENERGY_MAX_FREQ_EVAL,STOPPING_PERC) t_base_nyq = np.arange(0,len(data[0])/freq_desired,1/nyq_freq) t_base_orig = np.arange(0,(len(data[0]))/freq_desired,1/freq_desired) print(f"Nyquist frequency: {nyq_freq}") print("Generating Nyquist sampled objective") dataset_nyq = proc.interpolate_dataset(data, freq_desired, nyq_freq) print("Initializing pipelines Parameters") sigma = float((jnp.max(data)-jnp.min(data))/sigma_div) params = init_pipeline(mus_num,np.min(data),np.max(data),nyq_freq/nyq_freq_div) print(f"Parameters: {params}\nSigma: {sigma}") print("Computing mixed gaussian representation") gaussian_lvl_crossing_data = proc.mix_gaussian_lvl_crossing(data[0],params['mus'],sigma) print("Normalizing") gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data) print("Computing sync filter") resampled_gaussians = proc.sinc_interpolation_freq_parametrized(gaussian_lvl_crossing_data,params['freq'],freq_desired,t_base_orig) print("Normalizing") resampled_gaussians = proc.normalize(resampled_gaussians) print("Plotting") plt.figure() plt.plot(t_base_orig, data[0]) plt.plot(t_base_nyq,dataset_nyq[0],"d") plt.plot(t_base_orig,gaussian_lvl_crossing_data) plt.plot(t_base_orig,resampled_gaussians[:len(t_base_orig)],"-") plt.hlines(params['mus'],t_base_orig[0],t_base_orig[-1]) plt.show() return #Setting env: if not os.path.isdir(imag_res_folder): os.mkdir(imag_res_folder) t_stamp = str(os.times().elapsed) res_folder_this_run = os.path.join(imag_res_folder,t_stamp) os.mkdir(res_folder_this_run) os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={BATCH_SIZE}' #load/generate data print("Loading data") data = dataLoader.get_signal(SIGNAL_TYPE,num_pts=NUM_SAMPLES,freq=FREQ) if data is None: print("Failed loading data") return data = proc.normalize_dataset(data) print("\n-------------------\n") print(f"Loaded Data:shape {data.shape}") #reference print("Computing maximum Nyquist frequency for the Dataset...") nyq_freq = proc.get_nyquist_freq_dataset(data,FREQ,USE_ENERGY_MAX_FREQ_EVAL,STOPPING_PERC) t_base_nyq = np.arange(0,len(data[0])/FREQ,1/nyq_freq) t_base_orig = np.arange(0,(len(data[0]))/FREQ,1/FREQ) print(f"Nyquist frequency: {nyq_freq}") print("Generating Nyquist sampled objective dataset") dataset_nyq = proc.interpolate_dataset(data, FREQ, nyq_freq) train_dataset = [[d,o] for d,o in zip(data,dataset_nyq)] static_params = hashabledict({'freq' : None,'time_base' : None}) static_params['freq'] = nyq_freq static_params['time_base'] = t_base_nyq.view(hashable_np_array) #static_params = t_base_nyq.view(hashable_np_array) #INIT print("Initializing pipelines Parameters") - sigma = float((jnp.max(data)-jnp.min(data))/SIGMA_INIT_DIVISOR) + sigma = float((jnp.max(data)-jnp.min(data))/(2.2*NUM_LVLS)) params = init_pipeline(NUM_LVLS,np.min(data),np.max(data),nyq_freq/INIT_FREQ_DIVISOR) train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn, drop_last=False) print(f"Parameters: {params}") loss_sigmas = [] lr = LR for i in range(SIGMA_EPOCH): print(f"SIGMA EPOCH: {i+1}/{SIGMA_EPOCH}, sigma: {sigma}") params = train(train_loader,params,sigma,static_params,lr) loss = compute_loss(train_loader,params,sigma,static_params) loss_sigmas.append(loss) losses = compute_loss_each_samples(train_loader,params,sigma,static_params) best_beat = np.argmin(losses) worst_beat = np.argmax(losses) gaussian_lvl_crossing_data = proc.mix_gaussian_lvl_crossing(data[best_beat],params['mus'],sigma) gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data) resampled_gaussians = proc.sinc_interpolation_freq_parametrized(gaussian_lvl_crossing_data,params['freq'],FREQ,t_base_nyq) resampled_gaussians = proc.normalize(resampled_gaussians) plt.figure() plt.plot(t_base_orig, data[best_beat]) plt.plot(t_base_nyq,dataset_nyq[best_beat],"d") plt.plot(t_base_orig,gaussian_lvl_crossing_data) plt.plot(t_base_nyq,resampled_gaussians[:len(t_base_nyq)],"o-") plt.hlines(params['mus'],t_base_orig[0],t_base_orig[-1]) plt.savefig(f'{res_folder_this_run}/sigma:{sigma}_best_loss:{np.min(losses)}.svg') plt.close() gaussian_lvl_crossing_data = proc.mix_gaussian_lvl_crossing(data[worst_beat],params['mus'],sigma) gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data) resampled_gaussians = proc.sinc_interpolation_freq_parametrized(gaussian_lvl_crossing_data,params['freq'],FREQ,t_base_nyq) resampled_gaussians = proc.normalize(resampled_gaussians) plt.figure() plt.plot(t_base_orig, data[best_beat]) plt.plot(t_base_nyq,dataset_nyq[best_beat],"d") plt.plot(t_base_orig,gaussian_lvl_crossing_data) plt.plot(t_base_nyq,resampled_gaussians[:len(t_base_nyq)],"o-") plt.hlines(params['mus'],t_base_orig[0],t_base_orig[-1]) plt.savefig(f'{res_folder_this_run}/sigma:{sigma}_worst_loss:{np.max(losses)}.svg') plt.close() sigma /= SIGMA_EPOCH_DIVISOR lr -= lr/(20) print(f"END OF SIGMA EPOCH {i+1}, LOSS = {loss}") print(params) plt.figure() plt.plot(loss_sigmas) plt.savefig(f'{res_folder_this_run}/lossVSepoch.svg') plt.close() pt = 3 #create gaussain level crossing: mu_s = params['mus'] gaussian_lvl_crossing_data = proc.mix_gaussian_lvl_crossing(data[pt],mu_s,sigma) gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data) resampled_gaussians = proc.sinc_interpolation_freq_parametrized(gaussian_lvl_crossing_data,params['freq'],FREQ,t_base_orig) resampled_gaussians = proc.normalize(resampled_gaussians) #Testing grounds plt.plot(t_base_orig, data[pt]) plt.plot(t_base_nyq,dataset_nyq[pt],"d") plt.plot(t_base_orig,gaussian_lvl_crossing_data) plt.plot(t_base_orig,resampled_gaussians[:len(t_base_orig)],"-") plt.show() if __name__ == "__main__": #test_signal() main()