diff --git a/PySONIC/constants.py b/PySONIC/constants.py index 9db9734..c51f3a0 100644 --- a/PySONIC/constants.py +++ b/PySONIC/constants.py @@ -1,55 +1,55 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-11-04 13:23:31 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-04-29 11:48:22 +# @Last Modified time: 2019-05-21 16:23:32 ''' Algorithmic constants used in the package. ''' # Biophysical constants FARADAY = 9.64853e4 # Faraday constant (C/mol) Rg = 8.31342 # Universal gas constant (Pa.m^3.mol^-1.K^-1 or J.mol^-1.K^-1) Z_Ca = 2 # Calcium valence Z_Na = 1 # Sodium valence Z_K = 1 # Potassium valence CELSIUS_2_KELVIN = 273.15 # Celsius to Kelvin conversion constant # Fitting and pre-processing LJFIT_PM_MAX = 1e8 # intermolecular pressure at the deflection lower bound for LJ fitting (Pa) PNET_EQ_MAX = 1e-1 # error threshold for net pressure at computed equilibrium position (Pa) PMAVG_STD_ERR_MAX = 3000 # error threshold in nonlinear fit of molecular pressure (Pa) # Mechanical simulations Z_ERR_MAX = 1e-11 # periodic convergence threshold for deflection (m) NG_ERR_MAX = 1e-24 # periodic convergence threshold for gas content (mol) NCYCLES_MAX = 10 # max number of acoustic cycles in mechanical simulations CHARGE_RANGE = (-200e-5, 150e-5) # physiological charge range constraining the membrane (C/m2) # E-STIM simulations -DT_ESTIM = 1e-5 +DT_ESTIM = 5e-5 # A-STIM simulations SOLVER_NSTEPS = 1000 # maximum number of steps allowed during one call to the LSODA/DOP853 solvers CLASSIC_TARGET_DT = 1e-8 # target temporal resolution for output arrays of classic simulations NPC_FULL = 1000 # nb of samples per acoustic period in full system NPC_HH = 40 # nb of samples per acoustic period in HH system DQ_UPDATE = 1e-5 # charge evolution threshold between two hybrid integrations (C/m2) DT_UPDATE = 5e-4 # time interval between two hybrid integrations (s) DT_EFF = 5e-5 # time step for effective integration (s) MIN_SAMPLES_PER_PULSE_INT = 1 # minimal number of time points per pulse interval (TON of TOFF) # Spike detection SPIKE_MIN_QAMP = 5e-5 # threshold amplitude for spike detection on charge signal (C/m2) SPIKE_MIN_QPROM = 20e-5 # threshold prominence for spike detection on charge signal (C/m2) SPIKE_MIN_VAMP = 10.0 # threshold amplitude for spike detection on potential signal (mV) SPIKE_MIN_VPROM = 20.0 # threshold prominence for spike detection on potential signal (mV) SPIKE_MIN_DT = 5e-4 # minimal time interval for spike detection on charge signal (s) MIN_NSPIKES_SPECTRUM = 3 # minimum number of spikes to compute firing rate spectrum # Titrations TITRATION_T_OFFSET = 200e-3 # offset period for titration procedures (s) TITRATION_ASTIM_DA_MAX = 1e2 # acoustic pressure search range threshold for titration (Pa) TITRATION_ESTIM_A_MAX = 50.0 # initial current density upper bound for titration (mA/m2) TITRATION_ESTIM_DA_MAX = 0.1 # current density search range threshold for titration (mA/m2) diff --git a/PySONIC/core/nbls.py b/PySONIC/core/nbls.py index 5402334..40e6e1a 100644 --- a/PySONIC/core/nbls.py +++ b/PySONIC/core/nbls.py @@ -1,858 +1,905 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-09-29 16:16:19 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-05-16 16:39:44 +# @Last Modified time: 2019-05-21 17:59:33 import os -import inspect import time import logging import pickle import progressbar as pb import numpy as np import pandas as pd -from scipy.integrate import ode, odeint +from scipy.integrate import ode, odeint, solve_ivp from scipy.interpolate import interp1d from .bls import BilayerSonophore from .pneuron import PointNeuron from ..utils import * from ..constants import * from ..postpro import findPeaks from ..batches import xlslog class NeuronalBilayerSonophore(BilayerSonophore): ''' This class inherits from the BilayerSonophore class and receives an PointNeuron instance at initialization, to define the electro-mechanical NICE model and its SONIC variant. ''' tscale = 'ms' # relevant temporal scale of the model defvar = 'Q' # default plot variable def __init__(self, a, neuron, Fdrive=None, embedding_depth=0.0): ''' Constructor of the class. :param a: in-plane radius of the sonophore structure within the membrane (m) :param neuron: neuron object :param Fdrive: frequency of acoustic perturbation (Hz) :param embedding_depth: depth of the embedding tissue around the membrane (m) ''' # Check validity of input parameters if not isinstance(neuron, PointNeuron): raise ValueError('Invalid neuron type: "{}" (must inherit from PointNeuron class)' .format(neuron.name)) self.neuron = neuron # Initialize BilayerSonophore parent object BilayerSonophore.__init__(self, a, neuron.Cm0, neuron.Cm0 * neuron.Vm0 * 1e-3, embedding_depth) def __repr__(self): return 'NeuronalBilayerSonophore({}m, {})'.format( si_format(self.a, precision=1, space=' '), self.neuron) def pprint(self): return '{}m radius NBLS - {} neuron'.format( si_format(self.a, precision=0, space=' '), self.neuron.name) def getPltVars(self, wrapleft='df["', wrapright='"]'): pltvars = super().getPltVars(wrapleft, wrapright) pltvars.update(self.neuron.getPltVars(wrapleft, wrapright)) return pltvars def getPltScheme(self): return self.neuron.getPltScheme() def fullDerivatives(self, y, t, Adrive, Fdrive, phi): ''' Compute the derivatives of the (n+3) ODE full NBLS system variables. :param y: vector of state variables :param t: specific instant in time (s) :param Adrive: acoustic drive amplitude (Pa) :param Fdrive: acoustic drive frequency (Hz) :param phi: acoustic drive phase (rad) :return: vector of derivatives ''' dydt_mech = BilayerSonophore.derivatives(self, y[:3], t, Adrive, Fdrive, y[3], phi) dydt_elec = self.neuron.Qderivatives(y[3:], t, self.Capct(y[1])) return dydt_mech + dydt_elec - def effDerivatives(self, y, t, interp_data): + def effDerivatives(self, y, t, lkp): ''' Compute the derivatives of the n-ODE effective HH system variables, based on 1-dimensional linear interpolation of "effective" coefficients that summarize the system's behaviour over an acoustic cycle. :param y: vector of HH system variables at time t :param t: specific instant in time (s) - :param interp_data: dictionary of 1D data points of "effective" coefficients + :param lkp: dictionary of 1D data points of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: vector of effective system derivatives at time t ''' # Split input vector explicitly Qm, *states = y # Compute charge and channel states variation - Vm = np.interp(Qm, interp_data['Q'], interp_data['V'], left=np.nan, right=np.nan) # mV - dQmdt = - self.neuron.iNet(Vm, states) * 1e-3 - dstates = self.neuron.derEffStates(Qm, states, interp_data) + Vmeff = self.neuron.interpVmeff(Qm, lkp) + dQmdt = - self.neuron.iNet(Vmeff, states) * 1e-3 + dstates = self.neuron.derEffStates(Qm, states, lkp) # Return derivatives vector return [dQmdt, *[dstates[k] for k in self.neuron.states]] + def evaluateStability(self, tint, Qm0, states0, lkp, Q_conv_thr, Q_div_thr): + ''' Integrate the effective differential system from a given starting point, + until clear convergence or clear divergence is found. + + :param tint: iterative integration interval (s) + :param Qm0: initial membrane charge density (C/m2) + :param states0: dictionary of initial states values + :param lkp: dictionary of 1D data points of "effective" coefficients + over the charge domain, for specific frequency and amplitude values. + :param Q_conv_thr: membrane charge density difference within an interval span + below which convergence is assumed + :param Q_div_thr: membrane charge density difference from initial value + above which divergence is assumed + :return: 2-tuple with convergence state and final charge density value + ''' + + # Initialize y0 vector + yf = np.array([Qm0] + list(states0.values())) + tf = 0. + conv = False + div = False + + # As long as there is no clear convergence or clear divergence w.r.t. Qm + while not conv and not div: + + # Re-initialize start time and initial states with previous end points + t0, y0 = tf, yf + + # Integrate system for small interval and retrieve results + sol = solve_ivp(lambda t, y: self.effDerivatives(y, t, lkp), [t0, t0 + tint], y0) + tf, yf = sol.t[-1], sol.y[:, -1] + Qmf = yf[0] + + # If charge deviation within last interval is small enough -> convergence + if np.abs(Qmf - sol.y[0, 0]) < Q_conv_thr: + conv = True + + # If charge deviation from the beginning is too large -> divergence + dQ = Qmf - Qm0 + if np.abs(dQ) > Q_div_thr: + div = True + + logger.debug('{}vergence after {:.0f} ms: dQ = {:.5f} nC/cm2'.format( + {True: 'con', False: 'di'}[conv], tf * 1e3, dQ * 1e5)) + + return Qmf, conv + + def runFull(self, Fdrive, Adrive, tstim, toffset, PRF, DC, phi=np.pi): ''' Compute solutions of the full electro-mechanical system for a specific set of US stimulation parameters, using a classic integration scheme. The first iteration uses the quasi-steady simplification to compute the initiation of motion from a flat leaflet configuration. Afterwards, the ODE system is solved iteratively until completion. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param phi: acoustic drive phase (rad) :return: 3-tuple with the time profile, the effective solution matrix and a state vector ''' # Determine system time step Tdrive = 1 / Fdrive dt = Tdrive / NPC_FULL # if CW stimulus: divide integration during stimulus into 100 intervals if DC == 1.0: PRF = 100 / tstim # Compute vector sizes npulses = int(np.round(PRF * tstim)) Tpulse_on = DC / PRF Tpulse_off = (1 - DC) / PRF n_pulse_on = int(np.round(Tpulse_on / dt)) n_pulse_off = int(np.round(Tpulse_off / dt)) n_off = int(np.round(toffset / dt)) # Solve quasi-steady equation to compute first deflection value Z0 = 0.0 ng0 = self.ng0 Qm0 = self.Qm0 Pac1 = self.Pacoustic(dt, Adrive, Fdrive, phi) Z1 = self.balancedefQS(ng0, Qm0, Pac1) # Initialize global arrays stimstate = np.array([1, 1]) t = np.array([0., dt]) y_membrane = np.array([[0., (Z1 - Z0) / dt], [Z0, Z1], [ng0, ng0], [Qm0, Qm0]]) steady_states = self.neuron.steadyStates(self.neuron.Vm0) y_channels = np.tile(np.array([steady_states[k] for k in self.neuron.states]), (2, 1)).T y = np.vstack((y_membrane, y_channels)) nvar = y.shape[0] # Initialize pulse time and stimstate vectors t_pulse0 = np.linspace(0, Tpulse_on + Tpulse_off, n_pulse_on + n_pulse_off) stimstate_pulse = np.concatenate((np.ones(n_pulse_on), np.zeros(n_pulse_off))) logger.debug('Computing detailed solution') # Initialize progress bar if logger.getEffectiveLevel() <= logging.INFO: widgets = ['Running: ', pb.Percentage(), ' ', pb.Bar(), ' ', pb.ETA()] pbar = pb.ProgressBar(widgets=widgets, max_value=int(npulses * (toffset + tstim) / tstim)) pbar.start() # Loop through all pulse (ON and OFF) intervals for i in range(npulses): # Construct and initialize arrays t_pulse = t_pulse0 + t[-1] y_pulse = np.empty((nvar, n_pulse_on + n_pulse_off)) # Integrate ON system y_pulse[:, :n_pulse_on] = odeint( self.fullDerivatives, y[:, -1], t_pulse[:n_pulse_on], args=(Adrive, Fdrive, phi)).T # Integrate OFF system if n_pulse_off > 0: y_pulse[:, n_pulse_on:] = odeint( self.fullDerivatives, y_pulse[:, n_pulse_on - 1], t_pulse[n_pulse_on:], args=(0.0, 0.0, 0.0)).T # Append pulse arrays to global arrays stimstate = np.concatenate([stimstate, stimstate_pulse[1:]]) t = np.concatenate([t, t_pulse[1:]]) y = np.concatenate([y, y_pulse[:, 1:]], axis=1) # Update progress bar if logger.getEffectiveLevel() <= logging.INFO: pbar.update(i) # Integrate offset interval if n_off > 0: t_off = np.linspace(0, toffset, n_off) + t[-1] stimstate_off = np.zeros(n_off) y_off = odeint(self.fullDerivatives, y[:, -1], t_off, args=(0.0, 0.0, 0.0)).T # Concatenate offset arrays to global arrays stimstate = np.concatenate([stimstate, stimstate_off[1:]]) t = np.concatenate([t, t_off[1:]]) y = np.concatenate([y, y_off[:, 1:]], axis=1) # Terminate progress bar if logger.getEffectiveLevel() <= logging.INFO: pbar.finish() # Downsample arrays in time-domain according to target temporal resolution ds_factor = int(np.round(CLASSIC_TARGET_DT / dt)) if ds_factor > 1: Fs = 1 / (dt * ds_factor) logger.info('Downsampling output arrays by factor %u (Fs = %.2f MHz)', ds_factor, Fs * 1e-6) t = t[::ds_factor] y = y[:, ::ds_factor] stimstate = stimstate[::ds_factor] # Compute membrane potential vector (in mV) Vm = y[3, :] / self.v_Capct(y[1, :]) * 1e3 # mV # Return output variables with Vm return (t, np.vstack([y[1:4, :], Vm, y[4:, :]]), stimstate) def runSONIC(self, Fdrive, Adrive, tstim, toffset, PRF, DC, dt=DT_EFF): ''' Compute solutions of the system for a specific set of US stimulation parameters, using charge-predicted "effective" coefficients to solve the HH equations at each step. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param dt: integration time step (s) :return: 3-tuple with the time profile, the effective solution matrix and a state vector ''' # Load appropriate 2D lookups Aref, Qref, lookups2D, _ = getLookups2D(self.neuron.name, a=self.a, Fdrive=Fdrive) # Check that acoustic amplitude is within lookup range Adrive = isWithin('amplitude', Adrive, (Aref.min(), Aref.max())) # Interpolate 2D lookups at zero and US amplitude logger.debug('Interpolating lookups at A = %.2f kPa and A = 0', Adrive * 1e-3) lookups_on = {key: interp1d(Aref, y2D, axis=0)(Adrive) for key, y2D in lookups2D.items()} lookups_off = {key: interp1d(Aref, y2D, axis=0)(0.0) for key, y2D in lookups2D.items()} # Add reference charge vector to 1D lookup dictionaries lookups_on['Q'] = Qref lookups_off['Q'] = Qref # if CW stimulus: change PRF to have exactly one integration interval during stimulus if DC == 1.0: PRF = 1 / tstim # Compute vector sizes npulses = int(np.round(PRF * tstim)) Tpulse_on = DC / PRF Tpulse_off = (1 - DC) / PRF # For high-PRF pulsed protocols: adapt time step to ensure minimal # number of samples during TON or TOFF dt_warning_msg = 'high-PRF protocol: lowering time step to %.2e s to properly integrate %s' for key, Tpulse in {'TON': Tpulse_on, 'TOFF': Tpulse_off}.items(): if Tpulse > 0 and Tpulse / dt < MIN_SAMPLES_PER_PULSE_INT: dt = Tpulse / MIN_SAMPLES_PER_PULSE_INT logger.warning(dt_warning_msg, dt, key) n_pulse_on = int(np.round(Tpulse_on / dt)) + 1 n_pulse_off = int(np.round(Tpulse_off / dt)) # Compute offset size n_off = int(np.round(toffset / dt)) # Initialize global arrays stimstate = np.array([1]) t = np.array([0.0]) steady_states = self.neuron.steadyStates(self.neuron.Vm0) y = np.atleast_2d(np.insert( np.array([steady_states[k] for k in self.neuron.states]), 0, self.Qm0)).T nvar = y.shape[0] # Initializing accurate pulse time vector t_pulse_on = np.linspace(0, Tpulse_on, n_pulse_on) t_pulse_off = np.linspace(dt, Tpulse_off, n_pulse_off) + Tpulse_on t_pulse0 = np.concatenate([t_pulse_on, t_pulse_off]) stimstate_pulse = np.concatenate((np.ones(n_pulse_on), np.zeros(n_pulse_off))) logger.debug('Computing effective solution') # Loop through all pulse (ON and OFF) intervals for i in range(npulses): # Construct and initialize arrays t_pulse = t_pulse0 + t[-1] y_pulse = np.empty((nvar, n_pulse_on + n_pulse_off)) y_pulse[:, 0] = y[:, -1] # Integrate ON system y_pulse[:, :n_pulse_on] = odeint( self.effDerivatives, y[:, -1], t_pulse[:n_pulse_on], args=(lookups_on, )).T # Integrate OFF system if n_pulse_off > 0: y_pulse[:, n_pulse_on:] = odeint( self.effDerivatives, y_pulse[:, n_pulse_on - 1], t_pulse[n_pulse_on:], args=(lookups_off, )).T # Append pulse arrays to global arrays stimstate = np.concatenate([stimstate[:-1], stimstate_pulse]) t = np.concatenate([t, t_pulse[1:]]) y = np.concatenate([y, y_pulse[:, 1:]], axis=1) # Integrate offset interval if n_off > 0: t_off = np.linspace(0, toffset, n_off) + t[-1] y_off = odeint(self.effDerivatives, y[:, -1], t_off, args=(lookups_off, )).T # Concatenate offset arrays to global arrays stimstate = np.concatenate([stimstate, np.zeros(n_off - 1)]) t = np.concatenate([t, t_off[1:]]) y = np.concatenate([y, y_off[:, 1:]], axis=1) # Compute effective gas content vector ngeff = np.zeros(stimstate.size) ngeff[stimstate == 0] = np.interp(y[0, stimstate == 0], lookups_on['Q'], lookups_on['ng'], left=np.nan, right=np.nan) # mole ngeff[stimstate == 1] = np.interp(y[0, stimstate == 1], lookups_off['Q'], lookups_off['ng'], left=np.nan, right=np.nan) # mole # Compute quasi-steady deflection vector Zeff = np.array([self.balancedefQS(ng, Qm) for ng, Qm in zip(ngeff, y[0, :])]) # m # Compute membrane potential vector (in mV) Vm = np.zeros(stimstate.size) Vm[stimstate == 1] = np.interp(y[0, stimstate == 1], lookups_on['Q'], lookups_on['V'], left=np.nan, right=np.nan) # mV Vm[stimstate == 0] = np.interp(y[0, stimstate == 0], lookups_off['Q'], lookups_off['V'], left=np.nan, right=np.nan) # mV # Add Zeff, ngeff and Vm to solution matrix y = np.vstack([Zeff, ngeff, y[0, :], Vm, y[1:, :]]) # return output variables return (t, y, stimstate) def runHybrid(self, Fdrive, Adrive, tstim, toffset, phi=np.pi): ''' Compute solutions of the system for a specific set of US stimulation parameters, using a hybrid integration scheme. The first iteration uses the quasi-steady simplification to compute the initiation of motion from a flat leaflet configuration. Afterwards, the NBLS ODE system is solved iteratively for "slices" of N-microseconds, in a 2-steps scheme: - First, the full (n+3) ODE system is integrated for a few acoustic cycles until Z and ng reach a stable periodic solution (limit cycle) - Second, the signals of the 3 mechanical variables over the last acoustic period are selected and resampled to a far lower sampling rate - Third, the HH n-ODE system is integrated for the remaining time of the slice, using periodic expansion of the mechanical signals to precompute the values of capacitance. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param phi: acoustic drive phase (rad) :return: 3-tuple with the time profile, the solution matrix and a state vector .. warning:: This method cannot handle pulsed stimuli ''' # Initialize full and HH systems solvers solver_full = ode( lambda t, y, Adrive, Fdrive, phi: self.fullDerivatives(y, t, Adrive, Fdrive, phi)) solver_full.set_f_params(Adrive, Fdrive, phi) solver_full.set_integrator('lsoda', nsteps=SOLVER_NSTEPS) solver_hh = ode(lambda t, y, Cm: self.neuron.Qderivatives(y, t, Cm)) solver_hh.set_integrator('dop853', nsteps=SOLVER_NSTEPS, atol=1e-12) # Determine full and HH systems time steps Tdrive = 1 / Fdrive dt_full = Tdrive / NPC_FULL dt_hh = Tdrive / NPC_HH n_full_per_hh = int(NPC_FULL / NPC_HH) t_full_cycle = np.linspace(0, Tdrive - dt_full, NPC_FULL) t_hh_cycle = np.linspace(0, Tdrive - dt_hh, NPC_HH) # Determine number of samples in prediction vectors npc_pred = NPC_FULL - n_full_per_hh + 1 # Solve quasi-steady equation to compute first deflection value Z0 = 0.0 ng0 = self.ng0 Qm0 = self.Qm0 Pac1 = self.Pacoustic(dt_full, Adrive, Fdrive, phi) Z1 = self.balancedefQS(ng0, Qm0, Pac1) # Initialize global arrays stimstate = np.array([1, 1]) t = np.array([0., dt_full]) y_membrane = np.array([[0., (Z1 - Z0) / dt_full], [Z0, Z1], [ng0, ng0], [Qm0, Qm0]]) steady_states = self.neuron.steadyStates(self.neuron.Vm0) y_channels = np.tile(np.array([steady_states[k] for k in self.neuron.states]), (2, 1)).T y = np.vstack((y_membrane, y_channels)) nvar = y.shape[0] # Initialize progress bar if logger.getEffectiveLevel() == logging.DEBUG: widgets = ['Running: ', pb.Percentage(), ' ', pb.Bar(), ' ', pb.ETA()] pbar = pb.ProgressBar(widgets=widgets, max_value=1000) pbar.start() # For each hybrid integration interval irep = 0 sim_error = False while not sim_error and t[-1] < tstim + toffset: # Integrate full system for a few acoustic cycles until stabilization periodic_conv = False j = 0 ng_last = None Z_last = None while not sim_error and not periodic_conv: if t[-1] > tstim: solver_full.set_f_params(0.0, 0.0, 0.0) t_full = t_full_cycle + t[-1] + dt_full y_full = np.empty((nvar, NPC_FULL)) y0_full = y[:, -1] solver_full.set_initial_value(y0_full, t[-1]) k = 0 while solver_full.successful() and k <= NPC_FULL - 1: solver_full.integrate(t_full[k]) y_full[:, k] = solver_full.y k += 1 # Compare Z and ng signals over the last 2 acoustic periods if j > 0 and rmse(Z_last, y_full[1, :]) < Z_ERR_MAX \ and rmse(ng_last, y_full[2, :]) < NG_ERR_MAX: periodic_conv = True # Update last vectors for next comparison Z_last = y_full[1, :] ng_last = y_full[2, :] # Concatenate time and solutions to global vectors stimstate = np.concatenate([stimstate, np.ones(NPC_FULL)], axis=0) t = np.concatenate([t, t_full], axis=0) y = np.concatenate([y, y_full], axis=1) # Increment loop index j += 1 # Retrieve last period of the 3 mechanical variables to propagate in HH system t_last = t[-npc_pred:] mech_last = y[0:3, -npc_pred:] # Downsample signals to specified HH system time step (_, mech_pred) = downsample(t_last, mech_last, NPC_HH) # Integrate HH system until certain dQ or dT is reached Q0 = y[3, -1] dQ = 0.0 t0_interval = t[-1] dt_interval = 0.0 j = 0 if t[-1] < tstim: tlim = tstim else: tlim = tstim + toffset while (not sim_error and t[-1] < tlim and (np.abs(dQ) < DQ_UPDATE or dt_interval < DT_UPDATE)): t_hh = t_hh_cycle + t[-1] + dt_hh y_hh = np.empty((nvar - 3, NPC_HH)) y0_hh = y[3:, -1] solver_hh.set_initial_value(y0_hh, t[-1]) k = 0 while solver_hh.successful() and k <= NPC_HH - 1: solver_hh.set_f_params(self.Capct(mech_pred[1, k])) solver_hh.integrate(t_hh[k]) y_hh[:, k] = solver_hh.y k += 1 # Concatenate time and solutions to global vectors stimstate = np.concatenate([stimstate, np.zeros(NPC_HH)], axis=0) t = np.concatenate([t, t_hh], axis=0) y = np.concatenate([y, np.concatenate([mech_pred, y_hh], axis=0)], axis=1) # Compute charge variation from interval beginning dQ = y[3, -1] - Q0 dt_interval = t[-1] - t0_interval # Increment loop index j += 1 # Update progress bar if logger.getEffectiveLevel() == logging.DEBUG: pbar.update(int(1000 * (t[-1] / (tstim + toffset)))) irep += 1 # Terminate progress bar if logger.getEffectiveLevel() == logging.DEBUG: pbar.finish() # Compute membrane potential vector (in mV) Vm = y[3, :] / self.v_Capct(y[1, :]) * 1e3 # mV # Return output variables with Vm return (t, np.vstack([y[1:4, :], Vm, y[4:, :]]), stimstate) def checkInputsFull(self, Fdrive, Adrive, tstim, toffset, PRF, DC, method): ''' Check validity of simulation parameters. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param method: selected integration method :return: 3-tuple with the time profile, the solution matrix and a state vector ''' BilayerSonophore.checkInputs(self, Fdrive, Adrive, 0.0, 0.0) self.neuron.checkInputs(Adrive, tstim, toffset, PRF, DC) # Check validity of simulation type if method not in ('full', 'hybrid', 'sonic'): raise ValueError('Invalid integration method: "{}"'.format(method)) def simulate(self, Fdrive, Adrive, tstim, toffset, PRF=None, DC=1.0, method='sonic'): ''' Run simulation of the system for a specific set of US stimulation parameters. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param method: selected integration method :return: 3-tuple with the time profile, the solution matrix and a state vector ''' # Check validity of stimulation parameters self.checkInputsFull(Fdrive, Adrive, tstim, toffset, PRF, DC, method) # Call appropriate simulation function if method == 'full': return self.runFull(Fdrive, Adrive, tstim, toffset, PRF, DC) elif method == 'sonic': return self.runSONIC(Fdrive, Adrive, tstim, toffset, PRF, DC) elif method == 'hybrid': if DC < 1.0: raise ValueError('Pulsed protocol incompatible with hybrid integration method') return self.runHybrid(Fdrive, Adrive, tstim, toffset) def nSpikes(self, Adrive, Fdrive, tstim, toffset, PRF, DC, method): ''' Run a simulation and determine number of spikes in the response. :param Adrive: acoustic amplitude (Pa) :param Fdrive: US frequency (Hz) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :return: number of spikes found in response ''' t, y, _ = self.simulate(Fdrive, Adrive, tstim, toffset, PRF, DC, method=method) dt = t[1] - t[0] ipeaks, *_ = findPeaks(y[2, :], SPIKE_MIN_QAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_QPROM) nspikes = ipeaks.size logger.debug('A = %sPa ---> %s spike%s detected', si_format(Adrive, 2, space=' '), nspikes, "s" if nspikes > 1 else "") return nspikes def titrate(self, Fdrive, tstim, toffset, PRF=None, DC=1.0, Arange=None, method='sonic'): ''' Use a binary search to determine the threshold amplitude needed to obtain neural excitation for a given frequency, duration, PRF and duty cycle. :param Fdrive: US frequency (Hz) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param Arange: search interval for Adrive, iteratively refined :return: determined threshold amplitude (Pa) ''' # Determine amplitude interval if needed if Arange is None: Arange = (0, getLookups2D(self.neuron.name, a=self.a, Fdrive=Fdrive)[0].max()) # Titrate return titrate(self.nSpikes, (Fdrive, tstim, toffset, PRF, DC, method), Arange, TITRATION_ASTIM_DA_MAX) def runAndSave(self, outdir, Fdrive, tstim, toffset, PRF=None, DC=1.0, Adrive=None, method='sonic'): ''' Run a simulation of the full electro-mechanical system for a given neuron type with specific parameters, and save the results in a PKL file. :param outdir: full path to output directory :param Fdrive: US frequency (Hz) :param tstim: stimulus duration (s) :param toffset: stimulus offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: stimulus duty cycle (-) :param Adrive: acoustic pressure amplitude (Pa) :param method: integration method ''' # Get date and time info date_str = time.strftime("%Y.%m.%d") daytime_str = time.strftime("%H:%M:%S") logger.info( '%s: %s @ f = %sHz, %st = %ss (%ss offset)%s', self, 'titration' if Adrive is None else 'simulation', si_format(Fdrive, 0, space=' '), 'A = {}Pa, '.format(si_format(Adrive, 2, space=' ')) if Adrive is not None else '', *si_format([tstim, toffset], 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format(si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) if Adrive is None: Adrive = self.titrate(Fdrive, tstim, toffset, PRF, DC, method=method) if np.isnan(Adrive): logger.error('Could not find threshold excitation amplitude') return None # Run simulation tstart = time.time() t, y, stimstate = self.simulate(Fdrive, Adrive, tstim, toffset, PRF, DC, method=method) tcomp = time.time() - tstart Z, ng, Qm, Vm, *channels = y # Detect spikes on Qm signal dt = t[1] - t[0] ipeaks, *_ = findPeaks(Qm, SPIKE_MIN_QAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_QPROM) nspikes = ipeaks.size lat = t[ipeaks[0]] if nspikes > 0 else 'N/A' outstr = '{} spike{} detected'.format(nspikes, 's' if nspikes > 1 else '') logger.debug('completed in %ss, %s', si_format(tcomp, 1), outstr) sr = np.mean(1 / np.diff(t[ipeaks])) if nspikes > 1 else None # Store dataframe and metadata U = np.insert(np.diff(Z) / np.diff(t), 0, 0.0) df = pd.DataFrame({ 't': t, 'stimstate': stimstate, 'U': U, 'Z': Z, 'ng': ng, 'Qm': Qm, 'Vm': Vm }) for j in range(len(self.neuron.states)): df[self.neuron.states[j]] = channels[j] meta = { 'neuron': self.neuron.name, 'a': self.a, 'd': self.d, 'Fdrive': Fdrive, 'Adrive': Adrive, 'phi': np.pi, 'tstim': tstim, 'toffset': toffset, 'PRF': PRF, 'DC': DC, 'tcomp': tcomp, 'method': method } # Export into to PKL file simcode = ASTIM_filecode(self.neuron.name, self.a, Fdrive, Adrive, tstim, PRF, DC, method) outpath = '{}/{}.pkl'.format(outdir, simcode) with open(outpath, 'wb') as fh: pickle.dump({'meta': meta, 'data': df}, fh) logger.debug('simulation data exported to "%s"', outpath) # Export key metrics to log file logpath = os.path.join(outdir, 'log_ASTIM.xlsx') logentry = { 'Date': date_str, 'Time': daytime_str, 'Neuron Type': self.neuron.name, 'Radius (nm)': self.a * 1e9, 'Thickness (um)': self.d * 1e6, 'Fdrive (kHz)': Fdrive * 1e-3, 'Adrive (kPa)': Adrive * 1e-3, 'Tstim (ms)': tstim * 1e3, 'PRF (kHz)': PRF * 1e-3 if DC < 1 else 'N/A', 'Duty factor': DC, 'Sim. Type': method, '# samples': t.size, 'Comp. time (s)': round(tcomp, 2), '# spikes': nspikes, 'Latency (ms)': lat * 1e3 if isinstance(lat, float) else 'N/A', 'Spike rate (sp/ms)': sr * 1e-3 if isinstance(sr, float) else 'N/A' } if xlslog(logpath, logentry) == 1: logger.debug('log exported to "%s"', logpath) else: logger.error('log export to "%s" aborted', self.logpath) return outpath def quasiSteadyStates(self, Fdrive, amps=None, charges=None, DCs=1.0, squeeze_output=False): ''' Compute the quasi-steady state values of the neuron's gating variables for a combination of US amplitudes, charge densities and duty cycles, at a specific US frequency. :param Fdrive: US frequency (Hz) :param amps: US amplitudes (Pa) :param charges: membrane charge densities (C/m2) :param DCs: duty cycle value(s) :return: 4-tuple with reference values of US amplitude and charge density, as well as interpolated Vmeff and QSS gating variables ''' # Get DC-averaged lookups interpolated at the appropriate amplitudes and charges amps, charges, lookups = getLookupsDCavg( self.neuron.name, self.a, Fdrive, amps, charges, DCs) # Compute QSS states using these lookups nA, nQ, nDC = lookups['V'].shape QSS = {k: np.empty((nA, nQ, nDC)) for k in self.neuron.states} for iA in range(nA): for iDC in range(nDC): QSS_1D = self.neuron.quasiSteadyStates( {k: v[iA, :, iDC] for k, v in lookups.items()}) for k in QSS.keys(): QSS[k][iA, :, iDC] = QSS_1D[k] # Compress outputs if needed if squeeze_output: QSS = {k: v.squeeze() for k, v in QSS.items()} lookups = {k: v.squeeze() for k, v in lookups.items()} # Return reference inputs and outputs return amps, charges, lookups, QSS def findRheobaseAmps(self, DCs, Fdrive, Vthr): ''' Find the rheobase amplitudes (i.e. threshold acoustic amplitudes of infinite duration that would result in excitation) of a specific neuron for various duty cycles. :param DCs: duty cycles vector (-) :param Fdrive: acoustic drive frequency (Hz) :param Vthr: threshold membrane potential above which the neuron necessarily fires (mV) :return: rheobase amplitudes vector (Pa) ''' # Get threshold charge from neuron's spike threshold parameter Qthr = self.neuron.Cm0 * Vthr * 1e-3 # C/m2 # Get QSS variables for each amplitude at threshold charge Aref, _, Vmeff, QS_states = self.quasiSteadyStates(Fdrive, charges=Qthr, DCs=DCs) if DCs.size == 1: QS_states = QS_states.reshape((*QS_states.shape, 1)) Vmeff = Vmeff.reshape((*Vmeff.shape, 1)) # Compute 2D QSS charge variation array at Qthr dQdt = -self.neuron.iNet(Vmeff, QS_states) # Find the threshold amplitude that cancels dQdt for each duty cycle Arheobase = np.array([np.interp(0, dQdt[:, i], Aref, left=0., right=np.nan) for i in range(DCs.size)]) # Check if threshold amplitude is found for all DCs inan = np.where(np.isnan(Arheobase))[0] if inan.size > 0: if inan.size == Arheobase.size: logger.error( 'No rheobase amplitudes within [%s - %sPa] for the provided duty cycles', *si_format((Aref.min(), Aref.max()))) else: minDC = DCs[inan.max() + 1] logger.warning( 'No rheobase amplitudes within [%s - %sPa] below %.1f%% duty cycle', *si_format((Aref.min(), Aref.max())), minDC * 1e2) return Arheobase, Aref def computeEffVars(self, Fdrive, Adrive, Qm, fs): ''' Compute "effective" coefficients of the HH system for a specific combination of stimulus frequency, stimulus amplitude and charge density. A short mechanical simulation is run while imposing the specific charge density, until periodic stabilization. The HH coefficients are then averaged over the last acoustic cycle to yield "effective" coefficients. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param Qm: imposed charge density (C/m2) :param fs: list of sonophore membrane coverage fractions :return: list with computation time and a list of dictionaries of effective variables ''' tstart = time.time() # Run simulation and retrieve deflection and gas content vectors from last cycle _, [Z, ng], _ = BilayerSonophore.simulate(self, Fdrive, Adrive, Qm) Z_last = Z[-NPC_FULL:] # m Cm_last = self.v_Capct(Z_last) # F/m2 # For each coverage fraction effvars = [] for x in fs: # Compute membrane capacitance and membrane potential vectors Cm = x * Cm_last + (1 - x) * self.Cm0 # F/m2 Vm = Qm / Cm * 1e3 # mV # Compute average cycle value for membrane potential and rate constants effvars.append({'V': np.mean(Vm)}) effvars[-1].update(self.neuron.computeEffRates(Vm)) tcomp = time.time() - tstart # Log process log = '{}: lookups @ {}Hz, {}Pa, {:.2f} nC/cm2'.format( self, *si_format([Fdrive, Adrive], precision=1, space=' '), Qm * 1e5) if len(fs) > 1: log += ', fs = {:.0f} - {:.0f}%'.format(fs.min() * 1e2, fs.max() * 1e2) log += ', tcomp = {:.3f} s'.format(tcomp) logger.info(log) # Return effective coefficients return [tcomp, effvars] diff --git a/PySONIC/core/pneuron.py b/PySONIC/core/pneuron.py index cbb2b48..4183ca9 100644 --- a/PySONIC/core/pneuron.py +++ b/PySONIC/core/pneuron.py @@ -1,656 +1,657 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-08-03 11:53:04 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-05-17 16:14:20 +# @Last Modified time: 2019-05-21 14:13:22 import os import time import pickle import abc import inspect import re import numpy as np from scipy.integrate import odeint import pandas as pd from ..postpro import findPeaks from ..constants import * from ..utils import si_format, logger, ESTIM_filecode, titrate from ..batches import xlslog class PointNeuron(metaclass=abc.ABCMeta): ''' Abstract class defining the common API (i.e. mandatory attributes and methods) of all subclasses implementing the channels mechanisms of specific point neurons. ''' tscale = 'ms' # relevant temporal scale of the model defvar = 'V' # default plot variable def __repr__(self): return self.__class__.__name__ def pprint(self): return '{} neuron'.format(self.__class__.__name__) @property @abc.abstractmethod def name(self): return 'Should never reach here' @property @abc.abstractmethod def Cm0(self): return 'Should never reach here' @property @abc.abstractmethod def Vm0(self): return 'Should never reach here' @abc.abstractmethod def currents(self, Vm, states): ''' Compute all ionic currents per unit area. :param Vm: membrane potential (mV) :states: state probabilities of the ion channels :return: dictionary of ionic currents per unit area (mA/m2) ''' def iNet(self, Vm, states): ''' net membrane current :param Vm: membrane potential (mV) :states: states of ion channels gating and related variables :return: current per unit area (mA/m2) ''' return sum(self.currents(Vm, states).values()) def dQdt(self, Vm, states): ''' membrane charge density variation rate :param Vm: membrane potential (mV) :states: states of ion channels gating and related variables :return: variation rate (mA/m2) ''' return -self.iNet(Vm, states) def isTitratable(self): ''' Simple method returning whether the neuron can be titrated (defaults to True). ''' return True def currentToConcentrationRate(self, z_ion, depth): ''' Compute the conversion factor from a specific ionic current (in mA/m2) into a variation rate of submembrane ion concentration (in M/s). :param: z_ion: ion valence :param depth: submembrane depth (m) :return: conversion factor (Mmol.m-1.C-1) ''' return 1e-6 / (z_ion * depth * FARADAY) def nernst(self, z_ion, Cion_in, Cion_out, T): ''' Nernst potential of a specific ion given its intra and extracellular concentrations. :param z_ion: ion valence :param Cion_in: intracellular ion concentration :param Cion_out: extracellular ion concentration :param T: temperature (K) :return: ion Nernst potential (mV) ''' return (Rg * T) / (z_ion * FARADAY) * np.log(Cion_out / Cion_in) * 1e3 def vtrap(self, x, y): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x / y) - 1) def efun(self, x): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x) - 1) def ghkDrive(self, Vm, Z_ion, Cion_in, Cion_out, T): ''' Use the Goldman-Hodgkin-Katz equation to compute the electrochemical driving force of a specific ion species for a given membrane potential. :param Vm: membrane potential (mV) :param Cin: intracellular ion concentration (M) :param Cout: extracellular ion concentration (M) :param T: temperature (K) :return: electrochemical driving force of a single ion particle (mC.m-3) ''' x = Z_ion * FARADAY * Vm / (Rg * T) * 1e-3 # [-] eCin = Cion_in * self.efun(-x) # M eCout = Cion_out * self.efun(x) # M return FARADAY * (eCin - eCout) * 1e6 # mC/m3 def getDesc(self): return inspect.getdoc(self).splitlines()[0] def getCurrentsNames(self): return list(self.currents(np.nan, [np.nan] * len(self.states)).keys()) def getPltScheme(self): pltscheme = { 'Q_m': ['Qm'], 'V_m': ['Vm'] } pltscheme['I'] = self.getCurrentsNames() + ['iNet'] for cname in self.getCurrentsNames(): if 'Leak' not in cname: key = 'i_{{{}}}\ kin.'.format(cname[1:]) cargs = inspect.getargspec(getattr(self, cname))[0][1:] pltscheme[key] = [var for var in cargs if var not in ['Vm', 'Cai']] return pltscheme def getPltVars(self, wrapleft='df["', wrapright='"]'): ''' Return a dictionary with information about all plot variables related to the neuron. ''' pltvars = { 'Qm': { 'desc': 'membrane charge density', 'label': 'Q_m', 'unit': 'nC/cm^2', 'factor': 1e5, 'bounds': (-100, 50) }, 'Vm': { 'desc': 'membrane potential', 'label': 'V_m', 'unit': 'mV', 'y0': self.Vm0, 'bounds': (-150, 70) }, 'ELeak': { 'constant': 'obj.ELeak', 'desc': 'non-specific leakage current resting potential', 'label': 'V_{leak}', 'unit': 'mV', 'ls': '--', 'color': 'k' } } for cname in self.getCurrentsNames(): cfunc = getattr(self, cname) cargs = inspect.getargspec(cfunc)[0][1:] pltvars[cname] = { 'desc': inspect.getdoc(cfunc).splitlines()[0], 'label': 'I_{{{}}}'.format(cname[1:]), 'unit': 'A/m^2', 'factor': 1e-3, 'func': '{}({})'.format(cname, ', '.join(['{}{}{}'.format(wrapleft, a, wrapright) for a in cargs])) } for var in cargs: if var not in ['Vm', 'Cai']: vfunc = getattr(self, 'der{}{}'.format(var[0].upper(), var[1:])) desc = cname + re.sub('^Evolution of', '', inspect.getdoc(vfunc).splitlines()[0]) pltvars[var] = { 'desc': desc, 'label': var, 'bounds': (-0.1, 1.1) } pltvars['iNet'] = { 'desc': inspect.getdoc(getattr(self, 'iNet')).splitlines()[0], 'label': 'I_{net}', 'unit': 'A/m^2', 'factor': 1e-3, 'func': 'iNet({0}Vm{1}, {2}{3}{4}.values.T)'.format( wrapleft, wrapright, wrapleft[:-1], self.states, wrapright[1:]), 'ls': '--', 'color': 'black' } pltvars['dQdt'] = { 'desc': inspect.getdoc(getattr(self, 'dQdt')).splitlines()[0], 'label': 'dQ_m/dt', 'unit': 'A/m^2', 'factor': 1e-3, 'func': 'dQdt({0}Vm{1}, {2}{3}{4}.values.T)'.format( wrapleft, wrapright, wrapleft[:-1], self.states, wrapright[1:]), 'ls': '--', 'color': 'black' } for x in self.getGates(): for rate in ['alpha', 'beta']: pltvars['{}{}'.format(rate, x)] = { 'label': '\\{}_{{{}}}'.format(rate, x), 'unit': 'ms^{-1}', 'factor': 1e-3 } return pltvars def getRatesNames(self, states): return list(sum( [['alpha{}'.format(x.lower()), 'beta{}'.format(x.lower())] for x in states], [] )) def Qm0(self): ''' Return the resting charge density (in C/m2). ''' return self.Cm0 * self.Vm0 * 1e-3 # C/cm2 @abc.abstractmethod def steadyStates(self, Vm): ''' Compute the steady-state values for a specific membrane potential value. :param Vm: membrane potential (mV) :return: dictionary of steady-states ''' @abc.abstractmethod def derStates(self, Vm, states): ''' Compute the derivatives of channel states. :param Vm: membrane potential (mV) :states: state probabilities of the ion channels :return: current per unit area (mA/m2) ''' @abc.abstractmethod def computeEffRates(self, Vm): ''' Get the effective rate constants of ion channels, averaged along an acoustic cycle, for future use in effective simulations. :param Vm: array of membrane potential values for an acoustic cycle (mV) :return: a dictionary of rate average constants (s-1) ''' def interpEffRates(self, Qm, lkp, keys=None): ''' Interpolate effective rate constants for a given charge density using reference lookup vectors. :param Qm: membrane charge density (C/m2) :states: state probabilities of the ion channels :param lkp: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: dictionary of interpolated rate constants ''' if keys is None: keys = self.rates return {k: np.interp(Qm, lkp['Q'], lkp[k], left=np.nan, right=np.nan) for k in keys} def interpVmeff(self, Qm, lkp): ''' Interpolate the effective membrane potential for a given charge density using reference lookup vectors. :param Qm: membrane charge density (C/m2) :param lkp: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: dictionary of interpolated rate constants ''' return np.interp(Qm, lkp['Q'], lkp['V'], left=np.nan, right=np.nan) @abc.abstractmethod def derEffStates(self, Qm, states, lkp): ''' Compute the effective derivatives of channel states, based on 1-dimensional linear interpolation of "effective" coefficients that summarize the system's behaviour over an acoustic cycle. :param Qm: membrane charge density (C/m2) :states: state probabilities of the ion channels :param lkp: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. ''' def Qbounds(self): ''' Determine bounds of membrane charge physiological range for a given neuron. ''' return np.array([np.round(self.Vm0 - 25.0), 50.0]) * self.Cm0 * 1e-3 # C/m2 def isVoltageGated(self, state): ''' Determine whether a given state is purely voltage-gated or not.''' return 'alpha{}'.format(state.lower()) in self.rates def getGates(self): ''' Retrieve the names of the neuron's states that match an ion channel gating. ''' gates = [] for x in self.states: if self.isVoltageGated(x): gates.append(x) return gates def qsStates(self, lkp, states): ''' Compute a collection of quasi steady states using the standard xinf = ax / (ax + Bx) equation. :param lkp: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: dictionary of quasi-steady states ''' return { x: lkp['alpha{}'.format(x)] / (lkp['alpha{}'.format(x)] + lkp['beta{}'.format(x)]) for x in states } @abc.abstractmethod def quasiSteadyStates(self, lkp): ''' Compute the quasi-steady states of a neuron for a range of membrane charge densities, based on 1-dimensional lookups interpolated at a given sonophore diameter, US frequency, US amplitude and duty cycle. :param lkp: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: dictionary of quasi-steady states ''' def getRates(self, Vm): ''' Compute the ion channels rate constants for a given membrane potential. :param Vm: membrane potential (mV) :return: a dictionary of rate constants and their values at the given potential. ''' rates = {} for x in self.getGates(): x = x.lower() alpha_str, beta_str = ['{}{}'.format(s, x.lower()) for s in ['alpha', 'beta']] inf_str, tau_str = ['{}inf'.format(x.lower()), 'tau{}'.format(x.lower())] if hasattr(self, 'alpha{}'.format(x)): alphax = getattr(self, alpha_str)(Vm) betax = getattr(self, beta_str)(Vm) elif hasattr(self, '{}inf'.format(x)): xinf = getattr(self, inf_str)(Vm) taux = getattr(self, tau_str)(Vm) alphax = xinf / taux betax = 1 / taux - alphax rates[alpha_str] = alphax rates[beta_str] = betax return rates def Vderivatives(self, y, t, Iinj): ''' Compute the derivatives of a V-cast HH system for a specific value of injected current. :param y: vector of HH system variables at time t :param t: time value (s, unused) :param Iinj: injected current (mA/m2) :return: vector of HH system derivatives at time t ''' Vm, *states = y Iionic = self.iNet(Vm, states) # mA/m2 dVmdt = (- Iionic + Iinj) / self.Cm0 # mV/s dstates = self.derStates(Vm, states) return [dVmdt, *[dstates[k] for k in self.states]] def Qderivatives(self, y, t, Cm=None): ''' Compute the derivatives of the n-ODE HH system variables, based on a value of membrane capacitance. :param y: vector of HH system variables at time t :param t: specific instant in time (s) :param Cm: membrane capacitance (F/m2) :return: vector of HH system derivatives at time t ''' if Cm is None: Cm = self.Cm0 Qm, *states = y Vm = Qm / Cm * 1e3 # mV dQmdt = - self.iNet(Vm, states) * 1e-3 # A/m2 dstates = self.derStates(Vm, states) return [dQmdt, *[dstates[k] for k in self.states]] + def checkInputs(self, Astim, tstim, toffset, PRF, DC): ''' Check validity of electrical stimulation parameters. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) ''' # Check validity of stimulation parameters if not all(isinstance(param, float) for param in [Astim, tstim, toffset, DC]): raise TypeError('Invalid stimulation parameters (must be float typed)') if tstim <= 0: raise ValueError('Invalid stimulus duration: {} ms (must be strictly positive)' .format(tstim * 1e3)) if toffset < 0: raise ValueError('Invalid stimulus offset: {} ms (must be positive or null)' .format(toffset * 1e3)) if DC <= 0.0 or DC > 1.0: raise ValueError('Invalid duty cycle: {} (must be within ]0; 1])'.format(DC)) if DC < 1.0: if not isinstance(PRF, float): raise TypeError('Invalid PRF value (must be float typed)') if PRF is None: raise AttributeError('Missing PRF value (must be provided when DC < 1)') if PRF < 1 / tstim: raise ValueError('Invalid PRF: {} Hz (PR interval exceeds stimulus duration)' .format(PRF)) def simulate(self, Astim, tstim, toffset, PRF=None, DC=1.0): ''' Compute solutions of a neuron's HH system for a specific set of electrical stimulation parameters, using a classic integration scheme. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :return: 3-tuple with the time profile and solution matrix and a state vector ''' # Check validity of stimulation parameters self.checkInputs(Astim, tstim, toffset, PRF, DC) # Determine system time step dt = DT_ESTIM # if CW stimulus: divide integration during stimulus into single interval if DC == 1.0: PRF = 1 / tstim # Compute vector sizes npulses = int(np.round(PRF * tstim)) Tpulse_on = DC / PRF Tpulse_off = (1 - DC) / PRF # For high-PRF pulsed protocols: adapt time step to ensure minimal # number of samples during TON or TOFF dt_warning_msg = 'high-PRF protocol: lowering time step to %.2e s to properly integrate %s' for key, Tpulse in {'TON': Tpulse_on, 'TOFF': Tpulse_off}.items(): if Tpulse > 0 and Tpulse / dt < MIN_SAMPLES_PER_PULSE_INT: dt = Tpulse / MIN_SAMPLES_PER_PULSE_INT logger.warning(dt_warning_msg, dt, key) n_pulse_on = int(np.round(Tpulse_on / dt)) n_pulse_off = int(np.round(Tpulse_off / dt)) # Compute offset size n_off = int(np.round(toffset / dt)) # Set initial conditions steady_states = self.steadyStates(self.Vm0) y0 = [self.Vm0, *[steady_states[k] for k in self.states]] nvar = len(y0) # Initialize global arrays t = np.array([0.]) stimstate = np.array([1]) y = np.array([y0]).T # Initialize pulse time and stimstate vectors t_pulse0 = np.linspace(0, Tpulse_on + Tpulse_off, n_pulse_on + n_pulse_off) stimstate_pulse = np.concatenate((np.ones(n_pulse_on), np.zeros(n_pulse_off))) # Loop through all pulse (ON and OFF) intervals for i in range(npulses): # Construct and initialize arrays t_pulse = t_pulse0 + t[-1] y_pulse = np.empty((nvar, n_pulse_on + n_pulse_off)) # Integrate ON system y_pulse[:, :n_pulse_on] = odeint( self.Vderivatives, y[:, -1], t_pulse[:n_pulse_on], args=(Astim,)).T # Integrate OFF system if n_pulse_off > 0: y_pulse[:, n_pulse_on:] = odeint( self.Vderivatives, y_pulse[:, n_pulse_on - 1], t_pulse[n_pulse_on:], args=(0.0,)).T # Append pulse arrays to global arrays stimstate = np.concatenate([stimstate, stimstate_pulse[1:]]) t = np.concatenate([t, t_pulse[1:]]) y = np.concatenate([y, y_pulse[:, 1:]], axis=1) # Integrate offset interval if n_off > 0: t_off = np.linspace(0, toffset, n_off) + t[-1] stimstate_off = np.zeros(n_off) y_off = odeint(self.Vderivatives, y[:, -1], t_off, args=(0.0, )).T # Concatenate offset arrays to global arrays stimstate = np.concatenate([stimstate, stimstate_off[1:]]) t = np.concatenate([t, t_off[1:]]) y = np.concatenate([y, y_off[:, 1:]], axis=1) # Return output variables return (t, y, stimstate) def nSpikes(self, Astim, tstim, toffset, PRF, DC): ''' Run a simulation and determine number of spikes in the response. :param Astim: current amplitude (mA/m2) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :return: number of spikes found in response ''' t, y, _ = self.simulate(Astim, tstim, toffset, PRF, DC) dt = t[1] - t[0] ipeaks, *_ = findPeaks(y[0, :], SPIKE_MIN_VAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_VPROM) nspikes = ipeaks.size logger.debug('A = %sA/m2 ---> %s spike%s detected', si_format(Astim * 1e-3, 2, space=' '), nspikes, "s" if nspikes > 1 else "") return nspikes def titrate(self, tstim, toffset, PRF=None, DC=1.0, Arange=(0., 2 * TITRATION_ESTIM_A_MAX)): ''' Use a binary search to determine the threshold amplitude needed to obtain neural excitation for a given duration, PRF and duty cycle. :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param Arange: search interval for Astim, iteratively refined :return: excitation threshold amplitude (mA/m2) ''' return titrate(self.nSpikes, (tstim, toffset, PRF, DC), Arange, TITRATION_ESTIM_DA_MAX) def runAndSave(self, outdir, tstim, toffset, PRF=None, DC=1.0, Astim=None): ''' Run a simulation of the point-neuron Hodgkin-Huxley system with specific parameters, and save the results in a PKL file. :param outdir: full path to output directory :param tstim: stimulus duration (s) :param toffset: stimulus offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: stimulus duty cycle (-) :param Astim: stimulus amplitude (mA/m2) ''' # Get date and time info date_str = time.strftime("%Y.%m.%d") daytime_str = time.strftime("%H:%M:%S") logger.info( '%s: %s @ %st = %ss (%ss offset)%s', self, 'titration' if Astim is None else 'simulation', 'A = {}A/m2, '.format(si_format(Astim, 2, space=' ')) if Astim is not None else '', *si_format([tstim, toffset], 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format(si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) if Astim is None: Astim = self.titrate(tstim, toffset, PRF, DC) if np.isnan(Astim): logger.error('Could not find threshold excitation amplitude') return None # Run simulation tstart = time.time() t, y, stimstate = self.simulate(Astim, tstim, toffset, PRF, DC) Vm, *channels = y tcomp = time.time() - tstart # Detect spikes on Vm signal dt = t[1] - t[0] ipeaks, *_ = findPeaks(Vm, SPIKE_MIN_VAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_VPROM) nspikes = ipeaks.size lat = t[ipeaks[0]] if nspikes > 0 else 'N/A' outstr = '{} spike{} detected'.format(nspikes, 's' if nspikes > 1 else '') logger.debug('completed in %ss, %s', si_format(tcomp, 1), outstr) sr = np.mean(1 / np.diff(t[ipeaks])) if nspikes > 1 else None # Store dataframe and metadata df = pd.DataFrame({ 't': t, 'stimstate': stimstate, 'Vm': Vm, 'Qm': Vm * self.Cm0 * 1e-3 }) for j in range(len(self.states)): df[self.states[j]] = channels[j] meta = { 'neuron': self.name, 'Astim': Astim, 'tstim': tstim, 'toffset': toffset, 'PRF': PRF, 'DC': DC, 'tcomp': tcomp } # Export into to PKL file simcode = ESTIM_filecode(self.name, Astim, tstim, PRF, DC) outpath = '{}/{}.pkl'.format(outdir, simcode) with open(outpath, 'wb') as fh: pickle.dump({'meta': meta, 'data': df}, fh) logger.debug('simulation data exported to "%s"', outpath) # Export key metrics to log file logpath = os.path.join(outdir, 'log_ESTIM.xlsx') logentry = { 'Date': date_str, 'Time': daytime_str, 'Neuron Type': self.name, 'Astim (mA/m2)': Astim, 'Tstim (ms)': tstim * 1e3, 'PRF (kHz)': PRF * 1e-3 if DC < 1 else 'N/A', 'Duty factor': DC, '# samples': t.size, 'Comp. time (s)': round(tcomp, 2), '# spikes': nspikes, 'Latency (ms)': lat * 1e3 if isinstance(lat, float) else 'N/A', 'Spike rate (sp/ms)': sr * 1e-3 if isinstance(sr, float) else 'N/A' } if xlslog(logpath, logentry) == 1: logger.debug('log exported to "%s"', logpath) else: logger.error('log export to "%s" aborted', self.logpath) return outpath diff --git a/PySONIC/plt/QSS.py b/PySONIC/plt/QSS.py index 37707f6..e2881a4 100644 --- a/PySONIC/plt/QSS.py +++ b/PySONIC/plt/QSS.py @@ -1,528 +1,525 @@ import inspect from copy import deepcopy import pandas as pd import numpy as np import matplotlib.pyplot as plt from matplotlib import cm, colors from matplotlib.colors import ListedColormap from ..postpro import getFixedPoints from ..core import NeuronalBilayerSonophore from .pltutils import * from ..constants import TITRATION_T_OFFSET from ..utils import logger def plotVarQSSDynamics(neuron, a, Fdrive, Adrive, charges, varname, varrange, fs=12): ''' Plot the QSS-approximated derivative of a specific variable as function of the variable itself, as well as equilibrium values, for various membrane charge densities at a given acoustic amplitude. :param neuron: neuron object :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param Adrive: US amplitude (Pa) :param charges: charge density vector (C/m2) :param varname: name of variable to plot :param varrange: range over which to compute the derivative :return: figure handle ''' # Extract information about variable to plot pltvar = neuron.getPltVars()[varname] # Get methods to compute derivative and steady-state of variable of interest derX_func = getattr(neuron, 'der{}{}'.format(varname[0].upper(), varname[1:])) Xinf_func = getattr(neuron, '{}inf'.format(varname)) derX_args = inspect.getargspec(derX_func)[0][1:] Xinf_args = inspect.getargspec(Xinf_func)[0][1:] # Get dictionary of charge and amplitude dependent QSS variables nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) _, Qref, lookups, QSS = nbls.quasiSteadyStates( Fdrive, amps=Adrive, charges=charges, squeeze_output=True) df = QSS df['Vm'] = lookups['V'] # Create figure fig, ax = plt.subplots(figsize=(6, 4)) ax.set_title('{} neuron - QSS {} dynamics @ {:.2f} kPa'.format( neuron.name, pltvar['desc'], Adrive * 1e-3), fontsize=fs) ax.set_xscale('log') for key in ['top', 'right']: ax.spines[key].set_visible(False) ax.set_xlabel('$\\rm {}\ ({})$'.format(pltvar['label'], pltvar.get('unit', '')), fontsize=fs) ax.set_ylabel('$\\rm QSS\ d{}/dt\ ({}/s)$'.format(pltvar['label'], pltvar.get('unit', '1')), fontsize=fs) ax.set_ylim(-40, 40) ax.axhline(0, c='k', linewidth=0.5) y0_str = '{}0'.format(varname) if hasattr(neuron, y0_str): ax.axvline(getattr(neuron, y0_str) * pltvar.get('factor', 1), label=y0_str, c='k', linewidth=0.5) # For each charge value icolor = 0 for j, Qm in enumerate(charges): lbl = 'Q = {:.0f} nC/cm2'.format(Qm * 1e5) # Compute variable derivative as a function of its value, as well as equilibrium value, # keeping other variables at quasi steady-state derX_inputs = [varrange if arg == varname else df[arg][j] for arg in derX_args] Xinf_inputs = [df[arg][j] for arg in Xinf_args] dX_QSS = neuron.derCai(*derX_inputs) Xeq_QSS = neuron.Caiinf(*Xinf_inputs) # Plot variable derivative and its root as a function of the variable itself c = 'C{}'.format(icolor) ax.plot(varrange * pltvar.get('factor', 1), dX_QSS * pltvar.get('factor', 1), c=c, label=lbl) ax.axvline(Xeq_QSS * pltvar.get('factor', 1), linestyle='--', c=c) icolor += 1 ax.legend(frameon=False, fontsize=fs - 3) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) fig.tight_layout() fig.canvas.set_window_title('{}_QSS_{}_dynamics_{:.2f}kPa'.format( neuron.name, varname, Adrive * 1e-3)) return fig def plotQSSvars(neuron, a, Fdrive, Adrive, fs=12): ''' Plot effective membrane potential, quasi-steady states and resulting membrane currents as a function of membrane charge density, for a given acoustic amplitude. :param neuron: neuron object :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param Adrive: US amplitude (Pa) :return: figure handle ''' # Get neuron-specific pltvars pltvars = neuron.getPltVars() # Compute neuron-specific charge and amplitude dependent QS states at this amplitude nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) _, Qref, lookups, QSS = nbls.quasiSteadyStates(Fdrive, amps=Adrive, squeeze_output=True) Vmeff = lookups['V'] # Compute QSS currents currents = neuron.currents(Vmeff, np.array([QSS[k] for k in neuron.states])) iNet = sum(currents.values()) # Compute fixed points in dQdt profile dQdt = -iNet Q_SFPs = getFixedPoints(Qref, dQdt, filter='stable') Q_UFPs = getFixedPoints(Qref, dQdt, filter='unstable') # Extract dimensionless states norm_QSS = {} for x in neuron.states: if 'unit' not in pltvars[x]: norm_QSS[x] = QSS[x] # Create figure fig, axes = plt.subplots(3, 1, figsize=(7, 9)) axes[-1].set_xlabel('$\\rm Q_m\ (nC/cm^2)$', fontsize=fs) for ax in axes: for skey in ['top', 'right']: ax.spines[skey].set_visible(False) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) for item in ax.get_xticklabels(minor=True): item.set_visible(False) figname = '{} neuron - QSS dynamics @ {:.2f} kPa'.format(neuron.name, Adrive * 1e-3) fig.suptitle(figname, fontsize=fs) # Subplot: Vmeff ax = axes[0] ax.set_ylabel('$V_m^*$ (mV)', fontsize=fs) ax.plot(Qref * 1e5, Vmeff, color='k') ax.axhline(neuron.Vm0, linewidth=0.5, color='k') # Subplot: dimensionless quasi-steady states cset = plt.get_cmap('Dark2').colors + plt.get_cmap('tab10').colors ax = axes[1] ax.set_ylabel('QSS gating variables (-)', fontsize=fs) ax.set_yticks([0, 0.5, 1]) ax.set_ylim([-0.05, 1.05]) for i, (label, QS_state) in enumerate(norm_QSS.items()): ax.plot(Qref * 1e5, QS_state, label=label, c=cset[i]) # Subplot: currents ax = axes[2] cset = plt.get_cmap('tab10').colors ax.set_ylabel('QSS currents ($\\rm A/m^2$)', fontsize=fs) for i, (k, I) in enumerate(currents.items()): ax.plot(Qref * 1e5, -I * 1e-3, '--', c=cset[i], label='$\\rm -{}$'.format(neuron.getPltVars()[k]['label'])) ax.plot(Qref * 1e5, -iNet * 1e-3, color='k', label='$\\rm -I_{Net}$') ax.axhline(0, color='k', linewidth=0.5) if Q_SFPs.size > 0: ax.plot(Q_SFPs * 1e5, np.zeros(Q_SFPs.size), 'o', c='k', markersize=5, zorder=2) if Q_SFPs.size > 0: ax.plot(Q_UFPs * 1e5, np.zeros(Q_UFPs.size), 'o', c='k', markersize=5, mfc='none', zorder=2) fig.tight_layout() fig.subplots_adjust(right=0.8) for ax in axes[1:]: ax.legend(loc='center right', fontsize=fs, frameon=False, bbox_to_anchor=(1.3, 0.5)) for ax in axes[:-1]: ax.set_xticklabels([]) fig.canvas.set_window_title( '{}_QSS_states_vs_Qm_{:.2f}kPa'.format(neuron.name, Adrive * 1e-3)) return fig def plotQSSVarVsAmp(neuron, a, Fdrive, varname, amps=None, DC=1., fs=12, cmap='viridis', yscale='lin', zscale='lin'): ''' Plot a specific QSS variable (state or current) as a function of membrane charge density, for various acoustic amplitudes. :param neuron: neuron object :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param amps: US amplitudes (Pa) :param DC: duty cycle (-) :param varname: extraction key for variable to plot :return: figure handle ''' # Determine stimulation modality if a is None and Fdrive is None: stim_type = 'elec' a = 32e-9 Fdrive = 500e3 else: stim_type = 'US' # Extract information about variable to plot pltvar = neuron.getPltVars()[varname] Qvar = neuron.getPltVars()['Qm'] Afactor = {'US': 1e-3, 'elec': 1.}[stim_type] Q_SFPs = [] Q_UFPs = [] log = 'plotting {} neuron QSS {} vs. amp for {} stimulation @ {:.0f}% DC'.format( neuron.name, varname, stim_type, DC * 1e2) logger.info(log) nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) # Get reference dictionaries for zero amplitude _, Qref, lookups0, QSS0 = nbls.quasiSteadyStates(Fdrive, amps=0., squeeze_output=True) Vmeff0 = lookups0['V'] if stim_type == 'elec': # if E-STIM case, compute steady states with constant capacitance Vmeff0 = Qref / neuron.Cm0 * 1e3 QSS0 = neuron.steadyStates(Vmeff0) df0 = QSS0 df0['Vm'] = Vmeff0 # Create figure fig, ax = plt.subplots(figsize=(6, 4)) title = '{} neuron - {}steady-state {}'.format( neuron.name, 'quasi-' if amps is not None else '', pltvar['desc']) if amps is not None: title += '\nvs. {} amplitude @ {:.0f}% DC'.format(stim_type, DC * 1e2) ax.set_title(title, fontsize=fs) ax.set_xlabel('$\\rm {}\ ({})$'.format(Qvar['label'], Qvar['unit']), fontsize=fs) ax.set_ylabel('$\\rm QSS\ {}\ ({})$'.format(pltvar['label'], pltvar.get('unit', '')), fontsize=fs) if yscale == 'log': ax.set_yscale('log') for key in ['top', 'right']: ax.spines[key].set_visible(False) # Plot y-variable reference line, if any y0 = None y0_str = '{}0'.format(varname) if hasattr(neuron, y0_str): y0 = getattr(neuron, y0_str) * pltvar.get('factor', 1) elif varname in neuron.getCurrentsNames() + ['iNet', 'dQdt']: y0 = 0. y0_str = '' if y0 is not None: ax.axhline(y0, label=y0_str, c='k', linewidth=0.5) # Plot reference QSS profile of variable as a function of charge density var0 = extractPltVar( neuron, pltvar, pd.DataFrame({k: df0[k] for k in df0.keys()}), name=varname) ax.plot(Qref * Qvar['factor'], var0, '--', c='k', zorder=1, label='$\\rm A_{{{}}}=0$'.format(stim_type)) if varname == 'dQdt': Q_SFPs += getFixedPoints(Qref, var0, filter='stable').tolist() Q_UFPs += getFixedPoints(Qref, var0, filter='unstable').tolist() # Define color code mymap = plt.get_cmap(cmap) zref = amps * Afactor if zscale == 'lin': norm = colors.Normalize(zref.min(), zref.max()) elif zscale == 'log': norm = colors.LogNorm(zref.min(), zref.max()) sm = cm.ScalarMappable(norm=norm, cmap=mymap) sm._A = [] # Get amplitude-dependent QSS dictionary if stim_type == 'US': # Get dictionary of charge and amplitude dependent QSS variables _, Qref, lookups, QSS = nbls.quasiSteadyStates( Fdrive, amps=amps, DCs=DC, squeeze_output=True) df = QSS df['Vm'] = lookups['V'] else: # Repeat zero-amplitude QSS dictionary for all amplitudes df = {k: np.tile(df0[k], (amps.size, 1)) for k in df0} # Plot QSS profiles for various amplitudes for i, A in enumerate(amps): var = extractPltVar( neuron, pltvar, pd.DataFrame({k: df[k][i] for k in df.keys()}), name=varname) if varname == 'dQdt' and stim_type == 'elec': var += A * DC * pltvar['factor'] ax.plot(Qref * Qvar['factor'], var, c=sm.to_rgba(A * Afactor), zorder=0) if varname == 'dQdt': # mark eq. point if starting point provided, otherwise mark all FPs Q_SFPs += getFixedPoints(Qref, var, filter='stable').tolist() Q_UFPs += getFixedPoints(Qref, var, filter='unstable').tolist() # Plot fixed-points, if any if len(Q_SFPs) > 0: ax.plot(np.array(Q_SFPs) * Qvar['factor'], np.zeros(len(Q_SFPs)), 'o', c='k', markersize=5, zorder=2) if len(Q_UFPs) > 0: ax.plot(np.array(Q_UFPs) * Qvar['factor'], np.zeros(len(Q_UFPs)), 'x', c='k', markersize=5, zorder=2) # Add legend and adjust layout ax.legend(frameon=False, fontsize=fs) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) fig.tight_layout() fig.subplots_adjust(bottom=0.15, top=0.9, right=0.80, hspace=0.5) # Plot amplitude colorbar if amps is not None: cbarax = fig.add_axes([0.85, 0.15, 0.03, 0.75]) fig.colorbar(sm, cax=cbarax) cbarax.set_ylabel( 'Amplitude ({})'.format({'US': 'kPa', 'elec': 'mA/m2'}[stim_type]), fontsize=fs) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) title = '{}_{}SS_{}'.format(neuron.name, 'Q' if amps is not None else '', varname) if amps is not None: title += '_vs_{}A_{}_{:.0f}%DC'.format(zscale, stim_type, DC * 1e2) fig.canvas.set_window_title(title) return fig def plotEqChargeVsAmp(neurons, a, Fdrive, amps=None, tstim=250e-3, PRF=100.0, DCs=[1.], fs=12, xscale='lin', titrate=False): ''' Plot the equilibrium membrane charge density as a function of acoustic amplitude, given an initial value of membrane charge density. :param neurons: neuron objects :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param amps: US amplitudes (Pa) :return: figure handle ''' # Determine stimulation modality if a is None and Fdrive is None: stim_type = 'elec' a = 32e-9 Fdrive = 500e3 else: stim_type = 'US' logger.info('plotting equilibrium charges for %s stimulation', stim_type) # Create figure fig, ax = plt.subplots(figsize=(6, 4)) figname = 'charge stability vs. amplitude' ax.set_title(figname) ax.set_xlabel('Amplitude ({})'.format({'US': 'kPa', 'elec': 'mA/m2'}[stim_type]), fontsize=fs) ax.set_ylabel('$\\rm Q_m\ (nC/cm^2)$', fontsize=fs) if xscale == 'log': ax.set_xscale('log') for skey in ['top', 'right']: ax.spines[skey].set_visible(False) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) Qrange = (np.inf, -np.inf) icolor = 0 for i, neuron in enumerate(neurons): nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) # Compute reference charge variation array for zero amplitude _, Qref, lookups0, QSS0 = nbls.quasiSteadyStates(Fdrive, amps=0., squeeze_output=True) Qrange = (min(Qrange[0], Qref.min()), max(Qrange[1], Qref.max())) Vmeff0 = lookups0['V'] if stim_type == 'elec': # if E-STIM case, compute steady states with constant capacitance Vmeff0 = Qref / neuron.Cm0 * 1e3 QSS0 = neuron.steadyStates(Vmeff0) dQdt0 = -neuron.iNet(Vmeff0, np.array([QSS0[k] for k in neuron.states])) # mA/m2 # Compute 3D QSS charge variation array if stim_type == 'US': _, _, lookups, QSS = nbls.quasiSteadyStates(Fdrive, amps=amps, DCs=DCs) dQdt = -neuron.iNet(lookups['V'], np.array([QSS[k] for k in neuron.states])) # mA/m2 Afactor = 1e-3 else: Afactor = 1. dQdt = np.empty((amps.size, Qref.size, DCs.size)) for iA, A in enumerate(amps): for iDC, DC in enumerate(DCs): dQdt[iA, :, iDC] = dQdt0 + A * DC + # Stability parameters + rel_dx = .5 + Q_conv_thr = 1e-8 # C/m2 + Q_div_thr = 3e-5 # C/m2 + tint = 1e-3 # s + # For each duty cycle for iDC, DC in enumerate(DCs): color = 'k' if len(neurons) * len(DCs) == 1 else 'C{}'.format(icolor) # Initialize containers for stable and unstable fixed points SFPs = [] UFPs = [] # For each acoustic amplitude for iA, Adrive in enumerate(amps): logger.debug('-- A = {:.2f} kPa'.format(Adrive * 1e-3)) # Extract stable and unstable fixed points from QSS charge variation profile dQ_profile = dQdt[iA, :, iDC] sfp = getFixedPoints(Qref, dQ_profile, filter='stable').tolist() ufp = getFixedPoints(Qref, dQ_profile, filter='unstable').tolist() for Qpoint in ufp: UFPs.append((Adrive, Qpoint)) - # Re-compute QSS and corresponding derivatives at the stable Q-points + # Re-compute QSS at the stable Q-points # !!! -------------------------- Warning -------------------------- !!! # : QSS cannot be predicted accurately from linear interpolation # along any dimension, since they are non-linear functions of the lookups. # Hence, they must be re-computed for each new point in the parameter space # !!! ------------------------------------------------------------- !!! lookups1D = {k: v[iA, :, iDC] for k, v in lookups.items()} lookups1D['Q'] = Qref _, _, _, QSS_sfp = nbls.quasiSteadyStates( Fdrive, amps=Adrive, charges=np.array(sfp), DCs=DC) QSS_sfp = {k: v[0, :, 0] for k, v in QSS_sfp.items()} - dQSS_sfp = neuron.derEffStates(sfp, QSS_sfp.values(), lookups1D) # For each stable Q-point for ipoint, Qpoint in enumerate(sfp): logger.debug('---- Q-SFP = {:.2f} nC/cm2'.format(Qpoint * 1e5)) + QSS_Qpoint = {k: v[ipoint] for k, v in QSS_sfp.items()} + + # Simulate from unperturbed QSS and evaluate stability + Qmf, conv = nbls.evaluateStability( + tint, Qpoint, QSS_Qpoint, lookups1D, Q_conv_thr, Q_div_thr) + if not conv: + logger.warning( + 'diverging system at ({:.2f} kPa, {:.2f} nC/cm2)'.format( + Adrive * 1e-3, Qpoint * 1e5)) + UFPs.append((Adrive, Qpoint)) + else: - # Check that QSS derivatives are indeed zero (or close enough) - abs_tol = 1e-9 - for k, v in dQSS_sfp.items(): - if np.abs(v[ipoint] >= abs_tol): - raise ValueError( - 'non-zero {0} derivative (d{0}/dt = {1})'.format(k, v[ipoint])) - - # For each QSS - rel_dx = .5 - is_stable = [] - for x in neuron.states: + # For each state is_stable_state = [] - for sign, operator in zip([-1, +1], [np.greater, np.less]): - - # Perturb state with small offset - QSS_Qpoint = {k: v[ipoint] for k, v in QSS_sfp.items()} - QSS_perturbed = deepcopy(QSS_Qpoint) - QSS_perturbed[x] *= (1 + sign * rel_dx) - QSS_perturbed[x] = np.clip(QSS_perturbed[x], 0., 1.) - logger.debug('------ {}: {} -> {}'.format( - x, QSS_Qpoint[x], QSS_perturbed[x])) - - # Recompute all states derivatives (some are inter-dependent) - dQSS_perturbed = neuron.derEffStates( - Qpoint, QSS_perturbed.values(), lookups1D) - - # Define stability threshold - st_thr = sign * abs_tol - - # Print states whose derivatives have changed, and their stability - for k, v in dQSS_perturbed.items(): - if v != dQSS_sfp[k][ipoint]: - logger.debug('-------- d{}/dt = {} -> {}stable'.format( - k, v, '' if operator(v, st_thr) else 'un')) - - # Check if all states show stability upon x-state perturbation - # (i.e. negative derivative for positive offset and vice-versa) - is_stable_state.append( - np.all(operator(list(dQSS_perturbed.values()), st_thr))) - - # Check if all states show stability upon x-state perturbation - # in both directions - is_stable.append(np.all(is_stable_state)) - - # Classify fixed point as stable only if all states show stability - if np.all(is_stable): - SFPs.append((Adrive, Qpoint)) - else: - logger.warning('Unstable fixed-point at ({:.2f} kPa, {:.2f} nC.cm2)'.format( - Adrive * 1e-3, Qpoint * 1e5)) - UFPs.append((Adrive, Qpoint)) + for x in neuron.states: + is_stable_direction = [] + for sign in [-1, +1]: + + # Perturb state with small offset + QSS_perturbed = deepcopy(QSS_Qpoint) + QSS_perturbed[x] *= (1 + sign * rel_dx) + QSS_perturbed[x] = np.clip(QSS_perturbed[x], 0., 1.) + logger.debug('----- {}: {:.5f} -> {:.5f}'.format( + x, QSS_Qpoint[x], QSS_perturbed[x])) + + # Simulate from perturbed QSS and evaluate stability + Qmf, conv = nbls.evaluateStability( + tint, Qpoint, QSS_perturbed, lookups1D, Q_conv_thr, Q_div_thr) + is_stable_direction.append(conv) + + # Check if system shows stability upon x-state perturbation + # in both directions + is_stable_state.append(np.all(is_stable_direction)) + + # Check if system shows stability upon perturbation of all states + is_stable_fixed_point = np.all(is_stable_state) + logger.info('{}stable fixed-point at ({:.2f} kPa, {:.2f} nC/cm2)'.format( + '' if is_stable_fixed_point else 'un', Adrive * 1e-3, Qpoint * 1e5)) + + # Classify fixed point as stable only if all states show stability + if np.all(is_stable_state): + SFPs.append((Adrive, Qpoint)) + else: + UFPs.append((Adrive, Qpoint)) # Plot charge SFPs and UFPs for each acoustic amplitude A_SFPs, Q_SFPs = np.array(SFPs).T A_UFPs, Q_UFPs = np.array(UFPs).T ax.plot(np.array(A_SFPs) * Afactor, np.array(Q_SFPs) * 1e5, 'o', c=color, markersize=3, label='{} neuron - SFPs @ {:.0f} % DC'.format(neuron.name, DC * 1e2)) ax.plot(np.array(A_UFPs) * Afactor, np.array(Q_UFPs) * 1e5, 'x', c=color, markersize=3, label='{} neuron - UFPs @ {:.0f} % DC'.format(neuron.name, DC * 1e2)) # If specified, compute and plot the threshold excitation amplitude if titrate: if stim_type == 'US': Athr = nbls.titrate(Fdrive, tstim, TITRATION_T_OFFSET, PRF=PRF, DC=DC, Arange=(amps.min(), amps.max())) # Pa ax.axvline(Athr * Afactor, c=color, linestyle='--') else: Athr_pos = neuron.titrate(tstim, TITRATION_T_OFFSET, PRF=PRF, DC=DC, Arange=(0., amps.max())) # mA/m2 ax.axvline(Athr_pos * Afactor, c=color, linestyle='--') Athr_neg = neuron.titrate(tstim, TITRATION_T_OFFSET, PRF=PRF, DC=DC, Arange=(amps.min(), 0.)) # mA/m2 ax.axvline(Athr_neg * Afactor, c=color, linestyle='-.') icolor += 1 if len(neurons) * len(DCs) == 1: dQdt_sign = np.sign(np.squeeze(dQdt)) cmap = ListedColormap(plt.get_cmap('Pastel2').colors[:2]) # x = computeMeshEdges(amps, scale=xscale) * Afactor # y = computeMeshEdges(Qref) * 1e5 # xx, yy = np.meshgrid(x, y) # print(xx.shape, yy.shape) # ax.pcolormesh(xx.T, yy.T, dQdt_sign, cmap=cmap) ax.contourf(amps * Afactor, Qref * 1e5, dQdt_sign.T, cmap=cmap) # Post-process figure ax.set_ylim(np.array([Qrange[0], 0]) * 1e5) ax.legend(frameon=False, fontsize=fs) fig.tight_layout() fig.canvas.set_window_title('QSS_Qstab_vs_{}A_{}_{}_{}%DC{}'.format( xscale, '_'.join([n.name for n in neurons]), stim_type, '_'.join(['{:.0f}'.format(DC * 1e2) for DC in DCs]), '_with_thresholds' if titrate else '' )) return fig diff --git a/scripts/plot_QSS_IQ.py b/scripts/plot_QSS_IQ.py index f75105a..9b46a2d 100644 --- a/scripts/plot_QSS_IQ.py +++ b/scripts/plot_QSS_IQ.py @@ -1,119 +1,123 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2018-09-28 16:13:34 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-05-20 16:36:04 +# @Last Modified time: 2019-05-21 18:03:57 ''' Phase-plane analysis of neuron behavior under quasi-steady state approximation. ''' import os import numpy as np import matplotlib.pyplot as plt from argparse import ArgumentParser import logging from PySONIC.utils import logger, selectDirDialog from PySONIC.neurons import getNeuronsDict from PySONIC.plt import plotQSSvars, plotQSSVarVsAmp, plotEqChargeVsAmp def main(): ap = ArgumentParser() # Stimulation parameters ap.add_argument('-n', '--neurons', type=str, nargs='+', default=None, help='Neuron types') ap.add_argument('-o', '--outputdir', type=str, default=None, help='Output directory') ap.add_argument('-c', '--cmap', type=str, default='viridis', help='Colormap name') ap.add_argument('-v', '--verbose', default=False, action='store_true', help='Increase verbosity') ap.add_argument('-s', '--save', default=False, action='store_true', help='Save output figures') ap.add_argument('--titrate', default=False, action='store_true', help='Titrate excitation threshold') ap.add_argument('-A', '--amp', type=float, default=None, help='Amplitude (kPa or mA/m2)') ap.add_argument('--tstim', type=float, default=500., help='Stimulus duration for titration (ms)') ap.add_argument('--PRF', type=float, default=100., help='Pulse-repetition-frequency for titration (Hz)') ap.add_argument('--DC', type=float, nargs='+', default=None, help='Duty cycle (%)') ap.add_argument('--Ascale', type=str, default='lin', help='Scale type for acoustic amplitude ("lin" or "log")') ap.add_argument('--stim', type=str, default='US', help='Stimulation type ("US" or "elec")') ap.add_argument('--vars', type=str, nargs='+', default=None, help='Variables to plot') # Parse arguments args = ap.parse_args() logger.setLevel(logging.DEBUG if args.verbose else logging.INFO) neurons = ['RS', 'LTS'] if args.neurons is None else args.neurons neurons = [getNeuronsDict()[n]() for n in neurons] # US parameters a = 32e-9 # m Fdrive = 500e3 # Hz Arange = (1., 600.) # kPa nA = 10 US_amps = { 'lin': np.linspace(Arange[0], Arange[1], nA), 'log': np.logspace(np.log10(Arange[0]), np.log10(Arange[1]), nA) }[args.Ascale] * 1e3 # Pa # E-STIM parameters Irange = (-20., 20.) # mA/m2 nI = 100 Iinjs = np.linspace(Irange[0], Irange[1], nI) # mA/m2 # Pulsing parameters tstim = args.tstim * 1e-3 # s PRF = args.PRF # Hz DCs = [100.] if args.DC is None else args.DC # % DCs = np.array(DCs) * 1e-2 # (-) if args.stim == 'US': amps = US_amps if args.amp is None else np.array([args.amp * 1e3]) cmap = args.cmap else: a = None Fdrive = None amps = Iinjs if args.amp is None else np.array([args.amp]) cmap = 'RdBu_r' if args.vars is None: args.vars = ['dQdt'] figs = [] # Plot iNet vs Q for different amplitudes for each neuron and DC for i, neuron in enumerate(neurons): for DC in DCs: if amps.size == 1: figs.append( plotQSSvars(neuron, a, Fdrive, amps[0])) else: for var in args.vars: figs.append(plotQSSVarVsAmp( neuron, a, Fdrive, var, amps=amps, DC=DC, cmap=cmap, zscale=args.Ascale)) # Plot equilibrium charge as a function of amplitude for each neuron if amps.size > 1 and 'dQdt' in args.vars: - figs.append( - plotEqChargeVsAmp( - neurons, a, Fdrive, amps=amps, tstim=tstim, PRF=PRF, DCs=DCs, - xscale=args.Ascale, titrate=args.titrate)) + try: + figs.append( + plotEqChargeVsAmp( + neurons, a, Fdrive, amps=amps, tstim=tstim, PRF=PRF, DCs=DCs, + xscale=args.Ascale, titrate=args.titrate)) + except ValueError as err: + logger.error(err) + quit() if args.save: outputdir = args.outputdir if args.outputdir is not None else selectDirDialog() if outputdir == '': logger.error('no output directory') else: for fig in figs: s = fig.canvas.get_window_title() s = s.replace('(', '- ').replace('/', '_').replace(')', '') figname = '{}.png'.format(s) fig.savefig(os.path.join(outputdir, figname), transparent=True) else: plt.show() if __name__ == '__main__': main()