diff --git a/src/dataset_generator.py b/src/dataset_generator.py index 173467d..112e21e 100644 --- a/src/dataset_generator.py +++ b/src/dataset_generator.py @@ -1,144 +1,146 @@ import os from typing import List import pandas as pd import numpy as np import jax.numpy as jnp DATA_PATH = "../ecg_syn/ecgsyn.dat" SAVE_PATH = "../dataset/beats.npy" OPT_FILE = "../ecg_syn/ecgsyn.opt" HR = 60 T_SPAN_RANDOM_SIGNAL_SECONDS = 2 MEM_FOR_ECGSYN = 32e9 BYTES_ECGSYN_BOINT = 1001 def run_ECGSYN(data_path,freq,num_samples): dt = 1/freq if os.path.isfile(OPT_FILE): os.remove(OPT_FILE) command = f'cd ../ecg_syn/ ; ./ecgsyn -n {num_samples+2} -s {freq} -S {freq} -h {HR} %%' #num_samples+2 as the first and last heartbeat might be vexed os.system(command) data = pd.read_csv(data_path,delimiter=" ",header=None) return data def separate_beats(vs: np.ndarray, ms: List) -> List[np.ndarray]: out: List[np.ndarray] = [] min_value_idx: int = 0 min_value_idx_old: int = 0 min_value: float = np.inf in_t_p: bool = False for i,(v,m) in enumerate(zip(vs,ms)): if m == 5: in_t_p = True if m == 1: in_t_p = False out.append(vs[min_value_idx_old:min_value_idx]) min_value_idx_old = min_value_idx min_value = np.inf if in_t_p: if v List[np.ndarray]: out: List[np.ndarray] = [] small_len = min([len(w) for w in windows]) + if small_len % 2 != 0: #ENSURE EVEN LENGTH WINDOWS + small_len -= 1 for w in windows: len_diff = len(w)-small_len - out.append(w[len_diff:]-min(w[len_diff:])) + out.append(w[len_diff:]) return out def load_signal(num_pts,freq): tot_num_pts = 0 freq_this_file = 0 dataset = None if os.path.isfile(OPT_FILE): with open(OPT_FILE) as f: for l in f.readlines(): if "-s" in l: freq_this_file = int(l[3:12]) if "-n" in l: tot_num_pts = int(l[3:12]) if freq == freq_this_file: dataset = np.load(SAVE_PATH)[:num_pts] if tot_num_pts>num_pts: print(f"Loaded {tot_num_pts} points, the dataset contains {tot_num_pts} points") elif tot_num_pts None: pass if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/gaussain_collpsing_freq.py b/src/gaussain_collpsing_freq.py new file mode 100644 index 0000000..47ddd2c --- /dev/null +++ b/src/gaussain_collpsing_freq.py @@ -0,0 +1,270 @@ +from functools import partial +import os + + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +from jax import grad, jit, pmap, value_and_grad, vmap +from jax.scipy.special import logsumexp +from torch.utils.data import DataLoader +import optax + +import dataset_generator as dataLoader +import processing as proc + +FREQ = 100_000 +NUM_SAMPLES = 200 +BATCH_SIZE = 20 +USE_ENERGY_MAX_FREQ_EVAL = True +STOPPING_PERC = 2_000 +INTERPOLATOR = 'sinc' +SIGNAL_TYPE = 'create' #'create': create a new emyulated ECG signal, 'load': load previous ECG signal, 'random': random sinusoidal mix signal +EPOCHS = 40 +SIGMA_EPOCH = 20 +SIGMA_EPOCH_DIVISOR = 1.2 +INIT_FREQ_DIVISOR = 10 +NUM_LVLS = 16 +LR = 0.005 + + +DEV_MODE = False +imag_res_folder = "../img_res" + + +def init_pipeline(gaussian_numbers,signal_len,range_min,range_max): + mu_s = jnp.linspace(range_min+(range_max-range_min)/gaussian_numbers,range_max-(range_max-range_min)/gaussian_numbers,gaussian_numbers) + freq_coeffs = jnp.array([1.]*(signal_len//2)) #The length of the signal need to always be even!! + params = {'mus':mu_s,'freq_coeffs':freq_coeffs} + return params + + +def proc_pipeline(params,x,sigma,static_params): + signal = x + signal = proc.mix_gaussian_lvl_crossing(signal,params['mus'],sigma) + signal = proc.normalize(signal) + signal = proc.fourier_filtering(signal,params['freq_coeffs']) + #TO CHECK + signal = proc.sinc_interpolation_freq_parametrized(signal,FREQ,FREQ,static_params['time_base']) + signal = proc.normalize(signal) + return signal + + +batched_proc_pipeline = vmap(proc_pipeline,in_axes=(None,0,None,None)) +batched_RMSE = vmap(proc.RMSE) + +@partial(jit, static_argnums=(4,)) +def batched_loss_fn(params, data, nyquist_sampled_data, sigma, static_params): + proc_results = batched_proc_pipeline(params, data,sigma, static_params) + return jnp.mean(batched_RMSE(proc_results,nyquist_sampled_data)) +''' + +batched_proc_pipeline = pmap(proc_pipeline,in_axes=(None,0,None,None), static_broadcasted_argnums=[3]) +batched_RMSE = pmap(proc.RMSE) + +def batched_loss_fn(params, data, nyquist_sampled_data, sigma, static_params): + proc_results = batched_proc_pipeline(params, data,sigma, static_params) + return jnp.mean(batched_RMSE(proc_results,nyquist_sampled_data)) +''' + +def compute_dataset_loss(dataset,params,sigma,static_params): + loss = 0 + i = 0 + for (batch,objective_batch) in dataset: + i+=1 + loss += batched_loss_fn(params, batch, objective_batch,sigma, static_params) + return loss/i + +def train(train_dataset,params,sigma,static_params, lr=LR): + print(f"Loss: {compute_dataset_loss(train_dataset,params,sigma,static_params)}") + optimizer = optax.adam(lr) + opt_state = optimizer.init(params) + for e in range(EPOCHS): + print(f"Epoch {e+1}") + for (train_batch,objective_batch) in train_dataset: + grads = grad(batched_loss_fn)(params, train_batch, objective_batch, sigma, static_params) + updates, opt_state = optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) + print(f"Loss: {compute_dataset_loss(train_dataset,params,sigma,static_params)}") + return params #Final loss + +def compute_loss_each_samples(dataset,params,sigma,static_params): + loss = [] + for (batch,objective_batch) in dataset: + proc_results = batched_proc_pipeline(params, batch,sigma, static_params) + loss.extend(batched_RMSE(proc_results, objective_batch)) + return loss + +def compute_transform_each_samples(train_dataset,params,sigma,static_params, lr=LR): + proc_results = [] + for (train_batch,_) in train_dataset: + proc_results.extend(batched_proc_pipeline(params, train_batch,sigma, static_params)) + return proc_results #Final loss + + + +def custom_collate_fn(batch): + transposed_data = list(zip(*batch)) + data = np.array(transposed_data[0]) + obj = np.array(transposed_data[1]) + return data, obj + +class hashabledict(dict): + def __hash__(self): + return hash(tuple(sorted(self.items()))) + +class hashable_np_array(np.ndarray): + def __hash__(self): + return int(self.mean()*1_000_000_000) + + + +def main(): + if DEV_MODE: + freq_desired = int(input("What base frequency do you need your data: ")) + mus_num = int(input("How many levels: ")) + sigma_div = int(input("Sigma fraction of the signal range (1/X): ")) + + data = dataLoader.get_signal(SIGNAL_TYPE,num_pts=1,freq=freq_desired) + print("Computing Nyquist frequency...") + nyq_freq = proc.get_nyquist_freq_dataset(data,freq_desired,USE_ENERGY_MAX_FREQ_EVAL,STOPPING_PERC) + t_base_nyq = np.arange(0,len(data[0])/freq_desired,1/nyq_freq) + t_base_orig = np.arange(0,(len(data[0]))/freq_desired,1/freq_desired) + print(f"Nyquist frequency: {nyq_freq}") + + print("Generating Nyquist sampled objective") + dataset_nyq = proc.interpolate_dataset(data, freq_desired, nyq_freq) + print("Initializing pipelines Parameters") + sigma = float((jnp.max(data)-jnp.min(data))/sigma_div) + params = init_pipeline(mus_num,len(data[0]),np.min(data),np.max(data)) + print(f"Parameters: {params}\nSigma: {sigma}") + + print("Computing mixed gaussian representation") + gaussian_lvl_crossing_data = proc.mix_gaussian_lvl_crossing(data[0],params['mus'],sigma) + print("Normalizing") + gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data) + print("Computing filter") + filtered = proc.fourier_filtering(gaussian_lvl_crossing_data,params['freq_coeffs']) + #TO CHECK SINC RESAMPLING + resampled = proc.sinc_interpolation_freq_parametrized(filtered,freq_desired,freq_desired,t_base_orig) + print("Normalizing") + normed = proc.normalize(resampled) + + print("Plotting") + plt.figure() + plt.plot(t_base_orig, data[0]) + plt.plot(t_base_nyq,dataset_nyq[0],"d") + plt.plot(t_base_orig,gaussian_lvl_crossing_data) + plt.plot(t_base_orig,normed[:len(t_base_orig)],"-") + plt.hlines(params['mus'],t_base_orig[0],t_base_orig[-1]) + plt.show() + return + + #Setting env: + if not os.path.isdir(imag_res_folder): + os.mkdir(imag_res_folder) + t_stamp = str(os.times().elapsed) + res_folder_this_run = os.path.join(imag_res_folder,t_stamp) + os.mkdir(res_folder_this_run) + os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={BATCH_SIZE}' + + #load/generate data + print("Loading data") + data = dataLoader.get_signal(SIGNAL_TYPE,num_pts=NUM_SAMPLES,freq=FREQ) + if data is None: + print("Failed loading data") + return + data = proc.normalize_dataset(data) + t_base_orig = np.arange(0,(len(data[0]))/FREQ,1/FREQ)[:len(data[0])] + print("\n-------------------\n") + print(f"Loaded Data:shape {data.shape}") + + #reference + print("Computing maximum Nyquist frequency for the Dataset...") + nyq_freq = proc.get_nyquist_freq_dataset(data,FREQ,USE_ENERGY_MAX_FREQ_EVAL,STOPPING_PERC) + print(f"Nyquist frequency: {nyq_freq}") + + print("Generating Nyquist sampled objective dataset") + dataset_nyq = proc.interpolate_dataset(data, FREQ, nyq_freq) + t_base_nyq = np.arange(0,len(dataset_nyq[0])/nyq_freq,1/nyq_freq)[:len(dataset_nyq[0])] + + train_dataset = [[d,o] for d,o in zip(data,dataset_nyq)] + + static_params = hashabledict({'freq' : None,'time_base' : None}) + static_params['freq'] = nyq_freq + static_params['time_base'] = t_base_nyq.view(hashable_np_array) + #static_params = t_base_nyq.view(hashable_np_array) + + #INIT + print("Initializing pipelines Parameters") + sigma = float((jnp.max(data)-jnp.min(data))/(NUM_LVLS)) + params = init_pipeline(NUM_LVLS,len(data[0]),np.min(data),np.max(data)) + train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn, drop_last=False) + print(f"Parameters: {params}") + + + loss_sigmas = [] + lr = LR + sigma_iter = 0 + for i in range(SIGMA_EPOCH): + print(f"SIGMA EPOCH: {i+1}/{SIGMA_EPOCH}, sigma: {sigma}") + + + params = train(train_loader,params,sigma,static_params,lr) + loss = compute_dataset_loss(train_loader,params,sigma,static_params) + loss_sigmas.append(loss) + + losses = compute_loss_each_samples(train_loader,params,sigma,static_params) + best_beat = np.argmin(losses) + worst_beat = np.argmax(losses) + + + gaussian_lvl_crossing_data = proc.mix_gaussian_lvl_crossing(data[best_beat],params['mus'],sigma) + gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data) + filtered = proc.fourier_filtering(gaussian_lvl_crossing_data,params['freq_coeffs']) + #TO CHECK + interpolated = proc.sinc_interpolation_freq_parametrized(filtered,FREQ,FREQ,static_params['time_base']) + normed = proc.normalize(interpolated) + + plt.figure() + plt.plot(t_base_orig, data[best_beat]) + plt.plot(t_base_nyq,dataset_nyq[best_beat],"d") + plt.plot(t_base_orig,gaussian_lvl_crossing_data) + plt.plot(t_base_nyq,normed,"o-") + plt.hlines(params['mus'],t_base_orig[0],t_base_orig[-1]) + plt.savefig(f'{res_folder_this_run}/{sigma_iter}:sigma:{sigma}_best_loss:{np.min(losses)}.svg') + plt.close() + + + gaussian_lvl_crossing_data = proc.mix_gaussian_lvl_crossing(data[worst_beat],params['mus'],sigma) + gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data) + filtered = proc.fourier_filtering(gaussian_lvl_crossing_data,params['freq_coeffs']) + #TO CHECK + interpolated = proc.sinc_interpolation_freq_parametrized(filtered,FREQ,FREQ,static_params['time_base']) + normed = proc.normalize(interpolated) + + plt.figure() + plt.plot(t_base_orig, data[best_beat]) + plt.plot(t_base_nyq,dataset_nyq[best_beat],"d") + plt.plot(t_base_orig,gaussian_lvl_crossing_data) + plt.plot(t_base_nyq,normed,"o-") + plt.hlines(params['mus'],t_base_orig[0],t_base_orig[-1]) + plt.savefig(f'{res_folder_this_run}/{sigma_iter}:sigma:{sigma}_worst_loss:{np.max(losses)}.svg') + plt.close() + + + sigma /= SIGMA_EPOCH_DIVISOR + lr -= lr/(20) + sigma_iter += 1 + print(f"END OF SIGMA EPOCH {i+1}, LOSS = {loss}") + print(params) + plt.figure() + plt.plot(loss_sigmas) + plt.savefig(f'{res_folder_this_run}/lossVSepoch.svg') + plt.close() + +if __name__ == "__main__": + #test_signal() + main() + diff --git a/src/processing.py b/src/processing.py index 14d6088..fa0a851 100644 --- a/src/processing.py +++ b/src/processing.py @@ -1,109 +1,116 @@ import numpy as np import numpy.fft as fft import jax.numpy as jnp from jax import lax,vmap from functools import partial INTERPOLATOR = 'sinc' SIGNAL_TYPE = 'create' #'create': create a new emyulated ECG signal, 'load': load previous ECG signal, 'random': random sinusoidal mix signal def get_max_freq(signal, freq, use_energy, stopping_prec): dt = 1/freq signal_0_mean = signal-np.mean(signal) f_sig = fft.fft(signal_0_mean)[0:len(signal_0_mean)//2] freq_sig = np.fft.fftfreq(signal.size, d=dt)[0:len(signal_0_mean)//2] if use_energy: threshold = sum(abs(f_sig)**2)/stopping_prec eval_coeff = lambda x: abs(x)**2 else: threshold = max(abs(f_sig))/stopping_prec eval_coeff = lambda x: abs(x) last_significant_coef_pos = 0 magn_f_c = 0 for i,f_c in enumerate(f_sig[-1::-1]): magn_f_c += eval_coeff(f_c) if magn_f_c > threshold: last_significant_coef_pos = len(f_sig)-1-i break max_freq = freq_sig[last_significant_coef_pos] return max_freq def get_nyquist_freq_dataset(data,freq,use_energy,stopping_perc): max_f = 0 for pt in data: f = get_max_freq(pt,freq,use_energy,stopping_perc) max_f = max(max_f,f) return 2.1*max_f def normalize(signal): return signal/(jnp.max(signal)-jnp.min(signal)) def normalize_dataset(data): out =[] for pt in data: out.append(normalize(pt)) return jnp.array(out) # ---------------------------------- def gaussian(x, mu, sig): return jnp.exp(-jnp.power(x - mu, 2.) / (2 * jnp.power(sig, 2.))) def mix_gaussian_lvl_crossing(signal,mu_s,sigma): gaussian_matr = vmap(lambda mu:gaussian(signal,mu,sigma))(mu_s) return signal*jnp.sum(gaussian_matr,axis = 0) def lvl_crossing_stupid(x,mus): out = jnp.zeros((len(x)-1,)) for i in range(1,len(x)): for mu in mus: if (x[i-1]=mu) or (x[i-1]>mu and x[i]<=mu): out = out.at[i].set(mu) #GRAD IS 0 IF I USE X[I] BUT NOT IF I USE MU return out def mix_gaussian_lvl_crossing_dataset(data,mu_s,sigma): return vmap(mix_gaussian_lvl_crossing,in_axes=(0,None,None))(data,mu_s,sigma) def sinc_interpolation_freq_parametrized(samples,sinc_freq,signal_freq,time_base): dt_samples = 1/signal_freq base_sinc = lambda n:jnp.sinc((time_base-n*dt_samples)*sinc_freq) sinc_matr_v = vmap(base_sinc)(jnp.array(range(len(samples)))) return jnp.dot(samples,sinc_matr_v) #TODO: Inefficient, testing on test.py def multi_sinc_interpolation(samples,sinc_freqs,ampls,signal_freq,time_base): dt_samples = 1/signal_freq tensor = jnp.zeros((len(samples),len(time_base),len(ampls))) for i in range(len(ampls)): base_sinc = lambda n:ampls[i]*jnp.sinc((time_base-n*dt_samples)*sinc_freqs[i]) tensor = tensor.at[:,:,i].set(vmap(base_sinc)(jnp.array(range(len(samples))))) sincs_projections = jnp.tensordot(samples,tensor,axes = 1)# axes = 1 -> stupid nnumpy notation for saying "tensor dot product", why is this not the default? fuck me I guess return jnp.sum(sincs_projections,axis = 1) +def fourier_filtering(signal,coeffs): + signal_fft = jnp.fft.fft(signal) + filter = jnp.concatenate((coeffs,jnp.flip(coeffs))) + filtered = jnp.multiply(signal_fft,filter) + return jnp.real(jnp.fft.ifft(filtered)) + + def RMSE(v1,v2): return jnp.sqrt(jnp.sum((v1-v2)**2)/len(v1)) # ---------------------------------- def sinc_interp(samples,freq_in,freq_out): dt_samples = 1/freq_in dt_new = 1/freq_out new_time_base = np.arange(0,len(samples)*dt_samples,dt_new) return sinc_interpolation_freq_parametrized(samples,freq_in,freq_in,new_time_base) def interpolate(signal,freq_in,freq_out,type = INTERPOLATOR): if type == "sinc": resampled = sinc_interp(signal,freq_in,freq_out) else: print("No interpolator recognized") resampled = signal return resampled def interpolate_dataset(data,freq_in,freq_out,type = INTERPOLATOR): out = [] for pt in data: out.append(interpolate(pt,freq_in,freq_out)) return np.array(out) \ No newline at end of file