Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F83469684
MOAT_2.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Tue, Sep 17, 07:43
Size
25 KB
Mime Type
text/x-python
Expires
Thu, Sep 19, 07:43 (2 d)
Engine
blob
Format
Raw Data
Handle
20843638
Attached To
R12761 LC-Sampling-Theory
MOAT_2.py
View Options
import
os
import
time
from
functools
import
partial
import
jax.numpy
as
jnp
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
optax
from
jax
import
grad
,
jacfwd
,
jit
,
jvp
,
vjp
,
vmap
import
jaxopt
as
jopt
from
torch.utils.data
import
DataLoader
import
dataset_generator
as
dataLoader
import
processing
as
proc
# ------------------------ PARAMS ------------------------
#
#
#
#---------------------------------------------------------
#ADDITIONAL FLAGS FOR EXPERIMENTATIONS
COMPLEX
=
True
FILTER_LEN_WRT_NYQ
=
5
USE_LBFGS
=
False
#SIGNAL GENERATION
FREQ
=
5
_000
NUM_SAMPLES_TRAIN
=
1
NUM_SAMPLES_TEST
=
1
BATCH_SIZE
=
NUM_SAMPLES_TRAIN
USE_ENERGY_MAX_FREQ_EVAL
=
True
STOPPING_PERC
=
5000
INTERPOLATOR
=
'sinc'
SIGNAL_TYPE
=
'create'
#'create': create a new emulated ECG signal, 'load': load previous ECG signal, 'random': random sinusoidal mix signal
#GAUSSIAN PARAMETERS
INITIAL_SIGMA_DIVISOR
=
6
SIGMA_EPOCH_DIVISOR
=
1.6
NUM_LVLS
=
3
#EPOCHS
SIGMA_EPOCH
=
6
ADAM_EPOCHS
=
2000
ADAM_EPOCHS_FREQUENCY_LEARNING
=
8000
#LBFGS PARAMETERS
LINE_SEARCH
=
'zoom'
#backtracking, zoom (default), hager-zhang (not working)
HISTORY_SIZE
=
20
TOLLERANCE
=
0.05
BFGS_ITER
=
10
#LR
LR
=
1e-5
LR_DIVISOR
=
1
LR_FREQ_LEARNING
=
5e-5
#REGULARIZER
LAMBDA_MAXIMA
=
5e-9
LAMBDA_FREQUENCY
=
1e-3
#RESULTS SAVE BACK
res_folder
=
"../res"
#ENVIRONMENT STRING TO SAVEBACK
str_env_to_save
=
\
f
'''
#ADDITIONAL FLAGS FOR EXPERIMENTATIONS
COMPLEX = {COMPLEX}
FILTER_LEN_WRT_NYQ = {FILTER_LEN_WRT_NYQ}
USE_LBFGS = {USE_LBFGS}
#SIGNAL GENERATION
FREQ = {FREQ}
NUM_SAMPLES_TRAIN = {NUM_SAMPLES_TRAIN}
NUM_SAMPLES_TEST = {NUM_SAMPLES_TEST}
BATCH_SIZE = {BATCH_SIZE}
USE_ENERGY_MAX_FREQ_EVAL = {USE_ENERGY_MAX_FREQ_EVAL}
STOPPING_PERC = {STOPPING_PERC}
INTERPOLATOR = {INTERPOLATOR}
SIGNAL_TYPE = {SIGNAL_TYPE} #'create': create a new emulated ECG signal, 'load': load previous ECG signal, 'random': random sinusoidal mix signal
#GAUSSIAN PARAMETERS
INITIAL_SIGMA_DIVISOR = {INITIAL_SIGMA_DIVISOR}
SIGMA_EPOCH_DIVISOR = {SIGMA_EPOCH_DIVISOR}
NUM_LVLS = {NUM_LVLS}
#EPOCHS
SIGMA_EPOCH = {SIGMA_EPOCH}
ADAM_EPOCHS = {ADAM_EPOCHS}
ADAM_EPOCHS_FREQUENCY_LEARNING = {ADAM_EPOCHS_FREQUENCY_LEARNING}
#LBFGS PARAMETERS
LINE_SEARCH = {LINE_SEARCH} #backtracking, zoom (default), hager-zhang
HISTORY_SIZE = {HISTORY_SIZE}
TOLLERANCE = {TOLLERANCE}
BFGS_ITER = {BFGS_ITER}
#LR
LR = {LR}
LR_DIVISOR = {LR_DIVISOR}
LR_FREQ_LEARNING = {LR_FREQ_LEARNING}
#REGULARIZER
LAMBDA_MAXIMA = {LAMBDA_MAXIMA}
LAMBDA_FREQUENCY = {LAMBDA_FREQUENCY}
#RESULTS SAVE BACK
res_folder = {res_folder}
'''
# ------------------------ PROC. PIPELINE DEFINITION ------------------------
#
#
#
#----------------------------------------------------------------------------
def
init_pipeline
(
gaussian_numbers
,
range_min
,
range_max
,
len_filter
,
len_signal
):
mus
=
jnp
.
linspace
(
range_min
+
(
range_max
-
range_min
)
/
gaussian_numbers
,
range_max
-
(
range_max
-
range_min
)
/
gaussian_numbers
,
gaussian_numbers
)
if
COMPLEX
:
freq_coeffs
=
np
.
array
([[
0.0
+
0.0
J
]
*
len_filter
]
*
(
len_signal
//
2
))
else
:
freq_coeffs
=
np
.
array
([[
0
]
*
len_filter
]
*
(
len_signal
//
2
))
.
astype
(
float
)
params
=
{
'mus'
:
mus
,
'freq_coeffs'
:
jnp
.
array
(
freq_coeffs
)}
return
params
#TODO I feel like there's some normalization problem
#@partial(jit,static_argnums=(2,3))
def
proc_pipeline
(
params
,
x
,
sigma
,
static_params
):
signal
=
x
signal
=
proc
.
mix_gaussian_lvl_crossing
(
signal
,
params
[
'mus'
],
sigma
)
#signal = proc.normalize(signal)
num_maxima
=
proc
.
count_maxima
(
signal
)
signal
=
proc
.
linear_transform_fourier_domain_no_back_transform
(
signal
,
params
[
'freq_coeffs'
])
return
signal
,
num_maxima
def
wrapper_proc_pipeline
(
static_params
):
return
jit
(
partial
(
proc_pipeline_freq_learning
,
static_params
=
static_params
))
#@partial(jit,static_argnums=(2,))
def
proc_pipeline_freq_learning
(
params
,
x
,
static_params
):
signal
=
x
signal
=
proc
.
lvl_crossing_stupid
(
signal
,
static_params
[
'mus'
])
signal
=
proc
.
linear_transform_fourier_domain_no_back_transform
(
signal
,
params
[
'freq_coeffs'
])
return
signal
def
wrapper_proc_pipeline_freq_learning
(
static_params
):
return
jit
(
partial
(
proc_pipeline_freq_learning
,
static_params
=
static_params
))
#@partial(jit,static_argnums=(3,4))
def
loss
(
params
,
x
,
ground_truth
,
sigma
,
static_params
):
proc_results
,
maxima
=
proc_pipeline
(
params
,
x
,
sigma
,
static_params
)
error
=
abs
(
proc
.
RMSE
(
proc_results
,
jnp
.
fft
.
fft
(
ground_truth
)))
#RMSE
maxima_reg
=
static_params
[
'lambda_maxima'
]
*
maxima
#maxima regularizer
return
error
+
maxima_reg
def
wrapper_loss
(
static_params
):
return
jit
(
partial
(
loss
,
static_params
=
static_params
))
#@partial(jit,static_argnums=(3,))
def
loss_freq
(
params
,
x
,
ground_truth
,
static_params
):
proc_results
=
proc_pipeline_freq_learning
(
params
,
x
,
static_params
)
error
=
abs
(
proc
.
RMSE
(
proc_results
,
jnp
.
fft
.
fft
(
ground_truth
)))
#RMSE
return
error
def
wrapper_loss_freq
(
static_params
):
return
jit
(
partial
(
loss_freq
,
static_params
=
static_params
))
# ------------------------ BATCHING ------------------------
#
#
#
#-----------------------------------------------------------
#------------------------LVL LEARNING------------------------
batched_proc_pipeline
=
vmap
(
proc_pipeline
,
in_axes
=
(
None
,
0
,
None
,
None
))
batched_proc_pipeline_freq
=
vmap
(
proc_pipeline_freq_learning
,
in_axes
=
(
None
,
0
,
None
))
batched_lvl_crossing
=
vmap
(
proc
.
lvl_crossing_stupid
,
in_axes
=
(
0
,
None
))
batched_loss
=
vmap
(
loss
,
in_axes
=
(
None
,
0
,
0
,
None
,
None
))
batched_loss_freq
=
vmap
(
loss_freq
,
in_axes
=
(
None
,
0
,
0
,
None
))
#@partial(jit,static_argnums=(3,4))
def
avg_batch_loss
(
params
,
x
,
ground_truth
,
sigma
,
static_params
):
loss
=
batched_loss
(
params
,
x
,
ground_truth
,
sigma
,
static_params
)
return
jnp
.
average
(
loss
)
def
wrapper_avg_batch_loss
(
static_params
):
return
jit
(
partial
(
avg_batch_loss
,
static_params
=
static_params
))
def
dataset_loss
(
params
,
dataset
,
sigma
,
static_params
):
loss
=
0
i
=
0
for
(
batch
,
objective_batch
)
in
dataset
:
i
+=
len
(
batch
)
loss
+=
jnp
.
sum
(
batched_loss
(
params
,
batch
,
objective_batch
,
sigma
,
static_params
))
return
loss
/
i
def
loss_each_samples
(
params
,
dataset
,
sigma
,
static_params
):
loss
=
[]
for
(
batch
,
objective_batch
)
in
dataset
:
loss
.
extend
(
batched_loss
(
params
,
batch
,
objective_batch
,
sigma
,
static_params
))
return
loss
#------------------------FREQ LEARNING------------------------
#@partial(jit,static_argnums=(3,))
def
avg_batch_loss_freq
(
params
,
x
,
ground_truth
,
static_params
):
loss
=
batched_loss_freq
(
params
,
x
,
ground_truth
,
static_params
)
return
jnp
.
average
(
loss
)
def
wrapper_avg_batch_loss_freq
(
static_params
):
return
jit
(
partial
(
avg_batch_loss_freq
,
static_params
=
static_params
))
def
compute_dataset_loss_freq
(
params
,
dataset
,
static_params
):
loss
=
0
i
=
0
for
(
batch
,
objective_batch
)
in
dataset
:
i
+=
len
(
batch
)
loss
+=
jnp
.
sum
(
batched_loss_freq
(
params
,
batch
,
objective_batch
,
static_params
))
return
loss
/
i
def
loss_each_samples_freq
(
params
,
dataset
,
static_params
):
loss
=
[]
for
(
batch
,
objective_batch
)
in
dataset
:
loss
.
extend
(
batched_loss_freq
(
params
,
batch
,
objective_batch
,
static_params
))
return
loss
# ------------------------ GAUSSIAN/SPECTRA OPTIMIZATION ------------------------
#
#
#
#--------------------------------------------------------------------------------
def
train_and_test
(
dataset
,
params
,
sigma
,
static_params
,
lr
=
LR
):
this_avg_batch_loss
=
wrapper_avg_batch_loss
(
static_params
)
train_dataset
=
dataset
[
0
]
test_dataset
=
dataset
[
1
]
print
(
f
"Train Loss: {dataset_loss(params,train_dataset,sigma,static_params)}"
)
print
(
f
"Test loss:{dataset_loss(params,test_dataset,sigma,static_params)}"
)
optimizer
=
optax
.
adam
(
lr
)
opt_state
=
optimizer
.
init
(
params
)
for
e
in
range
(
ADAM_EPOCHS
):
for
(
train_batch
,
objective_batch
)
in
train_dataset
:
#grads = jacfwd(batched_loss_fn)(params, train_batch, objective_batch, sigma, static_params)
#grads = vjp(batched_loss_fn, params, train_batch, objective_batch, sigma, static_params)[1](1.0)
grads
=
grad
(
this_avg_batch_loss
)(
params
,
train_batch
,
objective_batch
,
sigma
)
updates
,
opt_state
=
optimizer
.
update
(
grads
,
opt_state
)
params
=
optax
.
apply_updates
(
params
,
updates
)
if
e
%
(
ADAM_EPOCHS
//
10
)
==
0
:
print
(
f
"
\n
Epoch {e+1}"
)
print
(
f
"
\t
Train Loss: {dataset_loss(params,train_dataset,sigma,static_params)}"
)
print
(
f
"
\t
Test loss:{dataset_loss(params,test_dataset,sigma,static_params)}"
)
print
(
f
"
\n
Epoch {e+1}"
)
print
(
f
"
\t
Train Loss: {dataset_loss(params,train_dataset,sigma,static_params)}"
)
print
(
f
"
\t
Test loss:{dataset_loss(params,test_dataset,sigma,static_params)}"
)
lbfgs_run
=
USE_LBFGS
while
lbfgs_run
:
print
(
"
\n
Now using LBFGS"
)
params_lbfgs
=
params
tollerance
=
TOLLERANCE
solver
=
jopt
.
LBFGS
(
fun
=
this_avg_batch_loss
,
maxiter
=
BFGS_ITER
,
jit
=
True
,
linesearch
=
LINE_SEARCH
,
tol
=
tollerance
,
history_size
=
HISTORY_SIZE
)
for
(
train_batch
,
objective_batch
)
in
train_dataset
:
args
=
{
'x'
:
train_batch
,
'ground_truth'
:
objective_batch
,
'sigma'
:
sigma
}
params_lbfgs
,
_
=
solver
.
run
(
params
,
**
args
)
loss_train
=
dataset_loss
(
params_lbfgs
,
train_dataset
,
sigma
,
static_params
)
if
not
jnp
.
isnan
(
loss_train
):
lbfgs_run
=
False
else
:
print
(
'LBFGS returned nan, re-running...'
)
tollerance
+=
tollerance
*
0.5
print
(
"
\n
LBFGS results:"
)
print
(
f
"
\t
Train Loss: {dataset_loss(params,train_dataset,sigma,static_params)}"
)
print
(
f
"
\t
Test loss:{dataset_loss(params,test_dataset,sigma,static_params)}"
)
params
=
params_lbfgs
return
params
def
compute_transform_each_samples
(
train_dataset
,
params
,
sigma
,
static_params
):
proc_results
=
[]
for
(
train_batch
,
_
)
in
train_dataset
:
proc_results
.
extend
(
batched_proc_pipeline
(
params
,
train_batch
,
sigma
,
static_params
)[
0
])
return
proc_results
# ------------------------ SPECTRA ONLY OPTIMIZATION ---------------------------
#
#
#
#--------------------------------------------------------------------------------
def
train_and_test_freq
(
dataset
,
params
,
static_params
,
lr
=
LR
):
this_avg_batch_loss_freq
=
wrapper_avg_batch_loss_freq
(
static_params
)
train_dataset
=
dataset
[
0
]
test_dataset
=
dataset
[
1
]
print
(
f
"Train Loss: {compute_dataset_loss_freq(params,train_dataset,static_params)}"
)
print
(
f
"Test Loss: {compute_dataset_loss_freq(params,test_dataset,static_params)}"
)
optimizer
=
optax
.
adam
(
lr
)
opt_state
=
optimizer
.
init
(
params
)
for
e
in
range
(
ADAM_EPOCHS_FREQUENCY_LEARNING
):
for
(
train_batch
,
objective_batch
)
in
train_dataset
:
#grads = jacfwd(batched_loss_fn_freq_learning)(params, train_batch, objective_batch, static_params)
#grads = vjp(batched_loss_fn_freq_learning, params, train_batch, objective_batch, static_params)[1](1.0)
grads
=
grad
(
this_avg_batch_loss_freq
)(
params
,
train_batch
,
objective_batch
)
updates
,
opt_state
=
optimizer
.
update
(
grads
,
opt_state
)
params
=
optax
.
apply_updates
(
params
,
updates
)
if
e
%
(
ADAM_EPOCHS_FREQUENCY_LEARNING
//
10
)
==
0
:
print
(
f
"
\n
Epoch {e+1}"
)
print
(
f
"
\t
Train Loss: {compute_dataset_loss_freq(params,train_dataset,static_params)}"
)
print
(
f
"
\t
Test Loss: {compute_dataset_loss_freq(params,test_dataset,static_params)}"
)
print
(
f
"
\n
Epoch {e+1}"
)
print
(
f
"
\t
Train Loss: {compute_dataset_loss_freq(params,train_dataset,static_params)}"
)
print
(
f
"
\t
Test loss:{compute_dataset_loss_freq(params,test_dataset,static_params)}"
)
lbfgs_run
=
USE_LBFGS
while
lbfgs_run
:
print
(
"
\n
Now using LBFGS"
)
params_lbfgs
=
params
tollerance
=
TOLLERANCE
solver
=
jopt
.
LBFGS
(
fun
=
this_avg_batch_loss_freq
,
maxiter
=
BFGS_ITER
,
jit
=
True
,
linesearch
=
LINE_SEARCH
,
tol
=
tollerance
,
history_size
=
HISTORY_SIZE
)
for
(
train_batch
,
objective_batch
)
in
train_dataset
:
args
=
{
'x'
:
train_batch
,
'ground_truth'
:
objective_batch
}
params_lbfgs
,
_
=
solver
.
run
(
params
,
**
args
)
loss_train
=
compute_dataset_loss_freq
(
params_lbfgs
,
train_dataset
,
static_params
)
if
not
jnp
.
isnan
(
loss_train
):
lbfgs_run
=
False
else
:
print
(
'LBFGS returned nan, re-running...'
)
tollerance
+=
tollerance
*
0.5
print
(
"
\n
LBFGS results:"
)
print
(
f
"
\t
Train Loss: {compute_dataset_loss_freq(params,train_dataset,static_params)}"
)
print
(
f
"
\t
Test loss:{compute_dataset_loss_freq(params,test_dataset,static_params)}"
)
params
=
params_lbfgs
return
params
def
compute_transform_each_samples_freq
(
dataset
,
params
,
static_params
):
proc_results
=
[]
for
(
train_batch
,
_
)
in
dataset
:
proc_results
.
extend
(
batched_proc_pipeline_freq
(
params
,
train_batch
,
static_params
))
return
proc_results
# ------------------------ HELPER FUNCTIONS ------------------------
#
#
#
#-------------------------------------------------------------------
def
custom_collate_fn
(
batch
):
transposed_data
=
list
(
zip
(
*
batch
))
data
=
np
.
array
(
transposed_data
[
0
])
obj
=
np
.
array
(
transposed_data
[
1
])
return
data
,
obj
class
hashabledict
(
dict
):
def
__hash__
(
self
):
return
hash
(
tuple
(
sorted
(
self
.
items
())))
class
hashable_np_array
(
np
.
ndarray
):
def
__hash__
(
self
):
return
int
(
self
.
mean
()
*
1
_000_000_000
)
def
compute_transform_and_plot_beat
(
beat
,
ground_truth
,
params
,
sigma
,
static_params
,
t_hf
,
t_lf
,
name
):
gaussian_lvl_crossing_data
=
proc
.
mix_gaussian_lvl_crossing
(
beat
,
params
[
'mus'
],
sigma
)
#gaussian_lvl_crossing_data = proc.normalize(gaussian_lvl_crossing_data)
filtered
=
proc
.
linear_transform_fourier_domain
(
gaussian_lvl_crossing_data
,
params
[
'freq_coeffs'
])
#interpolated = proc.sinc_interpolation_freq_parametrized(filtered,FREQ,FREQ,static_params['time_base'])
#normed = interpolated #proc.normalize(interpolated)
plt
.
figure
()
plt
.
plot
(
t_hf
,
beat
)
plt
.
plot
(
t_lf
,
ground_truth
,
"d"
)
plt
.
plot
(
t_hf
,
gaussian_lvl_crossing_data
)
plt
.
plot
(
t_hf
[:
len
(
filtered
)],
filtered
)
plt
.
hlines
(
params
[
'mus'
],
t_hf
[
0
],
t_hf
[
-
1
])
plt
.
savefig
(
name
)
plt
.
close
()
interpolated
=
proc
.
sinc_interpolation_freq_parametrized
(
filtered
,
FREQ
,
FREQ
,
static_params
[
'time_base'
])
normed
=
interpolated
#proc.normalize(interpolated)
plt
.
figure
()
plt
.
plot
(
t_hf
,
beat
)
plt
.
plot
(
t_lf
,
ground_truth
,
"d"
)
plt
.
plot
(
t_hf
,
gaussian_lvl_crossing_data
)
plt
.
plot
(
t_lf
,
normed
,
"o"
)
plt
.
hlines
(
params
[
'mus'
],
t_hf
[
0
],
t_hf
[
-
1
])
plt
.
savefig
(
name
+
"_samples.svg"
)
plt
.
close
()
beat_freq
=
np
.
fft
.
fft
(
beat
)[:
len
(
beat
)
//
2
]
trans_freq
=
np
.
fft
.
fft
(
filtered
)[:
len
(
filtered
)
//
2
]
plt
.
figure
()
plt
.
plot
(
abs
(
beat_freq
))
plt
.
plot
(
abs
(
trans_freq
))
plt
.
savefig
(
name
+
"_f.svg"
)
plt
.
close
()
def
compute_pure_lc_transform_and_plot_beat
(
beat
,
ground_truth
,
params
,
static_params
,
t_hf
,
t_lf
,
name
):
lvl_crossing_data
=
proc
.
lvl_crossing_stupid
(
beat
,
static_params
[
'mus'
])
#lvl_crossing_data = proc.normalize(lvl_crossing_data)
filtered
=
proc
.
linear_transform_fourier_domain
(
lvl_crossing_data
,
params
[
'freq_coeffs'
])
plt
.
figure
()
plt
.
plot
(
t_hf
,
beat
)
plt
.
plot
(
t_lf
,
ground_truth
,
"d"
)
plt
.
plot
(
t_hf
,
lvl_crossing_data
)
plt
.
plot
(
t_hf
[:
len
(
filtered
)],
filtered
)
plt
.
hlines
(
static_params
[
'mus'
],
t_hf
[
0
],
t_hf
[
-
1
])
plt
.
savefig
(
name
)
plt
.
close
()
interpolated
=
proc
.
sinc_interpolation_freq_parametrized
(
filtered
,
FREQ
,
FREQ
,
static_params
[
'time_base'
])
normed
=
interpolated
#proc.normalize(interpolated)
plt
.
figure
()
plt
.
plot
(
t_hf
,
beat
)
plt
.
plot
(
t_lf
,
ground_truth
,
"d"
)
plt
.
plot
(
t_hf
,
lvl_crossing_data
)
plt
.
plot
(
t_lf
,
normed
,
"o"
)
plt
.
hlines
(
static_params
[
'mus'
],
t_hf
[
0
],
t_hf
[
-
1
])
plt
.
savefig
(
name
+
"_samples.svg"
)
plt
.
close
()
beat_freq
=
np
.
fft
.
fft
(
beat
)[:
len
(
beat
)
//
2
]
trans_freq
=
np
.
fft
.
fft
(
filtered
)[:
len
(
filtered
)
//
2
]
plt
.
figure
()
plt
.
plot
(
abs
(
beat_freq
))
plt
.
plot
(
abs
(
trans_freq
))
plt
.
savefig
(
name
+
"_f.svg"
)
plt
.
close
()
def
save_env
(
res_folder_this_run
):
with
open
(
os
.
path
.
join
(
res_folder_this_run
,
'settings.txt'
),
'w'
)
as
f
:
f
.
write
(
str_env_to_save
)
def
save_params_state
(
res_folder_this_run
,
params
):
with
open
(
os
.
path
.
join
(
res_folder_this_run
,
'params.txt'
),
'w'
)
as
f
:
for
p
in
params
:
f
.
write
(
f
"{p}: {str(params[p])}
\n
"
)
# ------------------------ MAIN LOOP ------------------------
#
#
#
#------------------------------------------------------------
def
main
():
#Setting env:
if
not
os
.
path
.
isdir
(
res_folder
):
os
.
mkdir
(
res_folder
)
t_stamp
=
time
.
ctime
()
res_folder_this_run
=
os
.
path
.
join
(
res_folder
,
t_stamp
)
os
.
mkdir
(
res_folder_this_run
)
save_env
(
res_folder_this_run
)
os
.
environ
[
"XLA_FLAGS"
]
=
f
'--xla_force_host_platform_device_count={BATCH_SIZE}'
#load/generate data
print
(
"Loading data"
)
data
=
dataLoader
.
get_signal
(
SIGNAL_TYPE
,
num_pts
=
NUM_SAMPLES_TEST
+
NUM_SAMPLES_TRAIN
,
freq
=
FREQ
)
if
data
is
None
:
print
(
"Failed loading data"
)
return
#Only even number of points allowed, why? 'cause reasons
if
data
.
shape
[
1
]
%
2
!=
0
:
data
=
data
[:,:
-
1
]
data
=
proc
.
normalize_dataset
(
data
)
t_base_orig
=
np
.
arange
(
0
,(
len
(
data
[
0
]))
/
FREQ
,
1
/
FREQ
)[:
len
(
data
[
0
])]
print
(
"
\n
-------------------
\n
"
)
print
(
f
"Loaded Data:shape {data.shape}"
)
#reference
print
(
"Computing maximum Nyquist frequency for the Dataset..."
)
nyq_freq_true
=
proc
.
get_nyquist_freq_dataset
(
data
,
FREQ
,
USE_ENERGY_MAX_FREQ_EVAL
,
STOPPING_PERC
)
print
(
f
"Nyquist frequency: {nyq_freq_true}"
)
print
(
"Generating Nyquist sampled objective dataset"
)
#CHANGING THIS TO BASICALLY IGNORE ALL NQ. SHIT
dataset_nyq
=
data
#proc.interpolate_dataset(data, FREQ, nyq_freq)
nyq_freq
=
FREQ
#... CHE SCHIFO
t_base_nyq
=
np
.
arange
(
0
,
len
(
dataset_nyq
[
0
])
/
nyq_freq
,
1
/
nyq_freq
)[:
len
(
dataset_nyq
[
0
])]
train_dataset
=
[[
d
,
o
]
for
d
,
o
in
zip
(
data
[:
NUM_SAMPLES_TRAIN
],
dataset_nyq
[:
NUM_SAMPLES_TRAIN
])]
test_dataset
=
[[
d
,
o
]
for
d
,
o
in
zip
(
data
[
NUM_SAMPLES_TRAIN
:],
dataset_nyq
[
NUM_SAMPLES_TRAIN
:])]
train_loader
=
DataLoader
(
train_dataset
,
BATCH_SIZE
,
shuffle
=
False
,
collate_fn
=
custom_collate_fn
,
drop_last
=
False
)
test_loader
=
DataLoader
(
test_dataset
,
BATCH_SIZE
,
shuffle
=
False
,
collate_fn
=
custom_collate_fn
,
drop_last
=
False
)
dataset
=
[
train_loader
,
test_loader
]
#------------------------------------ GAUSSIAN/FREQUENCY LEARNING WITH TRUE LVL CROSSING ------------------------------------
#
#
#
#----------------------------------------------------------------------------------------------------------------------------
#INIT
print
(
"Initializing pipeline Parameters"
)
sigma
=
float
((
jnp
.
max
(
data
)
-
jnp
.
min
(
data
))
/
(
INITIAL_SIGMA_DIVISOR
*
NUM_LVLS
))
#We can pass the length of the nyquist samples signals as length of the filter
params
=
init_pipeline
(
NUM_LVLS
,
np
.
min
(
data
),
np
.
max
(
data
),
len_filter
=
int
(
t_base_nyq
[
-
1
]
*
nyq_freq_true
*
FILTER_LEN_WRT_NYQ
),
len_signal
=
len
(
data
[
0
]))
# len(data[0])
static_params
=
{
'time_base'
:
None
,
'lambda_maxima'
:
None
,
'lambda_freq'
:
None
,
'ideal_filter_len'
:
None
}
#hashabledict({'freq' : None, 'time_base' : None, 'lambda' : None})
static_params
[
'time_base'
]
=
t_base_nyq
#.view(hashable_np_array)
static_params
[
'lambda_maxima'
]
=
LAMBDA_MAXIMA
static_params
[
'lambda_freq'
]
=
LAMBDA_FREQUENCY
static_params
[
'ideal_filter_len'
]
=
int
(
t_base_nyq
[
-
1
]
*
nyq_freq_true
//
2
)
#static_params = hashabledict(static_params)
print
(
f
"Parameters: {params}
\n
"
)
loss_train_sigmas
=
[]
loss_test_sigmas
=
[]
lr
=
LR
sigma_iter
=
0
for
i
in
range
(
SIGMA_EPOCH
):
print
(
"%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/%/"
)
print
(
f
"SIGMA EPOCH: {i+1}/{SIGMA_EPOCH}, sigma: {sigma}
\n
"
)
params
=
train_and_test
(
dataset
,
params
,
sigma
,
static_params
,
lr
)
loss_train
=
dataset_loss
(
params
,
train_loader
,
sigma
,
static_params
)
loss_test
=
dataset_loss
(
params
,
test_loader
,
sigma
,
static_params
)
loss_train_sigmas
.
append
(
loss_train
)
loss_test_sigmas
.
append
(
loss_test
)
# -------------------------- TRAIN ------------------------------
losses
=
loss_each_samples
(
params
,
train_loader
,
sigma
,
static_params
)
best_beat
=
train_dataset
[
np
.
argmin
(
losses
)][
0
]
best_beat_nyq
=
train_dataset
[
np
.
argmin
(
losses
)][
1
]
worst_beat
=
train_dataset
[
np
.
argmax
(
losses
)][
0
]
worst_beat_nyq
=
train_dataset
[
np
.
argmax
(
losses
)][
1
]
name
=
f
'{res_folder_this_run}/TRAIN_{sigma_iter}:sigma:{sigma}_best_loss:{np.min(losses)}.svg'
compute_transform_and_plot_beat
(
best_beat
,
best_beat_nyq
,
params
,
sigma
,
static_params
,
t_base_orig
,
t_base_nyq
,
name
)
name
=
f
'{res_folder_this_run}/TRAIN_{sigma_iter}:sigma:{sigma}_worst_loss:{np.max(losses)}.svg'
compute_transform_and_plot_beat
(
worst_beat
,
worst_beat_nyq
,
params
,
sigma
,
static_params
,
t_base_orig
,
t_base_nyq
,
name
)
# -------------------------- TEST ------------------------------
losses
=
loss_each_samples
(
params
,
test_loader
,
sigma
,
static_params
)
best_beat
=
test_dataset
[
np
.
argmin
(
losses
)][
0
]
best_beat_nyq
=
test_dataset
[
np
.
argmin
(
losses
)][
1
]
worst_beat
=
test_dataset
[
np
.
argmax
(
losses
)][
0
]
worst_beat_nyq
=
test_dataset
[
np
.
argmax
(
losses
)][
1
]
name
=
f
'{res_folder_this_run}/TEST_{sigma_iter}:sigma:{sigma}_best_loss:{np.min(losses)}.svg'
compute_transform_and_plot_beat
(
best_beat
,
best_beat_nyq
,
params
,
sigma
,
static_params
,
t_base_orig
,
t_base_nyq
,
name
)
name
=
f
'{res_folder_this_run}/TEST_{sigma_iter}:sigma:{sigma}_worst_loss:{np.max(losses)}.svg'
compute_transform_and_plot_beat
(
worst_beat
,
worst_beat_nyq
,
params
,
sigma
,
static_params
,
t_base_orig
,
t_base_nyq
,
name
)
plt
.
figure
()
plt
.
matshow
(
abs
(
params
[
'freq_coeffs'
]))
plt
.
colorbar
()
plt
.
savefig
(
f
'{res_folder_this_run}/{sigma_iter}:sigma:{sigma}_filter_abs.svg'
)
plt
.
close
()
plt
.
figure
()
plt
.
matshow
(
jnp
.
angle
(
params
[
'freq_coeffs'
]))
plt
.
colorbar
()
plt
.
savefig
(
f
'{res_folder_this_run}/{sigma_iter}:sigma:{sigma}_filter_phs.svg'
)
plt
.
close
()
sigma
/=
SIGMA_EPOCH_DIVISOR
lr
/=
LR_DIVISOR
sigma_iter
+=
1
print
(
'
\n
--------------------------------------------------------------------'
)
print
(
f
"END OF SIGMA EPOCH {i+1}, TRAIN LOSS = {loss_train}, TEST LOSS = {loss_test}"
)
print
(
f
"Levels:{params['mus']}"
)
print
(
'--------------------------------------------------------------------
\n
'
)
plt
.
figure
()
plt
.
plot
(
loss_train_sigmas
)
plt
.
plot
(
loss_test_sigmas
)
plt
.
legend
([
'train'
,
'test'
])
plt
.
savefig
(
f
'{res_folder_this_run}/lossVSepoch.svg'
)
#plt.show()
#----------------------------------------- FREQUENCY LEARNING WITH TRUE LVL CROSSING -----------------------------------------
#
#
#
#
#-----------------------------------------------------------------------------------------------------------------------------
print
(
f
"SPECTRAL LEARNING FOR THE FOUND LEVELS"
)
#INIT
print
(
"Initializing pipeline Parameters"
)
static_params
=
{
'mus'
:
params
[
'mus'
],
#.view(hashable_np_array),
'time_base'
:
static_params
[
'time_base'
],
'lambda_freq'
:
static_params
[
'lambda_freq'
],
'ideal_filter_len'
:
static_params
[
'ideal_filter_len'
]}
#static_params = hashabledict(static_params)
params
=
{
'freq_coeffs'
:
params
[
'freq_coeffs'
]}
params
=
train_and_test_freq
(
dataset
,
params
,
static_params
,
LR_FREQ_LEARNING
)
# -------------------------- TRAIN ------------------------------
losses
=
loss_each_samples_freq
(
params
,
train_loader
,
static_params
)
best_beat
=
train_dataset
[
np
.
argmin
(
losses
)][
0
]
best_beat_nyq
=
train_dataset
[
np
.
argmin
(
losses
)][
1
]
worst_beat
=
train_dataset
[
np
.
argmax
(
losses
)][
0
]
worst_beat_nyq
=
train_dataset
[
np
.
argmax
(
losses
)][
1
]
#print(f"TRAIN: Average number of events: {np.average(num_samples)}+-{np.std(num_samples)}")
name
=
f
'{res_folder_this_run}/TRAIN_freq_learning_level_crossing_best_loss:{np.min(losses)}.svg'
compute_pure_lc_transform_and_plot_beat
(
best_beat
,
best_beat_nyq
,
params
,
static_params
,
t_base_orig
,
t_base_nyq
,
name
)
name
=
f
'{res_folder_this_run}/TRAIN_freq_learning_level_crossing_worst_loss:{np.max(losses)}.svg'
compute_pure_lc_transform_and_plot_beat
(
worst_beat
,
worst_beat_nyq
,
params
,
static_params
,
t_base_orig
,
t_base_nyq
,
name
)
# -------------------------- TEST ------------------------------
losses
=
loss_each_samples_freq
(
params
,
test_loader
,
static_params
)
best_beat
=
test_dataset
[
np
.
argmin
(
losses
)][
0
]
best_beat_nyq
=
test_dataset
[
np
.
argmin
(
losses
)][
1
]
worst_beat
=
test_dataset
[
np
.
argmax
(
losses
)][
0
]
worst_beat_nyq
=
test_dataset
[
np
.
argmax
(
losses
)][
1
]
#print(f"TEST: Average number of events: {np.average(num_samples)}+-{np.std(num_samples)}")
name
=
f
'{res_folder_this_run}/TEST_freq_learning_level_crossing_best_loss:{np.min(losses)}.svg'
compute_pure_lc_transform_and_plot_beat
(
best_beat
,
best_beat_nyq
,
params
,
static_params
,
t_base_orig
,
t_base_nyq
,
name
)
name
=
f
'{res_folder_this_run}/TEST_freq_learning_level_crossing_worst_loss:{np.max(losses)}.svg'
compute_pure_lc_transform_and_plot_beat
(
worst_beat
,
worst_beat_nyq
,
params
,
static_params
,
t_base_orig
,
t_base_nyq
,
name
)
plt
.
figure
()
plt
.
matshow
(
abs
(
params
[
'freq_coeffs'
]))
plt
.
colorbar
()
plt
.
savefig
(
f
'{res_folder_this_run}/freq_learning_filter_abs.svg'
)
plt
.
close
()
plt
.
figure
()
plt
.
matshow
(
jnp
.
angle
(
params
[
'freq_coeffs'
]))
plt
.
colorbar
()
plt
.
savefig
(
f
'{res_folder_this_run}/freq_learning_filter_phs.svg'
)
plt
.
close
()
save_params_state
(
res_folder_this_run
,{
'filter_coefficeints'
:
params
[
'freq_coeffs'
],
'levels'
:
static_params
[
'mus'
]})
if
__name__
==
"__main__"
:
#test_signal()
main
()
Event Timeline
Log In to Comment