diff --git a/.gitignore b/.gitignore index 3fad5b5..c3e78b6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,112 +1,113 @@ # Sphinx tools and doc _build/ _static/ _templates/ Makefile *.bat conf.py -# Sublime Workspace files +# Sublime files +*.sublime-project *.sublime-workspace # PNG images *.png # Temporary stats file *.stats # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # NEURON object and binary files *.o *.c *.dll # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *,cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # dotenv .env # virtualenv .venv/ venv/ ENV/ # Spyder project settings .spyderproject # Rope project settings .ropeproject \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index f1f150f..3e3d1c2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,5 @@ include README.md include PySONIC/neurons/*.pkl -include /scripts/*.py \ No newline at end of file +include /scripts/*.py +include /tests/* +include /notebooks/*.ipynb \ No newline at end of file diff --git a/PySONIC.sublime-project b/PySONIC.sublime-project index f0adbd2..e0ab6dc 100644 --- a/PySONIC.sublime-project +++ b/PySONIC.sublime-project @@ -1,36 +1,41 @@ { "build_systems": [ { "file_regex": "^[ ]*File \"(...*?)\", line ([0-9]*)", "name": "Anaconda Python Builder", "selector": "source.python", "shell_cmd": "\"C:\\ProgramData\\Anaconda3\\python.exe\" -u \"$file\"" } ], "folders": [ { "file_exclude_patterns": [ "*.sublime-workspace", - "*.sublime-project", "MANIFEST.in", "LICENSE", "conf.py", "index.rst", "*.gitignore", "__init__.py", "*.c", - "run_tests.sh" + "*.sh", + "*.bat", + "Makefile" ], "folder_exclude_patterns": [ "docs", - "*.egg-info" + "*.egg-info", + ".ipynb_checkpoints", + "_build", + "_static", + "_templates" ], "path": "." } ], "translate_tabs_to_spaces": true } diff --git a/PySONIC/batches.py b/PySONIC/batches.py index f253198..32cb500 100644 --- a/PySONIC/batches.py +++ b/PySONIC/batches.py @@ -1,187 +1,187 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-08-22 14:33:04 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-27 02:02:34 +# @Last Modified time: 2018-09-28 14:14:11 -""" Utility functions used in simulations """ +''' Utility functions used in simulations ''' import os import lockfile import logging import multiprocessing as mp import numpy as np import pandas as pd from .utils import logger class Consumer(mp.Process): ''' Generic consumer process, taking tasks from a queue and outputing results in another queue. ''' def __init__(self, queue_in, queue_out): mp.Process.__init__(self) self.queue_in = queue_in self.queue_out = queue_out logger.debug('Starting %s', self.name) def run(self): while True: nextTask = self.queue_in.get() if nextTask is None: logger.debug('Exiting %s', self.name) self.queue_in.task_done() break answer = nextTask() self.queue_in.task_done() self.queue_out.put(answer) return class Worker(): ''' Generic worker class calling a specific object's method with a given set of parameters. ''' def __init__(self, wid, obj, method_str, params, loglevel): ''' Worker constructor. :param wid: worker ID :param obj: object containing the method to call :param method_str: name of the method to call :param params: list of method parameters :param loglevel: logging level ''' self.id = wid self.obj = obj self.method_str = method_str self.params = params self.loglevel = loglevel def __call__(self): ''' Caller to the specific object method. ''' logger.setLevel(self.loglevel) return self.id, getattr(self.obj, self.method_str)(*self.params) def createQueue(dims): ''' Create a serialized 2D array of all parameter combinations for a series of individual parameter sweeps. :param dims: list of lists (or 1D arrays) of input parameters :return: list of parameters (list) for each simulation ''' ndims = len(dims) dims_in = [dims[1], dims[0]] inds_out = [1, 0] if ndims > 2: dims_in += dims[2:] inds_out += list(range(2, ndims)) queue = np.stack(np.meshgrid(*dims_in), -1).reshape(-1, ndims) queue = queue[:, inds_out] return queue.tolist() def createSimQueue(amps, durations, offsets, PRFs, DCs): ''' Create a serialized 2D array of all parameter combinations for a series of individual parameter sweeps, while avoiding repetition of CW protocols for a given PRF sweep. :param amps: list (or 1D-array) of acoustic amplitudes :param durations: list (or 1D-array) of stimulus durations :param offsets: list (or 1D-array) of stimulus offsets (paired with durations array) :param PRFs: list (or 1D-array) of pulse-repetition frequencies :param DCs: list (or 1D-array) of duty cycle values :return: list of parameters (list) for each simulation ''' DCs = np.array(DCs) queue = [] if 1.0 in DCs: queue += createQueue((durations, offsets, PRFs.min(), 1.0, amps)) if np.any(DCs != 1.0): queue += createQueue((durations, offsets, PRFs, DCs[DCs != 1.0], amps)) return queue def runBatch(obj, method_str, queue, extra_params=[], mpi=False, loglevel=logging.INFO): ''' Run batch of simulations of a given object for various combinations of stimulation parameters. :param queue: array of all stimulation parameters combinations :param mpi: boolean stating whether or not to use multiprocessing ''' nsims = len(queue) if mpi: mp.freeze_support() tasks = mp.JoinableQueue() results = mp.Queue() nconsumers = min(mp.cpu_count(), nsims) consumers = [Consumer(tasks, results) for i in range(nconsumers)] for w in consumers: w.start() # Run simulations outputs = [] for i, stim_params in enumerate(queue): params = extra_params + stim_params if mpi: worker = Worker(i, obj, method_str, params, loglevel) tasks.put(worker, block=False) else: outputs.append(getattr(obj, method_str)(*params)) if mpi: for i in range(nconsumers): tasks.put(None, block=False) tasks.join() idxs = [] for i in range(nsims): wid, out = results.get() outputs.append(out) idxs.append(wid) outputs = [x for _, x in sorted(zip(idxs, outputs))] # Close tasks and results queues tasks.close() results.close() return outputs def xlslog(filepath, logentry, sheetname='Data'): - """ Append log data on a new row to specific sheet of excel workbook, using a lockfile + ''' Append log data on a new row to specific sheet of excel workbook, using a lockfile to avoid read/write errors between concurrent processes. :param filepath: absolute or relative path to the Excel workbook :param logentry: log entry (dictionary) to add to log file :param sheetname: name of the Excel spreadsheet to which data is appended :return: boolean indicating success (1) or failure (0) of operation - """ + ''' # Parse log dataframe from Excel file if it exists, otherwise create new one if not os.path.isfile(filepath): df = pd.DataFrame(columns=logentry.keys()) else: df = pd.read_excel(filepath, sheet_name=sheetname) # Add log entry to log dataframe df = df.append(pd.Series(logentry), ignore_index=True) # Write log dataframe to Excel file try: lock = lockfile.FileLock(filepath) lock.acquire() writer = pd.ExcelWriter(filepath) df.to_excel(writer, sheet_name=sheetname, index=False) writer.save() lock.release() return 1 except PermissionError: # If file cannot be accessed for writing because already opened logger.warning('Cannot write to "%s". Close the file and type "Y"', filepath) user_str = input() if user_str in ['y', 'Y']: return xlslog(filepath, logentry, sheetname) else: return 0 diff --git a/PySONIC/constants.py b/PySONIC/constants.py index 5b2c338..5955944 100644 --- a/PySONIC/constants.py +++ b/PySONIC/constants.py @@ -1,56 +1,51 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-11-04 13:23:31 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-22 20:35:56 +# @Last Modified time: 2018-09-28 14:08:01 ''' Algorithmic constants used in the package. ''' - # Fitting and pre-processing LJFIT_PM_MAX = 1e8 # intermolecular pressure at the deflection lower bound for LJ fitting (Pa) PNET_EQ_MAX = 1e-1 # error threshold for net pressure at computed equilibrium position (Pa) PMAVG_STD_ERR_MAX = 2000 # error threshold in nonlinear fit of molecular pressure (Pa) - # Mechanical simulations Z_ERR_MAX = 1e-11 # periodic convergence threshold for deflection (m) NG_ERR_MAX = 1e-24 # periodic convergence threshold for gas content (mol) NCYCLES_MAX = 10 # max number of acoustic cycles in mechanical simulations CHARGE_RANGE = (-120e-5, 70e-5) # physiological charge range constraining the membrane (C/m2) # E-STIM simulations DT_ESTIM = 1e-4 - # A-STIM simulations SOLVER_NSTEPS = 1000 # maximum number of steps allowed during one call to the LSODA/DOP853 solvers CLASSIC_TARGET_DT = 1e-8 # target temporal resolution for output arrays of classic simulations NPC_FULL = 1000 # nb of samples per acoustic period in full system NPC_HH = 40 # nb of samples per acoustic period in HH system DQ_UPDATE = 1e-5 # charge evolution threshold between two hybrid integrations (C/m2) DT_UPDATE = 5e-4 # time interval between two hybrid integrations (s) DT_EFF = 5e-5 # time step for effective integration (s) MIN_SAMPLES_PER_PULSE_INT = 1 # minimal number of time points per pulse interval (TON of TOFF) - # Spike detection SPIKE_MIN_QAMP = 5e-5 # threshold amplitude for spike detection on charge signal (C/m2) SPIKE_MIN_QPROM = 5e-5 # threshold prominence for spike detection on charge signal (C/m2) SPIKE_MIN_VAMP = 10.0 # threshold amplitude for spike detection on potential signal (mV) SPIKE_MIN_VPROM = 10.0 # threshold prominence for spike detection on potential signal (mV) SPIKE_MIN_DT = 1e-3 # minimal time interval for spike detection on charge signal (s) - # Titrations TITRATION_T_OFFSET = 50e-3 # offset period for titration procedures (s) TITRATION_ASTIM_A_MAX = 6e5 - 1 # initial acoustic pressure upper bound for titration (Pa) TITRATION_ASTIM_DA_MAX = 1e3 # acoustic pressure search range threshold for titration (Pa) TITRATION_ESTIM_A_MAX = 50.0 # initial current density upper bound for titration (mA/m2) TITRATION_ESTIM_DA_MAX = 0.1 # current density search range threshold for titration (mA/m2) TITRATION_T_MAX = 2e-1 # initial stimulus duration upper bound for titration (s) TITRATION_DT_THR = 1e-3 # stimulus duration search range threshold for titration (s) TITRATION_DDC_THR = 0.01 # stimulus duty cycle search range threshold for titration (-) TITRATION_DC_MAX = 1.0 # initial stimulus duty cycle upper bound for titration (-) diff --git a/PySONIC/core/bls.py b/PySONIC/core/bls.py index 4e25f3b..b094d55 100644 --- a/PySONIC/core/bls.py +++ b/PySONIC/core/bls.py @@ -1,869 +1,869 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-09-29 16:16:19 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-26 13:51:38 +# @Last Modified time: 2018-09-28 14:14:11 from enum import Enum import time import os import json import inspect import pickle import numpy as np import pandas as pd import scipy.integrate as integrate from scipy.optimize import brentq, curve_fit from ..utils import logger, rmse, si_format, MECH_filecode from ..constants import * from ..batches import xlslog class PmCompMethod(Enum): - """ Enum: types of computation method for the intermolecular pressure """ + ''' Enum: types of computation method for the intermolecular pressure ''' direct = 1 predict = 2 def LennardJones(x, beta, alpha, C, m, n): - """ Generic expression of a Lennard-Jones function, adapted for the context of + ''' Generic expression of a Lennard-Jones function, adapted for the context of symmetric deflection (distance = 2x). :param x: deflection (i.e. half-distance) :param beta: x-shifting factor :param alpha: x-scaling factor :param C: y-scaling factor :param m: exponent of the repulsion term :param n: exponent of the attraction term :return: Lennard-Jones potential at given distance (2x) - """ + ''' return C * (np.power((alpha / (2 * x + beta)), m) - np.power((alpha / (2 * x + beta)), n)) class BilayerSonophore: - """ This class contains the geometric and mechanical parameters of the + ''' This class contains the geometric and mechanical parameters of the Bilayer Sonophore Model, as well as all the core functions needed to compute the dynamics (kinetics and kinematics) of the bilayer membrane cavitation, and run dynamic BLS simulations. - """ + ''' # BIOMECHANICAL PARAMETERS T = 309.15 # Temperature (K) Rg = 8.314 # Universal gas constant (Pa.m^3.mol^-1.K^-1) delta0 = 2.0e-9 # Thickness of the leaflet (m) Delta_ = 1.4e-9 # Initial gap between the two leaflets on a non-charged membrane at equil. (m) pDelta = 1.0e5 # Attraction/repulsion pressure coefficient (Pa) m = 5.0 # Exponent in the repulsion term (dimensionless) n = 3.3 # Exponent in the attraction term (dimensionless) rhoL = 1075.0 # Density of the surrounding fluid (kg/m^3) muL = 7.0e-4 # Dynamic viscosity of the surrounding fluid (Pa.s) muS = 0.035 # Dynamic viscosity of the leaflet (Pa.s) kA = 0.24 # Area compression modulus of the leaflet (N/m) alpha = 7.56 # Tissue shear loss modulus frequency coefficient (Pa.s) C0 = 0.62 # Initial gas molar concentration in the surrounding fluid (mol/m^3) kH = 1.613e5 # Henry's constant (Pa.m^3/mol) P0 = 1.0e5 # Static pressure in the surrounding fluid (Pa) Dgl = 3.68e-9 # Diffusion coefficient of gas in the fluid (m^2/s) xi = 0.5e-9 # Boundary layer thickness for gas transport across leaflet (m) c = 1515.0 # Speed of sound in medium (m/s) # BIOPHYSICAL PARAMETERS epsilon0 = 8.854e-12 # Vacuum permittivity (F/m) epsilonR = 1.0 # Relative permittivity of intramembrane cavity (dimensionless) def __init__(self, a, Cm0, Qm0, Fdrive=None, embedding_depth=0.0): - """ Constructor of the class. + ''' Constructor of the class. :param a: in-plane diameter of the sonophore structure within the membrane (m) :param Cm0: membrane resting capacitance (F/m2) :param Qm0: membrane resting charge density (C/m2) :param Fdrive: frequency of acoustic perturbation (Hz) :param embedding_depth: depth of the embedding tissue around the membrane (m) - """ + ''' # Extract resting constants and geometry self.Cm0 = Cm0 self.Qm0 = Qm0 self.a = a self.d = embedding_depth self.S0 = np.pi * self.a**2 # Derive frequency-dependent tissue elastic modulus if Fdrive is not None: G_tissue = self.alpha * Fdrive # G'' (Pa) self.kA_tissue = 2 * G_tissue * self.d # kA of the tissue layer (N/m) else: self.kA_tissue = 0. # Check existence of lookups for derived parameters lookups = self.getLookups() akey = '{:.1f}'.format(a * 1e9) Qkey = '{:.2f}'.format(Qm0 * 1e5) # If no lookup, compute parameters and store them in lookup if akey not in lookups or Qkey not in lookups[akey]: # Find Delta that cancels out Pm + Pec at Z = 0 (m) if self.Qm0 == 0.0: D_eq = self.Delta_ else: (D_eq, Pnet_eq) = self.findDeltaEq(self.Qm0) assert Pnet_eq < PNET_EQ_MAX, 'High Pnet at Z = 0 with ∆ = %.2f nm' % (D_eq * 1e9) self.Delta = D_eq # Find optimal Lennard-Jones parameters to approximate PMavg (LJ_approx, std_err, _) = self.LJfitPMavg() assert std_err < PMAVG_STD_ERR_MAX, 'High error in PmAvg nonlinear fit:'\ ' std_err = %.2f Pa' % std_err self.LJ_approx = LJ_approx if akey not in lookups: lookups[akey] = {Qkey: {'LJ_approx': LJ_approx, 'Delta_eq': D_eq}} else: lookups[akey][Qkey] = {'LJ_approx': LJ_approx, 'Delta_eq': D_eq} logger.debug('Saving BLS derived parameters to lookup file') self.saveLookups(lookups) # If lookup exists, load parameters from it else: logger.debug('Loading BLS derived parameters from lookup file') self.LJ_approx = lookups[akey][Qkey]['LJ_approx'] self.Delta = lookups[akey][Qkey]['Delta_eq'] # Compute initial volume and gas content self.V0 = np.pi * self.Delta * self.a**2 self.ng0 = self.gasPa2mol(self.P0, self.V0) def __repr__(self): return 'BilayerSonophore({}m, {}F/cm2, {}C/cm2, embedding_depth={}m'.format( si_format([self.a, self.Cm0 * 1e-4, self.Qm0 * 1e-4, self.embedding_depth], precision=1, space=' ')) def pprint(self): return '{}m diameter BilayerSonophore'.format( si_format(self.a, precision=0, space=' ')) def getLookupsPath(self): return os.path.join(os.path.split(__file__)[0], 'bls_lookups.json') def getLookups(self): try: with open(self.getLookupsPath()) as fh: sample = json.load(fh) return sample except FileNotFoundError: return {} def saveLookups(self, lookups): with open(self.getLookupsPath(), 'w') as fh: json.dump(lookups, fh, indent=2) def pparams(self): s = '-------- Bilayer Sonophore --------\n' s += 'class attributes:\n' class_attrs = inspect.getmembers(self.__class__, lambda a: not(inspect.isroutine(a))) class_attrs = [a for a in class_attrs if not(a[0].startswith('__') and a[0].endswith('__'))] for ca in class_attrs: s += '{} = {}\n'.format(ca[0], ca[1]) s += 'instance attributes:\n' inst_attrs = inspect.getmembers(self, lambda a: not(inspect.isroutine(a))) inst_attrs = [a for a in inst_attrs if not(a[0].startswith('__') and a[0].endswith('__')) and a not in class_attrs] for ia in inst_attrs: s += '{} = {}\n'.format(ia[0], ia[1]) return s def reinit(self): logger.debug('Re-initializing BLS object') # Find Delta that cancels out Pm + Pec at Z = 0 (m) if self.Qm0 == 0.0: D_eq = self.Delta_ else: (D_eq, Pnet_eq) = self.findDeltaEq(self.Qm0) assert Pnet_eq < PNET_EQ_MAX, 'High Pnet at Z = 0 with ∆ = %.2f nm' % (D_eq * 1e9) self.Delta = D_eq # Compute initial volume and gas content self.V0 = np.pi * self.Delta * self.a**2 self.ng0 = self.gasPa2mol(self.P0, self.V0) def curvrad(self, Z): - """ Return the (signed) instantaneous curvature radius of the leaflet. + ''' Return the (signed) instantaneous curvature radius of the leaflet. :param Z: leaflet apex outward deflection value (m) :return: leaflet curvature radius (m) - """ + ''' if Z == 0.0: return np.inf else: return (self.a**2 + Z**2) / (2 * Z) def v_curvrad(self, Z): ''' Vectorized curvrad function ''' return np.array(list(map(self.curvrad, Z))) def surface(self, Z): - """ Return the surface area of the stretched leaflet (spherical cap). + ''' Return the surface area of the stretched leaflet (spherical cap). :param Z: leaflet apex outward deflection value (m) :return: surface of the stretched leaflet (m^2) - """ + ''' return np.pi * (self.a**2 + Z**2) def volume(self, Z): - """ Return the total volume of the inter-leaflet space (cylinder +/- + ''' Return the total volume of the inter-leaflet space (cylinder +/- spherical cap). :param Z: leaflet apex outward deflection value (m) :return: inner volume of the bilayer sonophore structure (m^3) - """ + ''' return np.pi * self.a**2 * self.Delta\ * (1 + (Z / (3 * self.Delta) * (3 + Z**2 / self.a**2))) def arealstrain(self, Z): - """ Compute the areal strain of the stretched leaflet. + ''' Compute the areal strain of the stretched leaflet. epsilon = (S - S0)/S0 = (Z/a)^2 :param Z: leaflet apex outward deflection value (m) :return: areal strain (dimensionless) - """ + ''' return (Z / self.a)**2 def Capct(self, Z): - """ Compute the membrane capacitance per unit area, + ''' Compute the membrane capacitance per unit area, under the assumption of parallel-plate capacitor with average inter-layer distance. :param Z: leaflet apex outward deflection value (m) :return: capacitance per unit area (F/m2) - """ + ''' if Z == 0.0: return self.Cm0 else: return ((self.Cm0 * self.Delta / self.a**2) * (Z + (self.a**2 - Z**2 - Z * self.Delta) / (2 * Z) * np.log((2 * Z + self.Delta) / self.Delta))) def v_Capct(self, Z): ''' Vectorized Capct function ''' return np.array(list(map(self.Capct, Z))) def derCapct(self, Z, U): - """ Compute the derivative of the membrane capacitance per unit area + ''' Compute the derivative of the membrane capacitance per unit area with respect to time, under the assumption of parallel-plate capacitor. :param Z: leaflet apex outward deflection value (m) :param U: leaflet apex outward deflection velocity (m/s) :return: derivative of capacitance per unit area (F/m2.s) - """ + ''' dCmdZ = ((self.Cm0 * self.Delta / self.a**2) * ((Z**2 + self.a**2) / (Z * (2 * Z + self.Delta)) - ((Z**2 + self.a**2) * np.log((2 * Z + self.Delta) / self.Delta)) / (2 * Z**2))) return dCmdZ * U def localdef(self, r, Z, R): - """ Compute the (signed) local transverse leaflet deviation at a distance + ''' Compute the (signed) local transverse leaflet deviation at a distance r from the center of the dome. :param r: in-plane distance from center of the sonophore (m) :param Z: leaflet apex outward deflection value (m) :param R: leaflet curvature radius (m) :return: local transverse leaflet deviation (m) - """ + ''' if np.abs(Z) == 0.0: return 0.0 else: return np.sign(Z) * (np.sqrt(R**2 - r**2) - np.abs(R) + np.abs(Z)) def Pacoustic(self, t, Adrive, Fdrive, phi=np.pi): - """ Compute the acoustic pressure at a specific time, given + ''' Compute the acoustic pressure at a specific time, given the amplitude, frequency and phase of the acoustic stimulus. :param t: time of interest :param Adrive: acoustic drive amplitude (Pa) :param Fdrive: acoustic drive frequency (Hz) :param phi: acoustic drive phase (rad) - """ + ''' return Adrive * np.sin(2 * np.pi * Fdrive * t - phi) def PMlocal(self, r, Z, R): - """ Compute the local intermolecular pressure. + ''' Compute the local intermolecular pressure. :param r: in-plane distance from center of the sonophore (m) :param Z: leaflet apex outward deflection value (m) :param R: leaflet curvature radius (m) :return: local intermolecular pressure (Pa) - """ + ''' z = self.localdef(r, Z, R) relgap = (2 * z + self.Delta) / self.Delta_ return self.pDelta * ((1 / relgap)**self.m - (1 / relgap)**self.n) def PMavg(self, Z, R, S): - """ Compute the average intermolecular pressure felt across the leaflet + ''' Compute the average intermolecular pressure felt across the leaflet by quadratic integration. :param Z: leaflet apex outward deflection value (m) :param R: leaflet curvature radius (m) :param S: surface of the stretched leaflet (m^2) :return: averaged intermolecular resultant pressure across the leaflet (Pa) .. warning:: quadratic integration is computationally expensive. - """ + ''' # Integrate intermolecular force over an infinitely thin ring of radius r from 0 to a fTotal, _ = integrate.quad(lambda r, Z, R: 2 * np.pi * r * self.PMlocal(r, Z, R), 0, self.a, args=(Z, R)) return fTotal / S def v_PMavg(self, Z, R, S): ''' Vectorized PMavg function ''' return np.array(list(map(self.PMavg, Z, R, S))) def LJfitPMavg(self): - """ Determine optimal parameters of a Lennard-Jones expression + ''' Determine optimal parameters of a Lennard-Jones expression approximating the average intermolecular pressure. These parameters are obtained by a nonlinear fit of the Lennard-Jones function for a range of deflection values between predetermined Zmin and Zmax. :return: 3-tuple with optimized LJ parameters for PmAvg prediction (Map) and the standard and max errors of the prediction in the fitting range (in Pascals) - """ + ''' # Determine lower bound of deflection range: when Pm = Pmmax PMmax = LJFIT_PM_MAX # Pa Zminlb = -0.49 * self.Delta Zminub = 0.0 Zmin = brentq(lambda Z, Pmmax: self.PMavg(Z, self.curvrad(Z), self.surface(Z)) - PMmax, Zminlb, Zminub, args=(PMmax), xtol=1e-16) # Create vectors for geometric variables Zmax = 2 * self.a Z = np.arange(Zmin, Zmax, 1e-11) Pmavg = self.v_PMavg(Z, self.v_curvrad(Z), self.surface(Z)) # Compute optimal nonlinear fit of custom LJ function with initial guess x0_guess = self.delta0 C_guess = 0.1 * self.pDelta nrep_guess = self.m nattr_guess = self.n pguess = (x0_guess, C_guess, nrep_guess, nattr_guess) popt, _ = curve_fit(lambda x, x0, C, nrep, nattr: LennardJones(x, self.Delta, x0, C, nrep, nattr), Z, Pmavg, p0=pguess, maxfev=10000) (x0_opt, C_opt, nrep_opt, nattr_opt) = popt Pmavg_fit = LennardJones(Z, self.Delta, x0_opt, C_opt, nrep_opt, nattr_opt) # Compute prediction error residuals = Pmavg - Pmavg_fit ss_res = np.sum(residuals**2) N = residuals.size std_err = np.sqrt(ss_res / N) max_err = max(np.abs(residuals)) logger.debug('LJ approx: x0 = %.2f nm, C = %.2f kPa, m = %.2f, n = %.2f', x0_opt * 1e9, C_opt * 1e-3, nrep_opt, nattr_opt) LJ_approx = {"x0": x0_opt, "C": C_opt, "nrep": nrep_opt, "nattr": nattr_opt} return (LJ_approx, std_err, max_err) def PMavgpred(self, Z): - """ Return the predicted intermolecular pressure based on a specific Lennard-Jones + ''' Return the predicted intermolecular pressure based on a specific Lennard-Jones function fitted on the deflection physiological range. :param Z: leaflet apex outward deflection value (m) :return: predicted average intermolecular pressure (Pa) - """ + ''' return LennardJones(Z, self.Delta, self.LJ_approx['x0'], self.LJ_approx['C'], self.LJ_approx['nrep'], self.LJ_approx['nattr']) def Pelec(self, Z, Qm): - """ Compute the electric equivalent pressure term. + ''' Compute the electric equivalent pressure term. :param Z: leaflet apex outward deflection value (m) :param Qm: membrane charge density (C/m2) :return: electric equivalent pressure (Pa) - """ + ''' relS = self.S0 / self.surface(Z) abs_perm = self.epsilon0 * self.epsilonR # F/m return -relS * Qm**2 / (2 * abs_perm) # Pa def findDeltaEq(self, Qm): - """ Compute the Delta that cancels out the (Pm + Pec) equation at Z = 0 + ''' Compute the Delta that cancels out the (Pm + Pec) equation at Z = 0 for a given membrane charge density, using the Brent method to refine the pressure root iteratively. :param Qm: membrane charge density (C/m2) :return: equilibrium value (m) and associated pressure (Pa) - """ + ''' f = lambda Delta: (self.pDelta * ( (self.Delta_ / Delta)**self.m - (self.Delta_ / Delta)**self.n) + self.Pelec(0.0, Qm) ) Delta_lb = 0.1 * self.Delta_ Delta_ub = 2.0 * self.Delta_ Delta_eq = brentq(f, Delta_lb, Delta_ub, xtol=1e-16) logger.debug('∆eq = %.2f nm', Delta_eq * 1e9) return (Delta_eq, f(Delta_eq)) def gasflux(self, Z, P): - """ Compute the gas molar flux through the BLS boundary layer for + ''' Compute the gas molar flux through the BLS boundary layer for an unsteady system. :param Z: leaflet apex outward deflection value (m) :param P: internal gas pressure in the inter-leaflet space (Pa) :return: gas molar flux (mol/s) - """ + ''' dC = self.C0 - P / self.kH return 2 * self.surface(Z) * self.Dgl * dC / self.xi def gasmol2Pa(self, ng, V): - """ Compute the gas pressure in the inter-leaflet space for an + ''' Compute the gas pressure in the inter-leaflet space for an unsteady system, from the value of gas molar content. :param ng: internal molar content (mol) :param V: inner volume of the bilayer sonophore structure (m^3) :return: internal gas pressure (Pa) - """ + ''' return ng * self.Rg * self.T / V def gasPa2mol(self, P, V): - """ Compute the gas molar content in the inter-leaflet space for + ''' Compute the gas molar content in the inter-leaflet space for an unsteady system, from the value of internal gas pressure. :param P: internal gas pressure in the inter-leaflet space (Pa) :param V: inner volume of the bilayer sonophore structure (m^3) :return: internal gas molar content (mol) - """ + ''' return P * V / (self.Rg * self.T) def PtotQS(self, Z, ng, Qm, Pac, Pm_comp_method): - """ Compute the balance pressure of the quasi-steady system, upon application + ''' Compute the balance pressure of the quasi-steady system, upon application of an external perturbation on a charged membrane: Ptot = Pm + Pg + Pec - P0 - Pac. :param Z: leaflet apex outward deflection value (m) :param ng: internal molar content (mol) :param Qm: membrane charge density (C/m2) :param Pac: external acoustic perturbation (Pa) :param Pm_comp_method: type of method used to compute average intermolecular pressure :return: total balance pressure (Pa) - """ + ''' if Pm_comp_method is PmCompMethod.direct: Pm = self.PMavg(Z, self.curvrad(Z), self.surface(Z)) elif Pm_comp_method is PmCompMethod.predict: Pm = self.PMavgpred(Z) return Pm + self.gasmol2Pa(ng, self.volume(Z)) - self.P0 - Pac + self.Pelec(Z, Qm) def balancedefQS(self, ng, Qm, Pac=0.0, Pm_comp_method=PmCompMethod.predict): - """ Compute the leaflet deflection upon application of an external + ''' Compute the leaflet deflection upon application of an external perturbation to a quasi-steady system with a charged membrane. This function uses the Brent method (progressive approximation of function root) to solve the following transcendental equation for Z: Pm + Pg + Pec - P0 - Pac = 0. :param ng: internal molar content (mol) :param Qm: membrane charge density (C/m2) :param Pac: external acoustic perturbation (Pa) :param Pm_comp_method: type of method used to compute average intermolecular pressure :return: leaflet deflection (Z) canceling out the balance equation - """ + ''' lb = -0.49 * self.Delta ub = self.a Plb = self.PtotQS(lb, ng, Qm, Pac, Pm_comp_method) Pub = self.PtotQS(ub, ng, Qm, Pac, Pm_comp_method) assert (Plb > 0 > Pub), '[%d, %d] is not a sign changing interval for PtotQS' % (lb, ub) return brentq(self.PtotQS, lb, ub, args=(ng, Qm, Pac, Pm_comp_method), xtol=1e-16) def TEleaflet(self, Z): - """ Compute the circumferential elastic tension felt across the + ''' Compute the circumferential elastic tension felt across the entire leaflet upon stretching. :param Z: leaflet apex outward deflection value (m) :return: circumferential elastic tension (N/m) - """ + ''' return self.kA * self.arealstrain(Z) def TEtissue(self, Z): - """ Compute the circumferential elastic tension felt across the + ''' Compute the circumferential elastic tension felt across the embedding viscoelastic tissue layer upon stretching. :param Z: leaflet apex outward deflection value (m) :return: circumferential elastic tension (N/m) - """ + ''' return self.kA_tissue * self.arealstrain(Z) def TEtot(self, Z): - """ Compute the total circumferential elastic tension (leaflet + ''' Compute the total circumferential elastic tension (leaflet and embedding tissue) felt upon stretching. :param Z: leaflet apex outward deflection value (m) :return: circumferential elastic tension (N/m) - """ + ''' return self.TEleaflet(Z) + self.TEtissue(Z) def PEtot(self, Z, R): - """ Compute the total elastic tension pressure (leaflet + embedding + ''' Compute the total elastic tension pressure (leaflet + embedding tissue) felt upon stretching. :param Z: leaflet apex outward deflection value (m) :param R: leaflet curvature radius (m) :return: elastic tension pressure (Pa) - """ + ''' return - self.TEtot(Z) / R def PVleaflet(self, U, R): - """ Compute the viscous stress felt across the entire leaflet + ''' Compute the viscous stress felt across the entire leaflet upon stretching. :param U: leaflet apex outward deflection velocity (m/s) :param R: leaflet curvature radius (m) :return: leaflet viscous stress (Pa) - """ + ''' return - 12 * U * self.delta0 * self.muS / R**2 def PVfluid(self, U, R): - """ Compute the viscous stress felt across the entire fluid + ''' Compute the viscous stress felt across the entire fluid upon stretching. :param U: leaflet apex outward deflection velocity (m/s) :param R: leaflet curvature radius (m) :return: fluid viscous stress (Pa) - """ + ''' return - 4 * U * self.muL / np.abs(R) def accP(self, Pres, R): - """ Compute the pressure-driven acceleration of the leaflet in the + ''' Compute the pressure-driven acceleration of the leaflet in the unsteady system, upon application of an external perturbation. :param Pres: net resultant pressure (Pa) :param R: leaflet curvature radius (m) :return: pressure-driven acceleration (m/s^2) - """ + ''' return Pres / (self.rhoL * np.abs(R)) def accNL(self, U, R): - """ Compute the non-linear term of the leaflet acceleration in the + ''' Compute the non-linear term of the leaflet acceleration in the unsteady system, upon application of an external perturbation. :param U: leaflet apex outward deflection velocity (m/s) :param R: leaflet curvature radius (m) :return: nonlinear acceleration (m/s^2) .. note:: A simplified version of nonlinear acceleration (neglecting dR/dH) is used here. - """ + ''' # return - (3/2 - 2*R/H) * U**2 / R return -(3 * U**2) / (2 * R) def derivatives(self, y, t, Adrive, Fdrive, Qm, phi, Pm_comp_method=PmCompMethod.predict): - """ Compute the derivatives of the 3-ODE mechanical system variables, + ''' Compute the derivatives of the 3-ODE mechanical system variables, with an imposed constant charge density. :param y: vector of HH system variables at time t :param t: specific instant in time (s) :param Adrive: acoustic drive amplitude (Pa) :param Fdrive: acoustic drive frequency (Hz) :param Qm: membrane charge density (F/m2) :param phi: acoustic drive phase (rad) :param Pm_comp_method: type of method used to compute average intermolecular pressure :return: vector of mechanical system derivatives at time t - """ + ''' # Split input vector explicitly (U, Z, ng) = y # Correct deflection value is below critical compression if Z < -0.5 * self.Delta: logger.warning('Deflection out of range: Z = %.2f nm', Z * 1e9) Z = -0.49 * self.Delta # Compute curvature radius R = self.curvrad(Z) # Compute total pressure Pg = self.gasmol2Pa(ng, self.volume(Z)) if Pm_comp_method is PmCompMethod.direct: Pm = self.PMavg(Z, self.curvrad(Z), self.surface(Z)) elif Pm_comp_method is PmCompMethod.predict: Pm = self.PMavgpred(Z) Ptot = (Pm + Pg - self.P0 - self.Pacoustic(t, Adrive, Fdrive, phi) + self.PEtot(Z, R) + self.PVleaflet(U, R) + self.PVfluid(U, R) + self.Pelec(Z, Qm)) # Compute derivatives dUdt = self.accP(Ptot, R) + self.accNL(U, R) dZdt = U dngdt = self.gasflux(Z, Pg) # Return derivatives vector return [dUdt, dZdt, dngdt] def checkInputs(self, Fdrive, Adrive, Qm, phi): ''' Check validity of stimulation parameters. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param phi: acoustic drive phase (rad) :param Qm: imposed membrane charge density (C/m2) ''' if not all(isinstance(param, float) for param in [Fdrive, Adrive, Qm, phi]): raise TypeError('Invalid stimulation parameters (must be float typed)') if Fdrive <= 0: raise ValueError('Invalid US driving frequency: {} kHz (must be strictly positive)' .format(Fdrive * 1e-3)) if Adrive < 0: raise ValueError('Invalid US pressure amplitude: {} kPa (must be positive or null)' .format(Adrive * 1e-3)) if Qm < CHARGE_RANGE[0] or Qm > CHARGE_RANGE[1]: raise ValueError('Invalid applied charge: {} nC/cm2 (must be within [{}, {}] interval' .format(Qm * 1e5, CHARGE_RANGE[0] * 1e5, CHARGE_RANGE[1] * 1e5)) if phi < 0 or phi >= 2 * np.pi: raise ValueError('Invalid US pressure phase: {:.2f} rad (must be within [0, 2 PI[ rad' .format(phi)) def simulate(self, Fdrive, Adrive, Qm, phi=np.pi, Pm_comp_method=PmCompMethod.predict): - """ Compute short solutions of the mechanical system for specific + ''' Compute short solutions of the mechanical system for specific US stimulation parameters and with an imposed membrane charge density. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param phi: acoustic drive phase (rad) :param Qm: imposed membrane charge density (C/m2) :param Pm_comp_method: type of method used to compute average intermolecular pressure :return: 3-tuple with the time profile, the solution matrix and a state vector - """ + ''' # Check validity of stimulation parameters self.checkInputs(Fdrive, Adrive, Qm, phi) # Determine mechanical system time step Tdrive = 1 / Fdrive dt_mech = Tdrive / NPC_FULL t_mech_cycle = np.linspace(0, Tdrive - dt_mech, NPC_FULL) # Initialize system variables t0 = 0.0 Z0 = 0.0 U0 = 0.0 ng0 = self.ng0 # Solve quasi-steady equation to compute first deflection value Pac1 = self.Pacoustic(t0 + dt_mech, Adrive, Fdrive, phi) Z1 = self.balancedefQS(ng0, Qm, Pac1, Pm_comp_method) U1 = (Z1 - Z0) / dt_mech # Construct arrays to hold system variables states = np.array([1, 1]) t = np.array([t0, t0 + dt_mech]) y = np.array([[U0, U1], [Z0, Z1], [ng0, ng0]]) # Integrate mechanical system for a few acoustic cycles until stabilization j = 0 ng_last = None Z_last = None periodic_conv = False while not periodic_conv and j < NCYCLES_MAX: t_mech = t_mech_cycle + t[-1] + dt_mech y_mech = integrate.odeint(self.derivatives, y[:, -1], t_mech, args=(Adrive, Fdrive, Qm, phi, Pm_comp_method)).T # Compare Z and ng signals over the last 2 acoustic periods if j > 0: Z_rmse = rmse(Z_last, y_mech[1, :]) ng_rmse = rmse(ng_last, y_mech[2, :]) logger.debug('step %u: Z_rmse = %.2e m, ng_rmse = %.2e mol', j, Z_rmse, ng_rmse) if Z_rmse < Z_ERR_MAX and ng_rmse < NG_ERR_MAX: periodic_conv = True # Update last vectors for next comparison Z_last = y_mech[1, :] ng_last = y_mech[2, :] # Concatenate time and solutions to global vectors states = np.concatenate([states, np.ones(NPC_FULL)], axis=0) t = np.concatenate([t, t_mech], axis=0) y = np.concatenate([y, y_mech], axis=1) # Increment loop index j += 1 if j == NCYCLES_MAX: logger.warning('No convergence: stopping after %u cycles', j) else: logger.debug('Periodic convergence after %u cycles', j) states[-1] = 0 # return output variables return (t, y[1:, :], states) def runAndSave(self, outdir, Fdrive, Adrive, Qm): ''' Run a simulation of the mechanical system with specific stimulation parameters and an imposed value of charge density, and save the results in a PKL file. :param outdir: full path to output directory :param Fdrive: US frequency (Hz) :param Adrive: acoustic pressure amplitude (Pa) :param Qm: applided membrane charge density (C/m2) ''' # Get date and time info date_str = time.strftime("%Y.%m.%d") daytime_str = time.strftime("%H:%M:%S") logger.info('%s: simulation @ f = %sHz, A = %sPa, Q = %sC/cm2', self.pprint(), *si_format([Fdrive, Adrive, Qm * 1e-4], 2, space=' ')) # Run simulation tstart = time.time() (t, y, states) = self.simulate(Fdrive, Adrive, Qm) (Z, ng) = y tcomp = time.time() - tstart logger.debug('completed in %s', si_format(tcomp, 1)) U = np.insert(np.diff(Z) / np.diff(t), 0, 0.0) # Store dataframe and metadata df = pd.DataFrame({ 't': t, 'states': states, 'U': U, 'Z': Z, 'ng': ng }) meta = { 'a': self.a, 'd': self.d, 'Cm0': self.Cm0, 'Qm0': self.Qm0, 'Fdrive': Fdrive, 'Adrive': Adrive, 'phi': np.pi, 'Qm': Qm, 'tcomp': tcomp } # Export into to PKL file simcode = MECH_filecode(self.a, Fdrive, Adrive, Qm) outpath = '{}/{}.pkl'.format(outdir, simcode) with open(outpath, 'wb') as fh: pickle.dump({'meta': meta, 'data': df}, fh) logger.debug('simulation data exported to "%s"', outpath) # Compute key output metrics Zmax = np.amax(Z) Zmin = np.amin(Z) Zabs_max = np.amax(np.abs([Zmin, Zmax])) eAmax = self.arealstrain(Zabs_max) Tmax = self.TEtot(Zabs_max) Pmmax = self.PMavgpred(Zmin) ngmax = np.amax(ng) dUdtmax = np.amax(np.abs(np.diff(U) / np.diff(t)**2)) # Export key metrics to log file logpath = os.path.join(outdir, 'log_MECH.xlsx') logentry = { 'Date': date_str, 'Time': daytime_str, 'Radius (nm)': self.a * 1e9, 'Thickness (um)': self.d * 1e6, 'Fdrive (kHz)': Fdrive * 1e-3, 'Adrive (kPa)': Adrive * 1e-3, 'Charge (nC/cm2)': Qm * 1e5, '# samples': t.size, 'Comp. time (s)': round(tcomp, 2), 'kA total (N/m)': self.kA + self.kA_tissue, 'Max Z (nm)': Zmax * 1e9, 'Max eA (-)': eAmax, 'Max Te (mN/m)': Tmax * 1e3, 'Max rel. ng increase (-)': (ngmax - self.ng0) / self.ng0, 'Max Pm (kPa)': Pmmax * 1e-3, 'Max acc. (m/s2)': dUdtmax } if xlslog(logpath, logentry) == 1: logger.debug('log exported to "%s"', logpath) else: logger.error('log export to "%s" aborted', logpath) return outpath def getCycleProfiles(self, Fdrive, Adrive, Qm): ''' Run a mechanical simulation until periodic stabilization, and compute pressure profiles over the last acoustic cycle. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param Qm: imposed membrane charge density (C/m2) :return: dataframe with the time, kinematic and pressure profiles over the last cycle. ''' # Run default simulation and compute relevant profiles logger.info('Running mechanical simulation (a = %sm, f = %sHz, A = %sPa)', si_format(self.a, 1), si_format(Fdrive, 1), si_format(Adrive, 1)) t, y, _ = self.simulate(Fdrive, Adrive, Qm, Pm_comp_method=PmCompMethod.direct) dt = (t[-1] - t[0]) / (t.size - 1) Z, ng = y[:, -NPC_FULL:] t = t[-NPC_FULL:] t -= t[0] logger.info('Computing pressure cyclic profiles') R = self.v_curvrad(Z) U = np.diff(Z) / dt U = np.hstack((U, U[-1])) data = { 't': t, 'Z': Z, 'Cm': self.v_Capct(Z), 'P_M': self.v_PMavg(Z, R, self.surface(Z)), 'P_Q': self.Pelec(Z, Qm), 'P_{VE}': self.PEtot(Z, R) + self.PVleaflet(U, R), 'P_V': self.PVfluid(U, R), 'P_G': self.gasmol2Pa(ng, self.volume(Z)), 'P_0': - np.ones(Z.size) * self.P0 } return pd.DataFrame(data, columns=data.keys()) diff --git a/PySONIC/core/nbls.py b/PySONIC/core/nbls.py index dd7c6c5..847ba41 100644 --- a/PySONIC/core/nbls.py +++ b/PySONIC/core/nbls.py @@ -1,867 +1,867 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-09-29 16:16:19 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-27 01:48:07 +# @Last Modified time: 2018-09-28 14:14:12 import os import time import logging import pickle import progressbar as pb import numpy as np import pandas as pd from scipy.integrate import ode, odeint from scipy.interpolate import interp1d from .bls import BilayerSonophore from .pneuron import PointNeuron from ..utils import logger, si_format, downsample, rmse, ASTIM_filecode, getLookups2D from ..constants import * from ..postpro import findPeaks from ..batches import xlslog class NeuronalBilayerSonophore(BilayerSonophore): - """ This class inherits from the BilayerSonophore class and receives an PointNeuron instance - at initialization, to define the electro-mechanical NICE model and its SONIC variant. """ + ''' This class inherits from the BilayerSonophore class and receives an PointNeuron instance + at initialization, to define the electro-mechanical NICE model and its SONIC variant. ''' def __init__(self, diameter, neuron, Fdrive=None, embedding_depth=0.0): - """ Constructor of the class. + ''' Constructor of the class. :param diameter: in-plane diameter of the sonophore structure within the membrane (m) :param neuron: neuron object :param Fdrive: frequency of acoustic perturbation (Hz) :param embedding_depth: depth of the embedding tissue around the membrane (m) - """ + ''' # Check validity of input parameters if not isinstance(neuron, PointNeuron): raise ValueError('Invalid neuron type: "{}" (must inherit from PointNeuron class)' .format(neuron.name)) self.neuron = neuron # Initialize BilayerSonophore parent object BilayerSonophore.__init__(self, diameter, neuron.Cm0, neuron.Cm0 * neuron.Vm0 * 1e-3, embedding_depth) def __repr__(self): return 'NeuronalBilayerSonophore({}m, {})'.format( si_format(self.a, precision=1, space=' '), self.neuron) def pprint(self): return '{}m diameter NBLS - {} neuron'.format( si_format(self.a, precision=0, space=' '), self.neuron.name) def fullDerivatives(self, y, t, Adrive, Fdrive, phi): - """ Compute the derivatives of the (n+3) ODE full NBLS system variables. + ''' Compute the derivatives of the (n+3) ODE full NBLS system variables. :param y: vector of state variables :param t: specific instant in time (s) :param Adrive: acoustic drive amplitude (Pa) :param Fdrive: acoustic drive frequency (Hz) :param phi: acoustic drive phase (rad) :return: vector of derivatives - """ + ''' dydt_mech = BilayerSonophore.derivatives(self, y[:3], t, Adrive, Fdrive, y[3], phi) dydt_elec = self.neuron.Qderivatives(y[3:], t, self.Capct(y[1])) return dydt_mech + dydt_elec def effDerivatives(self, t, y, interp_data): - """ Compute the derivatives of the n-ODE effective HH system variables, + ''' Compute the derivatives of the n-ODE effective HH system variables, based on 1-dimensional linear interpolation of "effective" coefficients that summarize the system's behaviour over an acoustic cycle. :param t: specific instant in time (s) :param y: vector of HH system variables at time t :param interp_data: dictionary of 1D data points of "effective" coefficients over the charge domain, for specific frequency and amplitude values. :return: vector of effective system derivatives at time t - """ + ''' # Split input vector explicitly Qm, *states = y # Compute charge and channel states variation Vm = np.interp(Qm, interp_data['Q'], interp_data['V']) # mV dQmdt = - self.neuron.currNet(Vm, states) * 1e-3 dstates = self.neuron.derStatesEff(Qm, states, interp_data) # Return derivatives vector return [dQmdt, *dstates] def runFull(self, Fdrive, Adrive, tstim, toffset, PRF, DC, phi=np.pi): - """ Compute solutions of the full electro-mechanical system for a specific set of + ''' Compute solutions of the full electro-mechanical system for a specific set of US stimulation parameters, using a classic integration scheme. The first iteration uses the quasi-steady simplification to compute the initiation of motion from a flat leaflet configuration. Afterwards, the ODE system is solved iteratively until completion. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param phi: acoustic drive phase (rad) :return: 3-tuple with the time profile, the effective solution matrix and a state vector - """ + ''' # Determine system time step Tdrive = 1 / Fdrive dt = Tdrive / NPC_FULL # if CW stimulus: divide integration during stimulus into 100 intervals if DC == 1.0: PRF = 100 / tstim # Compute vector sizes npulses = int(np.round(PRF * tstim)) Tpulse_on = DC / PRF Tpulse_off = (1 - DC) / PRF n_pulse_on = int(np.round(Tpulse_on / dt)) n_pulse_off = int(np.round(Tpulse_off / dt)) n_off = int(np.round(toffset / dt)) # Solve quasi-steady equation to compute first deflection value Z0 = 0.0 ng0 = self.ng0 Qm0 = self.Qm0 Pac1 = self.Pacoustic(dt, Adrive, Fdrive, phi) Z1 = self.balancedefQS(ng0, Qm0, Pac1) # Initialize global arrays states = np.array([1, 1]) t = np.array([0., dt]) y_membrane = np.array([[0., (Z1 - Z0) / dt], [Z0, Z1], [ng0, ng0], [Qm0, Qm0]]) y_channels = np.tile(self.neuron.states0, (2, 1)).T y = np.vstack((y_membrane, y_channels)) nvar = y.shape[0] # Initialize pulse time and states vectors t_pulse0 = np.linspace(0, Tpulse_on + Tpulse_off, n_pulse_on + n_pulse_off) states_pulse = np.concatenate((np.ones(n_pulse_on), np.zeros(n_pulse_off))) # Initialize progress bar if logger.getEffectiveLevel() <= logging.INFO: widgets = ['Running: ', pb.Percentage(), ' ', pb.Bar(), ' ', pb.ETA()] pbar = pb.ProgressBar(widgets=widgets, max_value=int(npulses * (toffset + tstim) / tstim)) pbar.start() # Loop through all pulse (ON and OFF) intervals for i in range(npulses): # Construct and initialize arrays t_pulse = t_pulse0 + t[-1] y_pulse = np.empty((nvar, n_pulse_on + n_pulse_off)) # Integrate ON system y_pulse[:, :n_pulse_on] = odeint( self.fullDerivatives, y[:, -1], t_pulse[:n_pulse_on], args=(Adrive, Fdrive, phi)).T # Integrate OFF system if n_pulse_off > 0: y_pulse[:, n_pulse_on:] = odeint( self.fullDerivatives, y_pulse[:, n_pulse_on - 1], t_pulse[n_pulse_on:], args=(0.0, 0.0, 0.0)).T # Append pulse arrays to global arrays states = np.concatenate([states, states_pulse[1:]]) t = np.concatenate([t, t_pulse[1:]]) y = np.concatenate([y, y_pulse[:, 1:]], axis=1) # Update progress bar if logger.getEffectiveLevel() <= logging.INFO: pbar.update(i) # Integrate offset interval if n_off > 0: t_off = np.linspace(0, toffset, n_off) + t[-1] states_off = np.zeros(n_off) y_off = odeint(self.fullDerivatives, y[:, -1], t_off, args=(0.0, 0.0, 0.0)).T # Concatenate offset arrays to global arrays states = np.concatenate([states, states_off[1:]]) t = np.concatenate([t, t_off[1:]]) y = np.concatenate([y, y_off[:, 1:]], axis=1) # Terminate progress bar if logger.getEffectiveLevel() <= logging.INFO: pbar.finish() # Downsample arrays in time-domain accordgin to target temporal resolution ds_factor = int(np.round(CLASSIC_TARGET_DT / dt)) if ds_factor > 1: Fs = 1 / (dt * ds_factor) logger.info('Downsampling output arrays by factor %u (Fs = %.2f MHz)', ds_factor, Fs * 1e-6) t = t[::ds_factor] y = y[:, ::ds_factor] states = states[::ds_factor] # Compute membrane potential vector (in mV) Vm = y[3, :] / self.v_Capct(y[1, :]) * 1e3 # mV # Return output variables with Vm # return (t, y[1:, :], states) return (t, np.vstack([y[1:4, :], Vm, y[4:, :]]), states) def runSONIC(self, Fdrive, Adrive, tstim, toffset, PRF, DC, dt=DT_EFF): - """ Compute solutions of the system for a specific set of + ''' Compute solutions of the system for a specific set of US stimulation parameters, using charge-predicted "effective" coefficients to solve the HH equations at each step. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param dt: integration time step (s) :return: 3-tuple with the time profile, the effective solution matrix and a state vector - """ + ''' # Load appropriate 2D lookups Aref, Qref, lookups2D = getLookups2D(self.neuron.name, self.a, Fdrive) # Check that acoustic amplitude is within lookup range margin = 1e-9 # adding margin to compensate for eventual round error Arange = (Aref.min() - margin, Aref.max() + margin) if Adrive < Arange[0] or Adrive > Arange[1]: raise ValueError('Invalid amplitude: {}Pa (must be within {}Pa - {} Pa lookup interval)' .format(*si_format([Adrive, *Arange], precision=2, space=' '))) # Interpolate 2D lookups at US amplitude (along with "ng" at zero amplitude) lookups1D = {key: interp1d(Aref, y2D, axis=0)(Adrive) for key, y2D in lookups2D.items()} lookups1D['ng0'] = interp1d(Aref, lookups2D['ng'], axis=0)(0.0) # Add reference charge vector to 1D lookup dictionary lookups1D['Q'] = Qref # Initialize system solvers solver_on = ode(self.effDerivatives) solver_on.set_integrator('lsoda', nsteps=SOLVER_NSTEPS) solver_on.set_f_params(lookups1D) solver_off = ode(lambda t, y, Cm: self.neuron.Qderivatives(y, t, Cm)) solver_off.set_integrator('lsoda', nsteps=SOLVER_NSTEPS) # if CW stimulus: change PRF to have exactly one integration interval during stimulus if DC == 1.0: PRF = 1 / tstim # Compute vector sizes npulses = int(np.round(PRF * tstim)) Tpulse_on = DC / PRF Tpulse_off = (1 - DC) / PRF # For high-PRF pulsed protocols: adapt time step to ensure minimal # number of samples during TON or TOFF dt_warning_msg = 'high-PRF protocol: lowering time step to %.2e s to properly integrate %s' for key, Tpulse in {'TON': Tpulse_on, 'TOFF': Tpulse_off}.items(): if Tpulse > 0 and Tpulse / dt < MIN_SAMPLES_PER_PULSE_INT: dt = Tpulse / MIN_SAMPLES_PER_PULSE_INT logger.warning(dt_warning_msg, dt, key) n_pulse_on = int(np.round(Tpulse_on / dt)) + 1 n_pulse_off = int(np.round(Tpulse_off / dt)) # Compute ofset size n_off = int(np.round(toffset / dt)) # Initialize global arrays states = np.array([1]) t = np.array([0.0]) y = np.atleast_2d(np.insert(self.neuron.states0, 0, self.Qm0)).T nvar = y.shape[0] Zeff = np.array([0.0]) ngeff = np.array([self.ng0]) # Initializing accurate pulse time vector t_pulse_on = np.linspace(0, Tpulse_on, n_pulse_on) t_pulse_off = np.linspace(dt, Tpulse_off, n_pulse_off) + Tpulse_on t_pulse0 = np.concatenate([t_pulse_on, t_pulse_off]) states_pulse = np.concatenate((np.ones(n_pulse_on), np.zeros(n_pulse_off))) # Loop through all pulse (ON and OFF) intervals for i in range(npulses): # Construct and initialize arrays t_pulse = t_pulse0 + t[-1] y_pulse = np.empty((nvar, n_pulse_on + n_pulse_off)) ngeff_pulse = np.empty(n_pulse_on + n_pulse_off) Zeff_pulse = np.empty(n_pulse_on + n_pulse_off) y_pulse[:, 0] = y[:, -1] ngeff_pulse[0] = ngeff[-1] Zeff_pulse[0] = Zeff[-1] # Initialize iterator k = 0 # Integrate ON system solver_on.set_initial_value(y_pulse[:, k], t_pulse[k]) while solver_on.successful() and k < n_pulse_on - 1: k += 1 solver_on.integrate(t_pulse[k]) y_pulse[:, k] = solver_on.y ngeff_pulse[k] = np.interp(y_pulse[0, k], lookups1D['Q'], lookups1D['ng']) # mole Zeff_pulse[k] = self.balancedefQS(ngeff_pulse[k], y_pulse[0, k]) # m # Integrate OFF system if n_pulse_off > 0: solver_off.set_initial_value(y_pulse[:, k], t_pulse[k]) solver_off.set_f_params(self.Capct(Zeff_pulse[k])) while solver_off.successful() and k < n_pulse_on + n_pulse_off - 1: k += 1 solver_off.integrate(t_pulse[k]) y_pulse[:, k] = solver_off.y ngeff_pulse[k] = np.interp(y_pulse[0, k], lookups1D['Q'], lookups1D['ng0']) # mole Zeff_pulse[k] = self.balancedefQS(ngeff_pulse[k], y_pulse[0, k]) # m solver_off.set_f_params(self.Capct(Zeff_pulse[k])) # Append pulse arrays to global arrays states = np.concatenate([states[:-1], states_pulse]) t = np.concatenate([t, t_pulse[1:]]) y = np.concatenate([y, y_pulse[:, 1:]], axis=1) Zeff = np.concatenate([Zeff, Zeff_pulse[1:]]) ngeff = np.concatenate([ngeff, ngeff_pulse[1:]]) # Integrate offset interval if n_off > 0: t_off = np.linspace(0, toffset, n_off) + t[-1] states_off = np.zeros(n_off) y_off = np.empty((nvar, n_off)) ngeff_off = np.empty(n_off) Zeff_off = np.empty(n_off) y_off[:, 0] = y[:, -1] ngeff_off[0] = ngeff[-1] Zeff_off[0] = Zeff[-1] solver_off.set_initial_value(y_off[:, 0], t_off[0]) solver_off.set_f_params(self.Capct(Zeff_pulse[-1])) k = 0 while solver_off.successful() and k < n_off - 1: k += 1 solver_off.integrate(t_off[k]) y_off[:, k] = solver_off.y ngeff_off[k] = np.interp(y_off[0, k], lookups1D['Q'], lookups1D['ng0']) # mole Zeff_off[k] = self.balancedefQS(ngeff_off[k], y_off[0, k]) # m solver_off.set_f_params(self.Capct(Zeff_off[k])) # Concatenate offset arrays to global arrays states = np.concatenate([states, states_off[1:]]) t = np.concatenate([t, t_off[1:]]) y = np.concatenate([y, y_off[:, 1:]], axis=1) Zeff = np.concatenate([Zeff, Zeff_off[1:]]) ngeff = np.concatenate([ngeff, ngeff_off[1:]]) # Compute membrane potential vector (in mV) Vm = np.zeros(states.size) Vm[states == 0] = y[0, states == 0] / self.v_Capct(Zeff[states == 0]) * 1e3 # mV Vm[states == 1] = np.interp(y[0, states == 1], lookups1D['Q'], lookups1D['V']) # mV # Add Zeff, ngeff and Vm to solution matrix y = np.vstack([Zeff, ngeff, y[0, :], Vm, y[1:, :]]) # return output variables return (t, y, states) def runHybrid(self, Fdrive, Adrive, tstim, toffset, phi=np.pi): - """ Compute solutions of the system for a specific set of + ''' Compute solutions of the system for a specific set of US stimulation parameters, using a hybrid integration scheme. The first iteration uses the quasi-steady simplification to compute the initiation of motion from a flat leaflet configuration. Afterwards, the NBLS ODE system is solved iteratively for "slices" of N-microseconds, in a 2-steps scheme: - First, the full (n+3) ODE system is integrated for a few acoustic cycles until Z and ng reach a stable periodic solution (limit cycle) - Second, the signals of the 3 mechanical variables over the last acoustic period are selected and resampled to a far lower sampling rate - Third, the HH n-ODE system is integrated for the remaining time of the slice, using periodic expansion of the mechanical signals to precompute the values of capacitance. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param phi: acoustic drive phase (rad) :return: 3-tuple with the time profile, the solution matrix and a state vector .. warning:: This method cannot handle pulsed stimuli - """ + ''' # Initialize full and HH systems solvers solver_full = ode( lambda t, y, Adrive, Fdrive, phi: self.fullDerivatives(y, t, Adrive, Fdrive, phi)) solver_full.set_f_params(Adrive, Fdrive, phi) solver_full.set_integrator('lsoda', nsteps=SOLVER_NSTEPS) solver_hh = ode(lambda t, y, Cm: self.neuron.Qderivatives(y, t, Cm)) solver_hh.set_integrator('dop853', nsteps=SOLVER_NSTEPS, atol=1e-12) # Determine full and HH systems time steps Tdrive = 1 / Fdrive dt_full = Tdrive / NPC_FULL dt_hh = Tdrive / NPC_HH n_full_per_hh = int(NPC_FULL / NPC_HH) t_full_cycle = np.linspace(0, Tdrive - dt_full, NPC_FULL) t_hh_cycle = np.linspace(0, Tdrive - dt_hh, NPC_HH) # Determine number of samples in prediction vectors npc_pred = NPC_FULL - n_full_per_hh + 1 # Solve quasi-steady equation to compute first deflection value Z0 = 0.0 ng0 = self.ng0 Qm0 = self.Qm0 Pac1 = self.Pacoustic(dt_full, Adrive, Fdrive, phi) Z1 = self.balancedefQS(ng0, Qm0, Pac1) # Initialize global arrays states = np.array([1, 1]) t = np.array([0., dt_full]) y_membrane = np.array([[0., (Z1 - Z0) / dt_full], [Z0, Z1], [ng0, ng0], [Qm0, Qm0]]) y_channels = np.tile(self.neuron.states0, (2, 1)).T y = np.vstack((y_membrane, y_channels)) nvar = y.shape[0] # Initialize progress bar if logger.getEffectiveLevel() == logging.DEBUG: widgets = ['Running: ', pb.Percentage(), ' ', pb.Bar(), ' ', pb.ETA()] pbar = pb.ProgressBar(widgets=widgets, max_value=1000) pbar.start() # For each hybrid integration interval irep = 0 sim_error = False while not sim_error and t[-1] < tstim + toffset: # Integrate full system for a few acoustic cycles until stabilization periodic_conv = False j = 0 ng_last = None Z_last = None while not sim_error and not periodic_conv: if t[-1] > tstim: solver_full.set_f_params(0.0, 0.0, 0.0) t_full = t_full_cycle + t[-1] + dt_full y_full = np.empty((nvar, NPC_FULL)) y0_full = y[:, -1] solver_full.set_initial_value(y0_full, t[-1]) k = 0 while solver_full.successful() and k <= NPC_FULL - 1: solver_full.integrate(t_full[k]) y_full[:, k] = solver_full.y k += 1 # Compare Z and ng signals over the last 2 acoustic periods if j > 0 and rmse(Z_last, y_full[1, :]) < Z_ERR_MAX \ and rmse(ng_last, y_full[2, :]) < NG_ERR_MAX: periodic_conv = True # Update last vectors for next comparison Z_last = y_full[1, :] ng_last = y_full[2, :] # Concatenate time and solutions to global vectors states = np.concatenate([states, np.ones(NPC_FULL)], axis=0) t = np.concatenate([t, t_full], axis=0) y = np.concatenate([y, y_full], axis=1) # Increment loop index j += 1 # Retrieve last period of the 3 mechanical variables to propagate in HH system t_last = t[-npc_pred:] mech_last = y[0:3, -npc_pred:] # Downsample signals to specified HH system time step (_, mech_pred) = downsample(t_last, mech_last, NPC_HH) # Integrate HH system until certain dQ or dT is reached Q0 = y[3, -1] dQ = 0.0 t0_interval = t[-1] dt_interval = 0.0 j = 0 if t[-1] < tstim: tlim = tstim else: tlim = tstim + toffset while (not sim_error and t[-1] < tlim and (np.abs(dQ) < DQ_UPDATE or dt_interval < DT_UPDATE)): t_hh = t_hh_cycle + t[-1] + dt_hh y_hh = np.empty((nvar - 3, NPC_HH)) y0_hh = y[3:, -1] solver_hh.set_initial_value(y0_hh, t[-1]) k = 0 while solver_hh.successful() and k <= NPC_HH - 1: solver_hh.set_f_params(self.Capct(mech_pred[1, k])) solver_hh.integrate(t_hh[k]) y_hh[:, k] = solver_hh.y k += 1 # Concatenate time and solutions to global vectors states = np.concatenate([states, np.zeros(NPC_HH)], axis=0) t = np.concatenate([t, t_hh], axis=0) y = np.concatenate([y, np.concatenate([mech_pred, y_hh], axis=0)], axis=1) # Compute charge variation from interval beginning dQ = y[3, -1] - Q0 dt_interval = t[-1] - t0_interval # Increment loop index j += 1 # Update progress bar if logger.getEffectiveLevel() == logging.DEBUG: pbar.update(int(1000 * (t[-1] / (tstim + toffset)))) irep += 1 # Terminate progress bar if logger.getEffectiveLevel() == logging.DEBUG: pbar.finish() # Compute membrane potential vector (in mV) Vm = y[3, :] / self.v_Capct(y[1, :]) * 1e3 # mV # Return output variables with Vm # return (t, y[1:, :], states) return (t, np.vstack([y[1:4, :], Vm, y[4:, :]]), states) def checkInputsFull(self, Fdrive, Adrive, tstim, toffset, PRF, DC, method): - """ Check validity of simulation parameters. + ''' Check validity of simulation parameters. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param method: selected integration method :return: 3-tuple with the time profile, the solution matrix and a state vector - """ + ''' BilayerSonophore.checkInputs(self, Fdrive, Adrive, 0.0, 0.0) self.neuron.checkInputs(Adrive, tstim, toffset, PRF, DC) # Check validity of simulation type if method not in ('full', 'hybrid', 'sonic'): raise ValueError('Invalid integration method: "{}"'.format(method)) def simulate(self, Fdrive, Adrive, tstim, toffset, PRF=None, DC=1.0, method='sonic'): - """ Run simulation of the system for a specific set of + ''' Run simulation of the system for a specific set of US stimulation parameters. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param method: selected integration method :return: 3-tuple with the time profile, the solution matrix and a state vector - """ + ''' # Check validity of stimulation parameters self.checkInputsFull(Fdrive, Adrive, tstim, toffset, PRF, DC, method) # Call appropriate simulation function if method == 'full': return self.runFull(Fdrive, Adrive, tstim, toffset, PRF, DC) elif method == 'sonic': return self.runSONIC(Fdrive, Adrive, tstim, toffset, PRF, DC) elif method == 'hybrid': if DC < 1.0: raise ValueError('Pulsed protocol incompatible with hybrid integration method') return self.runHybrid(Fdrive, Adrive, tstim, toffset) def titrate(self, Fdrive, tstim, toffset, PRF=None, DC=1.0, Arange=(0., 2 * TITRATION_ASTIM_A_MAX), method='sonic'): ''' Use a dichotomic recursive search to determine the threshold amplitude needed to obtain neural excitation for a given frequency, duration, PRF and duty cycle. :param Fdrive: US frequency (Hz) :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param Arange: search interval for Adrive, iteratively refined :return: 5-tuple with the determined threshold, time profile, solution matrix, state vector and response latency ''' Adrive = (Arange[0] + Arange[1]) / 2 # Run simulation and detect spikes t0 = time.time() (t, y, states) = self.simulate(Fdrive, Adrive, tstim, toffset, PRF, DC, method=method) tcomp = time.time() - t0 dt = t[1] - t[0] ipeaks, *_ = findPeaks(y[2, :], SPIKE_MIN_QAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_QPROM) nspikes = ipeaks.size latency = t[ipeaks[0]] if nspikes > 0 else None logger.debug('A = %sPa ---> %s spike%s detected', si_format(Adrive, 2, space=' '), nspikes, "s" if nspikes > 1 else "") # If accurate threshold is found, return simulation results if (Arange[1] - Arange[0]) <= TITRATION_ASTIM_DA_MAX and nspikes == 1: return (Adrive, t, y, states, latency, tcomp) # Otherwise, refine titration interval and iterate recursively else: if nspikes == 0: # if Adrive too close to max then stop if (TITRATION_ASTIM_A_MAX - Adrive) <= TITRATION_ASTIM_DA_MAX: return (np.nan, t, y, states, latency, tcomp) Arange = (Adrive, Arange[1]) else: Arange = (Arange[0], Adrive) return self.titrate(Fdrive, tstim, toffset, PRF, DC, Arange=Arange, method=method) def runAndSave(self, outdir, Fdrive, tstim, toffset, PRF=None, DC=1.0, Adrive=None, method='sonic'): ''' Run a simulation of the full electro-mechanical system for a given neuron type with specific parameters, and save the results in a PKL file. :param outdir: full path to output directory :param Fdrive: US frequency (Hz) :param tstim: stimulus duration (s) :param toffset: stimulus offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: stimulus duty cycle (-) :param Adrive: acoustic pressure amplitude (Pa) :param method: integration method ''' # Get date and time info date_str = time.strftime("%Y.%m.%d") daytime_str = time.strftime("%H:%M:%S") if Adrive is not None: logger.info('%s: simulation @ f = %sHz, A = %sPa, t = %ss%s', self, si_format(Fdrive, 0, space=' '), si_format(Adrive, 2, space=' '), si_format(tstim, 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format(si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) # Run simulation tstart = time.time() t, y, states = self.simulate(Fdrive, Adrive, tstim, toffset, PRF, DC, method=method) tcomp = time.time() - tstart Z, ng, Qm, Vm, *channels = y # Detect spikes on Qm signal dt = t[1] - t[0] ipeaks, *_ = findPeaks(Qm, SPIKE_MIN_QAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_QPROM) nspikes = ipeaks.size lat = t[ipeaks[0]] if nspikes > 0 else 'N/A' outstr = '{} spike{} detected'.format(nspikes, 's' if nspikes > 1 else '') else: logger.info('%s: titration @ f = %sHz, t = %ss%s', self, si_format(Fdrive, 0, space=' '), si_format(tstim, 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format(si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) # Run titration Adrive, t, y, states, lat, tcomp = self.titrate(Fdrive, tstim, toffset, PRF, DC, method=method) Z, ng, Qm, Vm, *channels = y if Adrive is np.nan: outstr = 'no spikes detected within titration interval' nspikes = 0 else: outstr = 'Athr = {}Pa'.format(si_format(Adrive, 2, space=' ')) nspikes = 1 logger.debug('completed in %s, %s', si_format(tcomp, 1), outstr) sr = np.mean(1 / np.diff(t[ipeaks])) if nspikes > 1 else None # Store dataframe and metadata U = np.insert(np.diff(Z) / np.diff(t), 0, 0.0) df = pd.DataFrame({ 't': t, 'states': states, 'U': U, 'Z': Z, 'ng': ng, 'Qm': Qm, 'Vm': Vm }) for j in range(len(self.neuron.states_names)): df[self.neuron.states_names[j]] = channels[j] meta = { 'neuron': self.neuron.name, 'a': self.a, 'd': self.d, 'Fdrive': Fdrive, 'Adrive': Adrive, 'phi': np.pi, 'tstim': tstim, 'toffset': toffset, 'PRF': PRF, 'DC': DC, 'tcomp': tcomp, 'method': method } # Export into to PKL file simcode = ASTIM_filecode(self.neuron.name, self.a, Fdrive, Adrive, tstim, PRF, DC, method) outpath = '{}/{}.pkl'.format(outdir, simcode) with open(outpath, 'wb') as fh: pickle.dump({'meta': meta, 'data': df}, fh) logger.debug('simulation data exported to "%s"', outpath) # Export key metrics to log file logpath = os.path.join(outdir, 'log_ASTIM.xlsx') logentry = { 'Date': date_str, 'Time': daytime_str, 'Neuron Type': self.neuron.name, 'Radius (nm)': self.a * 1e9, 'Thickness (um)': self.d * 1e6, 'Fdrive (kHz)': Fdrive * 1e-3, 'Adrive (kPa)': Adrive * 1e-3, 'Tstim (ms)': tstim * 1e3, 'PRF (kHz)': PRF * 1e-3 if DC < 1 else 'N/A', 'Duty factor': DC, 'Sim. Type': method, '# samples': t.size, 'Comp. time (s)': round(tcomp, 2), '# spikes': nspikes, 'Latency (ms)': lat * 1e3 if isinstance(lat, float) else 'N/A', 'Spike rate (sp/ms)': sr * 1e-3 if isinstance(sr, float) else 'N/A' } if xlslog(logpath, logentry) == 1: logger.debug('log exported to "%s"', logpath) else: logger.error('log export to "%s" aborted', self.logpath) return outpath def findRheobaseAmps(self, DCs, Fdrive, Vthr, curr='net'): ''' Find the rheobase amplitudes (i.e. threshold acoustic amplitudes of infinite duration that would result in excitation) of a specific neuron for various stimulation duty cycles. :param DCs: duty cycles vector (-) :param Fdrive: acoustic drive frequency (Hz) :param Vthr: threshold membrane potential above which the neuron necessarily fires (mV) :return: rheobase amplitudes vector (Pa) ''' # Get lookups projected at specific (a, Fdrive, Qthr) combination. Aref, Qref, lookups2D = getLookups2D(self.neuron.name, self.a, Fdrive) Qthr = self.neuron.Cm0 * Vthr * 1e-3 # C/m2 lookups1D = {key: interp1d(Qref, y2D, axis=1)(Qthr) for key, y2D in lookups2D.items()} # Remove unnecessary items ot get ON rates and effective potential at threshold charge rates_on = lookups1D rates_on.pop('ng') Vm_on = rates_on.pop('V') # Compute neuron OFF rates at threshold potential rates_off = self.neuron.getRates(Vthr) # Compute rheobase amplitudes rheboase_amps = np.empty(DCs.size) for i, DC in enumerate(DCs): sstates_pulse = np.empty((len(self.neuron.states_names), Aref.size)) for j, x in enumerate(self.neuron.states_names): # If channel state, compute pulse-average steady-state values if x in self.neuron.getGates(): x = x.lower() alpha_str, beta_str = ['{}{}'.format(s, x) for s in ['alpha', 'beta']] alphax_pulse = rates_on[alpha_str] * DC + rates_off[alpha_str] * (1 - DC) betax_pulse = rates_on[beta_str] * DC + rates_off[beta_str] * (1 - DC) sstates_pulse[j, :] = alphax_pulse / (alphax_pulse + betax_pulse) # Otherwise assume the state has reached a steady-state value for Vthr else: sstates_pulse[j, :] = np.ones(Aref.size) * self.neuron.steadyStates(Vthr)[j] # Compute the pulse average net (or leakage) current along the amplitude space if curr == 'net': iNet_on = self.neuron.currNet(Vm_on, sstates_pulse) iNet_off = self.neuron.currNet(Vthr, sstates_pulse) elif curr == 'leak': iNet_on = self.neuron.currL(Vm_on) iNet_off = self.neuron.currL(Vthr) iNet_avg = iNet_on * DC + iNet_off * (1 - DC) # Find the threshold amplitude that cancels the pulse average net current rheboase_amps[i] = np.interp(0, -iNet_avg, Aref, left=0., right=np.nan) inan = np.where(np.isnan(rheboase_amps))[0] if inan.size > 0: if inan.size == rheboase_amps.size: logger.error('No rheobase amplitudes within [%s - %sPa] for the provided duty cycles', *si_format((Aref.min(), Aref.max()))) else: minDC = DCs[inan.max() + 1] logger.warning('No rheobase amplitudes within [%s - %sPa] below %.1f%% duty cycle', *si_format((Aref.min(), Aref.max())), minDC * 1e2) return rheboase_amps def computeEffVars(self, Fdrive, Adrive, Qm, phi=np.pi): ''' Compute "effective" coefficients of the HH system for a specific combination of stimulus frequency, stimulus amplitude and charge density. A short mechanical simulation is run while imposing the specific charge density, until periodic stabilization. The HH coefficients are then averaged over the last acoustic cycle to yield "effective" coefficients. :param Fdrive: acoustic drive frequency (Hz) :param Adrive: acoustic drive amplitude (Pa) :param Qm: imposed charge density (C/m2) :param phi: acoustic drive phase (rad) ''' logger.info( '%s: lookups @ %sHz, %sPa, %.2f nC/cm2', self, *si_format([Fdrive, Adrive], precision=1, space=' '), Qm * 1e5) # Run simulation and retrieve deflection and gas content vectors from last cycle _, [Z, ng], _ = BilayerSonophore.simulate(self, Fdrive, Adrive, Qm, phi) Z_last = Z[-NPC_FULL:] # m # Compute membrane potential vector Vm = Qm / self.v_Capct(Z_last) * 1e3 # mV # Compute average cycle value for membrane potential and rate constants Vm_eff = np.mean(Vm) # mV rates_eff = self.neuron.getEffRates(Vm) # Take final cycle value for gas content ng_eff = ng[-1] # mole # Return effective coefficients return [Vm_eff, ng_eff, *rates_eff] diff --git a/PySONIC/core/pneuron.py b/PySONIC/core/pneuron.py index 323f8d3..dd516ec 100644 --- a/PySONIC/core/pneuron.py +++ b/PySONIC/core/pneuron.py @@ -1,481 +1,474 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-08-03 11:53:04 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-26 13:52:21 - -''' Module standard API for all neuron mechanisms. - - Each mechanism class can use different methods to define the membrane dynamics of a - specific neuron type. However, they must contain some mandatory attributes and methods - in order to be properly imported in other sonic modules and used in NICE simulations. -''' +# @Last Modified time: 2018-09-28 14:14:12 import os import time import pickle import abc import numpy as np from scipy.integrate import odeint import pandas as pd from ..postpro import findPeaks from ..constants import * from ..utils import si_format, logger, ESTIM_filecode from ..batches import xlslog class PointNeuron(metaclass=abc.ABCMeta): ''' Abstract class defining the common API (i.e. mandatory attributes and methods) of all subclasses implementing the channels mechanisms of specific point neurons. The mandatory attributes are: - **name**: a string defining the name of the mechanism. - **Cm0**: a float defining the membrane resting capacitance (in F/m2) - **Vm0**: a float defining the membrane resting potential (in mV) - **states_names**: a list of strings defining the names of the different state probabilities governing the channels behaviour (i.e. the differential HH variables). - **states0**: a 1D array of floats (NOT integers !!!) defining the initial values of the different state probabilities. - **coeff_names**: a list of strings defining the names of the different coefficients to be used in effective simulations. The mandatory methods are: - **currNet**: compute the net ionic current density (in mA/m2) across the membrane, given a specific membrane potential (in mV) and channel states. - **steadyStates**: compute the channels steady-state values for a specific membrane potential value (in mV). - **derStates**: compute the derivatives of channel states, given a specific membrane potential (in mV) and channel states. This method must return a list of derivatives ordered identically as in the states0 attribute. - **getEffRates**: get the effective rate constants of ion channels to be used in effective simulations. This method must return an array of effective rates ordered identically as in the coeff_names attribute. - **derStatesEff**: compute the effective derivatives of channel states, based on 1-dimensional linear interpolators of "effective" coefficients. This method must return a list of derivatives ordered identically as in the states0 attribute. ''' def __repr__(self): return self.__class__.__name__ def pprint(self): return '{} neuron'.format(self.__class__.__name__) @property @abc.abstractmethod def name(self): return 'Should never reach here' @property @abc.abstractmethod def Cm0(self): return 'Should never reach here' @property @abc.abstractmethod def Vm0(self): return 'Should never reach here' @abc.abstractmethod def currNet(self, Vm, states): ''' Compute the net ionic current per unit area. :param Vm: membrane potential (mV) :states: state probabilities of the ion channels :return: current per unit area (mA/m2) ''' @abc.abstractmethod def steadyStates(self, Vm): ''' Compute the channels steady-state values for a specific membrane potential value. :param Vm: membrane potential (mV) :return: array of steady-states ''' @abc.abstractmethod def derStates(self, Vm, states): ''' Compute the derivatives of channel states. :param Vm: membrane potential (mV) :states: state probabilities of the ion channels :return: current per unit area (mA/m2) ''' @abc.abstractmethod def getEffRates(self, Vm): ''' Get the effective rate constants of ion channels, averaged along an acoustic cycle, for future use in effective simulations. :param Vm: array of membrane potential values for an acoustic cycle (mV) :return: an array of rate average constants (s-1) ''' @abc.abstractmethod def derStatesEff(self, Qm, states, interp_data): ''' Compute the effective derivatives of channel states, based on 1-dimensional linear interpolation of "effective" coefficients that summarize the system's behaviour over an acoustic cycle. :param Qm: membrane charge density (C/m2) :states: state probabilities of the ion channels :param interp_data: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. ''' def getGates(self): ''' Retrieve the names of the neuron's states that match an ion channel gating. ''' gates = [] for x in self.states_names: if 'alpha{}'.format(x.lower()) in self.coeff_names: gates.append(x) return gates def getRates(self, Vm): ''' Compute the ion channels rate constants for a given membrane potential. :param Vm: membrane potential (mV) :return: a dictionary of rate constants and their values at the given potential. ''' rates = {} for x in self.getGates(): x = x.lower() alpha_str, beta_str = ['{}{}'.format(s, x.lower()) for s in ['alpha', 'beta']] inf_str, tau_str = ['{}inf'.format(x.lower()), 'tau{}'.format(x.lower())] if hasattr(self, 'alpha{}'.format(x)): alphax = getattr(self, alpha_str)(Vm) betax = getattr(self, beta_str)(Vm) elif hasattr(self, '{}inf'.format(x)): xinf = getattr(self, inf_str)(Vm) taux = getattr(self, tau_str)(Vm) alphax = xinf / taux betax = 1 / taux - alphax rates[alpha_str] = alphax rates[beta_str] = betax return rates def vtrap(self, x, y): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x / y) - 1) def Vderivatives(self, y, t, Iinj): ''' Compute the derivatives of a V-cast HH system for a specific value of injected current. :param y: vector of HH system variables at time t :param t: time value (s, unused) :param Iinj: injected current (mA/m2) :return: vector of HH system derivatives at time t ''' Vm, *states = y Iionic = self.currNet(Vm, states) # mA/m2 dVmdt = (- Iionic + Iinj) / self.Cm0 # mV/s dstates = self.derStates(Vm, states) return [dVmdt, *dstates] def Qderivatives(self, y, t, Cm=None): - """ Compute the derivatives of the n-ODE HH system variables, + ''' Compute the derivatives of the n-ODE HH system variables, based on a value of membrane capacitance. :param y: vector of HH system variables at time t :param t: specific instant in time (s) :param Cm: membrane capacitance (F/m2) :return: vector of HH system derivatives at time t - """ + ''' if Cm is None: Cm = self.Cm0 Qm, *states = y Vm = Qm / Cm * 1e3 # mV dQm = - self.currNet(Vm, states) * 1e-3 # A/m2 dstates = self.derStates(Vm, states) return [dQm, *dstates] def checkInputs(self, Astim, tstim, toffset, PRF, DC): ''' Check validity of electrical stimulation parameters. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) ''' # Check validity of stimulation parameters if not all(isinstance(param, float) for param in [Astim, tstim, toffset, DC]): raise TypeError('Invalid stimulation parameters (must be float typed)') if tstim <= 0: raise ValueError('Invalid stimulus duration: {} ms (must be strictly positive)' .format(tstim * 1e3)) if toffset < 0: raise ValueError('Invalid stimulus offset: {} ms (must be positive or null)' .format(toffset * 1e3)) if DC <= 0.0 or DC > 1.0: raise ValueError('Invalid duty cycle: {} (must be within ]0; 1])'.format(DC)) if DC < 1.0: if not isinstance(PRF, float): raise TypeError('Invalid PRF value (must be float typed)') if PRF is None: raise AttributeError('Missing PRF value (must be provided when DC < 1)') if PRF < 1 / tstim: raise ValueError('Invalid PRF: {} Hz (PR interval exceeds stimulus duration)' .format(PRF)) def simulate(self, Astim, tstim, toffset, PRF=None, DC=1.0): ''' Compute solutions of a neuron's HH system for a specific set of electrical stimulation parameters, using a classic integration scheme. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :return: 3-tuple with the time profile and solution matrix and a state vector ''' # Check validity of stimulation parameters self.checkInputs(Astim, tstim, toffset, PRF, DC) # Determine system time step dt = DT_ESTIM # if CW stimulus: divide integration during stimulus into single interval if DC == 1.0: PRF = 1 / tstim # Compute vector sizes npulses = int(np.round(PRF * tstim)) Tpulse_on = DC / PRF Tpulse_off = (1 - DC) / PRF # For high-PRF pulsed protocols: adapt time step to ensure minimal # number of samples during TON or TOFF dt_warning_msg = 'high-PRF protocol: lowering time step to %.2e s to properly integrate %s' for key, Tpulse in {'TON': Tpulse_on, 'TOFF': Tpulse_off}.items(): if Tpulse > 0 and Tpulse / dt < MIN_SAMPLES_PER_PULSE_INT: dt = Tpulse / MIN_SAMPLES_PER_PULSE_INT logger.warning(dt_warning_msg, dt, key) n_pulse_on = int(np.round(Tpulse_on / dt)) n_pulse_off = int(np.round(Tpulse_off / dt)) # Compute offset size n_off = int(np.round(toffset / dt)) # Set initial conditions y0 = [self.Vm0, *self.states0] nvar = len(y0) # Initialize global arrays t = np.array([0.]) states = np.array([1]) y = np.array([y0]).T # Initialize pulse time and states vectors t_pulse0 = np.linspace(0, Tpulse_on + Tpulse_off, n_pulse_on + n_pulse_off) states_pulse = np.concatenate((np.ones(n_pulse_on), np.zeros(n_pulse_off))) # Loop through all pulse (ON and OFF) intervals for i in range(npulses): # Construct and initialize arrays t_pulse = t_pulse0 + t[-1] y_pulse = np.empty((nvar, n_pulse_on + n_pulse_off)) # Integrate ON system y_pulse[:, :n_pulse_on] = odeint( self.Vderivatives, y[:, -1], t_pulse[:n_pulse_on], args=(Astim,)).T # Integrate OFF system if n_pulse_off > 0: y_pulse[:, n_pulse_on:] = odeint( self.Vderivatives, y_pulse[:, n_pulse_on - 1], t_pulse[n_pulse_on:], args=(0.0,)).T # Append pulse arrays to global arrays states = np.concatenate([states, states_pulse[1:]]) t = np.concatenate([t, t_pulse[1:]]) y = np.concatenate([y, y_pulse[:, 1:]], axis=1) # Integrate offset interval if n_off > 0: t_off = np.linspace(0, toffset, n_off) + t[-1] states_off = np.zeros(n_off) y_off = odeint(self.Vderivatives, y[:, -1], t_off, args=(0.0, )).T # Concatenate offset arrays to global arrays states = np.concatenate([states, states_off[1:]]) t = np.concatenate([t, t_off[1:]]) y = np.concatenate([y, y_off[:, 1:]], axis=1) # Return output variables return (t, y, states) def titrate(self, tstim, toffset, PRF=None, DC=1.0, Arange=(0., 2 * TITRATION_ESTIM_A_MAX)): ''' Use a dichotomic recursive search to determine the threshold amplitude needed to obtain neural excitation for a given duration, PRF and duty cycle. :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param Arange: search interval for Astim, iteratively refined :return: 5-tuple with the determined threshold, time profile, solution matrix, state vector and response latency ''' Astim = (Arange[0] + Arange[1]) / 2 # Run simulation and detect spikes t0 = time.time() (t, y, states) = self.simulate(Astim, tstim, toffset, PRF, DC) tcomp = time.time() - t0 dt = t[1] - t[0] ipeaks, *_ = findPeaks(y[0, :], SPIKE_MIN_VAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_VPROM) nspikes = ipeaks.size latency = t[ipeaks[0]] if nspikes > 0 else None logger.debug('A = %sA/m2 ---> %s spike%s detected', si_format(Astim * 1e-3, 2, space=' '), nspikes, "s" if nspikes > 1 else "") # If accurate threshold is found, return simulation results if (Arange[1] - Arange[0]) <= TITRATION_ESTIM_DA_MAX and nspikes == 1: return (Astim, t, y, states, latency, tcomp) # Otherwise, refine titration interval and iterate recursively else: if nspikes == 0: # if Astim too close to max then stop if (TITRATION_ESTIM_A_MAX - Astim) <= TITRATION_ESTIM_DA_MAX: return (np.nan, t, y, states, latency, tcomp) Arange = (Astim, Arange[1]) else: Arange = (Arange[0], Astim) return self.titrate(tstim, toffset, PRF, DC, Arange=Arange) def runAndSave(self, outdir, tstim, toffset, PRF=None, DC=1.0, Astim=None): ''' Run a simulation of the point-neuron Hodgkin-Huxley system with specific parameters, and save the results in a PKL file. :param outdir: full path to output directory :param tstim: stimulus duration (s) :param toffset: stimulus offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: stimulus duty cycle (-) :param Astim: stimulus amplitude (mA/m2) ''' # Get date and time info date_str = time.strftime("%Y.%m.%d") daytime_str = time.strftime("%H:%M:%S") if Astim is not None: logger.info('%s: simulation @ A = %sA/m2, t = %ss%s', self, si_format(Astim * 1e-3, 2, space=' '), si_format(tstim, 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format(si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) # Run simulation tstart = time.time() t, y, states = self.simulate(Astim, tstim, toffset, PRF, DC) Vm, *channels = y tcomp = time.time() - tstart # Detect spikes on Vm signal dt = t[1] - t[0] ipeaks, *_ = findPeaks(Vm, SPIKE_MIN_VAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_VPROM) nspikes = ipeaks.size lat = t[ipeaks[0]] if nspikes > 0 else 'N/A' outstr = '{} spike{} detected'.format(nspikes, 's' if nspikes > 1 else '') else: logger.info('%s: titration @ t = %ss%s', self, si_format(tstim, 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format(si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) # Run titration Astim, t, y, states, lat, tcomp = self.titrate(tstim, toffset, PRF, DC) Vm, *channels = y nspikes = 1 if Astim is np.nan: outstr = 'no spikes detected within titration interval' nspikes = 0 else: nspikes = 1 outstr = 'Athr = {}A/m2'.format(si_format(Astim * 1e-3, 2, space=' ')) logger.debug('completed in %s, %s', si_format(tcomp, 1), outstr) sr = np.mean(1 / np.diff(t[ipeaks])) if nspikes > 1 else None # Store dataframe and metadata df = pd.DataFrame({ 't': t, 'states': states, 'Vm': Vm, 'Qm': Vm * self.Cm0 * 1e-3 }) for j in range(len(self.states_names)): df[self.states_names[j]] = channels[j] meta = { 'neuron': self.name, 'Astim': Astim, 'tstim': tstim, 'toffset': toffset, 'PRF': PRF, 'DC': DC, 'tcomp': tcomp } # Export into to PKL file simcode = ESTIM_filecode(self.name, Astim, tstim, PRF, DC) outpath = '{}/{}.pkl'.format(outdir, simcode) with open(outpath, 'wb') as fh: pickle.dump({'meta': meta, 'data': df}, fh) logger.debug('simulation data exported to "%s"', outpath) # Export key metrics to log file logpath = os.path.join(outdir, 'log_ESTIM.xlsx') logentry = { 'Date': date_str, 'Time': daytime_str, 'Neuron Type': self.name, 'Astim (mA/m2)': Astim, 'Tstim (ms)': tstim * 1e3, 'PRF (kHz)': PRF * 1e-3 if DC < 1 else 'N/A', 'Duty factor': DC, '# samples': t.size, 'Comp. time (s)': round(tcomp, 2), '# spikes': nspikes, 'Latency (ms)': lat * 1e3 if isinstance(lat, float) else 'N/A', 'Spike rate (sp/ms)': sr * 1e-3 if isinstance(sr, float) else 'N/A' } if xlslog(logpath, logentry) == 1: logger.debug('log exported to "%s"', logpath) else: logger.error('log export to "%s" aborted', self.logpath) return outpath def findRheobaseAmps(self, DCs, Vthr, curr='net'): ''' Find the rheobase amplitudes (i.e. threshold amplitudes of infinite duration that would result in excitation) of a specific neuron for various stimulation duty cycles. :param DCs: duty cycles vector (-) :param Vthr: threshold membrane potential above which the neuron necessarily fires (mV) :return: rheobase amplitudes vector (mA/m2) ''' # Compute the pulse average net (or leakage) current along the amplitude space if curr == 'net': iNet = self.currNet(Vthr, self.steadyStates(Vthr)) elif curr == 'leak': iNet = self.currL(Vthr) # Compute rheobase amplitudes return iNet / np.array(DCs) diff --git a/PySONIC/neurons/cortical.py b/PySONIC/neurons/cortical.py index 529c1b1..cbda424 100644 --- a/PySONIC/neurons/cortical.py +++ b/PySONIC/neurons/cortical.py @@ -1,813 +1,811 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-07-31 15:19:51 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-26 17:23:09 - -''' Channels mechanisms for thalamic neurons. ''' +# @Last Modified time: 2018-09-28 14:05:49 import numpy as np from ..core import PointNeuron class Cortical(PointNeuron): ''' Class defining the generic membrane channel dynamics of a cortical neuron with 4 different current types: - Inward Sodium current - Outward, delayed-rectifier Potassium current - Outward, slow non.inactivating Potassium current - Non-specific leakage current This generic class cannot be used directly as it does not contain any specific parameters. Reference: *Pospischil, M., Toledo-Rodriguez, M., Monier, C., Piwkowska, Z., Bal, T., Frégnac, Y., Markram, H., and Destexhe, A. (2008). Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biol Cybern 99, 427–441.* ''' # Generic biophysical parameters of cortical cells Cm0 = 1e-2 # Cell membrane resting capacitance (F/m2) Vm0 = 0.0 # Dummy value for membrane potential (mV) VNa = 50.0 # Sodium Nernst potential (mV) VK = -90.0 # Potassium Nernst potential (mV) VCa = 120.0 # # Calcium Nernst potential (mV) def __init__(self): ''' Constructor of the class ''' # Names and initial states of the channels state probabilities self.states_names = ['m', 'h', 'n', 'p'] self.states0 = np.array([]) # Names of the different coefficients to be averaged in a lookup table. self.coeff_names = ['alpham', 'betam', 'alphah', 'betah', 'alphan', 'betan', 'alphap', 'betap'] # Charge interval bounds for lookup creation self.Qbounds = (np.round(self.Vm0 - 25.0) * 1e-5, 50.0e-5) def alpham(self, Vm): ''' Compute the alpha rate for the open-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT alpha = 0.32 * self.vtrap(13 - Vdiff, 4) # ms-1 return alpha * 1e3 # s-1 def betam(self, Vm): ''' Compute the beta rate for the open-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT beta = 0.28 * self.vtrap(Vdiff - 40, 5) # ms-1 return beta * 1e3 # s-1 def alphah(self, Vm): ''' Compute the alpha rate for the inactivation-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT alpha = (0.128 * np.exp(-(Vdiff - 17) / 18)) # ms-1 return alpha * 1e3 # s-1 def betah(self, Vm): ''' Compute the beta rate for the inactivation-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT beta = (4 / (1 + np.exp(-(Vdiff - 40) / 5))) # ms-1 return beta * 1e3 # s-1 def alphan(self, Vm): ''' Compute the alpha rate for the open-probability of delayed-rectifier Potassium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT alpha = 0.032 * self.vtrap(15 - Vdiff, 5) # ms-1 return alpha * 1e3 # s-1 def betan(self, Vm): ''' Compute the beta rate for the open-probability of delayed-rectifier Potassium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT beta = (0.5 * np.exp(-(Vdiff - 10) / 40)) # ms-1 return beta * 1e3 # s-1 def pinf(self, Vm): ''' Compute the asymptotic value of the open-probability of slow non-inactivating Potassium channels. :param Vm: membrane potential (mV) :return: asymptotic probability (-) ''' return 1.0 / (1 + np.exp(-(Vm + 35) / 10)) # prob def taup(self, Vm): ''' Compute the decay time constant for adaptation of slow non-inactivating Potassium channels. :param Vm: membrane potential (mV) :return: decayed time constant (s) ''' return self.TauMax / (3.3 * np.exp((Vm + 35) / 20) + np.exp(-(Vm + 35) / 20)) # s def derM(self, Vm, m): ''' Compute the evolution of the open-probability of Sodium channels. :param Vm: membrane potential (mV) :param m: open-probability of Sodium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alpham(Vm) * (1 - m) - self.betam(Vm) * m def derH(self, Vm, h): ''' Compute the evolution of the inactivation-probability of Sodium channels. :param Vm: membrane potential (mV) :param h: inactivation-probability of Sodium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alphah(Vm) * (1 - h) - self.betah(Vm) * h def derN(self, Vm, n): ''' Compute the evolution of the open-probability of delayed-rectifier Potassium channels. :param Vm: membrane potential (mV) :param n: open-probability of delayed-rectifier Potassium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alphan(Vm) * (1 - n) - self.betan(Vm) * n def derP(self, Vm, p): ''' Compute the evolution of the open-probability of slow non-inactivating Potassium channels. :param Vm: membrane potential (mV) :param p: open-probability of slow non-inactivating Potassium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return (self.pinf(Vm) - p) / self.taup(Vm) def currNa(self, m, h, Vm): ''' Compute the inward Sodium current per unit area. :param m: open-probability of Sodium channels :param h: inactivation-probability of Sodium channels :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' GNa = self.GNaMax * m**3 * h return GNa * (Vm - self.VNa) def currK(self, n, Vm): ''' Compute the outward, delayed-rectifier Potassium current per unit area. :param n: open-probability of delayed-rectifier Potassium channels :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' GK = self.GKMax * n**4 return GK * (Vm - self.VK) def currM(self, p, Vm): ''' Compute the outward, slow non-inactivating Potassium current per unit area. :param p: open-probability of the slow non-inactivating Potassium channels :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' GM = self.GMMax * p return GM * (Vm - self.VK) def currL(self, Vm): ''' Compute the non-specific leakage current per unit area. :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.GL * (Vm - self.VL) def currNet(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' m, h, n, p = states return (self.currNa(m, h, Vm) + self.currK(n, Vm) + self.currM(p, Vm) + self.currL(Vm)) # mA/m2 def steadyStates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Solve the equation dx/dt = 0 at Vm for each x-state meq = self.alpham(Vm) / (self.alpham(Vm) + self.betam(Vm)) heq = self.alphah(Vm) / (self.alphah(Vm) + self.betah(Vm)) neq = self.alphan(Vm) / (self.alphan(Vm) + self.betan(Vm)) peq = self.pinf(Vm) return np.array([meq, heq, neq, peq]) def derStates(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' m, h, n, p = states dmdt = self.derM(Vm, m) dhdt = self.derH(Vm, h) dndt = self.derN(Vm, n) dpdt = self.derP(Vm, p) return [dmdt, dhdt, dndt, dpdt] def getEffRates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Compute average cycle value for rate constants am_avg = np.mean(self.alpham(Vm)) bm_avg = np.mean(self.betam(Vm)) ah_avg = np.mean(self.alphah(Vm)) bh_avg = np.mean(self.betah(Vm)) an_avg = np.mean(self.alphan(Vm)) bn_avg = np.mean(self.betan(Vm)) Tp = self.taup(Vm) pinf = self.pinf(Vm) ap_avg = np.mean(pinf / Tp) bp_avg = np.mean(1 / Tp) - ap_avg # Return array of coefficients return np.array([am_avg, bm_avg, ah_avg, bh_avg, an_avg, bn_avg, ap_avg, bp_avg]) def derStatesEff(self, Qm, states, interp_data): ''' Concrete implementation of the abstract API method. ''' rates = np.array([np.interp(Qm, interp_data['Q'], interp_data[rn]) for rn in self.coeff_names]) m, h, n, p = states dmdt = rates[0] * (1 - m) - rates[1] * m dhdt = rates[2] * (1 - h) - rates[3] * h dndt = rates[4] * (1 - n) - rates[5] * n dpdt = rates[6] * (1 - p) - rates[7] * p return [dmdt, dhdt, dndt, dpdt] class CorticalRS(Cortical): ''' Specific membrane channel dynamics of a cortical regular spiking, excitatory pyramidal neuron. Reference: *Pospischil, M., Toledo-Rodriguez, M., Monier, C., Piwkowska, Z., Bal, T., Frégnac, Y., Markram, H., and Destexhe, A. (2008). Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biol Cybern 99, 427–441.* ''' # Name of channel mechanism name = 'RS' # Cell-specific biophysical parameters Vm0 = -71.9 # Cell membrane resting potential (mV) GNaMax = 560.0 # Max. conductance of Sodium current (S/m^2) GKMax = 60.0 # Max. conductance of delayed Potassium current (S/m^2) GMMax = 0.75 # Max. conductance of slow non-inactivating Potassium current (S/m^2) GL = 0.205 # Conductance of non-specific leakage current (S/m^2) VL = -70.3 # Non-specific leakage Nernst potential (mV) VT = -56.2 # Spike threshold adjustment parameter (mV) TauMax = 0.608 # Max. adaptation decay of slow non-inactivating Potassium current (s) # Default plotting scheme pltvars_scheme = { 'i_{Na}\ kin.': ['m', 'h'], 'i_K\ kin.': ['n'], 'i_M\ kin.': ['p'], 'I': ['iNa', 'iK', 'iM', 'iL', 'iNet'] } def __init__(self): ''' Constructor of the class. ''' # Instantiate parent class super().__init__() # Define initial channel probabilities (solving dx/dt = 0 at resting potential) self.states0 = self.steadyStates(self.Vm0) class CorticalFS(Cortical): ''' Specific membrane channel dynamics of a cortical fast-spiking, inhibitory neuron. Reference: *Pospischil, M., Toledo-Rodriguez, M., Monier, C., Piwkowska, Z., Bal, T., Frégnac, Y., Markram, H., and Destexhe, A. (2008). Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biol Cybern 99, 427–441.* ''' # Name of channel mechanism name = 'FS' # Cell-specific biophysical parameters Vm0 = -71.4 # Cell membrane resting potential (mV) GNaMax = 580.0 # Max. conductance of Sodium current (S/m^2) GKMax = 39.0 # Max. conductance of delayed Potassium current (S/m^2) GMMax = 0.787 # Max. conductance of slow non-inactivating Potassium current (S/m^2) GL = 0.38 # Conductance of non-specific leakage current (S/m^2) VL = -70.4 # Non-specific leakage Nernst potential (mV) VT = -57.9 # Spike threshold adjustment parameter (mV) TauMax = 0.502 # Max. adaptation decay of slow non-inactivating Potassium current (s) # Default plotting scheme pltvars_scheme = { 'i_{Na}\ kin.': ['m', 'h'], 'i_K\ kin.': ['n'], 'i_M\ kin.': ['p'], 'I': ['iNa', 'iK', 'iM', 'iL', 'iNet'] } def __init__(self): ''' Constructor of the class. ''' # Instantiate parent class super().__init__() # Define initial channel probabilities (solving dx/dt = 0 at resting potential) self.states0 = self.steadyStates(self.Vm0) class CorticalLTS(Cortical): ''' Specific membrane channel dynamics of a cortical low-threshold spiking, inhibitory neuron with an additional inward Calcium current due to the presence of a T-type channel. References: *Pospischil, M., Toledo-Rodriguez, M., Monier, C., Piwkowska, Z., Bal, T., Frégnac, Y., Markram, H., and Destexhe, A. (2008). Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biol Cybern 99, 427–441.* *Huguenard, J.R., and McCormick, D.A. (1992). Simulation of the currents involved in rhythmic oscillations in thalamic relay neurons. J. Neurophysiol. 68, 1373–1383.* ''' # Name of channel mechanism name = 'LTS' # Cell-specific biophysical parameters Vm0 = -54.0 # Cell membrane resting potential (mV) GNaMax = 500.0 # Max. conductance of Sodium current (S/m^2) GKMax = 40.0 # Max. conductance of delayed Potassium current (S/m^2) GMMax = 0.28 # Max. conductance of slow non-inactivating Potassium current (S/m^2) GTMax = 4.0 # Max. conductance of low-threshold Calcium current (S/m^2) GL = 0.19 # Conductance of non-specific leakage current (S/m^2) VL = -50.0 # Non-specific leakage Nernst potential (mV) VT = -50.0 # Spike threshold adjustment parameter (mV) TauMax = 4.0 # Max. adaptation decay of slow non-inactivating Potassium current (s) Vx = -7.0 # Voltage-dependence uniform shift factor at 36°C (mV) # Default plotting scheme pltvars_scheme = { 'i_{Na}\ kin.': ['m', 'h'], 'i_K\ kin.': ['n'], 'i_M\ kin.': ['p'], 'i_T\ kin.': ['s', 'u'], 'I': ['iNa', 'iK', 'iM', 'iT', 'iL', 'iNet'] } def __init__(self): ''' Constructor of the class. ''' # Instantiate parent class super().__init__() # Add names of cell-specific Calcium channel probabilities self.states_names += ['s', 'u'] # Define initial channel probabilities (solving dx/dt = 0 at resting potential) self.states0 = self.steadyStates(self.Vm0) # Define the names of the different coefficients to be averaged in a lookup table. self.coeff_names += ['alphas', 'betas', 'alphau', 'betau'] def sinf(self, Vm): ''' Compute the asymptotic value of the open-probability of the S-type, activation gate of Calcium channels. :param Vm: membrane potential (mV) :return: asymptotic probability (-) ''' return 1.0 / (1.0 + np.exp(-(Vm + self.Vx + 57.0) / 6.2)) # prob def taus(self, Vm): ''' Compute the decay time constant for adaptation of S-type, activation gate of Calcium channels. :param Vm: membrane potential (mV) :return: decayed time constant (s) ''' tmp = np.exp(-(Vm + self.Vx + 132.0) / 16.7) + np.exp((Vm + self.Vx + 16.8) / 18.2) return 1.0 / 3.7 * (0.612 + 1.0 / tmp) * 1e-3 # s def uinf(self, Vm): ''' Compute the asymptotic value of the open-probability of the U-type, inactivation gate of Calcium channels. :param Vm: membrane potential (mV) :return: asymptotic probability (-) ''' return 1.0 / (1.0 + np.exp((Vm + self.Vx + 81.0) / 4.0)) # prob def tauu(self, Vm): ''' Compute the decay time constant for adaptation of U-type, inactivation gate of Calcium channels. :param Vm: membrane potential (mV) :return: decayed time constant (s) ''' if Vm + self.Vx < -80.0: return 1.0 / 3.7 * np.exp((Vm + self.Vx + 467.0) / 66.6) * 1e-3 # s else: return 1.0 / 3.7 * (np.exp(-(Vm + self.Vx + 22) / 10.5) + 28.0) * 1e-3 # s def derS(self, Vm, s): ''' Compute the evolution of the open-probability of the S-type, activation gate of Calcium channels. :param Vm: membrane potential (mV) :param s: open-probability of S-type Calcium activation gates (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return (self.sinf(Vm) - s) / self.taus(Vm) def derU(self, Vm, u): ''' Compute the evolution of the open-probability of the U-type, inactivation gate of Calcium channels. :param Vm: membrane potential (mV) :param u: open-probability of U-type Calcium inactivation gates (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return (self.uinf(Vm) - u) / self.tauu(Vm) def currCa(self, s, u, Vm): ''' Compute the inward, low-threshold Calcium current per unit area. :param s: open-probability of the S-type activation gate of Calcium channels :param u: open-probability of the U-type inactivation gate of Calcium channels :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' GT = self.GTMax * s**2 * u return GT * (Vm - self.VCa) def currNet(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' m, h, n, p, s, u = states return (self.currNa(m, h, Vm) + self.currK(n, Vm) + self.currM(p, Vm) + self.currCa(s, u, Vm) + self.currL(Vm)) # mA/m2 def steadyStates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Call parent method to compute Sodium and Potassium channels gates steady-states NaK_eqstates = super().steadyStates(Vm) # Compute Calcium channel gates steady-states seq = self.sinf(Vm) ueq = self.uinf(Vm) Ca_eqstates = np.array([seq, ueq]) # Merge all steady-states and return return np.concatenate((NaK_eqstates, Ca_eqstates)) def derStates(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' # Unpack input states *NaK_states, s, u = states # Call parent method to compute Sodium and Potassium channels states derivatives NaK_derstates = super().derStates(Vm, NaK_states) # Compute Calcium channels states derivatives dsdt = self.derS(Vm, s) dudt = self.derU(Vm, u) # Merge all states derivatives and return return NaK_derstates + [dsdt, dudt] def getEffRates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Call parent method to compute Sodium and Potassium effective rate constants NaK_rates = super().getEffRates(Vm) # Compute Calcium effective rate constants Ts = self.taus(Vm) as_avg = np.mean(self.sinf(Vm) / Ts) bs_avg = np.mean(1 / Ts) - as_avg Tu = np.array([self.tauu(v) for v in Vm]) au_avg = np.mean(self.uinf(Vm) / Tu) bu_avg = np.mean(1 / Tu) - au_avg Ca_rates = np.array([as_avg, bs_avg, au_avg, bu_avg]) # Merge all rates and return return np.concatenate((NaK_rates, Ca_rates)) def derStatesEff(self, Qm, states, interp_data): ''' Concrete implementation of the abstract API method. ''' # Unpack input states *NaK_states, s, u = states # Call parent method to compute Sodium and Potassium channels states derivatives NaK_dstates = super().derStatesEff(Qm, NaK_states, interp_data) # Compute Calcium channels states derivatives Ca_rates = np.array([np.interp(Qm, interp_data['Q'], interp_data[rn]) for rn in self.coeff_names[8:]]) dsdt = Ca_rates[0] * (1 - s) - Ca_rates[1] * s dudt = Ca_rates[2] * (1 - u) - Ca_rates[3] * u # Merge all states derivatives and return return NaK_dstates + [dsdt, dudt] class CorticalIB(Cortical): ''' Specific membrane channel dynamics of a cortical intrinsically bursting neuron with an additional inward Calcium current due to the presence of a L-type channel. References: *Pospischil, M., Toledo-Rodriguez, M., Monier, C., Piwkowska, Z., Bal, T., Frégnac, Y., Markram, H., and Destexhe, A. (2008). Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biol Cybern 99, 427–441.* *Reuveni, I., Friedman, A., Amitai, Y., and Gutnick, M.J. (1993). Stepwise repolarization from Ca2+ plateaus in neocortical pyramidal cells: evidence for nonhomogeneous distribution of HVA Ca2+ channels in dendrites. J. Neurosci. 13, 4609–4621.* ''' # Name of channel mechanism name = 'IB' # Cell-specific biophysical parameters Vm0 = -71.4 # Cell membrane resting potential (mV) GNaMax = 500 # Max. conductance of Sodium current (S/m^2) GKMax = 50 # Max. conductance of delayed Potassium current (S/m^2) GMMax = 0.3 # Max. conductance of slow non-inactivating Potassium current (S/m^2) GCaLMax = 1.0 # Max. conductance of L-type Calcium current (S/m^2) GL = 0.1 # Conductance of non-specific leakage current (S/m^2) VL = -70 # Non-specific leakage Nernst potential (mV) VT = -56.2 # Spike threshold adjustment parameter (mV) TauMax = 0.608 # Max. adaptation decay of slow non-inactivating Potassium current (s) # Default plotting scheme pltvars_scheme = { 'i_{Na}\ kin.': ['m', 'h'], 'i_K\ kin.': ['n'], 'i_M\ kin.': ['p'], 'i_{CaL}\ kin.': ['q', 'r', 'q2r'], 'I': ['iNa', 'iK', 'iM', 'iCaL', 'iL', 'iNet'] } def __init__(self): ''' Constructor of the class. ''' # Instantiate parent class super().__init__() # Add names of cell-specific Calcium channel probabilities self.states_names += ['q', 'r'] # Define initial channel probabilities (solving dx/dt = 0 at resting potential) self.states0 = self.steadyStates(self.Vm0) # Define the names of the different coefficients to be averaged in a lookup table. self.coeff_names += ['alphaq', 'betaq', 'alphar', 'betar'] def alphaq(self, Vm): ''' Compute the alpha rate for the open-probability of L-type Calcium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' alpha = 0.055 * self.vtrap(-(Vm + 27), 3.8) # ms-1 return alpha * 1e3 # s-1 def betaq(self, Vm): ''' Compute the beta rate for the open-probability of L-type Calcium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' beta = 0.94 * np.exp(-(Vm + 75) / 17) # ms-1 return beta * 1e3 # s-1 def alphar(self, Vm): ''' Compute the alpha rate for the inactivation-probability of L-type Calcium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' alpha = 0.000457 * np.exp(-(Vm + 13) / 50) # ms-1 return alpha * 1e3 # s-1 def betar(self, Vm): ''' Compute the beta rate for the inactivation-probability of L-type Calcium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' beta = 0.0065 / (np.exp(-(Vm + 15) / 28) + 1) # ms-1 return beta * 1e3 # s-1 def derQ(self, Vm, q): ''' Compute the evolution of the open-probability of the Q (activation) gate of L-type Calcium channels. :param Vm: membrane potential (mV) :param q: open-probability of Q gate (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alphaq(Vm) * (1 - q) - self.betaq(Vm) * q def derR(self, Vm, r): ''' Compute the evolution of the open-probability of the R (inactivation) gate of L-type Calcium channels. :param Vm: membrane potential (mV) :param r: open-probability of R gate (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alphar(Vm) * (1 - r) - self.betar(Vm) * r def currCaL(self, q, r, Vm): ''' Compute the inward L-type Calcium current per unit area. :param q: open-probability of Q gate (prob) :param r: open-probability of R gate (prob) :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' GCaL = self.GCaLMax * q**2 * r return GCaL * (Vm - self.VCa) def currNet(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' m, h, n, p, q, r = states return (self.currNa(m, h, Vm) + self.currK(n, Vm) + self.currM(p, Vm) + self.currCaL(q, r, Vm) + self.currL(Vm)) # mA/m2 def steadyStates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Call parent method to compute Sodium and Potassium channels gates steady-states NaK_eqstates = super().steadyStates(Vm) # Compute L-type Calcium channel gates steady-states qeq = self.alphaq(Vm) / (self.alphaq(Vm) + self.betaq(Vm)) req = self.alphar(Vm) / (self.alphar(Vm) + self.betar(Vm)) CaL_eqstates = np.array([qeq, req]) # Merge all steady-states and return return np.concatenate((NaK_eqstates, CaL_eqstates)) def derStates(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' # Unpack input states *NaK_states, q, r = states # Call parent method to compute Sodium and Potassium channels states derivatives NaK_derstates = super().derStates(Vm, NaK_states) # Compute L-type Calcium channels states derivatives dqdt = self.derQ(Vm, q) drdt = self.derR(Vm, r) # Merge all states derivatives and return return NaK_derstates + [dqdt, drdt] def getEffRates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Call parent method to compute Sodium and Potassium effective rate constants NaK_rates = super().getEffRates(Vm) # Compute Calcium effective rate constants aq_avg = np.mean(self.alphaq(Vm)) bq_avg = np.mean(self.betaq(Vm)) ar_avg = np.mean(self.alphar(Vm)) br_avg = np.mean(self.betar(Vm)) CaL_rates = np.array([aq_avg, bq_avg, ar_avg, br_avg]) # Merge all rates and return return np.concatenate((NaK_rates, CaL_rates)) def derStatesEff(self, Qm, states, interp_data): ''' Concrete implementation of the abstract API method. ''' # Unpack input states *NaK_states, q, r = states # Call parent method to compute Sodium and Potassium channels states derivatives NaK_dstates = super().derStatesEff(Qm, NaK_states, interp_data) # Compute Calcium channels states derivatives CaL_rates = np.array([np.interp(Qm, interp_data['Q'], interp_data[rn]) for rn in self.coeff_names[8:]]) dqdt = CaL_rates[0] * (1 - q) - CaL_rates[1] * q drdt = CaL_rates[2] * (1 - r) - CaL_rates[3] * r # Merge all states derivatives and return return NaK_dstates + [dqdt, drdt] diff --git a/PySONIC/neurons/leech.py b/PySONIC/neurons/leech.py index 0556c73..23d08ef 100644 --- a/PySONIC/neurons/leech.py +++ b/PySONIC/neurons/leech.py @@ -1,1079 +1,1076 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-07-31 15:20:54 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-26 17:22:59 +# @Last Modified time: 2018-09-28 14:06:06 -''' Channels mechanisms for leech ganglion neurons. ''' from functools import partialmethod import numpy as np from ..core import PointNeuron -from ..utils import logger - class LeechTouch(PointNeuron): ''' Class defining the membrane channel dynamics of a leech touch sensory neuron. with 4 different current types: - Inward Sodium current - Outward Potassium current - Inward Calcium current - Non-specific leakage current - Calcium-dependent, outward Potassium current - Outward, Sodium pumping current Reference: *Cataldo, E., Brunelli, M., Byrne, J.H., Av-Ron, E., Cai, Y., and Baxter, D.A. (2005). Computational model of touch sensory cells (T Cells) of the leech: role of the afterhyperpolarization (AHP) in activity-dependent conduction failure. J Comput Neurosci 18, 5–24.* ''' # Name of channel mechanism name = 'LeechT' # Cell-specific biophysical parameters Cm0 = 1e-2 # Cell membrane resting capacitance (F/m2) Vm0 = -53.58 # Cell membrane resting potential (mV) VNa = 45.0 # Sodium Nernst potential (mV) VK = -62.0 # Potassium Nernst potential (mV) VCa = 60.0 # Calcium Nernst potential (mV) VL = -48.0 # Non-specific leakage Nernst potential (mV) VPumpNa = -300.0 # Sodium pump current reversal potential (mV) GNaMax = 3500.0 # Max. conductance of Sodium current (S/m^2) GKMax = 900.0 # Max. conductance of Potassium current (S/m^2) GCaMax = 20.0 # Max. conductance of Calcium current (S/m^2) GKCaMax = 236.0 # Max. conductance of Calcium-dependent Potassium current (S/m^2) GL = 1.0 # Conductance of non-specific leakage current (S/m^2) GPumpNa = 20.0 # Max. conductance of Sodium pump current (S/m^2) taum = 0.1e-3 # Sodium activation time constant (s) taus = 0.6e-3 # Calcium activation time constant (s) # Original conversion constants from inward ion current (nA) to build-up of # intracellular ion concentration (arb.) K_Na_original = 0.016 # iNa to intracellular [Na+] K_Ca_original = 0.1 # iCa to intracellular [Ca2+] # Constants needed to convert K from original model (soma compartment) # to current model (point-neuron) surface = 6434.0e-12 # surface of cell assumed as a single soma (m2) curr_factor = 1e6 # mA to nA # Time constants for the removal of ions from intracellular pools (s) tau_Na_removal = 16.0 # Na+ removal tau_Ca_removal = 1.25 # Ca2+ removal # Time constants for the iPumpNa and iKCa currents activation # from specific intracellular ions (s) tau_PumpNa_act = 0.1 # iPumpNa activation from intracellular Na+ tau_KCa_act = 0.01 # iKCa activation from intracellular Ca2+ # Default plotting scheme pltvars_scheme = { 'i_{Na}\ kin.': ['m', 'h', 'm3h'], 'i_K\ kin.': ['n'], 'i_{Ca}\ kin.': ['s'], 'pools': ['C_Na_arb', 'C_Na_arb_activation', 'C_Ca_arb', 'C_Ca_arb_activation'], 'I': ['iNa', 'iK', 'iCa', 'iKCa', 'iPumpNa', 'iL', 'iNet'] } def __init__(self): ''' Constructor of the class. ''' # Names and initial states of the channels state probabilities self.states_names = ['m', 'h', 'n', 's', 'C_Na', 'A_Na', 'C_Ca', 'A_Ca'] self.states0 = np.array([]) # Names of the channels effective coefficients self.coeff_names = ['alpham', 'betam', 'alphah', 'betah', 'alphan', 'betan', 'alphas', 'betas'] self.K_Na = self.K_Na_original * self.surface * self.curr_factor self.K_Ca = self.K_Ca_original * self.surface * self.curr_factor # Define initial channel probabilities (solving dx/dt = 0 at resting potential) self.states0 = self.steadyStates(self.Vm0) # Charge interval bounds for lookup creation self.Qbounds = (np.round(self.Vm0 - 10.0) * 1e-5, 50.0e-5) # ----------------- Generic ----------------- def _xinf(self, Vm, halfmax, slope, power): ''' Generic function computing the steady-state activation/inactivation of a particular ion channel at a given voltage. :param Vm: membrane potential (mV) :param halfmax: half-(in)activation voltage (mV) :param slope: slope parameter of (in)activation function (mV) :param power: power exponent multiplying the exponential expression (integer) :return: steady-state (in)activation (-) ''' return 1 / (1 + np.exp((Vm - halfmax) / slope))**power def _taux(self, Vm, halfmax, slope, tauMax, tauMin): ''' Generic function computing the voltage-dependent, activation/inactivation time constant of a particular ion channel at a given voltage. :param Vm: membrane potential (mV) :param halfmax: voltage at which (in)activation time constant is half-maximal (mV) :param slope: slope parameter of (in)activation time constant function (mV) :return: steady-state (in)activation (-) ''' return (tauMax - tauMin) / (1 + np.exp((Vm - halfmax) / slope)) + tauMin def _derC_ion(self, Cion, Iion, Kion, tau): ''' Generic function computing the time derivative of the concentration of a specific ion in its intracellular pool. :param Cion: ion concentration in the pool (arbitrary unit) :param Iion: ionic current (mA/m2) :param Kion: scaling factor for current contribution to pool (arb. unit / nA???) :param tau: time constant for removal of ions from the pool (s) :return: variation of ionic concentration in the pool (arbitrary unit /s) ''' return (Kion * (-Iion) - Cion) / tau def _derA_ion(self, Aion, Cion, tau): ''' Generic function computing the time derivative of the concentration and time dependent activation function, for a specific pool-dependent ionic current. :param Aion: concentration and time dependent activation function (arbitrary unit) :param Cion: ion concentration in the pool (arbitrary unit) :param tau: time constant for activation function variation (s) :return: variation of activation function (arbitrary unit / s) ''' return (Cion - Aion) / tau # ------------------ Na ------------------- minf = partialmethod(_xinf, halfmax=-35.0, slope=-5.0, power=1) hinf = partialmethod(_xinf, halfmax=-50.0, slope=9.0, power=2) tauh = partialmethod(_taux, halfmax=-36.0, slope=3.5, tauMax=14.0e-3, tauMin=0.2e-3) def derM(self, Vm, m): ''' Instantaneous derivative of Sodium activation. ''' return (self.minf(Vm) - m) / self.taum # s-1 def derH(self, Vm, h): ''' Instantaneous derivative of Sodium inactivation. ''' return (self.hinf(Vm) - h) / self.tauh(Vm) # s-1 # ------------------ K ------------------- ninf = partialmethod(_xinf, halfmax=-22.0, slope=-9.0, power=1) taun = partialmethod(_taux, halfmax=-10.0, slope=10.0, tauMax=6.0e-3, tauMin=1.0e-3) def derN(self, Vm, n): ''' Instantaneous derivative of Potassium activation. ''' return (self.ninf(Vm) - n) / self.taun(Vm) # s-1 # ------------------ Ca ------------------- sinf = partialmethod(_xinf, halfmax=-10.0, slope=-2.8, power=1) def derS(self, Vm, s): ''' Instantaneous derivative of Calcium activation. ''' return (self.sinf(Vm) - s) / self.taus # s-1 # ------------------ Pools ------------------- def derC_Na(self, C_Na, I_Na): ''' Derivative of Sodium concentration in intracellular pool. ''' return self._derC_ion(C_Na, I_Na, self.K_Na, self.tau_Na_removal) def derA_Na(self, A_Na, C_Na): ''' Derivative of Sodium pool-dependent activation function for iPumpNa. ''' return self._derA_ion(A_Na, C_Na, self.tau_PumpNa_act) def derC_Ca(self, C_Ca, I_Ca): ''' Derivative of Calcium concentration in intracellular pool. ''' return self._derC_ion(C_Ca, I_Ca, self.K_Ca, self.tau_Ca_removal) def derA_Ca(self, A_Ca, C_Ca): ''' Derivative of Calcium pool-dependent activation function for iKCa. ''' return self._derA_ion(A_Ca, C_Ca, self.tau_KCa_act) # ------------------ Currents ------------------- def currNa(self, m, h, Vm): ''' Sodium inward current. ''' return self.GNaMax * m**3 * h * (Vm - self.VNa) def currK(self, n, Vm): ''' Potassium outward current. ''' return self.GKMax * n**2 * (Vm - self.VK) def currCa(self, s, Vm): ''' Calcium inward current. ''' return self.GCaMax * s * (Vm - self.VCa) def currKCa(self, A_Ca, Vm): ''' Calcium-activated Potassium outward current. ''' return self.GKCaMax * A_Ca * (Vm - self.VK) def currPumpNa(self, A_Na, Vm): ''' Outward current mimicking the activity of the NaK-ATPase pump. ''' return self.GPumpNa * A_Na * (Vm - self.VPumpNa) def currL(self, Vm): ''' Leakage current. ''' return self.GL * (Vm - self.VL) def currNet(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' m, h, n, s, _, A_Na, _, A_Ca = states return (self.currNa(m, h, Vm) + self.currK(n, Vm) + self.currCa(s, Vm) + self.currL(Vm) + self.currPumpNa(A_Na, Vm) + self.currKCa(A_Ca, Vm)) # mA/m2 def steadyStates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Standard gating dynamics: Solve the equation dx/dt = 0 at Vm for each x-state meq = self.minf(Vm) heq = self.hinf(Vm) neq = self.ninf(Vm) seq = self.sinf(Vm) # PumpNa pool concentration and activation steady-state INa_eq = self.currNa(meq, heq, Vm) CNa_eq = self.K_Na * (-INa_eq) ANa_eq = CNa_eq # KCa current pool concentration and activation steady-state ICa_eq = self.currCa(seq, Vm) CCa_eq = self.K_Ca * (-ICa_eq) ACa_eq = CCa_eq return np.array([meq, heq, neq, seq, CNa_eq, ANa_eq, CCa_eq, ACa_eq]) def derStates(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' # Unpack states m, h, n, s, C_Na, A_Na, C_Ca, A_Ca = states # Standard gating states derivatives dmdt = self.derM(Vm, m) dhdt = self.derH(Vm, h) dndt = self.derN(Vm, n) dsdt = self.derS(Vm, s) # PumpNa current pool concentration and activation state I_Na = self.currNa(m, h, Vm) dCNa_dt = self.derC_Na(C_Na, I_Na) dANa_dt = self.derA_Na(A_Na, C_Na) # KCa current pool concentration and activation state I_Ca = self.currCa(s, Vm) dCCa_dt = self.derC_Ca(C_Ca, I_Ca) dACa_dt = self.derA_Ca(A_Ca, C_Ca) # Pack derivatives and return return [dmdt, dhdt, dndt, dsdt, dCNa_dt, dANa_dt, dCCa_dt, dACa_dt] def getEffRates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Compute average cycle value for rate constants Tm = self.taum minf = self.minf(Vm) am_avg = np.mean(minf / Tm) bm_avg = np.mean(1 / Tm) - am_avg Th = self.tauh(Vm) hinf = self.hinf(Vm) ah_avg = np.mean(hinf / Th) bh_avg = np.mean(1 / Th) - ah_avg Tn = self.taun(Vm) ninf = self.ninf(Vm) an_avg = np.mean(ninf / Tn) bn_avg = np.mean(1 / Tn) - an_avg Ts = self.taus sinf = self.sinf(Vm) as_avg = np.mean(sinf / Ts) bs_avg = np.mean(1 / Ts) - as_avg # Return array of coefficients return np.array([am_avg, bm_avg, ah_avg, bh_avg, an_avg, bn_avg, as_avg, bs_avg]) def derStatesEff(self, Qm, states, interp_data): ''' Concrete implementation of the abstract API method. ''' rates = np.array([np.interp(Qm, interp_data['Q'], interp_data[rn]) for rn in self.coeff_names]) Vmeff = np.interp(Qm, interp_data['Q'], interp_data['V']) # Unpack states m, h, n, s, C_Na, A_Na, C_Ca, A_Ca = states # Standard gating states derivatives dmdt = rates[0] * (1 - m) - rates[1] * m dhdt = rates[2] * (1 - h) - rates[3] * h dndt = rates[4] * (1 - n) - rates[5] * n dsdt = rates[6] * (1 - s) - rates[7] * s # PumpNa current pool concentration and activation state I_Na = self.currNa(m, h, Vmeff) dCNa_dt = self.derC_Na(C_Na, I_Na) dANa_dt = self.derA_Na(A_Na, C_Na) # KCa current pool concentration and activation state I_Ca_eff = self.currCa(s, Vmeff) dCCa_dt = self.derC_Ca(C_Ca, I_Ca_eff) dACa_dt = self.derA_Ca(A_Ca, C_Ca) # Pack derivatives and return return [dmdt, dhdt, dndt, dsdt, dCNa_dt, dANa_dt, dCCa_dt, dACa_dt] class LeechMech(PointNeuron): ''' Class defining the basic dynamics of Sodium, Potassium and Calcium channels for several neurons of the leech. Reference: *Baccus, S.A. (1998). Synaptic facilitation by reflected action potentials: enhancement of transmission when nerve impulses reverse direction at axon branch points. Proc. Natl. Acad. Sci. U.S.A. 95, 8345–8350.* ''' alphaC_sf = 1e-5 # Calcium activation rate constant scaling factor (M) betaC = 0.1e3 # beta rate for the open-probability of Ca2+-dependent Potassium channels (s-1) T = 293.15 # Room temperature (K) Rg = 8.314 # Universal gas constant (J.mol^-1.K^-1) Faraday = 9.6485e4 # Faraday constant (C/mol) def alpham(self, Vm): ''' Compute the alpha rate for the open-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' alpha = -0.03 * (Vm + 28) / (np.exp(- (Vm + 28) / 15) - 1) # ms-1 return alpha * 1e3 # s-1 def betam(self, Vm): ''' Compute the beta rate for the open-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' beta = 2.7 * np.exp(-(Vm + 53) / 18) # ms-1 return beta * 1e3 # s-1 def alphah(self, Vm): ''' Compute the alpha rate for the inactivation-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' alpha = 0.045 * np.exp(-(Vm + 58) / 18) # ms-1 return alpha * 1e3 # s-1 def betah(self, Vm): ''' Compute the beta rate for the inactivation-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) .. warning:: the original paper contains an error (multiplication) in the expression of this rate constant, corrected in the mod file on ModelDB (division). ''' beta = 0.72 / (np.exp(-(Vm + 23) / 14) + 1) # ms-1 return beta * 1e3 # s-1 def alphan(self, Vm): ''' Compute the alpha rate for the open-probability of delayed-rectifier Potassium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' alpha = -0.024 * (Vm - 17) / (np.exp(-(Vm - 17) / 8) - 1) # ms-1 return alpha * 1e3 # s-1 def betan(self, Vm): ''' Compute the beta rate for the open-probability of delayed-rectifier Potassium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' beta = 0.2 * np.exp(-(Vm + 48) / 35) # ms-1 return beta * 1e3 # s-1 def alphas(self, Vm): ''' Compute the alpha rate for the open-probability of Calcium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' alpha = -1.5 * (Vm - 20) / (np.exp(-(Vm - 20) / 5) - 1) # ms-1 return alpha * 1e3 # s-1 def betas(self, Vm): ''' Compute the beta rate for the open-probability of Calcium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' beta = 1.5 * np.exp(-(Vm + 25) / 10) # ms-1 return beta * 1e3 # s-1 def alphaC(self, C_Ca_in): ''' Compute the alpha rate for the open-probability of Calcium-dependent Potassium channels. :param C_Ca_in: intracellular Calcium concentration (M) :return: rate constant (s-1) ''' alpha = 0.1 * C_Ca_in / self.alphaC_sf # ms-1 return alpha * 1e3 # s-1 def derM(self, Vm, m): ''' Compute the evolution of the open-probability of Sodium channels. :param Vm: membrane potential (mV) :param m: open-probability of Sodium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alpham(Vm) * (1 - m) - self.betam(Vm) * m def derH(self, Vm, h): ''' Compute the evolution of the inactivation-probability of Sodium channels. :param Vm: membrane potential (mV) :param h: inactivation-probability of Sodium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alphah(Vm) * (1 - h) - self.betah(Vm) * h def derN(self, Vm, n): ''' Compute the evolution of the open-probability of delayed-rectifier Potassium channels. :param Vm: membrane potential (mV) :param n: open-probability of delayed-rectifier Potassium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alphan(Vm) * (1 - n) - self.betan(Vm) * n def derS(self, Vm, s): ''' Compute the evolution of the open-probability of Calcium channels. :param Vm: membrane potential (mV) :param s: open-probability of Calcium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alphas(Vm) * (1 - s) - self.betas(Vm) * s def derC(self, c, C_Ca_in): ''' Compute the evolution of the open-probability of Calcium-dependent Potassium channels. :param c: open-probability of Calcium-dependent Potassium channels (prob) :param C_Ca_in: intracellular Calcium concentration (M) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alphaC(C_Ca_in) * (1 - c) - self.betaC * c def currNa(self, m, h, Vm, C_Na_in): ''' Compute the inward Sodium current per unit area. :param m: open-probability of Sodium channels :param h: inactivation-probability of Sodium channels :param Vm: membrane potential (mV) :param C_Na_in: intracellular Sodium concentration (M) :return: current per unit area (mA/m2) ''' GNa = self.GNaMax * m**4 * h VNa = self.nernst(self.Z_Na, C_Na_in, self.C_Na_out) # Sodium Nernst potential return GNa * (Vm - VNa) def currK(self, n, Vm): ''' Compute the outward, delayed-rectifier Potassium current per unit area. :param n: open-probability of delayed-rectifier Potassium channels :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' GK = self.GKMax * n**2 return GK * (Vm - self.VK) def currCa(self, s, Vm, C_Ca_in): ''' Compute the inward Calcium current per unit area. :param s: open-probability of Calcium channels :param Vm: membrane potential (mV) :param C_Ca_in: intracellular Calcium concentration (M) :return: current per unit area (mA/m2) ''' GCa = self.GCaMax * s VCa = self.nernst(self.Z_Ca, C_Ca_in, self.C_Ca_out) # Calcium Nernst potential return GCa * (Vm - VCa) def currKCa(self, c, Vm): ''' Compute the outward Calcium-dependent Potassium current per unit area. :param c: open-probability of Calcium-dependent Potassium channels :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' GKCa = self.GKCaMax * c return GKCa * (Vm - self.VK) def currL(self, Vm): ''' Compute the non-specific leakage current per unit area. :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.GL * (Vm - self.VL) class LeechPressure(LeechMech): ''' Class defining the membrane channel dynamics of a leech pressure sensory neuron. with 7 different current types: - Inward Sodium current - Outward Potassium current - Inward high-voltage-activated Calcium current - Non-specific leakage current - Calcium-dependent, outward Potassium current - Sodium pump current - Calcium pump current Reference: *Baccus, S.A. (1998). Synaptic facilitation by reflected action potentials: enhancement of transmission when nerve impulses reverse direction at axon branch points. Proc. Natl. Acad. Sci. U.S.A. 95, 8345–8350.* ''' # Name of channel mechanism name = 'LeechP' # Cell-specific biophysical parameters Cm0 = 1e-2 # Cell membrane resting capacitance (F/m2) Vm0 = -48.865 # Cell membrane resting potential (mV) C_Na_out = 0.11 # Sodium extracellular concentration (M) C_Ca_out = 1.8e-3 # Calcium extracellular concentration (M) C_Na_in0 = 0.01 # Initial Sodium intracellular concentration (M) C_Ca_in0 = 1e-7 # Initial Calcium intracellular concentration (M) # VNa = 60 # Sodium Nernst potential, from MOD file on ModelDB (mV) # VCa = 125 # Calcium Nernst potential, from MOD file on ModelDB (mV) VK = -68.0 # Potassium Nernst potential (mV) VL = -49.0 # Non-specific leakage Nernst potential (mV) INaPmax = 70.0 # Maximum pump rate of the NaK-ATPase (mA/m2) khalf_Na = 0.012 # Sodium concentration at which NaK-ATPase is at half its maximum rate (M) ksteep_Na = 1e-3 # Sensitivity of NaK-ATPase to varying Sodium concentrations (M) iCaS = 0.1 # Calcium pump current parameter (mA/m2) GNaMax = 3500.0 # Max. conductance of Sodium current (S/m^2) GKMax = 60.0 # Max. conductance of Potassium current (S/m^2) GCaMax = 0.02 # Max. conductance of Calcium current (S/m^2) GKCaMax = 8.0 # Max. conductance of Calcium-dependent Potassium current (S/m^2) GL = 5.0 # Conductance of non-specific leakage current (S/m^2) diam = 50e-6 # Cell soma diameter (m) Z_Na = 1 # Sodium valence Z_Ca = 2 # Calcium valence # Default plotting scheme pltvars_scheme = { 'i_{Na}\ kin.': ['m', 'h', 'm4h'], 'i_K\ kin.': ['n'], 'i_{Ca}\ kin.': ['s'], 'i_{KCa}\ kin.': ['c'], 'pools': ['C_Na', 'C_Ca'], 'I': ['iNa2', 'iK', 'iCa2', 'iKCa2', 'iPumpNa2', 'iPumpCa2', 'iL', 'iNet'] } def __init__(self): ''' Constructor of the class. ''' SV_ratio = 6 / self.diam # surface to volume ratio of the (spherical) cell soma # Conversion constant from membrane ionic currents into # change rate of intracellular ionic concentrations self.K_Na = SV_ratio / (self.Z_Na * self.Faraday) * 1e-6 # Sodium (M/s) self.K_Ca = SV_ratio / (self.Z_Ca * self.Faraday) * 1e-6 # Calcium (M/s) # Names and initial states of the channels state probabilities self.states_names = ['m', 'h', 'n', 's', 'c', 'C_Na', 'C_Ca'] self.states0 = np.array([]) # Names of the channels effective coefficients self.coeff_names = ['alpham', 'betam', 'alphah', 'betah', 'alphan', 'betan', 'alphas', 'betas'] # Define initial channel probabilities (solving dx/dt = 0 at resting potential) self.states0 = self.steadyStates(self.Vm0) # Charge interval bounds for lookup creation self.Qbounds = (np.round(self.Vm0 - 10.0) * 1e-5, 60.0e-5) def nernst(self, z_ion, C_ion_in, C_ion_out): ''' Return the Nernst potential of a specific ion given its intra and extracellular concentrations. :param z_ion: ion valence :param C_ion_in: intracellular ion concentration (M) :param C_ion_out: extracellular ion concentration (M) :return: ion Nernst potential (mV) ''' return (self.Rg * self.T) / (z_ion * self.Faraday) * np.log(C_ion_out / C_ion_in) * 1e3 def currPumpNa(self, C_Na_in): ''' Outward current mimicking the activity of the NaK-ATPase pump. :param C_Na_in: intracellular Sodium concentration (M) :return: current per unit area (mA/m2) ''' return self.INaPmax / (1 + np.exp((self.khalf_Na - C_Na_in) / self.ksteep_Na)) def currPumpCa(self, C_Ca_in): ''' Outward current representing the activity of a Calcium pump. :param C_Ca_in: intracellular Calcium concentration (M) :return: current per unit area (mA/m2) ''' return self.iCaS * (C_Ca_in - self.C_Ca_in0) / 1.5 def currNet(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' m, h, n, s, c, C_Na_in, C_Ca_in = states return (self.currNa(m, h, Vm, C_Na_in) + self.currK(n, Vm) + self.currCa(s, Vm, C_Ca_in) + self.currKCa(c, Vm) + self.currL(Vm) + (self.currPumpNa(C_Na_in) / 3.) + self.currPumpCa(C_Ca_in)) # mA/m2 def steadyStates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Intracellular concentrations C_Na_eq = self.C_Na_in0 C_Ca_eq = self.C_Ca_in0 # Standard gating dynamics: Solve the equation dx/dt = 0 at Vm for each x-state meq = self.alpham(Vm) / (self.alpham(Vm) + self.betam(Vm)) heq = self.alphah(Vm) / (self.alphah(Vm) + self.betah(Vm)) neq = self.alphan(Vm) / (self.alphan(Vm) + self.betan(Vm)) seq = self.alphas(Vm) / (self.alphas(Vm) + self.betas(Vm)) ceq = self.alphaC(C_Ca_eq) / (self.alphaC(C_Ca_eq) + self.betaC) return np.array([meq, heq, neq, seq, ceq, C_Na_eq, C_Ca_eq]) def derStates(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' # Unpack states m, h, n, s, c, C_Na_in, C_Ca_in = states # Standard gating states derivatives dmdt = self.derM(Vm, m) dhdt = self.derH(Vm, h) dndt = self.derN(Vm, n) dsdt = self.derS(Vm, s) dcdt = self.derC(c, C_Ca_in) # Intracellular concentrations dCNa_dt = - (self.currNa(m, h, Vm, C_Na_in) + self.currPumpNa(C_Na_in)) * self.K_Na # M/s dCCa_dt = -(self.currCa(s, Vm, C_Ca_in) + self.currPumpCa(C_Ca_in)) * self.K_Ca # M/s # Pack derivatives and return return [dmdt, dhdt, dndt, dsdt, dcdt, dCNa_dt, dCCa_dt] def getEffRates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Compute average cycle value for rate constants am_avg = np.mean(self.alpham(Vm)) bm_avg = np.mean(self.betam(Vm)) ah_avg = np.mean(self.alphah(Vm)) bh_avg = np.mean(self.betah(Vm)) an_avg = np.mean(self.alphan(Vm)) bn_avg = np.mean(self.betan(Vm)) as_avg = np.mean(self.alphas(Vm)) bs_avg = np.mean(self.betas(Vm)) # Return array of coefficients return np.array([am_avg, bm_avg, ah_avg, bh_avg, an_avg, bn_avg, as_avg, bs_avg]) def derStatesEff(self, Qm, states, interp_data): ''' Concrete implementation of the abstract API method. ''' rates = np.array([np.interp(Qm, interp_data['Q'], interp_data[rn]) for rn in self.coeff_names]) Vmeff = np.interp(Qm, interp_data['Q'], interp_data['V']) # Unpack states m, h, n, s, c, C_Na_in, C_Ca_in = states # Standard gating states derivatives dmdt = rates[0] * (1 - m) - rates[1] * m dhdt = rates[2] * (1 - h) - rates[3] * h dndt = rates[4] * (1 - n) - rates[5] * n dsdt = rates[6] * (1 - s) - rates[7] * s # KCa current gating state derivative dcdt = self.derC(c, C_Ca_in) # Intracellular concentrations dCNa_dt = - (self.currNa(m, h, Vmeff, C_Na_in) + self.currPumpNa(C_Na_in)) * self.K_Na # M/s dCCa_dt = -(self.currCa(s, Vmeff, C_Ca_in) + self.currPumpCa(C_Ca_in)) * self.K_Ca # M/s # Pack derivatives and return return [dmdt, dhdt, dndt, dsdt, dcdt, dCNa_dt, dCCa_dt] class LeechRetzius(LeechMech): ''' Class defining the membrane channel dynamics of a leech Retzius neuron. with 5 different current types: - Inward Sodium current - Outward Potassium current - Inward high-voltage-activated Calcium current - Non-specific leakage current - Calcium-dependent, outward Potassium current References: *Vazquez, Y., Mendez, B., Trueta, C., and De-Miguel, F.F. (2009). Summation of excitatory postsynaptic potentials in electrically-coupled neurones. Neuroscience 163, 202–212.* *ModelDB link: https://senselab.med.yale.edu/modeldb/ShowModel.cshtml?model=120910* ''' # Name of channel mechanism # name = 'LeechR' # Cell-specific biophysical parameters Cm0 = 5e-2 # Cell membrane resting capacitance (F/m2) Vm0 = -44.45 # Cell membrane resting potential (mV) VNa = 50.0 # Sodium Nernst potential, from retztemp.ses file on ModelDB (mV) VCa = 125.0 # Calcium Nernst potential, from cachdend.mod file on ModelDB (mV) VK = -79.0 # Potassium Nernst potential, from retztemp.ses file on ModelDB (mV) VL = -30.0 # Non-specific leakage Nernst potential, from leakdend.mod file on ModelDB (mV) GNaMax = 1250.0 # Max. conductance of Sodium current (S/m^2) GKMax = 10.0 # Max. conductance of Potassium current (S/m^2) GAMax = 100.0 # Max. conductance of transient Potassium current (S/m^2) GCaMax = 4.0 # Max. conductance of Calcium current (S/m^2) GKCaMax = 130.0 # Max. conductance of Calcium-dependent Potassium current (S/m^2) GL = 1.25 # Conductance of non-specific leakage current (S/m^2) Vhalf = -73.1 # mV C_Ca_in = 5e-8 # Calcium intracellular concentration, from retztemp.ses file (M) # Default plotting scheme pltvars_scheme = { 'i_{Na}\ kin.': ['m', 'h', 'm4h'], 'i_K\ kin.': ['n'], 'i_A\ kin.': ['a', 'b', 'ab'], 'i_{Ca}\ kin.': ['s'], 'i_{KCa}\ kin.': ['c'], 'I': ['iNa', 'iK', 'iCa', 'iKCa2', 'iA', 'iL', 'iNet'] } def __init__(self): ''' Constructor of the class. ''' # Names and initial states of the channels state probabilities self.states_names = ['m', 'h', 'n', 's', 'c', 'a', 'b'] self.states0 = np.array([]) # Names of the channels effective coefficients self.coeff_names = ['alpham', 'betam', 'alphah', 'betah', 'alphan', 'betan', 'alphas', 'betas', 'alphac', 'betac', 'alphaa', 'betaa' 'alphab', 'betab'] # Define initial channel probabilities (solving dx/dt = 0 at resting potential) self.states0 = self.steadyStates(self.Vm0) def ainf(self, Vm): ''' Steady-state activation probability of transient Potassium channels. Source: *Beck, H., Ficker, E., and Heinemann, U. (1992). Properties of two voltage-activated potassium currents in acutely isolated juvenile rat dentate gyrus granule cells. J. Neurophysiol. 68, 2086–2099.* :param Vm: membrane potential (mV) :return: time constant (s) ''' Vth = -55.0 # mV return 0 if Vm <= Vth else min(1, 2 * (Vm - Vth)**3 / ((11 - Vth)**3 + (Vm - Vth)**3)) def taua(self, Vm): ''' Activation time constant of transient Potassium channels. (assuming T = 20°C). Source: *Beck, H., Ficker, E., and Heinemann, U. (1992). Properties of two voltage-activated potassium currents in acutely isolated juvenile rat dentate gyrus granule cells. J. Neurophysiol. 68, 2086–2099.* :param Vm: membrane potential (mV) :return: time constant (s) ''' x = -1.5 * (Vm - self.Vhalf) * 1e-3 * self.Faraday / (self.Rg * self.T) # [-] alpha = np.exp(x) # ms-1 beta = np.exp(0.7 * x) # ms-1 return max(0.5, beta / (0.3 * (1 + alpha))) * 1e-3 # s def binf(self, Vm): ''' Steady-state inactivation probability of transient Potassium channels. Source: *Beck, H., Ficker, E., and Heinemann, U. (1992). Properties of two voltage-activated potassium currents in acutely isolated juvenile rat dentate gyrus granule cells. J. Neurophysiol. 68, 2086–2099.* :param Vm: membrane potential (mV) :return: time constant (s) ''' return 1. / (1 + np.exp((self.Vhalf - Vm) / -6.3)) def taub(self, Vm): ''' Inactivation time constant of transient Potassium channels. (assuming T = 20°C). Source: *Beck, H., Ficker, E., and Heinemann, U. (1992). Properties of two voltage-activated potassium currents in acutely isolated juvenile rat dentate gyrus granule cells. J. Neurophysiol. 68, 2086–2099.* :param Vm: membrane potential (mV) :return: time constant (s) ''' x = 2 * (Vm - self.Vhalf) * 1e-3 * self.Faraday / (self.Rg * self.T) alpha = np.exp(x) beta = np.exp(0.65 * x) return max(7.5, beta / (0.02 * (1 + alpha))) * 1e-3 # s def derA(self, Vm, a): ''' Compute the evolution of the activation-probability of transient Potassium channels. :param Vm: membrane potential (mV) :param a: activation-probability of transient Potassium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return (self.ainf(Vm) - a) / self.taua(Vm) def derB(self, Vm, b): ''' Compute the evolution of the inactivation-probability of transient Potassium channels. :param Vm: membrane potential (mV) :param b: inactivation-probability of transient Potassium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return (self.binf(Vm) - b) / self.taub(Vm) def currA(self, a, b, Vm): ''' Compute the outward, transient Potassium current per unit area. :param a: open-probability of transient Potassium channels :param b: inactivation-probability of transient Potassium channels :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' GK = self.GAMax * a * b return GK * (Vm - self.VK) def currNet(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' m, h, n, s, c, a, b = states return (self.currNa(m, h, Vm) + self.currK(n, Vm) + self.currCa(s, Vm) + self.currL(Vm) + self.currKCa(c, Vm) + self.currA(a, b, Vm)) # mA/m2 def steadyStates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Standard gating dynamics: Solve the equation dx/dt = 0 at Vm for each x-state meq = self.alpham(Vm) / (self.alpham(Vm) + self.betam(Vm)) heq = self.alphah(Vm) / (self.alphah(Vm) + self.betah(Vm)) neq = self.alphan(Vm) / (self.alphan(Vm) + self.betan(Vm)) seq = self.alphas(Vm) / (self.alphas(Vm) + self.betas(Vm)) ceq = self.alphaC(self.C_Ca_in) / (self.alphaC(self.C_Ca_in) + self.betaC) aeq = self.ainf(Vm) beq = self.binf(Vm) return np.array([meq, heq, neq, seq, ceq, aeq, beq]) def derStates(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' # Unpack states m, h, n, s, c, a, b = states # Standard gating states derivatives dmdt = self.derM(Vm, m) dhdt = self.derH(Vm, h) dndt = self.derN(Vm, n) dsdt = self.derS(Vm, s) dcdt = self.derC(c, self.C_Ca_in) dadt = self.derA(Vm, a) dbdt = self.derB(Vm, b) # Pack derivatives and return return [dmdt, dhdt, dndt, dsdt, dcdt, dadt, dbdt] def getEffRates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Compute average cycle value for rate constants am_avg = np.mean(self.alpham(Vm)) bm_avg = np.mean(self.betam(Vm)) ah_avg = np.mean(self.alphah(Vm)) bh_avg = np.mean(self.betah(Vm)) an_avg = np.mean(self.alphan(Vm)) bn_avg = np.mean(self.betan(Vm)) as_avg = np.mean(self.alphas(Vm)) bs_avg = np.mean(self.betas(Vm)) Ta = self.taua(Vm) ainf = self.ainf(Vm) aa_avg = np.mean(ainf / Ta) ba_avg = np.mean(1 / Ta) - aa_avg Tb = self.taub(Vm) binf = self.binf(Vm) ab_avg = np.mean(binf / Tb) bb_avg = np.mean(1 / Tb) - ab_avg # Return array of coefficients return np.array([am_avg, bm_avg, ah_avg, bh_avg, an_avg, bn_avg, as_avg, bs_avg, aa_avg, ba_avg, ab_avg, bb_avg]) def derStatesEff(self, Qm, states, interp_data): ''' Concrete implementation of the abstract API method. ''' rates = np.array([np.interp(Qm, interp_data['Q'], interp_data[rn]) for rn in self.coeff_names]) # Unpack states m, h, n, s, c, a, b = states # Standard gating states derivatives dmdt = rates[0] * (1 - m) - rates[1] * m dhdt = rates[2] * (1 - h) - rates[3] * h dndt = rates[4] * (1 - n) - rates[5] * n dsdt = rates[6] * (1 - s) - rates[7] * s dadt = rates[8] * (1 - a) - rates[9] * a dbdt = rates[10] * (1 - b) - rates[11] * b # KCa current gating state derivative dcdt = self.derC(c, self.C_Ca_in) # Pack derivatives and return return [dmdt, dhdt, dndt, dsdt, dcdt, dadt, dbdt] diff --git a/PySONIC/neurons/thalamic.py b/PySONIC/neurons/thalamic.py index f4fcd57..37a653d 100644 --- a/PySONIC/neurons/thalamic.py +++ b/PySONIC/neurons/thalamic.py @@ -1,788 +1,786 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-07-31 15:20:54 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-26 17:23:31 - -''' Channels mechanisms for thalamic neurons. ''' +# @Last Modified time: 2018-09-28 14:06:20 import numpy as np from ..core import PointNeuron class Thalamic(PointNeuron): ''' Class defining the generic membrane channel dynamics of a thalamic neuron with 4 different current types: - Inward Sodium current - Outward Potassium current - Inward Calcium current - Non-specific leakage current This generic class cannot be used directly as it does not contain any specific parameters. Reference: *Plaksin, M., Kimmel, E., and Shoham, S. (2016). Cell-Type-Selective Effects of Intramembrane Cavitation as a Unifying Theoretical Framework for Ultrasonic Neuromodulation. eNeuro 3.* ''' # Generic biophysical parameters of thalamic cells Cm0 = 1e-2 # Cell membrane resting capacitance (F/m2) Vm0 = 0.0 # Dummy value for membrane potential (mV) VNa = 50.0 # Sodium Nernst potential (mV) VK = -90.0 # Potassium Nernst potential (mV) VCa = 120.0 # Calcium Nernst potential (mV) def __init__(self): ''' Constructor of the class ''' # Names and initial states of the channels state probabilities self.states_names = ['m', 'h', 'n', 's', 'u'] self.states0 = np.array([]) # Names of the different coefficients to be averaged in a lookup table. self.coeff_names = ['alpham', 'betam', 'alphah', 'betah', 'alphan', 'betan', 'alphas', 'betas', 'alphau', 'betau'] # Charge interval bounds for lookup creation self.Qbounds = (np.round(self.Vm0 - 25.0) * 1e-5, 50.0e-5) def alpham(self, Vm): ''' Compute the alpha rate for the open-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT alpha = 0.32 * self.vtrap(13 - Vdiff, 4) # ms-1 return alpha * 1e3 # s-1 def betam(self, Vm): ''' Compute the beta rate for the open-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT beta = 0.28 * self.vtrap(Vdiff - 40, 5) # ms-1 return beta * 1e3 # s-1 def alphah(self, Vm): ''' Compute the alpha rate for the inactivation-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT alpha = (0.128 * np.exp(-(Vdiff - 17) / 18)) # ms-1 return alpha * 1e3 # s-1 def betah(self, Vm): ''' Compute the beta rate for the inactivation-probability of Sodium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT beta = (4 / (1 + np.exp(-(Vdiff - 40) / 5))) # ms-1 return beta * 1e3 # s-1 def alphan(self, Vm): ''' Compute the alpha rate for the open-probability of delayed-rectifier Potassium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT alpha = 0.032 * self.vtrap(15 - Vdiff, 5) # ms-1 return alpha * 1e3 # s-1 def betan(self, Vm): ''' Compute the beta rate for the open-probability of delayed-rectifier Potassium channels. :param Vm: membrane potential (mV) :return: rate constant (s-1) ''' Vdiff = Vm - self.VT beta = (0.5 * np.exp(-(Vdiff - 10) / 40)) # ms-1 return beta * 1e3 # s-1 def derM(self, Vm, m): ''' Compute the evolution of the open-probability of Sodium channels. :param Vm: membrane potential (mV) :param m: open-probability of Sodium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alpham(Vm) * (1 - m) - self.betam(Vm) * m def derH(self, Vm, h): ''' Compute the evolution of the inactivation-probability of Sodium channels. :param Vm: membrane potential (mV) :param h: inactivation-probability of Sodium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alphah(Vm) * (1 - h) - self.betah(Vm) * h def derN(self, Vm, n): ''' Compute the evolution of the open-probability of delayed-rectifier Potassium channels. :param Vm: membrane potential (mV) :param n: open-probability of delayed-rectifier Potassium channels (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return self.alphan(Vm) * (1 - n) - self.betan(Vm) * n def derS(self, Vm, s): ''' Compute the evolution of the open-probability of the S-type, activation gate of Calcium channels. :param Vm: membrane potential (mV) :param s: open-probability of S-type Calcium activation gates (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return (self.sinf(Vm) - s) / self.taus(Vm) def derU(self, Vm, u): ''' Compute the evolution of the open-probability of the U-type, inactivation gate of Calcium channels. :param Vm: membrane potential (mV) :param u: open-probability of U-type Calcium inactivation gates (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return (self.uinf(Vm) - u) / self.tauu(Vm) def currNa(self, m, h, Vm): ''' Compute the inward Sodium current per unit area. :param m: open-probability of Sodium channels :param h: inactivation-probability of Sodium channels :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' GNa = self.GNaMax * m**3 * h return GNa * (Vm - self.VNa) def currK(self, n, Vm): ''' Compute the outward delayed-rectifier Potassium current per unit area. :param n: open-probability of delayed-rectifier Potassium channels :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' GK = self.GKMax * n**4 return GK * (Vm - self.VK) def currCa(self, s, u, Vm): ''' Compute the inward Calcium current per unit area. :param s: open-probability of the S-type activation gate of Calcium channels :param u: open-probability of the U-type inactivation gate of Calcium channels :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' GT = self.GTMax * s**2 * u return GT * (Vm - self.VCa) def currL(self, Vm): ''' Compute the non-specific leakage current per unit area. :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.GL * (Vm - self.VL) def currNet(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' m, h, n, s, u = states return (self.currNa(m, h, Vm) + self.currK(n, Vm) + self.currCa(s, u, Vm) + self.currL(Vm)) # mA/m2 def steadyStates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Solve the equation dx/dt = 0 at Vm for each x-state meq = self.alpham(Vm) / (self.alpham(Vm) + self.betam(Vm)) heq = self.alphah(Vm) / (self.alphah(Vm) + self.betah(Vm)) neq = self.alphan(Vm) / (self.alphan(Vm) + self.betan(Vm)) seq = self.sinf(Vm) ueq = self.uinf(Vm) return np.array([meq, heq, neq, seq, ueq]) def derStates(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' m, h, n, s, u = states dmdt = self.derM(Vm, m) dhdt = self.derH(Vm, h) dndt = self.derN(Vm, n) dsdt = self.derS(Vm, s) dudt = self.derU(Vm, u) return [dmdt, dhdt, dndt, dsdt, dudt] def getEffRates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Compute average cycle value for rate constants am_avg = np.mean(self.alpham(Vm)) bm_avg = np.mean(self.betam(Vm)) ah_avg = np.mean(self.alphah(Vm)) bh_avg = np.mean(self.betah(Vm)) an_avg = np.mean(self.alphan(Vm)) bn_avg = np.mean(self.betan(Vm)) Ts = self.taus(Vm) as_avg = np.mean(self.sinf(Vm) / Ts) bs_avg = np.mean(1 / Ts) - as_avg Tu = np.array([self.tauu(v) for v in Vm]) au_avg = np.mean(self.uinf(Vm) / Tu) bu_avg = np.mean(1 / Tu) - au_avg # Return array of coefficients return np.array([am_avg, bm_avg, ah_avg, bh_avg, an_avg, bn_avg, as_avg, bs_avg, au_avg, bu_avg]) def derStatesEff(self, Qm, states, interp_data): ''' Concrete implementation of the abstract API method. ''' rates = np.array([np.interp(Qm, interp_data['Q'], interp_data[rn]) for rn in self.coeff_names]) m, h, n, s, u = states dmdt = rates[0] * (1 - m) - rates[1] * m dhdt = rates[2] * (1 - h) - rates[3] * h dndt = rates[4] * (1 - n) - rates[5] * n dsdt = rates[6] * (1 - s) - rates[7] * s dudt = rates[8] * (1 - u) - rates[9] * u return [dmdt, dhdt, dndt, dsdt, dudt] class ThalamicRE(Thalamic): ''' Specific membrane channel dynamics of a thalamic reticular neuron. References: *Destexhe, A., Contreras, D., Steriade, M., Sejnowski, T.J., and Huguenard, J.R. (1996). In vivo, in vitro, and computational analysis of dendritic calcium currents in thalamic reticular neurons. J. Neurosci. 16, 169–185.* *Huguenard, J.R., and Prince, D.A. (1992). A novel T-type current underlies prolonged Ca(2+)-dependent burst firing in GABAergic neurons of rat thalamic reticular nucleus. J. Neurosci. 12, 3804–3817.* ''' # Name of channel mechanism name = 'RE' # Cell-specific biophysical parameters Vm0 = -89.5 # Cell membrane resting potential (mV) GNaMax = 2000.0 # Max. conductance of Sodium current (S/m^2) GKMax = 200.0 # Max. conductance of Potassium current (S/m^2) GTMax = 30.0 # Max. conductance of low-threshold Calcium current (S/m^2) GL = 0.5 # Conductance of non-specific leakage current (S/m^2) VL = -90.0 # Non-specific leakage Nernst potential (mV) VT = -67.0 # Spike threshold adjustment parameter (mV) # Default plotting scheme pltvars_scheme = { 'i_{Na}\ kin.': ['m', 'h', 'm3h'], 'i_K\ kin.': ['n'], 'i_{TS}\ kin.': ['s', 'u', 's2u'], 'I': ['iNa', 'iK', 'iTs', 'iL', 'iNet'] } def __init__(self): ''' Constructor of the class. ''' # Instantiate parent class super().__init__() # Define initial channel probabilities (solving dx/dt = 0 at resting potential) self.states0 = self.steadyStates(self.Vm0) def sinf(self, Vm): ''' Compute the asymptotic value of the open-probability of the S-type, activation gate of Calcium channels. :param Vm: membrane potential (mV) :return: asymptotic probability (-) ''' return 1.0 / (1.0 + np.exp(-(Vm + 52.0) / 7.4)) # prob def taus(self, Vm): ''' Compute the decay time constant for adaptation of S-type, activation gate of Calcium channels. :param Vm: membrane potential (mV) :return: decayed time constant (s) ''' return (1 + 0.33 / (np.exp((Vm + 27.0) / 10.0) + np.exp(-(Vm + 102.0) / 15.0))) * 1e-3 # s def uinf(self, Vm): ''' Compute the asymptotic value of the open-probability of the U-type, inactivation gate of Calcium channels. :param Vm: membrane potential (mV) :return: asymptotic probability (-) ''' return 1.0 / (1.0 + np.exp((Vm + 80.0) / 5.0)) # prob def tauu(self, Vm): ''' Compute the decay time constant for adaptation of U-type, inactivation gate of Calcium channels. :param Vm: membrane potential (mV) :return: decayed time constant (s) ''' return (28.3 + 0.33 / (np.exp((Vm + 48.0) / 4.0) + np.exp(-(Vm + 407.0) / 50.0))) * 1e-3 # s class ThalamoCortical(Thalamic): ''' Specific membrane channel dynamics of a thalamo-cortical neuron, with a specific hyperpolarization-activated, mixed cationic current and a leakage Potassium current. References: *Pospischil, M., Toledo-Rodriguez, M., Monier, C., Piwkowska, Z., Bal, T., Frégnac, Y., Markram, H., and Destexhe, A. (2008). Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biol Cybern 99, 427–441.* *Destexhe, A., Bal, T., McCormick, D.A., and Sejnowski, T.J. (1996). Ionic mechanisms underlying synchronized oscillations and propagating waves in a model of ferret thalamic slices. J. Neurophysiol. 76, 2049–2070.* *McCormick, D.A., and Huguenard, J.R. (1992). A model of the electrophysiological properties of thalamocortical relay neurons. J. Neurophysiol. 68, 1384–1400.* ''' # Name of channel mechanism name = 'TC' # Cell-specific biophysical parameters # Vm0 = -63.4 # Cell membrane resting potential (mV) Vm0 = -61.93 # Cell membrane resting potential (mV) GNaMax = 900.0 # Max. conductance of Sodium current (S/m^2) GKMax = 100.0 # Max. conductance of Potassium current (S/m^2) GTMax = 20.0 # Max. conductance of low-threshold Calcium current (S/m^2) GKL = 0.138 # Conductance of leakage Potassium current (S/m^2) GhMax = 0.175 # Max. conductance of mixed cationic current (S/m^2) GL = 0.1 # Conductance of non-specific leakage current (S/m^2) Vh = -40.0 # Mixed cationic current reversal potential (mV) VL = -70.0 # Non-specific leakage Nernst potential (mV) VT = -52.0 # Spike threshold adjustment parameter (mV) Vx = 0.0 # Voltage-dependence uniform shift factor at 36°C (mV) tau_Ca_removal = 5e-3 # decay time constant for intracellular Ca2+ dissolution (s) CCa_min = 50e-9 # minimal intracellular Calcium concentration (M) deff = 100e-9 # effective depth beneath membrane for intracellular [Ca2+] calculation F_Ca = 1.92988e5 # Faraday constant for bivalent ion (Coulomb / mole) nCa = 4 # number of Calcium binding sites on regulating factor k1 = 2.5e22 # intracellular Ca2+ regulation factor (M-4 s-1) k2 = 0.4 # intracellular Ca2+ regulation factor (s-1) k3 = 100.0 # intracellular Ca2+ regulation factor (s-1) k4 = 1.0 # intracellular Ca2+ regulation factor (s-1) # Default plotting scheme pltvars_scheme = { 'i_{Na}\ kin.': ['m', 'h'], 'i_K\ kin.': ['n'], 'i_{T}\ kin.': ['s', 'u'], 'i_{H}\ kin.': ['O', 'OL', 'O + 2OL'], 'I': ['iNa', 'iK', 'iT', 'iH', 'iKL', 'iL', 'iNet'] } def __init__(self): ''' Constructor of the class. ''' # Instantiate parent class super().__init__() # Compute current to concentration conversion constant self.iT_2_CCa = 1e-6 / (self.deff * self.F_Ca) # Define names of the channels state probabilities self.states_names += ['O', 'C', 'P0', 'C_Ca'] # Define the names of the different coefficients to be averaged in a lookup table. self.coeff_names += ['alphao', 'betao'] # Define initial channel probabilities (solving dx/dt = 0 at resting potential) self.states0 = self.steadyStates(self.Vm0) def sinf(self, Vm): ''' Compute the asymptotic value of the open-probability of the S-type, activation gate of Calcium channels. Reference: *Pospischil, M., Toledo-Rodriguez, M., Monier, C., Piwkowska, Z., Bal, T., Frégnac, Y., Markram, H., and Destexhe, A. (2008). Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biol Cybern 99, 427–441.* :param Vm: membrane potential (mV) :return: asymptotic probability (-) ''' return 1.0 / (1.0 + np.exp(-(Vm + self.Vx + 57.0) / 6.2)) # prob def taus(self, Vm): ''' Compute the decay time constant for adaptation of S-type, activation gate of Calcium channels. Reference: *Pospischil, M., Toledo-Rodriguez, M., Monier, C., Piwkowska, Z., Bal, T., Frégnac, Y., Markram, H., and Destexhe, A. (2008). Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biol Cybern 99, 427–441.* :param Vm: membrane potential (mV) :return: decayed time constant (s) ''' tmp = np.exp(-(Vm + self.Vx + 132.0) / 16.7) + np.exp((Vm + self.Vx + 16.8) / 18.2) return 1.0 / 3.7 * (0.612 + 1.0 / tmp) * 1e-3 # s def uinf(self, Vm): ''' Compute the asymptotic value of the open-probability of the U-type, inactivation gate of Calcium channels. Reference: *Pospischil, M., Toledo-Rodriguez, M., Monier, C., Piwkowska, Z., Bal, T., Frégnac, Y., Markram, H., and Destexhe, A. (2008). Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biol Cybern 99, 427–441.* :param Vm: membrane potential (mV) :return: asymptotic probability (-) ''' return 1.0 / (1.0 + np.exp((Vm + self.Vx + 81.0) / 4.0)) # prob def tauu(self, Vm): ''' Compute the decay time constant for adaptation of U-type, inactivation gate of Calcium channels. Reference: *Pospischil, M., Toledo-Rodriguez, M., Monier, C., Piwkowska, Z., Bal, T., Frégnac, Y., Markram, H., and Destexhe, A. (2008). Minimal Hodgkin-Huxley type models for different classes of cortical and thalamic neurons. Biol Cybern 99, 427–441.* :param Vm: membrane potential (mV) :return: decayed time constant (s) ''' if Vm + self.Vx < -80.0: return 1.0 / 3.7 * np.exp((Vm + self.Vx + 467.0) / 66.6) * 1e-3 # s else: return 1 / 3.7 * (np.exp(-(Vm + self.Vx + 22) / 10.5) + 28.0) * 1e-3 # s def derS(self, Vm, s): ''' Compute the evolution of the open-probability of the S-type, activation gate of Calcium channels. :param Vm: membrane potential (mV) :param s: open-probability of S-type Calcium activation gates (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return (self.sinf(Vm) - s) / self.taus(Vm) def derU(self, Vm, u): ''' Compute the evolution of the open-probability of the U-type, inactivation gate of Calcium channels. :param Vm: membrane potential (mV) :param u: open-probability of U-type Calcium inactivation gates (prob) :return: derivative of open-probability w.r.t. time (prob/s) ''' return (self.uinf(Vm) - u) / self.tauu(Vm) def oinf(self, Vm): ''' Voltage-dependent steady-state activation of hyperpolarization-activated cation current channels. Reference: *Huguenard, J.R., and McCormick, D.A. (1992). Simulation of the currents involved in rhythmic oscillations in thalamic relay neurons. J. Neurophysiol. 68, 1373–1383.* :param Vm: membrane potential (mV) :return: steady-state activation (-) ''' return 1.0 / (1.0 + np.exp((Vm + 75.0) / 5.5)) def tauo(self, Vm): ''' Time constant for activation of hyperpolarization-activated cation current channels. Reference: *Huguenard, J.R., and McCormick, D.A. (1992). Simulation of the currents involved in rhythmic oscillations in thalamic relay neurons. J. Neurophysiol. 68, 1373–1383.* :param Vm: membrane potential (mV) :return: time constant (s) ''' return 1 / (np.exp(-14.59 - 0.086 * Vm) + np.exp(-1.87 + 0.0701 * Vm)) * 1e-3 def alphao(self, Vm): ''' Transition rate between closed and open form of hyperpolarization-activated cation current channels. :param Vm: membrane potential (mV) :return: transition rate (s-1) ''' return self.oinf(Vm) / self.tauo(Vm) def betao(self, Vm): ''' Transition rate between open and closed form of hyperpolarization-activated cation current channels. :param Vm: membrane potential (mV) :return: transition rate (s-1) ''' return (1 - self.oinf(Vm)) / self.tauo(Vm) def derC(self, C, O, Vm): ''' Compute the evolution of the proportion of hyperpolarization-activated cation current channels in closed state. Kinetics scheme of Calcium dependent activation derived from: *Destexhe, A., Bal, T., McCormick, D.A., and Sejnowski, T.J. (1996). Ionic mechanisms underlying synchronized oscillations and propagating waves in a model of ferret thalamic slices. J. Neurophysiol. 76, 2049–2070.* :param Vm: membrane potential (mV) :param C: proportion of Ih channels in closed state (-) :param O: proportion of Ih channels in open state (-) :return: derivative of proportion w.r.t. time (s-1) ''' return self.betao(Vm) * O - self.alphao(Vm) * C def derO(self, C, O, P0, Vm): ''' Compute the evolution of the proportion of hyperpolarization-activated cation current channels in open state. Kinetics scheme of Calcium dependent activation derived from: *Destexhe, A., Bal, T., McCormick, D.A., and Sejnowski, T.J. (1996). Ionic mechanisms underlying synchronized oscillations and propagating waves in a model of ferret thalamic slices. J. Neurophysiol. 76, 2049–2070.* :param Vm: membrane potential (mV) :param C: proportion of Ih channels in closed state (-) :param O: proportion of Ih channels in open state (-) :param P0: proportion of Ih channels regulating factor in unbound state (-) :return: derivative of proportion w.r.t. time (s-1) ''' return - self.derC(C, O, Vm) - self.k3 * O * (1 - P0) + self.k4 * (1 - O - C) def derP0(self, P0, C_Ca): ''' Compute the evolution of the proportion of Ih channels regulating factor in unbound state. Kinetics scheme of Calcium dependent activation derived from: *Destexhe, A., Bal, T., McCormick, D.A., and Sejnowski, T.J. (1996). Ionic mechanisms underlying synchronized oscillations and propagating waves in a model of ferret thalamic slices. J. Neurophysiol. 76, 2049–2070.* :param Vm: membrane potential (mV) :param P0: proportion of Ih channels regulating factor in unbound state (-) :param C_Ca: Calcium concentration in effective submembranal space (M) :return: derivative of proportion w.r.t. time (s-1) ''' return self.k2 * (1 - P0) - self.k1 * P0 * C_Ca**self.nCa def derC_Ca(self, C_Ca, ICa): ''' Compute the evolution of the Calcium concentration in submembranal space. Model of Ca2+ buffering and contribution from iCa derived from: *McCormick, D.A., and Huguenard, J.R. (1992). A model of the electrophysiological properties of thalamocortical relay neurons. J. Neurophysiol. 68, 1384–1400.* :param Vm: membrane potential (mV) :param C_Ca: Calcium concentration in submembranal space (M) :param ICa: inward Calcium current filling up the submembranal space with Ca2+ (mA/m2) :return: derivative of Calcium concentration in submembranal space w.r.t. time (s-1) ''' return (self.CCa_min - C_Ca) / self.tau_Ca_removal - self.iT_2_CCa * ICa def currKL(self, Vm): ''' Compute the voltage-dependent leak Potassium current per unit area. :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' return self.GKL * (Vm - self.VK) def currH(self, O, C, Vm): ''' Compute the outward mixed cationic current per unit area. :param O: proportion of the channels in open form :param OL: proportion of the channels in locked-open form :param Vm: membrane potential (mV) :return: current per unit area (mA/m2) ''' OL = 1 - O - C return self.GhMax * (O + 2 * OL) * (Vm - self.Vh) def currNet(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' m, h, n, s, u, O, C, _, _ = states return (self.currNa(m, h, Vm) + self.currK(n, Vm) + self.currCa(s, u, Vm) + self.currKL(Vm) + self.currH(O, C, Vm) + self.currL(Vm)) # mA/m2 def steadyStates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Call parent method to compute Sodium, Potassium and Calcium channels gates steady-states NaKCa_eqstates = super().steadyStates(Vm) # Compute steady-state Calcium current seq = NaKCa_eqstates[3] ueq = NaKCa_eqstates[4] iTeq = self.currCa(seq, ueq, Vm) # Compute steady-state variables for the kinetics system of Ih CCa_eq = self.CCa_min - self.tau_Ca_removal * self.iT_2_CCa * iTeq P0_eq = self.k2 / (self.k2 + self.k1 * CCa_eq**self.nCa) BA = self.betao(Vm) / self.alphao(Vm) O_eq = self.k4 / (self.k3 * (1 - P0_eq) + self.k4 * (1 + BA)) C_eq = BA * O_eq kin_eqstates = np.array([O_eq, C_eq, P0_eq, CCa_eq]) # Merge all steady-states and return return np.concatenate((NaKCa_eqstates, kin_eqstates)) def derStates(self, Vm, states): ''' Concrete implementation of the abstract API method. ''' m, h, n, s, u, O, C, P0, C_Ca = states NaKCa_states = [m, h, n, s, u] NaKCa_derstates = super().derStates(Vm, NaKCa_states) dO_dt = self.derO(C, O, P0, Vm) dC_dt = self.derC(C, O, Vm) dP0_dt = self.derP0(P0, C_Ca) ICa = self.currCa(s, u, Vm) dCCa_dt = self.derC_Ca(C_Ca, ICa) return NaKCa_derstates + [dO_dt, dC_dt, dP0_dt, dCCa_dt] def getEffRates(self, Vm): ''' Concrete implementation of the abstract API method. ''' # Compute effective coefficients for Sodium, Potassium and Calcium conductances NaKCa_effrates = super().getEffRates(Vm) # Compute effective coefficients for Ih conductance ao_avg = np.mean(self.alphao(Vm)) bo_avg = np.mean(self.betao(Vm)) iH_effrates = np.array([ao_avg, bo_avg]) # Return array of coefficients return np.concatenate((NaKCa_effrates, iH_effrates)) def derStatesEff(self, Qm, states, interp_data): ''' Concrete implementation of the abstract API method. ''' rates = np.array([np.interp(Qm, interp_data['Q'], interp_data[rn]) for rn in self.coeff_names]) Vmeff = np.interp(Qm, interp_data['Q'], interp_data['V']) # Unpack states m, h, n, s, u, O, C, P0, C_Ca = states # INa, IK, ICa effective states derivatives dmdt = rates[0] * (1 - m) - rates[1] * m dhdt = rates[2] * (1 - h) - rates[3] * h dndt = rates[4] * (1 - n) - rates[5] * n dsdt = rates[6] * (1 - s) - rates[7] * s dudt = rates[8] * (1 - u) - rates[9] * u # Ih effective states derivatives dC_dt = rates[11] * O - rates[10] * C dO_dt = - dC_dt - self.k3 * O * (1 - P0) + self.k4 * (1 - O - C) dP0_dt = self.derP0(P0, C_Ca) ICa_eff = self.currCa(s, u, Vmeff) dCCa_dt = self.derC_Ca(C_Ca, ICa_eff) # Merge derivatives and return return [dmdt, dhdt, dndt, dsdt, dudt, dO_dt, dC_dt, dP0_dt, dCCa_dt] diff --git a/PySONIC/plt/actmap.py b/PySONIC/plt/actmap.py index 0c1ad4d..74b2635 100644 --- a/PySONIC/plt/actmap.py +++ b/PySONIC/plt/actmap.py @@ -1,279 +1,285 @@ +# -*- coding: utf-8 -*- +# @Author: Theo Lemaire +# @Date: 2018-09-26 16:47:18 +# @Last Modified by: Theo Lemaire +# @Last Modified time: 2018-09-28 14:06:48 + import os import ntpath import pickle import numpy as np import matplotlib.pyplot as plt import matplotlib from matplotlib.ticker import FormatStrFormatter from ..core import NeuronalBilayerSonophore from ..utils import logger, si_format, ASTIM_filecode, cm2inch from ..postpro import findPeaks from ..constants import * from ..neurons import getNeuronsDict def getActivationMap(root, neuron, a, Fdrive, tstim, PRF, amps, DCs): ''' Compute the activation map of a neuron at a given frequency and PRF, by computing the spiking metrics of simulation results over a 2D space (amplitude x duty cycle). :param root: directory containing the input data files :param neuron: neuron name :param a: sonophore diameter :param Fdrive: US frequency (Hz) :param tstim: duration of US stimulation (s) :param PRF: pulse repetition frequency (Hz) :param amps: vector of acoustic amplitudes (Pa) :param DCs: vector of duty cycles (-) :return the activation matrix ''' # Load activation map from file if it exists actmap_filename = 'actmap {} {}Hz PRF{}Hz {}s.csv'.format( neuron, *si_format([Fdrive, PRF, tstim], space='')) actmap_filepath = os.path.join(root, actmap_filename) if os.path.isfile(actmap_filepath): logger.info('Loading activation map for %s neuron', neuron) return np.loadtxt(actmap_filepath, delimiter=',') # Otherwise generate it logger.info('Generating activation map for %s neuron', neuron) actmap = np.empty((amps.size, DCs.size)) nfiles = DCs.size * amps.size for i, A in enumerate(amps): for j, DC in enumerate(DCs): fname = '{}.pkl'.format(ASTIM_filecode(neuron, a, Fdrive, A, tstim, PRF, DC, 'sonic')) fpath = os.path.join(root, fname) if not os.path.isfile(fpath): logger.error('"{}" file not found'.format(fname)) actmap[i, j] = np.nan else: # Load data logger.debug('Loading file {}/{}: "{}"'.format(i * amps.size + j + 1, nfiles, fname)) with open(fpath, 'rb') as fh: frame = pickle.load(fh) df = frame['data'] meta = frame['meta'] tstim = meta['tstim'] t = df['t'].values Qm = df['Qm'].values dt = t[1] - t[0] # Detect spikes on charge profile during stimulus mpd = int(np.ceil(SPIKE_MIN_DT / dt)) ispikes, *_ = findPeaks(Qm[t <= tstim], SPIKE_MIN_QAMP, mpd, SPIKE_MIN_QPROM) # Compute firing metrics if ispikes.size == 0: # if no spike, assign -1 actmap[i, j] = -1 elif ispikes.size == 1: # if only 1 spike, assign 0 actmap[i, j] = 0 else: # if more than 1 spike, assign firing rate FRs = 1 / np.diff(t[ispikes]) actmap[i, j] = np.mean(FRs) # Save activation map to file np.savetxt(actmap_filepath, actmap, delimiter=',') return actmap def computeMeshEdges(x, scale='lin'): ''' Compute the appropriate edges of a mesh that quads a linear or logarihtmic distribution. :param x: the input vector :param scale: the type of distribution ('lin' for linear, 'log' for logarihtmic) :return: the edges vector ''' if scale is 'log': x = np.log10(x) dx = x[1] - x[0] if scale is 'lin': y = np.linspace(x[0] - dx / 2, x[-1] + dx / 2, x.size + 1) elif scale is 'log': y = np.logspace(x[0] - dx / 2, x[-1] + dx / 2, x.size + 1) return y def onClick(event, root, neuron, a, Fdrive, tstim, PRF, amps, DCs, meshedges, tmax, Vbounds): ''' Retrieve the specific input parameters of the x and y dimensions when the user clicks on a cell in the 2D map, and define filename from it. ''' # Get DC and A from x and y coordinates x, y = event.xdata, event.ydata DC = DCs[np.searchsorted(meshedges[0], x * 1e-2) - 1] Adrive = amps[np.searchsorted(meshedges[1], y * 1e3) - 1] # Define filepath fname = '{}.pkl'.format(ASTIM_filecode(neuron, a, Fdrive, Adrive, tstim, PRF, DC, 'sonic')) filepath = os.path.join(root, fname) # Plot Q-trace try: plotQVeff(filepath, tmax=tmax, ybounds=Vbounds) plt.show() except FileNotFoundError as err: logger.error(err) def plotQVeff(filepath, tonset=10, tmax=None, ybounds=None, fs=8, lw=1): ''' Plot superimposed profiles of membrane charge density and effective membrane potential. :param filepath: full path to the data file :param tonset: pre-stimulus onset to add to profiles (ms) :param tmax: max time value showed on graph (ms) :param ybounds: y-axis bounds (mV / nC/cm2) :return: handle to the generated figure ''' # Check file existence fname = ntpath.basename(filepath) if not os.path.isfile(filepath): raise FileNotFoundError('Error: "{}" file does not exist'.format(fname)) # Load data logger.debug('Loading data from "%s"', fname) with open(filepath, 'rb') as fh: frame = pickle.load(fh) df = frame['data'] meta = frame['meta'] t = df['t'].values * 1e3 # ms Qm = df['Qm'].values * 1e5 # nC/cm2 Vm = df['Vm'].values # mV # Add onset to profiles t = np.hstack((np.array([-tonset, t[0]]), t)) Vm = np.hstack((np.array([getNeuronsDict()[meta['neuron']]().Vm0] * 2), Vm)) Qm = np.hstack((np.array([Qm[0]] * 2), Qm)) # Determine axes bounds if tmax is None: tmax = t.max() if ybounds is None: ybounds = (min(Vm.min(), Qm.min()), max(Vm.max(), Qm.max())) # Create figure fig, ax = plt.subplots(figsize=cm2inch(7, 3)) fig.canvas.set_window_title(fname) plt.subplots_adjust(left=0.2, bottom=0.2, right=0.95, top=0.95) for key in ['top', 'right']: ax.spines[key].set_visible(False) for key in ['bottom', 'left']: ax.spines[key].set_position(('axes', -0.03)) ax.spines[key].set_linewidth(2) ax.yaxis.set_tick_params(width=2) ax.yaxis.set_major_formatter(FormatStrFormatter('%.0f')) print(-tonset, tmax) ax.set_xlim((-tonset, tmax)) ax.set_xticks([]) ax.set_xlabel('{}s'.format(si_format((tonset + tmax) * 1e-3, space=' ')), fontsize=fs) ax.set_ylabel('mV - $\\rm nC/cm^2$', fontsize=fs, labelpad=-15) ax.set_ylim(ybounds) ax.set_yticks(ybounds) for item in ax.get_yticklabels(): item.set_fontsize(fs) # Plot Qm and Vmeff profiles ax.plot(t, Vm, color='darkgrey', linewidth=lw) ax.plot(t, Qm, color='k', linewidth=lw) # fig.tight_layout() return fig def plotActivationMap(root, neuron, a, Fdrive, tstim, PRF, amps, DCs, Ascale='log', FRscale='log', FRbounds=None, title=None, fs=8, rheobase=True, connect=False, tmax=None, Vbounds=None): ''' Plot a neuron's activation map over the amplitude x duty cycle 2D space. :param root: directory containing the input data files :param neuron: neuron name :param a: sonophore diameter :param Fdrive: US frequency (Hz) :param tstim: duration of US stimulation (s) :param PRF: pulse repetition frequency (Hz) :param amps: vector of acoustic amplitudes (Pa) :param DCs: vector of duty cycles (-) :param Ascale: scale to use for the amplitude dimension ('lin' or 'log') :param FRscale: scale to use for the firing rate coloring ('lin' or 'log') :param FRbounds: lower and upper bounds of firing rate color-scale :param title: figure title :param fs: fontsize to use for the title and labels :return: 3-tuple with the handle to the generated figure and the mesh x and y coordinates ''' # Get activation map actmap = getActivationMap(root, neuron, a, Fdrive, tstim, PRF, amps, DCs) # Check firing rate bounding minFR, maxFR = (actmap[actmap > 0].min(), actmap.max()) logger.info('FR range: %.0f - %.0f Hz', minFR, maxFR) if FRbounds is None: FRbounds = (minFR, maxFR) else: if minFR < FRbounds[0]: logger.warning('Minimal firing rate (%.0f Hz) is below defined lower bound (%.0f Hz)', minFR, FRbounds[0]) if maxFR > FRbounds[1]: logger.warning('Maximal firing rate (%.0f Hz) is above defined upper bound (%.0f Hz)', maxFR, FRbounds[1]) # Plot activation map if FRscale == 'lin': norm = matplotlib.colors.Normalize(*FRbounds) elif FRscale == 'log': norm = matplotlib.colors.LogNorm(*FRbounds) fig, ax = plt.subplots(figsize=cm2inch(8, 5.8)) fig.subplots_adjust(left=0.15, bottom=0.15, right=0.8, top=0.92) if title is None: title = '{} neuron @ {}Hz, {}Hz PRF ({}m sonophore)'.format( neuron, *si_format([Fdrive, PRF, a])) ax.set_title(title, fontsize=fs) if Ascale == 'log': ax.set_yscale('log') ax.set_xlabel('Duty cycle (%)', fontsize=fs, labelpad=-0.5) ax.set_ylabel('Amplitude (kPa)', fontsize=fs) ax.set_xlim(np.array([DCs.min(), DCs.max()]) * 1e2) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) xedges = computeMeshEdges(DCs) yedges = computeMeshEdges(amps, scale=Ascale) actmap[actmap == -1] = np.nan actmap[actmap == 0] = 1e-3 cmap = plt.get_cmap('viridis') cmap.set_bad('silver') cmap.set_under('k') ax.pcolormesh(xedges * 1e2, yedges * 1e-3, actmap, cmap=cmap, norm=norm) # Plot rheobase amplitudes if specified if rheobase: logger.info('Computing rheobase amplitudes') dDC = 0.01 DCs_dense = np.arange(dDC, 100 + dDC / 2, dDC) / 1e2 neuronobj = getNeuronsDict()[neuron]() nbls = NeuronalBilayerSonophore(a, neuronobj) Athrs = nbls.findRheobaseAmps(DCs_dense, Fdrive, neuron.VT) ax.plot(DCs_dense * 1e2, Athrs * 1e-3, '-', color='#F26522', linewidth=2, label='threshold amplitudes') ax.legend(loc='lower center', frameon=False, fontsize=8) # Plot firing rate colorbar sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm._A = [] pos1 = ax.get_position() # get the map axis position cbarax = fig.add_axes([pos1.x1 + 0.02, pos1.y0, 0.03, pos1.height]) fig.colorbar(sm, cax=cbarax) cbarax.set_ylabel('Firing rate (Hz)', fontsize=fs) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) # Link callback to figure if connect: fig.canvas.mpl_connect( 'button_press_event', lambda event: onClick(event, root, neuron, a, Fdrive, tstim, PRF, amps, DCs, (xedges, yedges), tmax, Vbounds) ) - return fig \ No newline at end of file + return fig diff --git a/PySONIC/plt/batch.py b/PySONIC/plt/batch.py index 15b5515..12c96e2 100644 --- a/PySONIC/plt/batch.py +++ b/PySONIC/plt/batch.py @@ -1,231 +1,237 @@ +# -*- coding: utf-8 -*- +# @Author: Theo Lemaire +# @Date: 2018-09-25 16:19:19 +# @Last Modified by: Theo Lemaire +# @Last Modified time: 2018-09-28 14:06:57 + import sys import pickle import ntpath import numpy as np import matplotlib.pyplot as plt from ..utils import * from ..core import BilayerSonophore from .pltvars import pltvars from ..neurons import getNeuronsDict def plotBatch(filepaths, vars_dict=None, plt_save=False, directory=None, ask_before_save=True, fig_ext='png', tag='fig', fs=15, lw=2, title=True, show_patches=True): ''' Plot a figure with profiles of several specific NICE output variables, for several NICE simulations. :param filepaths: list of full paths to output data files to be compared :param vars_dict: dict of lists of variables names to extract and plot together :param plt_save: boolean stating whether to save the created figures :param directory: directory where to save figures :param ask_before_save: boolean stating whether to show the created figures :param fig_ext: file extension for the saved figures :param tag: suffix added to the end of the figures name :param fs: labels font size :param lw: curves line width :param title: boolean stating whether to display a general title on the figures :param show_patches: boolean indicating whether to indicate periods of stimulation with colored rectangular patches ''' # Check validity of plot variables if vars_dict: yvars = list(sum(list(vars_dict.values()), [])) for key in yvars: if key not in pltvars: raise KeyError('Unknown plot variable: "{}"'.format(key)) # Dictionary of neurons neurons_dict = getNeuronsDict() # Loop through data files figs = [] for filepath in filepaths: # Get code from file name pkl_filename = ntpath.basename(filepath) filecode = pkl_filename[0:-4] # Retrieve sim type mo1 = rgxp.fullmatch(pkl_filename) mo2 = rgxp_mech.fullmatch(pkl_filename) if mo1: mo = mo1 elif mo2: mo = mo2 else: logger.error('Error: "%s" file does not match regexp pattern', pkl_filename) sys.exit(1) sim_type = mo.group(1) if sim_type not in ('MECH', 'ASTIM', 'ESTIM'): raise ValueError('Invalid simulation type: {}'.format(sim_type)) # Load data logger.info('Loading data from "%s"', pkl_filename) with open(filepath, 'rb') as fh: frame = pickle.load(fh) df = frame['data'] meta = frame['meta'] # Extract variables logger.info('Extracting variables') t = df['t'].values states = df['states'].values nsamples = t.size # Initialize channel mechanism if sim_type in ['ASTIM', 'ESTIM']: neuron_name = mo.group(2) global neuron neuron = neurons_dict[neuron_name]() neuron_states = [df[sn].values for sn in neuron.states_names] Cm0 = neuron.Cm0 Qm0 = Cm0 * neuron.Vm0 * 1e-3 t_plt = pltvars['t_ms'] else: Cm0 = meta['Cm0'] Qm0 = meta['Qm0'] t_plt = pltvars['t_us'] # Initialize BLS if sim_type in ['MECH', 'ASTIM']: global bls Fdrive = meta['Fdrive'] a = meta['a'] bls = BilayerSonophore(a, Cm0, Qm0) # Determine patches location npatches, tpatch_on, tpatch_off = getStimPulses(t, states) # Adding onset to time vector if t_plt['onset'] > 0.0: tonset = np.array([-t_plt['onset'], -t[0] - t[1]]) t = np.hstack((tonset, t)) states = np.hstack((states, np.zeros(2))) # Determine variables to plot if not provided if not vars_dict: if sim_type == 'ASTIM': vars_dict = {'Z': ['Z'], 'Q_m': ['Qm']} elif sim_type == 'ESTIM': vars_dict = {'V_m': ['Vm']} elif sim_type == 'MECH': vars_dict = {'P_{AC}': ['Pac'], 'Z': ['Z'], 'n_g': ['ng']} if sim_type in ['ASTIM', 'ESTIM'] and hasattr(neuron, 'pltvars_scheme'): vars_dict.update(neuron.pltvars_scheme) labels = list(vars_dict.keys()) naxes = len(vars_dict) # Plotting if naxes == 1: fig, ax = plt.subplots(figsize=(11, 4)) axes = [ax] else: fig, axes = plt.subplots(naxes, 1, figsize=(11, min(3 * naxes, 9))) for i in range(naxes): ax = axes[i] for item in ['top', 'right']: ax.spines[item].set_visible(False) ax_pltvars = [pltvars[j] for j in vars_dict[labels[i]]] nvars = len(ax_pltvars) # X-axis if i < naxes - 1: ax.get_xaxis().set_ticklabels([]) else: ax.set_xlabel('${}\ ({})$'.format(t_plt['label'], t_plt['unit']), fontsize=fs) for tick in ax.xaxis.get_major_ticks(): tick.label.set_fontsize(fs) # Y-axis if ax_pltvars[0]['unit']: ax.set_ylabel('${}\ ({})$'.format(labels[i], ax_pltvars[0]['unit']), fontsize=fs) else: ax.set_ylabel('${}$'.format(labels[i]), fontsize=fs) if 'min' in ax_pltvars[0] and 'max' in ax_pltvars[0]: ax_min = min([ap['min'] for ap in ax_pltvars]) ax_max = max([ap['max'] for ap in ax_pltvars]) ax.set_ylim(ax_min, ax_max) ax.locator_params(axis='y', nbins=2) for tick in ax.yaxis.get_major_ticks(): tick.label.set_fontsize(fs) # Time series icolor = 0 for j in range(nvars): # Extract variable pltvar = ax_pltvars[j] if 'alias' in pltvar: var = eval(pltvar['alias']) elif 'key' in pltvar: var = df[pltvar['key']].values elif 'constant' in pltvar: var = eval(pltvar['constant']) * np.ones(nsamples) else: var = df[vars_dict[labels[i]][j]].values if var.size == t.size - 2: if pltvar['desc'] == 'membrane potential': var = np.hstack((np.array([neuron.Vm0] * 2), var)) else: var = np.hstack((np.array([var[0]] * 2), var)) # var = np.insert(var, 0, var[0]) # Plot variable if 'constant' in pltvar or pltvar['desc'] in ['net current']: ax.plot(t * t_plt['factor'], var * pltvar['factor'], '--', c='black', lw=lw, label='${}$'.format(pltvar['label'])) else: ax.plot(t * t_plt['factor'], var * pltvar['factor'], c='C{}'.format(icolor), lw=lw, label='${}$'.format(pltvar['label'])) icolor += 1 # Patches if show_patches == 1: (ybottom, ytop) = ax.get_ylim() for j in range(npatches): ax.axvspan(tpatch_on[j] * t_plt['factor'], tpatch_off[j] * t_plt['factor'], edgecolor='none', facecolor='#8A8A8A', alpha=0.2) # Legend if nvars > 1: ax.legend(fontsize=fs, loc=7, ncol=nvars // 4 + 1) # Title if title: if sim_type == 'ESTIM': fig_title = ESTIM_title( neuron.name, meta['Astim'], meta['tstim'] * 1e3, meta['PRF'], meta['DC'] * 1e2) elif sim_type == 'ASTIM': fig_title = ASTIM_title( neuron.name, Fdrive * 1e-3, meta['Adrive'] * 1e-3, meta['tstim'] * 1e3, meta['PRF'], meta['DC'] * 1e2) elif sim_type == 'MECH': fig_title = MECH_title(a * 1e9, Fdrive * 1e-3, meta['Adrive'] * 1e-3) axes[0].set_title(fig_title, fontsize=fs) fig.tight_layout() # Save figure if needed (automatic or checked) if plt_save: if directory is None: directory = os.path.split(filepath)[0] if ask_before_save: plt_filename = SaveFileDialog( '{}_{}.{}'.format(filecode, tag, fig_ext), dirname=directory, ext=fig_ext) else: plt_filename = '{}/{}_{}.{}'.format(directory, filecode, tag, fig_ext) if plt_filename: plt.savefig(plt_filename) logger.info('Saving figure as "{}"'.format(plt_filename)) plt.close() figs.append(fig) return figs diff --git a/PySONIC/plt/comp.py b/PySONIC/plt/comp.py index 061c886..609ff77 100644 --- a/PySONIC/plt/comp.py +++ b/PySONIC/plt/comp.py @@ -1,363 +1,369 @@ +# -*- coding: utf-8 -*- +# @Author: Theo Lemaire +# @Date: 2018-09-25 16:18:45 +# @Last Modified by: Theo Lemaire +# @Last Modified time: 2018-09-28 14:14:10 + import sys import pickle import ntpath import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from matplotlib.ticker import FormatStrFormatter from ..utils import * from .pltvars import pltvars from ..core import BilayerSonophore from ..neurons import getNeuronsDict class InteractiveLegend(object): - """ Class defining an interactive matplotlib legend, where lines visibility can + ''' Class defining an interactive matplotlib legend, where lines visibility can be toggled by simply clicking on the corresponding legend label. Other graphic objects can also be associated to the toggle of a specific line Adapted from: http://stackoverflow.com/questions/31410043/hiding-lines-after-showing-a-pyplot-figure - """ + ''' def __init__(self, legend, aliases): self.legend = legend self.fig = legend.axes.figure self.lookup_artist, self.lookup_handle = self._build_lookups(legend) self._setup_connections() self.handles_aliases = aliases self.update() def _setup_connections(self): for artist in self.legend.texts + self.legend.legendHandles: artist.set_picker(10) # 10 points tolerance self.fig.canvas.mpl_connect('pick_event', self.on_pick) def _build_lookups(self, legend): ''' Method of the InteractiveLegend class building the legend lookups. ''' labels = [t.get_text() for t in legend.texts] handles = legend.legendHandles label2handle = dict(zip(labels, handles)) handle2text = dict(zip(handles, legend.texts)) lookup_artist = {} lookup_handle = {} for artist in legend.axes.get_children(): if artist.get_label() in labels: handle = label2handle[artist.get_label()] lookup_handle[artist] = handle lookup_artist[handle] = artist lookup_artist[handle2text[handle]] = artist lookup_handle.update(zip(handles, handles)) lookup_handle.update(zip(legend.texts, handles)) return lookup_artist, lookup_handle def on_pick(self, event): handle = event.artist if handle in self.lookup_artist: artist = self.lookup_artist[handle] artist.set_visible(not artist.get_visible()) self.update() def update(self): for artist in self.lookup_artist.values(): handle = self.lookup_handle[artist] if artist.get_visible(): handle.set_visible(True) if artist in self.handles_aliases: for al in self.handles_aliases[artist]: al.set_visible(True) else: handle.set_visible(False) if artist in self.handles_aliases: for al in self.handles_aliases[artist]: al.set_visible(False) self.fig.canvas.draw() def show(self): ''' showing the interactive legend ''' plt.show() def plotComp(filepaths, varname, labels=None, fs=15, lw=2, colors=None, lines=None, patches='one', xticks=None, yticks=None, blacklegend=False, straightlegend=False, inset=None, figsize=(11, 4)): ''' Compare profiles of several specific output variables of NICE simulations. :param filepaths: list of full paths to output data files to be compared :param varname: name of variable to extract and compare :param labels: list of labels to use in the legend :param fs: labels fontsize :param patches: string indicating whether to indicate periods of stimulation with colored rectangular patches ''' # Input check 1: variable name if varname not in pltvars: raise KeyError('Unknown plot variable: "{}"'.format(varname)) pltvar = pltvars[varname] # Input check 2: labels if labels is not None: if len(labels) != len(filepaths): raise AssertionError('Invalid labels ({}): not matching number of compared files ({})' .format(len(labels), len(filepaths))) if not all(isinstance(x, str) for x in labels): raise TypeError('Invalid labels: must be string typed') # Input check 3: line styles and colors if colors is None: colors = ['C{}'.format(j) for j in range(len(filepaths))] if lines is None: lines = ['-'] * len(filepaths) # Input check 4: STIM-ON patches greypatch = False if patches == 'none': patches = [False] * len(filepaths) elif patches == 'all': patches = [True] * len(filepaths) elif patches == 'one': patches = [True] + [False] * (len(filepaths) - 1) greypatch = True elif isinstance(patches, list): if len(patches) != len(filepaths): raise AssertionError('Invalid patches ({}): not matching number of compared files ({})' .format(len(patches), len(filepaths))) if not all(isinstance(p, bool) for p in patches): raise TypeError('Invalid patch sequence: all list items must be boolean typed') else: raise ValueError('Invalid patches: must be either "none", all", "one", or a boolean list') # Initialize figure and axis fig, ax = plt.subplots(figsize=figsize) ax.set_zorder(0) for item in ['top', 'right']: ax.spines[item].set_visible(False) if 'min' in pltvar and 'max' in pltvar: # optional min and max on y-axis ax.set_ylim(pltvar['min'], pltvar['max']) if pltvar['unit']: # y-label with optional unit ax.set_ylabel('$\\rm {}\ ({})$'.format(pltvar['label'], pltvar['unit']), fontsize=fs) else: ax.set_ylabel('$\\rm {}$'.format(pltvar['label']), fontsize=fs) if xticks is not None: # optional x-ticks ax.set_xticks(xticks) if yticks is not None: # optional y-ticks ax.set_yticks(yticks) else: ax.locator_params(axis='y', nbins=2) if any(ax.get_yticks() < 0): ax.yaxis.set_major_formatter(FormatStrFormatter('%+.0f')) for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks(): tick.label.set_fontsize(fs) # Optional inset axis if inset is not None: inset_ax = fig.add_axes(ax.get_position()) inset_ax.set_xlim(inset['xlims'][0], inset['xlims'][1]) inset_ax.set_ylim(inset['ylims'][0], inset['ylims'][1]) inset_ax.set_xticks([]) inset_ax.set_yticks([]) # inset_ax.patch.set_alpha(1.0) inset_ax.set_zorder(1) inset_ax.add_patch(Rectangle((inset['xlims'][0], inset['ylims'][0]), inset['xlims'][1] - inset['xlims'][0], inset['ylims'][1] - inset['ylims'][0], color='w')) # Retrieve neurons dictionary neurons_dict = getNeuronsDict() # Loop through data files aliases = {} for j, filepath in enumerate(filepaths): # Retrieve sim type pkl_filename = ntpath.basename(filepath) mo1 = rgxp.fullmatch(pkl_filename) mo2 = rgxp_mech.fullmatch(pkl_filename) if mo1: mo = mo1 elif mo2: mo = mo2 else: logger.error('Error: "%s" file does not match regexp pattern', pkl_filename) sys.exit(1) sim_type = mo.group(1) if sim_type not in ('MECH', 'ASTIM', 'ESTIM'): raise ValueError('Invalid simulation type: {}'.format(sim_type)) if j == 0: sim_type_ref = sim_type t_plt = pltvars[timeunits[sim_type]] elif sim_type != sim_type_ref: raise ValueError('Invalid comparison: different simulation types') # Load data logger.info('Loading data from "%s"', pkl_filename) with open(filepath, 'rb') as fh: frame = pickle.load(fh) df = frame['data'] meta = frame['meta'] # Extract variables t = df['t'].values states = df['states'].values nsamples = t.size # Initialize neuron object if ESTIM or ASTIM sim type if sim_type in ['ASTIM', 'ESTIM']: neuron_name = mo.group(2) global neuron neuron = neurons_dict[neuron_name]() Cm0 = neuron.Cm0 Qm0 = Cm0 * neuron.Vm0 * 1e-3 # Extract neuron states if needed if 'alias' in pltvar and 'neuron_states' in pltvar['alias']: neuron_states = [df[sn].values for sn in neuron.states_names] else: Cm0 = meta['Cm0'] Qm0 = meta['Qm0'] # Initialize BLS if needed if sim_type in ['MECH', 'ASTIM'] and 'alias' in pltvar and 'bls' in pltvar['alias']: global bls bls = BilayerSonophore(meta['a'], Cm0, Qm0) # Determine patches location npatches, tpatch_on, tpatch_off = getStimPulses(t, states) # Add onset to time vectors if t_plt['onset'] > 0.0: tonset = np.array([-t_plt['onset'], -t[0] - t[1]]) t = np.hstack((tonset, t)) states = np.hstack((states, np.zeros(2))) # Set x-axis label ax.set_xlabel('$\\rm {}\ ({})$'.format(t_plt['label'], t_plt['unit']), fontsize=fs) # Extract variable to plot if 'alias' in pltvar: var = eval(pltvar['alias']) elif 'key' in pltvar: var = df[pltvar['key']].values elif 'constant' in pltvar: var = eval(pltvar['constant']) * np.ones(nsamples) else: var = df[varname].values if var.size == t.size - 2: if varname is 'Vm': var = np.hstack((np.array([neuron.Vm0] * 2), var)) else: var = np.hstack((np.array([var[0]] * 2), var)) # var = np.insert(var, 0, var[0]) # Determine legend label if labels is not None: label = labels[j] else: if sim_type == 'ESTIM': label = ESTIM_title( neuron.name, meta['Astim'], meta['tstim'] * 1e3, meta['PRF'], meta['DC'] * 1e2) elif sim_type == 'ASTIM': label = ASTIM_title( neuron.name, meta['Fdrive'] * 1e-3, meta['Adrive'] * 1e-3, meta['tstim'] * 1e3, meta['PRF'], meta['DC'] * 1e2) elif sim_type == 'MECH': label = MECH_title(a * 1e9, meta['Fdrive'] * 1e-3, meta['Adrive'] * 1e-3) # Plot trace handle = ax.plot(t * t_plt['factor'], var * pltvar['factor'], linewidth=lw, linestyle=lines[j], color=colors[j], label=label) if inset is not None: inset_window = np.logical_and(t > (inset['xlims'][0] / t_plt['factor']), t < (inset['xlims'][1] / t_plt['factor'])) inset_ax.plot(t[inset_window] * t_plt['factor'], var[inset_window] * pltvar['factor'], linewidth=lw, linestyle=lines[j], color=colors[j]) # Add optional STIM-ON patches if patches[j]: (ybottom, ytop) = ax.get_ylim() la = [] color = '#8A8A8A' if greypatch else handle[0].get_color() for i in range(npatches): la.append(ax.axvspan(tpatch_on[i] * t_plt['factor'], tpatch_off[i] * t_plt['factor'], edgecolor='none', facecolor=color, alpha=0.2)) aliases[handle[0]] = la if inset is not None: cond_on = np.logical_and(tpatch_on > (inset['xlims'][0] / t_plt['factor']), tpatch_on < (inset['xlims'][1] / t_plt['factor'])) cond_off = np.logical_and(tpatch_off > (inset['xlims'][0] / t_plt['factor']), tpatch_off < (inset['xlims'][1] / t_plt['factor'])) cond_glob = np.logical_and(tpatch_on < (inset['xlims'][0] / t_plt['factor']), tpatch_off > (inset['xlims'][1] / t_plt['factor'])) cond_onoff = np.logical_or(cond_on, cond_off) cond = np.logical_or(cond_onoff, cond_glob) npatches_inset = np.sum(cond) for i in range(npatches_inset): inset_ax.add_patch(Rectangle((tpatch_on[cond][i] * t_plt['factor'], ybottom), (tpatch_off[cond][i] - tpatch_on[cond][i]) * t_plt['factor'], ytop - ybottom, color=color, alpha=0.1)) fig.tight_layout() # Optional operations on inset: if inset is not None: # Re-position inset axis axpos = ax.get_position() left, right, = rescale(inset['xcoords'], ax.get_xlim()[0], ax.get_xlim()[1], axpos.x0, axpos.x0 + axpos.width) bottom, top, = rescale(inset['ycoords'], ax.get_ylim()[0], ax.get_ylim()[1], axpos.y0, axpos.y0 + axpos.height) inset_ax.set_position([left, bottom, right - left, top - bottom]) for i in inset_ax.spines.values(): i.set_linewidth(2) # Materialize inset target region with contour frame ax.plot(inset['xlims'], [inset['ylims'][0]] * 2, linestyle='-', color='k') ax.plot(inset['xlims'], [inset['ylims'][1]] * 2, linestyle='-', color='k') ax.plot([inset['xlims'][0]] * 2, inset['ylims'], linestyle='-', color='k') ax.plot([inset['xlims'][1]] * 2, inset['ylims'], linestyle='-', color='k') # Link target and inset with dashed lines if possible if inset['xcoords'][1] < inset['xlims'][0]: ax.plot([inset['xcoords'][1], inset['xlims'][0]], [inset['ycoords'][0], inset['ylims'][0]], linestyle='--', color='k') ax.plot([inset['xcoords'][1], inset['xlims'][0]], [inset['ycoords'][1], inset['ylims'][1]], linestyle='--', color='k') elif inset['xcoords'][0] > inset['xlims'][1]: ax.plot([inset['xcoords'][0], inset['xlims'][1]], [inset['ycoords'][0], inset['ylims'][0]], linestyle='--', color='k') ax.plot([inset['xcoords'][0], inset['xlims'][1]], [inset['ycoords'][1], inset['ylims'][1]], linestyle='--', color='k') else: logger.warning('Inset x-coordinates intersect with those of target region') # Create interactive legend leg = ax.legend(loc=1, fontsize=fs, frameon=False) if blacklegend: for l in leg.get_lines(): l.set_color('k') if straightlegend: for l in leg.get_lines(): l.set_linestyle('-') interactive_legend = InteractiveLegend(ax.legend_, aliases) return fig \ No newline at end of file diff --git a/PySONIC/postpro.py b/PySONIC/postpro.py index b3d953b..6183593 100644 --- a/PySONIC/postpro.py +++ b/PySONIC/postpro.py @@ -1,436 +1,436 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-08-22 14:33:04 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-26 17:23:51 +# @Last Modified time: 2018-09-28 14:14:11 ''' Utility functions to detect spikes on signals and compute spiking metrics. ''' import pickle import numpy as np import pandas as pd from .constants import * from .utils import logger def detectPeaks(x, mph=None, mpd=1, threshold=0, edge='rising', kpsh=False, valley=False, ax=None): ''' Detect peaks in data based on their amplitude and other features. Adapted from Marco Duarte: http://nbviewer.jupyter.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb :param x: 1D array_like data. :param mph: minimum peak height (default = None). :param mpd: minimum peak distance in indexes (default = 1) :param threshold : minimum peak prominence (default = 0) :param edge : for a flat peak, keep only the rising edge ('rising'), only the falling edge ('falling'), both edges ('both'), or don't detect a flat peak (None). (default = 'rising') :param kpsh: keep peaks with same height even if they are closer than `mpd` (default = False). :param valley: detect valleys (local minima) instead of peaks (default = False). :param show: plot data in matplotlib figure (default = False). :param ax: a matplotlib.axes.Axes instance, optional (default = None). :return: 1D array with the indices of the peaks ''' print('min peak height:', mph, ', min peak distance:', mpd, ', min peak prominence:', threshold) # Convert input to numpy array x = np.atleast_1d(x).astype('float64') # Revert signal sign for valley detection if valley: x = -x # Differentiate signal dx = np.diff(x) # Find indices of all peaks with edge criterion ine, ire, ife = np.array([[], [], []], dtype=int) if not edge: ine = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) > 0))[0] else: if edge.lower() in ['rising', 'both']: ire = np.where((np.hstack((dx, 0)) <= 0) & (np.hstack((0, dx)) > 0))[0] if edge.lower() in ['falling', 'both']: ife = np.where((np.hstack((dx, 0)) < 0) & (np.hstack((0, dx)) >= 0))[0] ind = np.unique(np.hstack((ine, ire, ife))) # Remove first and last values of x if they are detected as peaks if ind.size and ind[0] == 0: ind = ind[1:] if ind.size and ind[-1] == x.size - 1: ind = ind[:-1] print('{} raw peaks'.format(ind.size)) # Remove peaks < minimum peak height if ind.size and mph is not None: ind = ind[x[ind] >= mph] print('{} height-filtered peaks'.format(ind.size)) # Remove peaks - neighbors < threshold if ind.size and threshold > 0: dx = np.min(np.vstack([x[ind] - x[ind - 1], x[ind] - x[ind + 1]]), axis=0) ind = np.delete(ind, np.where(dx < threshold)[0]) print('{} prominence-filtered peaks'.format(ind.size)) # Detect small peaks closer than minimum peak distance if ind.size and mpd > 1: ind = ind[np.argsort(x[ind])][::-1] # sort ind by peak height idel = np.zeros(ind.size, dtype=bool) for i in range(ind.size): if not idel[i]: # keep peaks with the same height if kpsh is True idel = idel | (ind >= ind[i] - mpd) & (ind <= ind[i] + mpd) \ & (x[ind[i]] > x[ind] if kpsh else True) idel[i] = 0 # Keep current peak # remove the small peaks and sort back the indices by their occurrence ind = np.sort(ind[~idel]) print('{} distance-filtered peaks'.format(ind.size)) return ind def detectPeaksTime(t, y, mph, mtd, mpp=0): - """ Extension of the detectPeaks function to detect peaks in data based on their + ''' Extension of the detectPeaks function to detect peaks in data based on their amplitude and time difference, with a non-uniform time vector. :param t: time vector (not necessarily uniform) :param y: signal :param mph: minimal peak height :param mtd: minimal time difference :mpp: minmal peak prominence :return: array of peak indexes - """ + ''' # Determine whether time vector is uniform (threshold in time step variation) dt = np.diff(t) if (dt.max() - dt.min()) / dt.min() < 1e-2: isuniform = True else: isuniform = False if isuniform: print('uniform time vector') dt = t[1] - t[0] mpd = int(np.ceil(mtd / dt)) ipeaks = detectPeaks(y, mph, mpd=mpd, threshold=mpp) else: print('non-uniform time vector') # Detect peaks on signal with no restriction on inter-peak distance irawpeaks = detectPeaks(y, mph, mpd=1, threshold=mpp) npeaks = irawpeaks.size if npeaks > 0: # Filter relevant peaks with temporal distance ipeaks = [irawpeaks[0]] for i in range(1, npeaks): i1 = ipeaks[-1] i2 = irawpeaks[i] if t[i2] - t[i1] < mtd: if y[i2] > y[i1]: ipeaks[-1] = i2 else: ipeaks.append(i2) else: ipeaks = [] ipeaks = np.array(ipeaks) return ipeaks def detectSpikes(t, Qm, min_amp, min_dt): ''' Detect spikes on a charge density signal, and return their number, latency and rate. :param t: time vector (s) :param Qm: charge density vector (C/m2) :param min_amp: minimal charge amplitude to detect spikes (C/m2) :param min_dt: minimal time interval between 2 spikes (s) :return: 3-tuple with number of spikes, latency (s) and spike rate (sp/s) ''' i_spikes = detectPeaksTime(t, Qm, min_amp, min_dt) if len(i_spikes) > 0: latency = t[i_spikes[0]] # s n_spikes = i_spikes.size if n_spikes > 1: first_to_last_spike = t[i_spikes[-1]] - t[i_spikes[0]] # s spike_rate = (n_spikes - 1) / first_to_last_spike # spikes/s else: spike_rate = 'N/A' else: latency = 'N/A' spike_rate = 'N/A' n_spikes = 0 return (n_spikes, latency, spike_rate) def findPeaks(y, mph=None, mpd=None, mpp=None): ''' Detect peaks in a signal based on their height, prominence and/or separating distance. :param y: signal vector :param mph: minimum peak height (in signal units, default = None). :param mpd: minimum inter-peak distance (in indexes, default = None) :param mpp: minimum peak prominence (in signal units, default = None) :return: 4-tuple of arrays with the indexes of peaks occurence, peaks prominence, peaks width at half-prominence and peaks half-prominence bounds (left and right) Adapted from: - Marco Duarte's detect_peaks function (http://nbviewer.jupyter.org/github/demotu/BMC/blob/master/notebooks/DetectPeaks.ipynb) - MATLAB findpeaks function (https://ch.mathworks.com/help/signal/ref/findpeaks.html) ''' # Define empty output empty = (np.array([]),) * 4 # Differentiate signal dy = np.diff(y) # Find all peaks and valleys # s = np.sign(dy) # ipeaks = np.where(np.diff(s) < 0.0)[0] + 1 # ivalleys = np.where(np.diff(s) > 0.0)[0] + 1 ipeaks = np.where((np.hstack((dy, 0)) <= 0) & (np.hstack((0, dy)) > 0))[0] ivalleys = np.where((np.hstack((dy, 0)) >= 0) & (np.hstack((0, dy)) < 0))[0] # Return empty output if no peak detected if ipeaks.size == 0: return empty logger.debug('%u peaks found, starting at index %u and ending at index %u', ipeaks.size, ipeaks[0], ipeaks[-1]) if ivalleys.size > 0: logger.debug('%u valleys found, starting at index %u and ending at index %u', ivalleys.size, ivalleys[0], ivalleys[-1]) else: logger.debug('no valleys found') # Ensure each peak is bounded by two valleys, adding signal boundaries as valleys if necessary if ivalleys.size == 0 or ipeaks[0] < ivalleys[0]: ivalleys = np.insert(ivalleys, 0, -1) if ipeaks[-1] > ivalleys[-1]: ivalleys = np.insert(ivalleys, ivalleys.size, y.size - 1) if ivalleys.size - ipeaks.size != 1: logger.debug('Cleaning up incongruities') i = 0 while i < min(ipeaks.size, ivalleys.size) - 1: if ipeaks[i] < ivalleys[i]: # 2 peaks between consecutive valleys -> remove lowest idel = i - 1 if y[ipeaks[i - 1]] < y[ipeaks[i]] else i logger.debug('Removing abnormal peak at index %u', ipeaks[idel]) ipeaks = np.delete(ipeaks, idel) if ipeaks[i] > ivalleys[i + 1]: idel = i + 1 if y[ivalleys[i]] < y[ivalleys[i + 1]] else i logger.debug('Removing abnormal valley at index %u', ivalleys[idel]) ivalleys = np.delete(ivalleys, idel) else: i += 1 logger.debug('Post-cleanup: %u peaks and %u valleys', ipeaks.size, ivalleys.size) # Remove peaks < minimum peak height if mph is not None: ipeaks = ipeaks[y[ipeaks] >= mph] if ipeaks.size == 0: return empty # Detect small peaks closer than minimum peak distance if mpd is not None: ipeaks = ipeaks[np.argsort(y[ipeaks])][::-1] # sort ipeaks by descending peak height idel = np.zeros(ipeaks.size, dtype=bool) # initialize boolean deletion array (all false) for i in range(ipeaks.size): # for each peak if not idel[i]: # if not marked for deletion closepeaks = (ipeaks >= ipeaks[i] - mpd) & (ipeaks <= ipeaks[i] + mpd) # close peaks idel = idel | closepeaks # mark for deletion along with previously marked peaks # idel = idel | (ipeaks >= ipeaks[i] - mpd) & (ipeaks <= ipeaks[i] + mpd) idel[i] = 0 # keep current peak # remove the small peaks and sort back the indices by their occurrence ipeaks = np.sort(ipeaks[~idel]) # Detect smallest valleys between consecutive relevant peaks ibottomvalleys = [] if ipeaks[0] > ivalleys[0]: itrappedvalleys = ivalleys[ivalleys < ipeaks[0]] ibottomvalleys.append(itrappedvalleys[np.argmin(y[itrappedvalleys])]) for i, j in zip(ipeaks[:-1], ipeaks[1:]): itrappedvalleys = ivalleys[np.logical_and(ivalleys > i, ivalleys < j)] ibottomvalleys.append(itrappedvalleys[np.argmin(y[itrappedvalleys])]) if ipeaks[-1] < ivalleys[-1]: itrappedvalleys = ivalleys[ivalleys > ipeaks[-1]] ibottomvalleys.append(itrappedvalleys[np.argmin(y[itrappedvalleys])]) ipeaks = ipeaks ivalleys = np.array(ibottomvalleys, dtype=int) # Ensure each peak is bounded by two valleys, adding signal boundaries as valleys if necessary if ipeaks[0] < ivalleys[0]: ivalleys = np.insert(ivalleys, 0, 0) if ipeaks[-1] > ivalleys[-1]: ivalleys = np.insert(ivalleys, ivalleys.size, y.size - 1) # Remove peaks < minimum peak prominence if mpp is not None: # Compute peaks prominences as difference between peaks and their closest valley prominences = y[ipeaks] - np.amax((y[ivalleys[:-1]], y[ivalleys[1:]]), axis=0) # initialize peaks and valleys deletion tables idelp = np.zeros(ipeaks.size, dtype=bool) idelv = np.zeros(ivalleys.size, dtype=bool) # for each peak (sorted by ascending prominence order) for ind in np.argsort(prominences): ipeak = ipeaks[ind] # get peak index # get peak bases as first valleys on either side not marked for deletion indleftbase = ind indrightbase = ind + 1 while idelv[indleftbase]: indleftbase -= 1 while idelv[indrightbase]: indrightbase += 1 ileftbase = ivalleys[indleftbase] irightbase = ivalleys[indrightbase] # Compute peak prominence and mark for deletion if < mpp indmaxbase = indleftbase if y[ileftbase] > y[irightbase] else indrightbase if y[ipeak] - y[ivalleys[indmaxbase]] < mpp: idelp[ind] = True # mark peak for deletion idelv[indmaxbase] = True # mark highest surrouding valley for deletion # remove irrelevant peaks and valleys, and sort back the indices by their occurrence ipeaks = np.sort(ipeaks[~idelp]) ivalleys = np.sort(ivalleys[~idelv]) if ipeaks.size == 0: return empty # Compute peaks prominences and reference half-prominence levels prominences = y[ipeaks] - np.amax((y[ivalleys[:-1]], y[ivalleys[1:]]), axis=0) refheights = y[ipeaks] - prominences / 2 # Compute half-prominence bounds ibounds = np.empty((ipeaks.size, 2)) for i in range(ipeaks.size): # compute the index of the left-intercept at half max ileft = ipeaks[i] while ileft >= ivalleys[i] and y[ileft] > refheights[i]: ileft -= 1 if ileft < ivalleys[i]: # intercept exactly on valley ibounds[i, 0] = ivalleys[i] else: # interpolate intercept linearly between signal boundary points a = (y[ileft + 1] - y[ileft]) / 1 b = y[ileft] - a * ileft ibounds[i, 0] = (refheights[i] - b) / a # compute the index of the right-intercept at half max iright = ipeaks[i] while iright <= ivalleys[i + 1] and y[iright] > refheights[i]: iright += 1 if iright > ivalleys[i + 1]: # intercept exactly on valley ibounds[i, 1] = ivalleys[i + 1] else: # interpolate intercept linearly between signal boundary points if iright == y.size - 1: # special case: if end of signal is reached, decrement iright iright -= 1 a = (y[iright + 1] - y[iright]) / 1 b = y[iright] - a * iright ibounds[i, 1] = (refheights[i] - b) / a # Compute peaks widths at half-prominence widths = np.diff(ibounds, axis=1) return (ipeaks - 1, prominences, widths, ibounds) def computeSpikeMetrics(filenames): ''' Analyze the charge density profile from a list of files and compute for each one of them the following spiking metrics: - latency (ms) - firing rate mean and standard deviation (Hz) - spike amplitude mean and standard deviation (nC/cm2) - spike width mean and standard deviation (ms) :param filenames: list of files to analyze :return: a dataframe with the computed metrics ''' # Initialize metrics dictionaries keys = [ 'latencies (ms)', 'mean firing rates (Hz)', 'std firing rates (Hz)', 'mean spike amplitudes (nC/cm2)', 'std spike amplitudes (nC/cm2)', 'mean spike widths (ms)', 'std spike widths (ms)' ] metrics = {k: [] for k in keys} # Compute spiking metrics for fname in filenames: # Load data from file logger.debug('loading data from file "{}"'.format(fname)) with open(fname, 'rb') as fh: frame = pickle.load(fh) df = frame['data'] meta = frame['meta'] tstim = meta['tstim'] t = df['t'].values Qm = df['Qm'].values dt = t[1] - t[0] # Detect spikes on charge profile mpd = int(np.ceil(SPIKE_MIN_DT / dt)) ispikes, prominences, widths, _ = findPeaks(Qm, SPIKE_MIN_QAMP, mpd, SPIKE_MIN_QPROM) widths *= dt if ispikes.size > 0: # Compute latency latency = t[ispikes[0]] # Select prior-offset spikes ispikes_prior = ispikes[t[ispikes] < tstim] else: latency = np.nan ispikes_prior = np.array([]) # Compute spikes widths and amplitude if ispikes_prior.size > 0: widths_prior = widths[:ispikes_prior.size] prominences_prior = prominences[:ispikes_prior.size] else: widths_prior = np.array([np.nan]) prominences_prior = np.array([np.nan]) # Compute inter-spike intervals and firing rates if ispikes_prior.size > 1: ISIs_prior = np.diff(t[ispikes_prior]) FRs_prior = 1 / ISIs_prior else: ISIs_prior = np.array([np.nan]) FRs_prior = np.array([np.nan]) # Log spiking metrics logger.debug('%u spikes detected (%u prior to offset)', ispikes.size, ispikes_prior.size) logger.debug('latency: %.2f ms', latency * 1e3) logger.debug('average spike width within stimulus: %.2f +/- %.2f ms', np.nanmean(widths_prior) * 1e3, np.nanstd(widths_prior) * 1e3) logger.debug('average spike amplitude within stimulus: %.2f +/- %.2f nC/cm2', np.nanmean(prominences_prior) * 1e5, np.nanstd(prominences_prior) * 1e5) logger.debug('average ISI within stimulus: %.2f +/- %.2f ms', np.nanmean(ISIs_prior) * 1e3, np.nanstd(ISIs_prior) * 1e3) logger.debug('average FR within stimulus: %.2f +/- %.2f Hz', np.nanmean(FRs_prior), np.nanstd(FRs_prior)) # Complete metrics dictionaries metrics['latencies (ms)'].append(latency * 1e3) metrics['mean firing rates (Hz)'].append(np.mean(FRs_prior)) metrics['std firing rates (Hz)'].append(np.std(FRs_prior)) metrics['mean spike amplitudes (nC/cm2)'].append(np.mean(prominences_prior) * 1e5) metrics['std spike amplitudes (nC/cm2)'].append(np.std(prominences_prior) * 1e5) metrics['mean spike widths (ms)'].append(np.mean(widths_prior) * 1e3) metrics['std spike widths (ms)'].append(np.std(widths_prior) * 1e3) # Return dataframe with metrics return pd.DataFrame(metrics, columns=metrics.keys()) diff --git a/PySONIC/utils.py b/PySONIC/utils.py index 6ee73f2..295c9e7 100644 --- a/PySONIC/utils.py +++ b/PySONIC/utils.py @@ -1,395 +1,395 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-09-19 22:30:46 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-27 02:23:20 +# @Last Modified time: 2018-09-28 14:14:11 -""" Definition of generic utility functions used in other modules """ +''' Definition of generic utility functions used in other modules ''' import operator import os import pickle import re import tkinter as tk from tkinter import filedialog import numpy as np import colorlog from scipy.interpolate import interp1d import matplotlib # Matplotlib parameters matplotlib.rcParams['pdf.fonttype'] = 42 matplotlib.rcParams['ps.fonttype'] = 42 matplotlib.rcParams['font.family'] = 'arial' # Package logger def setLogger(): log_formatter = colorlog.ColoredFormatter( '%(log_color)s %(asctime)s %(message)s', datefmt='%d/%m/%Y %H:%M:%S:', reset=True, log_colors={ 'DEBUG': 'green', 'INFO': 'white', 'WARNING': 'yellow', 'ERROR': 'red', 'CRITICAL': 'red,bg_white', }, style='%' ) log_handler = colorlog.StreamHandler() log_handler.setFormatter(log_formatter) color_logger = colorlog.getLogger('PySONIC') color_logger.addHandler(log_handler) return color_logger logger = setLogger() # File naming conventions def ESTIM_filecode(neuron, Astim, tstim, PRF, DC): return 'ESTIM_{}_{}_{:.1f}mA_per_m2_{:.0f}ms{}'.format( neuron, 'CW' if DC == 1 else 'PW', Astim, tstim * 1e3, '_PRF{:.2f}Hz_DC{:.2f}%'.format(PRF, DC * 1e2) if DC < 1. else '') def ASTIM_filecode(neuron, a, Fdrive, Adrive, tstim, PRF, DC, method): return 'ASTIM_{}_{}_{:.0f}nm_{:.0f}kHz_{:.1f}kPa_{:.0f}ms_{}{}'.format( neuron, 'CW' if DC == 1 else 'PW', a * 1e9, Fdrive * 1e-3, Adrive * 1e-3, tstim * 1e3, 'PRF{:.2f}Hz_DC{:.2f}%_'.format(PRF, DC * 1e2) if DC < 1. else '', method) def MECH_filecode(a, Fdrive, Adrive, Qm): return 'MECH_{:.0f}nm_{:.0f}kHz_{:.1f}kPa_{:.1f}nCcm2'.format( a * 1e9, Fdrive * 1e-3, Adrive * 1e-3, Qm * 1e5) rgxp = re.compile('(ESTIM|ASTIM)_([A-Za-z]*)_(.*).pkl') rgxp_mech = re.compile('(MECH)_(.*).pkl') # Figure naming conventions def ESTIM_title(name, A, t, PRF, DC): return '{} neuron: {} E-STIM {:.2f}mA/m2, {:.0f}ms{}'.format( name, 'PW' if DC < 1. else 'CW', A, t, ', {:.2f}Hz PRF, {:.0f}% DC'.format(PRF, DC) if DC < 1. else '') def ASTIM_title(name, f, A, t, PRF, DC): return '{} neuron: {} A-STIM {:.0f}kHz {:.0f}kPa, {:.0f}ms{}'.format( name, 'PW' if DC < 1. else 'CW', f, A, t, ', {:.2f}Hz PRF, {:.0f}% DC'.format(PRF, DC) if DC < 1. else '') def MECH_title(a, f, A): return '{:.0f}nm BLS structure: MECH-STIM {:.0f}kHz, {:.0f}kPa'.format(a, f, A) timeunits = { 'ASTIM': 't_ms', 'ESTIM': 't_ms', 'MECH': 't_us' } def cm2inch(*tupl): inch = 2.54 if isinstance(tupl[0], tuple): return tuple(i / inch for i in tupl[0]) else: return tuple(i / inch for i in tupl) # SI units prefixes si_prefixes = { 'y': 1e-24, # yocto 'z': 1e-21, # zepto 'a': 1e-18, # atto 'f': 1e-15, # femto 'p': 1e-12, # pico 'n': 1e-9, # nano 'u': 1e-6, # micro 'm': 1e-3, # mili '': 1e0, # None 'k': 1e3, # kilo 'M': 1e6, # mega 'G': 1e9, # giga 'T': 1e12, # tera 'P': 1e15, # peta 'E': 1e18, # exa 'Z': 1e21, # zetta 'Y': 1e24, # yotta } def si_format(x, precision=0, space=' '): ''' Format a float according to the SI unit system, with the appropriate prefix letter. ''' if isinstance(x, float) or isinstance(x, int) or isinstance(x, np.float) or\ isinstance(x, np.int32) or isinstance(x, np.int64): if x == 0: factor = 1e0 prefix = '' else: sorted_si_prefixes = sorted(si_prefixes.items(), key=operator.itemgetter(1)) vals = [tmp[1] for tmp in sorted_si_prefixes] # vals = list(si_prefixes.values()) ix = np.searchsorted(vals, np.abs(x)) - 1 if np.abs(x) == vals[ix + 1]: ix += 1 factor = vals[ix] prefix = sorted_si_prefixes[ix][0] # prefix = list(si_prefixes.keys())[ix] return '{{:.{}f}}{}{}'.format(precision, space, prefix).format(x / factor) elif isinstance(x, list) or isinstance(x, tuple): return [si_format(item, precision, space) for item in x] elif isinstance(x, np.ndarray) and x.ndim == 1: return [si_format(float(item), precision, space) for item in x] else: print(type(x)) def pow10_format(number, precision=2): ''' Format a number in power of 10 notation. ''' ret_string = '{0:.{1:d}e}'.format(number, precision) a, b = ret_string.split("e") a = float(a) b = int(b) return '{}10^{{{}}}'.format('{} * '.format(a) if a != 1. else '', b) def rmse(x1, x2): - """ Compute the root mean square error between two 1D arrays """ + ''' Compute the root mean square error between two 1D arrays ''' return np.sqrt(((x1 - x2) ** 2).mean()) def rsquared(x1, x2): ''' compute the R-squared coefficient between two 1D arrays ''' residuals = x1 - x2 ss_res = np.sum(residuals**2) ss_tot = np.sum((x1 - np.mean(x1))**2) return 1 - (ss_res / ss_tot) def Pressure2Intensity(p, rho=1075.0, c=1515.0): - """ Return the spatial peak, pulse average acoustic intensity (ISPPA) + ''' Return the spatial peak, pulse average acoustic intensity (ISPPA) associated with the specified pressure amplitude. :param p: pressure amplitude (Pa) :param rho: medium density (kg/m3) :param c: speed of sound in medium (m/s) :return: spatial peak, pulse average acoustic intensity (W/m2) - """ + ''' return p**2 / (2 * rho * c) def Intensity2Pressure(I, rho=1075.0, c=1515.0): - """ Return the pressure amplitude associated with the specified + ''' Return the pressure amplitude associated with the specified spatial peak, pulse average acoustic intensity (ISPPA). :param I: spatial peak, pulse average acoustic intensity (W/m2) :param rho: medium density (kg/m3) :param c: speed of sound in medium (m/s) :return: pressure amplitude (Pa) - """ + ''' return np.sqrt(2 * rho * c * I) def OpenFilesDialog(filetype, dirname=''): - """ Open a FileOpenDialogBox to select one or multiple file. + ''' Open a FileOpenDialogBox to select one or multiple file. The default directory and file type are given. :param dirname: default directory :param filetype: default file type :return: tuple of full paths to the chosen filenames - """ + ''' root = tk.Tk() root.withdraw() filenames = filedialog.askopenfilenames(filetypes=[(filetype + " files", '.' + filetype)], initialdir=dirname) if filenames: par_dir = os.path.abspath(os.path.join(filenames[0], os.pardir)) else: par_dir = None return (filenames, par_dir) def selectDirDialog(): - """ Open a dialog box to select a directory. + ''' Open a dialog box to select a directory. :return: full path to selected directory - """ + ''' root = tk.Tk() root.withdraw() return filedialog.askdirectory() def SaveFileDialog(filename, dirname=None, ext=None): ''' Open a dialog box to save file. :param filename: filename :param dirname: initial directory :param ext: default extension :return: full path to the chosen filename ''' root = tk.Tk() root.withdraw() filename_out = filedialog.asksaveasfilename( defaultextension=ext, initialdir=dirname, initialfile=filename) return filename_out def downsample(t_dense, y, nsparse): - """ Decimate periodic signals to a specified number of samples.""" + ''' Decimate periodic signals to a specified number of samples.''' if(y.ndim) > 1: nsignals = y.shape[0] else: nsignals = 1 y = np.array([y]) # determine time step and period of input signal T = t_dense[-1] - t_dense[0] dt_dense = t_dense[1] - t_dense[0] # resample time vector linearly t_ds = np.linspace(t_dense[0], t_dense[-1], nsparse) # create MAV window nmav = int(0.03 * T / dt_dense) if nmav % 2 == 0: nmav += 1 mav = np.ones(nmav) / nmav # determine signals padding npad = int((nmav - 1) / 2) # determine indexes of sampling on convolved signals ids = np.round(np.linspace(0, t_dense.size - 1, nsparse)).astype(int) y_ds = np.empty((nsignals, nsparse)) # loop through signals for i in range(nsignals): # pad, convolve and resample pad_left = y[i, -(npad + 2):-2] pad_right = y[i, 1:npad + 1] y_ext = np.concatenate((pad_left, y[i, :], pad_right), axis=0) y_mav = np.convolve(y_ext, mav, mode='valid') y_ds[i, :] = y_mav[ids] if nsignals == 1: y_ds = y_ds[0, :] return (t_ds, y_ds) def rescale(x, lb=None, ub=None, lb_new=0, ub_new=1): ''' Rescale a value to a specific interval by linear transformation. ''' if lb is None: lb = x.min() if ub is None: ub = x.max() xnorm = (x - lb) / (ub - lb) return xnorm * (ub_new - lb_new) + lb_new def getStimPulses(t, states): ''' Determine the onset and offset times of pulses from a stimulation vector. :param t: time vector (s). :param states: a vector of stimulation state (ON/OFF) at each instant in time. :return: 3-tuple with number of patches, timing of STIM-ON an STIM-OFF instants. ''' # Compute states derivatives and identify bounds indexes of pulses dstates = np.diff(states) ipulse_on = np.insert(np.where(dstates > 0.0)[0] + 1, 0, 0) ipulse_off = np.where(dstates < 0.0)[0] + 1 if ipulse_off.size < ipulse_on.size: ioff = t.size - 1 if ipulse_off.size == 0: ipulse_off = np.array([ioff]) else: ipulse_off = np.insert(ipulse_off, ipulse_off.size - 1, ioff) # Get time instants for pulses ON and OFF npulses = ipulse_on.size tpulse_on = t[ipulse_on] tpulse_off = t[ipulse_off] # return 3-tuple with #pulses, pulse ON and pulse OFF instants return npulses, tpulse_on, tpulse_off def extractCompTimes(filenames): ''' Extract computation times from a list of simulation files. ''' tcomps = np.empty(len(filenames)) for i, fn in enumerate(filenames): logger.info('Loading data from "%s"', fn) with open(fn, 'rb') as fh: frame = pickle.load(fh) meta = frame['meta'] tcomps[i] = meta['tcomp'] return tcomps def getNeuronLookupsFile(mechname): return os.path.join( os.path.split(__file__)[0], 'neurons', '{}_lookups.pkl'.format(mechname)) def getLookups2D(mechname, a, Fdrive): ''' Retrieve appropriate 2D lookup tables and reference vectors for a given membrane mechanism, sonophore diameter and US frequency. :param mechname: name of membrane density mechanism :param a: sonophore diameter (m) :param Fdrive: US frequency (Hz) :return: 3-tuple with 1D numpy arrays of reference acoustic amplitudes and charge densities, and a dictionary of 2D lookup numpy arrays ''' # Check lookup file existence lookup_path = getNeuronLookupsFile(mechname) if not os.path.isfile(lookup_path): raise FileNotFoundError('Missing lookup file: "{}"'.format(lookup_path)) # Load lookups dictionary logger.debug('Loading lookup table') with open(lookup_path, 'rb') as fh: lookups4D = pickle.load(fh) # Retrieve 1D inputs from lookups dictionary aref = lookups4D.pop('a') Fref = lookups4D.pop('f') Aref = lookups4D.pop('A') Qref = lookups4D.pop('Q') # Check that sonophore diameter is within lookup range arange = (aref.min() - 1e-12, aref.max() + 1e-12) if a < arange[0] or a > arange[1]: raise ValueError('Invalid sonophore diameter: {}m (must be within {}m - {}m lookup interval)' .format(*si_format([a, *arange], precision=2, space=' '))) # Check that US frequency is within lookup range Frange = (Fref.min() - 1e-9, Fref.max() + 1e-9) if Fdrive < Frange[0] or Fdrive > Frange[1]: raise ValueError('Invalid frequency: {}Hz (must be within {}Hz - {}Hz lookup interval)' .format(*si_format([Fdrive, *Frange], precision=2, space=' '))) # Interpolate 4D lookups at sonophore diameter and then at US frequency logger.debug('Interpolating lookups at a = {}m'.format(si_format(a, space=' '))) lookups3D = {key: interp1d(aref, y4D, axis=0)(a) for key, y4D in lookups4D.items()} logger.debug('Interpolating lookups at f = {}Hz'.format(si_format(Fdrive, space=' '))) lookups2D = {key: interp1d(Fref, y3D, axis=0)(Fdrive) for key, y3D in lookups3D.items()} return Aref, Qref, lookups2D diff --git a/deprecated/Taylor expansions/test_alpham_Taylor.py b/deprecated/Taylor expansions/test_alpham_Taylor.py index 97f4452..f30b6dc 100644 --- a/deprecated/Taylor expansions/test_alpham_Taylor.py +++ b/deprecated/Taylor expansions/test_alpham_Taylor.py @@ -1,46 +1,46 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-03-21 11:38:56 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2017-03-29 19:12:27 +# @Last Modified time: 2018-09-28 14:14:11 -""" Taylor expansions of the alpha_m function around different potential values. """ +''' Taylor expansions of the alpha_m function around different potential values. ''' import numpy as np from scipy.special import factorial import matplotlib.pyplot as plt from utils import bilinearExp # Vm vector nVm = 100 Vm = np.linspace(-80.0, 50.0, nVm) # mV # alpha_m vector am_params = (-43.2, -0.32, 0.25) alpham = bilinearExp(Vm, am_params, 0) # alpha_m Taylor expansion npoints = 10 norder = 4 Vm0 = np.linspace(-80.0, 50.0, npoints) # mV Vmdiff = Vm - np.tile(Vm0, (nVm, 1)).transpose() Talpham = np.empty((npoints, nVm)) for i in range(npoints): T = np.zeros(nVm) for j in range(norder + 1): T[:] += bilinearExp(Vm0[i], am_params, j) * Vmdiff[i, :]**j / factorial(j) Talpham[i, :] = T # Plot standard alpha_m vs. Taylor reconstruction around Vm0 _, ax = plt.subplots(figsize=(22, 10)) ax.set_xlabel('$V_m\ [mV]$', fontsize=20) ax.set_ylabel('$[ms^{-1}]$', fontsize=20) ax.plot(Vm, alpham, linewidth=2, label='$\\alpha_m$') for i in range(npoints): ax.plot(Vm, Talpham[i, :], linewidth=2, label='$T_{}\\alpha_m({:.1f})$'.format(norder, Vm0[i])) ax.legend(fontsize=20) plt.show() diff --git a/deprecated/Taylor expansions/test_alpham_eff_Taylor.py b/deprecated/Taylor expansions/test_alpham_eff_Taylor.py index 2003592..ad8dd25 100644 --- a/deprecated/Taylor expansions/test_alpham_eff_Taylor.py +++ b/deprecated/Taylor expansions/test_alpham_eff_Taylor.py @@ -1,152 +1,152 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-03-21 11:38:56 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2017-03-29 19:40:44 +# @Last Modified time: 2018-09-28 14:14:11 -""" Perform Taylor expansions (up to 4th order) of the alpha_m function - along one acoustic cycle. """ +''' Perform Taylor expansions (up to 4th order) of the alpha_m function + along one acoustic cycle. ''' import importlib import numpy as np from scipy.special import factorial import matplotlib.pyplot as plt import matplotlib.cm as cm import nblscore from utils import LoadParams, rescale, bilinearExp from constants import * importlib.reload(nblscore) # reloading nblscore module # Load NBLS parameters params = LoadParams("params.yaml") biomech = params['biomech'] ac_imp = biomech['rhoL'] * biomech['c'] # Rayl # Set geometry of NBLS structure a = 32e-9 # in-plane radius (m) d = 0.0e-6 # embedding tissue thickness (m) geom = {"a": a, "d": d} # Create a NBLS instance here (with dummy frequency parameter) nbls = nblscore.NeuronalBilayerSonophore(geom, params, 0.0, True) # Set stimulation parameters Fdrive = 3.5e5 # Hz Adrive = 1e5 # Pa phi = np.pi # acoustic wave phase # Set charge linear space nQ = 100 charges = np.linspace(-80.0, 50.0, nQ) * 1e-5 # C/m2 Qmin = np.amin(charges) Qmax = np.amax(charges) # Set alpha_m parameters am_params = (-43.2, -0.32, 0.25) # Set highest Taylor expansion order norder = 4 # Set time vector T = 1 / Fdrive t = np.linspace(0, T, NPC_FULL) dt = t[1] - t[0] # Initialize coefficients vectors deflections = np.empty((nQ, NPC_FULL)) Vm = np.empty((nQ, NPC_FULL)) alpham = np.empty((nQ, NPC_FULL)) # Run mechanical simulations for each imposed charge density print('Running {} mechanical simulations with imposed charge densities'.format(nQ)) simcount = 0 for i in range(nQ): simcount += 1 # Log to console print('--- sim {}/{}: Q = {:.1f} nC/cm2'.format(simcount, nQ, charges[i] * 1e5)) # Run simulation and retrieve deflection vector (_, y, _) = nbls.runMech(Adrive, Fdrive, phi, charges[i]) (_, Z, _) = y deflections[i, :] = Z[-NPC_FULL:] # Compute Vm and alpham vectors Vm[i, :] = [charges[i] / nbls.Capct(ZZ) for ZZ in deflections[i, :]] alpham[i, :] = bilinearExp(Vm[i, :] * 1e3, am_params, 0) # time-average Vm and alpham Vmavg = np.mean(Vm, axis=1) alphamavg = np.mean(alpham, axis=1) # (Vm - Vmavg) differences along cycle Vmavgext = np.tile(Vmavg, (NPC_FULL, 1)).transpose() Vmdiff = (Vm - Vmavgext) * 1e3 # alpham derivatives dalpham = np.empty((norder + 1, nQ)) for j in range(norder + 1): dalpham[j, :] = bilinearExp(Vmavg * 1e3, am_params, j) # Taylor expansions along cycle Talpham = np.empty((norder + 1, nQ, NPC_FULL)) dalphamext = np.tile(dalpham.transpose(), (NPC_FULL, 1, 1)).transpose() Talpham[0, :, :] = dalphamext[0, :, :] for j in range(1, norder + 1): jterm = dalphamext[j, :, :] * Vmdiff[:, :]**j / factorial(j) Talpham[j, :, :] = Talpham[j - 1, :, :] + jterm # time-averaging of Taylor expansions Talphamavg = np.mean(Talpham, axis=2) # ------------------ PLOTS ------------------- mymap = cm.get_cmap('jet') sm_Q = plt.cm.ScalarMappable(cmap=mymap, norm=plt.Normalize(Qmin * 1e5, Qmax * 1e5)) sm_Q._A = [] t_factor = 1e6 # 1: time average Vm _, ax = plt.subplots(figsize=(22, 10)) ax.set_xlabel('$Qm\ [uF/cm^2]$', fontsize=20) ax.set_ylabel('$\\overline{V_m}\ [mV]$', fontsize=20) ax.plot(charges * 1e5, Vmavg * 1e3, linewidth=2) # 2: alpham: standard time-averaged vs.evaluated at time-average Vm # vs. Taylor reconstructions around Vm_avg _, ax = plt.subplots(figsize=(22, 10)) ax.set_xlabel('$Qm\ [uF/cm^2]$', fontsize=20) ax.set_ylabel('$[ms^{-1}]$', fontsize=20) ax.plot(charges * 1e5, alphamavg, linewidth=2, label='$\\overline{\\alpha_m(V_m)}$') for j in range(norder + 1): ax.plot(charges * 1e5, Talphamavg[j, :], linewidth=2, label='$\\overline{T_' + str(j) + '[\\alpha_m(\\overline{V_m})]}$') ax.legend(fontsize=20) # 3: original alpham vs. highest order Taylor alpham reconstruction _, ax = plt.subplots(figsize=(22, 10)) ax.set_xlabel('$t \ (us)$', fontsize=20) ax.set_ylabel('$[ms^{-1}]$', fontsize=20) ax.plot(t * t_factor, alpham[0, :], linewidth=2, c=mymap(rescale(charges[0], Qmin, Qmax)), label='$\\overline{\\alpha_m(V_m)}$') ax.plot(t * t_factor, Talpham[-1, 0, :], '--', linewidth=2, c=mymap(rescale(charges[0], Qmin, Qmax)), label='$T_' + str(norder) + '[\\alpha_m(\\overline{V_m})]$') for i in range(1, nQ): ax.plot(t * t_factor, alpham[i, :], linewidth=2, c=mymap(rescale(charges[i], Qmin, Qmax))) ax.plot(t * t_factor, Talpham[-1, i, :], '--', linewidth=2, c=mymap(rescale(charges[i], Qmin, Qmax))) cbar = plt.colorbar(sm_Q) cbar.ax.set_ylabel('$Q \ (nC/cm^2)$', fontsize=28) ax.legend(fontsize=20) plt.tight_layout() plt.show() diff --git a/docs/index.rst b/docs/index.rst index 1cc870f..cb74bc1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,26 +1,26 @@ -***************************** -SONIC model -***************************** +.. PySONIC documentation master file, created by + sphinx-quickstart on Fri Sep 28 11:49:08 2018. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. -.. include:: ../README.md - -Modules: -========== +Welcome to PySONIC's documentation! +=================================== .. toctree:: :maxdepth: 2 - sonic.bls - sonic.solvers - sonic.channels - sonic.utils + core + neurons + utils + batches + postpro + plt Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` - diff --git a/scripts/plot.py b/scripts/plot.py index dd6f1d6..99a2c8c 100644 --- a/scripts/plot.py +++ b/scripts/plot.py @@ -1,62 +1,62 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-02-13 12:41:26 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-25 17:24:13 +# @Last Modified time: 2018-09-28 14:09:53 -""" Compare profiles of several specific output variables of NICE simulations. """ +''' Plot temporal profiles of specific simulation output variables. ''' import logging from argparse import ArgumentParser import matplotlib.pyplot as plt from PySONIC.utils import logger, OpenFilesDialog from PySONIC.plt import plotComp, plotBatch # Set logging level logger.setLevel(logging.INFO) default_comp = 'Qm' defaults_batch = {'V_m': ['Vm'], 'Q_m': ['Qm']} def main(): ap = ArgumentParser() # Runtime options ap.add_argument('-v', '--verbose', default=False, action='store_true', help='Increase verbosity') ap.add_argument('--hide', default=False, action='store_true', help='Hide output') ap.add_argument('-o', '--outputdir', type=str, default=None, help='Output directory') ap.add_argument('-c', '--compare', default=False, action='store_true', help='Comparative graph') ap.add_argument('-s', '--save', default=False, action='store_true', help='Save output') ap.add_argument('--vars', type=str, nargs='+', default=None, help='Variables to plot') # Parse arguments args = ap.parse_args() loglevel = logging.DEBUG if args.verbose is True else logging.INFO logger.setLevel(loglevel) # Select data files pkl_filepaths, _ = OpenFilesDialog('pkl') if not pkl_filepaths: logger.error('No input file') return # Comparative plot if args.compare: varname = default_comp if args.vars is None else args.vars[0] plotComp(pkl_filepaths, varname=varname) else: vars_dict = defaults_batch if args.vars is None else {key: [key] for key in args.vars} plotBatch(pkl_filepaths, title=True, vars_dict=vars_dict, directory=args.outputdir, plt_save=args.save, ask_before_save=not args.save) if not args.hide: plt.show() if __name__ == '__main__': main() diff --git a/scripts/plot_activation_map.py b/scripts/plot_activation_map.py index 0205582..d493dee 100644 --- a/scripts/plot_activation_map.py +++ b/scripts/plot_activation_map.py @@ -1,115 +1,116 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2018-09-26 09:51:43 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-26 16:50:27 +# @Last Modified time: 2018-09-28 14:11:11 + +''' Plot (duty-cycle x amplitude) US activation map of a neuron at a given frequency and PRF. ''' import numpy as np import logging import matplotlib.pyplot as plt from argparse import ArgumentParser from PySONIC.utils import logger, selectDirDialog, Intensity2Pressure from PySONIC.plt import plotActivationMap - # Default parameters defaults = dict( neuron='RS', diam=32, # nm freq=500, # kHz duration=1000, # ms PRF=100, # Hz amps=np.logspace(np.log10(10), np.log10(600), num=30), # kPa DCs=np.arange(1, 101), # % Ascale='log', FRscale='log', FRbounds=(1e0, 1e3), # Hz tmax=240, # ms Vbounds=(-150, 50), # mV ) def main(): ap = ArgumentParser() # Runtime options ap.add_argument('-v', '--verbose', default=False, action='store_true', help='Increase verbosity') ap.add_argument('-i', '--inputdir', type=str, default=None, help='Input directory') ap.add_argument('-r', '--rheobase', default=False, action='store_true', help='Show rheobase amplitudes') ap.add_argument('-c', '--connect', default=False, action='store_true', help='Show traces on click') # Stimulation parameters ap.add_argument('-n', '--neuron', type=str, default=defaults['neuron'], help='Neuron name (string)') ap.add_argument('-a', '--diam', type=float, default=defaults['diam'], help='Sonophore diameter (nm)') ap.add_argument('-f', '--freq', type=float, default=defaults['freq'], help='US frequency (kHz)') ap.add_argument('-d', '--duration', type=float, default=defaults['duration'], help='Stimulus duration (ms)') ap.add_argument('-A', '--amps', nargs='+', type=float, help='Acoustic pressure amplitude (kPa)') ap.add_argument('-I', '--intensities', nargs='+', type=float, help='Acoustic intensity (W/cm2)') ap.add_argument('--PRF', type=float, default=defaults['PRF'], help='PRF (Hz)') ap.add_argument('--DC', nargs='+', type=float, help='Duty cycle (%%)') # Plot options ap.add_argument('--Ascale', type=str, default=defaults['Ascale'], help='y-axis scale ("log" or "lin")') ap.add_argument('--FRscale', type=str, default=defaults['FRscale'], help='map color scale ("log" or "lin")') ap.add_argument('--FRbounds', type=float, nargs='+', default=defaults['FRbounds'], help='Lower and upper bounds for firing rate (Hz)') ap.add_argument('--tmax', type=float, default=defaults['tmax'], help='Max time value for callback graphs (ms)') ap.add_argument('--Vbounds', type=float, nargs='+', default=defaults['Vbounds'], help='Y-axis extent for callback graphs (mV)') # Parse arguments args = {key: value for key, value in vars(ap.parse_args()).items() if value is not None} # Runtime options loglevel = logging.DEBUG if args['verbose'] is True else logging.INFO logger.setLevel(loglevel) inputdir = args['inputdir'] if 'inputdir' in args else selectDirDialog() if inputdir == '': logger.error('Operation cancelled') return # Parameters neuron = args['neuron'] a = args['diam'] * 1e-9 # m Fdrive = args['freq'] * 1e3 # Hz tstim = args['duration'] * 1e-3 # s PRF = args['PRF'] # Hz DCs = np.array(args.get('DCs', defaults['DCs'])) * 1e-2 # (-) if 'amps' in args: amps = np.array(args['amps']) * 1e3 # Pa elif 'intensities' in args: amps = Intensity2Pressure(np.array(args['intensities']) * 1e4) # Pa else: amps = np.array(defaults['amps']) * 1e3 # Pa # Plot options for item in ['Ascale', 'FRscale']: assert args[item] in ('lin', 'log'), 'Unknown {}'.format(item) Ascale = args['Ascale'] FRscale = args['FRscale'] tmax = args['tmax'] # ms Vbounds = args['Vbounds'] # mV FRbounds = args['FRbounds'] # Hz # Plot activation map plotActivationMap( inputdir, neuron, a, Fdrive, tstim, PRF, amps, DCs, Ascale=Ascale, FRscale=FRscale, FRbounds=FRbounds, rheobase=args['rheobase'], connect=args['connect'], tmax=tmax, Vbounds=Vbounds) plt.show() if __name__ == '__main__': main() diff --git a/scripts/plot_effective_variables.py b/scripts/plot_effective_variables.py index 42aac23..b15b5b8 100644 --- a/scripts/plot_effective_variables.py +++ b/scripts/plot_effective_variables.py @@ -1,183 +1,181 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-02-15 15:59:37 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-24 20:35:30 - -''' Plot the profiles of effective variables as a function of charge density - with amplitude color code. ''' +# @Last Modified time: 2018-09-28 14:10:11 +''' Plot the effective variables as a function of charge density with amplitude color code. ''' import numpy as np from scipy.interpolate import interp2d import matplotlib.pyplot as plt import matplotlib.cm as cm import matplotlib from argparse import ArgumentParser from PySONIC.plt import pltvars from PySONIC.utils import logger, getLookups2D from PySONIC.neurons import getNeuronsDict # Default parameters defaults = dict( neuron='RS', diam=32.0, freq=500.0, amps=np.logspace(np.log10(1), np.log10(600), 10), # kPa ) def setGrid(n, ncolmax=3): ''' Determine number of rows and columns in figure grid, based on number of variables to plot. ''' if n <= ncolmax: return (1, n) else: return ((n - 1) // ncolmax + 1, ncolmax) def plotEffectiveVariables(neuron, a, Fdrive, amps=None, fs=12, ncolmax=2): ''' Plot the profiles of effective variables of a specific neuron for a given frequency. For each variable, one line chart per amplitude is plotted, using charge as the input variable on the abscissa and a linear color code for the amplitude value. :param neuron: channel mechanism object :param a: sonophore diameter (m) :param Fdrive: acoustic drive frequency (Hz) :param amps: vector of amplitudes at which variables must be plotted (Pa) :param fs: figure fontsize :param ncolmax: max number of columns on the figure :return: handle to the created figure ''' # Get 2D lookups at specific (a, Fdrive) combination Aref, Qref, lookups2D = getLookups2D(neuron.name, a, Fdrive) if 'V' in lookups2D: lookups2D['Vm'] = lookups2D.pop('V') # Define log-amplitude color code if amps is None: amps = Aref mymap = cm.get_cmap('Oranges') norm = matplotlib.colors.LogNorm(amps.min(), amps.max()) sm = cm.ScalarMappable(norm=norm, cmap=mymap) sm._A = [] # Plot logger.info('plotting') nrows, ncols = setGrid(len(lookups2D), ncolmax=ncolmax) xvar = pltvars['Qm'] Qbounds = np.array([Qref.min(), Qref.max()]) * xvar['factor'] fig, _ = plt.subplots(figsize=(3 * ncols, 1 * nrows), squeeze=False) for j, key in enumerate(lookups2D.keys()): ax = plt.subplot2grid((nrows, ncols), (j // ncols, j % ncols)) for s in ['right', 'top']: ax.spines[s].set_visible(False) yvar = pltvars[key] if j // ncols == nrows - 1: ax.set_xlabel('$\\rm {}\ ({})$'.format(xvar['label'], xvar['unit']), fontsize=fs) ax.set_xticks(Qbounds) else: ax.set_xticks([]) ax.spines['bottom'].set_visible(False) ax.xaxis.set_label_coords(0.5, -0.1) ax.yaxis.set_label_coords(-0.02, 0.5) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) ymin = np.inf ymax = -np.inf # Plot effective variable for each selected amplitude y0 = np.squeeze(interp2d(Aref, Qref, lookups2D[key].T)(0, Qref)) for Adrive in amps: y = np.squeeze(interp2d(Aref, Qref, lookups2D[key].T)(Adrive, Qref)) if 'alpha' in key or 'beta' in key: y[y > y0.max() * 2] = np.nan ax.plot(Qref * xvar['factor'], y * yvar['factor'], c=sm.to_rgba(Adrive)) ymin = min(ymin, y.min()) ymax = max(ymax, y.max()) # Plot reference variable ax.plot(Qref * xvar['factor'], y0 * yvar['factor'], '--', c='k') ymax = max(ymax, y0.max()) ymin = min(ymin, y0.min()) # Set axis y-limits if 'alpha' in key or 'beta' in key: ymax = y0.max() * 2 ylim = [ymin * yvar['factor'], ymax * yvar['factor']] if key == 'ng': ylim = [np.floor(ylim[0] * 1e2) / 1e2, np.ceil(ylim[1] * 1e2) / 1e2] else: factor = 1 / np.power(10, np.floor(np.log10(ylim[1]))) ylim = [np.floor(ylim[0] * factor) / factor, np.ceil(ylim[1] * factor) / factor] dy = ylim[1] - ylim[0] ax.set_yticks(ylim) ax.set_ylim(ylim) # ax.set_ylim([ylim[0] - 0.05 * dy, ylim[1] + 0.05 * dy]) # Annotate variable and unit xlim = ax.get_xlim() if np.argmax(y0) < np.argmin(y0): xtext = xlim[0] + 0.6 * (xlim[1] - xlim[0]) else: xtext = xlim[0] + 0.01 * (xlim[1] - xlim[0]) if key in ['Vm', 'ng']: ytext = ylim[0] + 0.85 * dy else: ytext = ylim[0] + 0.15 * dy ax.text(xtext, ytext, '$\\rm {}\ ({})$'.format(yvar['label'], yvar['unit']), fontsize=fs) fig.suptitle('{} neuron: original vs. effective variables @ {:.0f} kHz'.format( neuron.name, Fdrive * 1e-3)) # Plot colorbar fig.subplots_adjust(left=0.10, bottom=0.05, top=0.9, right=0.85) cbarax = fig.add_axes([0.87, 0.05, 0.04, 0.85]) fig.colorbar(sm, cax=cbarax) cbarax.set_ylabel('amplitude (Pa)', fontsize=fs) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) return fig def main(): ap = ArgumentParser() # Stimulation parameters ap.add_argument('-n', '--neuron', type=str, default=defaults['neuron'], help='Neuron name (string)') ap.add_argument('-a', '--diam', type=float, default=defaults['diam'], help='Sonophore diameter (nm)') ap.add_argument('-f', '--freq', type=float, default=defaults['freq'], help='US frequency (kHz)') ap.add_argument('-A', '--amps', nargs='+', type=float, help='Acoustic pressure amplitude (kPa)') # Parse arguments args = {key: value for key, value in vars(ap.parse_args()).items() if value is not None} neuron_str = args['neuron'] diam = args['diam'] * 1e-9 # m Fdrive = args['freq'] * 1e3 # Hz amps = np.array(args.get('amps', defaults['amps'])) * 1e3 # Pa # Plot effective variables if neuron_str not in getNeuronsDict(): logger.error('Unknown neuron type: "%s"', neuron_str) return neuron = getNeuronsDict()[neuron_str]() plotEffectiveVariables(neuron, diam, Fdrive, amps=amps) plt.show() if __name__ == '__main__': main() diff --git a/scripts/plot_gating_kinetics.py b/scripts/plot_gating_kinetics.py index b977663..e46cc35 100644 --- a/scripts/plot_gating_kinetics.py +++ b/scripts/plot_gating_kinetics.py @@ -1,131 +1,131 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-10-11 20:35:38 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-24 20:35:03 +# @Last Modified time: 2018-09-28 14:11:36 -""" Plot the voltage-dependent steady-states and time constants of activation and inactivation - gates of the different ionic currents involved in the neuron's membrane. """ +''' Plot the voltage-dependent steady-states and time constants of activation and inactivation + gates of the different ionic currents involved in the neuron's membrane dynamics. ''' import numpy as np import matplotlib.pyplot as plt from argparse import ArgumentParser from PySONIC.utils import logger from PySONIC.neurons import getNeuronsDict # Default parameters defaults = dict( neuron='RS' ) def plotGatingKinetics(neuron, fs=15): ''' Plot the voltage-dependent steady-states and time constants of activation and inactivation gates of the different ionic currents involved in a specific neuron's membrane. :param neuron: specific channel mechanism object :param fs: labels and title font size ''' # Input membrane potential vector Vm = np.linspace(-100, 50, 300) xinf_dict = {} taux_dict = {} logger.info('Computing %s neuron gating kinetics', neuron.name) names = neuron.states_names for xname in names: Vm_state = True # Names of functions of interest xinf_func_str = xname.lower() + 'inf' taux_func_str = 'tau' + xname.lower() alphax_func_str = 'alpha' + xname.lower() betax_func_str = 'beta' + xname.lower() # derx_func_str = 'der' + xname.upper() # 1st choice: use xinf and taux function if hasattr(neuron, xinf_func_str) and hasattr(neuron, taux_func_str): xinf_func = getattr(neuron, xinf_func_str) taux_func = getattr(neuron, taux_func_str) xinf = np.array([xinf_func(v) for v in Vm]) if isinstance(taux_func, float): taux = taux_func * np.ones(len(Vm)) else: taux = np.array([taux_func(v) for v in Vm]) # 2nd choice: use alphax and betax functions elif hasattr(neuron, alphax_func_str) and hasattr(neuron, betax_func_str): alphax_func = getattr(neuron, alphax_func_str) betax_func = getattr(neuron, betax_func_str) alphax = np.array([alphax_func(v) for v in Vm]) if isinstance(betax_func, float): betax = betax_func * np.ones(len(Vm)) else: betax = np.array([betax_func(v) for v in Vm]) taux = 1.0 / (alphax + betax) xinf = taux * alphax # # 3rd choice: use derX choice # elif hasattr(neuron, derx_func_str): # derx_func = getattr(neuron, derx_func_str) # xinf = brentq(lambda x: derx_func(neuron.Vm, x), 0, 1) else: Vm_state = False if not Vm_state: logger.error('no function to compute %s-state gating kinetics', xname) else: xinf_dict[xname] = xinf taux_dict[xname] = taux fig, axes = plt.subplots(2) fig.suptitle('{} neuron: gating dynamics'.format(neuron.name)) ax = axes[0] ax.get_xaxis().set_ticklabels([]) ax.set_ylabel('$X_{\infty}$', fontsize=fs) for xname in names: if xname in xinf_dict: ax.plot(Vm, xinf_dict[xname], lw=2, label='$' + xname + '_{\infty}$') ax.legend(fontsize=fs, loc=7) ax = axes[1] ax.set_xlabel('$V_m\ (mV)$', fontsize=fs) ax.set_ylabel('$\\tau_X\ (ms)$', fontsize=fs) for xname in names: if xname in taux_dict: ax.plot(Vm, taux_dict[xname] * 1e3, lw=2, label='$\\tau_{' + xname + '}$') ax.legend(fontsize=fs, loc=7) return fig def main(): ap = ArgumentParser() # Stimulation parameters ap.add_argument('-n', '--neuron', type=str, default=defaults['neuron'], help='Neuron name (string)') # Parse arguments args = ap.parse_args() neuron_str = args.neuron # Plot gating kinetics variables if neuron_str not in getNeuronsDict(): logger.error('Unknown neuron type: "%s"', neuron_str) return neuron = getNeuronsDict()[neuron_str]() plotGatingKinetics(neuron) plt.show() if __name__ == '__main__': main() diff --git a/scripts/plot_rheobase_amps.py b/scripts/plot_rheobase_amps.py index fbb39ed..4273c6d 100644 --- a/scripts/plot_rheobase_amps.py +++ b/scripts/plot_rheobase_amps.py @@ -1,104 +1,105 @@ # -*- coding: utf-8 -*- # @Author: Theo # @Date: 2018-04-30 21:06:10 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-26 11:44:29 +# @Last Modified time: 2018-09-28 14:12:31 -''' Plot neuron-specific rheobase acoustic amplitudes for various duty cycles. ''' +''' Plot duty-cycle dependent rheobase acoustic amplitudes of various neurons + for a specific US frequency and PRF. ''' import logging import numpy as np import matplotlib.pyplot as plt from argparse import ArgumentParser from PySONIC.utils import logger, si_format from PySONIC.core import NeuronalBilayerSonophore from PySONIC.neurons import getNeuronsDict # Set logging level logger.setLevel(logging.INFO) # Default parameters defaults = dict( neurons=['RS', 'FS', 'RE'], diam=32.0, freq=500.0 ) def plotAstimRheobaseAmps(neurons, a, Fdrive, fs=15): fig, ax = plt.subplots() ax.set_title('Rheobase amplitudes @ {}Hz ({:.0f} nm sonophore)'.format( si_format(Fdrive, 1, space=' '), a * 1e9), fontsize=fs) ax.set_xlabel('Duty cycle (%)', fontsize=fs) ax.set_ylabel('Threshold amplitude (kPa)', fontsize=fs) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) ax.set_yscale('log') ax.set_ylim([10, 600]) DCs = np.arange(1, 101) / 1e2 for neuron in neurons: nbls = NeuronalBilayerSonophore(a, neuron) Athrs = nbls.findRheobaseAmps(DCs, Fdrive, neuron.VT) ax.plot(DCs * 1e2, Athrs * 1e-3, label='{} neuron'.format(neuron.name)) ax.legend(fontsize=fs, frameon=False) fig.tight_layout() return fig def plotEstimRheobaseAmps(neurons, fs=15): fig, ax = plt.subplots() ax.set_title('Rheobase amplitudes', fontsize=fs) ax.set_xlabel('Duty cycle (%)', fontsize=fs) ax.set_ylabel('Threshold amplitude (mA/m2)', fontsize=fs) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) ax.set_yscale('log') ax.set_ylim([1e0, 1e3]) DCs = np.arange(1, 101) / 1e2 for neuron in neurons: Athrs = neuron.findRheobaseAmps(DCs, neuron.VT) ax.plot(DCs * 1e2, Athrs, label='{} neuron'.format(neuron.name)) ax.legend(fontsize=fs, frameon=False) fig.tight_layout() return fig def main(): ap = ArgumentParser() # Stimulation parameters ap.add_argument('-n', '--neurons', type=str, nargs='+', default=defaults['neurons'], help='Neuron name (string)') ap.add_argument('-a', '--diam', type=float, default=defaults['diam'], help='Sonophore diameter (nm)') ap.add_argument('-f', '--freq', type=float, default=defaults['freq'], help='US frequency (kHz)') ap.add_argument('-m', '--mode', type=str, default='US', help='Stimulation modality (US or elec)') # Parse arguments args = {key: value for key, value in vars(ap.parse_args()).items() if value is not None} neurons_str = args.get('neurons', defaults['neurons']) neurons = [] for n in neurons_str: if n not in getNeuronsDict(): logger.error('Invalid neuron type: "%s"', n) return neurons.append(getNeuronsDict()[n]()) mode = args['mode'] if mode == 'US': diam = args.get('diam', defaults['diam']) * 1e-9 # m Fdrive = args.get('freq', defaults['freq']) * 1e3 # Hz plotAstimRheobaseAmps(neurons, diam, Fdrive) elif mode == 'elec': plotEstimRheobaseAmps(neurons) else: logger.error('Invalid stimulation type: "%s"', mode) return plt.show() if __name__ == '__main__': main() diff --git a/scripts/plot_spikes_details.py b/scripts/plot_spikes_details.py index 690799c..e4fc41c 100644 --- a/scripts/plot_spikes_details.py +++ b/scripts/plot_spikes_details.py @@ -1,75 +1,75 @@ # -*- coding: utf-8 -*- # @Author: Theo # @Date: 2018-04-04 11:49:07 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-24 20:54:26 +# @Last Modified time: 2018-09-28 14:12:50 -''' Plot detected spikes on charge profiles. ''' +''' Plot features of detected spikes on charge profiles. ''' import os import pickle import logging import numpy as np import matplotlib.pyplot as plt from PySONIC.utils import logger, OpenFilesDialog from PySONIC.postpro import findPeaks from PySONIC.constants import * # Set logging level logger.setLevel(logging.INFO) def plotSpikesDetails(filepaths, fs=15, lw=2): for fpath in filepaths: # Load charge profile from file fname = os.path.basename(fpath) logger.info('Loading data from "%s" file', fname) with open(fpath, 'rb') as fh: frame = pickle.load(fh) df = frame['data'] t = df['t'].values Qm = df['Qm'].values dt = t[1] - t[0] indexes = np.arange(t.size) mpd = int(np.ceil(SPIKE_MIN_DT / dt)) ipeaks, prominences, widths, ibounds = findPeaks(Qm, mph=SPIKE_MIN_QAMP, mpd=mpd, mpp=SPIKE_MIN_QPROM) if ipeaks is not None: widths *= dt tleftbounds = np.interp(ibounds[:, 0], indexes, t) trightbounds = np.interp(ibounds[:, 1], indexes, t) # Plot results fig, ax = plt.subplots(figsize=(8, 4)) ax.set_title(os.path.splitext(fname)[0], fontsize=fs) ax.set_xlabel('time (ms)', fontsize=fs) ax.set_ylabel('charge\ (nC/cm2)', fontsize=fs) ax.plot(t * 1e3, Qm * 1e5, color='C0', label='trace', linewidth=lw) if ipeaks is not None: ax.scatter(t[ipeaks] * 1e3, Qm[ipeaks] * 1e5 + 3, color='k', label='peaks', marker='v') for i in range(len(ipeaks)): ax.plot(np.array([t[ipeaks[i]]] * 2) * 1e3, np.array([Qm[ipeaks[i]], Qm[ipeaks[i]] - prominences[i]]) * 1e5, color='C1', label='prominences' if i == 0 else '') ax.plot(np.array([tleftbounds[i], trightbounds[i]]) * 1e3, np.array([Qm[ipeaks[i]] - 0.5 * prominences[i]] * 2) * 1e5, color='C2', label='widths' if i == 0 else '') ax.legend(frameon=False) plt.show() def main(): # Select data files pkl_filepaths, _ = OpenFilesDialog('pkl') if not pkl_filepaths: logger.error('No input file') return plotSpikesDetails(pkl_filepaths) if __name__ == '__main__': main() diff --git a/scripts/run_lookups.py b/scripts/run_lookups.py index 2517743..066d56e 100644 --- a/scripts/run_lookups.py +++ b/scripts/run_lookups.py @@ -1,155 +1,155 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-06-02 17:50:10 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2018-09-27 02:03:15 +# @Last Modified time: 2018-09-28 14:14:11 -""" Create lookup table for specific neuron. """ +''' Create lookup table for specific neuron. ''' import os import pickle import logging import numpy as np from argparse import ArgumentParser from PySONIC.utils import logger, getNeuronLookupsFile from PySONIC.batches import createQueue, runBatch from PySONIC.neurons import getNeuronsDict from PySONIC.core import NeuronalBilayerSonophore # Default parameters defaults = dict( neuron='RS', diams=np.array([16.0, 32.0, 64.0]), # nm freqs=np.array([20., 100., 500., 1e3, 2e3, 3e3, 4e3]), # kHz amps=np.insert(np.logspace(np.log10(0.1), np.log10(600), num=50), 0, 0.0), # kPa charges=None ) def computeAStimLookups(neuron, aref, fref, Aref, Qref, phi=np.pi, mpi=False, loglevel=logging.INFO): ''' Run simulations of the mechanical system for a multiple combinations of imposed US frequencies, acoustic amplitudes and charge densities, compute effective coefficients and store them in a dictionary of 3D arrays. :param neuron: neuron object :param aref: array of sonophore diameters (m) :param fref: array of acoustic drive frequencies (Hz) :param Aref: array of acoustic drive amplitudes (Pa) :param Qref: array of membrane charge densities (C/m2) :param phi: acoustic drive phase (rad) :param mpi: boolean statting wether or not to use multiprocessing :return: lookups dictionary ''' # Check validity of input parameters for key, values in {'diameters': aref, 'frequencies': fref, 'amplitudes': Aref}.items(): if not (isinstance(values, list) or isinstance(values, np.ndarray)): raise TypeError('Invalid {} (must be provided as list or numpy array)'.format(key)) if not all(isinstance(x, float) for x in values): raise TypeError('Invalid {} (must all be float typed)'.format(key)) if len(values) == 0: raise ValueError('Empty {} array'.format(key)) if key in ('diameters', 'frequencies') and min(values) <= 0: raise ValueError('Invalid {} (must all be strictly positive)'.format(key)) if key is 'amplitudes' and min(values) < 0: raise ValueError('Invalid {} (must all be positive or null)'.format(key)) # create simulation queue na, nf, nA, nQ = len(aref), len(fref), len(Aref), len(Qref) queue = createQueue((fref, Aref, Qref)) # run simulations and populate outputs (list of lists) logger.info('Starting simulation batch for %s neuron', neuron.name) outputs = [] for a in aref: nbls = NeuronalBilayerSonophore(a, neuron) outputs += runBatch(nbls, 'computeEffVars', queue, mpi=mpi, loglevel=loglevel) outputs = np.array(outputs).T # populate lookups dictionary with input vectors lookups = dict( a=aref, # nm f=fref, # Hz A=Aref, # Pa Q=Qref # C/m2 ) # reshape outputs into 4D arrays and add them to lookups dictionary logger.info('Reshaping output into lookup tables') keys = ['V', 'ng'] + neuron.coeff_names assert len(keys) == len(outputs), 'Lookup keys not matching array size' for key, output in zip(keys, outputs): lookups[key] = output.reshape(na, nf, nA, nQ) return lookups def main(): ap = ArgumentParser() # Runtime options ap.add_argument('--mpi', default=False, action='store_true', help='Use multiprocessing') ap.add_argument('-v', '--verbose', default=False, action='store_true', help='Increase verbosity') ap.add_argument('-t', '--test', default=False, action='store_true', help='Test configuration') # Stimulation parameters ap.add_argument('-n', '--neuron', type=str, default=defaults['neuron'], help='Neuron name (string)') ap.add_argument('-a', '--diams', nargs='+', type=float, help='Sonophore diameter (nm)') ap.add_argument('-f', '--freqs', nargs='+', type=float, help='US frequency (kHz)') ap.add_argument('-A', '--amps', nargs='+', type=float, help='Acoustic pressure amplitude (kPa)') ap.add_argument('-Q', '--charges', nargs='+', type=float, help='Mmebrane charge density (nC/cm2)') # Parse arguments args = {key: value for key, value in vars(ap.parse_args()).items() if value is not None} loglevel = logging.DEBUG if args['verbose'] is True else logging.INFO logger.setLevel(loglevel) mpi = args['mpi'] neuron_str = args['neuron'] diams = np.array(args.get('diams', defaults['diams'])) * 1e-9 # m freqs = np.array(args.get('freqs', defaults['freqs'])) * 1e3 # Hz amps = np.array(args.get('amps', defaults['amps'])) * 1e3 # Pa # Check neuron name validity if neuron_str not in getNeuronsDict(): logger.error('Unknown neuron type: "%s"', neuron_str) return neuron = getNeuronsDict()[neuron_str]() if 'charges' in args: charges = np.array(args['charges']) * 1e-5 # C/m2 else: charges = np.arange(neuron.Qbounds[0], neuron.Qbounds[1] + 1e-5, 1e-5) # C/m2 if args['test']: diams = np.array([diams.min(), diams.max()]) freqs = np.array([freqs.min(), freqs.max()]) amps = np.array([amps.min(), amps.max()]) charges = np.array([charges.min(), 0., charges.max()]) # Check if lookup file already exists lookup_path = getNeuronLookupsFile(neuron.name) if os.path.isfile(lookup_path): logger.warning('"%s" file already exists and will be overwritten. ' + 'Continue? (y/n)', lookup_path) user_str = input() if user_str not in ['y', 'Y']: logger.error('%s Lookup creation canceled', neuron.name) return # compute lookups lookup_dict = computeAStimLookups(neuron, diams, freqs, amps, charges, mpi=mpi, loglevel=loglevel) # Save dictionary in lookup file logger.info('Saving %s neuron lookup table in file: "%s"', neuron.name, lookup_path) with open(lookup_path, 'wb') as fh: pickle.dump(lookup_dict, fh) if __name__ == '__main__': main()