diff --git a/PySONIC/core/nbls.py b/PySONIC/core/nbls.py index 688091a..cc51833 100644 --- a/PySONIC/core/nbls.py +++ b/PySONIC/core/nbls.py @@ -1,603 +1,604 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2016-09-29 16:16:19 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-07-15 21:09:16 +# @Last Modified time: 2019-07-15 21:38:23 from copy import deepcopy import logging import numpy as np import pandas as pd from .simulators import PWSimulator, HybridSimulator, PeriodicSimulator from .bls import BilayerSonophore from .pneuron import PointNeuron from .model import Model from .batches import Batch from ..utils import * from ..constants import * from ..postpro import getFixedPoints from .lookups import SmartLookup NEURONS_LOOKUP_DIR = os.path.abspath(os.path.split(__file__)[0] + "/../neurons/") class NeuronalBilayerSonophore(BilayerSonophore): ''' This class inherits from the BilayerSonophore class and receives an PointNeuron instance at initialization, to define the electro-mechanical NICE model and its SONIC variant. ''' tscale = 'ms' # relevant temporal scale of the model simkey = 'ASTIM' # keyword used to characterize simulations made with this model def __init__(self, a, pneuron, Fdrive=None, embedding_depth=0.0): ''' Constructor of the class. :param a: in-plane radius of the sonophore structure within the membrane (m) :param pneuron: point-neuron model :param Fdrive: frequency of acoustic perturbation (Hz) :param embedding_depth: depth of the embedding tissue around the membrane (m) ''' # Check validity of input parameters if not isinstance(pneuron, PointNeuron): raise ValueError('Invalid neuron type: "{}" (must inherit from PointNeuron class)' .format(pneuron.name)) self.pneuron = pneuron # Initialize BilayerSonophore parent object BilayerSonophore.__init__(self, a, pneuron.Cm0, pneuron.Qm0(), embedding_depth) def __repr__(self): s = '{}({:.1f} nm, {}'.format(self.__class__.__name__, self.a * 1e9, self.pneuron) if self.d > 0.: s += ', d={}m'.format(si_format(self.d, precision=1, space=' ')) return s + ')' def params(self): return {**super().params(), **self.pneuron.params()} def getPltVars(self, wrapleft='df["', wrapright='"]'): return {**super().getPltVars(wrapleft, wrapright), **self.pneuron.getPltVars(wrapleft, wrapright)} def getPltScheme(self): return self.pneuron.getPltScheme() def filecode(self, *args): return Model.filecode(self, *args) @staticmethod def inputs(): # Get parent input vars and supress irrelevant entries bls_vars = BilayerSonophore.inputs() pneuron_vars = PointNeuron.inputs() del bls_vars['Qm'] del pneuron_vars['Astim'] # Fill in current input vars in appropriate order inputvars = bls_vars inputvars.update(pneuron_vars) inputvars['fs'] = { 'desc': 'sonophore membrane coverage fraction', 'label': 'f_s', 'unit': '\%', 'factor': 1e2, 'precision': 0 } inputvars['method'] = None return inputvars def filecodes(self, Fdrive, Adrive, tstim, toffset, PRF, DC, fs, method): # Get parent codes and supress irrelevant entries bls_codes = super().filecodes(Fdrive, Adrive, 0.0) pneuron_codes = self.pneuron.filecodes(0.0, tstim, toffset, PRF, DC) for x in [bls_codes, pneuron_codes]: del x['simkey'] del bls_codes['Qm'] del pneuron_codes['Astim'] # Fill in current codes in appropriate order codes = { 'simkey': self.simkey, 'neuron': pneuron_codes.pop('neuron'), 'nature': pneuron_codes.pop('nature') } codes.update(bls_codes) codes.update(pneuron_codes) codes['fs'] = 'fs{:.0f}%'.format(fs * 1e2) if fs < 1 else None codes['method'] = method return codes @staticmethod def interpOnOffVariable(key, Qm, stim, lkp): ''' Interpolate Q-dependent effective variable along ON and OFF periods of a solution. :param key: lookup variable key :param Qm: charge density solution vector :param stim: stimulation state solution vector :param lkp: dictionary of lookups for ON and OFF states :return: interpolated effective variable vector ''' x = np.zeros(stim.size) x[stim == 0] = lkp['OFF'].interpVar(Qm[stim == 0], 'Q', key) x[stim == 1] = lkp['ON'].interpVar(Qm[stim == 1], 'Q', key) return x @staticmethod def spatialAverage(fs, x, x0): ''' fs-modulated spatial averaging. ''' return fs * x + (1 - x) * x0 def computeEffVars(self, Fdrive, Adrive, Qm, fs): ''' Compute "effective" coefficients of the HH system for a specific combination of stimulus frequency, stimulus amplitude and charge density. A short mechanical simulation is run while imposing the specific charge density, until periodic stabilization. The HH coefficients are then averaged over the last acoustic cycle to yield "effective" coefficients. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param Qm: imposed charge density (C/m2) :param fs: list of sonophore membrane coverage fractions :return: list with computation time and a list of dictionaries of effective variables ''' # Run simulation and retrieve deflection and gas content vectors from last cycle data, meta = BilayerSonophore.simulate(self, Fdrive, Adrive, Qm) Z_last = data.loc[-NPC_DENSE:, 'Z'].values # m Cm_last = self.v_capacitance(Z_last) # F/m2 # For each coverage fraction effvars = [] for x in fs: # Compute membrane capacitance and membrane potential vectors Cm = self.spatialAverage(x, Cm_last, self.Cm0) # F/m2 Vm = Qm / Cm * 1e3 # mV # Compute average cycle value for membrane potential and rate constants effvars.append({**{'V': np.mean(Vm)}, **self.pneuron.getEffRates(Vm)}) # Log process log = '{}: lookups @ {}Hz, {}Pa, {:.2f} nC/cm2'.format( self, *si_format([Fdrive, Adrive], precision=1, space=' '), Qm * 1e5) if len(fs) > 1: log += ', fs = {:.0f} - {:.0f}%'.format(fs.min() * 1e2, fs.max() * 1e2) log += ', tcomp = {:.3f} s'.format(meta['tcomp']) logger.info(log) # Return effective coefficients return [meta['tcomp'], effvars] def fullDerivatives(self, t, y, Fdrive, Adrive, phi, fs): ''' Compute the full system derivatives. :param t: specific instant in time (s) :param y: vector of state variables :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param phi: acoustic drive phase (rad) :param fs: sonophore membrane coevrage fraction (-) :return: vector of derivatives ''' dydt_mech = BilayerSonophore.derivatives( self, t, y[:3], Fdrive, Adrive, y[3], phi) dydt_elec = self.pneuron.derivatives( t, y[3:], Cm=self.spatialAverage(fs, self.capacitance(y[1]), self.Cm0)) return dydt_mech + dydt_elec def effDerivatives(self, t, y, lkp1d): ''' Compute the derivatives of the n-ODE effective HH system variables, based on 1-dimensional linear interpolation of "effective" coefficients that summarize the system's behaviour over an acoustic cycle. :param t: specific instant in time (s) :param y: vector of HH system variables at time t :param lkp: dictionary of 1D data points of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: vector of effective system derivatives at time t ''' Qm, *states = y states_dict = dict(zip(self.pneuron.statesNames(), states)) lkp0d = lkp1d.interpolate1D('Q', Qm) dQmdt = - self.pneuron.iNet(lkp0d['V'], states_dict) * 1e-3 return [dQmdt, *self.pneuron.getDerEffStates(lkp0d, states_dict)] def _simFull(self, Fdrive, Adrive, tstim, toffset, PRF, DC, fs, phi=np.pi): # Determine time step dt = 1 / (NPC_DENSE * Fdrive) # Compute non-zero deflection value for a small perturbation (solving quasi-steady equation) Pac = self.Pacoustic(dt, Adrive, Fdrive, phi) Z0 = self.balancedefQS(self.ng0, self.Qm0, Pac) # Set initial conditions y0 = np.concatenate(( [0., Z0, self.ng0, self.Qm0], self.pneuron.getSteadyStates(self.pneuron.Vm0))) # Initialize simulator and compute solution logger.debug('Computing detailed solution') simulator = PWSimulator( lambda t, y: self.fullDerivatives(t, y, Fdrive, Adrive, phi, fs), lambda t, y: self.fullDerivatives(t, y, 0., 0., 0., fs)) t, y, stim = simulator( y0, dt, tstim, toffset, PRF, DC, print_progress=logger.getEffectiveLevel() <= logging.INFO, target_dt=CLASSIC_TARGET_DT) # Store output in dataframe and return data = pd.DataFrame({ 't': t, 'stimstate': stim, 'Z': y[:, 1], 'ng': y[:, 2], 'Qm': y[:, 3] }) data['Vm'] = data['Qm'].values / self.spatialAverage( fs, self.v_capacitance(data['Z'].values), self.Cm0) * 1e3 # mV for i in range(len(self.pneuron.states)): data[self.pneuron.statesNames()[i]] = y[:, i + 4] return data def _simHybrid(self, Fdrive, Adrive, tstim, toffset, PRF, DC, fs, phi=np.pi): # Determine time steps dt_dense, dt_sparse = [1. / (n * Fdrive) for n in [NPC_DENSE, NPC_SPARSE]] # Compute non-zero deflection value for a small perturbation (solving quasi-steady equation) Pac = self.Pacoustic(dt_dense, Adrive, Fdrive, phi) Z0 = self.balancedefQS(self.ng0, self.Qm0, Pac) # Set initial conditions y0 = np.concatenate(( [0., Z0, self.ng0, self.Qm0], self.pneuron.getSteadyStates(self.pneuron.Vm0))) is_dense_var = np.array([True] * 3 + [False] * (len(self.pneuron.states) + 1)) # Initialize simulator and compute solution logger.debug('Computing hybrid solution') simulator = HybridSimulator( lambda t, y: self.fullDerivatives(t, y, Fdrive, Adrive, phi, fs), lambda t, y: self.fullDerivatives(t, y, 0., 0., 0., fs), lambda t, y, Cm: self.pneuron.derivatives( t, y, Cm=self.spatialAverage(fs, Cm, self.Cm0)), lambda yref: self.capacitance(yref[1]), is_dense_var, ivars_to_check=[1, 2]) t, y, stim = simulator(y0, dt_dense, dt_sparse, 1. / Fdrive, tstim, toffset, PRF, DC) # Store output in dataframe and return data = pd.DataFrame({ 't': t, 'stimstate': stim, 'Z': y[:, 1], 'ng': y[:, 2], 'Qm': y[:, 3] }) data['Vm'] = data['Qm'].values / self.spatialAverage( fs, self.v_capacitance(data['Z'].values), self.Cm0) * 1e3 # mV for i in range(len(self.pneuron.states)): data[self.pneuron.statesNames()[i]] = y[:, i + 4] return data def _simSonic(self, Fdrive, Adrive, tstim, toffset, PRF, DC, fs): # Load appropriate 2D lookups lkp2d = self.getLookup2D(Fdrive, fs) # Interpolate 2D lookups at zero and US amplitude logger.debug('Interpolating lookups at A = %.2f kPa and A = 0', Adrive * 1e-3) lkps1d = {'ON': lkp2d.project('A', Adrive), 'OFF': lkp2d.project('A', 0.)} # Set initial conditions y0 = np.concatenate(([self.Qm0], self.pneuron.getSteadyStates(self.pneuron.Vm0))) # Initialize simulator and compute solution logger.debug('Computing effective solution') simulator = PWSimulator( lambda t, y: self.effDerivatives(t, y, lkps1d['ON']), lambda t, y: self.effDerivatives(t, y, lkps1d['OFF'])) t, y, stim = simulator(y0, DT_EFFECTIVE, tstim, toffset, PRF, DC) # Store output in dataframe and return data = pd.DataFrame({ 't': t, 'stimstate': stim, 'Qm': y[:, 0] }) data['Vm'] = self.interpOnOffVariable('V', data['Qm'].values, stim, lkps1d) for key in ['Z', 'ng']: data[key] = np.full(t.size, np.nan) for i in range(len(self.pneuron.states)): data[self.pneuron.statesNames()[i]] = y[:, i + 1] return data @classmethod def simQueue(cls, freqs, amps, durations, offsets, PRFs, DCs, fs, methods, outputdir=None): ''' Create a serialized 2D array of all parameter combinations for a series of individual parameter sweeps, while avoiding repetition of CW protocols for a given PRF sweep. :param freqs: list (or 1D-array) of US frequencies :param amps: list (or 1D-array) of acoustic amplitudes :param durations: list (or 1D-array) of stimulus durations :param offsets: list (or 1D-array) of stimulus offsets (paired with durations array) :param PRFs: list (or 1D-array) of pulse-repetition frequencies :param DCs: list (or 1D-array) of duty cycle values :param fs: sonophore membrane coverage fractions (-) :params methods: integration methods :return: list of parameters (list) for each simulation ''' method_ids = list(range(len(methods))) if ('full' in methods or 'hybrid' in methods) and outputdir is None: logger.warning('Running cumbersome simulation(s) without file saving') if amps is None: amps = [np.nan] DCs = np.array(DCs) queue = [] if 1.0 in DCs: queue += Batch.createQueue( freqs, amps, durations, offsets, min(PRFs), 1.0, fs, method_ids) if np.any(DCs != 1.0): queue += Batch.createQueue( freqs, amps, durations, offsets, PRFs, DCs[DCs != 1.0], fs, method_ids) for item in queue: if np.isnan(item[1]): item[1] = None item[-1] = methods[int(item[-1])] return cls.checkOutputDir(queue, outputdir) @Model.logNSpikes @Model.checkTitrate('Adrive') @Model.addMeta def simulate(self, Fdrive, Adrive, tstim, toffset, PRF=100., DC=1., fs=1., method='sonic'): ''' Simulate the electro-mechanical model for a specific set of US stimulation parameters, and return output data in a dataframe. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param fs: sonophore membrane coverage fraction (-) :param method: selected integration method :return: 2-tuple with the output dataframe and computation time. ''' logger.info( '%s: simulation @ f = %sHz, A = %sPa, t = %ss (%ss offset)%s%s', self, si_format(Fdrive, 0, space=' '), si_format(Adrive, 2, space=' '), *si_format([tstim, toffset], 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format( si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else ''), ', fs = {:.2f}%'.format(fs * 1e2) if fs < 1.0 else '') # Check validity of stimulation parameters BilayerSonophore.checkInputs(Fdrive, Adrive, 0.0, 0.0) PointNeuron.checkInputs(Adrive, tstim, toffset, PRF, DC) # Call appropriate simulation function and return try: simfunc = { 'full': self._simFull, 'hybrid': self._simHybrid, 'sonic': self._simSonic }[method] except KeyError: raise ValueError('Invalid integration method: "{}"'.format(method)) return simfunc(Fdrive, Adrive, tstim, toffset, PRF, DC, fs) def meta(self, Fdrive, Adrive, tstim, toffset, PRF, DC, fs, method): return { 'simkey': self.simkey, 'neuron': self.pneuron.name, 'a': self.a, 'd': self.d, 'Fdrive': Fdrive, 'Adrive': Adrive, 'tstim': tstim, 'toffset': toffset, 'PRF': PRF, 'DC': DC, 'fs': fs, 'method': method } @staticmethod def getNSpikes(data): return PointNeuron.getNSpikes(data) @logCache(os.path.join(os.path.split(__file__)[0], 'astim_titrations.log')) def titrate(self, Fdrive, tstim, toffset, PRF=100., DC=1., fs=1., method='sonic', xfunc=None, Arange=None): ''' Use a binary search to determine the threshold amplitude needed to obtain neural excitation for a given frequency, duration, PRF and duty cycle. :param Fdrive: US frequency (Hz) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param fs: sonophore membrane coverage fraction (-) :param method: integration method :param xfunc: function determining whether condition is reached from simulation output :param Arange: search interval for Adrive, iteratively refined :return: determined threshold amplitude (Pa) ''' # Default output function if xfunc is None: xfunc = self.pneuron.titrationFunc # Default amplitude interval if Arange is None: Arange = [0., self.getLookup().refs['A'].max()] return binarySearch( lambda x: xfunc(self.simulate(*x)[0]), [Fdrive, tstim, toffset, PRF, DC, fs, method], 1, Arange, THRESHOLD_CONV_RANGE_ASTIM ) - def quasiSteadyStates(self, Fdrive, amps=None, charges=None, DCs=1.0, squeeze_output=False): + def getQuasiSteadyStates(self, Fdrive, amps=None, charges=None, DCs=1.0, squeeze_output=False): ''' Compute the quasi-steady state values of the neuron's gating variables for a combination of US amplitudes, charge densities and duty cycles, at a specific US frequency. :param Fdrive: US frequency (Hz) :param amps: US amplitudes (Pa) :param charges: membrane charge densities (C/m2) :param DCs: duty cycle value(s) :return: 4-tuple with reference values of US amplitude and charge density, as well as interpolated Vmeff and QSS gating variables ''' # Get DC-averaged lookups interpolated at the appropriate amplitudes and charges lkp = self.getLookup().projectDCs(amps=amps, DCs=DCs).projectN({'a': self.a, 'f': Fdrive}) if charges is not None: lkp = lkp.project('Q', charges) # Specify dimensions with A and DC as the first two axes A_axis = lkp.getAxisIndex('A') lkp.move('A', 0) lkp.move('DC', 1) nA, nDC = lkp.dims()[:2] # Compute QSS states using these lookups QSS = {k: np.empty(lkp.dims()) for k in self.pneuron.statesNames()} for iA in range(nA): for iDC in range(nDC): - QSS_1D = self.pneuron.quasiSteadyStates({k: v[iA, iDC] for k, v in lkp.items()}) + lkp1d = {k: v[iA, iDC] for k, v in lkp.items()} + QSS_1D = {k: v(lkp1d) for k, v in self.pneuron.quasiSteadyStates().items()} for k in QSS.keys(): QSS[k][iA, iDC] = QSS_1D[k] QSS = SmartLookup(lkp.refs, QSS) for item in [lkp, QSS]: item.move('A', A_axis) item.move('DC', -1) # Compress outputs if needed if squeeze_output: QSS = QSS.squeeze() lkp = lkp.squeeze() return lkp, QSS def iNetQSS(self, Qm, Fdrive, Adrive, DC): ''' Compute quasi-steady state net membrane current for a given combination of US parameters and a given membrane charge density. :param Qm: membrane charge density (C/m2) :param Fdrive: US frequency (Hz) :param Adrive: US amplitude (Pa) :param DC: duty cycle (-) :return: net membrane current (mA/m2) ''' - lkp, QSS = self.quasiSteadyStates( + lkp, QSS = self.getQuasiSteadyStates( Fdrive, amps=Adrive, charges=Qm, DCs=DC, squeeze_output=True) return self.pneuron.iNet(lkp['V'], QSS) # mA/m2 def fixedPointsQSS(self, Fdrive, Adrive, DC, lkp, dQdt): ''' Compute QSS fixed points along the charge dimension for a given combination of US parameters, and determine their stability. :param Fdrive: US frequency (Hz) :param Adrive: US amplitude (Pa) :param DC: duty cycle (-) :param lkp: lookup dictionary for effective variables along charge dimension :param dQdt: charge derivative profile along charge dimension :return: 2-tuple with values of stable and unstable fixed points ''' pltvars = self.getPltVars() logger.debug('A = {:.2f} kPa, DC = {:.0f}%'.format(Adrive * 1e-3, DC * 1e2)) # Extract fixed points from QSS charge variation profile def dfunc(Qm): return - self.iNetQSS(Qm, Fdrive, Adrive, DC) SFP_candidates = getFixedPoints( lkp.refs['Q'], dQdt, filter='both', der_func=dfunc).tolist() SFPs, UFPs = [], [] dfunc = lambda x: self.effDerivatives(_, x, lkp) # For each fixed point for i, Qm in enumerate(SFP_candidates): # Re-compute QSS - *_, QSS_FP = self.quasiSteadyStates(Fdrive, amps=Adrive, charges=Qm, DCs=DC, - squeeze_output=True) + *_, QSS_FP = self.getQuasiSteadyStates(Fdrive, amps=Adrive, charges=Qm, DCs=DC, + squeeze_output=True) # Approximate the system's Jacobian matrix at the fixed-point and compute its eigenvalues J = jacobian([Qm, *QSS_FP.tables.values()], dfunc) lambdas, _ = np.linalg.eig(J) print(lambdas.real) # Determine fixed point stability based on eigenvalues is_stable_FP = np.all(lambdas.real < 0) s = 'fixed point @ Q = {:.2f} nC/cm2'.format(Qm * 1e5) if is_stable_FP: SFPs.append(Qm) logger.debug('stable ' + s) else: UFPs.append(Qm) logger.debug('unstable ' + s) return SFPs, UFPs def isStableQSS(self, Fdrive, Adrive, DC): - lookups, QSS = self.quasiSteadyStates( + lookups, QSS = self.getQuasiSteadyStates( Fdrive, amps=Adrive, DCs=DC, squeeze_output=True) dQdt = -self.pneuron.iNet(lookups['V'], QSS.tables) # mA/m2 SFPs, _ = self.fixedPointsQSS(Fdrive, Adrive, DC, lookups, dQdt) return len(SFPs) > 0 def titrateQSS(self, Fdrive, DC=1., Arange=None): # Default amplitude interval if Arange is None: Arange = [0., self.getLookup().refs['A'].max()] # Titration function def xfunc(x): if self.pneuron.name == 'STN': return self.isStableQSS(*x) else: return not self.isStableQSS(*x) return binarySearch( xfunc, [Fdrive, DC], 1, Arange, THRESHOLD_CONV_RANGE_ASTIM) def getLookupFileName(self, a=None, Fdrive=None, Adrive=None, fs=False): fname = '{}_lookups'.format(self.pneuron.name) if a is not None: fname += '_{:.0f}nm'.format(a * 1e9) if Fdrive is not None: fname += '_{:.0f}kHz'.format(Fdrive * 1e-3) if Adrive is not None: fname += '_{:.0f}kPa'.format(Adrive * 1e-3) if fs is True: fname += '_fs' return '{}.pkl'.format(fname) def getLookupFilePath(self, *args, **kwargs): return os.path.join(NEURONS_LOOKUP_DIR, self.getLookupFileName(*args, **kwargs)) def getLookup(self, *args, **kwargs): lookup_path = self.getLookupFilePath(*args, **kwargs) if not os.path.isfile(lookup_path): raise FileNotFoundError('Missing lookup file: "{}"'.format(lookup_path)) with open(lookup_path, 'rb') as fh: frame = pickle.load(fh) if 'ng' in frame['lookup']: del frame['lookup']['ng'] refs = frame['input'] # Move fs to last reference dimension keys = list(refs.keys()) if 'fs' in keys and keys.index('fs') < len(keys) - 1: del keys[keys.index('fs')] keys.append('fs') refs = {k: refs[k] for k in keys} return SmartLookup(refs, frame['lookup']) def getLookup2D(self, Fdrive, fs): if fs < 1: lkp2d = self.getLookup(a=self.a, Fdrive=Fdrive, fs=True).project('fs', fs) else: lkp2d = self.getLookup().projectN({'a': self.a, 'f': Fdrive}) return lkp2d diff --git a/PySONIC/core/pneuron.py b/PySONIC/core/pneuron.py index 83db1e1..3f27d6a 100644 --- a/PySONIC/core/pneuron.py +++ b/PySONIC/core/pneuron.py @@ -1,596 +1,589 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-08-03 11:53:04 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-07-01 17:34:06 +# @Last Modified time: 2019-07-15 21:38:13 import abc import inspect import numpy as np import pandas as pd from .batches import Batch from .model import Model from .lookups import SmartLookup from .simulators import PWSimulator from ..postpro import findPeaks, computeFRProfile from ..constants import * from ..utils import * class PointNeuron(Model): ''' Generic point-neuron model interface. ''' tscale = 'ms' # relevant temporal scale of the model simkey = 'ESTIM' # keyword used to characterize simulations made with this model def __repr__(self): return self.__class__.__name__ @property @classmethod @abc.abstractmethod def name(cls): ''' Neuron name. ''' raise NotImplementedError @property @classmethod @abc.abstractmethod def Cm0(cls): ''' Neuron's resting capacitance (F/cm2). ''' raise NotImplementedError @property @classmethod @abc.abstractmethod def Vm0(cls): ''' Neuron's resting membrane potential(mV). ''' raise NotImplementedError @classmethod def Qm0(cls): return cls.Cm0 * cls.Vm0 * 1e-3 # C/cm2 @staticmethod def inputs(): return { 'Astim': { 'desc': 'current density amplitude', 'label': 'A', 'unit': 'mA/m2', 'factor': 1e0, 'precision': 1 }, 'tstim': { 'desc': 'stimulus duration', 'label': 't_{stim}', 'unit': 'ms', 'factor': 1e3, 'precision': 0 }, 'toffset': { 'desc': 'offset duration', 'label': 't_{offset}', 'unit': 'ms', 'factor': 1e3, 'precision': 0 }, 'PRF': { 'desc': 'pulse repetition frequency', 'label': 'PRF', 'unit': 'Hz', 'factor': 1e0, 'precision': 0 }, 'DC': { 'desc': 'duty cycle', 'label': 'DC', 'unit': '%', 'factor': 1e2, 'precision': 2 } } @classmethod def filecodes(cls, Astim, tstim, toffset, PRF, DC): is_CW = DC == 1. return { 'simkey': cls.simkey, 'neuron': cls.name, 'nature': 'CW' if is_CW else 'PW', 'Astim': '{:.1f}mAm2'.format(Astim), 'tstim': '{:.0f}ms'.format(tstim * 1e3), 'toffset': None, 'PRF': 'PRF{:.2f}Hz'.format(PRF) if not is_CW else None, 'DC': 'DC{:.2f}%'.format(DC * 1e2) if not is_CW else None } @classmethod def getPltVars(cls, wrapleft='df["', wrapright='"]'): pltvars = { 'Qm': { 'desc': 'membrane charge density', 'label': 'Q_m', 'unit': 'nC/cm^2', 'factor': 1e5, 'bounds': (-100, 50) }, 'Vm': { 'desc': 'membrane potential', 'label': 'V_m', 'unit': 'mV', 'y0': cls.Vm0, 'bounds': (-150, 70) }, 'ELeak': { 'constant': 'obj.ELeak', 'desc': 'non-specific leakage current resting potential', 'label': 'V_{leak}', 'unit': 'mV', 'ls': '--', 'color': 'k' } } for cname in cls.getCurrentsNames(): cfunc = getattr(cls, cname) cargs = inspect.getargspec(cfunc)[0][1:] pltvars[cname] = { 'desc': inspect.getdoc(cfunc).splitlines()[0], 'label': 'I_{{{}}}'.format(cname[1:]), 'unit': 'A/m^2', 'factor': 1e-3, 'func': '{}({})'.format(cname, ', '.join(['{}{}{}'.format(wrapleft, a, wrapright) for a in cargs])) } for var in cargs: if var != 'Vm': pltvars[var] = { 'desc': cls.states[var], 'label': var, 'bounds': (-0.1, 1.1) } pltvars['iNet'] = { 'desc': inspect.getdoc(getattr(cls, 'iNet')).splitlines()[0], 'label': 'I_{net}', 'unit': 'A/m^2', 'factor': 1e-3, 'func': 'iNet({0}Vm{1}, {2}{3}{4})'.format( wrapleft, wrapright, wrapleft[:-1], cls.statesNames(), wrapright[1:]), 'ls': '--', 'color': 'black' } pltvars['dQdt'] = { 'desc': inspect.getdoc(getattr(cls, 'dQdt')).splitlines()[0], 'label': 'dQ_m/dt', 'unit': 'A/m^2', 'factor': 1e-3, 'func': 'dQdt({0}Vm{1}, {2}{3}{4})'.format( wrapleft, wrapright, wrapleft[:-1], cls.statesNames(), wrapright[1:]), 'ls': '--', 'color': 'black' } for rate in cls.rates: if 'alpha' in rate: prefix, suffix = 'alpha', rate[5:] else: prefix, suffix = 'beta', rate[4:] pltvars['{}'.format(rate)] = { 'label': '\\{}_{{{}}}'.format(prefix, suffix), 'unit': 'ms^{-1}', 'factor': 1e-3 } pltvars['FR'] = { 'desc': 'riring rate', 'label': 'FR', 'unit': 'Hz', 'factor': 1e0, # 'bounds': (0, 1e3), 'func': 'firingRateProfile({0}t{1}.values, {0}Qm{1}.values)'.format(wrapleft, wrapright) } return pltvars @classmethod def getPltScheme(cls): pltscheme = { 'Q_m': ['Qm'], 'V_m': ['Vm'] } pltscheme['I'] = cls.getCurrentsNames() + ['iNet'] for cname in cls.getCurrentsNames(): if 'Leak' not in cname: key = 'i_{{{}}}\ kin.'.format(cname[1:]) cargs = inspect.getargspec(getattr(cls, cname))[0][1:] pltscheme[key] = [var for var in cargs if var not in ['Vm', 'Cai']] return pltscheme @classmethod def statesNames(cls): ''' Return a list of names of all state variables of the model. ''' return list(cls.states.keys()) @classmethod @abc.abstractmethod def derStates(cls): ''' Dictionary of states derivatives functions ''' raise NotImplementedError @classmethod def getDerStates(cls, Vm, states): ''' Compute states derivatives array given a membrane potential and states dictionary ''' return np.array([cls.derStates()[k](Vm, states) for k in cls.statesNames()]) @classmethod @abc.abstractmethod def steadyStates(cls): ''' Return a dictionary of steady-states functions ''' raise NotImplementedError @classmethod def getSteadyStates(cls, Vm): ''' Compute array of steady-states for a given membrane potential ''' return np.array([cls.steadyStates()[k](Vm) for k in cls.statesNames()]) @classmethod def getDerEffStates(cls, lkp, states): ''' Compute effective states derivatives array given lookups and states dictionaries. ''' return np.array([ cls.derEffStates()[k](lkp, states) for k in cls.statesNames()]) @classmethod def getEffRates(cls, Vm): ''' Compute array of effective rate constants for a given membrane potential vector. ''' return {k: np.mean(np.vectorize(v)(Vm)) for k, v in cls.effRates().items()} @classmethod def getLookup(cls): ''' Get lookup of membrane potential rate constants interpolated along the neuron's charge physiological range. ''' Qref = np.arange(*cls.Qbounds(), 1e-5) # C/m2 Vref = Qref / cls.Cm0 * 1e3 # mV tables = {k: np.vectorize(v)(Vref) for k, v in cls.effRates().items()} return SmartLookup({'Q': Qref}, {**{'V': Vref}, **tables}) @classmethod @abc.abstractmethod def currents(cls): ''' Dictionary of ionic currents functions (returning current densities in mA/m2) ''' @classmethod def iNet(cls, Vm, states): ''' net membrane current :param Vm: membrane potential (mV) :states: states of ion channels gating and related variables :return: current per unit area (mA/m2) ''' return sum([cfunc(Vm, states) for cfunc in cls.currents().values()]) @classmethod def dQdt(cls, Vm, states): ''' membrane charge density variation rate :param Vm: membrane potential (mV) :states: states of ion channels gating and related variables :return: variation rate (mA/m2) ''' return -cls.iNet(Vm, states) @classmethod def titrationFunc(cls, *args, **kwargs): ''' Default titration function. ''' return cls.isExcited(*args, **kwargs) @staticmethod def currentToConcentrationRate(z_ion, depth): ''' Compute the conversion factor from a specific ionic current (in mA/m2) into a variation rate of submembrane ion concentration (in M/s). :param: z_ion: ion valence :param depth: submembrane depth (m) :return: conversion factor (Mmol.m-1.C-1) ''' return 1e-6 / (z_ion * depth * FARADAY) @staticmethod def nernst(z_ion, Cion_in, Cion_out, T): ''' Nernst potential of a specific ion given its intra and extracellular concentrations. :param z_ion: ion valence :param Cion_in: intracellular ion concentration :param Cion_out: extracellular ion concentration :param T: temperature (K) :return: ion Nernst potential (mV) ''' return (Rg * T) / (z_ion * FARADAY) * np.log(Cion_out / Cion_in) * 1e3 @staticmethod def vtrap(x, y): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x / y) - 1) @staticmethod def efun(x): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x) - 1) @classmethod def ghkDrive(cls, Vm, Z_ion, Cion_in, Cion_out, T): ''' Use the Goldman-Hodgkin-Katz equation to compute the electrochemical driving force of a specific ion species for a given membrane potential. :param Vm: membrane potential (mV) :param Cin: intracellular ion concentration (M) :param Cout: extracellular ion concentration (M) :param T: temperature (K) :return: electrochemical driving force of a single ion particle (mC.m-3) ''' x = Z_ion * FARADAY * Vm / (Rg * T) * 1e-3 # [-] eCin = Cion_in * cls.efun(-x) # M eCout = Cion_out * cls.efun(x) # M return FARADAY * (eCin - eCout) * 1e6 # mC/m3 @classmethod def getCurrentsNames(cls): return list(cls.currents().keys()) def firingRateProfile(*args, **kwargs): return computeFRProfile(*args, **kwargs) @classmethod def Qbounds(cls): ''' Determine bounds of membrane charge physiological range for a given neuron. ''' return np.array([np.round(cls.Vm0 - 25.0), 50.0]) * cls.Cm0 * 1e-3 # C/m2 @classmethod def isVoltageGated(cls, state): ''' Determine whether a given state is purely voltage-gated or not.''' return 'alpha{}'.format(state.lower()) in cls.rates @staticmethod - def qsStates(lkp, states): - ''' Compute a collection of quasi steady states using the standard - xinf = ax / (ax + Bx) equation. + def qsState(x): + ''' Create a function that returns a given quasi steady state given a lookup table, + using the standard xinf = ax / (ax + Bx) equation. - :param lkp: dictionary of 1D vectors of "effective" coefficients - over the charge domain, for specific frequency and amplitude values. - :return: dictionary of quasi-steady states + :param x: state name. + :return: quasi-steady state function ''' - return { - x: lkp['alpha{}'.format(x)] / (lkp['alpha{}'.format(x)] + lkp['beta{}'.format(x)]) - for x in states - } + return lambda lkp: lkp[f'alpha{x}'] / (lkp[f'alpha{x}'] + lkp[f'beta{x}']) @classmethod - def quasiSteadyStates(cls, lkp): - ''' Compute the quasi-steady states of a neuron for a range of membrane charge densities, - based on 1-dimensional lookups interpolated at a given sonophore diameter, US frequency, - US amplitude and duty cycle. - - :param lkp: dictionary of 1D vectors of "effective" coefficients - over the charge domain, for specific frequency and amplitude values. - :return: dictionary of quasi-steady states + def quasiSteadyStates(cls): + ''' Create a dictionary of functions computing quasi-steady states of a neuron + for a range of membrane charge densities, based on 1-dimensional lookups + interpolated at a given sonophore diameter, US frequency, US amplitude and duty cycle. + + :return: dictionary of quasi-steady state functions ''' - return cls.qsStates(lkp, cls.statesNames()) - # return {k: func(lkp['Vm']) for k, func in self.steadyStates().items()} + return {k: cls.qsState(k) for k in cls.statesNames()} @classmethod def simQueue(cls, amps, durations, offsets, PRFs, DCs, outputdir=None): ''' Create a serialized 2D array of all parameter combinations for a series of individual parameter sweeps, while avoiding repetition of CW protocols for a given PRF sweep. :param amps: list (or 1D-array) of acoustic amplitudes :param durations: list (or 1D-array) of stimulus durations :param offsets: list (or 1D-array) of stimulus offsets (paired with durations array) :param PRFs: list (or 1D-array) of pulse-repetition frequencies :param DCs: list (or 1D-array) of duty cycle values :return: list of parameters (list) for each simulation ''' if amps is None: amps = [np.nan] DCs = np.array(DCs) queue = [] if 1.0 in DCs: queue += Batch.createQueue(amps, durations, offsets, min(PRFs), 1.0) if np.any(DCs != 1.0): queue += Batch.createQueue(amps, durations, offsets, PRFs, DCs[DCs != 1.0]) for item in queue: if np.isnan(item[0]): item[0] = None return cls.checkOutputDir(queue, outputdir) @staticmethod def checkInputs(Astim, tstim, toffset, PRF, DC): ''' Check validity of electrical stimulation parameters. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) ''' # Check validity of stimulation parameters if not all(isinstance(param, float) for param in [Astim, tstim, toffset, DC]): raise TypeError('Invalid stimulation parameters (must be float typed)') if tstim <= 0: raise ValueError('Invalid stimulus duration: {} ms (must be strictly positive)' .format(tstim * 1e3)) if toffset < 0: raise ValueError('Invalid stimulus offset: {} ms (must be positive or null)' .format(toffset * 1e3)) if DC <= 0.0 or DC > 1.0: raise ValueError('Invalid duty cycle: {} (must be within ]0; 1])'.format(DC)) if DC < 1.0: if not isinstance(PRF, float): raise TypeError('Invalid PRF value (must be float typed)') if PRF is None: raise AttributeError('Missing PRF value (must be provided when DC < 1)') if PRF < 1 / tstim: raise ValueError('Invalid PRF: {} Hz (PR interval exceeds stimulus duration)' .format(PRF)) @classmethod def derivatives(cls, t, y, Cm=None, Iinj=0.): ''' Compute system derivatives for a given mambrane capacitance and injected current. :param t: specific instant in time (s) :param y: vector of HH system variables at time t :param Cm: membrane capacitance (F/m2) :param Iinj: injected current (mA/m2) :return: vector of system derivatives at time t ''' if Cm is None: Cm = cls.Cm0 Qm, *states = y Vm = Qm / Cm * 1e3 # mV states_dict = dict(zip(cls.statesNames(), states)) dQmdt = (Iinj - cls.iNet(Vm, states_dict)) * 1e-3 # A/m2 return [dQmdt, *cls.getDerStates(Vm, states_dict)] @Model.logNSpikes @Model.checkTitrate('Astim') @Model.addMeta def simulate(self, Astim, tstim, toffset, PRF=100., DC=1.0): ''' Simulate a specific neuron model for a specific set of electrical parameters, and return output data in a dataframe. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :return: 2-tuple with the output dataframe and computation time. ''' logger.info( '%s: simulation @ A = %sA/m2, t = %ss (%ss offset)%s', self, si_format(Astim, 2, space=' '), *si_format([tstim, toffset], 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format( si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) # Check validity of stimulation parameters self.checkInputs(Astim, tstim, toffset, PRF, DC) # Set initial conditions y0 = np.array((self.Qm0(), *self.getSteadyStates(self.Vm0))) # Initialize simulator and compute solution logger.debug('Computing solution') simulator = PWSimulator( lambda t, y: self.derivatives(t, y, Iinj=Astim), lambda t, y: self.derivatives(t, y, Iinj=0.)) t, y, stim = simulator( y0, DT_EFFECTIVE, tstim, toffset, PRF, DC) # Store output in dataframe and return data = pd.DataFrame({ 't': t, 'stimstate': stim, 'Qm': y[:, 0], 'Vm': y[:, 0] / self.Cm0 * 1e3, }) for i in range(len(self.states)): data[self.statesNames()[i]] = y[:, i + 1] return data @classmethod def meta(cls, Astim, tstim, toffset, PRF, DC): return { 'simkey': cls.simkey, 'neuron': cls.name, 'Astim': Astim, 'tstim': tstim, 'toffset': toffset, 'PRF': PRF, 'DC': DC } @staticmethod def getNSpikes(data): ''' Compute number of spikes in charge profile of simulation output. :param data: dataframe containing output time series :return: number of detected spikes ''' dt = np.diff(data.ix[:1, 't'].values)[0] ipeaks, *_ = findPeaks( data['Qm'].values, SPIKE_MIN_QAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_QPROM ) return ipeaks.size @staticmethod def getStabilizationValue(data): ''' Determine stabilization value from the charge profile of a simulation output. :param data: dataframe containing output time series :return: charge stabilization value (or np.nan if no stabilization detected) ''' # Extract charge signal posterior to observation window t, Qm = [data[key].values for key in ['t', 'Qm']] if t.max() <= TMIN_STABILIZATION: raise ValueError('solution length is too short to assess stabilization') Qm = Qm[t > TMIN_STABILIZATION] # Compute variation range Qm_range = np.ptp(Qm) logger.debug('%.2f nC/cm2 variation range over the last %.0f ms, Qmf = %.2f nC/cm2', Qm_range * 1e5, TMIN_STABILIZATION * 1e3, Qm[-1] * 1e5) # Return final value only if stabilization is detected if np.ptp(Qm) < QSS_Q_DIV_THR: return Qm[-1] else: return np.nan @classmethod def isExcited(cls, data): ''' Determine if neuron is excited from simulation output. :param data: dataframe containing output time series :return: boolean stating whether neuron is excited or not ''' return cls.getNSpikes(data) > 0 @classmethod def isSilenced(cls, data): ''' Determine if neuron is silenced from simulation output. :param data: dataframe containing output time series :return: boolean stating whether neuron is silenced or not ''' return not np.isnan(cls.getStabilizationValue(data)) def titrate(self, tstim, toffset, PRF, DC, xfunc=None, Arange=(0., 2 * AMP_UPPER_BOUND_ESTIM)): ''' Use a binary search to determine the threshold amplitude needed to obtain neural excitation for a given duration, PRF and duty cycle. :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param xfunc: function determining whether condition is reached from simulation output :param Arange: search interval for Astim, iteratively refined :return: excitation threshold amplitude (mA/m2) ''' # Default output function if xfunc is None: xfunc = self.titrationFunc return binarySearch( lambda x: xfunc(self.simulate(*x)[0]), [tstim, toffset, PRF, DC], 0, Arange, THRESHOLD_CONV_RANGE_ESTIM ) diff --git a/PySONIC/plt/QSS.py b/PySONIC/plt/QSS.py index 60cb5e2..18ad94e 100644 --- a/PySONIC/plt/QSS.py +++ b/PySONIC/plt/QSS.py @@ -1,505 +1,505 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2019-06-04 18:24:29 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-07-15 20:34:54 +# @Last Modified time: 2019-07-15 21:17:51 import inspect import logging import pandas as pd import numpy as np import matplotlib.pyplot as plt from ..core import NeuronalBilayerSonophore, Batch from .pltutils import * from ..utils import logger, fileCache root = '../../../QSS analysis/data' def plotVarQSSDynamics(pneuron, a, Fdrive, Adrive, charges, varname, varrange, fs=12): ''' Plot the QSS-approximated derivative of a specific variable as function of the variable itself, as well as equilibrium values, for various membrane charge densities at a given acoustic amplitude. :param pneuron: point-neuron model :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param Adrive: US amplitude (Pa) :param charges: charge density vector (C/m2) :param varname: name of variable to plot :param varrange: range over which to compute the derivative :return: figure handle ''' # Extract information about variable to plot pltvar = pneuron.getPltVars()[varname] # Get methods to compute derivative and steady-state of variable of interest derX_func = getattr(pneuron, 'der{}{}'.format(varname[0].upper(), varname[1:])) Xinf_func = getattr(pneuron, '{}inf'.format(varname)) derX_args = inspect.getargspec(derX_func)[0][1:] Xinf_args = inspect.getargspec(Xinf_func)[0][1:] # Get dictionary of charge and amplitude dependent QSS variables nbls = NeuronalBilayerSonophore(a, pneuron, Fdrive) - _, Qref, lookups, QSS = nbls.quasiSteadyStates( + _, Qref, lookups, QSS = nbls.getQuasiSteadyStates( Fdrive, amps=Adrive, charges=charges, squeeze_output=True) df = QSS df['Vm'] = lookups['V'] # Create figure fig, ax = plt.subplots(figsize=(6, 4)) ax.set_title('{} neuron - QSS {} dynamics @ {:.2f} kPa'.format( pneuron.name, pltvar['desc'], Adrive * 1e-3), fontsize=fs) ax.set_xscale('log') for key in ['top', 'right']: ax.spines[key].set_visible(False) ax.set_xlabel('$\\rm {}\ ({})$'.format(pltvar['label'], pltvar.get('unit', '')), fontsize=fs) ax.set_ylabel('$\\rm QSS\ d{}/dt\ ({}/s)$'.format(pltvar['label'], pltvar.get('unit', '1')), fontsize=fs) ax.set_ylim(-40, 40) ax.axhline(0, c='k', linewidth=0.5) y0_str = '{}0'.format(varname) if hasattr(pneuron, y0_str): ax.axvline(getattr(pneuron, y0_str) * pltvar.get('factor', 1), label=y0_str, c='k', linewidth=0.5) # For each charge value icolor = 0 for j, Qm in enumerate(charges): lbl = 'Q = {:.0f} nC/cm2'.format(Qm * 1e5) # Compute variable derivative as a function of its value, as well as equilibrium value, # keeping other variables at quasi steady-state derX_inputs = [varrange if arg == varname else df[arg][j] for arg in derX_args] Xinf_inputs = [df[arg][j] for arg in Xinf_args] dX_QSS = pneuron.derCai(*derX_inputs) Xeq_QSS = pneuron.Caiinf(*Xinf_inputs) # Plot variable derivative and its root as a function of the variable itself c = 'C{}'.format(icolor) ax.plot(varrange * pltvar.get('factor', 1), dX_QSS * pltvar.get('factor', 1), c=c, label=lbl) ax.axvline(Xeq_QSS * pltvar.get('factor', 1), linestyle='--', c=c) icolor += 1 ax.legend(frameon=False, fontsize=fs - 3) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) fig.tight_layout() fig.canvas.set_window_title('{}_QSS_{}_dynamics_{:.2f}kPa'.format( pneuron.name, varname, Adrive * 1e-3)) return fig def plotQSSdynamics(pneuron, a, Fdrive, Adrive, DC=1., fs=12): ''' Plot effective membrane potential, quasi-steady states and resulting membrane currents as a function of membrane charge density, for a given acoustic amplitude. :param pneuron: point-neuron model :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param Adrive: US amplitude (Pa) :return: figure handle ''' # Get neuron-specific pltvars pltvars = pneuron.getPltVars() # Compute neuron-specific charge and amplitude dependent QS states at this amplitude nbls = NeuronalBilayerSonophore(a, pneuron, Fdrive) - lookups, QSS = nbls.quasiSteadyStates(Fdrive, amps=Adrive, DCs=DC, squeeze_output=True) + lookups, QSS = nbls.getQuasiSteadyStates(Fdrive, amps=Adrive, DCs=DC, squeeze_output=True) Qref = lookups.refs['Q'] Vmeff = lookups['V'] # Compute QSS currents and 1D charge variation array states = {k: QSS[k] for k in pneuron.states} currents = {name: cfunc(Vmeff, states) for name, cfunc in pneuron.currents().items()} iNet = sum(currents.values()) dQdt = -iNet # Compute stable and unstable fixed points Q_SFPs, Q_UFPs = nbls.fixedPointsQSS(Fdrive, Adrive, DC, lookups, dQdt) # Extract dimensionless states norm_QSS = {} for x in pneuron.states: if 'unit' not in pltvars[x]: norm_QSS[x] = QSS[x] # Create figure fig, axes = plt.subplots(3, 1, figsize=(7, 9)) axes[-1].set_xlabel('$\\rm Q_m\ (nC/cm^2)$', fontsize=fs) for ax in axes: for skey in ['top', 'right']: ax.spines[skey].set_visible(False) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) for item in ax.get_xticklabels(minor=True): item.set_visible(False) fig.suptitle('{} neuron - QSS dynamics @ {:.2f} kPa, {:.0f}%DC'.format( pneuron.name, Adrive * 1e-3, DC * 1e2), fontsize=fs) # Subplot: Vmeff ax = axes[0] ax.set_ylabel('$V_m^*$ (mV)', fontsize=fs) ax.plot(Qref * 1e5, Vmeff, color='k') ax.axhline(pneuron.Vm0, linewidth=0.5, color='k') # Subplot: dimensionless quasi-steady states cset = plt.get_cmap('Dark2').colors + plt.get_cmap('tab10').colors ax = axes[1] ax.set_ylabel('QSS gating variables (-)', fontsize=fs) ax.set_yticks([0, 0.5, 1]) ax.set_ylim([-0.05, 1.05]) for i, (label, QS_state) in enumerate(norm_QSS.items()): ax.plot(Qref * 1e5, QS_state, label=label, c=cset[i]) # Subplot: currents ax = axes[2] cset = plt.get_cmap('tab10').colors ax.set_ylabel('QSS currents ($\\rm A/m^2$)', fontsize=fs) for i, (k, I) in enumerate(currents.items()): ax.plot(Qref * 1e5, -I * 1e-3, '--', c=cset[i], label='$\\rm -{}$'.format(pneuron.getPltVars()[k]['label'])) ax.plot(Qref * 1e5, -iNet * 1e-3, color='k', label='$\\rm -I_{Net}$') ax.axhline(0, color='k', linewidth=0.5) if len(Q_SFPs) > 0: ax.scatter(np.array(Q_SFPs) * 1e5, np.zeros(len(Q_SFPs)), marker='.', s=200, facecolors='g', edgecolors='none', label='QSS stable FPs', zorder=3) if len(Q_UFPs) > 0: ax.scatter(np.array(Q_UFPs) * 1e5, np.zeros(len(Q_UFPs)), marker='.', s=200, facecolors='r', edgecolors='none', label='QSS unstable FPs', zorder=3) fig.tight_layout() fig.subplots_adjust(right=0.8) for ax in axes[1:]: ax.legend(loc='center right', fontsize=fs, frameon=False, bbox_to_anchor=(1.3, 0.5)) for ax in axes[:-1]: ax.set_xticklabels([]) fig.canvas.set_window_title( '{}_QSS_dynamics_vs_Qm_{:.2f}kPa_DC{:.0f}%'.format(pneuron.name, Adrive * 1e-3, DC * 1e2)) return fig def plotQSSVarVsQm(pneuron, a, Fdrive, varname, amps=None, DC=1., fs=12, cmap='viridis', yscale='lin', zscale='lin', mpi=False, loglevel=logging.INFO): ''' Plot a specific QSS variable (state or current) as a function of membrane charge density, for various acoustic amplitudes. :param pneuron: point-neuron model :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param amps: US amplitudes (Pa) :param DC: duty cycle (-) :param varname: extraction key for variable to plot :return: figure handle ''' # Extract information about variable to plot pltvar = pneuron.getPltVars()[varname] Qvar = pneuron.getPltVars()['Qm'] Afactor = 1e-3 logger.info('plotting %s neuron QSS %s vs. Qm for various amplitudes @ %.0f%% DC', pneuron.name, pltvar['desc'], DC * 1e2) nbls = NeuronalBilayerSonophore(a, pneuron, Fdrive) # Get reference dictionaries for zero amplitude - lookups0, QSS0 = nbls.quasiSteadyStates(Fdrive, amps=0., squeeze_output=True) + lookups0, QSS0 = nbls.getQuasiSteadyStates(Fdrive, amps=0., squeeze_output=True) Vmeff0 = lookups0['V'] Qref = lookups0.refs['Q'] df0 = QSS0.tables df0['Vm'] = Vmeff0 # Create figure fig, ax = plt.subplots(figsize=(6, 4)) title = '{} neuron - QSS {} vs. Qm - {:.0f}% DC'.format(pneuron.name, varname, DC * 1e2) ax.set_title(title, fontsize=fs) ax.set_xlabel('$\\rm {}\ ({})$'.format(Qvar['label'], Qvar['unit']), fontsize=fs) ax.set_ylabel('$\\rm QSS\ {}\ ({})$'.format(pltvar['label'], pltvar.get('unit', '')), fontsize=fs) if yscale == 'log': ax.set_yscale('log') for key in ['top', 'right']: ax.spines[key].set_visible(False) # Plot y-variable reference line, if any y0 = None y0_str = '{}0'.format(varname) if hasattr(pneuron, y0_str): y0 = getattr(pneuron, y0_str) * pltvar.get('factor', 1) elif varname in pneuron.getCurrentsNames() + ['iNet', 'dQdt']: y0 = 0. y0_str = '' if y0 is not None: ax.axhline(y0, label=y0_str, c='k', linewidth=0.5) # Plot reference QSS profile of variable as a function of charge density var0 = extractPltVar( pneuron, pltvar, pd.DataFrame({k: df0[k] for k in df0.keys()}), name=varname) ax.plot(Qref * Qvar['factor'], var0, '--', c='k', zorder=1, label='A = 0') if varname == 'dQdt': # Plot charge SFPs and UFPs for each acoustic amplitude SFPs, UFPs = getQSSFixedPointsvsAdrive( nbls, Fdrive, amps, DC, mpi=mpi, loglevel=loglevel) if len(SFPs) > 0: _, Q_SFPs = np.array(SFPs).T ax.scatter(np.array(Q_SFPs) * 1e5, np.zeros(len(Q_SFPs)), marker='.', s=100, facecolors='g', edgecolors='none', label='QSS stable fixed points') if len(UFPs) > 0: _, Q_UFPs = np.array(UFPs).T ax.scatter(np.array(Q_UFPs) * 1e5, np.zeros(len(Q_UFPs)), marker='.', s=100, facecolors='r', edgecolors='none', label='QSS unstable fixed points') # Define color code mymap = plt.get_cmap(cmap) zref = amps * Afactor norm, sm = setNormalizer(mymap, (zref.min(), zref.max()), zscale) # Get amplitude-dependent QSS dictionary - lookups, QSS = nbls.quasiSteadyStates( + lookups, QSS = nbls.getQuasiSteadyStates( Fdrive, amps=amps, DCs=DC, squeeze_output=True) df = QSS.tables df['Vm'] = lookups['V'] # Plot QSS profiles for various amplitudes for i, A in enumerate(amps): var = extractPltVar( pneuron, pltvar, pd.DataFrame({k: df[k][i] for k in df.keys()}), name=varname) ax.plot(Qref * Qvar['factor'], var, c=sm.to_rgba(A * Afactor), zorder=0) # Add legend and adjust layout ax.legend(frameon=False, fontsize=fs) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) fig.tight_layout() fig.subplots_adjust(bottom=0.15, top=0.9, right=0.80, hspace=0.5) # Plot amplitude colorbar if amps is not None: cbarax = fig.add_axes([0.85, 0.15, 0.03, 0.75]) fig.colorbar(sm, cax=cbarax) cbarax.set_ylabel('Amplitude (kPa)', fontsize=fs) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) fig.canvas.set_window_title('{}_QSS_{}_vs_Qm_{}A_{:.2f}-{:.2f}kPa_DC{:.0f}%'.format( pneuron.name, varname, zscale, amps.min() * 1e-3, amps.max() * 1e-3, DC * 1e2)) return fig @fileCache( root, lambda nbls, Fdrive, amps, DC: '{}_QSS_FPs_{:.0f}kHz_{:.2f}-{:.2f}kPa_DC{:.0f}%'.format( nbls.pneuron.name, Fdrive * 1e-3, amps.min() * 1e-3, amps.max() * 1e-3, DC * 1e2) ) def getQSSFixedPointsvsAdrive(nbls, Fdrive, amps, DC, mpi=False, loglevel=logging.INFO): # Compute 2D QSS charge variation array - lkp2d, QSS = nbls.quasiSteadyStates( + lkp2d, QSS = nbls.getQuasiSteadyStates( Fdrive, amps=amps, DCs=DC, squeeze_output=True) dQdt = -nbls.pneuron.iNet(lkp2d['V'], QSS.tables) # mA/m2 # Generate batch queue queue = [] for iA, Adrive in enumerate(amps): lkp1d = lkp2d.project('A', Adrive) queue.append([Fdrive, Adrive, DC, lkp1d, dQdt[iA, :]]) # Run batch to find stable and unstable fixed points at each amplitude batch = Batch(nbls.fixedPointsQSS, queue) output = batch(mpi=mpi, loglevel=loglevel) # Sort points by amplitude SFPs, UFPs = [], [] for i, Adrive in enumerate(amps): SFPs += [(Adrive, Qm) for Qm in output[i][0]] UFPs += [(Adrive, Qm) for Qm in output[i][1]] return SFPs, UFPs def runAndGetStab(nbls, *args): args = list(args[:-1]) + [1., args[-1]] # hacking coverage fraction into args return nbls.pneuron.getStabilizationValue(nbls.getOutput(*args)[0]) @fileCache( root, lambda nbls, Fdrive, amps, tstim, toffset, PRF, DC: '{}_sim_FPs_{:.0f}kHz_{:.0f}ms_offset{:.0f}ms_PRF{:.0f}Hz_{:.2f}-{:.2f}kPa_DC{:.0f}%'.format( nbls.pneuron.name, Fdrive * 1e-3, tstim * 1e3, toffset * 1e3, PRF, amps.min() * 1e-3, amps.max() * 1e-3, DC * 1e2) ) def getSimFixedPointsvsAdrive(nbls, Fdrive, amps, tstim, toffset, PRF, DC, outputdir=None, mpi=False, loglevel=logging.INFO): # Run batch to find stabilization point from simulations (if any) at each amplitude queue = [[nbls, outputdir, Fdrive, Adrive, tstim, toffset, PRF, DC, 'sonic'] for Adrive in amps] batch = Batch(runAndGetStab, queue) output = batch(mpi=mpi, loglevel=loglevel) return list(zip(amps, output)) def plotEqChargeVsAmp(pneuron, a, Fdrive, amps=None, tstim=None, toffset=None, PRF=None, DC=1., fs=12, xscale='lin', compdir=None, mpi=False, loglevel=logging.INFO): ''' Plot the equilibrium membrane charge density as a function of acoustic amplitude, given an initial value of membrane charge density. :param pneuron: point-neuron model :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param amps: US amplitudes (Pa) :return: figure handle ''' logger.info('plotting equilibrium charges for various amplitudes') # Create figure fig, ax = plt.subplots(figsize=(6, 4)) figname = '{} neuron - charge stability vs. amplitude @ {:.0f}%DC'.format( pneuron.name, DC * 1e2) ax.set_title(figname) ax.set_xlabel('Amplitude (kPa)', fontsize=fs) ax.set_ylabel('$\\rm Q_m\ (nC/cm^2)$', fontsize=fs) if xscale == 'log': ax.set_xscale('log') for skey in ['top', 'right']: ax.spines[skey].set_visible(False) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) nbls = NeuronalBilayerSonophore(a, pneuron, Fdrive) Afactor = 1e-3 # Plot charge SFPs and UFPs for each acoustic amplitude SFPs, UFPs = getQSSFixedPointsvsAdrive( nbls, Fdrive, amps, DC, mpi=mpi, loglevel=loglevel) if len(SFPs) > 0: A_SFPs, Q_SFPs = np.array(SFPs).T ax.scatter(np.array(A_SFPs) * Afactor, np.array(Q_SFPs) * 1e5, marker='.', s=20, facecolors='g', edgecolors='none', label='QSS stable fixed points') if len(UFPs) > 0: A_UFPs, Q_UFPs = np.array(UFPs).T ax.scatter(np.array(A_UFPs) * Afactor, np.array(Q_UFPs) * 1e5, marker='.', s=20, facecolors='r', edgecolors='none', label='QSS unstable fixed points') # Plot charge asymptotic stabilization points from simulations for each acoustic amplitude if compdir is not None: stab_points = getSimFixedPointsvsAdrive( nbls, Fdrive, amps, tstim, toffset, PRF, DC, outputdir=compdir, mpi=mpi, loglevel=loglevel) if len(stab_points) > 0: A_stab, Q_stab = np.array(stab_points).T ax.scatter(np.array(A_stab) * Afactor, np.array(Q_stab) * 1e5, marker='o', s=20, facecolors='none', edgecolors='k', label='stabilization points from simulations') # Post-process figure ax.set_ylim(np.array([pneuron.Qm0() - 10e-5, 0]) * 1e5) ax.legend(frameon=False, fontsize=fs) fig.tight_layout() fig.canvas.set_window_title('{}_QSS_Qstab_vs_{}A_{:.0f}%DC{}'.format( pneuron.name, xscale, DC * 1e2, '_with_comp' if compdir is not None else '' )) return fig @fileCache( root, lambda nbls, Fdrive, DCs: '{}_QSS_threshold_curve_{:.0f}kHz_DC{:.2f}-{:.2f}%'.format( nbls.pneuron.name, Fdrive * 1e-3, DCs.min() * 1e2, DCs.max() * 1e2), ext='csv' ) def getQSSThresholdAmps(nbls, Fdrive, DCs, mpi=False, loglevel=logging.INFO): queue = [[Fdrive, DC] for DC in DCs] batch = Batch(nbls.titrateQSS, queue) return batch(mpi=mpi, loglevel=loglevel) @fileCache( root, lambda nbls, Fdrive, tstim, toffset, PRF, DCs: '{}_sim_threshold_curve_{:.0f}kHz_{:.0f}ms_offset{:.0f}ms_PRF{:.0f}Hz_DC{:.2f}-{:.2f}%'.format( nbls.pneuron.name, Fdrive * 1e-3, tstim * 1e3, toffset * 1e3, PRF, DCs.min() * 1e2, DCs.max() * 1e2), ext='csv' ) def getSimThresholdAmps(nbls, Fdrive, tstim, toffset, PRF, DCs, mpi=False, loglevel=logging.INFO): # Run batch to find threshold amplitude from titrations at each DC queue = [[Fdrive, tstim, toffset, PRF, DC, 'sonic'] for DC in DCs] batch = Batch(nbls.titrate, queue) return batch(mpi=mpi, loglevel=loglevel) def plotQSSThresholdCurve(pneuron, a, Fdrive, tstim=None, toffset=None, PRF=None, DCs=None, fs=12, Ascale='lin', comp=False, mpi=False, loglevel=logging.INFO): logger.info('plotting %s neuron threshold curve', pneuron.name) if pneuron.name == 'STN': raise ValueError('cannot compute threshold curve for STN neuron') # Create figure fig, ax = plt.subplots(figsize=(6, 4)) figname = '{} neuron - threshold amplitude vs. duty cycle'.format(pneuron.name) ax.set_title(figname) ax.set_xlabel('Duty cycle (%)', fontsize=fs) ax.set_ylabel('Amplitude (kPa)', fontsize=fs) if Ascale == 'log': ax.set_yscale('log') for skey in ['top', 'right']: ax.spines[skey].set_visible(False) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) nbls = NeuronalBilayerSonophore(a, pneuron, Fdrive) Athrs_QSS = np.array(getQSSThresholdAmps(nbls, Fdrive, DCs, mpi=mpi, loglevel=loglevel)) ax.plot(DCs * 1e2, Athrs_QSS * 1e-3, '-', c='k', label='QSS curve') if comp: Athrs_sim = np.array(getSimThresholdAmps( nbls, Fdrive, tstim, toffset, PRF, DCs, mpi=mpi, loglevel=loglevel)) ax.plot(DCs * 1e2, Athrs_sim * 1e-3, '--', c='k', label='sim curve') # Post-process figure ax.set_xlim([0, 100]) ax.set_ylim([10, 600]) ax.legend(frameon=False, fontsize=fs) fig.tight_layout() fig.canvas.set_window_title('{}_QSS_threhold_curve_{:.0f}-{:.0f}%DC_{}A{}'.format( pneuron.name, DCs.min() * 1e2, DCs.max() * 1e2, Ascale, '_with_comp' if comp else '' )) return fig