Page MenuHomec4science

MOAT_2.py
No OneTemporary

File Metadata

Created
Tue, Sep 17, 07:43

MOAT_2.py

import os
import time
from functools import partial
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from jax import grad, jacfwd, jit, jvp, vjp, vmap
import jaxopt as jopt
from torch.utils.data import DataLoader
import dataset_generator as dataLoader
import processing as proc
# ------------------------ PARAMS ------------------------
#
#
#
#---------------------------------------------------------
#ADDITIONAL FLAGS FOR EXPERIMENTATIONS
COMPLEX = True
FILTER_LEN_WRT_NYQ = 5
USE_LBFGS = False
#SIGNAL GENERATION
FREQ = 5_000
NUM_SAMPLES_TRAIN = 1
NUM_SAMPLES_TEST = 1
BATCH_SIZE = NUM_SAMPLES_TRAIN
USE_ENERGY_MAX_FREQ_EVAL = True
STOPPING_PERC = 5000
INTERPOLATOR = 'sinc'
SIGNAL_TYPE = 'create' #'create': create a new emulated ECG signal, 'load': load previous ECG signal, 'random': random sinusoidal mix signal
#GAUSSIAN PARAMETERS
INITIAL_SIGMA_DIVISOR = 6
SIGMA_EPOCH_DIVISOR = 1.6
NUM_LVLS = 3
#EPOCHS
SIGMA_EPOCH = 6
ADAM_EPOCHS = 2000
ADAM_EPOCHS_FREQUENCY_LEARNING = 8000
#LBFGS PARAMETERS
LINE_SEARCH = 'zoom' #backtracking, zoom (default), hager-zhang (not working)
HISTORY_SIZE = 20
TOLLERANCE = 0.05
BFGS_ITER = 10
#LR
LR = 1e-5
LR_DIVISOR = 1
LR_FREQ_LEARNING = 5e-5
#REGULARIZER
LAMBDA_MAXIMA = 5e-9
LAMBDA_FREQUENCY = 1e-3
#RESULTS SAVE BACK
res_folder = "../res"
#ENVIRONMENT STRING TO SAVEBACK
str_env_to_save = \
f'''
#ADDITIONAL FLAGS FOR EXPERIMENTATIONS
COMPLEX = {COMPLEX}
FILTER_LEN_WRT_NYQ = {FILTER_LEN_WRT_NYQ}
USE_LBFGS = {USE_LBFGS}
#SIGNAL GENERATION
FREQ = {FREQ}
NUM_SAMPLES_TRAIN = {NUM_SAMPLES_TRAIN}
NUM_SAMPLES_TEST = {NUM_SAMPLES_TEST}
BATCH_SIZE = {BATCH_SIZE}
USE_ENERGY_MAX_FREQ_EVAL = {USE_ENERGY_MAX_FREQ_EVAL}
STOPPING_PERC = {STOPPING_PERC}
INTERPOLATOR = {INTERPOLATOR}
SIGNAL_TYPE = {SIGNAL_TYPE} #'create': create a new emulated ECG signal, 'load': load previous ECG signal, 'random': random sinusoidal mix signal
#GAUSSIAN PARAMETERS
INITIAL_SIGMA_DIVISOR = {INITIAL_SIGMA_DIVISOR}
SIGMA_EPOCH_DIVISOR = {SIGMA_EPOCH_DIVISOR}
NUM_LVLS = {NUM_LVLS}
#EPOCHS
SIGMA_EPOCH = {SIGMA_EPOCH}
ADAM_EPOCHS = {ADAM_EPOCHS}
ADAM_EPOCHS_FREQUENCY_LEARNING = {ADAM_EPOCHS_FREQUENCY_LEARNING}
#LBFGS PARAMETERS
LINE_SEARCH = {LINE_SEARCH} #backtracking, zoom (default), hager-zhang
HISTORY_SIZE = {HISTORY_SIZE}
TOLLERANCE = {TOLLERANCE}
BFGS_ITER = {BFGS_ITER}
#LR
LR = {LR}
LR_DIVISOR = {LR_DIVISOR}
LR_FREQ_LEARNING = {LR_FREQ_LEARNING}
#REGULARIZER
LAMBDA_MAXIMA = {LAMBDA_MAXIMA}
LAMBDA_FREQUENCY = {LAMBDA_FREQUENCY}
#RESULTS SAVE BACK
res_folder = {res_folder}
'''
# ------------------------ PROC. PIPELINE DEFINITION ------------------------
#
#
#
#----------------------------------------------------------------------------
def init_pipeline(gaussian_numbers,range_min,range_max,len_filter,len_signal):
mus = jnp.linspace(range_min+(range_max-range_min)/gaussian_numbers,range_max-(range_max-range_min)/gaussian_numbers,gaussian_numbers)
if COMPLEX:
freq_coeffs = np.array([[0.0+0.0J]*len_filter]*(len_signal//2))
else:
freq_coeffs = np.array([[0]*len_filter]*(len_signal//2)).astype(float)
params = {'mus':mus,'freq_coeffs':jnp.array(freq_coeffs)}
return params
#TODO I feel like there's some normalization problem
#@partial(jit,static_argnums=(2,3))
def proc_pipeline(params,x,sigma,static_params):
signal = x
signal = proc.mix_gaussian_lvl_crossing(signal,params['mus'],sigma)
#signal = proc.normalize(signal)
num_maxima = proc.count_maxima(signal)
signal = proc.linear_transform_fourier_domain_no_back_transform(signal,params['freq_coeffs'])
return signal,num_maxima
def wrapper_proc_pipeline(static_params):
return jit(partial(proc_pipeline_freq_learning,static_params=static_params))
#@partial(jit,static_argnums=(2,))
def proc_pipeline_freq_learning(params,x,static_params):
signal = x
signal = proc.lvl_crossing_stupid(signal,static_params['mus'])
signal = proc.linear_transform_fourier_domain_no_back_transform(signal,params['freq_coeffs'])
return signal
def wrapper_proc_pipeline_freq_learning(static_params):
return jit(partial(proc_pipeline_freq_learning,static_params=static_params))
#@partial(jit,static_argnums=(3,4))
def loss(params,x,ground_truth,sigma,static_params):
proc_results,maxima = proc_pipeline(params, x,sigma, static_params)
error = abs(proc.RMSE(proc_results,jnp.fft.fft(ground_truth))) #RMSE
maxima_reg = static_params['lambda_maxima']*maxima #maxima regularizer
return error+maxima_reg
def wrapper_loss(static_params):
return jit(partial(loss,static_params=static_params))
#@partial(jit,static_argnums=(3,))
def loss_freq(params,x,ground_truth,static_params):
proc_results = proc_pipeline_freq_learning(params, x,static_params)
error = abs(proc.RMSE(proc_results,jnp.fft.fft(ground_truth))) #RMSE
return error
def wrapper_loss_freq(static_params):
return jit(partial(loss_freq,static_params=static_params))
# ------------------------ BATCHING ------------------------
#
#
#
#-----------------------------------------------------------
#------------------------LVL LEARNING------------------------
batched_proc_pipeline = vmap(proc_pipeline,in_axes=(None,0,None,None))
batched_proc_pipeline_freq = vmap(proc_pipeline_freq_learning,in_axes=(None,0,None))
batched_lvl_crossing = vmap(proc.lvl_crossing_stupid,in_axes=(0,None))
batched_loss = vmap(loss,in_axes=(None,0,0,None,None))
batched_loss_freq = vmap(loss_freq,in_axes=(None,0,0,None))
#@partial(jit,static_argnums=(3,4))
def avg_batch_loss(params,x,ground_truth,sigma,static_params):
loss = batched_loss(params,x,ground_truth,sigma,static_params)
return jnp.average(loss)
def wrapper_avg_batch_loss(static_params):
return jit(partial(avg_batch_loss,static_params=static_params))
def dataset_loss(params,dataset,sigma,static_params):
loss = 0
i = 0
for (batch,objective_batch) in dataset:
i+= len(batch)
loss += jnp.sum(batched_loss(params, batch, objective_batch, sigma, static_params))
return loss/i
def loss_each_samples(params,dataset,sigma,static_params):
loss = []
for (batch,objective_batch) in dataset:
loss.extend(batched_loss(params, batch, objective_batch, sigma, static_params))
return loss
#------------------------FREQ LEARNING------------------------
#@partial(jit,static_argnums=(3,))
def avg_batch_loss_freq(params,x,ground_truth,static_params):
loss = batched_loss_freq(params,x,ground_truth,static_params)
return jnp.average(loss)
def wrapper_avg_batch_loss_freq(static_params):
return jit(partial(avg_batch_loss_freq,static_params=static_params))
def compute_dataset_loss_freq(params,dataset,static_params):
loss = 0
i = 0
for (batch,objective_batch) in dataset:
i+= len(batch)
loss += jnp.sum(batched_loss_freq(params, batch, objective_batch, static_params))
return loss/i
def loss_each_samples_freq(params,dataset,static_params):
loss = []
for (batch,objective_batch) in dataset:
loss.extend(batched_loss_freq(params, batch, objective_batch, static_params))
return loss
# ------------------------ GAUSSIAN/SPECTRA OPTIMIZATION ------------------------
#
#
#
#--------------------------------------------------------------------------------
def train_and_test(dataset,params,sigma,static_params, lr=LR):
this_avg_batch_loss = wrapper_avg_batch_loss(static_params)
train_dataset=dataset[0]
test_dataset = dataset[1]
print(f"Train Loss: {dataset_loss(params,train_dataset,sigma,static_params)}")
print(f"Test loss:{dataset_loss(params,test_dataset,sigma,static_params)}")
optimizer = optax.adam(lr)
opt_state = optimizer.init(params)
for e in range(ADAM_EPOCHS):
for (train_batch,objective_batch) in train_dataset:
#grads = jacfwd(batched_loss_fn)(params, train_batch, objective_batch, sigma, static_params)
#grads = vjp(batched_loss_fn, params, train_batch, objective_batch, sigma, static_params)[1](1.0)
grads = grad(this_avg_batch_loss)(params, train_batch, objective_batch, sigma)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if e%(ADAM_EPOCHS//10)==0:
print(f"\nEpoch {e+1}")
print(f"\tTrain Loss: {dataset_loss(params,train_dataset,sigma,static_params)}")
print(f"\tTest loss:{dataset_loss(params,test_dataset,sigma,static_params)}")
print(f"\nEpoch {e+1}")
print(f"\tTrain Loss: {dataset_loss(params,train_dataset,sigma,static_params)}")
print(f"\tTest loss:{dataset_loss(params,test_dataset,sigma,static_params)}")
lbfgs_run = USE_LBFGS
while lbfgs_run:
print("\nNow using LBFGS")
params_lbfgs = params
tollerance = TOLLERANCE
solver = jopt.LBFGS(fun=this_avg_batch_loss, maxiter=BFGS_ITER, jit = True, linesearch=LINE_SEARCH, tol=tollerance, history_size=HISTORY_SIZE)
for (train_batch,objective_batch) in train_dataset:
args = {'x' : train_batch,'ground_truth' : objective_batch,'sigma' : sigma}
params_lbfgs,_ = solver.run(params,**args)
loss_train = dataset_loss(params_lbfgs,train_dataset,sigma,static_params)
if not jnp.isnan(loss_train):
lbfgs_run = False
else:
print('LBFGS returned nan, re-running...')
tollerance += tollerance*0.5
print("\nLBFGS results:")
print(f"\tTrain Loss: {dataset_loss(params,train_dataset,sigma,static_params)}")
print(f"\tTest loss:{dataset_loss(params,test_dataset,sigma,static_params)}")
params = params_lbfgs
return params
def compute_transform_each_samples(train_dataset,params,sigma,static_params):
proc_results = []
for (train_batch,_) in train_dataset:
proc_results.extend(batched_proc_pipeline(params, train_batch,sigma, static_params)[0])
return proc_results
# ------------------------ SPECTRA ONLY OPTIMIZATION ---------------------------
#
#
#
#--------------------------------------------------------------------------------
def train_and_test_freq(dataset,params,static_params, lr=LR):
this_avg_batch_loss_freq = wrapper_avg_batch_loss_freq(static_params)
train_dataset=dataset[0]
test_dataset = dataset[1]
print(f"Train Loss: {compute_dataset_loss_freq(params,train_dataset,static_params)}")
print(f"Test Loss: {compute_dataset_loss_freq(params,test_dataset,static_params)}")
optimizer = optax.adam(lr)
opt_state = optimizer.init(params)
for e in range(ADAM_EPOCHS_FREQUENCY_LEARNING):
for (train_batch,objective_batch) in train_dataset:
#grads = jacfwd(batched_loss_fn_freq_learning)(params, train_batch, objective_batch, static_params)
#grads = vjp(batched_loss_fn_freq_learning, params, train_batch, objective_batch, static_params)[1](1.0)
grads = grad(this_avg_batch_loss_freq)(params, train_batch, objective_batch)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if e%(ADAM_EPOCHS_FREQUENCY_LEARNING//10)==0:
print(f"\nEpoch {e+1}")
print(f"\tTrain Loss: {compute_dataset_loss_freq(params,train_dataset,static_params)}")
print(f"\tTest Loss: {compute_dataset_loss_freq(params,test_dataset,static_params)}")
print(f"\nEpoch {e+1}")
print(f"\tTrain Loss: {compute_dataset_loss_freq(params,train_dataset,static_params)}")
print(f"\tTest loss:{compute_dataset_loss_freq(params,test_dataset,static_params)}")
lbfgs_run = USE_LBFGS
while lbfgs_run:
print("\nNow using LBFGS")
params_lbfgs = params
tollerance = TOLLERANCE
solver = jopt.LBFGS(fun=this_avg_batch_loss_freq, maxiter=BFGS_ITER, jit = True, linesearch=LINE_SEARCH,tol=tollerance,history_size=HISTORY_SIZE)
for (train_batch,objective_batch) in train_dataset:
args = {'x' : train_batch,'ground_truth' : objective_batch}
params_lbfgs,_ = solver.run(params,**args)
loss_train = compute_dataset_loss_freq(params_lbfgs,train_dataset,static_params)
if not jnp.isnan(loss_train):
lbfgs_run = False
else:
print('LBFGS returned nan, re-running...')
tollerance += tollerance*0.5
print("\nLBFGS results:")
print(f"\tTrain Loss: {compute_dataset_loss_freq(params,train_dataset,static_params)}")
print(f"\tTest loss:{compute_dataset_loss_freq(params,test_dataset,static_params)}")
params = params_lbfgs
return params
def compute_transform_each_samples_freq(dataset,params,static_params):
proc_results = []
for (train_batch,_) in dataset:
proc_results.extend(batched_proc_pipeline_freq(params, train_batch, static_params))
return proc_results
# ------------------------ HELPER FUNCTIONS ------------------------
#
#
#
#-------------------------------------------------------------------
def custom_collate_fn(batch):
transposed_data = list(zip(*batch))
data = np.array(transposed_data[0])
obj = np.array(transposed_data[1])
return data, obj
class hashabledict(dict):
def __hash__(self):
return hash(tuple(sorted(self.items())))
class hashable_np_array(np.ndarray):
def __hash__(self):
return int(self.mean()*1_000_000_000)
def compute_transform_and_plot_beat(beat,ground_truth,params,sigma,static_params,t_hf,t_lf,name):
gaussian_lvl_crossing_data = proc.mix_gaussian_lvl_crossing(beat,params['mus'],sigma)
#gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data)
filtered = proc.linear_transform_fourier_domain(gaussian_lvl_crossing_data,params['freq_coeffs'])
#interpolated = proc.sinc_interpolation_freq_parametrized(filtered,FREQ,FREQ,static_params['time_base'])
#normed = interpolated #proc.normalize(interpolated)
plt.figure()
plt.plot(t_hf, beat)
plt.plot(t_lf,ground_truth,"d")
plt.plot(t_hf,gaussian_lvl_crossing_data)
plt.plot(t_hf[:len(filtered)],filtered)
plt.hlines(params['mus'],t_hf[0],t_hf[-1])
plt.savefig(name)
plt.close()
interpolated = proc.sinc_interpolation_freq_parametrized(filtered,FREQ,FREQ,static_params['time_base'])
normed = interpolated #proc.normalize(interpolated)
plt.figure()
plt.plot(t_hf, beat)
plt.plot(t_lf,ground_truth,"d")
plt.plot(t_hf,gaussian_lvl_crossing_data)
plt.plot(t_lf,normed,"o")
plt.hlines(params['mus'],t_hf[0],t_hf[-1])
plt.savefig(name+"_samples.svg")
plt.close()
beat_freq = np.fft.fft(beat)[:len(beat)//2]
trans_freq = np.fft.fft(filtered)[:len(filtered)//2]
plt.figure()
plt.plot(abs(beat_freq))
plt.plot(abs(trans_freq))
plt.savefig(name+"_f.svg")
plt.close()
def compute_pure_lc_transform_and_plot_beat(beat,ground_truth,params,static_params,t_hf,t_lf,name):
lvl_crossing_data = proc.lvl_crossing_stupid(beat,static_params['mus'])
#lvl_crossing_data = proc.normalize(lvl_crossing_data)
filtered = proc.linear_transform_fourier_domain(lvl_crossing_data,params['freq_coeffs'])
plt.figure()
plt.plot(t_hf, beat)
plt.plot(t_lf,ground_truth,"d")
plt.plot(t_hf,lvl_crossing_data)
plt.plot(t_hf[:len(filtered)],filtered)
plt.hlines(static_params['mus'],t_hf[0],t_hf[-1])
plt.savefig(name)
plt.close()
interpolated = proc.sinc_interpolation_freq_parametrized(filtered,FREQ,FREQ,static_params['time_base'])
normed = interpolated #proc.normalize(interpolated)
plt.figure()
plt.plot(t_hf, beat)
plt.plot(t_lf,ground_truth,"d")
plt.plot(t_hf,lvl_crossing_data)
plt.plot(t_lf,normed,"o")
plt.hlines(static_params['mus'],t_hf[0],t_hf[-1])
plt.savefig(name+"_samples.svg")
plt.close()
beat_freq = np.fft.fft(beat)[:len(beat)//2]
trans_freq = np.fft.fft(filtered)[:len(filtered)//2]
plt.figure()
plt.plot(abs(beat_freq))
plt.plot(abs(trans_freq))
plt.savefig(name+"_f.svg")
plt.close()
def save_env(res_folder_this_run):
with open(os.path.join(res_folder_this_run,'settings.txt'),'w') as f:
f.write(str_env_to_save)
def save_params_state(res_folder_this_run,params):
with open(os.path.join(res_folder_this_run,'params.txt'),'w') as f:
for p in params:
f.write(f"{p}: {str(params[p])}\n")
# ------------------------ MAIN LOOP ------------------------
#
#
#
#------------------------------------------------------------
def main():
#Setting env:
if not os.path.isdir(res_folder):
os.mkdir(res_folder)
t_stamp = time.ctime()
res_folder_this_run = os.path.join(res_folder,t_stamp)
os.mkdir(res_folder_this_run)
save_env(res_folder_this_run)
os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={BATCH_SIZE}'
#load/generate data
print("Loading data")
data = dataLoader.get_signal(SIGNAL_TYPE,num_pts=NUM_SAMPLES_TEST+NUM_SAMPLES_TRAIN,freq=FREQ)
if data is None:
print("Failed loading data")
return
#Only even number of points allowed, why? 'cause reasons
if data.shape[1]%2!=0:
data = data[:,:-1]
data = proc.normalize_dataset(data)
t_base_orig = np.arange(0,(len(data[0]))/FREQ,1/FREQ)[:len(data[0])]
print("\n-------------------\n")
print(f"Loaded Data:shape {data.shape}")
#reference
print("Computing maximum Nyquist frequency for the Dataset...")
nyq_freq_true = proc.get_nyquist_freq_dataset(data,FREQ,USE_ENERGY_MAX_FREQ_EVAL,STOPPING_PERC)
print(f"Nyquist frequency: {nyq_freq_true}")
print("Generating Nyquist sampled objective dataset")
#CHANGING THIS TO BASICALLY IGNORE ALL NQ. SHIT
dataset_nyq = data#proc.interpolate_dataset(data, FREQ, nyq_freq)
nyq_freq = FREQ
#... CHE SCHIFO
t_base_nyq = np.arange(0,len(dataset_nyq[0])/nyq_freq,1/nyq_freq)[:len(dataset_nyq[0])]
train_dataset = [[d,o] for d,o in zip(data[:NUM_SAMPLES_TRAIN],dataset_nyq[:NUM_SAMPLES_TRAIN])]
test_dataset = [[d,o] for d,o in zip(data[NUM_SAMPLES_TRAIN:],dataset_nyq[NUM_SAMPLES_TRAIN:])]
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn, drop_last=False)
test_loader = DataLoader(test_dataset, BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn, drop_last=False)
dataset = [train_loader,test_loader]
#------------------------------------ GAUSSIAN/FREQUENCY LEARNING WITH TRUE LVL CROSSING ------------------------------------
#
#
#
#----------------------------------------------------------------------------------------------------------------------------
#INIT
print("Initializing pipeline Parameters")
sigma = float((jnp.max(data)-jnp.min(data))/(INITIAL_SIGMA_DIVISOR*NUM_LVLS)) #We can pass the length of the nyquist samples signals as length of the filter
params = init_pipeline(NUM_LVLS,np.min(data),np.max(data), len_filter = int(t_base_nyq[-1]*nyq_freq_true*FILTER_LEN_WRT_NYQ),len_signal= len(data[0]))# len(data[0])
static_params = {'time_base' : None, 'lambda_maxima' : None, 'lambda_freq' : None, 'ideal_filter_len' : None}#hashabledict({'freq' : None, 'time_base' : None, 'lambda' : None})
static_params['time_base'] = t_base_nyq#.view(hashable_np_array)
static_params['lambda_maxima'] = LAMBDA_MAXIMA
static_params['lambda_freq'] = LAMBDA_FREQUENCY
static_params['ideal_filter_len'] = int(t_base_nyq[-1]*nyq_freq_true//2)
#static_params = hashabledict(static_params)
print(f"Parameters: {params}\n")
loss_train_sigmas =[]
loss_test_sigmas = []
lr = LR
sigma_iter = 0
for i in range(SIGMA_EPOCH):
print("%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/")
print(f"SIGMA EPOCH: {i+1}/{SIGMA_EPOCH}, sigma: {sigma}\n")
params = train_and_test(dataset,params,sigma,static_params,lr)
loss_train = dataset_loss(params,train_loader,sigma,static_params)
loss_test = dataset_loss(params,test_loader,sigma,static_params)
loss_train_sigmas.append(loss_train)
loss_test_sigmas.append(loss_test)
# -------------------------- TRAIN ------------------------------
losses = loss_each_samples(params,train_loader,sigma,static_params)
best_beat = train_dataset[np.argmin(losses)][0]
best_beat_nyq = train_dataset[np.argmin(losses)][1]
worst_beat = train_dataset[np.argmax(losses)][0]
worst_beat_nyq = train_dataset[np.argmax(losses)][1]
name = f'{res_folder_this_run}/TRAIN_{sigma_iter}:sigma:{sigma}_best_loss:{np.min(losses)}.svg'
compute_transform_and_plot_beat(best_beat,best_beat_nyq,params,sigma,static_params,t_base_orig,t_base_nyq,name)
name = f'{res_folder_this_run}/TRAIN_{sigma_iter}:sigma:{sigma}_worst_loss:{np.max(losses)}.svg'
compute_transform_and_plot_beat(worst_beat,worst_beat_nyq,params,sigma,static_params,t_base_orig,t_base_nyq,name)
# -------------------------- TEST ------------------------------
losses = loss_each_samples(params,test_loader,sigma,static_params)
best_beat = test_dataset[np.argmin(losses)][0]
best_beat_nyq = test_dataset[np.argmin(losses)][1]
worst_beat = test_dataset[np.argmax(losses)][0]
worst_beat_nyq = test_dataset[np.argmax(losses)][1]
name = f'{res_folder_this_run}/TEST_{sigma_iter}:sigma:{sigma}_best_loss:{np.min(losses)}.svg'
compute_transform_and_plot_beat(best_beat,best_beat_nyq,params,sigma,static_params,t_base_orig,t_base_nyq,name)
name = f'{res_folder_this_run}/TEST_{sigma_iter}:sigma:{sigma}_worst_loss:{np.max(losses)}.svg'
compute_transform_and_plot_beat(worst_beat,worst_beat_nyq,params,sigma,static_params,t_base_orig,t_base_nyq,name)
plt.figure()
plt.matshow(abs(params['freq_coeffs']))
plt.colorbar()
plt.savefig(f'{res_folder_this_run}/{sigma_iter}:sigma:{sigma}_filter_abs.svg')
plt.close()
plt.figure()
plt.matshow(jnp.angle(params['freq_coeffs']))
plt.colorbar()
plt.savefig(f'{res_folder_this_run}/{sigma_iter}:sigma:{sigma}_filter_phs.svg')
plt.close()
sigma /= SIGMA_EPOCH_DIVISOR
lr /= LR_DIVISOR
sigma_iter += 1
print('\n--------------------------------------------------------------------')
print(f"END OF SIGMA EPOCH {i+1}, TRAIN LOSS = {loss_train}, TEST LOSS = {loss_test}")
print(f"Levels:{params['mus']}")
print('--------------------------------------------------------------------\n')
plt.figure()
plt.plot(loss_train_sigmas)
plt.plot(loss_test_sigmas)
plt.legend(['train','test'])
plt.savefig(f'{res_folder_this_run}/lossVSepoch.svg')
#plt.show()
#----------------------------------------- FREQUENCY LEARNING WITH TRUE LVL CROSSING -----------------------------------------
#
#
#
#
#-----------------------------------------------------------------------------------------------------------------------------
print(f"SPECTRAL LEARNING FOR THE FOUND LEVELS")
#INIT
print("Initializing pipeline Parameters")
static_params = {
'mus' : params['mus'],#.view(hashable_np_array),
'time_base' : static_params['time_base'],
'lambda_freq' : static_params['lambda_freq'],
'ideal_filter_len' : static_params['ideal_filter_len']}
#static_params = hashabledict(static_params)
params = {'freq_coeffs':params['freq_coeffs']}
params = train_and_test_freq(dataset,params,static_params,LR_FREQ_LEARNING)
# -------------------------- TRAIN ------------------------------
losses = loss_each_samples_freq(params,train_loader,static_params)
best_beat = train_dataset[np.argmin(losses)][0]
best_beat_nyq = train_dataset[np.argmin(losses)][1]
worst_beat = train_dataset[np.argmax(losses)][0]
worst_beat_nyq = train_dataset[np.argmax(losses)][1]
#print(f"TRAIN: Average number of events: {np.average(num_samples)}+-{np.std(num_samples)}")
name = f'{res_folder_this_run}/TRAIN_freq_learning_level_crossing_best_loss:{np.min(losses)}.svg'
compute_pure_lc_transform_and_plot_beat(best_beat,best_beat_nyq,params,static_params,t_base_orig,t_base_nyq,name)
name = f'{res_folder_this_run}/TRAIN_freq_learning_level_crossing_worst_loss:{np.max(losses)}.svg'
compute_pure_lc_transform_and_plot_beat(worst_beat,worst_beat_nyq,params,static_params,t_base_orig,t_base_nyq,name)
# -------------------------- TEST ------------------------------
losses = loss_each_samples_freq(params,test_loader,static_params)
best_beat = test_dataset[np.argmin(losses)][0]
best_beat_nyq = test_dataset[np.argmin(losses)][1]
worst_beat = test_dataset[np.argmax(losses)][0]
worst_beat_nyq = test_dataset[np.argmax(losses)][1]
#print(f"TEST: Average number of events: {np.average(num_samples)}+-{np.std(num_samples)}")
name = f'{res_folder_this_run}/TEST_freq_learning_level_crossing_best_loss:{np.min(losses)}.svg'
compute_pure_lc_transform_and_plot_beat(best_beat,best_beat_nyq,params,static_params,t_base_orig,t_base_nyq,name)
name = f'{res_folder_this_run}/TEST_freq_learning_level_crossing_worst_loss:{np.max(losses)}.svg'
compute_pure_lc_transform_and_plot_beat(worst_beat,worst_beat_nyq,params,static_params,t_base_orig,t_base_nyq,name)
plt.figure()
plt.matshow(abs(params['freq_coeffs']))
plt.colorbar()
plt.savefig(f'{res_folder_this_run}/freq_learning_filter_abs.svg')
plt.close()
plt.figure()
plt.matshow(jnp.angle(params['freq_coeffs']))
plt.colorbar()
plt.savefig(f'{res_folder_this_run}/freq_learning_filter_phs.svg')
plt.close()
save_params_state(res_folder_this_run,{'filter_coefficeints':params['freq_coeffs'],'levels':static_params['mus']})
if __name__ == "__main__":
#test_signal()
main()

Event Timeline