Page MenuHomec4science

gaussain_collpsing_one_sinc.py
No OneTemporary

File Metadata

Created
Tue, Jul 16, 07:45

gaussain_collpsing_one_sinc.py

from functools import partial
import os
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax import grad, jit, pmap, value_and_grad, vmap
from jax.scipy.special import logsumexp
from torch.utils.data import DataLoader
import optax
import dataset_generator as dataLoader
import processing as proc
FREQ = 50_000
NUM_SAMPLES = 50
BATCH_SIZE = 10
USE_ENERGY_MAX_FREQ_EVAL = True
STOPPING_PERC = 1_000
INTERPOLATOR = 'sinc'
SIGNAL_TYPE = 'create' #'create': create a new emyulated ECG signal, 'load': load previous ECG signal, 'random': random sinusoidal mix signal
EPOCHS = 10
SIGMA_EPOCH = 20
SIGMA_EPOCH_DIVISOR = 1.2
INIT_FREQ_DIVISOR = 4
NUM_LVLS = 16
LR = 0.01
DEV_MODE = False
imag_res_folder = "../img_res"
def init_pipeline(gaussian_numbers,range_min,range_max,freq):
mu_s = jnp.linspace(range_min+(range_max-range_min)/NUM_LVLS,range_max-(range_max-range_min)/NUM_LVLS,gaussian_numbers)
params = {'mus':mu_s,'freq':freq}
return params
def proc_pipeline(params,x,sigma,static_params):
signal = x
signal = proc.mix_gaussian_lvl_crossing(signal,params['mus'],sigma)
signal = proc.normalize(signal)
signal = proc.sinc_interpolation_freq_parametrized(signal,params['freq'],FREQ,static_params['time_base'])
signal = proc.normalize(signal)
return signal
batched_proc_pipeline = vmap(proc_pipeline,in_axes=(None,0,None,None))
batched_RMSE = vmap(proc.RMSE)
@partial(jit, static_argnums=(4,))
def batched_loss_fn(params, data, nyquist_sampled_data, sigma, static_params):
proc_results = batched_proc_pipeline(params, data,sigma, static_params)
return jnp.mean(batched_RMSE(proc_results,nyquist_sampled_data))
'''
batched_proc_pipeline = pmap(proc_pipeline,in_axes=(None,0,None,None), static_broadcasted_argnums=[3])
batched_RMSE = pmap(proc.RMSE)
def batched_loss_fn(params, data, nyquist_sampled_data, sigma, static_params):
proc_results = batched_proc_pipeline(params, data,sigma, static_params)
return jnp.mean(batched_RMSE(proc_results,nyquist_sampled_data))
'''
def compute_dataset_loss(dataset,params,sigma,static_params):
loss = 0
i = 0
for (batch,objective_batch) in dataset:
i+=1
loss += batched_loss_fn(params, batch, objective_batch,sigma, static_params)
return loss/i
def train(train_dataset,params,sigma,static_params, lr=LR):
print(f"Loss: {compute_dataset_loss(train_dataset,params,sigma,static_params)}")
optimizer = optax.adam(lr)
opt_state = optimizer.init(params)
for e in range(EPOCHS):
print(f"Epoch {e+1}")
for (train_batch,objective_batch) in train_dataset:
grads = grad(batched_loss_fn)(params, train_batch, objective_batch, sigma, static_params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
print(f"Loss: {compute_dataset_loss(train_dataset,params,sigma,static_params)}")
return params #Final loss
def compute_loss_each_samples(dataset,params,sigma,static_params):
loss = []
for (batch,objective_batch) in dataset:
proc_results = batched_proc_pipeline(params, batch,sigma, static_params)
loss.extend(batched_RMSE(proc_results, objective_batch))
return loss
def compute_transform_each_samples(train_dataset,params,sigma,static_params, lr=LR):
proc_results = []
for (train_batch,_) in train_dataset:
proc_results.extend(batched_proc_pipeline(params, train_batch,sigma, static_params))
return proc_results #Final loss
def custom_collate_fn(batch):
transposed_data = list(zip(*batch))
data = np.array(transposed_data[0])
obj = np.array(transposed_data[1])
return data, obj
class hashabledict(dict):
def __hash__(self):
return hash(tuple(sorted(self.items())))
class hashable_np_array(np.ndarray):
def __hash__(self):
return int(self.mean()*1_000_000_000)
def main():
if DEV_MODE:
freq_desired = int(input("What base frequency do you need your data: "))
mus_num = int(input("How many levels: "))
sigma_div = int(input("Sigma fraction of the signal range (1/X): "))
nyq_freq_div = int(input("Sync freq in fraction of nyq. freq: "))
data = dataLoader.get_signal(SIGNAL_TYPE,num_pts=1,freq=freq_desired)
print("Computing Nyquist frequency...")
nyq_freq = proc.get_nyquist_freq_dataset(data,freq_desired,USE_ENERGY_MAX_FREQ_EVAL,STOPPING_PERC)
t_base_nyq = np.arange(0,len(data[0])/freq_desired,1/nyq_freq)
t_base_orig = np.arange(0,(len(data[0]))/freq_desired,1/freq_desired)
print(f"Nyquist frequency: {nyq_freq}")
print("Generating Nyquist sampled objective")
dataset_nyq = proc.interpolate_dataset(data, freq_desired, nyq_freq)
print("Initializing pipelines Parameters")
sigma = float((jnp.max(data)-jnp.min(data))/sigma_div)
params = init_pipeline(mus_num,np.min(data),np.max(data),nyq_freq/nyq_freq_div)
print(f"Parameters: {params}\nSigma: {sigma}")
print("Computing mixed gaussian representation")
gaussian_lvl_crossing_data = proc.mix_gaussian_lvl_crossing(data[0],params['mus'],sigma)
print("Normalizing")
gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data)
print("Computing sync filter")
resampled_gaussians = proc.sinc_interpolation_freq_parametrized(gaussian_lvl_crossing_data,params['freq'],freq_desired,t_base_orig)
print("Normalizing")
resampled_gaussians = proc.normalize(resampled_gaussians)
print("Plotting")
plt.figure()
plt.plot(t_base_orig, data[0])
plt.plot(t_base_nyq,dataset_nyq[0],"d")
plt.plot(t_base_orig,gaussian_lvl_crossing_data)
plt.plot(t_base_orig,resampled_gaussians[:len(t_base_orig)],"-")
plt.hlines(params['mus'],t_base_orig[0],t_base_orig[-1])
plt.show()
return
#Setting env:
if not os.path.isdir(imag_res_folder):
os.mkdir(imag_res_folder)
t_stamp = str(os.times().elapsed)
res_folder_this_run = os.path.join(imag_res_folder,t_stamp)
os.mkdir(res_folder_this_run)
os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={BATCH_SIZE}'
#load/generate data
print("Loading data")
data = dataLoader.get_signal(SIGNAL_TYPE,num_pts=NUM_SAMPLES,freq=FREQ)
if data is None:
print("Failed loading data")
return
data = proc.normalize_dataset(data)
t_base_orig = np.arange(0,(len(data[0]))/FREQ,1/FREQ)[:len(data[0])]
print("\n-------------------\n")
print(f"Loaded Data:shape {data.shape}")
#reference
print("Computing maximum Nyquist frequency for the Dataset...")
nyq_freq = proc.get_nyquist_freq_dataset(data,FREQ,USE_ENERGY_MAX_FREQ_EVAL,STOPPING_PERC)
print(f"Nyquist frequency: {nyq_freq}")
print("Generating Nyquist sampled objective dataset")
dataset_nyq = proc.interpolate_dataset(data, FREQ, nyq_freq)
t_base_nyq = np.arange(0,len(dataset_nyq[0])/nyq_freq,1/nyq_freq)[:len(dataset_nyq[0])]
train_dataset = [[d,o] for d,o in zip(data,dataset_nyq)]
static_params = hashabledict({'freq' : None,'time_base' : None})
static_params['freq'] = nyq_freq
static_params['time_base'] = t_base_nyq.view(hashable_np_array)
#static_params = t_base_nyq.view(hashable_np_array)
#INIT
print("Initializing pipelines Parameters")
sigma = float((jnp.max(data)-jnp.min(data))/(2.2*NUM_LVLS))
params = init_pipeline(NUM_LVLS,np.min(data),np.max(data),nyq_freq/INIT_FREQ_DIVISOR)
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn, drop_last=False)
print(f"Parameters: {params}")
loss_sigmas = []
lr = LR
sigma_iter = 0
for i in range(SIGMA_EPOCH):
print(f"SIGMA EPOCH: {i+1}/{SIGMA_EPOCH}, sigma: {sigma}")
params = train(train_loader,params,sigma,static_params,lr)
loss = compute_dataset_loss(train_loader,params,sigma,static_params)
loss_sigmas.append(loss)
losses = compute_loss_each_samples(train_loader,params,sigma,static_params)
best_beat = np.argmin(losses)
worst_beat = np.argmax(losses)
gaussian_lvl_crossing_data = proc.mix_gaussian_lvl_crossing(data[best_beat],params['mus'],sigma)
gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data)
resampled_gaussians = proc.sinc_interpolation_freq_parametrized(gaussian_lvl_crossing_data,params['freq'],FREQ,t_base_nyq)
resampled_gaussians = proc.normalize(resampled_gaussians)
plt.figure()
plt.plot(t_base_orig, data[best_beat])
plt.plot(t_base_nyq,dataset_nyq[best_beat],"d")
plt.plot(t_base_orig,gaussian_lvl_crossing_data)
plt.plot(t_base_nyq,resampled_gaussians,"o-")
plt.hlines(params['mus'],t_base_orig[0],t_base_orig[-1])
plt.savefig(f'{res_folder_this_run}/{sigma_iter}:sigma:{sigma}_best_loss:{np.min(losses)}.svg')
plt.close()
gaussian_lvl_crossing_data = proc.mix_gaussian_lvl_crossing(data[worst_beat],params['mus'],sigma)
gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data)
resampled_gaussians = proc.sinc_interpolation_freq_parametrized(gaussian_lvl_crossing_data,params['freq'],FREQ,t_base_nyq)
resampled_gaussians = proc.normalize(resampled_gaussians)
plt.figure()
plt.plot(t_base_orig, data[best_beat])
plt.plot(t_base_nyq,dataset_nyq[best_beat],"d")
plt.plot(t_base_orig,gaussian_lvl_crossing_data)
plt.plot(t_base_nyq,resampled_gaussians,"o-")
plt.hlines(params['mus'],t_base_orig[0],t_base_orig[-1])
plt.savefig(f'{res_folder_this_run}/{sigma_iter}:sigma:{sigma}_worst_loss:{np.max(losses)}.svg')
plt.close()
sigma /= SIGMA_EPOCH_DIVISOR
lr -= lr/(20)
sigma_iter += 1
print(f"END OF SIGMA EPOCH {i+1}, LOSS = {loss}")
print(params)
plt.figure()
plt.plot(loss_sigmas)
plt.savefig(f'{res_folder_this_run}/lossVSepoch.svg')
plt.close()
if __name__ == "__main__":
#test_signal()
main()

Event Timeline