diff --git a/PySONIC/core/model.py b/PySONIC/core/model.py index a495fb8..de13266 100644 --- a/PySONIC/core/model.py +++ b/PySONIC/core/model.py @@ -1,82 +1,96 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-08-03 11:53:04 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-05-31 16:50:06 +# @Last Modified time: 2019-06-01 16:27:32 import pickle import abc import inspect +import numpy as np -from ..utils import logger +from ..utils import logger, debug class Model(metaclass=abc.ABCMeta): ''' Generic model interface. ''' @property @abc.abstractmethod def tscale(self): ''' Relevant temporal scale of the model. ''' raise NotImplementedError @property @abc.abstractmethod def __repr__(self): raise NotImplementedError @property @abc.abstractmethod def pprint(self): raise NotImplementedError @property @abc.abstractmethod def filecode(self, *args): raise NotImplementedError def getDesc(self): return inspect.getdoc(self).splitlines()[0] @property @abc.abstractmethod def getPltScheme(self): raise NotImplementedError @property @abc.abstractmethod def getPltVars(self, *args, **kwargs): raise NotImplementedError @property @abc.abstractmethod def checkInputs(self, *args): raise NotImplementedError @property @abc.abstractmethod def simulate(self, *args, **kwargs): raise NotImplementedError @property @abc.abstractmethod def meta(self, *args): raise NotImplementedError @property @abc.abstractmethod def createQueue(self, *args): raise NotImplementedError def runAndSave(self, outdir, *args): ''' Simulate system and save results in a PKL file. ''' + + # If no amplitude provided, perform titration to find it + if None in args: + iA = args.index(None) + new_args = [x for x in args if x is not None] + Athr = self.titrate(*new_args) + if np.isnan(Athr): + logger.error('Could not find threshold excitation amplitude') + return None + new_args.insert(iA, Athr) + args = new_args + + # Simulate model, save inf file and return file path data, tcomp = self.simulate(*args) meta = self.meta(*args) meta['tcomp'] = tcomp outpath = '{}/{}.pkl'.format(outdir, self.filecode(*args)) with open(outpath, 'wb') as fh: pickle.dump({'meta': meta, 'data': data}, fh) logger.debug('simulation data exported to "%s"', outpath) return outpath diff --git a/PySONIC/core/nbls.py b/PySONIC/core/nbls.py index 955ec04..b8373cb 100644 --- a/PySONIC/core/nbls.py +++ b/PySONIC/core/nbls.py @@ -1,704 +1,697 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-09-29 16:16:19 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-05-31 16:59:17 +# @Last Modified time: 2019-06-01 16:39:15 from copy import deepcopy import logging import numpy as np import pandas as pd from scipy.integrate import solve_ivp from scipy.interpolate import interp1d from .simulators import PWSimulator, HybridSimulator from .bls import BilayerSonophore from .pneuron import PointNeuron from .batches import createQueue from ..utils import * from ..constants import * from ..postpro import getFixedPoints class NeuronalBilayerSonophore(BilayerSonophore): ''' This class inherits from the BilayerSonophore class and receives an PointNeuron instance at initialization, to define the electro-mechanical NICE model and its SONIC variant. ''' tscale = 'ms' # relevant temporal scale of the model defvar = 'Q' # default plot variable def __init__(self, a, neuron, Fdrive=None, embedding_depth=0.0): ''' Constructor of the class. :param a: in-plane radius of the sonophore structure within the membrane (m) :param neuron: neuron object :param Fdrive: frequency of acoustic perturbation (Hz) :param embedding_depth: depth of the embedding tissue around the membrane (m) ''' # Check validity of input parameters if not isinstance(neuron, PointNeuron): raise ValueError('Invalid neuron type: "{}" (must inherit from PointNeuron class)' .format(neuron.name)) self.neuron = neuron # Initialize BilayerSonophore parent object BilayerSonophore.__init__(self, a, neuron.Cm0, neuron.Cm0 * neuron.Vm0 * 1e-3, embedding_depth) def __repr__(self): return 'NeuronalBilayerSonophore({}m, {})'.format( si_format(self.a, precision=1, space=' '), self.neuron) def pprint(self): return '{}m radius NBLS - {} neuron'.format( si_format(self.a, precision=0, space=' '), self.neuron.name) def getPltVars(self, wrapleft='df["', wrapright='"]'): pltvars = super().getPltVars(wrapleft, wrapright) pltvars.update(self.neuron.getPltVars(wrapleft, wrapright)) return pltvars def getPltScheme(self): return self.neuron.getPltScheme() def filecode(self, Fdrive, Adrive, tstim, toffset, PRF, DC, method): return 'ASTIM_{}_{}_{:.0f}nm_{:.0f}kHz_{:.2f}kPa_{:.0f}ms_{}{}'.format( self.neuron.name, 'CW' if DC == 1 else 'PW', self.a * 1e9, Fdrive * 1e-3, Adrive * 1e-3, tstim * 1e3, 'PRF{:.2f}Hz_DC{:.2f}%_'.format(PRF, DC * 1e2) if DC < 1. else '', method) def fullDerivatives(self, y, t, Adrive, Fdrive, phi): ''' Compute the derivatives of the (n+3) ODE full NBLS system variables. :param y: vector of state variables :param t: specific instant in time (s) :param Adrive: acoustic drive amplitude (Pa) :param Fdrive: acoustic drive frequency (Hz) :param phi: acoustic drive phase (rad) :return: vector of derivatives ''' dydt_mech = BilayerSonophore.derivatives(self, y[:3], t, Adrive, Fdrive, y[3], phi) dydt_elec = self.neuron.Qderivatives(y[3:], t, self.Capct(y[1])) return dydt_mech + dydt_elec def effDerivatives(self, y, t, lkp): ''' Compute the derivatives of the n-ODE effective HH system variables, based on 1-dimensional linear interpolation of "effective" coefficients that summarize the system's behaviour over an acoustic cycle. :param y: vector of HH system variables at time t :param t: specific instant in time (s) :param lkp: dictionary of 1D data points of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: vector of effective system derivatives at time t ''' # Split input vector explicitly Qm, *states = y # Compute charge and channel states variation Vmeff = self.neuron.interpVmeff(Qm, lkp) dQmdt = - self.neuron.iNet(Vmeff, states) * 1e-3 dstates = self.neuron.derEffStates(Qm, states, lkp) # Return derivatives vector return [dQmdt, *[dstates[k] for k in self.neuron.states]] def interpEffVariable(self, key, Qm, stim, lkp_on, lkp_off): ''' Interpolate Q-dependent effective variable along solution. :param key: lookup variable key :param Qm: charge density solution vector :param stim: stimulation state solution vector :param lkp_on: lookups for ON states :param lkp_off: lookups for OFF states :return: interpolated effective variable vector ''' x = np.zeros(stim.size) x[stim == 0] = np.interp( Qm[stim == 0], lkp_on['Q'], lkp_on[key], left=np.nan, right=np.nan) x[stim == 1] = np.interp( Qm[stim == 1], lkp_off['Q'], lkp_off[key], left=np.nan, right=np.nan) return x def runFull(self, Fdrive, Adrive, tstim, toffset, PRF, DC, phi=np.pi): ''' Compute solutions of the full electro-mechanical system for a specific set of US stimulation parameters, using a classic integration scheme. The first iteration uses the quasi-steady simplification to compute the initiation of motion from a flat leaflet configuration. Afterwards, the ODE system is solved iteratively until completion. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param phi: acoustic drive phase (rad) :return: 2-tuple with the output dataframe and computation time. ''' # Determine time step dt = 1 / (NPC_FULL * Fdrive) # Compute non-zero deflection value for a small perturbation (solving quasi-steady equation) Pac = self.Pacoustic(dt, Adrive, Fdrive, phi) Z0 = self.balancedefQS(self.ng0, self.Qm0, Pac) # Set initial conditions steady_states = self.neuron.steadyStates(self.neuron.Vm0) y0 = np.concatenate(( [0., Z0, self.ng0, self.Qm0], [steady_states[k] for k in self.neuron.states])) # Initialize simulator and compute solution logger.debug('Computing detailed solution') simulator = PWSimulator( lambda y, t: self.fullDerivatives(y, t, Adrive, Fdrive, phi), lambda y, t: self.fullDerivatives(y, t, 0., 0., 0.)) (t, y, stim), tcomp = simulator( y0, dt, tstim, toffset, PRF, DC, print_progress=logger.getEffectiveLevel() <= logging.INFO, target_dt=CLASSIC_TARGET_DT, monitor_time=True) logger.debug('completed in %ss', si_format(tcomp, 1)) # Store output in dataframe data = pd.DataFrame({ 't': t, 'stimstate': stim, 'Z': y[:, 1], 'ng': y[:, 2], 'Qm': y[:, 3] }) data['Vm'] = data['Qm'].values / self.v_Capct(data['Z'].values) * 1e3 # mV for i in range(len(self.neuron.states)): data[self.neuron.states[i]] = y[:, i + 4] # Return dataframe and computation time return data, tcomp - def runSONIC(self, Fdrive, Adrive, tstim, toffset, PRF, DC, dt=DT_EFF): + def runSONIC(self, Fdrive, Adrive, tstim, toffset, PRF, DC): ''' Compute solutions of the system for a specific set of US stimulation parameters, using charge-predicted "effective" coefficients to solve the HH equations at each step. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) - :param dt: integration time step (s) :return: 3-tuple with the time profile, the effective solution matrix and a state vector ''' # Load appropriate 2D lookups Aref, Qref, lookups2D, _ = getLookups2D(self.neuron.name, a=self.a, Fdrive=Fdrive) # Check that acoustic amplitude is within lookup range Adrive = isWithin('amplitude', Adrive, (Aref.min(), Aref.max())) # Interpolate 2D lookups at zero and US amplitude logger.debug('Interpolating lookups at A = %.2f kPa and A = 0', Adrive * 1e-3) lookups_on = {key: interp1d(Aref, y2D, axis=0)(Adrive) for key, y2D in lookups2D.items()} lookups_off = {key: interp1d(Aref, y2D, axis=0)(0.0) for key, y2D in lookups2D.items()} # Add reference charge vector to 1D lookup dictionaries lookups_on['Q'] = Qref lookups_off['Q'] = Qref # Set initial conditions steady_states = self.neuron.steadyStates(self.neuron.Vm0) y0 = np.insert( np.array([steady_states[k] for k in self.neuron.states]), 0, self.Qm0) # Initialize simulator and compute solution logger.debug('Computing effective solution') simulator = PWSimulator( lambda y, t: self.effDerivatives(y, t, lookups_on), lambda y, t: self.effDerivatives(y, t, lookups_off)) - (t, y, stim), tcomp = simulator(y0, dt, tstim, toffset, PRF, DC, monitor_time=True) + (t, y, stim), tcomp = simulator(y0, DT_EFF, tstim, toffset, PRF, DC, monitor_time=True) logger.debug('completed in %ss', si_format(tcomp, 1)) # Store output in dataframe data = pd.DataFrame({ 't': t, 'stimstate': stim, 'Qm': y[:, 0] }) for key in ['ng', 'V']: data[key] = self.interpEffVariable( key, data['Qm'].values, stim, lookups_on, lookups_off) data['Z'] = np.array([self.balancedefQS(ng, Qm) for ng, Qm in zip( data['ng'].values, data['Qm'].values)]) # m data['Vm'] = data['Qm'].values / self.v_Capct(data['Z'].values) * 1e3 # mV for i in range(len(self.neuron.states)): data[self.neuron.states[i]] = y[:, i + 1] # Return dataframe and computation time return data, tcomp def runHybrid(self, Fdrive, Adrive, tstim, toffset, PRF, DC, phi=np.pi): ''' Compute solutions of the system for a specific set of US stimulation parameters, using a hybrid integration scheme. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param phi: acoustic drive phase (rad) :return: 3-tuple with the time profile, the solution matrix and a state vector ''' # Determine time step dt_dense = 1 / (NPC_FULL * Fdrive) dt_sparse = 1 / (NPC_HH * Fdrive) # Compute non-zero deflection value for a small perturbation (solving quasi-steady equation) Pac = self.Pacoustic(dt_dense, Adrive, Fdrive, phi) Z0 = self.balancedefQS(self.ng0, self.Qm0, Pac) # Set initial conditions steady_states = self.neuron.steadyStates(self.neuron.Vm0) y0 = np.concatenate(( [0., Z0, self.ng0, self.Qm0], [steady_states[k] for k in self.neuron.states], )) is_dense_var = np.array([True] * 3 + [False] * (len(self.neuron.states) + 1)) # Initialize simulator and compute solution logger.debug('Computing hybrid solution') simulator = HybridSimulator( lambda y, t: self.fullDerivatives(y, t, Adrive, Fdrive, phi), lambda y, t: self.fullDerivatives(y, t, 0., 0., 0.), lambda t, y, Cm: self.neuron.Qderivatives(y, t, Cm), lambda yref: self.Capct(yref[1]), is_dense_var, ivars_to_check=[1, 2]) (t, y, stim), tcomp = simulator( y0, dt_dense, dt_sparse, Fdrive, tstim, toffset, PRF, DC, monitor_time=True) logger.debug('completed in %ss', si_format(tcomp, 1)) # Store output in dataframe data = pd.DataFrame({ 't': t, 'stimstate': stim, 'Z': y[:, 1], 'ng': y[:, 2], 'Qm': y[:, 3] }) data['Vm'] = data['Qm'].values / self.v_Capct(data['Z'].values) * 1e3 # mV for i in range(len(self.neuron.states)): data[self.neuron.states[i]] = y[:, i + 4] # Return dataframe and computation time return data, tcomp - def simulate(self, Fdrive, Adrive, tstim, toffset, PRF=None, DC=1.0, method='sonic'): + def simulate(self, Fdrive, Adrive, tstim, toffset, PRF=100., DC=1.0, method='sonic'): ''' Simulate the electro-mechanical model for a specific set of US stimulation parameters, and return output data in a dataframe. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param method: selected integration method :return: 2-tuple with the output dataframe and computation time. ''' logger.info( '%s: %s @ f = %sHz, %st = %ss (%ss offset)%s', self, 'titration' if Adrive is None else 'simulation', si_format(Fdrive, 0, space=' '), 'A = {}Pa, '.format(si_format(Adrive, 2, space=' ')) if Adrive is not None else '', *si_format([tstim, toffset], 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format( si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) - # TODO: If no amplitude provided, perform titration - if Adrive is None: - Adrive = self.titrate(Fdrive, tstim, toffset, PRF, DC, method=method) - if np.isnan(Adrive): - logger.error('Could not find threshold excitation amplitude') - return None - # Check validity of stimulation parameters BilayerSonophore.checkInputs(self, Fdrive, Adrive, 0.0, 0.0) self.neuron.checkInputs(Adrive, tstim, toffset, PRF, DC) # Call appropriate simulation function try: simfunc = { 'full': self.runFull, 'hybrid': self.runHybrid, 'sonic': self.runSONIC }[method] except KeyError: raise ValueError('Invalid integration method: "{}"'.format(method)) data, tcomp = simfunc(Fdrive, Adrive, tstim, toffset, PRF, DC) # Log number of detected spikes nspikes = self.neuron.getNSpikes(data) logger.debug('{} spike{} detected'.format(nspikes, plural(nspikes))) # Return dataframe and computation time return data, tcomp def meta(self, Fdrive, Adrive, tstim, toffset, PRF, DC, method): ''' Return information about object and simulation parameters. :param Fdrive: US frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: stimulus duration (s) :param toffset: stimulus offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: stimulus duty cycle (-) :param method: integration method :return: meta-data dictionary ''' return { 'neuron': self.neuron.name, 'a': self.a, 'd': self.d, 'Fdrive': Fdrive, 'Adrive': Adrive, 'tstim': tstim, 'toffset': toffset, 'PRF': PRF, 'DC': DC, 'method': method } - def titrate(self, Fdrive, tstim, toffset, PRF=None, DC=1.0, Arange=None, method='sonic'): + def titrate(self, Fdrive, tstim, toffset, PRF=100., DC=1., method='sonic', + xfunc=None, Arange=None): ''' Use a binary search to determine the threshold amplitude needed to obtain neural excitation for a given frequency, duration, PRF and duty cycle. :param Fdrive: US frequency (Hz) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) + :param method: integration method + :param xfunc: function determining whether condition is reached from simulation output :param Arange: search interval for Adrive, iteratively refined :return: determined threshold amplitude (Pa) ''' + # Default output function + if xfunc is None: + xfunc = self.neuron.titrationFunc - # Determine amplitude interval if needed + # Default amplitude interval if Arange is None: Arange = (0, getLookups2D(self.neuron.name, a=self.a, Fdrive=Fdrive)[0].max()) - # Determine output function - if self.neuron.isTitratable(): - xfunc = self.isExcited - else: - xfunc = self.isSilenced - - # Titrate - return titrate(xfunc, (Fdrive, tstim, toffset, PRF, DC, method), - Arange, TITRATION_ASTIM_DA_MAX) + return binarySearch( + lambda x: xfunc(self.simulate(*x)[0]), + [Fdrive, tstim, toffset, PRF, DC, method], 1, Arange, TITRATION_ASTIM_DA_MAX + ) def createQueue(self, freqs, amps, durations, offsets, PRFs, DCs, method): ''' Create a serialized 2D array of all parameter combinations for a series of individual parameter sweeps, while avoiding repetition of CW protocols for a given PRF sweep. :param freqs: list (or 1D-array) of US frequencies :param amps: list (or 1D-array) of acoustic amplitudes :param durations: list (or 1D-array) of stimulus durations :param offsets: list (or 1D-array) of stimulus offsets (paired with durations array) :param PRFs: list (or 1D-array) of pulse-repetition frequencies :param DCs: list (or 1D-array) of duty cycle values :params method: integration method :return: list of parameters (list) for each simulation ''' if amps is None: amps = [np.nan] DCs = np.array(DCs) queue = [] if 1.0 in DCs: queue += createQueue(freqs, amps, durations, offsets, min(PRFs), 1.0) if np.any(DCs != 1.0): queue += createQueue(freqs, amps, durations, offsets, PRFs, DCs[DCs != 1.0]) for item in queue: if np.isnan(item[1]): item[1] = None item.append(method) return queue def quasiSteadyStates(self, Fdrive, amps=None, charges=None, DCs=1.0, squeeze_output=False): ''' Compute the quasi-steady state values of the neuron's gating variables for a combination of US amplitudes, charge densities and duty cycles, at a specific US frequency. :param Fdrive: US frequency (Hz) :param amps: US amplitudes (Pa) :param charges: membrane charge densities (C/m2) :param DCs: duty cycle value(s) :return: 4-tuple with reference values of US amplitude and charge density, as well as interpolated Vmeff and QSS gating variables ''' # Get DC-averaged lookups interpolated at the appropriate amplitudes and charges amps, charges, lookups = getLookupsDCavg( self.neuron.name, self.a, Fdrive, amps, charges, DCs) # Compute QSS states using these lookups nA, nQ, nDC = lookups['V'].shape QSS = {k: np.empty((nA, nQ, nDC)) for k in self.neuron.states} for iA in range(nA): for iDC in range(nDC): QSS_1D = self.neuron.quasiSteadyStates( {k: v[iA, :, iDC] for k, v in lookups.items()}) for k in QSS.keys(): QSS[k][iA, :, iDC] = QSS_1D[k] # Compress outputs if needed if squeeze_output: QSS = {k: v.squeeze() for k, v in QSS.items()} lookups = {k: v.squeeze() for k, v in lookups.items()} # Return reference inputs and outputs return amps, charges, lookups, QSS def quasiSteadyStateiNet(self, Qm, Fdrive, Adrive, DC): ''' Compute quasi-steady state net membrane current for a given combination of US parameters and a given membrane charge density. :param Qm: membrane charge density (C/m2) :param Fdrive: US frequency (Hz) :param Adrive: US amplitude (Pa) :param DC: duty cycle (-) :return: net membrane current (mA/m2) ''' _, _, lookups, QSS = self.quasiSteadyStates( Fdrive, amps=Adrive, charges=Qm, DCs=DC, squeeze_output=True) return self.neuron.iNet(lookups['V'], np.array(list(QSS.values()))) # mA/m2 def evaluateStability(self, Qm0, states0, lkp): ''' Integrate the effective differential system from a given starting point, until clear convergence or clear divergence is found. :param Qm0: initial membrane charge density (C/m2) :param states0: dictionary of initial states values :param lkp: dictionary of 1D data points of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: boolean indicating convergence state ''' # Initialize y0 vector t0 = 0. y0 = np.array([Qm0] + list(states0.values())) # Initializing empty list to record evolution of charge deviation n = int(QSS_HISTORY_INTERVAL // QSS_INTEGRATION_INTERVAL) # size of history dQ = [] # As long as there is no clear charge convergence or divergence conv, div = False, False tf, yf = t0, y0 while not conv and not div: # Integrate system for small interval and retrieve final charge deviation t0, y0 = tf, yf sol = solve_ivp( lambda t, y: self.effDerivatives(y, t, lkp), [t0, t0 + QSS_INTEGRATION_INTERVAL], y0, method='LSODA' ) tf, yf = sol.t[-1], sol.y[:, -1] dQ.append(yf[0] - Qm0) # logger.debug('{:.0f} ms: dQ = {:.5f} nC/cm2, avg dQ = {:.5f} nC/cm2'.format( # tf * 1e3, dQ[-1] * 1e5, np.mean(dQ[-n:]) * 1e5)) # If last charge deviation is too large -> divergence if np.abs(dQ[-1]) > QSS_Q_DIV_THR: div = True # If last charge deviation or average deviation in recent history # is small enough -> convergence for x in [dQ[-1], np.mean(dQ[-n:])]: if np.abs(x) < QSS_Q_CONV_THR: conv = True # If max integration duration is been reached -> error if tf > QSS_MAX_INTEGRATION_DURATION: raise ValueError('too many iterations') logger.debug('{}vergence after {:.0f} ms: dQ = {:.5f} nC/cm2'.format( {True: 'con', False: 'di'}[conv], tf * 1e3, dQ[-1] * 1e5)) return conv def quasiSteadyStateFixedPoints(self, Fdrive, Adrive, DC, lkp, dQdt): ''' Compute QSS fixed points along the charge dimension for a given combination of US parameters, and determine their stability. :param Fdrive: US frequency (Hz) :param Adrive: US amplitude (Pa) :param DC: duty cycle (-) :param lkp: lookup dictionary for effective variables along charge dimension :param dQdt: charge derivative profile along charge dimension :return: 2-tuple with values of stable and unstable fixed points ''' logger.debug('A = {:.2f} kPa, DC = {:.0f}%'.format(Adrive * 1e-3, DC * 1e2)) # Extract stable and unstable fixed points from QSS charge variation profile dfunc = lambda Qm: - self.quasiSteadyStateiNet(Qm, Fdrive, Adrive, DC) SFP_candidates = getFixedPoints(lkp['Q'], dQdt, filter='stable', der_func=dfunc).tolist() UFPs = getFixedPoints(lkp['Q'], dQdt, filter='unstable', der_func=dfunc).tolist() SFPs = [] pltvars = self.getPltVars() # For each candidate SFP for i, Qm in enumerate(SFP_candidates): logger.debug('Q-SFP = {:.2f} nC/cm2'.format(Qm * 1e5)) # Re-compute QSS *_, QSS_FP = self.quasiSteadyStates(Fdrive, amps=Adrive, charges=Qm, DCs=DC, squeeze_output=True) # Simulate from unperturbed QSS and evaluate stability if not self.evaluateStability(Qm, QSS_FP, lkp): logger.warning('diverging system at ({:.2f} kPa, {:.2f} nC/cm2)'.format( Adrive * 1e-3, Qm * 1e5)) UFPs.append(Qm) else: # For each state unstable_states = [] for x in self.neuron.states: pltvar = pltvars[x] unit_str = pltvar.get('unit', '') factor = pltvar.get('factor', 1) is_stable_direction = [] for sign in [-1, +1]: # Perturb state with small offset QSS_perturbed = deepcopy(QSS_FP) QSS_perturbed[x] *= (1 + sign * QSS_REL_OFFSET) # If gating state, bound within [0., 1.] if self.neuron.isVoltageGated(x): QSS_perturbed[x] = np.clip(QSS_perturbed[x], 0., 1.) logger.debug('{}: {:.5f} -> {:.5f} {}'.format( x, QSS_FP[x] * factor, QSS_perturbed[x] * factor, unit_str)) # Simulate from perturbed QSS and evaluate stability is_stable_direction.append( self.evaluateStability(Qm, QSS_perturbed, lkp)) # Check if system shows stability upon x-state perturbation # in both directions if not np.all(is_stable_direction): unstable_states.append(x) # Classify fixed point as stable only if all states show stability is_stable_FP = len(unstable_states) == 0 {True: SFPs, False: UFPs}[is_stable_FP].append(Qm) logger.info('{}stable fixed-point at ({:.2f} kPa, {:.2f} nC/cm2){}'.format( '' if is_stable_FP else 'un', Adrive * 1e-3, Qm * 1e5, '' if is_stable_FP else ', caused by {} states'.format(unstable_states))) return SFPs, UFPs def findRheobaseAmps(self, DCs, Fdrive, Vthr): ''' Find the rheobase amplitudes (i.e. threshold acoustic amplitudes of infinite duration that would result in excitation) of a specific neuron for various duty cycles. :param DCs: duty cycles vector (-) :param Fdrive: acoustic drive frequency (Hz) :param Vthr: threshold membrane potential above which the neuron necessarily fires (mV) :return: rheobase amplitudes vector (Pa) ''' # Get threshold charge from neuron's spike threshold parameter Qthr = self.neuron.Cm0 * Vthr * 1e-3 # C/m2 # Get QSS variables for each amplitude at threshold charge Aref, _, Vmeff, QS_states = self.quasiSteadyStates(Fdrive, charges=Qthr, DCs=DCs) if DCs.size == 1: QS_states = QS_states.reshape((*QS_states.shape, 1)) Vmeff = Vmeff.reshape((*Vmeff.shape, 1)) # Compute 2D QSS charge variation array at Qthr dQdt = -self.neuron.iNet(Vmeff, QS_states) # Find the threshold amplitude that cancels dQdt for each duty cycle Arheobase = np.array([np.interp(0, dQdt[:, i], Aref, left=0., right=np.nan) for i in range(DCs.size)]) # Check if threshold amplitude is found for all DCs inan = np.where(np.isnan(Arheobase))[0] if inan.size > 0: if inan.size == Arheobase.size: logger.error( 'No rheobase amplitudes within [%s - %sPa] for the provided duty cycles', *si_format((Aref.min(), Aref.max()))) else: minDC = DCs[inan.max() + 1] logger.warning( 'No rheobase amplitudes within [%s - %sPa] below %.1f%% duty cycle', *si_format((Aref.min(), Aref.max())), minDC * 1e2) return Arheobase, Aref def computeEffVars(self, Fdrive, Adrive, Qm, fs): ''' Compute "effective" coefficients of the HH system for a specific combination of stimulus frequency, stimulus amplitude and charge density. A short mechanical simulation is run while imposing the specific charge density, until periodic stabilization. The HH coefficients are then averaged over the last acoustic cycle to yield "effective" coefficients. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param Qm: imposed charge density (C/m2) :param fs: list of sonophore membrane coverage fractions :return: list with computation time and a list of dictionaries of effective variables ''' # Run simulation and retrieve deflection and gas content vectors from last cycle data, tcomp = BilayerSonophore.simulate(self, Fdrive, Adrive, Qm) Z_last = data.loc[-NPC_FULL:, 'Z'].values # m Cm_last = self.v_Capct(Z_last) # F/m2 # For each coverage fraction effvars = [] for x in fs: # Compute membrane capacitance and membrane potential vectors Cm = x * Cm_last + (1 - x) * self.Cm0 # F/m2 Vm = Qm / Cm * 1e3 # mV # Compute average cycle value for membrane potential and rate constants effvars.append({'V': np.mean(Vm)}) effvars[-1].update(self.neuron.computeEffRates(Vm)) # Log process log = '{}: lookups @ {}Hz, {}Pa, {:.2f} nC/cm2'.format( self, *si_format([Fdrive, Adrive], precision=1, space=' '), Qm * 1e5) if len(fs) > 1: log += ', fs = {:.0f} - {:.0f}%'.format(fs.min() * 1e2, fs.max() * 1e2) log += ', tcomp = {:.3f} s'.format(tcomp) logger.info(log) # Return effective coefficients return [tcomp, effvars] diff --git a/PySONIC/core/pneuron.py b/PySONIC/core/pneuron.py index cf60dfe..9617fa1 100644 --- a/PySONIC/core/pneuron.py +++ b/PySONIC/core/pneuron.py @@ -1,607 +1,603 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-08-03 11:53:04 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-05-31 16:59:38 +# @Last Modified time: 2019-06-01 16:38:44 import abc import inspect import re import numpy as np import pandas as pd from .batches import createQueue from .model import Model from .simulators import PWSimulator from ..postpro import findPeaks from ..constants import * -from ..utils import si_format, logger, titrate, plural +from ..utils import si_format, logger, plural, binarySearch class PointNeuron(Model): ''' Generic point-neuron model interface. ''' tscale = 'ms' # relevant temporal scale of the model defvar = 'V' # default plot variable def __repr__(self): return self.__class__.__name__ def pprint(self): return '{} neuron'.format(self.__class__.__name__) def filecode(self, Astim, tstim, toffset, PRF, DC): ''' File naming convention. ''' return 'ESTIM_{}_{}_{:.1f}mA_per_m2_{:.0f}ms{}'.format( self.name, 'CW' if DC == 1 else 'PW', Astim, tstim * 1e3, '_PRF{:.2f}Hz_DC{:.2f}%'.format(PRF, DC * 1e2) if DC < 1. else '') @property @abc.abstractmethod def name(self): raise NotImplementedError @property @abc.abstractmethod def Cm0(self): raise NotImplementedError @property @abc.abstractmethod def Vm0(self): raise NotImplementedError @abc.abstractmethod def currents(self, Vm, states): ''' Compute all ionic currents per unit area. :param Vm: membrane potential (mV) :states: state probabilities of the ion channels :return: dictionary of ionic currents per unit area (mA/m2) ''' def iNet(self, Vm, states): ''' net membrane current :param Vm: membrane potential (mV) :states: states of ion channels gating and related variables :return: current per unit area (mA/m2) ''' return sum(self.currents(Vm, states).values()) def dQdt(self, Vm, states): ''' membrane charge density variation rate :param Vm: membrane potential (mV) :states: states of ion channels gating and related variables :return: variation rate (mA/m2) ''' return -self.iNet(Vm, states) - def isTitratable(self): - ''' Simple method returning whether the neuron can be titrated (defaults to True). ''' - return True + def titrationFunc(self, *args, **kwargs): + ''' Default titration function. ''' + return self.isExcited(*args, **kwargs) def currentToConcentrationRate(self, z_ion, depth): ''' Compute the conversion factor from a specific ionic current (in mA/m2) into a variation rate of submembrane ion concentration (in M/s). :param: z_ion: ion valence :param depth: submembrane depth (m) :return: conversion factor (Mmol.m-1.C-1) ''' return 1e-6 / (z_ion * depth * FARADAY) def nernst(self, z_ion, Cion_in, Cion_out, T): ''' Nernst potential of a specific ion given its intra and extracellular concentrations. :param z_ion: ion valence :param Cion_in: intracellular ion concentration :param Cion_out: extracellular ion concentration :param T: temperature (K) :return: ion Nernst potential (mV) ''' return (Rg * T) / (z_ion * FARADAY) * np.log(Cion_out / Cion_in) * 1e3 def vtrap(self, x, y): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x / y) - 1) def efun(self, x): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x) - 1) def ghkDrive(self, Vm, Z_ion, Cion_in, Cion_out, T): ''' Use the Goldman-Hodgkin-Katz equation to compute the electrochemical driving force of a specific ion species for a given membrane potential. :param Vm: membrane potential (mV) :param Cin: intracellular ion concentration (M) :param Cout: extracellular ion concentration (M) :param T: temperature (K) :return: electrochemical driving force of a single ion particle (mC.m-3) ''' x = Z_ion * FARADAY * Vm / (Rg * T) * 1e-3 # [-] eCin = Cion_in * self.efun(-x) # M eCout = Cion_out * self.efun(x) # M return FARADAY * (eCin - eCout) * 1e6 # mC/m3 def getDesc(self): return inspect.getdoc(self).splitlines()[0] def getCurrentsNames(self): return list(self.currents(np.nan, [np.nan] * len(self.states)).keys()) def getPltScheme(self): pltscheme = { 'Q_m': ['Qm'], 'V_m': ['Vm'] } pltscheme['I'] = self.getCurrentsNames() + ['iNet'] for cname in self.getCurrentsNames(): if 'Leak' not in cname: key = 'i_{{{}}}\ kin.'.format(cname[1:]) cargs = inspect.getargspec(getattr(self, cname))[0][1:] pltscheme[key] = [var for var in cargs if var not in ['Vm', 'Cai']] return pltscheme def getPltVars(self, wrapleft='df["', wrapright='"]'): ''' Return a dictionary with information about all plot variables related to the neuron. ''' pltvars = { 'Qm': { 'desc': 'membrane charge density', 'label': 'Q_m', 'unit': 'nC/cm^2', 'factor': 1e5, 'bounds': (-100, 50) }, 'Vm': { 'desc': 'membrane potential', 'label': 'V_m', 'unit': 'mV', 'y0': self.Vm0, 'bounds': (-150, 70) }, 'ELeak': { 'constant': 'obj.ELeak', 'desc': 'non-specific leakage current resting potential', 'label': 'V_{leak}', 'unit': 'mV', 'ls': '--', 'color': 'k' } } for cname in self.getCurrentsNames(): cfunc = getattr(self, cname) cargs = inspect.getargspec(cfunc)[0][1:] pltvars[cname] = { 'desc': inspect.getdoc(cfunc).splitlines()[0], 'label': 'I_{{{}}}'.format(cname[1:]), 'unit': 'A/m^2', 'factor': 1e-3, 'func': '{}({})'.format(cname, ', '.join(['{}{}{}'.format(wrapleft, a, wrapright) for a in cargs])) } for var in cargs: if var not in ['Vm', 'Cai']: vfunc = getattr(self, 'der{}{}'.format(var[0].upper(), var[1:])) desc = cname + re.sub( '^Evolution of', '', inspect.getdoc(vfunc).splitlines()[0]) pltvars[var] = { 'desc': desc, 'label': var, 'bounds': (-0.1, 1.1) } pltvars['iNet'] = { 'desc': inspect.getdoc(getattr(self, 'iNet')).splitlines()[0], 'label': 'I_{net}', 'unit': 'A/m^2', 'factor': 1e-3, 'func': 'iNet({0}Vm{1}, {2}{3}{4}.values.T)'.format( wrapleft, wrapright, wrapleft[:-1], self.states, wrapright[1:]), 'ls': '--', 'color': 'black' } pltvars['dQdt'] = { 'desc': inspect.getdoc(getattr(self, 'dQdt')).splitlines()[0], 'label': 'dQ_m/dt', 'unit': 'A/m^2', 'factor': 1e-3, 'func': 'dQdt({0}Vm{1}, {2}{3}{4}.values.T)'.format( wrapleft, wrapright, wrapleft[:-1], self.states, wrapright[1:]), 'ls': '--', 'color': 'black' } for x in self.getGates(): for rate in ['alpha', 'beta']: pltvars['{}{}'.format(rate, x)] = { 'label': '\\{}_{{{}}}'.format(rate, x), 'unit': 'ms^{-1}', 'factor': 1e-3 } return pltvars def getRatesNames(self, states): ''' Return a list of names of the alpha and beta rates of the neuron. ''' return list(sum( [['alpha{}'.format(x.lower()), 'beta{}'.format(x.lower())] for x in states], [] )) def Qm0(self): ''' Return the resting charge density (in C/m2). ''' return self.Cm0 * self.Vm0 * 1e-3 # C/cm2 @abc.abstractmethod def steadyStates(self, Vm): ''' Compute the steady-state values for a specific membrane potential value. :param Vm: membrane potential (mV) :return: dictionary of steady-states ''' @abc.abstractmethod def derStates(self, Vm, states): ''' Compute the derivatives of channel states. :param Vm: membrane potential (mV) :states: state probabilities of the ion channels :return: current per unit area (mA/m2) ''' @abc.abstractmethod def computeEffRates(self, Vm): ''' Get the effective rate constants of ion channels, averaged along an acoustic cycle, for future use in effective simulations. :param Vm: array of membrane potential values for an acoustic cycle (mV) :return: a dictionary of rate average constants (s-1) ''' def interpEffRates(self, Qm, lkp, keys=None): ''' Interpolate effective rate constants for a given charge density using reference lookup vectors. :param Qm: membrane charge density (C/m2) :states: state probabilities of the ion channels :param lkp: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: dictionary of interpolated rate constants ''' if keys is None: keys = self.rates return {k: np.interp(Qm, lkp['Q'], lkp[k], left=np.nan, right=np.nan) for k in keys} def interpVmeff(self, Qm, lkp): ''' Interpolate the effective membrane potential for a given charge density using reference lookup vectors. :param Qm: membrane charge density (C/m2) :param lkp: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: dictionary of interpolated rate constants ''' return np.interp(Qm, lkp['Q'], lkp['V'], left=np.nan, right=np.nan) @abc.abstractmethod def derEffStates(self, Qm, states, lkp): ''' Compute the effective derivatives of channel states, based on 1-dimensional linear interpolation of "effective" coefficients that summarize the system's behaviour over an acoustic cycle. :param Qm: membrane charge density (C/m2) :states: state probabilities of the ion channels :param lkp: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. ''' def Qbounds(self): ''' Determine bounds of membrane charge physiological range for a given neuron. ''' return np.array([np.round(self.Vm0 - 25.0), 50.0]) * self.Cm0 * 1e-3 # C/m2 def isVoltageGated(self, state): ''' Determine whether a given state is purely voltage-gated or not.''' return 'alpha{}'.format(state.lower()) in self.rates def getGates(self): ''' Retrieve the names of the neuron's states that match an ion channel gating. ''' gates = [] for x in self.states: if self.isVoltageGated(x): gates.append(x) return gates def qsStates(self, lkp, states): ''' Compute a collection of quasi steady states using the standard xinf = ax / (ax + Bx) equation. :param lkp: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: dictionary of quasi-steady states ''' return { x: lkp['alpha{}'.format(x)] / (lkp['alpha{}'.format(x)] + lkp['beta{}'.format(x)]) for x in states } @abc.abstractmethod def quasiSteadyStates(self, lkp): ''' Compute the quasi-steady states of a neuron for a range of membrane charge densities, based on 1-dimensional lookups interpolated at a given sonophore diameter, US frequency, US amplitude and duty cycle. :param lkp: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: dictionary of quasi-steady states ''' def getRates(self, Vm): ''' Compute the ion channels rate constants for a given membrane potential. :param Vm: membrane potential (mV) :return: a dictionary of rate constants and their values at the given potential. ''' rates = {} for x in self.getGates(): x = x.lower() alpha_str, beta_str = ['{}{}'.format(s, x.lower()) for s in ['alpha', 'beta']] inf_str, tau_str = ['{}inf'.format(x.lower()), 'tau{}'.format(x.lower())] if hasattr(self, 'alpha{}'.format(x)): alphax = getattr(self, alpha_str)(Vm) betax = getattr(self, beta_str)(Vm) elif hasattr(self, '{}inf'.format(x)): xinf = getattr(self, inf_str)(Vm) taux = getattr(self, tau_str)(Vm) alphax = xinf / taux betax = 1 / taux - alphax rates[alpha_str] = alphax rates[beta_str] = betax return rates def Vderivatives(self, y, t, Iinj): ''' Compute the derivatives of a V-cast HH system for a specific value of injected current. :param y: vector of HH system variables at time t :param t: time value (s, unused) :param Iinj: injected current (mA/m2) :return: vector of HH system derivatives at time t ''' Vm, *states = y Iionic = self.iNet(Vm, states) # mA/m2 dVmdt = (- Iionic + Iinj) / self.Cm0 # mV/s dstates = self.derStates(Vm, states) return [dVmdt, *[dstates[k] for k in self.states]] def Qderivatives(self, y, t, Cm=None): ''' Compute the derivatives of the n-ODE HH system variables, based on a value of membrane capacitance. :param y: vector of HH system variables at time t :param t: specific instant in time (s) :param Cm: membrane capacitance (F/m2) :return: vector of HH system derivatives at time t ''' if Cm is None: Cm = self.Cm0 Qm, *states = y Vm = Qm / Cm * 1e3 # mV dQmdt = - self.iNet(Vm, states) * 1e-3 # A/m2 dstates = self.derStates(Vm, states) return [dQmdt, *[dstates[k] for k in self.states]] def checkInputs(self, Astim, tstim, toffset, PRF, DC): ''' Check validity of electrical stimulation parameters. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) ''' # Check validity of stimulation parameters if not all(isinstance(param, float) for param in [Astim, tstim, toffset, DC]): raise TypeError('Invalid stimulation parameters (must be float typed)') if tstim <= 0: raise ValueError('Invalid stimulus duration: {} ms (must be strictly positive)' .format(tstim * 1e3)) if toffset < 0: raise ValueError('Invalid stimulus offset: {} ms (must be positive or null)' .format(toffset * 1e3)) if DC <= 0.0 or DC > 1.0: raise ValueError('Invalid duty cycle: {} (must be within ]0; 1])'.format(DC)) if DC < 1.0: if not isinstance(PRF, float): raise TypeError('Invalid PRF value (must be float typed)') if PRF is None: raise AttributeError('Missing PRF value (must be provided when DC < 1)') if PRF < 1 / tstim: raise ValueError('Invalid PRF: {} Hz (PR interval exceeds stimulus duration)' .format(PRF)) - def simulate(self, Astim, tstim, toffset, PRF=None, DC=1.0, dt=DT_ESTIM): + def simulate(self, Astim, tstim, toffset, PRF=100., DC=1.0): ''' Simulate a specific neuron model for a specific set of electrical parameters, and return output data in a dataframe. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) - :param dt: integration time step (s) :return: 2-tuple with the output dataframe and computation time. ''' logger.info( '%s: %s @ %st = %ss (%ss offset)%s', self, 'titration' if Astim is None else 'simulation', 'A = {}A/m2, '.format(si_format(Astim, 2, space=' ')) if Astim is not None else '', *si_format([tstim, toffset], 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format( si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) - # TODO: If no amplitude provided, perform titration - if Astim is None: - Astim = self.titrate(tstim, toffset, PRF, DC) - if np.isnan(Astim): - logger.error('Could not find threshold excitation amplitude') - return None - # Check validity of stimulation parameters self.checkInputs(Astim, tstim, toffset, PRF, DC) # Set initial conditions steady_states = self.steadyStates(self.Vm0) y0 = np.array([self.Vm0, *[steady_states[k] for k in self.states]]) # Initialize simulator and compute solution logger.debug('Computing solution') simulator = PWSimulator( lambda y, t: self.Vderivatives(y, t, Astim), lambda y, t: self.Vderivatives(y, t, 0.)) - (t, y, stim), tcomp = simulator(y0, dt, tstim, toffset, PRF, DC, monitor_time=True) + (t, y, stim), tcomp = simulator(y0, DT_ESTIM, tstim, toffset, PRF, DC, monitor_time=True) logger.debug('completed in %ss', si_format(tcomp, 1)) # Store output in dataframe data = pd.DataFrame({ 't': t, 'stimstate': stim, 'Vm': y[:, 0], 'Qm': y[:, 0] * self.Cm0 * 1e-3 }) data['Qm'] = data['Vm'].values * self.Cm0 * 1e-3 for i in range(len(self.states)): data[self.states[i]] = y[:, i + 1] # Log number of detected spikes nspikes = self.getNSpikes(data) logger.debug('{} spike{} detected'.format(nspikes, plural(nspikes))) # Return dataframe and computation time return data, tcomp def meta(self, Astim, tstim, toffset, PRF, DC): ''' Return information about object and simulation parameters. :param Astim: stimulus amplitude (mA/m2) :param tstim: stimulus duration (s) :param toffset: stimulus offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: stimulus duty cycle (-) :return: meta-data dictionary ''' return { 'neuron': self.name, 'Astim': Astim, 'tstim': tstim, 'toffset': toffset, 'PRF': PRF, 'DC': DC } def createQueue(self, amps, durations, offsets, PRFs, DCs): ''' Create a serialized 2D array of all parameter combinations for a series of individual parameter sweeps, while avoiding repetition of CW protocols for a given PRF sweep. :param amps: list (or 1D-array) of acoustic amplitudes :param durations: list (or 1D-array) of stimulus durations :param offsets: list (or 1D-array) of stimulus offsets (paired with durations array) :param PRFs: list (or 1D-array) of pulse-repetition frequencies :param DCs: list (or 1D-array) of duty cycle values :return: list of parameters (list) for each simulation ''' if amps is None: amps = [np.nan] DCs = np.array(DCs) queue = [] if 1.0 in DCs: queue += createQueue(amps, durations, offsets, min(PRFs), 1.0) if np.any(DCs != 1.0): queue += createQueue(amps, durations, offsets, PRFs, DCs[DCs != 1.0]) for item in queue: if np.isnan(item[0]): item[0] = None return queue def getNSpikes(self, data): ''' Compute number of spikes in charge profile of simulation output. :param data: dataframe containing output time series :return: number of detected spikes ''' dt = np.diff(data.ix[:1, 't'].values)[0] ipeaks, *_ = findPeaks( data['Qm'].values, SPIKE_MIN_QAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_QPROM ) return ipeaks.size def getStabilizationValue(self, data): ''' Determine stabilization value from the charge profile of a simulation output. :param data: dataframe containing output time series :return: charge stabilization value (or np.nan if no stabilization detected) ''' # Extract charge signal posterior to observation window t, Qm = [data[key].values for key in ['t', 'Qm']] Qm = y[2, t > TMIN_STABILIZATION] # Compute variation range Qm_range = np.ptp(Qm) logger.debug('%.2f nC/cm2 variation range over the last %.0f ms', Qm_range * 1e5, TMIN_STABILIZATION * 1e3) # Return final value only if stabilization is detected if np.ptp(Qm) < QSS_Q_DIV_THR: return Qm[-1] else: return np.nan def isExcited(self, data): ''' Determine if neuron is excited from simulation output. :param data: dataframe containing output time series :return: boolean stating whether neuron is excited or not ''' return self.getNSpikes(data) > 0 def isSilenced(self, data): ''' Determine if neuron is silenced from simulation output. :param data: dataframe containing output time series :return: boolean stating whether neuron is silenced or not ''' return not np.isinan(self.getStabilizationValue(data)) - def titrate(self, tstim, toffset, PRF=None, DC=1.0, Arange=(0., 2 * TITRATION_ESTIM_A_MAX), - xfunc=None): + def titrate(self, tstim, toffset, PRF, DC, xfunc=None, Arange=(0., 2 * TITRATION_ESTIM_A_MAX)): ''' Use a binary search to determine the threshold amplitude needed to obtain neural excitation for a given duration, PRF and duty cycle. :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) + :param xfunc: function determining whether condition is reached from simulation output :param Arange: search interval for Astim, iteratively refined :return: excitation threshold amplitude (mA/m2) ''' - # Determine output function + # Default output function if xfunc is None: - xfunc = self.isExcited - return titrate(xfunc, (tstim, toffset, PRF, DC), Arange, TITRATION_ESTIM_DA_MAX) + xfunc = self.titrationFunc + + return binarySearch( + lambda x: xfunc(self.simulate(*x)[0]), + [tstim, toffset, PRF, DC], 0, Arange, TITRATION_ESTIM_DA_MAX + ) diff --git a/PySONIC/core/simulators.py b/PySONIC/core/simulators.py index 08549a2..301b8eb 100644 --- a/PySONIC/core/simulators.py +++ b/PySONIC/core/simulators.py @@ -1,422 +1,422 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2019-05-28 14:45:12 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-05-31 13:32:29 +# @Last Modified time: 2019-05-31 17:23:43 import abc import numpy as np from scipy.integrate import ode, odeint from tqdm import tqdm from ..utils import * from ..constants import * class Simulator(metaclass=abc.ABCMeta): ''' Generic interface to simulator object. ''' def initialize(self, y0): ''' Initialize global arrays. :param y0: vector of initial conditions :return: 3-tuple with the initialized time vector, solution matrix and state vector ''' return np.array([0.]), np.atleast_2d(y0), np.array([1]) def appendSolution(self, t, y, stim, tnew, ynew, is_on): ''' Append to time vector, solution matrix and state vector. :param t: preceding time vector :param y: preceding solution matrix :param stim: preceding stimulation state vector :param tnew: integration time vector for current interval :param ynew: derivative function for current interval :param is_on: stimulation state for current interval :return: 3-tuple with the appended time vector, solution matrix and state vector ''' t = np.concatenate((t, tnew[1:])) y = np.concatenate((y, ynew[1:]), axis=0) stim = np.concatenate((stim, np.ones(tnew.size - 1) * is_on)) return t, y, stim def integrate(self, t, y, stim, tnew, func, is_on): ''' Integrate system for a time interval and append to preceding solution arrays. :param t: preceding time vector :param y: preceding solution matrix :param stim: preceding stimulation state vector :param tnew: integration time vector for current interval :param func: derivative function for current interval :param is_on: stimulation state for current interval :return: 3-tuple with the appended time vector, solution matrix and state vector ''' if tnew.size == 0: return t, y, stim ynew = odeint(func, y[-1], tnew) return self.appendSolution(t, y, stim, tnew, ynew, is_on) def resample(self, t, y, stim, target_dt): ''' Resample a solution to a new target time step. :param t: time vector :param y: solution matrix :param stim: stimulation state vector :target_dt: target time step after resampling :return: 3-tuple with the resampled time vector, solution matrix and state vector ''' dt = t[1] - t[0] rf = int(np.round(target_dt / dt)) assert rf >= 1, 'Hyper-sampling not supported' logger.debug( 'Downsampling output arrays by factor %u (Fs = %sHz)', rf, si_format(1 / (dt * rf), 2)) return t[::rf], y[::rf, :], stim[::rf] @property @abc.abstractmethod def compute(self): ''' Abstract compute method. ''' return 'Should never reach here' def __call__(self, *args, **kwargs): ''' Call and return compute method, with conditional time monitoring. ''' monitor_time = kwargs.pop('monitor_time') if monitor_time: start_time = time.perf_counter() output = self.compute(*args, **kwargs) if monitor_time: end_time = time.perf_counter() run_time = end_time - start_time output = output, run_time return output class PeriodicSimulator(Simulator): def __init__(self, dfunc, ivars_to_check=None): ''' Initialize simulator with specific derivative function :param dfunc: derivative function :param ivars_to_check: solution indexes of variables to check for stability ''' self.dfunc = dfunc self.ivars_to_check = ivars_to_check def getNPerCycle(self, dt, f): ''' Compute number of samples per cycle given a time step and a specific periodicity. :param dt: integration time step (s) :param f: periodic frequency (Hz) :return: number of samples per cycle ''' return int(np.round(1 / (f * dt))) + 1 def getTimeReference(self, dt, f): ''' Compute reference integration time vector for a specific periodicity. :param dt: integration time step (s) :param f: periodic frequency (Hz) :return: time vector for 1 periodic cycle ''' return np.linspace(0, 1 / f, self.getNPerCycle(dt, f)) def isPeriodicStable(self, y, npc, icycle): ''' Assess the periodic stabilization of a solution, by evaluating the deviation of system variables between two consecutive cycles. :param y: solution matrix :param npc: number of samples per cycle :param icycle: index of cycle of interest :return: boolean stating whether the solution is stable or not ''' # Extract the 2 cycles of interest from the solution y_target = y[icycle * npc: (icycle + 1) * npc, :] y_prec = y[(icycle - 1) * npc: icycle * npc, :] # For each variable of interest, evaluate the RMSE between the two cycles, the # variation range over the last cycle, and the ratio of these 2 quantities x_ratios = np.empty(len(self.ivars_to_check)) for i, ivar in enumerate(self.ivars_to_check): x_target, x_prec = y_target[:, ivar], y_prec[:, ivar] x_ratios[i] = rmse(x_target, x_prec) / np.ptp(x_target) # Classify the solution as stable only if all RMSE/PTP ratios are below critical threshold is_periodically_stable = np.all(x_ratios < MAX_RMSE_PTP_RATIO) logger.debug( 'step %u: ratios = [%s] -> %sstable', icycle, ', '.join(['{:.2e}'.format(r) for r in x_ratios]), '' if is_periodically_stable else 'un' ) return is_periodically_stable def compute(self, y0, dt, f, t0=0.): ''' Simulate system with a specific periodicity until stabilization. :param y0: 1D vector of initial conditions :param dt: integration time step (s) :param f: periodic frequency (Hz) :param t0: starting time :return: 3-tuple with the time profile, the effective solution matrix and a state vector ''' # If none specified, set all variables to be checked for stability if self.ivars_to_check is None: self.ivars_to_check = range(y0.size) # Get reference time vector tref = self.getTimeReference(dt, f) # Initialize global arrays t, y, stim = self.initialize(y0) # Integrate system for a few cycles until stabilization icycle = 0 conv = False while not conv and icycle < NCYCLES_MAX: t, y, stim = self.integrate(t, y, stim, tref + icycle / f, self.dfunc, True) if icycle > 0: conv = self.isPeriodicStable(y, tref.size - 1, icycle) icycle += 1 # Log stopping criterion if icycle == NCYCLES_MAX: logger.warning('No convergence: stopping after %u cycles', icycle) else: logger.debug('Periodic convergence after %u cycles', icycle) # Return output variables return t + t0, y, stim class PWSimulator(Simulator): def __init__(self, dfunc_on, dfunc_off): ''' Initialize simulator with specific derivative functions :param dfunc_on: derivative function for ON periods :param dfunc_off: derivative function for OFF periods ''' self.dfunc_on = dfunc_on self.dfunc_off = dfunc_off def getTimeReference(self, dt, tstim, toffset, PRF, DC): ''' Compute reference integration time vectors for a specific stimulus application pattern. :param dt: integration time step (s) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :return: 3-tuple with time vectors for stimulus ON and OFF periods and stimulus offset ''' # Compute vector sizes T_ON = DC / PRF T_OFF = (1 - DC) / PRF # For high-PRF pulsed protocols: adapt time step to ensure minimal # number of samples during TON or TOFF dt_warning_msg = 'high-PRF protocol: lowering time step to %.2e s to properly integrate %s' for key, T in {'TON': T_ON, 'TOFF': T_OFF}.items(): if T > 0 and T / dt < MIN_SAMPLES_PER_PULSE_INT: dt = T / MIN_SAMPLES_PER_PULSE_INT logger.warning(dt_warning_msg, dt, key) # Initializing accurate time vectors pulse ON and OFF periods, as well as offset t_on = np.linspace(0, T_ON, int(np.round(T_ON / dt)) + 1) t_off = np.linspace(T_ON, 1 / PRF, int(np.round(T_OFF / dt))) t_offset = np.linspace(tstim, tstim + toffset, int(np.round(toffset / dt))) return t_on, t_off, t_offset def adjustPRF(self, tstim, PRF, DC, print_progress): ''' Adjust the PRF in case of continuous wave stimulus, in order to obtain the desired number of integration interval(s) during stimulus. :param tstim: duration of US stimulation (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param print_progress: boolean specifying whether to show a progress bar :return: adjusted PRF value (Hz) ''' if DC < 1.0: # if PW stimuli, then no change return PRF else: # if CW stimuli, then divide integration according to presence of progress bar return {True: 100., False: 1.}[print_progress] / tstim def getNPulses(self, tstim, PRF): ''' Calculate number of pulses from stimulus temporal pattern. :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :return: number of pulses during the stimulus ''' return int(np.round(tstim * PRF)) def compute(self, y0, dt, tstim, toffset, PRF, DC, target_dt=None, print_progress=False): ''' Simulate system for a specific stimulus application pattern. :param y0: 1D vector of initial conditions :param dt: integration time step (s) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param t0: starting time :param target_dt: target time step after resampling :param print_progress: boolean specifying whether to show a progress bar :return: 3-tuple with the time profile, the effective solution matrix and a state vector ''' # Adjust PRF and get number of pulses PRF = self.adjustPRF(tstim, PRF, DC, print_progress) npulses = self.getNPulses(tstim, PRF) # Get reference time vectors t_on, t_off, t_offset = self.getTimeReference(dt, tstim, toffset, PRF, DC) # Initialize global arrays t, y, stim = self.initialize(y0) # Initialize progress bar if print_progress: setHandler(logger, TqdmHandler(my_log_formatter)) ntot = int(npulses * (tstim + toffset) / tstim) pbar = tqdm(total=ntot) # Integrate ON and OFF intervals of each pulse for i in range(npulses): for j, (tref, func) in enumerate(zip([t_on, t_off], [self.dfunc_on, self.dfunc_off])): t, y, stim = self.integrate(t, y, stim, tref + i / PRF, func, j == 0) # Update progress bar if print_progress: pbar.update(i) # Integrate offset interval t, y, stim = self.integrate(t, y, stim, t_offset, self.dfunc_off, False) # Terminate progress bar if print_progress: pbar.update(npulses) pbar.close() # Resample solution if specified if target_dt is not None: t, y, stim = self.resample(t, y, stim, target_dt) # Return output variables return t, y, stim class HybridSimulator(PWSimulator): def __init__(self, dfunc_on, dfunc_off, dfunc_sparse, predfunc, is_dense_var, ivars_to_check=None): ''' Initialize simulator with specific derivative functions :param dfunc_on: derivative function for ON periods :param dfunc_off: derivative function for OFF periods :param dfunc_sparse: derivative function for sparse integration :param predfunc: function computing the extra arguments necessary for sparse integration :param is_dense_var: boolean array stating for each variable if it evolves fast or not :param ivars_to_check: solution indexes of variables to check for stability ''' PWSimulator.__init__(self, dfunc_on, dfunc_off) self.sparse_solver = ode(dfunc_sparse) self.sparse_solver.set_integrator('dop853', nsteps=SOLVER_NSTEPS, atol=1e-12) self.predfunc = predfunc self.is_dense_var = is_dense_var self.is_sparse_var = np.invert(is_dense_var) self.ivars_to_check = ivars_to_check def integrate(self, t, y, stim, tnew, func, is_on): ''' Integrate system for a time interval and append to preceding solution arrays, using a hybrid scheme: - First, the full ODE system is integrated for a few cycles with a dense time granularity until it reaches a periodically stable behavior (limit cycle) - Second, the profiles of all variables over the last cycle are resampled to a far lower (i.e. sparse) sampling rate - Third, a subset of the ODE system is integrated with a sparse time granularity, for the remaining of the time interval, while the remaining variables are periodically expanded from their "stabilized" (last cycle) profile. :param t: preceding time vector :param y: preceding solution matrix :param stim: preceding stimulation state vector :param tnew: integration time vector for current interval :param func: derivative function for current interval :param is_on: stimulation state for current interval :return: 3-tuple with the appended time vector, solution matrix and state vector ''' if tnew.size == 0: return t, y, stim # Initialize periodic solver dense_solver = PeriodicSimulator(func, self.ivars_to_check) dt_dense = tnew[1] - tnew[0] npc_dense = dense_solver.getNPerCycle(dt_dense, self.f) # Until final integration time is reached while t[-1] < tnew[-1]: logger.debug('t = {:.5f} ms: starting new hybrid integration'.format(t[-1] * 1e3)) # Integrate dense system until convergence tdense, ydense, stimdense = dense_solver.compute(y[-1], dt_dense, self.f, t0=t[-1]) t, y, stim = self.appendSolution(t, y, stim, tdense, ydense, is_on) # Resample signals over last acoustic cycle to match sparse time step tlast, ylast, stimlast = self.resample( tdense[-npc_dense:], ydense[-npc_dense:], stimdense[-npc_dense:], self.dt_sparse) npc_sparse = tlast.size # Integrate until either the rest of the interval or max update interval is reached t0 = tdense[-1] tf = min(tnew[-1], tdense[0] + DT_UPDATE) nsparse = int(np.round((tf - t0) / self.dt_sparse)) tsparse = np.linspace(t0, tf, nsparse) ysparse = np.empty((nsparse, y.shape[1])) ysparse[0] = y[-1] self.sparse_solver.set_initial_value(y[-1, self.is_sparse_var], t[-1]) for j in range(1, tsparse.size): self.sparse_solver.set_f_params(self.predfunc(ylast[j % npc_sparse])) self.sparse_solver.integrate(tsparse[j]) if not self.sparse_solver.successful(): raise ValueError( 'integration error at t = {:.5f} ms'.format(tsparse[j] * 1e3)) ysparse[j, self.is_dense_var] = ylast[j % npc_sparse, self.is_dense_var] ysparse[j, self.is_sparse_var] = self.sparse_solver.y t, y, stim = self.appendSolution(t, y, stim, tsparse, ysparse, is_on) return t, y, stim def compute(self, y0, dt_dense, dt_sparse, f, tstim, toffset, PRF, DC, print_progress=False): ''' Simulate system for a specific stimulus application pattern. :param y0: 1D vector of initial conditions :param dt_dense: dense integration time step (s) :param dt_sparse: sparse integration time step (s) :param f: periodic frequency (Hz) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param print_progress: boolean specifying whether to show a progress bar :return: 3-tuple with the time profile, the effective solution matrix and a state vector ''' # Set periodicity and sparse time step self.f = f self.dt_sparse = dt_sparse # Call and return parent compute method return PWSimulator.compute( self, y0, dt_dense, tstim, toffset, PRF, DC, target_dt=None, print_progress=print_progress) diff --git a/PySONIC/neurons/stn.py b/PySONIC/neurons/stn.py index bc7f6a5..1575e2e 100644 --- a/PySONIC/neurons/stn.py +++ b/PySONIC/neurons/stn.py @@ -1,624 +1,618 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2018-11-29 16:56:45 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-05-17 16:15:22 +# @Last Modified time: 2019-06-01 16:37:32 import numpy as np from scipy.optimize import brentq from ..core import PointNeuron from ..constants import FARADAY, Z_Ca class OtsukaSTN(PointNeuron): ''' Sub-thalamic nucleus neuron References: *Otsuka, T., Abe, T., Tsukagawa, T., and Song, W.-J. (2004). Conductance-Based Model of the Voltage-Dependent Generation of a Plateau Potential in Subthalamic Neurons. Journal of Neurophysiology 92, 255–264.* *Tarnaud, T., Joseph, W., Martens, L., and Tanghe, E. (2018). Computational Modeling of Ultrasonic Subthalamic Nucleus Stimulation. IEEE Trans Biomed Eng.* ''' name = 'STN' # Resting parameters Cm0 = 1e-2 # Cell membrane resting capacitance (F/m2) Vm0 = -58.0 # Resting membrane potential (mV) Cai0 = 5e-9 # M (5 nM) # Reversal potentials ENa = 60.0 # Sodium Nernst potential (mV) EK = -90.0 # Potassium Nernst potential (mV) # Physical constants T = 306.15 # K (33°C) # Calcium dynamics Cao = 2e-3 # M (2 mM) taur_Cai = 0.5e-3 # decay time constant for intracellular Ca2+ dissolution (s) # Leakage current gLeak = 3.5 # Conductance of non-specific leakage current (S/m^2) ELeak = -60.0 # Leakage reversal potential (mV) # Fast Na current gNabar = 490.0 # Max. conductance of Sodium current (S/m^2) thetax_m = -40 # mV thetax_h = -45.5 # mV kx_m = -8 # mV kx_h = 6.4 # mV tau0_m = 0.2 * 1e-3 # s tau1_m = 3 * 1e-3 # s tau0_h = 0 * 1e-3 # s tau1_h = 24.5 * 1e-3 # s thetaT_m = -53 # mV thetaT1_h = -50 # mV thetaT2_h = -50 # mV sigmaT_m = -0.7 # mV sigmaT1_h = -15 # mV sigmaT2_h = 16 # mV # Delayed rectifier K+ current gKdbar = 570.0 # Max. conductance of delayed-rectifier Potassium current (S/m^2) thetax_n = -41 # mV kx_n = -14 # mV tau0_n = 0 * 1e-3 # s tau1_n = 11 * 1e-3 # s thetaT1_n = -40 # mV thetaT2_n = -40 # mV sigmaT1_n = -40 # mV sigmaT2_n = 50 # mV # T-type Ca2+ current gCaTbar = 50.0 # Max. conductance of low-threshold Calcium current (S/m^2) thetax_p = -56 # mV thetax_q = -85 # mV kx_p = -6.7 # mV kx_q = 5.8 # mV tau0_p = 5 * 1e-3 # s tau1_p = 0.33 * 1e-3 # s tau0_q = 0 * 1e-3 # s tau1_q = 400 * 1e-3 # s thetaT1_p = -27 # mV thetaT2_p = -102 # mV thetaT1_q = -50 # mV thetaT2_q = -50 # mV sigmaT1_p = -10 # mV sigmaT2_p = 15 # mV sigmaT1_q = -15 # mV sigmaT2_q = 16 # mV # L-type Ca2+ current gCaLbar = 150.0 # Max. conductance of high-threshold Calcium current (S/m^2) thetax_c = -30.6 # mV thetax_d1 = -60 # mV thetax_d2 = 0.1 * 1e-6 # M kx_c = -5 # mV kx_d1 = 7.5 # mV kx_d2 = 0.02 * 1e-6 # M tau0_c = 45 * 1e-3 # s tau1_c = 10 * 1e-3 # s tau0_d1 = 400 * 1e-3 # s tau1_d1 = 500 * 1e-3 # s tau_d2 = 130 * 1e-3 # s thetaT1_c = -27 # mV thetaT2_c = -50 # mV thetaT1_d1 = -40 # mV thetaT2_d1 = -20 # mV sigmaT1_c = -20 # mV sigmaT2_c = 15 # mV sigmaT1_d1 = -15 # mV sigmaT2_d1 = 20 # mV # A-type K+ current gAbar = 50.0 # Max. conductance of A-type Potassium current (S/m^2) thetax_a = -45 # mV thetax_b = -90 # mV kx_a = -14.7 # mV kx_b = 7.5 # mV tau0_a = 1 * 1e-3 # s tau1_a = 1 * 1e-3 # s tau0_b = 0 * 1e-3 # s tau1_b = 200 * 1e-3 # s thetaT_a = -40 # mV thetaT1_b = -60 # mV thetaT2_b = -40 # mV sigmaT_a = -0.5 # mV sigmaT1_b = -30 # mV sigmaT2_b = 10 # mV # Ca2+-activated K+ current gKCabar = 10.0 # Max. conductance of Calcium-dependent Potassium current (S/m^2) thetax_r = 0.17 * 1e-6 # M kx_r = -0.08 * 1e-6 # M tau_r = 2 * 1e-3 # s - def __init__(self): self.states = ['a', 'b', 'c', 'd1', 'd2', 'm', 'h', 'n', 'p', 'q', 'r', 'Cai'] self.rates = self.getRatesNames(['a', 'b', 'c', 'd1', 'm', 'h', 'n', 'p', 'q']) self.deff = self.getEffectiveDepth(self.Cai0, self.Vm0) # m self.iCa_to_Cai_rate = self.currentToConcentrationRate(Z_Ca, self.deff) # Mmol.m-1.C-1 - def getPltScheme(self): pltscheme = super().getPltScheme() pltscheme['[Ca^{2+}]_i'] = ['Cai'] return pltscheme - def getPltVars(self, wrapleft='df["', wrapright='"]'): pltvars = super().getPltVars(wrapleft, wrapright) pltvars['Cai'] = { 'desc': 'submembrane Ca2+ concentration', 'label': '[Ca^{2+}]_i', 'unit': 'uM', 'factor': 1e6 } return pltvars - - def isTitratable(self): - ''' Overriding parent method to False, since this neuron has spontaneous activity. ''' - return False - + def titrationFunc(self, *args, **kwargs): + ''' Overriding default titration function. ''' + return self.isSilenced(*args, **kwargs) def getEffectiveDepth(self, Cai, Vm): ''' Compute effective depth that matches a given membrane potential and intracellular Calcium concentration. :return: effective depth (m) ''' iCaT = self.iCaT(self.pinf(Vm), self.qinf(Vm), Vm, Cai) # mA/m2 iCaL = self.iCaL(self.cinf(Vm), self.d1inf(Vm), self.d2inf(Cai), Vm, Cai) # mA/m2 return -(iCaT + iCaL) / (Z_Ca * FARADAY * Cai / self.taur_Cai) * 1e-6 # m - def _xinf(self, var, theta, k): ''' Generic function computing the steady-state opening of a particular channel gate at a given voltage or ion concentration. :param var: membrane potential (mV) or ion concentration (mM) :param theta: half-(in)activation voltage or concentration (mV or mM) :param k: slope parameter of (in)activation function (mV or mM) :return: steady-state opening (-) ''' return 1 / (1 + np.exp((var - theta) / k)) def ainf(self, Vm): return self._xinf(Vm, self.thetax_a, self.kx_a) def binf(self, Vm): return self._xinf(Vm, self.thetax_b, self.kx_b) def cinf(self, Vm): return self._xinf(Vm, self.thetax_c, self.kx_c) def d1inf(self, Vm): return self._xinf(Vm, self.thetax_d1, self.kx_d1) def d2inf(self, Cai): return self._xinf(Cai, self.thetax_d2, self.kx_d2) def minf(self, Vm): return self._xinf(Vm, self.thetax_m, self.kx_m) def hinf(self, Vm): return self._xinf(Vm, self.thetax_h, self.kx_h) def ninf(self, Vm): return self._xinf(Vm, self.thetax_n, self.kx_n) def pinf(self, Vm): return self._xinf(Vm, self.thetax_p, self.kx_p) def qinf(self, Vm): return self._xinf(Vm, self.thetax_q, self.kx_q) def rinf(self, Cai): return self._xinf(Cai, self.thetax_r, self.kx_r) def _taux1(self, Vm, theta, sigma, tau0, tau1): ''' Generic function computing the voltage-dependent, activation/inactivation time constant of a particular ion channel at a given voltage (first variant). :param Vm: membrane potential (mV) :param theta: voltage at which (in)activation time constant is half-maximal (mV) :param sigma: slope parameter of (in)activation time constant function (mV) :param tau0: minimal time constant (s) :param tau1: modulated time constant (s) :return: (in)activation time constant (s) ''' return tau0 + tau1 / (1 + np.exp(-(Vm - theta) / sigma)) def taua(self, Vm): return self._taux1(Vm, self.thetaT_a, self.sigmaT_a, self.tau0_a, self.tau1_a) def taum(self, Vm): return self._taux1(Vm, self.thetaT_m, self.sigmaT_m, self.tau0_m, self.tau1_m) def _taux2(self, Vm, theta1, theta2, sigma1, sigma2, tau0, tau1): ''' Generic function computing the voltage-dependent, activation/inactivation time constant of a particular ion channel at a given voltage (second variant). :param Vm: membrane potential (mV) :param theta: voltage at which (in)activation time constant is half-maximal (mV) :param sigma: slope parameter of (in)activation time constant function (mV) :param tau0: minimal time constant (s) :param tau1: modulated time constant (s) :return: (in)activation time constant (s) ''' return tau0 + tau1 / (np.exp(-(Vm - theta1) / sigma1) + np.exp(-(Vm - theta2) / sigma2)) def taub(self, Vm): return self._taux2(Vm, self.thetaT1_b, self.thetaT2_b, self.sigmaT1_b, self.sigmaT2_b, self.tau0_b, self.tau1_b) def tauc(self, Vm): return self._taux2(Vm, self.thetaT1_c, self.thetaT2_c, self.sigmaT1_c, self.sigmaT2_c, self.tau0_c, self.tau1_c) def taud1(self, Vm): return self._taux2(Vm, self.thetaT1_d1, self.thetaT2_d1, self.sigmaT1_d1, self.sigmaT2_d1, self.tau0_d1, self.tau1_d1) def tauh(self, Vm): return self._taux2(Vm, self.thetaT1_h, self.thetaT2_h, self.sigmaT1_h, self.sigmaT2_h, self.tau0_h, self.tau1_h) def taun(self, Vm): return self._taux2(Vm, self.thetaT1_n, self.thetaT2_n, self.sigmaT1_n, self.sigmaT2_n, self.tau0_n, self.tau1_n) def taup(self, Vm): return self._taux2(Vm, self.thetaT1_p, self.thetaT2_p, self.sigmaT1_p, self.sigmaT2_p, self.tau0_p, self.tau1_p) def tauq(self, Vm): return self._taux2(Vm, self.thetaT1_q, self.thetaT2_q, self.sigmaT1_q, self.sigmaT2_q, self.tau0_q, self.tau1_q) def derA(self, Vm, a): ''' Evolution of a-gate open-probability :param Vm: membrane potential (mV) :param a: open-probability of a-gate (-) :return: time derivative of a-gate open-probability (s-1) ''' return (self.ainf(Vm) - a) / self.taua(Vm) def derB(self, Vm, b): ''' Evolution of b-gate open-probability :param Vm: membrane potential (mV) :param b: open-probability of b-gate (-) :return: time derivative of b-gate open-probability (s-1) ''' return (self.binf(Vm) - b) / self.taub(Vm) def derC(self, Vm, c): ''' Evolution of c-gate open-probability :param Vm: membrane potential (mV) :param c: open-probability of c-gate (-) :return: time derivative of c-gate open-probability (s-1) ''' return (self.cinf(Vm) - c) / self.tauc(Vm) def derD1(self, Vm, d1): ''' Evolution of d1-gate open-probability :param Vm: membrane potential (mV) :param d1: open-probability of d1-gate (-) :return: time derivative of d1-gate open-probability (s-1) ''' return (self.d1inf(Vm) - d1) / self.taud1(Vm) def derD2(self, Cai, d2): ''' Evolution of Calcium-dependent d2-gate open-probability :param Vm: membrane potential (mV) :param d2: open-probability of d2-gate (-) :return: time derivative of d2-gate open-probability (s-1) ''' return (self.d2inf(Cai) - d2) / self.tau_d2 def derM(self, Vm, m): ''' Evolution of m-gate open-probability :param Vm: membrane potential (mV) :param m: open-probability of m-gate (-) :return: time derivative of m-gate open-probability (s-1) ''' return (self.minf(Vm) - m) / self.taum(Vm) def derH(self, Vm, h): ''' Evolution of h-gate open-probability :param Vm: membrane potential (mV) :param h: open-probability of h-gate (-) :return: time derivative of h-gate open-probability (s-1) ''' return (self.hinf(Vm) - h) / self.tauh(Vm) def derN(self, Vm, n): ''' Evolution of n-gate open-probability :param Vm: membrane potential (mV) :param n: open-probability of n-gate (-) :return: time derivative of n-gate open-probability (s-1) ''' return (self.ninf(Vm) - n) / self.taun(Vm) def derP(self, Vm, p): ''' Evolution of p-gate open-probability :param Vm: membrane potential (mV) :param p: open-probability of p-gate (-) :return: time derivative of p-gate open-probability (s-1) ''' return (self.pinf(Vm) - p) / self.taup(Vm) def derQ(self, Vm, q): ''' Evolution of q-gate open-probability :param Vm: membrane potential (mV) :param q: open-probability of q-gate (-) :return: time derivative of q-gate open-probability (s-1) ''' return (self.qinf(Vm) - q) / self.tauq(Vm) def derR(self, Cai, r): ''' Evolution of Calcium-dependent r-gate open-probability :param Vm: membrane potential (mV) :param s: open-probability of r-gate (-) :return: time derivative of r-gate open-probability (s-1) ''' return (self.rinf(Cai) - r) / self.tau_r def derCai(self, p, q, c, d1, d2, Cai, Vm): ''' Evolution of Calcium concentration in submembrane space. :param p: open-probability of p-gate :param q: open-probability of q-gate :param c: open-probability of c-gate :param d1: open-probability of d1-gate :param d2: open-probability of d2-gate :param Cai: Calcium concentration in submembranal space (M) :param Vm: membrane potential (mV) :return: time derivative of Calcium concentration in submembrane space (M/s) ''' iCaT = self.iCaT(p, q, Vm, Cai) iCaL = self.iCaL(c, d1, d2, Vm, Cai) return - self.iCa_to_Cai_rate * (iCaT + iCaL) - Cai / self.taur_Cai def Caiinf(self, Vm, p, q, c, d1): ''' Find the steady-state intracellular Calcium concentration for a specific membrane potential and voltage-gated channel states. :param Vm: membrane potential (mV) :param p: open-probability of p-gate :param q: open-probability of q-gate :param c: open-probability of c-gate :param d1: open-probability of d1-gate :return: steady-state Calcium concentration in submembrane space (M) ''' if isinstance(Vm, np.ndarray): return np.array([self.Caiinf(Vm[i], p[i], q[i], c[i], d1[i]) for i in range(Vm.size)]) else: return brentq( lambda x: self.derCai(p, q, c, d1, self.d2inf(x), x, Vm), self.Cai0 * 1e-4, self.Cai0 * 1e3, xtol=1e-16 ) def iNa(self, m, h, Vm): ''' Sodium current :param m: open-probability of m-gate (-) :param h: open-probability of h-gate (-) :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.gNabar * m**3 * h * (Vm - self.ENa) def iKd(self, n, Vm): ''' delayed-rectifier Potassium current :param n: open-probability of n-gate (-) :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.gKdbar * n**4 * (Vm - self.EK) def iA(self, a, b, Vm): ''' A-type Potassium current :param a: open-probability of a-gate (-) :param b: open-probability of b-gate (-) :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.gAbar * a**2 * b * (Vm - self.EK) def iCaT(self, p, q, Vm, Cai): ''' low-threshold (T-type) Calcium current :param p: open-probability of p-gate (-) :param q: open-probability of q-gate (-) :param Vm: membrane potential (mV) :param Cai: submembrane Calcium concentration (M) :return: current per unit area (mA/m2) ''' return self.gCaTbar * p**2 * q * (Vm - self.nernst(Z_Ca, Cai, self.Cao, self.T)) def iCaL(self, c, d1, d2, Vm, Cai): ''' high-threshold (L-type) Calcium current :param c: open-probability of c-gate (-) :param d1: open-probability of d1-gate (-) :param d2: open-probability of d2-gate (-) :param Vm: membrane potential (mV) :param Cai: submembrane Calcium concentration (M) :return: current per unit area (mA/m2) ''' return self.gCaLbar * c**2 * d1 * d2 * (Vm - self.nernst(Z_Ca, Cai, self.Cao, self.T)) def iKCa(self, r, Vm): ''' Calcium-activated Potassium current :param r: open-probability of r-gate (-) :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.gKCabar * r**2 * (Vm - self.EK) def iLeak(self, Vm): ''' non-specific leakage current :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.gLeak * (Vm - self.ELeak) def currents(self, Vm, states): ''' Overriding of abstract parent method. ''' a, b, c, d1, d2, m, h, n, p, q, r, Cai = states return { 'iNa': self.iNa(m, h, Vm), 'iKd': self.iKd(n, Vm), 'iA': self.iA(a, b, Vm), 'iCaT': self.iCaT(p, q, Vm, Cai), 'iCaL': self.iCaL(c, d1, d2, Vm, Cai), 'iKCa': self.iKCa(r, Vm), 'iLeak': self.iLeak(Vm) } # mA/m2 def steadyStates(self, Vm): ''' Overriding of abstract parent method. ''' # voltage-gated steady states sstates = { 'a': self.ainf(Vm), 'b': self.binf(Vm), 'c': self.cinf(Vm), 'd1': self.d1inf(Vm), 'm': self.minf(Vm), 'h': self.hinf(Vm), 'n': self.ninf(Vm), 'p': self.pinf(Vm), 'q': self.qinf(Vm) } sstates['Cai'] = self.Caiinf(Vm, sstates['p'], sstates['q'], sstates['c'], sstates['d1']) sstates['d2'] = self.d2inf(sstates['Cai']) sstates['r'] = self.rinf(sstates['Cai']) return sstates def derStates(self, Vm, states): ''' Overriding of abstract parent method. ''' a, b, c, d1, d2, m, h, n, p, q, r, Cai = states return { 'a': self.derA(Vm, a), 'b': self.derB(Vm, b), 'c': self.derC(Vm, c), 'd1': self.derD1(Vm, d1), 'd2': self.derD2(Cai, d2), 'm': self.derM(Vm, m), 'h': self.derH(Vm, h), 'n': self.derN(Vm, n), 'p': self.derP(Vm, p), 'q': self.derQ(Vm, q), 'r': self.derR(Cai, r), 'Cai': self.derCai(p, q, c, d1, d2, Cai, Vm), } def computeEffRates(self, Vm): ''' Overriding of abstract parent method. ''' # Compute average cycle value for rate constants return { 'alphaa': np.mean(self.ainf(Vm) / self.taua(Vm)), 'betaa': np.mean((1 - self.ainf(Vm)) / self.taua(Vm)), 'alphab': np.mean(self.binf(Vm) / self.taub(Vm)), 'betab': np.mean((1 - self.binf(Vm)) / self.taub(Vm)), 'alphac': np.mean(self.cinf(Vm) / self.tauc(Vm)), 'betac': np.mean((1 - self.cinf(Vm)) / self.tauc(Vm)), 'alphad1': np.mean(self.d1inf(Vm) / self.taud1(Vm)), 'betad1': np.mean((1 - self.d1inf(Vm)) / self.taud1(Vm)), 'alpham': np.mean(self.minf(Vm) / self.taum(Vm)), 'betam': np.mean((1 - self.minf(Vm)) / self.taum(Vm)), 'alphah': np.mean(self.hinf(Vm) / self.tauh(Vm)), 'betah': np.mean((1 - self.hinf(Vm)) / self.tauh(Vm)), 'alphan': np.mean(self.ninf(Vm) / self.taun(Vm)), 'betan': np.mean((1 - self.ninf(Vm)) / self.taun(Vm)), 'alphap': np.mean(self.pinf(Vm) / self.taup(Vm)), 'betap': np.mean((1 - self.pinf(Vm)) / self.taup(Vm)), 'alphaq': np.mean(self.qinf(Vm) / self.tauq(Vm)), 'betaq': np.mean((1 - self.qinf(Vm)) / self.tauq(Vm)) } def derEffStates(self, Qm, states, lkp): ''' Overriding of abstract parent method. ''' rates = self.interpEffRates(Qm, lkp) Vmeff = self.interpVmeff(Qm, lkp) a, b, c, d1, d2, m, h, n, p, q, r, Cai = states return { 'a': rates['alphaa'] * (1 - a) - rates['betaa'] * a, 'b': rates['alphab'] * (1 - b) - rates['betab'] * b, 'c': rates['alphac'] * (1 - c) - rates['betac'] * c, 'd1': rates['alphad1'] * (1 - d1) - rates['betad1'] * d1, 'd2': self.derD2(Cai, d2), 'm': rates['alpham'] * (1 - m) - rates['betam'] * m, 'h': rates['alphah'] * (1 - h) - rates['betah'] * h, 'n': rates['alphan'] * (1 - n) - rates['betan'] * n, 'p': rates['alphap'] * (1 - p) - rates['betap'] * p, 'q': rates['alphaq'] * (1 - q) - rates['betaq'] * q, 'r': self.derR(Cai, r), 'Cai': self.derCai(p, q, c, d1, d2, Cai, Vmeff) } def quasiSteadyStates(self, lkp): ''' Overriding of abstract parent method. ''' qsstates = self.qsStates(lkp, ['a', 'b', 'c', 'd1', 'm', 'h', 'n', 'p', 'q']) qsstates['Cai'] = self.Caiinf(lkp['V'], qsstates['p'], qsstates['q'], qsstates['c'], qsstates['d1']) qsstates['d2'] = self.d2inf(qsstates['Cai']) qsstates['r'] = self.rinf(qsstates['Cai']) return qsstates diff --git a/PySONIC/neurons/titrations.log b/PySONIC/neurons/titrations.log index 66329f9..e82cac6 100644 --- a/PySONIC/neurons/titrations.log +++ b/PySONIC/neurons/titrations.log @@ -1,8 +1,10 @@ titrate(, (500000.0, 0.03, 0.05, 100.0, 1.0, 'sonic'), (0, 599999.9999999997), 100.0) 80181.88476562494 titrate(, (500000.0, 0.04, 0.05, 100.0, 1.0, 'sonic'), (0, 599999.9999999997), 100.0) 59902.95410156247 titrate(, (500000.0, 1.0, 0.0, 100.0, 1.0, 'sonic'), (0, 599999.9999999997), 100.0) 19830.322265624985 titrate(, (500000.0, 0.033, 0.05, 100.0, 1.0, 'sonic'), (0, 599999.9999999997), 100.0) 71887.20703124997 titrate(, (500000.0, 0.045, 0.05, 100.0, 1.0, 'sonic'), (0, 599999.9999999997), 100.0) 54930.49621582028 titrate(, (0.045, 0.05, 100.0, 1.0), (0.0, 100.0), 0.1) 7.177734375 titrate(, (500000.0, 0.047, 0.05, 100.0, 1.0, 'sonic'), (0, 599999.9999999997), 100.0) 53283.69140624997 titrate(, (500000.0, 0.046, 0.05, 100.0, 1.0, 'sonic'), (0, 599999.9999999997), 100.0) 54089.35546874997 +titrate(, , [0.1, 0.05, 100.0, 1.0], 0, (0.0, 100.0), 0.1) 5.029296875 +titrate(, , [0.101, 0.05, 100.0, 1.0], 0, (0.0, 100.0), 0.1) 5.029296875 diff --git a/PySONIC/utils.py b/PySONIC/utils.py index 59ed096..bfe0676 100644 --- a/PySONIC/utils.py +++ b/PySONIC/utils.py @@ -1,815 +1,816 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-09-19 22:30:46 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-05-31 14:14:27 +# @Last Modified time: 2019-06-01 16:30:07 ''' Definition of generic utility functions used in other modules ''' import csv from functools import wraps import operator import time import os import math import pickle from tqdm import tqdm import logging import tkinter as tk from tkinter import filedialog import numpy as np import colorlog from scipy.interpolate import interp1d # Package logger my_log_formatter = colorlog.ColoredFormatter( '%(log_color)s %(asctime)s %(message)s', datefmt='%d/%m/%Y %H:%M:%S:', reset=True, log_colors={ 'DEBUG': 'green', 'INFO': 'white', 'WARNING': 'yellow', 'ERROR': 'red', 'CRITICAL': 'red,bg_white', }, style='%' ) def setHandler(logger, handler): for h in logger.handlers: logger.removeHandler(h) logger.addHandler(handler) return logger def setLogger(name, formatter): handler = colorlog.StreamHandler() handler.setFormatter(formatter) logger = colorlog.getLogger(name) logger.addHandler(handler) return logger class TqdmHandler(logging.StreamHandler): def __init__(self, formatter): logging.StreamHandler.__init__(self) self.setFormatter(formatter) def emit(self, record): msg = self.format(record) tqdm.write(msg) logger = setLogger('PySONIC', my_log_formatter) - titrations_logfile = os.path.join(os.path.split(__file__)[0], 'neurons', 'titrations.log') # Figure naming conventions def figtitle(meta): ''' Return appropriate title based on simulation metadata. ''' if 'Cm0' in meta: return '{:.0f}nm radius BLS structure: MECH-STIM {:.0f}kHz, {:.2f}kPa, {:.1f}nC/cm2'.format( meta['a'] * 1e9, meta['Fdrive'] * 1e-3, meta['Adrive'] * 1e-3, meta['Qm'] * 1e5) else: if meta['DC'] < 1: wavetype = 'PW' suffix = ', {:.2f}Hz PRF, {:.0f}% DC'.format(meta['PRF'], meta['DC'] * 1e2) else: wavetype = 'CW' suffix = '' if 'Astim' in meta: return '{} neuron: {} E-STIM {:.2f}mA/m2, {:.0f}ms{}'.format( meta['neuron'], wavetype, meta['Astim'], meta['tstim'] * 1e3, suffix) else: return '{} neuron ({:.1f}nm): {} A-STIM {:.0f}kHz {:.2f}kPa, {:.0f}ms{} - {} model'.format( meta['neuron'], meta['a'] * 1e9, wavetype, meta['Fdrive'] * 1e-3, meta['Adrive'] * 1e-3, meta['tstim'] * 1e3, suffix, meta['method']) # SI units prefixes si_prefixes = { 'y': 1e-24, # yocto 'z': 1e-21, # zepto 'a': 1e-18, # atto 'f': 1e-15, # femto 'p': 1e-12, # pico 'n': 1e-9, # nano 'u': 1e-6, # micro 'm': 1e-3, # mili '': 1e0, # None 'k': 1e3, # kilo 'M': 1e6, # mega 'G': 1e9, # giga 'T': 1e12, # tera 'P': 1e15, # peta 'E': 1e18, # exa 'Z': 1e21, # zetta 'Y': 1e24, # yotta } def loadData(fpath, frequency=1): ''' Load dataframe and metadata dictionary from pickle file. ''' logger.info('Loading data from "%s"', os.path.basename(fpath)) with open(fpath, 'rb') as fh: frame = pickle.load(fh) df = frame['data'].iloc[::frequency] meta = frame['meta'] return df, meta def si_format(x, precision=0, space=' '): ''' Format a float according to the SI unit system, with the appropriate prefix letter. ''' if isinstance(x, float) or isinstance(x, int) or isinstance(x, np.float) or\ isinstance(x, np.int32) or isinstance(x, np.int64): if x == 0: factor = 1e0 prefix = '' else: sorted_si_prefixes = sorted(si_prefixes.items(), key=operator.itemgetter(1)) vals = [tmp[1] for tmp in sorted_si_prefixes] # vals = list(si_prefixes.values()) ix = np.searchsorted(vals, np.abs(x)) - 1 if np.abs(x) == vals[ix + 1]: ix += 1 factor = vals[ix] prefix = sorted_si_prefixes[ix][0] # prefix = list(si_prefixes.keys())[ix] return '{{:.{}f}}{}{}'.format(precision, space, prefix).format(x / factor) elif isinstance(x, list) or isinstance(x, tuple): return [si_format(item, precision, space) for item in x] elif isinstance(x, np.ndarray) and x.ndim == 1: return [si_format(float(item), precision, space) for item in x] else: print(type(x)) def pow10_format(number, precision=2): ''' Format a number in power of 10 notation. ''' ret_string = '{0:.{1:d}e}'.format(number, precision) a, b = ret_string.split("e") a = float(a) b = int(b) return '{}10^{{{}}}'.format('{} * '.format(a) if a != 1. else '', b) def rmse(x1, x2): ''' Compute the root mean square error between two 1D arrays ''' return np.sqrt(((x1 - x2) ** 2).mean()) def rsquared(x1, x2): ''' compute the R-squared coefficient between two 1D arrays ''' residuals = x1 - x2 ss_res = np.sum(residuals**2) ss_tot = np.sum((x1 - np.mean(x1))**2) return 1 - (ss_res / ss_tot) def Pressure2Intensity(p, rho=1075.0, c=1515.0): ''' Return the spatial peak, pulse average acoustic intensity (ISPPA) associated with the specified pressure amplitude. :param p: pressure amplitude (Pa) :param rho: medium density (kg/m3) :param c: speed of sound in medium (m/s) :return: spatial peak, pulse average acoustic intensity (W/m2) ''' return p**2 / (2 * rho * c) def Intensity2Pressure(I, rho=1075.0, c=1515.0): ''' Return the pressure amplitude associated with the specified spatial peak, pulse average acoustic intensity (ISPPA). :param I: spatial peak, pulse average acoustic intensity (W/m2) :param rho: medium density (kg/m3) :param c: speed of sound in medium (m/s) :return: pressure amplitude (Pa) ''' return np.sqrt(2 * rho * c * I) def OpenFilesDialog(filetype, dirname=''): ''' Open a FileOpenDialogBox to select one or multiple file. The default directory and file type are given. :param dirname: default directory :param filetype: default file type :return: tuple of full paths to the chosen filenames ''' root = tk.Tk() root.withdraw() filenames = filedialog.askopenfilenames(filetypes=[(filetype + " files", '.' + filetype)], initialdir=dirname) if filenames: par_dir = os.path.abspath(os.path.join(filenames[0], os.pardir)) else: par_dir = None return (filenames, par_dir) def selectDirDialog(): ''' Open a dialog box to select a directory. :return: full path to selected directory ''' root = tk.Tk() root.withdraw() return filedialog.askdirectory() def SaveFileDialog(filename, dirname=None, ext=None): ''' Open a dialog box to save file. :param filename: filename :param dirname: initial directory :param ext: default extension :return: full path to the chosen filename ''' root = tk.Tk() root.withdraw() filename_out = filedialog.asksaveasfilename( defaultextension=ext, initialdir=dirname, initialfile=filename) return filename_out def downsample(t_dense, y, nsparse): ''' Decimate periodic signals to a specified number of samples.''' if(y.ndim) > 1: nsignals = y.shape[0] else: nsignals = 1 y = np.array([y]) # determine time step and period of input signal T = t_dense[-1] - t_dense[0] dt_dense = t_dense[1] - t_dense[0] # resample time vector linearly t_ds = np.linspace(t_dense[0], t_dense[-1], nsparse) # create MAV window nmav = int(0.03 * T / dt_dense) if nmav % 2 == 0: nmav += 1 mav = np.ones(nmav) / nmav # determine signals padding npad = int((nmav - 1) / 2) # determine indexes of sampling on convolved signals ids = np.round(np.linspace(0, t_dense.size - 1, nsparse)).astype(int) y_ds = np.empty((nsignals, nsparse)) # loop through signals for i in range(nsignals): # pad, convolve and resample pad_left = y[i, -(npad + 2):-2] pad_right = y[i, 1:npad + 1] y_ext = np.concatenate((pad_left, y[i, :], pad_right), axis=0) y_mav = np.convolve(y_ext, mav, mode='valid') y_ds[i, :] = y_mav[ids] if nsignals == 1: y_ds = y_ds[0, :] return (t_ds, y_ds) def rescale(x, lb=None, ub=None, lb_new=0, ub_new=1): ''' Rescale a value to a specific interval by linear transformation. ''' if lb is None: lb = x.min() if ub is None: ub = x.max() xnorm = (x - lb) / (ub - lb) return xnorm * (ub_new - lb_new) + lb_new def getNeuronLookupsFile(mechname, a=None, Fdrive=None, Adrive=None, fs=False): fpath = os.path.join( os.path.split(__file__)[0], 'neurons', '{}_lookups'.format(mechname) ) if a is not None: fpath += '_{:.0f}nm'.format(a * 1e9) if Fdrive is not None: fpath += '_{:.0f}kHz'.format(Fdrive * 1e-3) if Adrive is not None: fpath += '_{:.0f}kPa'.format(Adrive * 1e-3) if fs is True: fpath += '_fs' return '{}.pkl'.format(fpath) def getLookups4D(mechname): ''' Retrieve 4D lookup tables and reference vectors for a given membrane mechanism. :param mechname: name of membrane density mechanism :return: 4-tuple with 1D numpy arrays of reference input vectors (charge density and one other variable), a dictionary of associated 2D lookup numpy arrays, and a dictionary with information about the other variable. ''' # Check lookup file existence lookup_path = getNeuronLookupsFile(mechname) if not os.path.isfile(lookup_path): raise FileNotFoundError('Missing lookup file: "{}"'.format(lookup_path)) # Load lookups dictionary # logger.debug('Loading %s lookup table', mechname) with open(lookup_path, 'rb') as fh: df = pickle.load(fh) inputs = df['input'] lookups4D = df['lookup'] # Retrieve 1D inputs from lookups dictionary aref = inputs['a'] Fref = inputs['f'] Aref = inputs['A'] Qref = inputs['Q'] return aref, Fref, Aref, Qref, lookups4D def getLookupsOff(mechname): ''' Retrieve appropriate US-OFF lookup tables and reference vectors for a given membrane mechanism. :param mechname: name of membrane density mechanism :return: 2-tuple with 1D numpy array of reference charge density and dictionary of associated 1D lookup numpy arrays. ''' # Get 4D lookups and input vectors aref, Fref, Aref, Qref, lookups4D = getLookups4D(mechname) # Perform 2D projection in appropriate dimensions logger.debug('Interpolating lookups at A = 0') lookups_off = {key: y4D[0, 0, 0, :] for key, y4D in lookups4D.items()} return Qref, lookups_off def getLookups2D(mechname, a=None, Fdrive=None, Adrive=None): ''' Retrieve appropriate 2D lookup tables and reference vectors for a given membrane mechanism, projected at a specific combination of sonophore radius, US frequency and/or acoustic pressure amplitude. :param mechname: name of membrane density mechanism :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param Adrive: Acoustic peak pressure amplitude (Hz) :return: 4-tuple with 1D numpy arrays of reference input vectors (charge density and one other variable), a dictionary of associated 2D lookup numpy arrays, and a dictionary with information about the other variable. ''' # Get 4D lookups and input vectors aref, Fref, Aref, Qref, lookups4D = getLookups4D(mechname) # Check that inputs are within lookup range if a is not None: a = isWithin('radius', a, (aref.min(), aref.max())) if Fdrive is not None: Fdrive = isWithin('frequency', Fdrive, (Fref.min(), Fref.max())) if Adrive is not None: Adrive = isWithin('amplitude', Adrive, (Aref.min(), Aref.max())) # Determine projection dimensions based on inputs var_a = {'name': 'a', 'label': 'sonophore radius', 'val': a, 'unit': 'm', 'factor': 1e9, 'ref': aref, 'axis': 0} var_Fdrive = {'name': 'f', 'label': 'frequency', 'val': Fdrive, 'unit': 'Hz', 'factor': 1e-3, 'ref': Fref, 'axis': 1} var_Adrive = {'name': 'A', 'label': 'amplitude', 'val': Adrive, 'unit': 'Pa', 'factor': 1e-3, 'ref': Aref, 'axis': 2} if not isinstance(Adrive, float): var1 = var_a var2 = var_Fdrive var3 = var_Adrive elif not isinstance(Fdrive, float): var1 = var_a var2 = var_Adrive var3 = var_Fdrive elif not isinstance(a, float): var1 = var_Fdrive var2 = var_Adrive var3 = var_a # Perform 2D projection in appropriate dimensions # logger.debug('Interpolating lookups at (%s = %s%s, %s = %s%s)', # var1['name'], si_format(var1['val'], space=' '), var1['unit'], # var2['name'], si_format(var2['val'], space=' '), var2['unit']) lookups3D = {key: interp1d(var1['ref'], y4D, axis=var1['axis'])(var1['val']) for key, y4D in lookups4D.items()} if var2['axis'] > var1['axis']: var2['axis'] -= 1 lookups2D = {key: interp1d(var2['ref'], y3D, axis=var2['axis'])(var2['val']) for key, y3D in lookups3D.items()} if var3['val'] is not None: logger.debug('Interpolating lookups at %d new %s values between %s%s and %s%s', len(var3['val']), var3['name'], si_format(min(var3['val']), space=' '), var3['unit'], si_format(max(var3['val']), space=' '), var3['unit']) lookups2D = {key: interp1d(var3['ref'], y2D, axis=0)(var3['val']) for key, y2D in lookups2D.items()} var3['ref'] = np.array(var3['val']) return var3['ref'], Qref, lookups2D, var3 def getLookups2Dfs(mechname, a, Fdrive, fs): # Check lookup file existence lookup_path = getNeuronLookupsFile(mechname, a=a, Fdrive=Fdrive, fs=True) if not os.path.isfile(lookup_path): raise FileNotFoundError('Missing lookup file: "{}"'.format(lookup_path)) # Load lookups dictionary logger.debug('Loading %s lookup table with fs = %.0f%%', mechname, fs * 1e2) with open(lookup_path, 'rb') as fh: df = pickle.load(fh) inputs = df['input'] lookups3D = df['lookup'] # Retrieve 1D inputs from lookups dictionary fsref = inputs['fs'] Aref = inputs['A'] Qref = inputs['Q'] # Check that fs is within lookup range fs = isWithin('coverage', fs, (fsref.min(), fsref.max())) # Perform projection at fs logger.debug('Interpolating lookups at fs = %s%%', fs * 1e2) lookups2D = {key: interp1d(fsref, y3D, axis=2)(fs) for key, y3D in lookups3D.items()} return Aref, Qref, lookups2D def getLookupsDCavg(mechname, a, Fdrive, amps=None, charges=None, DCs=1.0): ''' Get the DC-averaged lookups of a specific neuron for a combination of US amplitudes, charge densities and duty cycles, at a specific US frequency. :param mechname: name of membrane density mechanism :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param amps: US amplitudes (Pa) :param charges: membrane charge densities (C/m2) :param DCs: duty cycle value(s) :return: 4-tuple with reference values of US amplitude and charge density, as well as interpolated Vmeff and QSS gating variables ''' # Get lookups for specific (a, f, A) combination Aref, Qref, lookups2D, _ = getLookups2D(mechname, a=a, Fdrive=Fdrive) if 'ng' in lookups2D: lookups2D.pop('ng') # Derive inputs from lookups reference if not provided if amps is None: amps = Aref if charges is None: charges = Qref # Transform inputs into arrays if single value provided if isinstance(amps, float): amps = np.array([amps]) if isinstance(charges, float): charges = np.array([charges]) if isinstance(DCs, float): DCs = np.array([DCs]) nA, nQ, nDC = amps.size, charges.size, DCs.size cs = {True: 's', False: ''} # logger.debug('%u amplitude%s, %u charge%s, %u DC%s', # nA, cs[nA > 1], nQ, cs[nQ > 1], nDC, cs[nDC > 1]) # Re-interpolate lookups at input charges lookups2D = {key: interp1d(Qref, y2D, axis=1)(charges) for key, y2D in lookups2D.items()} # Interpolate US-ON (for each input amplitude) and US-OFF (A = 0) lookups amps = isWithin('amplitude', amps, (Aref.min(), Aref.max())) lookups_on = {key: interp1d(Aref, y2D, axis=0)(amps) for key, y2D in lookups2D.items()} lookups_off = {key: interp1d(Aref, y2D, axis=0)(0.0) for key, y2D in lookups2D.items()} # Compute DC-averaged lookups lookups_DCavg = {} for key in lookups2D.keys(): x_on, x_off = lookups_on[key], lookups_off[key] x_avg = np.empty((nA, nQ, nDC)) for iA, Adrive in enumerate(amps): for iDC, DC in enumerate(DCs): x_avg[iA, :, iDC] = x_on[iA, :] * DC + x_off * (1 - DC) lookups_DCavg[key] = x_avg return amps, charges, lookups_DCavg def isWithin(name, val, bounds, rel_tol=1e-9): ''' Check if a floating point number is within an interval. If the value falls outside the interval, an error is raised. If the value falls just outside the interval due to rounding errors, the associated interval bound is returned. :param val: float value :param bounds: interval bounds (float tuple) :return: original or corrected value ''' if isinstance(val, list) or isinstance(val, np.ndarray) or isinstance(val, tuple): return [isWithin(name, v, bounds, rel_tol) for v in val] if val >= bounds[0] and val <= bounds[1]: return val elif val < bounds[0] and math.isclose(val, bounds[0], rel_tol=rel_tol): logger.warning('Rounding %s value (%s) to interval lower bound (%s)', name, val, bounds[0]) return bounds[0] elif val > bounds[1] and math.isclose(val, bounds[1], rel_tol=rel_tol): logger.warning('Rounding %s value (%s) to interval upper bound (%s)', name, val, bounds[1]) return bounds[1] else: raise ValueError('{} value ({}) out of [{}, {}] interval'.format( name, val, bounds[0], bounds[1])) def getLookupsCompTime(mechname): # Check lookup file existence lookup_path = getNeuronLookupsFile(mechname) if not os.path.isfile(lookup_path): raise FileNotFoundError('Missing lookup file: "{}"'.format(lookup_path)) # Load lookups dictionary logger.debug('Loading comp times') with open(lookup_path, 'rb') as fh: df = pickle.load(fh) tcomps4D = df['tcomp'] return np.sum(tcomps4D) def getLowIntensitiesSTN(): ''' Return an array of acoustic intensities (W/m2) used to study the STN neuron in Tarnaud, T., Joseph, W., Martens, L., and Tanghe, E. (2018). Computational Modeling of Ultrasonic Subthalamic Nucleus Stimulation. IEEE Trans Biomed Eng. ''' return np.hstack(( np.arange(10, 101, 10), np.arange(101, 131, 1), np.array([140]) )) # W/m2 def getDistribution(xmin, xmax, nx, scale='lin'): if scale == 'log': xmin, xmax = np.log10(xmin), np.log10(xmax) return {'lin': np.linspace, 'log': np.logspace}[scale](xmin, xmax, nx) def getDistFromList(xlist): if not isinstance(xlist, list): raise TypeError('Input must be a list') if len(xlist) != 4: raise ValueError('List must contain exactly 4 arguments ([type, min, max, n])') scale = xlist[0] if scale not in ('log', 'lin'): raise ValueError('Unknown distribution type (must be "lin" or "log")') xmin, xmax = [float(x) for x in xlist[1:-1]] if xmin >= xmax: raise ValueError('Specified minimum higher or equal than specified maximum') nx = int(xlist[-1]) if nx < 2: raise ValueError('Specified number must be at least 2') return getDistribution(xmin, xmax, nx, scale=scale) def parseUSAmps(args, defaults): # Check if several mutually exclusive arguments were provided Aparams = ['Arange', 'Irange', 'amp', 'intensity'] if sum([x in args for x in Aparams]) > 1: raise ValueError('You must provide only one of the following arguments: {}'.format( ', '.join(Aparams))) if 'Arange' in args: return getDistFromList(args['Arange']) * 1e3 # Pa elif 'Irange' in args: return Intensity2Pressure(getDistFromList(args['Irange']) * 1e4) # Pa elif 'amp' in args: return np.array(args['amp']) * 1e3 # Pa elif 'intensity' in args: return Intensity2Pressure(np.array(args['intensity']) * 1e4) # Pa return np.array(defaults['amp']) * 1e3 # Pa def parseElecAmps(args, defaults): # Check if several mutually exclusive arguments were provided Aparams = ['Arange', 'amp'] if sum([x in args for x in Aparams]) > 1: raise ValueError('You must provide only one of the following arguments: {}'.format( ', '.join(Aparams))) if 'Arange' in args: return getDistFromList(args['Arange']) # mA/m2 elif 'amp' in args: return np.array(args['amp']) # mA/m2 return np.array(defaults['amp']) # mA/m2 def getIndex(container, value): ''' Return the index of a float / string value in a list / array :param container: list / 1D-array of elements :param value: value to search for :return: index of value (if found) ''' if isinstance(value, float): container = np.array(container) imatches = np.where(np.isclose(container, value, rtol=1e-9, atol=1e-16))[0] if len(imatches) == 0: raise ValueError('{} not found in {}'.format(value, container)) return imatches[0] elif isinstance(value, str): return container.index(value) def debug(func): ''' Print the function signature and return value. ''' @wraps(func) def wrapper_debug(*args, **kwargs): args_repr = [repr(a) for a in args] kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()] signature = '{}({})'.format(func.__name__, ', '.join(args_repr + kwargs_repr)) print('Calling {}'.format(signature)) value = func(*args, **kwargs) print(f"{func.__name__!r} returned {value!r}") return value return wrapper_debug def timer(func): ''' Monitor and return the runtime of the decorated function. ''' @wraps(func) def wrapper(*args, **kwargs): start_time = time.perf_counter() value = func(*args, **kwargs) end_time = time.perf_counter() run_time = end_time - start_time return value, run_time return wrapper def cache(fpath, delimiter='\t', out_type=float): ''' Add an extra IO memoization functionality to a function using file caching, to avoid repetitions of tedious computations with identical inputs. ''' def wrapper_with_args(func): @wraps(func) def wrapper(*args, **kwargs): # If function has history -> do not log if 'history' in kwargs: return func(*args, **kwargs) # Translate function arguments into string signature args_repr = [repr(a) for a in args] kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()] signature = '{}({})'.format(func.__name__, ', '.join(args_repr + kwargs_repr)) # If entry present in log, return corresponding output if os.path.isfile(fpath): with open(fpath, 'r', newline='') as f: reader = csv.reader(f, delimiter=delimiter) for row in reader: if row[0] == signature: logger.info('entry found in "{}"'.format(os.path.basename(fpath))) return out_type(row[1]) # Otherwise, compute output and log it into file before returning out = func(*args, **kwargs) with open(fpath, 'a', newline='') as csvfile: writer = csv.writer(csvfile, delimiter=delimiter) writer.writerow([signature, str(out)]) return out return wrapper return wrapper_with_args -@cache(titrations_logfile) -def titrate(xfunc, xargs, xbounds, dx_thr, history=None): +# @cache(titrations_logfile) +def binarySearch(bool_func, args, ix, xbounds, dx_thr, history=None): ''' Use a binary search to determine the threshold satisfying a given condition - within a specific search interval. + within a continuous search interval. - :param xfunc: boolean function returning whether condition is satisfied - :param xargs: list of function arguments other than refined value + :param bool_func: boolean function returning whether condition is satisfied + :param args: list of function arguments other than refined value :param xbounds: search interval for threshold (progressively refined) :param dx_thr: accuracy criterion for threshold :return: excitation threshold ''' # Assign empty history if first function call if history is None: history = [] # Compute function output at interval mid-point x = (xbounds[0] + xbounds[1]) / 2 - history.append(xfunc(x, *xargs)) + sim_args = args[:] + sim_args.insert(ix, x) + history.append(bool_func(sim_args)) # If titration interval is small enough conv = False if (xbounds[1] - xbounds[0]) <= dx_thr: logger.debug('titration interval smaller than defined threshold') # If both conditions have been encountered during titration process, # we're going towards convergence if (0 in history and 1 in history): logger.debug('converging around threshold') # If current value satisfies condition, convergence is achieved # -> return threshold if history[-1]: logger.debug('currently satisfying condition -> convergence') return x # If only one condition has been encountered during titration process, # then no titration is impossible within the defined interval -> return NaN else: logger.warning('titration does not converge within this interval') return np.nan # Return threshold if convergence is reached, otherwise refine interval and iterate if conv: return x else: if x > 0.: xbounds = (xbounds[0], x) if history[-1] else (x, xbounds[1]) else: xbounds = (x, xbounds[1]) if history[-1] else (xbounds[0], x) - return titrate(xfunc, xargs, xbounds, dx_thr, history=history) + return binarySearch(bool_func, args, ix, xbounds, dx_thr, history=history) def resolveDependencies(deps, join_items=True): ''' Solve a dictionary of dependencies. :param arg: dependency dictionary in which the values are the dependencies of their respective keys. :param join_items: boolean specifying whether or not to serialize output :return: list of inter-dependent elements in resolved order ''' # Transform input dictionary of lists into dictionary of sets, # while removing circular (auto) dependencies deps = dict((k, set([x for x in deps[k] if x != k])) for k in deps) # Initialize empty list of resolved dependencies resolved_deps = [] # Iterate while dependencies not entirely resolved while deps: # Extract latest items without dependencies (values that are not in keys # and keys without value) into a set nd_items = set(i for v in deps.values() for i in v) - set(deps.keys()) nd_items.update(k for k, v in deps.items() if not v) # Append new set of non-dependent items to output list resolved_deps.append(nd_items) # Remove those items from remaining dependencies in input dictionary deps = dict(((k, v - nd_items) for k, v in deps.items() if v)) # If specified, merge list of sets into a unique list (while preserving order) if join_items: tmp = [] for item in resolved_deps: tmp += list(item) resolved_deps = tmp return resolved_deps def plural(n): if n < 0: raise ValueError('Cannot format negative integer (n = {})'.format(n)) if n == 0: return '' else: return 's' diff --git a/paper figures/utils.py b/paper figures/utils.py index fbacb66..b166b07 100644 --- a/paper figures/utils.py +++ b/paper figures/utils.py @@ -1,119 +1,119 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2018-10-01 20:45:29 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-05-31 17:03:40 +# @Last Modified time: 2019-06-01 16:36:29 import os import numpy as np import pandas as pd from PySONIC.utils import * from PySONIC.core import NeuronalBilayerSonophore from PySONIC.neurons import * from PySONIC.postpro import computeSpikingMetrics def getCWtitrations_vs_Fdrive(neurons, a, freqs, tstim, toffset, fpath): fkey = 'Fdrive (kHz)' freqs = np.array(freqs) if os.path.isfile(fpath): df = pd.read_csv(fpath, sep=',', index_col=fkey) else: df = pd.DataFrame(index=freqs * 1e-3) for neuron in neurons: - if neuron not in df and neuron.isTitratable(): + if neuron not in df: neuronobj = getNeuronsDict()[neuron]() nbls = NeuronalBilayerSonophore(a, neuronobj) for i, Fdrive in enumerate(freqs): logger.info('Running CW titration for %s neuron @ %sHz', neuron, si_format(Fdrive)) Athr = nbls.titrate(Fdrive, tstim, toffset) # Pa df.loc[Fdrive * 1e-3, neuron] = np.ceil(Athr * 1e-2) / 10 df.sort_index(inplace=True) df.to_csv(fpath, sep=',', index_label=fkey) return df def getCWtitrations_vs_radius(neurons, radii, Fdrive, tstim, toffset, fpath): akey = 'radius (nm)' radii = np.array(radii) if os.path.isfile(fpath): df = pd.read_csv(fpath, sep=',', index_col=akey) else: df = pd.DataFrame(index=radii * 1e9) for neuron in neurons: - if neuron not in df and neuron.isTitratable(): + if neuron not in df: neuronobj = getNeuronsDict()[neuron]() for a in radii: nbls = NeuronalBilayerSonophore(a, neuronobj) logger.info( 'Running CW titration for %s neuron @ %sHz (%.2f nm sonophore radius)', neuron, si_format(Fdrive), a * 1e9) Athr = nbls.titrate(Fdrive, tstim, toffset) # Pa df.loc[a * 1e9, neuron] = np.ceil(Athr * 1e-2) / 10 df.sort_index(inplace=True) df.to_csv(fpath, sep=',', index_label=akey) return df def getSims(outdir, neuron, a, queue): fpaths = [] updated_queue = [] neuronobj = getNeuronsDict()[neuron]() nbls = NeuronalBilayerSonophore(a, neuronobj) for i, item in enumerate(queue): Fdrive, tstim, toffset, PRF, DC, Adrive, method = item fcode = nbls.filecode(Fdrive, Adrive, tstim, toffset, PRF, DC, method) fpath = os.path.join(outdir, '{}.pkl'.format(fcode)) if not os.path.isfile(fpath): print(fpath, 'does not exist') item.insert(0, outdir) updated_queue.append(item) fpaths.append(fpath) if len(updated_queue) > 0: print(updated_queue) # neuron = getNeuronsDict()[neuron]() # nbls = NeuronalBilayerSonophore(a, neuron) # runBatch(nbls.runAndSave, updated_queue, extra_params=[outdir], mpi=True) return fpaths def getSpikingMetrics(outdir, neuron, xvar, xkey, data_fpaths, metrics_fpaths): metrics = {} for stype in data_fpaths.keys(): if os.path.isfile(metrics_fpaths[stype]): logger.info('loading spiking metrics from file: "%s"', metrics_fpaths[stype]) metrics[stype] = pd.read_csv(metrics_fpaths[stype], sep=',') else: logger.warning('computing %s spiking metrics vs. %s for %s neuron', stype, xkey, neuron) metrics[stype] = computeSpikingMetrics(data_fpaths[stype]) metrics[stype][xkey] = pd.Series(xvar, index=metrics[stype].index) metrics[stype].to_csv(metrics_fpaths[stype], sep=',', index=False) return metrics def extractCompTimes(filenames): ''' Extract computation times from a list of simulation files. ''' tcomps = np.empty(len(filenames)) for i, fn in enumerate(filenames): logger.info('Loading data from "%s"', fn) with open(fn, 'rb') as fh: frame = pickle.load(fh) meta = frame['meta'] tcomps[i] = meta['tcomp'] return tcomps def getCompTimesQuant(outdir, neuron, xvars, xkey, data_fpaths, comptimes_fpath): if os.path.isfile(comptimes_fpath): logger.info('reading computation times from file: "%s"', comptimes_fpath) comptimes = pd.read_csv(comptimes_fpath, sep=',', index_col=xkey) else: logger.warning('extracting computation times for %s neuron', neuron) comptimes = pd.DataFrame(index=xvars) for stype in data_fpaths.keys(): for i, xvar in enumerate(xvars): comptimes.loc[xvar, stype] = extractCompTimes([data_fpaths[stype][i]]) comptimes.to_csv(comptimes_fpath, sep=',', index_label=xkey) return comptimes