Page MenuHomec4science

processing.py
No OneTemporary

File Metadata

Created
Wed, Aug 14, 17:33

processing.py

import numpy as np
import numpy.fft as fft
import jax.numpy as jnp
from jax import lax,vmap
from functools import partial
INTERPOLATOR = 'sinc'
SIGNAL_TYPE = 'create' #'create': create a new emyulated ECG signal, 'load': load previous ECG signal, 'random': random sinusoidal mix signal
def get_max_freq(signal, freq, use_energy, stopping_prec):
dt = 1/freq
signal_0_mean = signal-np.mean(signal)
f_sig = fft.fft(signal_0_mean)[0:len(signal_0_mean)//2]
freq_sig = np.fft.fftfreq(signal.size, d=dt)[0:len(signal_0_mean)//2]
if use_energy:
threshold = sum(abs(f_sig)**2)/stopping_prec
eval_coeff = lambda x: abs(x)**2
else:
threshold = max(abs(f_sig))/stopping_prec
eval_coeff = lambda x: abs(x)
last_significant_coef_pos = 0
magn_f_c = 0
for i,f_c in enumerate(f_sig[-1::-1]):
magn_f_c += eval_coeff(f_c)
if magn_f_c > threshold:
last_significant_coef_pos = len(f_sig)-1-i
break
max_freq = freq_sig[last_significant_coef_pos]
return max_freq
def get_nyquist_freq_dataset(data,freq,use_energy,stopping_perc):
max_f = 0
for pt in data:
f = get_max_freq(pt,freq,use_energy,stopping_perc)
max_f = max(max_f,f)
return 2.1*max_f
def normalize(signal):
return signal/(jnp.max(signal)-jnp.min(signal))
def normalize_dataset(data):
out =[]
for pt in data:
out.append(normalize(pt))
return jnp.array(out)
# ----------------------------------
def gaussian(x, mu, sig):
return jnp.exp(
-jnp.power(x - mu, 2.) /
(
2 * jnp.power(sig, 2.))
)
def gaussian_weight_diff(x, mu, sig):
return jnp.exp(
-jnp.power(x - mu, 2.) /
(
2 * 4e12 * jnp.power(sig, 2.)*jnp.diff(x,append = 0) #TODO: This is not a relative difference, this is miniscule! ... maybe not the problem, see experiments
)
)
def mix_gaussian_lvl_crossing(signal,mu_s,sigma):
gaussian_matr = vmap(lambda mu:gaussian(signal,mu,sigma))(mu_s)
return signal*jnp.sum(gaussian_matr,axis = 0)
def mix_gaussian_lvl_crossing_weight_diff(signal,mu_s,sigma):
gaussian_matr = vmap(lambda mu:gaussian_weight_diff(signal,mu,sigma))(mu_s)
return signal*jnp.sum(gaussian_matr,axis = 0)
def lvl_crossing_stupid(x,mus):
zero_vector = jnp.zeros((len(x),))
def lc_single_level(data, level):
"""Function that determines level-crossings for a single level"""
return jnp.logical_or(
jnp.logical_and(data[:-1] < level, data[1:] >= level),
jnp.logical_and(data[:-1] > level, data[1:] <= level))
lc_multi_level = vmap(lc_single_level, in_axes=(None, 0))
crossings = jnp.concatenate((jnp.array([False]),jnp.any(lc_multi_level(x, mus), axis=0)))
out = jnp.where(crossings,x,zero_vector)
return out
def mix_gaussian_lvl_crossing_dataset(data,mu_s,sigma):
return vmap(mix_gaussian_lvl_crossing,in_axes=(0,None,None))(data,mu_s,sigma)
def sinc_interpolation_freq_parametrized(samples,sinc_freq,signal_freq,time_base):
dt_samples = 1/signal_freq
base_sinc = lambda n:jnp.sinc((time_base-n*dt_samples)*sinc_freq)
sinc_matr_v = vmap(base_sinc)(jnp.array(range(len(samples))))
return jnp.dot(samples,sinc_matr_v)
#TODO: Inefficient, testing on test.py
def multi_sinc_interpolation(samples,sinc_freqs,ampls,signal_freq,time_base):
dt_samples = 1/signal_freq
tensor = jnp.zeros((len(samples),len(time_base),len(ampls)))
for i in range(len(ampls)):
base_sinc = lambda n:ampls[i]*jnp.sinc((time_base-n*dt_samples)*sinc_freqs[i])
tensor = tensor.at[:,:,i].set(vmap(base_sinc)(jnp.array(range(len(samples)))))
sincs_projections = jnp.tensordot(samples,tensor,axes = 1)# axes = 1 -> stupid nnumpy notation for saying "tensor dot product", why is this not the default? fuck me I guess
return jnp.sum(sincs_projections,axis = 1)
def fourier_filtering(signal,coeffs):
signal_fft = jnp.fft.fft(signal)
filter = jnp.concatenate((coeffs,jnp.conjugate(jnp.flip(coeffs))))
filtered = jnp.multiply(signal_fft,filter)
return jnp.real(jnp.fft.ifft(filtered))
def fourier_filtering_0_padding(signal,coeffs):
pad_len = len(signal)//2-len(coeffs)
padded_coefs = jnp.concatenate((coeffs,jnp.zeros(pad_len)))
signal_fft = jnp.fft.fft(signal)
filter = jnp.concatenate((padded_coefs,jnp.conjugate(jnp.flip(padded_coefs))))
filtered = jnp.multiply(signal_fft,filter)
return jnp.real(jnp.fft.ifft(filtered))
def fourier_filtering_no_back_transform(signal,coeffs):
signal_fft = jnp.fft.fft(signal)
filter = jnp.concatenate((coeffs,jnp.conjugate(jnp.flip(coeffs))))
filtered = jnp.multiply(signal_fft,filter)
return filtered
def conv(signal,coeffs):
len_pad_l = len(coeffs)//2
len_pad_r = len(coeffs)-len_pad_l-1
signal_padded = jnp.concatenate((signal[len(signal)-len_pad_l:],signal,signal[:len_pad_r])) #circular convolution to preserve fourier domain multiplication
return jnp.convolve(signal_padded,coeffs,mode = 'valid')
def linear_transform_fourier_domain(signal,coeffs):
'''
Today the muse did not speak,
thus my rage is inhuman
'''
signal_fft = jnp.fft.fft(signal)[:len(signal)//2]
pad_len = len(signal_fft)-coeffs.shape[1]
assert pad_len>=0
padded_coeffs = jnp.hstack((coeffs,jnp.zeros((coeffs.shape[0],pad_len))))
filtered = jnp.matmul(signal_fft,padded_coeffs)
filtered_full = jnp.concatenate((filtered,jnp.conjugate(jnp.flip(filtered))))
return jnp.real(jnp.fft.ifft(filtered_full))
def linear_transform_fourier_domain_no_back_transform(signal,coeffs):
'''
Today the muse did not speak,
thus my rage is inhuman
'''
signal_fft = jnp.fft.fft(signal)[:len(signal)//2]
pad_len = len(signal_fft)-coeffs.shape[1]
assert pad_len>=0
padded_coeffs = jnp.hstack((coeffs,jnp.zeros((coeffs.shape[0],pad_len))))
filtered = jnp.matmul(signal_fft,padded_coeffs)
filtered_full = jnp.concatenate((filtered,jnp.conjugate(jnp.flip(filtered))))
return filtered_full
def count_maxima(signal):
maxima = jnp.logical_and(signal[1:-1] > signal[2:], signal[1:-1] > signal[:-2])
return jnp.sum(maxima)
def aggregated_average_complex_error(v1,v2):
diff = v1-v2
abs_real_diff = jnp.sum(jnp.abs(diff))
abs_imag_diff = jnp.sum(jnp.abs(jnp.angle(diff)))
return (abs_real_diff+abs_imag_diff)/len(v1)
def RMSE(v1,v2):
return jnp.sqrt(jnp.sum((v1-v2)**2)/len(v1))
# ----------------------------------
def sinc_interp(samples,freq_in,freq_out):
dt_samples = 1/freq_in
dt_new = 1/freq_out
new_time_base = np.arange(0,len(samples)*dt_samples,dt_new)
return sinc_interpolation_freq_parametrized(samples,freq_in,freq_in,new_time_base)
def interpolate(signal,freq_in,freq_out,type = INTERPOLATOR):
if type == "sinc":
resampled = sinc_interp(signal,freq_in,freq_out)
else:
print("No interpolator recognized")
resampled = signal
return resampled
def interpolate_dataset(data,freq_in,freq_out,type = INTERPOLATOR):
out = []
for pt in data:
out.append(interpolate(pt,freq_in,freq_out))
return np.array(out)

Event Timeline