diff --git a/PySONIC/constants.py b/PySONIC/constants.py index 12de47c..e16fce8 100644 --- a/PySONIC/constants.py +++ b/PySONIC/constants.py @@ -1,80 +1,81 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2016-11-04 13:23:31 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-07-22 10:59:17 +# @Last Modified time: 2020-08-08 15:39:42 ''' Numerical constants used in the package. ''' # -------------------------- Biophysical constants -------------------------- FARADAY = 9.64853e4 # Faraday constant (C/mol) Rg = 8.31342 # Universal gas constant (Pa.m^3.mol^-1.K^-1 or J.mol^-1.K^-1) Z_Ca = 2 # Calcium valence Z_Na = 1 # Sodium valence Z_K = 1 # Potassium valence CELSIUS_2_KELVIN = 273.15 # Celsius to Kelvin conversion constant # -------------------------- Intermolecular pressure fitting -------------------------- LJFIT_PM_MAX = 1e8 # Pm value at the deflection lower bound for LJ fitting (Pa) PNET_EQ_MAX = 1e-1 # Pnet error threshold at computed equilibrium position (Pa) PMAVG_STD_ERR_MAX = 5e3 # error threshold in intermolecular pressure nonlinear fit (Pa) # -------------------------- Lookups pre-computing -------------------------- DQ_LOOKUP = 1e-5 # charge density interval step for lookup tables # -------------------------- Simulations -------------------------- MAX_RMSE_PTP_RATIO = 1e-4 # threshold RMSE / peak-to-peak ratio for periodic convergence Z_ERR_MAX = 1e-11 # periodic convergence threshold for deflection (m) NG_ERR_MAX = 1e-24 # periodic convergence threshold for gas content (mol) NCYCLES_MAX = 10 # max number of cycles in periodic simulations CHARGE_RANGE = (-300e-5, 150e-5) # physiological charge range constraining the membrane (C/m2) SOLVER_NSTEPS = 1000 # max number of steps during one ODE solver call CLASSIC_TARGET_DT = 1e-8 # target time step in output arrays of detailed simulations NPC_DENSE = 1000 # nb of samples per acoustic period in detailed simulations NPC_SPARSE = 40 # nb of samples per acoustic period in sparse simulations MIN_SPARSE_DT = 1e-12 # minimal time step used during sparse integration (s) HYBRID_UPDATE_INTERVAL = 5e-4 # time interval between two hybrid integrations (s) DT_EFFECTIVE = 5e-5 # time step for effective integration (s) MIN_SAMPLES_PER_PULSE_INTERVAL = 1 # minimal number of time points per pulse interval (TON of TOFF) MAX_NSAMPLES_EFFECTIVE = 1e5 # maximum number of time samples in effective simulations output # -------------------------- Post-processing -------------------------- +DT_MAX_REL_TOL = 1e-5 # max relative tolerance for time step irregularity SPIKE_MIN_DT = 5e-4 # minimal time interval for spike detection on charge signal (s) SPIKE_MIN_QAMP = 3e-5 # threshold amplitude for spike detection on charge signal (C/m2) SPIKE_MIN_QPROM = 20e-5 # threshold prominence for spike detection on charge signal (C/m2) SPIKE_MIN_VAMP = 3.0 # threshold amplitude for spike detection on potential signal (mV) SPIKE_MIN_VPROM = 20.0 # threshold prominence for spike detection on potential signal (mV) MIN_NSPIKES_SPECTRUM = 3 # minimum number of spikes to compute firing rate spectrum # -------------------------- Titrations -------------------------- ESTIM_AMP_UPPER_BOUND = 1e5 # initial current density upper bound for titration (mA/m2) ESTIM_AMP_INITIAL = 1e0 # initial ESTIM titration amplitude (mA/m2) ESTIM_REL_CONV_THR = 1e-2 # relative ESTIM titration convergence threshold ASTIM_AMP_INITIAL = 1e4 # initial ASTIM titration amplitude (Pa) ASTIM_ABS_CONV_THR = 1e2 # absolute ASTIM titration convergence threshold (Pa) ASTIM_REL_CONV_THR = 1e0 # relative ASTIM titration convergence threshold (Pa) # -------------------------- QSS stability analysis -------------------------- QSS_REL_OFFSET = .05 # relative state perturbation amplitude: s = s0 * (1 +/- x) QSS_HISTORY_INTERVAL = 30e-3 # recent history interval (s) QSS_INTEGRATION_INTERVAL = 1e-3 # iterative integration interval (s) QSS_MAX_INTEGRATION_DURATION = 1000e-3 # iterative integration interval (s) QSS_Q_CONV_THR = 1e-7 # max. charge deviation to infer convergence (C/m2) QSS_Q_DIV_THR = 1e-4 # min. charge deviation to infer divergence (C/m2) TMIN_STABILIZATION = 500e-3 # time window for stabilization analysis (s) def getConstantsDict(): cdict = {} for k, v in globals().items(): if not k.startswith('__') and k != 'getConstantsDict': cdict[k] = v return cdict diff --git a/PySONIC/postpro.py b/PySONIC/postpro.py index 8f3b79b..04f0bf6 100644 --- a/PySONIC/postpro.py +++ b/PySONIC/postpro.py @@ -1,409 +1,423 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-08-22 14:33:04 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-04-28 18:57:40 +# @Last Modified time: 2020-08-08 15:45:47 ''' Utility functions to detect spikes on signals and compute spiking metrics. ''' import numpy as np import pandas as pd from scipy.interpolate import interp1d from scipy.optimize import brentq -from scipy.signal import find_peaks, peak_prominences +from scipy.signal import find_peaks, peak_prominences, butter, sosfiltfilt from .constants import * from .utils import logger, isIterable, loadData def detectCrossings(x, thr=0.0, edge='both'): ''' Detect crossings of a threshold value in a 1D signal. :param x: 1D array_like data. :param edge: 'rising', 'falling', or 'both' :return: 1D array with the indices preceding the crossings ''' ine, ire, ife = np.array([[], [], []], dtype=int) x_padright = np.hstack((x, x[-1])) x_padleft = np.hstack((x[0], x)) if edge.lower() in ['falling', 'both']: ire = np.where((x_padright <= thr) & (x_padleft > thr))[0] if edge.lower() in ['rising', 'both']: ife = np.where((x_padright >= thr) & (x_padleft < thr))[0] ind = np.unique(np.hstack((ine, ire, ife))) - 1 return ind def getFixedPoints(x, dx, filter='stable', der_func=None): ''' Find fixed points in a 1D plane phase profile. :param x: variable (1D array) :param dx: derivative (1D array) :param filter: string indicating whether to consider only stable/unstable fixed points or both :param: der_func: derivative function :return: array of fixed points values (or None if none is found) ''' fps = [] edge = {'stable': 'falling', 'unstable': 'rising', 'both': 'both'}[filter] izc = detectCrossings(dx, edge=edge) if izc.size > 0: for i in izc: # If derivative function is provided, find root using iterative Brent method if der_func is not None: fps.append(brentq(der_func, x[i], x[i + 1], xtol=1e-16)) # Otherwise, approximate root by linear interpolation else: fps.append(x[i] - dx[i] * (x[i + 1] - x[i]) / (dx[i + 1] - dx[i])) return np.array(fps) else: return np.array([]) def getEqPoint1D(x, dx, x0): ''' Determine equilibrium point in a 1D plane phase profile, for a given starting point. :param x: variable (1D array) :param dx: derivative (1D array) :param x0: abscissa of starting point (float) :return: abscissa of equilibrium point (or np.nan if none is found) ''' # Find stable fixed points in 1D plane phase profile x_SFPs = getFixedPoints(x, dx, filter='stable') if x_SFPs.size == 0: return np.nan # Determine relevant stable fixed point from y0 sign y0 = np.interp(x0, x, dx, left=np.nan, right=np.nan) inds_subset = x_SFPs >= x0 ind_SFP = 0 if y0 < 0: inds_subset = ~inds_subset ind_SFP = -1 x_SFPs = x_SFPs[inds_subset] if len(x_SFPs) == 0: return np.nan return x_SFPs[ind_SFP] def convertTime2SampleCriterion(x, dt, nsamples): if isIterable(x) and len(x) == 2: return (convertTime2Sample(x[0], dt, nsamples), convertTime2Sample(x[1], dt, nsamples)) else: if isIterable(x) and len(x) == nsamples: return np.array([convertTime2Sample(item, dt, nsamples) for item in x]) elif x is None: return None else: return int(np.ceil(x / dt)) def computeTimeStep(t): ''' Compute time step based on time vector. :param t: time vector (s) :return: average time step (s) ''' - # Compute time step vector dt = np.diff(t) # s # Raise error if time step vector is not uniform - is_uniform_dt = np.allclose(np.diff(dt), np.zeros(dt.size - 1), atol=1e-5) - if not is_uniform_dt: - bounds_str = [f'{dt[i]:.2e} s (index {i})' for i in [dt.argmin(), dt.argmax()]] - raise ValueError(f'non-uniform time step: from {bounds_str[0]} to {bounds_str[1]}') + rel_dt_var = (dt.max() - dt.min()) / dt.min() + if rel_dt_var > DT_MAX_REL_TOL: + raise ValueError(f'irregular time step (rel. variance = {rel_dt_var:.2e})') # Return average dt value return np.mean(dt) # s def resample(t, y, dt): ''' Resample a dataframe at regular time step. ''' n = int(np.ptp(t) / dt) + 1 ts = np.linspace(t.min(), t.max(), n) ys = np.interp(ts, t, y) return ts, ys def resolveIndexes(indexes, y, choice='max'): if indexes.size == 0: return indexes icomp = np.array([np.floor(indexes), np.ceil(indexes)]).astype(int).T ycomp = np.array([y[i] for i in icomp]) method = {'min': np.argmin, 'max': np.argmax}[choice] ichoice = method(ycomp, axis=1) return np.array([x[ichoice[i]] for i, x in enumerate(icomp)]) def resampleDataFrame(data, dt): ''' Resample a dataframe at regular time step. ''' t = data['t'].values n = int(np.ptp(t) / dt) + 1 tnew = np.linspace(t.min(), t.max(), n) new_data = {} for key in data: kind = 'nearest' if key == 'stimstate' else 'linear' new_data[key] = interp1d(t, data[key].values, kind=kind)(tnew) return pd.DataFrame(new_data) def prependDataFrame(data, tonset=0.): ''' Add an initial value (for t = 0) to all columns of a dataframe. ''' # Repeat first row data = pd.concat([pd.DataFrame([data.iloc[0]]), data], ignore_index=True) data['t'][0] = tonset data['stimstate'][0] = 0 return data def boundDataFrame(data, tbounds): ''' Restrict all columns of a dataframe to indexes corresponding to time values within specific bounds. ''' tmin, tmax = tbounds return data[np.logical_and(data.t >= tmin, data.t <= tmax)].reset_index(drop=True) def find_tpeaks(t, y, **kwargs): ''' Wrapper around the scipy.signal.find_peaks function that provides a time vector associated to the signal, and translates time-based selection criteria into index-based criteria before calling the function. :param t: time vector :param y: signal vector :return: 2-tuple with peaks timings and properties dictionary ''' # Remove initial samples from vectors if time values are redundant ipad = 0 while t[ipad + 1] == t[ipad]: ipad += 1 if ipad > 0: ss = 'from vectors (redundant time values)' if ipad == 1: logger.debug(f'Removing index 0 {ss}') else: logger.debug(f'Removing indexes 0-{ipad - 1} {ss}') t = t[ipad:] y = y[ipad:] # If time step is irregular, resample vectors at a uniform time step try: dt = computeTimeStep(t) # s t_raw, y_raw = None, None indexes_raw = None except ValueError: new_dt = max(np.diff(t).min(), 1e-7) logger.debug(f'Resampling vector at regular time step (dt = {new_dt:.2e}s)') t_raw, y_raw = t.copy(), y.copy() indexes_raw = np.arange(t_raw.size) t, y = resample(t, y, new_dt) dt = computeTimeStep(t) # s # Convert provided time-based input criteria into samples-based criteria for key in ['distance', 'width', 'wlen', 'plateau_size']: if key in kwargs: kwargs[key] = convertTime2SampleCriterion(kwargs[key], dt, t.size) if 'width' not in kwargs: kwargs['width'] = 1 # Find peaks in the regularly sampled signal ipeaks, pps = find_peaks(y, **kwargs) # Adjust peak prominences and bases with restricted analysis window length # based on smallest peak width if len(ipeaks) > 0: wlen = 5 * min(pps['widths']) pps['prominences'], pps['left_bases'], pps['right_bases'] = peak_prominences( y, ipeaks, wlen=wlen) # If needed, re-project index-based outputs onto original sampling if t_raw is not None: logger.debug(f're-projecting index-based outputs onto original sampling') # Interpolate peak indexes and round to neighbor integer with max y value ipeaks_raw = np.interp(t[ipeaks], t_raw, indexes_raw, left=np.nan, right=np.nan) ipeaks = resolveIndexes(ipeaks_raw, y_raw, choice='max') # Interpolate peak base indexes and round to neighbor integer with min y value for key in ['left_bases', 'right_bases']: if key in pps: ibase_raw = np.interp( t[pps[key]], t_raw, indexes_raw, left=np.nan, right=np.nan) pps[key] = resolveIndexes(ibase_raw, y_raw, choice='min') # Interpolate peak half-width interpolated positions for key in ['left_ips', 'right_ips']: if key in pps: pps[key] = np.interp( dt * pps[key], t_raw, indexes_raw, left=np.nan, right=np.nan) # If original vectors were cropped, correct offset in index-based outputs if ipad > 0: logger.debug(f'offseting index-based outputs by {ipad} to compensate initial cropping') ipeaks += ipad for key in ['left_bases', 'right_bases', 'left_ips', 'right_ips']: if key in pps: pps[key] += ipad # Convert index-based peak widths into time-based widths if 'widths' in pps: pps['widths'] = np.array(pps['widths']) * dt # Return updated properties return ipeaks, pps def detectSpikes(data, key='Qm', mpt=SPIKE_MIN_DT, mph=SPIKE_MIN_QAMP, mpp=SPIKE_MIN_QPROM): ''' Detect spikes in simulation output data, by detecting peaks with specific height, prominence and distance properties on a given signal. :param data: simulation output dataframe :param key: key of signal on which to detect peaks :param mpt: minimal time interval between two peaks (s) :param mph: minimal peak height (in signal units) :param mpp: minimal peak prominence (in signal units) :return: indexes and properties of detected spikes ''' if key not in data: raise ValueError(f'{key} vector not available in dataframe') # Detect peaks return find_tpeaks( data['t'].values, data[key].values, height=mph, distance=mpt, prominence=mpp ) def convertPeaksProperties(t, properties): ''' Convert index-based peaks properties into time-based properties. :param t: time vector (s) :param properties: properties dictionary (with index-based information) :return: properties dictionary (with time-based information) ''' indexes = np.arange(t.size) for key in ['left_bases', 'right_bases', 'left_ips', 'right_ips']: if key in properties: properties[key] = np.interp(properties[key], indexes, t, left=np.nan, right=np.nan) return properties def computeFRProfile(data): ''' Compute temporal profile of firing rate from simulaton output. :param data: simulation output dataframe :return: firing rate profile interpolated along time vector ''' # Detect spikes in data ispikes, _ = detectSpikes(data) if len(ispikes) == 0: return np.ones(len(data)) * np.nan # Compute firing rate as function of spike time t = data['t'].values tspikes = t[ispikes][:-1] sr = 1 / np.diff(t[ispikes]) if len(sr) == 0: return np.ones(t.size) * np.nan # Interpolate firing rate vector along time vector return np.interp(t, tspikes, sr, left=np.nan, right=np.nan) def computeSpikingMetrics(outputs): ''' Analyze the charge density profile from a list of files and compute for each one of them the following spiking metrics: - latency (ms) - firing rate mean and standard deviation (Hz) - spike amplitude mean and standard deviation (nC/cm2) - spike width mean and standard deviation (ms) :param outputs: list / generator of simulation outputs :return: a dataframe with the computed metrics ''' # Initialize metrics dictionaries keys = [ 'latencies (ms)', 'mean firing rates (Hz)', 'std firing rates (Hz)', 'mean spike amplitudes (nC/cm2)', 'std spike amplitudes (nC/cm2)', 'mean spike widths (ms)', 'std spike widths (ms)' ] metrics = {k: [] for k in keys} # Compute spiking metrics for output in outputs: # Load data if isinstance(output, str): data, meta = loadData(output) else: data, meta = output tstim = meta['pp'].tstim t = data['t'].values # Detect spikes in data and extract features ispikes, properties = detectSpikes(data) widths = properties['widths'] prominences = properties['prominences'] if ispikes.size > 0: # Compute latency latency = t[ispikes[0]] # Select prior-offset spikes ispikes_prior = ispikes[t[ispikes] < tstim] else: latency = np.nan ispikes_prior = np.array([]) # Compute spikes widths and amplitude if ispikes_prior.size > 0: widths_prior = widths[:ispikes_prior.size] prominences_prior = prominences[:ispikes_prior.size] else: widths_prior = np.array([np.nan]) prominences_prior = np.array([np.nan]) # Compute inter-spike intervals and firing rates if ispikes_prior.size > 1: ISIs_prior = np.diff(t[ispikes_prior]) FRs_prior = 1 / ISIs_prior else: ISIs_prior = np.array([np.nan]) FRs_prior = np.array([np.nan]) # Log spiking metrics logger.debug('%u spikes detected (%u prior to offset)', ispikes.size, ispikes_prior.size) logger.debug('latency: %.2f ms', latency * 1e3) logger.debug('average spike width within stimulus: %.2f +/- %.2f ms', np.nanmean(widths_prior) * 1e3, np.nanstd(widths_prior) * 1e3) logger.debug('average spike amplitude within stimulus: %.2f +/- %.2f nC/cm2', np.nanmean(prominences_prior) * 1e5, np.nanstd(prominences_prior) * 1e5) logger.debug('average ISI within stimulus: %.2f +/- %.2f ms', np.nanmean(ISIs_prior) * 1e3, np.nanstd(ISIs_prior) * 1e3) logger.debug('average FR within stimulus: %.2f +/- %.2f Hz', np.nanmean(FRs_prior), np.nanstd(FRs_prior)) # Complete metrics dictionaries metrics['latencies (ms)'].append(latency * 1e3) metrics['mean firing rates (Hz)'].append(np.mean(FRs_prior)) metrics['std firing rates (Hz)'].append(np.std(FRs_prior)) metrics['mean spike amplitudes (nC/cm2)'].append(np.mean(prominences_prior) * 1e5) metrics['std spike amplitudes (nC/cm2)'].append(np.std(prominences_prior) * 1e5) metrics['mean spike widths (ms)'].append(np.mean(widths_prior) * 1e3) metrics['std spike widths (ms)'].append(np.std(widths_prior) * 1e3) # Return dataframe with metrics return pd.DataFrame(metrics, columns=metrics.keys()) + + +def filtfilt(y, fs, fc, order): + ''' Apply a bi-directional low-pass filter of specific order and cutoff frequency to a signal. + + :param y: signal vector + :param fs: sampling frequency + :param fc: cutoff frequency + :param order: filter order (must be even) + :return: filtered signal vector + + ..note: the filter order is divided by 2 since filtering is applied twice. + ''' + assert order % 2 == 0, 'filter order must be an even integer' + sos = butter(order // 2, fc, 'low', fs=fs, output='sos') + return sosfiltfilt(sos, y) diff --git a/scripts/plot_Cm_filtering.py b/scripts/plot_Cm_filtering.py new file mode 100644 index 0000000..e38c8b1 --- /dev/null +++ b/scripts/plot_Cm_filtering.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +# @Author: Theo Lemaire +# @Email: theo.lemaire@epfl.ch +# @Date: 2020-07-30 15:33:22 +# @Last Modified by: Theo Lemaire +# @Last Modified time: 2020-08-08 15:46:35 + +import logging +import numpy as np +from argparse import ArgumentParser +import matplotlib.pyplot as plt +from PySONIC.core import BilayerSonophore, AcousticDrive +from PySONIC.utils import logger +from PySONIC.postpro import filtfilt, computeTimeStep +from PySONIC.constants import NPC_DENSE + +logger.setLevel(logging.INFO) + +# Constants +MAX_PROFILES = 6 # max number of profiles to display simultaneously on figure + + +def invfiltfilt(y, *args, **kwargs): + ''' Inverse signal before and after filtering. ''' + return 1 / filtfilt(1 / y, *args, **kwargs) + + +def getCmProfiles(bls, drive, nreps): + ''' Simulate mechanical model with a specific drive and return extended + time, capacitance and capacitance sinusoidal approximation profiles. + ''' + data, _ = bls.simulate(drive, bls.Qm0) + logger.info('Extracting detailed capacitance profile') + Z_last = data.tail(NPC_DENSE)['Z'].values # m + Cm_last = bls.v_capacitance(Z_last) # F/m2 + Cm = np.tile(Cm_last, nreps) + t = np.linspace(0, nreps / drive.f, Cm.size) + gamma = np.ptp(Cm) / (2 * bls.Cm0) + logger.info(f'Generating corresponding pure sinusoid capacitance profile (gamma = {gamma:.2f})') + Cm_approx = bls.Cm0 * (1 + gamma * np.sin(2 * np.pi * f_US * t)) # F/m2 + return t, Cm, Cm_approx + + +def getSecondHalfAvg(x): + ''' Extract the effective capacitance from the second half of a capacitance profile. ''' + return np.squeeze(np.nanmean(x[x.shape[0] // 2:], axis=0)) + + +def plotRelCmfiltsVsCutoff(rel_fcs, rel_Cm, rel_Cmfilts, condition): + ''' Plot an original Cm profile and filtered variants at various cutoff frequencies. ''' + rsf = int(np.ceil(rel_fcs.size / MAX_PROFILES)) # potential resampling factor + colors = plt.get_cmap('tab10').colors + fig, ax = plt.subplots(figsize=(10, 4)) + ax.set_title(f'Cm profiles vs. cutoff ({condition})') + ax.set_xlabel('time (us)') + ax.set_ylabel('Cm / Cm0') + ax.plot(t * 1e6, rel_Cm, label='unfiltered', c='k') + ax.axhline(np.mean(rel_Cm), c='k', linestyle='--') + ax.axhline(1 / np.mean(1 / rel_Cm), c='k', linestyle=':') + for i, (rel_fc, rel_Cmfilt) in enumerate(zip(rel_fcs[::rsf], rel_Cmfilts[::rsf])): + ax.plot(t * 1e6, rel_Cmfilt, label=f'$f_c = {rel_fc:.2g}\\ f_{{US}}$', c=colors[i]) + ax.axhline(getSecondHalfAvg(rel_Cmfilt), c=colors[i], linestyle='--') + ax.legend() + fig.tight_layout() + return fig + + +def plotRelCmeffVsCutoff(rel_fcs, rel_Cmavgs, rel_Cmeffs, rel_Cmfilts, condition, colors=None): + ''' Plot effective capacitance as a function of cutoff frequency for various conditions. ''' + fig, ax = plt.subplots() + if colors is None: + colors = plt.get_cmap('tab10').colors + ax.set_title(f'Cmeff vs. cutoff - {condition}') + ax.set_xlabel('$f_c / f_{US}$') + ax.set_ylabel('$C_{m, eff} / C_{m0}$') + ax.set_xscale('log') + for (k, Cm), c in zip(rel_Cmfilts.items(), colors): + ax.plot(rel_fcs, getSecondHalfAvg(Cm.T), label=k, c=c) + ax.axhline(rel_Cmavgs[k], linestyle='--', c=c) + ax.axhline(rel_Cmeffs[k], linestyle=':', c=c) + # if gamma is not None: + # ax.axhline(np.sqrt(1 - gamma**2), c='k', linestyle='--', label='$\\sqrt{1 - \\gamma^2}$') + ax.legend() + fig.tight_layout() + return fig + + +if __name__ == '__main__': + + ap = ArgumentParser() + ap.add_argument('-p', '--plot', default=False, action='store_true', help='Plot profiles') + args = ap.parse_args() + + # Mechanical model + a = 32e-9 # m + Cm0 = 1e-2 # resting capacitance (F/m2) + Qm0 = 0.0 # resting charge density (C/m2) + bls = BilayerSonophore(a, Cm0, Qm0) + + # Acoustic parameters + freqs = np.array([20., 500., 4000.]) * 1e3 # Hz + amps = np.logspace(1, 3, 3) * 1e3 # Pa + + # Define colors + colors = list(plt.get_cmap('tab20c').colors) + del colors[3::4] + amps = amps[::-1] + + # Filter parameters + order = 2 # filter order + rel_fcs = np.logspace(-1, 3, 100) # relative cutoff frequencies w.r.t. US frequency + nreps = int(2 / rel_fcs.min()) # minimum number of acoustic cycles + + # Plot parameters + plot_profiles = args.plot + + variants = ['detailed', 'approx'] + rel_Cmavgs = {k: {} for k in variants} + rel_Cmeffs = {k: {} for k in variants} + rel_Cmfilts = {k: {} for k in variants} + + for f_US in freqs: + fcs = rel_fcs * f_US + for A_US in amps: + drive = AcousticDrive(f_US, A_US) + label = drive.desc + + # Get original Cm signal and sinusoidal approximation + t, *Cms = getCmProfiles(bls, drive, nreps) + + # Get sampling and Nyquist frequency from time signal + fs = 1 / computeTimeStep(t) + fnyq = fs / 2 + + # Warn if Nyquist frequency is lower than max cutoff frequency + if fcs.max() > fnyq: + logger.warning( + f'max cutoff {fcs.max() / fnyq:.2f} times higher than signal Nyquist') + + # For each Cm profile variant + for k, Cm in zip(variants, Cms): + # Normalize by resting capacitance + rel_Cm = Cm / Cm0 + + # Compute average and effective metrics + rel_Cmavgs[k][label] = rel_Cm.mean() + rel_Cmeffs[k][label] = 1 / np.mean(1 / rel_Cm) + + # Filter reciprocal at various cutoff frequencies (except those above nyquist) + rel_Cmfilts_list = [] + for fc in fcs[fcs <= fnyq]: + rel_Cmfilts_list.append(invfiltfilt(rel_Cm, fs, fc, order)) + for fc in fcs[fcs > fnyq]: + rel_Cmfilts_list.append(np.ones(rel_Cm.size) * np.nan) + rel_Cmfilts[k][label] = np.array(rel_Cmfilts_list) + + if plot_profiles: + # Plot filtered profiles for a subset of cutoff frequencies + plotRelCmfiltsVsCutoff(rel_fcs, rel_Cm, rel_Cmfilts[k][label], label) + + for k in variants: + # Plot effective capacitance as a function of cutoff frequency for various conditions + plotRelCmeffVsCutoff( + rel_fcs, rel_Cmavgs[k], rel_Cmeffs[k], rel_Cmfilts[k], k, colors=colors) + + plt.show()