diff --git a/PySONIC/constants.py b/PySONIC/constants.py index 599bd58..9db9734 100644 --- a/PySONIC/constants.py +++ b/PySONIC/constants.py @@ -1,55 +1,55 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-11-04 13:23:31 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-04-05 17:14:29 +# @Last Modified time: 2019-04-29 11:48:22 ''' Algorithmic constants used in the package. ''' # Biophysical constants FARADAY = 9.64853e4 # Faraday constant (C/mol) Rg = 8.31342 # Universal gas constant (Pa.m^3.mol^-1.K^-1 or J.mol^-1.K^-1) Z_Ca = 2 # Calcium valence Z_Na = 1 # Sodium valence Z_K = 1 # Potassium valence CELSIUS_2_KELVIN = 273.15 # Celsius to Kelvin conversion constant # Fitting and pre-processing LJFIT_PM_MAX = 1e8 # intermolecular pressure at the deflection lower bound for LJ fitting (Pa) PNET_EQ_MAX = 1e-1 # error threshold for net pressure at computed equilibrium position (Pa) PMAVG_STD_ERR_MAX = 3000 # error threshold in nonlinear fit of molecular pressure (Pa) # Mechanical simulations Z_ERR_MAX = 1e-11 # periodic convergence threshold for deflection (m) NG_ERR_MAX = 1e-24 # periodic convergence threshold for gas content (mol) NCYCLES_MAX = 10 # max number of acoustic cycles in mechanical simulations CHARGE_RANGE = (-200e-5, 150e-5) # physiological charge range constraining the membrane (C/m2) # E-STIM simulations DT_ESTIM = 1e-5 # A-STIM simulations SOLVER_NSTEPS = 1000 # maximum number of steps allowed during one call to the LSODA/DOP853 solvers CLASSIC_TARGET_DT = 1e-8 # target temporal resolution for output arrays of classic simulations NPC_FULL = 1000 # nb of samples per acoustic period in full system NPC_HH = 40 # nb of samples per acoustic period in HH system DQ_UPDATE = 1e-5 # charge evolution threshold between two hybrid integrations (C/m2) DT_UPDATE = 5e-4 # time interval between two hybrid integrations (s) DT_EFF = 5e-5 # time step for effective integration (s) MIN_SAMPLES_PER_PULSE_INT = 1 # minimal number of time points per pulse interval (TON of TOFF) # Spike detection SPIKE_MIN_QAMP = 5e-5 # threshold amplitude for spike detection on charge signal (C/m2) SPIKE_MIN_QPROM = 20e-5 # threshold prominence for spike detection on charge signal (C/m2) SPIKE_MIN_VAMP = 10.0 # threshold amplitude for spike detection on potential signal (mV) SPIKE_MIN_VPROM = 20.0 # threshold prominence for spike detection on potential signal (mV) SPIKE_MIN_DT = 5e-4 # minimal time interval for spike detection on charge signal (s) MIN_NSPIKES_SPECTRUM = 3 # minimum number of spikes to compute firing rate spectrum # Titrations -TITRATION_T_OFFSET = 50e-3 # offset period for titration procedures (s) +TITRATION_T_OFFSET = 200e-3 # offset period for titration procedures (s) TITRATION_ASTIM_DA_MAX = 1e2 # acoustic pressure search range threshold for titration (Pa) TITRATION_ESTIM_A_MAX = 50.0 # initial current density upper bound for titration (mA/m2) TITRATION_ESTIM_DA_MAX = 0.1 # current density search range threshold for titration (mA/m2) diff --git a/PySONIC/core/pneuron.py b/PySONIC/core/pneuron.py index 3fa398a..07be030 100644 --- a/PySONIC/core/pneuron.py +++ b/PySONIC/core/pneuron.py @@ -1,624 +1,624 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-08-03 11:53:04 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-04-08 14:07:00 +# @Last Modified time: 2019-04-29 12:04:37 import os import time import pickle import abc import inspect import re import numpy as np from scipy.integrate import odeint import pandas as pd from ..postpro import findPeaks from ..constants import * from ..utils import si_format, logger, ESTIM_filecode, titrate from ..batches import xlslog class PointNeuron(metaclass=abc.ABCMeta): ''' Abstract class defining the common API (i.e. mandatory attributes and methods) of all subclasses implementing the channels mechanisms of specific point neurons. The mandatory attributes are: - **name**: a string defining the name of the mechanism. - **Cm0**: a float defining the membrane resting capacitance (in F/m2) - **Vm0**: a float defining the membrane resting potential (in mV) - **states**: a list of strings defining the names of the different state probabilities governing the channels behaviour (i.e. the differential HH variables). - **rates**: a list of strings defining the names of the different coefficients to be used in effective simulations. The mandatory methods are: - **iNet**: compute the net ionic current density (in mA/m2) across the membrane, given a specific membrane potential (in mV) and channel states. - **steadyStates**: compute the channels steady-state values for a specific membrane potential value (in mV). - **derStates**: compute the derivatives of channel states, given a specific membrane potential (in mV) and channel states. This method must return a list of derivatives ordered identically as in the steadyStates output. - **getEffRates**: get the effective rate constants of ion channels to be used in effective simulations. This method must return an array of effective rates ordered identically as in the rates attribute. - **derStatesEff**: compute the effective derivatives of channel states, based on 1-dimensional linear interpolators of "effective" coefficients. This method must return a list of derivatives ordered identically as in the steadyStates output. - **steadyStates**: compute the steady-state values of all internal states for a given membrane potential. ''' tscale = 'ms' # relevant temporal scale of the model defvar = 'V' # default plot variable def __repr__(self): return self.__class__.__name__ def pprint(self): return '{} neuron'.format(self.__class__.__name__) @property @abc.abstractmethod def name(self): return 'Should never reach here' @property @abc.abstractmethod def Cm0(self): return 'Should never reach here' @property @abc.abstractmethod def Vm0(self): return 'Should never reach here' @abc.abstractmethod def currents(self, Vm, states): ''' Compute all ionic currents per unit area. :param Vm: membrane potential (mV) :states: state probabilities of the ion channels :return: dictionary of ionic currents per unit area (mA/m2) ''' def iNet(self, Vm, states): ''' net membrane current :param Vm: membrane potential (mV) :states: states of ion channels gating and related variables :return: current per unit area (mA/m2) ''' return sum(self.currents(Vm, states).values()) def currentToConcentrationRate(self, z_ion, depth): ''' Compute the conversion factor from a specific ionic current (in mA/m2) into a variation rate of submembrane ion concentration (in M/s). :param: z_ion: ion valence :param depth: submembrane depth (m) :return: conversion factor (Mmol.m-1.C-1) ''' return 1e-6 / (z_ion * depth * FARADAY) def nernst(self, z_ion, Cion_in, Cion_out, T): ''' Nernst potential of a specific ion given its intra and extracellular concentrations. :param z_ion: ion valence :param Cion_in: intracellular ion concentration :param Cion_out: extracellular ion concentration :param T: temperature (K) :return: ion Nernst potential (mV) ''' return (Rg * T) / (z_ion * FARADAY) * np.log(Cion_out / Cion_in) * 1e3 def vtrap(self, x, y): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x / y) - 1) def efun(self, x): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x) - 1) def ghkDrive(self, Vm, Z_ion, Cion_in, Cion_out, T): ''' Use the Goldman-Hodgkin-Katz equation to compute the electrochemical driving force of a specific ion species for a given membrane potential. :param Vm: membrane potential (mV) :param Cin: intracellular ion concentration (M) :param Cout: extracellular ion concentration (M) :param T: temperature (K) :return: electrochemical driving force of a single ion particle (mC.m-3) ''' x = Z_ion * FARADAY * Vm / (Rg * T) * 1e-3 # [-] eCin = Cion_in * self.efun(-x) # M eCout = Cion_out * self.efun(x) # M return FARADAY * (eCin - eCout) * 1e6 # mC/m3 def getDesc(self): return inspect.getdoc(self).splitlines()[0] def getCurrentsNames(self): return list(self.currents(np.nan, [np.nan] * len(self.states)).keys()) def getPltScheme(self): pltscheme = { 'Q_m': ['Qm'], 'V_m': ['Vm'] } pltscheme['I'] = self.getCurrentsNames() + ['iNet'] for cname in self.getCurrentsNames(): if 'Leak' not in cname: key = 'i_{{{}}}\ kin.'.format(cname[1:]) cargs = inspect.getargspec(getattr(self, cname))[0][1:] pltscheme[key] = [var for var in cargs if var not in ['Vm', 'Cai']] return pltscheme def getPltVars(self, wrapleft='df["', wrapright='"]'): ''' Return a dictionary with information about all plot variables related to the neuron. ''' pltvars = { 'Qm': { 'desc': 'membrane charge density', 'label': 'Q_m', 'unit': 'nC/cm^2', 'factor': 1e5, 'bounds': (-100, 50) }, 'Vm': { 'desc': 'membrane potential', 'label': 'V_m', 'unit': 'mV', 'y0': self.Vm0, 'bounds': (-150, 70) }, 'ELeak': { 'constant': 'obj.ELeak', 'desc': 'non-specific leakage current resting potential', 'label': 'V_{leak}', 'unit': 'mV', 'ls': '--', 'color': 'k' } } for cname in self.getCurrentsNames(): cfunc = getattr(self, cname) cargs = inspect.getargspec(cfunc)[0][1:] pltvars[cname] = { 'desc': inspect.getdoc(cfunc).splitlines()[0], 'label': 'I_{{{}}}'.format(cname[1:]), 'unit': 'A/m^2', 'factor': 1e-3, 'func': '{}({})'.format(cname, ', '.join(['{}{}{}'.format(wrapleft, a, wrapright) for a in cargs])) } for var in cargs: if var not in ['Vm', 'Cai']: vfunc = getattr(self, 'der{}{}'.format(var[0].upper(), var[1:])) desc = cname + re.sub('^Evolution of', '', inspect.getdoc(vfunc).splitlines()[0]) pltvars[var] = { 'desc': desc, 'label': var, 'bounds': (-0.1, 1.1) } pltvars['iNet'] = { 'desc': inspect.getdoc(getattr(self, 'iNet')).splitlines()[0], 'label': 'I_{net}', 'unit': 'A/m^2', 'factor': 1e-3, 'func': 'iNet({0}Vm{1}, {2}{3}{4}.values.T)'.format( wrapleft, wrapright, wrapleft[:-1], self.states, wrapright[1:]), 'ls': '--', 'color': 'black' } for x in self.getGates(): for rate in ['alpha', 'beta']: pltvars['{}{}'.format(rate, x)] = { 'label': '\\{}_{{{}}}'.format(rate, x), 'unit': 'ms^{-1}', 'factor': 1e-3 } return pltvars def getRatesNames(self, states): return list(sum( [['alpha{}'.format(x.lower()), 'beta{}'.format(x.lower())] for x in states], [] )) def Qm0(self): ''' Return the resting charge density (in C/m2). ''' return self.Cm0 * self.Vm0 * 1e-3 # C/cm2 @abc.abstractmethod def steadyStates(self, Vm): ''' Compute the steady-state values for a specific membrane potential value. :param Vm: membrane potential (mV) :return: array of steady-states ''' @abc.abstractmethod def derStates(self, Vm, states): ''' Compute the derivatives of channel states. :param Vm: membrane potential (mV) :states: state probabilities of the ion channels :return: current per unit area (mA/m2) ''' @abc.abstractmethod def getEffRates(self, Vm): ''' Get the effective rate constants of ion channels, averaged along an acoustic cycle, for future use in effective simulations. :param Vm: array of membrane potential values for an acoustic cycle (mV) :return: an array of rate average constants (s-1) ''' @abc.abstractmethod def derStatesEff(self, Qm, states, interp_data): ''' Compute the effective derivatives of channel states, based on 1-dimensional linear interpolation of "effective" coefficients that summarize the system's behaviour over an acoustic cycle. :param Qm: membrane charge density (C/m2) :states: state probabilities of the ion channels :param interp_data: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. ''' def Qbounds(self): ''' Determine bounds of membrane charge physiological range for a given neuron. ''' return np.array([np.round(self.Vm0 - 25.0), 50.0]) * self.Cm0 * 1e-3 # C/m2 def isVoltageGated(self, state): ''' Determine whether a given state is purely voltage-gated or not.''' return 'alpha{}'.format(state.lower()) in self.rates def getGates(self): ''' Retrieve the names of the neuron's states that match an ion channel gating. ''' gates = [] for x in self.states: if self.isVoltageGated(x): gates.append(x) return gates def getRates(self, Vm): ''' Compute the ion channels rate constants for a given membrane potential. :param Vm: membrane potential (mV) :return: a dictionary of rate constants and their values at the given potential. ''' rates = {} for x in self.getGates(): x = x.lower() alpha_str, beta_str = ['{}{}'.format(s, x.lower()) for s in ['alpha', 'beta']] inf_str, tau_str = ['{}inf'.format(x.lower()), 'tau{}'.format(x.lower())] if hasattr(self, 'alpha{}'.format(x)): alphax = getattr(self, alpha_str)(Vm) betax = getattr(self, beta_str)(Vm) elif hasattr(self, '{}inf'.format(x)): xinf = getattr(self, inf_str)(Vm) taux = getattr(self, tau_str)(Vm) alphax = xinf / taux betax = 1 / taux - alphax rates[alpha_str] = alphax rates[beta_str] = betax return rates def Vderivatives(self, y, t, Iinj): ''' Compute the derivatives of a V-cast HH system for a specific value of injected current. :param y: vector of HH system variables at time t :param t: time value (s, unused) :param Iinj: injected current (mA/m2) :return: vector of HH system derivatives at time t ''' Vm, *states = y Iionic = self.iNet(Vm, states) # mA/m2 dVmdt = (- Iionic + Iinj) / self.Cm0 # mV/s dstates = self.derStates(Vm, states) return [dVmdt, *dstates] def Qderivatives(self, y, t, Cm=None): ''' Compute the derivatives of the n-ODE HH system variables, based on a value of membrane capacitance. :param y: vector of HH system variables at time t :param t: specific instant in time (s) :param Cm: membrane capacitance (F/m2) :return: vector of HH system derivatives at time t ''' if Cm is None: Cm = self.Cm0 Qm, *states = y Vm = Qm / Cm * 1e3 # mV dQm = - self.iNet(Vm, states) * 1e-3 # A/m2 dstates = self.derStates(Vm, states) return [dQm, *dstates] def checkInputs(self, Astim, tstim, toffset, PRF, DC): ''' Check validity of electrical stimulation parameters. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) ''' # Check validity of stimulation parameters if not all(isinstance(param, float) for param in [Astim, tstim, toffset, DC]): raise TypeError('Invalid stimulation parameters (must be float typed)') if tstim <= 0: raise ValueError('Invalid stimulus duration: {} ms (must be strictly positive)' .format(tstim * 1e3)) if toffset < 0: raise ValueError('Invalid stimulus offset: {} ms (must be positive or null)' .format(toffset * 1e3)) if DC <= 0.0 or DC > 1.0: raise ValueError('Invalid duty cycle: {} (must be within ]0; 1])'.format(DC)) if DC < 1.0: if not isinstance(PRF, float): raise TypeError('Invalid PRF value (must be float typed)') if PRF is None: raise AttributeError('Missing PRF value (must be provided when DC < 1)') if PRF < 1 / tstim: raise ValueError('Invalid PRF: {} Hz (PR interval exceeds stimulus duration)' .format(PRF)) def simulate(self, Astim, tstim, toffset, PRF=None, DC=1.0): ''' Compute solutions of a neuron's HH system for a specific set of electrical stimulation parameters, using a classic integration scheme. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :return: 3-tuple with the time profile and solution matrix and a state vector ''' # Check validity of stimulation parameters self.checkInputs(Astim, tstim, toffset, PRF, DC) # Determine system time step dt = DT_ESTIM # if CW stimulus: divide integration during stimulus into single interval if DC == 1.0: PRF = 1 / tstim # Compute vector sizes npulses = int(np.round(PRF * tstim)) Tpulse_on = DC / PRF Tpulse_off = (1 - DC) / PRF # For high-PRF pulsed protocols: adapt time step to ensure minimal # number of samples during TON or TOFF dt_warning_msg = 'high-PRF protocol: lowering time step to %.2e s to properly integrate %s' for key, Tpulse in {'TON': Tpulse_on, 'TOFF': Tpulse_off}.items(): if Tpulse > 0 and Tpulse / dt < MIN_SAMPLES_PER_PULSE_INT: dt = Tpulse / MIN_SAMPLES_PER_PULSE_INT logger.warning(dt_warning_msg, dt, key) n_pulse_on = int(np.round(Tpulse_on / dt)) n_pulse_off = int(np.round(Tpulse_off / dt)) # Compute offset size n_off = int(np.round(toffset / dt)) # Set initial conditions y0 = [self.Vm0, *self.steadyStates(self.Vm0)] nvar = len(y0) # Initialize global arrays t = np.array([0.]) stimstate = np.array([1]) y = np.array([y0]).T # Initialize pulse time and stimstate vectors t_pulse0 = np.linspace(0, Tpulse_on + Tpulse_off, n_pulse_on + n_pulse_off) stimstate_pulse = np.concatenate((np.ones(n_pulse_on), np.zeros(n_pulse_off))) # Loop through all pulse (ON and OFF) intervals for i in range(npulses): # Construct and initialize arrays t_pulse = t_pulse0 + t[-1] y_pulse = np.empty((nvar, n_pulse_on + n_pulse_off)) # Integrate ON system y_pulse[:, :n_pulse_on] = odeint( self.Vderivatives, y[:, -1], t_pulse[:n_pulse_on], args=(Astim,)).T # Integrate OFF system if n_pulse_off > 0: y_pulse[:, n_pulse_on:] = odeint( self.Vderivatives, y_pulse[:, n_pulse_on - 1], t_pulse[n_pulse_on:], args=(0.0,)).T # Append pulse arrays to global arrays stimstate = np.concatenate([stimstate, stimstate_pulse[1:]]) t = np.concatenate([t, t_pulse[1:]]) y = np.concatenate([y, y_pulse[:, 1:]], axis=1) # Integrate offset interval if n_off > 0: t_off = np.linspace(0, toffset, n_off) + t[-1] stimstate_off = np.zeros(n_off) y_off = odeint(self.Vderivatives, y[:, -1], t_off, args=(0.0, )).T # Concatenate offset arrays to global arrays stimstate = np.concatenate([stimstate, stimstate_off[1:]]) t = np.concatenate([t, t_off[1:]]) y = np.concatenate([y, y_off[:, 1:]], axis=1) # Return output variables return (t, y, stimstate) def nSpikes(self, Astim, tstim, toffset, PRF, DC): ''' Run a simulation and determine number of spikes in the response. :param Astim: current amplitude (mA/m2) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :return: number of spikes found in response ''' t, y, _ = self.simulate(Astim, tstim, toffset, PRF, DC) dt = t[1] - t[0] ipeaks, *_ = findPeaks(y[0, :], SPIKE_MIN_VAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_VPROM) nspikes = ipeaks.size logger.debug('A = %sA/m2 ---> %s spike%s detected', si_format(Astim * 1e-3, 2, space=' '), nspikes, "s" if nspikes > 1 else "") return nspikes def titrate(self, tstim, toffset, PRF=None, DC=1.0, Arange=(0., 2 * TITRATION_ESTIM_A_MAX)): ''' Use a binary search to determine the threshold amplitude needed to obtain neural excitation for a given duration, PRF and duty cycle. :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param Arange: search interval for Astim, iteratively refined :return: excitation threshold amplitude (mA/m2) ''' return titrate(self.nSpikes, (tstim, toffset, PRF, DC), Arange, TITRATION_ESTIM_DA_MAX) def runAndSave(self, outdir, tstim, toffset, PRF=None, DC=1.0, Astim=None): ''' Run a simulation of the point-neuron Hodgkin-Huxley system with specific parameters, and save the results in a PKL file. :param outdir: full path to output directory :param tstim: stimulus duration (s) :param toffset: stimulus offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: stimulus duty cycle (-) :param Astim: stimulus amplitude (mA/m2) ''' # Get date and time info date_str = time.strftime("%Y.%m.%d") daytime_str = time.strftime("%H:%M:%S") logger.info( '%s: %s @ %st = %ss (%ss offset)%s', self, 'titration' if Astim is None else 'simulation', 'A = {}A/m2, '.format(si_format(Astim, 2, space=' ')) if Astim is not None else '', *si_format([tstim, toffset], 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format(si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) if Astim is None: Astim = self.titrate(tstim, toffset, PRF, DC) if np.isnan(Astim): logger.error('Could not find threshold excitation amplitude') return None # Run simulation tstart = time.time() t, y, stimstate = self.simulate(Astim, tstim, toffset, PRF, DC) Vm, *channels = y tcomp = time.time() - tstart # Detect spikes on Vm signal dt = t[1] - t[0] ipeaks, *_ = findPeaks(Vm, SPIKE_MIN_VAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_VPROM) nspikes = ipeaks.size lat = t[ipeaks[0]] if nspikes > 0 else 'N/A' outstr = '{} spike{} detected'.format(nspikes, 's' if nspikes > 1 else '') logger.debug('completed in %s, %s', si_format(tcomp, 1), outstr) sr = np.mean(1 / np.diff(t[ipeaks])) if nspikes > 1 else None # Store dataframe and metadata df = pd.DataFrame({ 't': t, 'stimstate': stimstate, 'Vm': Vm, 'Qm': Vm * self.Cm0 * 1e-3 }) for j in range(len(self.states)): df[self.states[j]] = channels[j] meta = { 'neuron': self.name, 'Astim': Astim, 'tstim': tstim, 'toffset': toffset, 'PRF': PRF, 'DC': DC, 'tcomp': tcomp } # Export into to PKL file simcode = ESTIM_filecode(self.name, Astim, tstim, PRF, DC) outpath = '{}/{}.pkl'.format(outdir, simcode) with open(outpath, 'wb') as fh: pickle.dump({'meta': meta, 'data': df}, fh) logger.debug('simulation data exported to "%s"', outpath) # Export key metrics to log file logpath = os.path.join(outdir, 'log_ESTIM.xlsx') logentry = { 'Date': date_str, 'Time': daytime_str, 'Neuron Type': self.name, 'Astim (mA/m2)': Astim, 'Tstim (ms)': tstim * 1e3, 'PRF (kHz)': PRF * 1e-3 if DC < 1 else 'N/A', 'Duty factor': DC, '# samples': t.size, 'Comp. time (s)': round(tcomp, 2), '# spikes': nspikes, 'Latency (ms)': lat * 1e3 if isinstance(lat, float) else 'N/A', 'Spike rate (sp/ms)': sr * 1e-3 if isinstance(sr, float) else 'N/A' } if xlslog(logpath, logentry) == 1: logger.debug('log exported to "%s"', logpath) else: logger.error('log export to "%s" aborted', self.logpath) return outpath def findRheobaseAmps(self, DCs, Vthr): ''' Find the rheobase amplitudes (i.e. threshold amplitudes of infinite duration that would result in excitation) of a specific neuron for various stimulation duty cycles. :param DCs: duty cycles vector (-) :param Vthr: threshold membrane potential above which the neuron necessarily fires (mV) :return: rheobase amplitudes vector (mA/m2) ''' # Compute the pulse average net (or leakage) current along the amplitude space iNet = self.iNet(Vthr, self.steadyStates(Vthr)) # Compute rheobase amplitudes return iNet / np.array(DCs) diff --git a/PySONIC/plt/QSS.py b/PySONIC/plt/QSS.py index a315bee..5349f78 100644 --- a/PySONIC/plt/QSS.py +++ b/PySONIC/plt/QSS.py @@ -1,425 +1,433 @@ import inspect import pandas as pd import numpy as np import matplotlib.pyplot as plt from matplotlib import cm, colors from ..postpro import getFixedPoints, getEqPoint1D from ..core import NeuronalBilayerSonophore from .pltutils import * from ..constants import TITRATION_T_OFFSET from ..utils import logger def plotVarDynamics(neuron, a, Fdrive, Adrive, charges, varname, varrange, fs=12): ''' Plot the QSS-approximated derivative of a specific variable as function of the variable itself, as well as equilibrium values, for various membrane charge densities at a given acoustic amplitude. :param neuron: neuron object :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param Adrive: US amplitude (Pa) :param charges: charge density vector (C/m2) :param varname: name of variable to plot :param varrange: range over which to compute the derivative :return: figure handle ''' # Extract information about variable to plot pltvar = neuron.getPltVars()[varname] # Get methods to compute derivative and steady-state of variable of interest derX_func = getattr(neuron, 'der{}{}'.format(varname[0].upper(), varname[1:])) Xinf_func = getattr(neuron, '{}inf'.format(varname)) derX_args = inspect.getargspec(derX_func)[0][1:] Xinf_args = inspect.getargspec(Xinf_func)[0][1:] # Get dictionary of charge and amplitude dependent QSS variables nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) _, Qref, Vmeff, QS_states = nbls.quasiSteadyStates(Fdrive, amps=Adrive, charges=charges) df = {k: QS_states[i] for i, k in enumerate(neuron.states)} df['Vm'] = Vmeff # Create figure fig, ax = plt.subplots(figsize=(6, 4)) ax.set_title('{} neuron - QSS {} dynamics @ {:.2f} kPa'.format( neuron.name, pltvar['desc'], Adrive * 1e-3), fontsize=fs) ax.set_xscale('log') for key in ['top', 'right']: ax.spines[key].set_visible(False) ax.set_xlabel('$\\rm {}\ ({})$'.format(pltvar['label'], pltvar.get('unit', '')), fontsize=fs) ax.set_ylabel('$\\rm QSS\ d{}/dt\ ({}/s)$'.format(pltvar['label'], pltvar.get('unit', '1')), fontsize=fs) ax.set_ylim(-40, 40) ax.axhline(0, c='k', linewidth=0.5) y0_str = '{}0'.format(varname) if hasattr(neuron, y0_str): ax.axvline(getattr(neuron, y0_str) * pltvar.get('factor', 1), label=y0_str, c='k', linewidth=0.5) # For each charge value icolor = 0 for j, Qm in enumerate(charges): lbl = 'Q = {:.0f} nC/cm2'.format(Qm * 1e5) # Compute variable derivative as a function of its value, as well as equilibrium value, # keeping other variables at quasi steady-state derX_inputs = [varrange if arg == varname else df[arg][j] for arg in derX_args] Xinf_inputs = [df[arg][j] for arg in Xinf_args] dX_QSS = neuron.derCai(*derX_inputs) Xeq_QSS = neuron.Caiinf(*Xinf_inputs) # Plot variable derivative and its root as a function of the variable itself c = 'C{}'.format(icolor) ax.plot(varrange * pltvar.get('factor', 1), dX_QSS * pltvar.get('factor', 1), c=c, label=lbl) ax.axvline(Xeq_QSS * pltvar.get('factor', 1), linestyle='--', c=c) icolor += 1 ax.legend(frameon=False, fontsize=fs - 3) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) fig.tight_layout() fig.canvas.set_window_title('{}_QSS_{}_dynamics_{:.2f}kPa'.format( neuron.name, varname, Adrive * 1e-3)) return fig def plotVarsQSS(neuron, a, Fdrive, Adrive, fs=12): ''' Plot effective membrane potential, quasi-steady states and resulting membrane currents as a function of membrane charge density, for a given acoustic amplitudes. :param neuron: neuron object :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param Adrive: US amplitude (Pa) :return: figure handle ''' # Get neuron-specific pltvars pltvars = neuron.getPltVars() # Compute neuron-specific charge and amplitude dependent QS states at this amplitude nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) _, Qref, Vmeff, QS_states = nbls.quasiSteadyStates(Fdrive, amps=Adrive) # Compute QSS currents currents = neuron.currents(Vmeff, QS_states) iNet = sum(currents.values()) # Extract dimensionless states norm_QS_states = {} for i, label in enumerate(neuron.states): if 'unit' not in pltvars[label]: norm_QS_states[label] = QS_states[i] # Create figure fig, axes = plt.subplots(3, 1, figsize=(7, 9)) axes[-1].set_xlabel('Charge Density (nC/cm2)', fontsize=fs) for ax in axes: for skey in ['top', 'right']: ax.spines[skey].set_visible(False) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) for item in ax.get_xticklabels(minor=True): item.set_visible(False) figname = '{} neuron - QSS dynamics @ {:.2f} kPa'.format(neuron.name, Adrive * 1e-3) fig.suptitle(figname, fontsize=fs) # Subplot: Vmeff ax = axes[0] ax.set_ylabel('$V_m^*$ (mV)', fontsize=fs) ax.plot(Qref * 1e5, Vmeff, color='k') ax.axhline(neuron.Vm0, linewidth=0.5, color='k') # Subplot: dimensionless quasi-steady states cset = plt.get_cmap('tab10').colors + plt.get_cmap('Dark2').colors ax = axes[1] ax.set_ylabel('$X_\infty$', fontsize=fs) ax.set_yticks([0, 0.5, 1]) ax.set_ylim([-0.05, 1.05]) for i, (label, QS_state) in enumerate(norm_QS_states.items()): ax.plot(Qref * 1e5, QS_state, label=label, c=cset[i]) # Subplot: currents ax = axes[2] ax.set_ylabel('QSS currents (A/m2)', fontsize=fs) for k, I in currents.items(): ax.plot(Qref * 1e5, I * 1e-3, label=k) ax.plot(Qref * 1e5, iNet * 1e-3, color='k', label='iNet') ax.axhline(0, color='k', linewidth=0.5) fig.tight_layout() fig.subplots_adjust(right=0.8) for ax in axes[1:]: ax.legend(loc='center right', fontsize=fs, frameon=False, bbox_to_anchor=(1.3, 0.5)) fig.canvas.set_window_title( '{}_QSS_states_vs_Qm_{:.2f}kPa'.format(neuron.name, Adrive * 1e-3)) return fig def plotQSSVarVsAmp(neuron, a, Fdrive, varname, amps=None, DC=1., Qi=None, plotQi=True, fs=12, cmap='viridis', yscale='lin', zscale='lin'): ''' Plot a specific QSS variable (state or current) as a function of membrane charge density, for various acoustic amplitudes. :param neuron: neuron object :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param amps: US amplitudes (Pa) :param DC: duty cycle (-) :param Qi: initial membrane charge density for phase-plane analysis (C/m2) :param varname: extraction key for variable to plot :return: figure handle ''' # Determine stimulation modality if a is None and Fdrive is None: stim_type = 'elec' a = 32e-9 Fdrive = 500e3 else: stim_type = 'US' # Extract information about variable to plot pltvar = neuron.getPltVars()[varname] Qvar = neuron.getPltVars()['Qm'] - logger.info('plotting %s QSS profiles for %s stimulation @ %.0f%% DC', - varname, stim_type, DC * 1e2) + logger.info('plotting %s %s QSS profiles for %s stimulation @ %.0f%% DC', + neuron.name, varname, stim_type, DC * 1e2) nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) # Get reference QSS dictionary for zero amplitude _, Qref, Vmeff0, QS_states0 = nbls.quasiSteadyStates(Fdrive, amps=0.) + if stim_type == 'elec': # if E-STIM case, compute steady states with constant capacitance + Vmeff0 = Qref / neuron.Cm0 * 1e3 + QS_states0 = np.array(list(map(neuron.steadyStates, Vmeff0))).T df0 = {k: QS_states0[i] for i, k in enumerate(neuron.states)} df0['Vm'] = Vmeff0 if stim_type == 'US': # Get dictionary of charge and amplitude dependent QSS variables Aref, Qref, Vmeff, QS_states = nbls.quasiSteadyStates(Fdrive, amps=amps, DCs=DC) df = {k: QS_states[i] for i, k in enumerate(neuron.states)} df['Vm'] = Vmeff Afactor = 1e-3 else: # Repeat zero-amplitude QSS dictionary for all amplitudes Aref = amps df = {k: np.tile(df0[k], (amps.size, 1)) for k in df0} Afactor = 1. # Define color code mymap = plt.get_cmap(cmap) zref = Aref * Afactor if zscale == 'lin': norm = colors.Normalize(zref.min(), zref.max()) elif zscale == 'log': norm = colors.LogNorm(zref.min(), zref.max()) sm = cm.ScalarMappable(norm=norm, cmap=mymap) sm._A = [] # Create figure fig, ax = plt.subplots(figsize=(6, 4)) ax.set_title('{} neuron - {:.0f} % DC {} stim\nquasi steady-state {} vs. amplitude'.format( neuron.name, DC * 1e2, stim_type, pltvar['desc']), fontsize=fs) ax.set_xlabel('{} ($\\rm {}$)'.format(Qvar['desc'], Qvar['unit']), fontsize=fs) ax.set_ylabel('$\\rm QSS\ {}\ ({})$'.format(pltvar['label'], pltvar.get('unit', '')), fontsize=fs) if yscale == 'log': ax.set_yscale('log') for key in ['top', 'right']: ax.spines[key].set_visible(False) # Plot charge starting point, if any if plotQi and Qi is not None: ax.axvline(Qi * Qvar['factor'], label='$\\rm Q_{m,i}$', c='silver') # Plot y-variable reference line, if any y0_str = '{}0'.format(varname) if hasattr(neuron, y0_str): y0 = getattr(neuron, y0_str) * pltvar.get('factor', 1) elif varname == 'iNet': y0 = 0. y0_str = '' ax.axhline(y0, label=y0_str, c='k', linewidth=0.5) # Plot QSS profile of variable as a function of charge density for various amplitudes tmp = False Qzeros = [] for i, A in enumerate(Aref): var = extractPltVar( neuron, pltvar, pd.DataFrame({k: df[k][i] for k in df.keys()}), name=varname) if varname == 'iNet' and stim_type == 'elec': var -= A * DC * pltvar['factor'] ax.plot(Qref * Qvar['factor'], var, c=sm.to_rgba(A * Afactor), zorder=0) if varname == 'iNet': # mark eq. point if starting point provided, otherwise mark all SFPs if Qi is None: SFPs = getFixedPoints(Qref, -var) if SFPs is not None: Qzeros += SFPs.tolist() if A > 0 and SFPs.min() > -40e-5 and not tmp: - print(A) + print('Athr = {:.2f} mA/m2'.format(A)) tmp = True else: Qzeros += [getEqPoint1D(Qref, -var, Qi)] # Plot reference QSS profile of variable as a function of charge density - var = extractPltVar( + var0 = extractPltVar( neuron, pltvar, pd.DataFrame({k: df0[k] for k in df0.keys()}), name=varname) - ax.plot(Qref * Qvar['factor'], var, '--', c='k', zorder=1, + ax.plot(Qref * Qvar['factor'], var0, '--', c='k', zorder=1, label='$\\rm resting\ {}\ (A=0)$'.format(pltvar['label'])) if varname == 'iNet': if Qi is None: - Qzeros += getFixedPoints(Qref, -var).tolist() + Qzeros += getFixedPoints(Qref, -var0).tolist() else: - Qzeros += [getEqPoint1D(Qref, -var, Qi)] + Qzeros += [getEqPoint1D(Qref, -var0, Qi)] # Plot fixed-points, if any if len(Qzeros) > 0: ax.plot(np.array(Qzeros) * Qvar['factor'], np.zeros(len(Qzeros)), '.', c='k', zorder=1, label='$\\rm Q_{m,eq}$' if Qi is not None else '$\\rm Q_m\ SPFs$') # Add legend and adjust layout ax.legend(frameon=False, fontsize=fs) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) fig.tight_layout() # Plot amplitude colorbar fig.subplots_adjust(bottom=0.15, top=0.9, right=0.80, hspace=0.5) cbarax = fig.add_axes([0.85, 0.15, 0.03, 0.75]) fig.colorbar(sm, cax=cbarax) cbarax.set_ylabel( 'Amplitude ({})'.format({'US': 'kPa', 'elec': 'mA/m2'}[stim_type]), fontsize=fs) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) fig.canvas.set_window_title('{}_QSS_{}_vs_{}A_{}_{:.0f}%DC'.format( neuron.name, varname, zscale, stim_type, DC * 1e2)) return fig def plotEqChargeVsAmp(neurons, a, Fdrive, amps=None, tstim=250e-3, PRF=100.0, DCs=[1.], Qi=None, fs=12, xscale='lin', titrate=False): ''' Plot the equilibrium membrane charge density as a function of acoustic amplitude, given an initial value of membrane charge density. :param neurons: neuron objects :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param amps: US amplitudes (Pa) :param Qi: initial membrane charge density for phase-plane analysis (C/m2) :return: figure handle ''' # Determine stimulation modality if a is None and Fdrive is None: stim_type = 'elec' a = 32e-9 Fdrive = 500e3 else: stim_type = 'US' logger.info('plotting equilibrium charges for %s stimulation', stim_type) # Create figure fig, ax = plt.subplots(figsize=(6, 4)) figname = 'equilibrium charge vs. amplitude' ax.set_title(figname) ax.set_xlabel('Amplitude ({})'.format({'US': 'kPa', 'elec': 'mA/m2'}[stim_type]), fontsize=fs) if Qi[0] is not None: ax.set_ylabel('$\\rm Q_{m, eq}\ (nC/cm^2)$', fontsize=fs) else: ax.set_ylabel('$\\rm Q_m\ SPFs\ (nC/cm^2)$', fontsize=fs) if xscale == 'log': ax.set_xscale('log') for skey in ['top', 'right']: ax.spines[skey].set_visible(False) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) - # For each neuron icolor = 0 - for i, neuron in enumerate(neurons): nbls = NeuronalBilayerSonophore(a, neuron, Fdrive) # Compute reference charge variation array for zero amplitude _, Qref, Vmeff0, QS_states0 = nbls.quasiSteadyStates(Fdrive, amps=0.) + if stim_type == 'elec': # if E-STIM case, compute steady states with constant capacitance + Vmeff0 = Qref / neuron.Cm0 * 1e3 + QS_states0 = np.array(list(map(neuron.steadyStates, Vmeff0))).T dQdt0 = -neuron.iNet(Vmeff0, QS_states0) # mA/m2 # Compute 3D QSS charge variation array if stim_type == 'US': Aref, _, Vmeff, QS_states = nbls.quasiSteadyStates(Fdrive, amps=amps, DCs=DCs) if DCs.size == 1: QS_states = QS_states.reshape((*QS_states.shape, 1)) Vmeff = Vmeff.reshape((*Vmeff.shape, 1)) dQdt = -neuron.iNet(Vmeff, QS_states) # mA/m2 Afactor = 1e-3 else: Aref = amps Afactor = 1. dQdt = np.empty((Aref.size, Qref.size, DCs.size)) for iA, A in enumerate(Aref): for iDC, DC in enumerate(DCs): dQdt[iA, :, iDC] = dQdt0 + A * DC # For each duty cycle for j, DC in enumerate(DCs): color = 'C{}'.format(icolor) # Plot either all charge SFPs or only equilibrium charge (if Qi provided) # for each acoustic amplitude Aplot, Qplot = [], [] for k, Adrive in enumerate(Aref): dQ_profile = dQdt[k, :, j] if Qi[i] is None: Qzeros = getFixedPoints(Qref, dQ_profile) if Qzeros is not None: Qzeros = Qzeros.tolist() else: Qzeros = [np.nan] else: Qzeros = [getEqPoint1D(Qref, dQ_profile, Qi[i])] Qplot += Qzeros Aplot += [Adrive] * len(Qzeros) ax.plot(np.array(Aplot) * Afactor, np.array(Qplot) * 1e5, '.', c=color, label='{} neuron - {:.0f} % DC'.format(neuron.name, DC * 1e2)) # If specified, compute and plot the threshold excitation amplitude if titrate: if stim_type == 'US': Athr = nbls.titrate(Fdrive, tstim, TITRATION_T_OFFSET, PRF=PRF, DC=DC, Arange=(Aref.min(), Aref.max())) # Pa + ax.axvline(Athr * Afactor, c=color, linestyle='--') else: - Athr = neuron.titrate(tstim, TITRATION_T_OFFSET, PRF=PRF, DC=DC, - Arange=(0., Aref.max())) # mA/m2 - ax.axvline(Athr * Afactor, c=color, linestyle='--') - ax.axvline(-Athr * Afactor, c=color, linestyle='--') + Athr_pos = neuron.titrate(tstim, TITRATION_T_OFFSET, PRF=PRF, DC=DC, + Arange=(0., Aref.max())) # mA/m2 + ax.axvline(Athr_pos * Afactor, c=color, linestyle='--') + Athr_neg = neuron.titrate(tstim, TITRATION_T_OFFSET, PRF=PRF, DC=DC, + Arange=(Aref.min(), 0.)) # mA/m2 + ax.axvline(Athr_neg * Afactor, c=color, linestyle='-.') + icolor += 1 # Post-process figure ax.legend(frameon=False, fontsize=fs) fig.tight_layout() fig.canvas.set_window_title('QSS_{}_vs_{}A_{}_{}_{}%DC{}'.format( 'Qeq' if Qi[0] is not None else 'SFPs', xscale, '_'.join([n.name for n in neurons]), stim_type, '_'.join(['{:.0f}'.format(DC * 1e2) for DC in DCs]), '_with_thresholds' if titrate else '' )) return fig diff --git a/PySONIC/utils.py b/PySONIC/utils.py index cdabcc0..8cf991e 100644 --- a/PySONIC/utils.py +++ b/PySONIC/utils.py @@ -1,574 +1,580 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-09-19 22:30:46 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-04-05 17:14:39 +# @Last Modified time: 2019-04-29 12:04:32 ''' Definition of generic utility functions used in other modules ''' import operator import os import math import pickle import tkinter as tk from tkinter import filedialog import numpy as np import colorlog from scipy.interpolate import interp1d # Package logger def setLogger(): log_formatter = colorlog.ColoredFormatter( '%(log_color)s %(asctime)s %(message)s', datefmt='%d/%m/%Y %H:%M:%S:', reset=True, log_colors={ 'DEBUG': 'green', 'INFO': 'white', 'WARNING': 'yellow', 'ERROR': 'red', 'CRITICAL': 'red,bg_white', }, style='%' ) log_handler = colorlog.StreamHandler() log_handler.setFormatter(log_formatter) color_logger = colorlog.getLogger('PySONIC') color_logger.addHandler(log_handler) return color_logger logger = setLogger() # File naming conventions def ESTIM_filecode(neuron, Astim, tstim, PRF, DC): return 'ESTIM_{}_{}_{:.1f}mA_per_m2_{:.0f}ms{}'.format( neuron, 'CW' if DC == 1 else 'PW', Astim, tstim * 1e3, '_PRF{:.2f}Hz_DC{:.2f}%'.format(PRF, DC * 1e2) if DC < 1. else '') def ASTIM_filecode(neuron, a, Fdrive, Adrive, tstim, PRF, DC, method): return 'ASTIM_{}_{}_{:.0f}nm_{:.0f}kHz_{:.2f}kPa_{:.0f}ms_{}{}'.format( neuron, 'CW' if DC == 1 else 'PW', a * 1e9, Fdrive * 1e-3, Adrive * 1e-3, tstim * 1e3, 'PRF{:.2f}Hz_DC{:.2f}%_'.format(PRF, DC * 1e2) if DC < 1. else '', method) def MECH_filecode(a, Fdrive, Adrive, Qm): return 'MECH_{:.0f}nm_{:.0f}kHz_{:.1f}kPa_{:.1f}nCcm2'.format( a * 1e9, Fdrive * 1e-3, Adrive * 1e-3, Qm * 1e5) # Figure naming conventions def figtitle(meta): ''' Return appropriate title based on simulation metadata. ''' if 'Cm0' in meta: return '{:.0f}nm radius BLS structure: MECH-STIM {:.0f}kHz, {:.2f}kPa, {:.1f}nC/cm2'.format( meta['a'] * 1e9, meta['Fdrive'] * 1e-3, meta['Adrive'] * 1e-3, meta['Qm'] * 1e5) else: if meta['DC'] < 1: wavetype = 'PW' suffix = ', {:.2f}Hz PRF, {:.0f}% DC'.format(meta['PRF'], meta['DC'] * 1e2) else: wavetype = 'CW' suffix = '' if 'Astim' in meta: return '{} neuron: {} E-STIM {:.2f}mA/m2, {:.0f}ms{}'.format( meta['neuron'], wavetype, meta['Astim'], meta['tstim'] * 1e3, suffix) else: return '{} neuron ({:.1f}nm): {} A-STIM {:.0f}kHz {:.2f}kPa, {:.0f}ms{} - {} model'.format( meta['neuron'], meta['a'] * 1e9, wavetype, meta['Fdrive'] * 1e-3, meta['Adrive'] * 1e-3, meta['tstim'] * 1e3, suffix, meta['method']) # SI units prefixes si_prefixes = { 'y': 1e-24, # yocto 'z': 1e-21, # zepto 'a': 1e-18, # atto 'f': 1e-15, # femto 'p': 1e-12, # pico 'n': 1e-9, # nano 'u': 1e-6, # micro 'm': 1e-3, # mili '': 1e0, # None 'k': 1e3, # kilo 'M': 1e6, # mega 'G': 1e9, # giga 'T': 1e12, # tera 'P': 1e15, # peta 'E': 1e18, # exa 'Z': 1e21, # zetta 'Y': 1e24, # yotta } def loadData(fpath, frequency=1): ''' Load dataframe and metadata dictionary from pickle file. ''' logger.info('Loading data from "%s"', os.path.basename(fpath)) with open(fpath, 'rb') as fh: frame = pickle.load(fh) df = frame['data'].iloc[::frequency] meta = frame['meta'] return df, meta def si_format(x, precision=0, space=' '): ''' Format a float according to the SI unit system, with the appropriate prefix letter. ''' if isinstance(x, float) or isinstance(x, int) or isinstance(x, np.float) or\ isinstance(x, np.int32) or isinstance(x, np.int64): if x == 0: factor = 1e0 prefix = '' else: sorted_si_prefixes = sorted(si_prefixes.items(), key=operator.itemgetter(1)) vals = [tmp[1] for tmp in sorted_si_prefixes] # vals = list(si_prefixes.values()) ix = np.searchsorted(vals, np.abs(x)) - 1 if np.abs(x) == vals[ix + 1]: ix += 1 factor = vals[ix] prefix = sorted_si_prefixes[ix][0] # prefix = list(si_prefixes.keys())[ix] return '{{:.{}f}}{}{}'.format(precision, space, prefix).format(x / factor) elif isinstance(x, list) or isinstance(x, tuple): return [si_format(item, precision, space) for item in x] elif isinstance(x, np.ndarray) and x.ndim == 1: return [si_format(float(item), precision, space) for item in x] else: print(type(x)) def pow10_format(number, precision=2): ''' Format a number in power of 10 notation. ''' ret_string = '{0:.{1:d}e}'.format(number, precision) a, b = ret_string.split("e") a = float(a) b = int(b) return '{}10^{{{}}}'.format('{} * '.format(a) if a != 1. else '', b) def rmse(x1, x2): ''' Compute the root mean square error between two 1D arrays ''' return np.sqrt(((x1 - x2) ** 2).mean()) def rsquared(x1, x2): ''' compute the R-squared coefficient between two 1D arrays ''' residuals = x1 - x2 ss_res = np.sum(residuals**2) ss_tot = np.sum((x1 - np.mean(x1))**2) return 1 - (ss_res / ss_tot) def Pressure2Intensity(p, rho=1075.0, c=1515.0): ''' Return the spatial peak, pulse average acoustic intensity (ISPPA) associated with the specified pressure amplitude. :param p: pressure amplitude (Pa) :param rho: medium density (kg/m3) :param c: speed of sound in medium (m/s) :return: spatial peak, pulse average acoustic intensity (W/m2) ''' return p**2 / (2 * rho * c) def Intensity2Pressure(I, rho=1075.0, c=1515.0): ''' Return the pressure amplitude associated with the specified spatial peak, pulse average acoustic intensity (ISPPA). :param I: spatial peak, pulse average acoustic intensity (W/m2) :param rho: medium density (kg/m3) :param c: speed of sound in medium (m/s) :return: pressure amplitude (Pa) ''' return np.sqrt(2 * rho * c * I) def OpenFilesDialog(filetype, dirname=''): ''' Open a FileOpenDialogBox to select one or multiple file. The default directory and file type are given. :param dirname: default directory :param filetype: default file type :return: tuple of full paths to the chosen filenames ''' root = tk.Tk() root.withdraw() filenames = filedialog.askopenfilenames(filetypes=[(filetype + " files", '.' + filetype)], initialdir=dirname) if filenames: par_dir = os.path.abspath(os.path.join(filenames[0], os.pardir)) else: par_dir = None return (filenames, par_dir) def selectDirDialog(): ''' Open a dialog box to select a directory. :return: full path to selected directory ''' root = tk.Tk() root.withdraw() return filedialog.askdirectory() def SaveFileDialog(filename, dirname=None, ext=None): ''' Open a dialog box to save file. :param filename: filename :param dirname: initial directory :param ext: default extension :return: full path to the chosen filename ''' root = tk.Tk() root.withdraw() filename_out = filedialog.asksaveasfilename( defaultextension=ext, initialdir=dirname, initialfile=filename) return filename_out def downsample(t_dense, y, nsparse): ''' Decimate periodic signals to a specified number of samples.''' if(y.ndim) > 1: nsignals = y.shape[0] else: nsignals = 1 y = np.array([y]) # determine time step and period of input signal T = t_dense[-1] - t_dense[0] dt_dense = t_dense[1] - t_dense[0] # resample time vector linearly t_ds = np.linspace(t_dense[0], t_dense[-1], nsparse) # create MAV window nmav = int(0.03 * T / dt_dense) if nmav % 2 == 0: nmav += 1 mav = np.ones(nmav) / nmav # determine signals padding npad = int((nmav - 1) / 2) # determine indexes of sampling on convolved signals ids = np.round(np.linspace(0, t_dense.size - 1, nsparse)).astype(int) y_ds = np.empty((nsignals, nsparse)) # loop through signals for i in range(nsignals): # pad, convolve and resample pad_left = y[i, -(npad + 2):-2] pad_right = y[i, 1:npad + 1] y_ext = np.concatenate((pad_left, y[i, :], pad_right), axis=0) y_mav = np.convolve(y_ext, mav, mode='valid') y_ds[i, :] = y_mav[ids] if nsignals == 1: y_ds = y_ds[0, :] return (t_ds, y_ds) def rescale(x, lb=None, ub=None, lb_new=0, ub_new=1): ''' Rescale a value to a specific interval by linear transformation. ''' if lb is None: lb = x.min() if ub is None: ub = x.max() xnorm = (x - lb) / (ub - lb) return xnorm * (ub_new - lb_new) + lb_new def getNeuronLookupsFile(mechname, a=None, Fdrive=None, Adrive=None, fs=False): fpath = os.path.join( os.path.split(__file__)[0], 'neurons', '{}_lookups'.format(mechname) ) if a is not None: fpath += '_{:.0f}nm'.format(a * 1e9) if Fdrive is not None: fpath += '_{:.0f}kHz'.format(Fdrive * 1e-3) if Adrive is not None: fpath += '_{:.0f}kPa'.format(Adrive * 1e-3) if fs is True: fpath += '_fs' return '{}.pkl'.format(fpath) def getLookups4D(mechname): ''' Retrieve 4D lookup tables and reference vectors for a given membrane mechanism. :param mechname: name of membrane density mechanism :return: 4-tuple with 1D numpy arrays of reference input vectors (charge density and one other variable), a dictionary of associated 2D lookup numpy arrays, and a dictionary with information about the other variable. ''' # Check lookup file existence lookup_path = getNeuronLookupsFile(mechname) if not os.path.isfile(lookup_path): raise FileNotFoundError('Missing lookup file: "{}"'.format(lookup_path)) # Load lookups dictionary logger.debug('Loading %s lookup table', mechname) with open(lookup_path, 'rb') as fh: df = pickle.load(fh) inputs = df['input'] lookups4D = df['lookup'] # Retrieve 1D inputs from lookups dictionary aref = inputs['a'] Fref = inputs['f'] Aref = inputs['A'] Qref = inputs['Q'] return aref, Fref, Aref, Qref, lookups4D def getLookupsOff(mechname): ''' Retrieve appropriate US-OFF lookup tables and reference vectors for a given membrane mechanism. :param mechname: name of membrane density mechanism :return: 2-tuple with 1D numpy array of reference charge density and dictionary of associated 1D lookup numpy arrays. ''' # Get 4D lookups and input vectors aref, Fref, Aref, Qref, lookups4D = getLookups4D(mechname) # Perform 2D projection in appropriate dimensions logger.debug('Interpolating lookups at A = 0') lookups_off = {key: y4D[0, 0, 0, :] for key, y4D in lookups4D.items()} return Qref, lookups_off def getLookups2D(mechname, a=None, Fdrive=None, Adrive=None): ''' Retrieve appropriate 2D lookup tables and reference vectors for a given membrane mechanism, projected at a specific combination of sonophore radius, US frequency and/or acoustic pressure amplitude. :param mechname: name of membrane density mechanism :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param Adrive: Acoustic peak pressure amplitude (Hz) :return: 4-tuple with 1D numpy arrays of reference input vectors (charge density and one other variable), a dictionary of associated 2D lookup numpy arrays, and a dictionary with information about the other variable. ''' # Get 4D lookups and input vectors aref, Fref, Aref, Qref, lookups4D = getLookups4D(mechname) # Check that inputs are within lookup range if a is not None: a = isWithin('radius', a, (aref.min(), aref.max())) if Fdrive is not None: Fdrive = isWithin('frequency', Fdrive, (Fref.min(), Fref.max())) if Adrive is not None: Adrive = isWithin('amplitude', Adrive, (Aref.min(), Aref.max())) # Determine projection dimensions based on inputs var_a = {'name': 'a', 'label': 'sonophore radius', 'val': a, 'unit': 'm', 'factor': 1e9, 'ref': aref, 'axis': 0} var_Fdrive = {'name': 'f', 'label': 'frequency', 'val': Fdrive, 'unit': 'Hz', 'factor': 1e-3, 'ref': Fref, 'axis': 1} var_Adrive = {'name': 'A', 'label': 'amplitude', 'val': Adrive, 'unit': 'Pa', 'factor': 1e-3, 'ref': Aref, 'axis': 2} if not isinstance(Adrive, float): var1 = var_a var2 = var_Fdrive var3 = var_Adrive elif not isinstance(Fdrive, float): var1 = var_a var2 = var_Adrive var3 = var_Fdrive elif not isinstance(a, float): var1 = var_Fdrive var2 = var_Adrive var3 = var_a # Perform 2D projection in appropriate dimensions logger.debug('Interpolating lookups at (%s = %s%s, %s = %s%s)', var1['name'], si_format(var1['val'], space=' '), var1['unit'], var2['name'], si_format(var2['val'], space=' '), var2['unit']) lookups3D = {key: interp1d(var1['ref'], y4D, axis=var1['axis'])(var1['val']) for key, y4D in lookups4D.items()} if var2['axis'] > var1['axis']: var2['axis'] -= 1 lookups2D = {key: interp1d(var2['ref'], y3D, axis=var2['axis'])(var2['val']) for key, y3D in lookups3D.items()} if var3['val'] is not None: logger.debug('Interpolating lookups at %d new %s values between %s%s and %s%s', len(var3['val']), var3['name'], si_format(min(var3['val']), space=' '), var3['unit'], si_format(max(var3['val']), space=' '), var3['unit']) lookups2D = {key: interp1d(var3['ref'], y2D, axis=0)(var3['val']) for key, y2D in lookups2D.items()} var3['ref'] = np.array(var3['val']) return var3['ref'], Qref, lookups2D, var3 def getLookups2Dfs(mechname, a, Fdrive, fs): # Check lookup file existence lookup_path = getNeuronLookupsFile(mechname, a=a, Fdrive=Fdrive, fs=True) if not os.path.isfile(lookup_path): raise FileNotFoundError('Missing lookup file: "{}"'.format(lookup_path)) # Load lookups dictionary logger.debug('Loading %s lookup table with fs = %.0f%%', mechname, fs * 1e2) with open(lookup_path, 'rb') as fh: df = pickle.load(fh) inputs = df['input'] lookups3D = df['lookup'] # Retrieve 1D inputs from lookups dictionary fsref = inputs['fs'] Aref = inputs['A'] Qref = inputs['Q'] # Check that fs is within lookup range fs = isWithin('coverage', fs, (fsref.min(), fsref.max())) # Perform projection at fs logger.debug('Interpolating lookups at fs = %s%%', fs * 1e2) lookups2D = {key: interp1d(fsref, y3D, axis=2)(fs) for key, y3D in lookups3D.items()} return Aref, Qref, lookups2D def isWithin(name, val, bounds, rel_tol=1e-9): ''' Check if a floating point number is within an interval. If the value falls outside the interval, an error is raised. If the value falls just outside the interval due to rounding errors, the associated interval bound is returned. :param val: float value :param bounds: interval bounds (float tuple) :return: original or corrected value ''' if isinstance(val, list) or isinstance(val, np.ndarray) or isinstance(val, tuple): return [isWithin(name, v, bounds, rel_tol) for v in val] if val >= bounds[0] and val <= bounds[1]: return val elif val < bounds[0] and math.isclose(val, bounds[0], rel_tol=rel_tol): logger.warning('Rounding %s value (%s) to interval lower bound (%s)', name, val, bounds[0]) return bounds[0] elif val > bounds[1] and math.isclose(val, bounds[1], rel_tol=rel_tol): logger.warning('Rounding %s value (%s) to interval upper bound (%s)', name, val, bounds[1]) return bounds[1] else: raise ValueError('{} value ({}) out of [{}, {}] interval'.format( name, val, bounds[0], bounds[1])) def getLookupsCompTime(mechname): # Check lookup file existence lookup_path = getNeuronLookupsFile(mechname) if not os.path.isfile(lookup_path): raise FileNotFoundError('Missing lookup file: "{}"'.format(lookup_path)) # Load lookups dictionary logger.debug('Loading comp times') with open(lookup_path, 'rb') as fh: df = pickle.load(fh) tcomps4D = df['tcomp'] return np.sum(tcomps4D) def getLowIntensitiesSTN(): ''' Return an array of acoustic intensities (W/m2) used to study the STN neuron in Tarnaud, T., Joseph, W., Martens, L., and Tanghe, E. (2018). Computational Modeling of Ultrasonic Subthalamic Nucleus Stimulation. IEEE Trans Biomed Eng. ''' return np.hstack(( np.arange(10, 101, 10), np.arange(101, 131, 1), np.array([140]) )) # W/m2 def getIndex(container, value): ''' Return the index of a float / string value in a list / array :param container: list / 1D-array of elements :param value: value to search for :return: index of value (if found) ''' if isinstance(value, float): container = np.array(container) imatches = np.where(np.isclose(container, value, rtol=1e-9, atol=1e-16))[0] if len(imatches) == 0: raise ValueError('{} not found in {}'.format(value, container)) return imatches[0] elif isinstance(value, str): return container.index(value) -def titrate(nspikes_func, args, xbounds, dx_thr, is_excited=[]): +def titrate(nspikes_func, args, xbounds, dx_thr, is_excited=None): ''' Use a binary search to determine an excitation threshold within a specific search interval. :param nspikes_func: function returning the number of spikes for a given condition :param args: list of function arguments other than refined value :param xbounds: search interval for threshold (progressively refined) :param dx_thr: accuracy criterion for threshold :return: excitation threshold ''' + + if is_excited is None: + is_excited = [] + x = (xbounds[0] + xbounds[1]) / 2 nspikes = nspikes_func(x, *args) is_excited.append(nspikes > 0) - conv = False # When titration interval becomes small enough if (xbounds[1] - xbounds[0]) <= dx_thr: logger.debug('titration interval smaller than defined threshold') # If exactly one spike, convergence is achieved -> return if nspikes == 1: logger.debug('exactly one spike -> convergence') conv = True # If both conditions have been encountered during titration process, # we're going towards convergence elif (0 in is_excited and 1 in is_excited): logger.debug('converging around threshold') # if current condition yields excitation, convergence is achieved if is_excited[-1]: logger.debug('currently above threshold -> convergence') conv = True # If only one condition has been encountered during titration process, # then no titration is impossible within the defined interval -> return NaN else: logger.warning('titration does not converge within this interval') return np.nan # Return threshold if convergence is reached, otherwise refine interval and iterate if conv: return x else: - xbounds = (xbounds[0], x) if is_excited[-1] else (x, xbounds[1]) + if x > 0.: + xbounds = (xbounds[0], x) if is_excited[-1] else (x, xbounds[1]) + else: + xbounds = (x, xbounds[1]) if is_excited[-1] else (xbounds[0], x) return titrate(nspikes_func, args, xbounds, dx_thr, is_excited=is_excited) diff --git a/scripts/plot_QSS_IQ.py b/scripts/plot_QSS_IQ.py index 3b63959..75b124c 100644 --- a/scripts/plot_QSS_IQ.py +++ b/scripts/plot_QSS_IQ.py @@ -1,121 +1,121 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2018-09-28 16:13:34 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-04-08 20:48:52 +# @Last Modified time: 2019-04-26 16:41:11 ''' Phase-plane analysis of neuron behavior under quasi-steady state approximation. ''' import os import numpy as np import matplotlib.pyplot as plt from argparse import ArgumentParser import logging from PySONIC.utils import logger, selectDirDialog from PySONIC.neurons import getNeuronsDict from PySONIC.plt import plotQSSVarVsAmp, plotEqChargeVsAmp def main(): ap = ArgumentParser() # Stimulation parameters ap.add_argument('-n', '--neurons', type=str, nargs='+', default=None, help='Neuron types') ap.add_argument('-o', '--outputdir', type=str, default=None, help='Output directory') ap.add_argument('-c', '--cmap', type=str, default='viridis', help='Colormap name') ap.add_argument('-v', '--verbose', default=False, action='store_true', help='Increase verbosity') ap.add_argument('-s', '--save', default=False, action='store_true', help='Save output figures') ap.add_argument('--titrate', default=False, action='store_true', help='Titrate excitation threshold') ap.add_argument('--tstim', type=float, default=500., help='Stimulus duration for titration (ms)') ap.add_argument('--PRF', type=float, default=100., help='Pulse-repetition-frequency for titration (Hz)') ap.add_argument('--DC', type=float, nargs='+', default=None, help='Duty cycle (%)') ap.add_argument('--Qi', type=str, default=None, help='Initial membrane charge density for phase-plane analysis (nC/cm2)') ap.add_argument('--Ascale', type=str, default='lin', help='Scale type for acoustic amplitude ("lin" or "log")') ap.add_argument('-p', '--plotdetails', default=False, action='store_true', help='Plot details') ap.add_argument('--stim', type=str, default='US', help='Stimulation type ("US" or "elec")') # Parse arguments args = ap.parse_args() logger.setLevel(logging.DEBUG if args.verbose else logging.INFO) neurons = ['RS', 'LTS'] if args.neurons is None else args.neurons neurons = [getNeuronsDict()[n]() for n in neurons] if args.Qi == 'Qm0': Qi = np.array([neuron.Qm0() for neuron in neurons]) # C/m2 elif args.Qi is None: Qi = [None] * len(neurons) else: Qi = np.ones(len(neurons)) * float(args.Qi) * 1e-5 # C/m2 # US parameters a = 32e-9 # m Fdrive = 500e3 # Hz Arange = (1., 50.) # kPa nA = 300 US_amps = { 'lin': np.linspace(Arange[0], Arange[1], nA), 'log': np.logspace(np.log10(Arange[0]), np.log10(Arange[1]), nA) }[args.Ascale] * 1e3 # Pa # E-STIM parameters Irange = (-10., 10.) # mA/m2 nI = 100 Iinjs = np.linspace(Irange[0], Irange[1], nI) # mA/m2 # Pulsing parameters tstim = args.tstim * 1e-3 # s PRF = args.PRF # Hz DCs = [100.] if args.DC is None else args.DC # % DCs = np.array(DCs) * 1e-2 # (-) if args.stim == 'US': amps = US_amps cmap = args.cmap else: a = None Fdrive = None amps = Iinjs - cmap = 'coolwarm' + cmap = 'RdBu_r' figs = [] if args.plotdetails: # Plot iNet vs Q for different amplitudes for each neuron and DC for i, neuron in enumerate(neurons): for DC in DCs: figs.append(plotQSSVarVsAmp( neuron, a, Fdrive, 'iNet', amps=amps, DC=DC, Qi=Qi[i], cmap=cmap, zscale=args.Ascale)) # Plot equilibrium charge as a function of amplitude for each neuron figs.append( plotEqChargeVsAmp( neurons, a, Fdrive, amps=amps, tstim=tstim, PRF=PRF, DCs=DCs, Qi=Qi, xscale=args.Ascale, titrate=args.titrate)) if args.save: outputdir = args.outputdir if args.outputdir is not None else selectDirDialog() if outputdir == '': logger.error('no output directory') else: for fig in figs: s = fig.canvas.get_window_title() s = s.replace('(', '- ').replace('/', '_').replace(')', '') figname = '{}.png'.format(s) fig.savefig(os.path.join(outputdir, figname), transparent=True) else: plt.show() if __name__ == '__main__': main()