Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F83481503
gaussain_collpsing.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Tue, Sep 17, 09:24
Size
10 KB
Mime Type
text/x-python
Expires
Thu, Sep 19, 09:24 (2 d)
Engine
blob
Format
Raw Data
Handle
20845618
Attached To
R12761 LC-Sampling-Theory
gaussain_collpsing.py
View Options
from
functools
import
partial
import
os
import
jax
import
jax.numpy
as
jnp
import
matplotlib.pyplot
as
plt
import
numpy
as
np
from
jax
import
grad
,
jit
,
pmap
,
value_and_grad
,
vmap
from
jax.scipy.special
import
logsumexp
from
torch.utils.data
import
DataLoader
import
optax
import
dataset_generator
as
dataLoader
import
processing
as
proc
FREQ
=
1
_000_000
NUM_SAMPLES
=
50
BATCH_SIZE
=
10
USE_ENERGY_MAX_FREQ_EVAL
=
True
STOPPING_PERC
=
4
_000
INTERPOLATOR
=
'sinc'
SIGNAL_TYPE
=
'create'
#'create': create a new emyulated ECG signal, 'load': load previous ECG signal, 'random': random sinusoidal mix signal
EPOCHS
=
20
SIGMA_EPOCH
=
50
SIGMA_EPOCH_DIVISOR
=
1.1
INIT_FREQ_DIVISOR
=
4
NUM_LVLS
=
16
LR
=
0.005
DEV_MODE
=
False
imag_res_folder
=
"../img_res"
def
init_pipeline
(
gaussian_numbers
,
range_min
,
range_max
,
freq
):
mu_s
=
jnp
.
linspace
(
range_min
+
(
range_max
-
range_min
)
/
NUM_LVLS
,
range_max
-
(
range_max
-
range_min
)
/
NUM_LVLS
,
gaussian_numbers
)
params
=
{
'mus'
:
mu_s
,
'freq'
:
freq
}
return
params
def
proc_pipeline
(
params
,
x
,
sigma
,
static_params
):
signal
=
x
signal
=
proc
.
mix_gaussian_lvl_crossing
(
signal
,
params
[
'mus'
],
sigma
)
signal
=
proc
.
normalize
(
signal
)
signal
=
proc
.
sinc_interpolation_freq_parametrized
(
signal
,
params
[
'freq'
],
FREQ
,
static_params
[
'time_base'
])
signal
=
proc
.
normalize
(
signal
)
return
signal
batched_proc_pipeline
=
pmap
(
proc_pipeline
,
in_axes
=
(
None
,
0
,
None
,
None
),
static_broadcasted_argnums
=
[
3
])
batched_RMSE
=
pmap
(
proc
.
RMSE
)
def
batched_loss_fn
(
params
,
data
,
nyquist_sampled_data
,
sigma
,
static_params
):
proc_results
=
batched_proc_pipeline
(
params
,
data
,
sigma
,
static_params
)
return
jnp
.
mean
(
batched_RMSE
(
proc_results
,
nyquist_sampled_data
))
def
compute_loss
(
dataset
,
params
,
sigma
,
static_params
):
loss
=
0
i
=
0
for
(
batch
,
objective_batch
)
in
dataset
:
i
+=
1
loss
+=
batched_loss_fn
(
params
,
batch
,
objective_batch
,
sigma
,
static_params
)
return
loss
/
i
def
train
(
train_dataset
,
params
,
sigma
,
static_params
,
lr
=
LR
):
print
(
f
"Loss: {compute_loss(train_dataset,params,sigma,static_params)}"
)
optimizer
=
optax
.
adam
(
lr
)
opt_state
=
optimizer
.
init
(
params
)
for
e
in
range
(
EPOCHS
):
print
(
f
"Epoch {e+1}"
)
for
(
train_batch
,
objective_batch
)
in
train_dataset
:
grads
=
grad
(
batched_loss_fn
)(
params
,
train_batch
,
objective_batch
,
sigma
,
static_params
)
updates
,
opt_state
=
optimizer
.
update
(
grads
,
opt_state
)
params
=
optax
.
apply_updates
(
params
,
updates
)
print
(
f
"Loss: {compute_loss(train_dataset,params,sigma,static_params)}"
)
return
params
#Final loss
def
compute_loss_each_samples
(
dataset
,
params
,
sigma
,
static_params
):
#NO
loss
=
[]
for
(
batch
,
objective_batch
)
in
dataset
:
proc_results
=
batched_proc_pipeline
(
params
,
batch
,
sigma
,
static_params
)
loss
.
extend
(
batched_RMSE
(
proc_results
,
objective_batch
))
return
loss
def
compute_transform_each_samples
(
train_dataset
,
params
,
sigma
,
static_params
,
lr
=
LR
):
proc_results
=
[]
for
(
train_batch
,
_
)
in
train_dataset
:
proc_results
.
extend
(
batched_proc_pipeline
(
params
,
train_batch
,
sigma
,
static_params
))
return
proc_results
#Final loss
def
custom_collate_fn
(
batch
):
transposed_data
=
list
(
zip
(
*
batch
))
data
=
np
.
array
(
transposed_data
[
0
])
obj
=
np
.
array
(
transposed_data
[
1
])
return
data
,
obj
class
hashabledict
(
dict
):
def
__hash__
(
self
):
return
hash
(
tuple
(
sorted
(
self
.
items
())))
class
hashable_np_array
(
np
.
ndarray
):
def
__hash__
(
self
):
return
int
(
self
.
mean
()
*
1
_000_000_000
)
def
main
():
if
DEV_MODE
:
freq_desired
=
int
(
input
(
"What base frequency do you need your data: "
))
mus_num
=
int
(
input
(
"How many levels: "
))
sigma_div
=
int
(
input
(
"Sigma fraction of the signal range (1/X): "
))
nyq_freq_div
=
int
(
input
(
"Sync freq in fraction of nyq. freq: "
))
data
=
dataLoader
.
get_signal
(
SIGNAL_TYPE
,
num_pts
=
1
,
freq
=
freq_desired
)
print
(
"Computing Nyquist frequency..."
)
nyq_freq
=
proc
.
get_nyquist_freq_dataset
(
data
,
freq_desired
,
USE_ENERGY_MAX_FREQ_EVAL
,
STOPPING_PERC
)
t_base_nyq
=
np
.
arange
(
0
,
len
(
data
[
0
])
/
freq_desired
,
1
/
nyq_freq
)
t_base_orig
=
np
.
arange
(
0
,(
len
(
data
[
0
]))
/
freq_desired
,
1
/
freq_desired
)
print
(
f
"Nyquist frequency: {nyq_freq}"
)
print
(
"Generating Nyquist sampled objective"
)
dataset_nyq
=
proc
.
interpolate_dataset
(
data
,
freq_desired
,
nyq_freq
)
print
(
"Initializing pipelines Parameters"
)
sigma
=
float
((
jnp
.
max
(
data
)
-
jnp
.
min
(
data
))
/
sigma_div
)
params
=
init_pipeline
(
mus_num
,
np
.
min
(
data
),
np
.
max
(
data
),
nyq_freq
/
nyq_freq_div
)
print
(
f
"Parameters: {params}
\n
Sigma: {sigma}"
)
print
(
"Computing mixed gaussian representation"
)
gaussian_lvl_crossing_data
=
proc
.
mix_gaussian_lvl_crossing
(
data
[
0
],
params
[
'mus'
],
sigma
)
print
(
"Normalizing"
)
gaussian_lvl_crossing_data
=
proc
.
normalize
(
gaussian_lvl_crossing_data
)
print
(
"Computing sync filter"
)
resampled_gaussians
=
proc
.
sinc_interpolation_freq_parametrized
(
gaussian_lvl_crossing_data
,
params
[
'freq'
],
freq_desired
,
t_base_orig
)
print
(
"Normalizing"
)
resampled_gaussians
=
proc
.
normalize
(
resampled_gaussians
)
print
(
"Plotting"
)
plt
.
figure
()
plt
.
plot
(
t_base_orig
,
data
[
0
])
plt
.
plot
(
t_base_nyq
,
dataset_nyq
[
0
],
"d"
)
plt
.
plot
(
t_base_orig
,
gaussian_lvl_crossing_data
)
plt
.
plot
(
t_base_orig
,
resampled_gaussians
[:
len
(
t_base_orig
)],
"-"
)
plt
.
hlines
(
params
[
'mus'
],
t_base_orig
[
0
],
t_base_orig
[
-
1
])
plt
.
show
()
return
#Setting env:
if
not
os
.
path
.
isdir
(
imag_res_folder
):
os
.
mkdir
(
imag_res_folder
)
t_stamp
=
str
(
os
.
times
()
.
elapsed
)
res_folder_this_run
=
os
.
path
.
join
(
imag_res_folder
,
t_stamp
)
os
.
mkdir
(
res_folder_this_run
)
os
.
environ
[
"XLA_FLAGS"
]
=
f
'--xla_force_host_platform_device_count={BATCH_SIZE}'
#load/generate data
print
(
"Loading data"
)
data
=
dataLoader
.
get_signal
(
SIGNAL_TYPE
,
num_pts
=
NUM_SAMPLES
,
freq
=
FREQ
)
if
data
is
None
:
print
(
"Failed loading data"
)
return
data
=
proc
.
normalize_dataset
(
data
)
print
(
"
\n
-------------------
\n
"
)
print
(
f
"Loaded Data:shape {data.shape}"
)
#reference
print
(
"Computing maximum Nyquist frequency for the Dataset..."
)
nyq_freq
=
proc
.
get_nyquist_freq_dataset
(
data
,
FREQ
,
USE_ENERGY_MAX_FREQ_EVAL
,
STOPPING_PERC
)
t_base_nyq
=
np
.
arange
(
0
,
len
(
data
[
0
])
/
FREQ
,
1
/
nyq_freq
)
t_base_orig
=
np
.
arange
(
0
,(
len
(
data
[
0
]))
/
FREQ
,
1
/
FREQ
)
print
(
f
"Nyquist frequency: {nyq_freq}"
)
print
(
"Generating Nyquist sampled objective dataset"
)
dataset_nyq
=
proc
.
interpolate_dataset
(
data
,
FREQ
,
nyq_freq
)
train_dataset
=
[[
d
,
o
]
for
d
,
o
in
zip
(
data
,
dataset_nyq
)]
static_params
=
hashabledict
({
'freq'
:
None
,
'time_base'
:
None
})
static_params
[
'freq'
]
=
nyq_freq
static_params
[
'time_base'
]
=
t_base_nyq
.
view
(
hashable_np_array
)
#static_params = t_base_nyq.view(hashable_np_array)
#INIT
print
(
"Initializing pipelines Parameters"
)
sigma
=
float
((
jnp
.
max
(
data
)
-
jnp
.
min
(
data
))
/
(
2.2
*
NUM_LVLS
))
params
=
init_pipeline
(
NUM_LVLS
,
np
.
min
(
data
),
np
.
max
(
data
),
nyq_freq
/
INIT_FREQ_DIVISOR
)
train_loader
=
DataLoader
(
train_dataset
,
BATCH_SIZE
,
shuffle
=
True
,
collate_fn
=
custom_collate_fn
,
drop_last
=
False
)
print
(
f
"Parameters: {params}"
)
loss_sigmas
=
[]
lr
=
LR
for
i
in
range
(
SIGMA_EPOCH
):
print
(
f
"SIGMA EPOCH: {i+1}/{SIGMA_EPOCH}, sigma: {sigma}"
)
params
=
train
(
train_loader
,
params
,
sigma
,
static_params
,
lr
)
loss
=
compute_loss
(
train_loader
,
params
,
sigma
,
static_params
)
loss_sigmas
.
append
(
loss
)
losses
=
compute_loss_each_samples
(
train_loader
,
params
,
sigma
,
static_params
)
best_beat
=
np
.
argmin
(
losses
)
worst_beat
=
np
.
argmax
(
losses
)
gaussian_lvl_crossing_data
=
proc
.
mix_gaussian_lvl_crossing
(
data
[
best_beat
],
params
[
'mus'
],
sigma
)
gaussian_lvl_crossing_data
=
proc
.
normalize
(
gaussian_lvl_crossing_data
)
resampled_gaussians
=
proc
.
sinc_interpolation_freq_parametrized
(
gaussian_lvl_crossing_data
,
params
[
'freq'
],
FREQ
,
t_base_nyq
)
resampled_gaussians
=
proc
.
normalize
(
resampled_gaussians
)
plt
.
figure
()
plt
.
plot
(
t_base_orig
,
data
[
best_beat
])
plt
.
plot
(
t_base_nyq
,
dataset_nyq
[
best_beat
],
"d"
)
plt
.
plot
(
t_base_orig
,
gaussian_lvl_crossing_data
)
plt
.
plot
(
t_base_nyq
,
resampled_gaussians
[:
len
(
t_base_nyq
)],
"o-"
)
plt
.
hlines
(
params
[
'mus'
],
t_base_orig
[
0
],
t_base_orig
[
-
1
])
plt
.
savefig
(
f
'{res_folder_this_run}/sigma:{sigma}_best_loss:{np.min(losses)}.svg'
)
plt
.
close
()
gaussian_lvl_crossing_data
=
proc
.
mix_gaussian_lvl_crossing
(
data
[
worst_beat
],
params
[
'mus'
],
sigma
)
gaussian_lvl_crossing_data
=
proc
.
normalize
(
gaussian_lvl_crossing_data
)
resampled_gaussians
=
proc
.
sinc_interpolation_freq_parametrized
(
gaussian_lvl_crossing_data
,
params
[
'freq'
],
FREQ
,
t_base_nyq
)
resampled_gaussians
=
proc
.
normalize
(
resampled_gaussians
)
plt
.
figure
()
plt
.
plot
(
t_base_orig
,
data
[
best_beat
])
plt
.
plot
(
t_base_nyq
,
dataset_nyq
[
best_beat
],
"d"
)
plt
.
plot
(
t_base_orig
,
gaussian_lvl_crossing_data
)
plt
.
plot
(
t_base_nyq
,
resampled_gaussians
[:
len
(
t_base_nyq
)],
"o-"
)
plt
.
hlines
(
params
[
'mus'
],
t_base_orig
[
0
],
t_base_orig
[
-
1
])
plt
.
savefig
(
f
'{res_folder_this_run}/sigma:{sigma}_worst_loss:{np.max(losses)}.svg'
)
plt
.
close
()
sigma
/=
SIGMA_EPOCH_DIVISOR
lr
-=
lr
/
(
20
)
print
(
f
"END OF SIGMA EPOCH {i+1}, LOSS = {loss}"
)
print
(
params
)
plt
.
figure
()
plt
.
plot
(
loss_sigmas
)
plt
.
savefig
(
f
'{res_folder_this_run}/lossVSepoch.svg'
)
plt
.
close
()
pt
=
3
#create gaussain level crossing:
mu_s
=
params
[
'mus'
]
gaussian_lvl_crossing_data
=
proc
.
mix_gaussian_lvl_crossing
(
data
[
pt
],
mu_s
,
sigma
)
gaussian_lvl_crossing_data
=
proc
.
normalize
(
gaussian_lvl_crossing_data
)
resampled_gaussians
=
proc
.
sinc_interpolation_freq_parametrized
(
gaussian_lvl_crossing_data
,
params
[
'freq'
],
FREQ
,
t_base_orig
)
resampled_gaussians
=
proc
.
normalize
(
resampled_gaussians
)
#Testing grounds
plt
.
plot
(
t_base_orig
,
data
[
pt
])
plt
.
plot
(
t_base_nyq
,
dataset_nyq
[
pt
],
"d"
)
plt
.
plot
(
t_base_orig
,
gaussian_lvl_crossing_data
)
plt
.
plot
(
t_base_orig
,
resampled_gaussians
[:
len
(
t_base_orig
)],
"-"
)
plt
.
show
()
if
__name__
==
"__main__"
:
#test_signal()
main
()
Event Timeline
Log In to Comment