diff --git a/PySONIC/__init__.py b/PySONIC/__init__.py index 8172a43..b50d2c2 100644 --- a/PySONIC/__init__.py +++ b/PySONIC/__init__.py @@ -1,17 +1,18 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-06-06 13:36:00 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-12-02 18:06:16 +# @Last Modified time: 2020-09-24 18:51:40 ''' Import the core classes, generic utilities and algorithmic constants. ''' from .core import * from .neurons import * from .utils import * from .threshold import * from .parsers import * from .postpro import * from .constants import * from .plt import * +from .multicomp import * diff --git a/PySONIC/core/__init__.py b/PySONIC/core/__init__.py index 87eb579..a70e9a5 100644 --- a/PySONIC/core/__init__.py +++ b/PySONIC/core/__init__.py @@ -1,51 +1,50 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-06-06 13:36:00 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-09-24 15:30:48 +# @Last Modified time: 2020-09-24 18:50:39 from types import MethodType import inspect import sys from .solvers import * from .batches import * from .model import * from .pneuron import * from .bls import * from .translators import * from .nbls import * from .vclamp import * from .lookups import * from .stimobj import * from .protocols import * from .drives import * -from .multicomp_benchmark import * from ..neurons import getPointNeuron def getModelsDict(): ''' Construct a dictionary of all model classes, indexed by simulation key. ''' current_module = sys.modules[__name__] models_dict = {} for _, obj in inspect.getmembers(current_module): if inspect.isclass(obj) and hasattr(obj, 'simkey') and isinstance(obj.simkey, str): models_dict[obj.simkey] = obj return models_dict # Add an initFromMeta method to the Pointneuron class (done here to avoid circular import) PointNeuron.initFromMeta = MethodType( lambda self, meta: getPointNeuron(meta['neuron']), PointNeuron) models_dict = getModelsDict() def getModel(meta): ''' Return appropriate model object based on a dictionary of meta-information. ''' simkey = meta['simkey'] try: return models_dict[simkey].initFromMeta(meta['model']) except KeyError: raise ValueError(f'Unknown model type:{simkey}') diff --git a/PySONIC/multicomp/__init__.py b/PySONIC/multicomp/__init__.py new file mode 100644 index 0000000..fbfccf4 --- /dev/null +++ b/PySONIC/multicomp/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +# @Author: Theo Lemaire +# @Email: theo.lemaire@epfl.ch +# @Date: 2020-09-24 18:50:50 +# @Last Modified by: Theo Lemaire +# @Last Modified time: 2020-09-24 19:01:42 + +from .gamma_map import * +from .benchmark import * +from .divmaps import * \ No newline at end of file diff --git a/PySONIC/core/multicomp_benchmark.py b/PySONIC/multicomp/benchmark.py similarity index 94% rename from PySONIC/core/multicomp_benchmark.py rename to PySONIC/multicomp/benchmark.py index 2dd6af5..0d9955f 100644 --- a/PySONIC/core/multicomp_benchmark.py +++ b/PySONIC/multicomp/benchmark.py @@ -1,592 +1,600 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2020-09-24 15:30:34 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-09-24 18:44:18 +# @Last Modified time: 2020-09-24 20:03:59 import numpy as np from scipy.integrate import odeint import matplotlib.pyplot as plt from PySONIC.core import EffectiveVariablesLookup from ..utils import logger, timer, isWithin, si_format, rmse from ..neurons import passiveNeuron class SonicBenchmark: ''' Interface allowing to run benchmark simulations of a two-compartment model incorporating the SONIC paradigm, with a simplified sinusoidal capacitive drive. ''' npc = 25 # number of samples per cycle min_ncycles = 10 # minimum number of cycles per simulation varunits = { 't': 'ms', 'Cm': 'uF/cm2', 'Vm': 'mV', 'Qm': 'nC/cm2' } varfactors = { 't': 1e3, 'Cm': 1e2, 'Vm': 1e0, 'Qm': 1e5 } nodelabels = ['node 1', 'node 2'] ga_bounds = [1e-10, 1e10] # S/m2 - def __init__(self, pneuron, ga, f, gammas, passive=False): + def __init__(self, pneuron, ga, Fdrive, gammas, passive=False): ''' Initialization. :param pneuron: point-neuron object :param ga: axial conductance (S/m2) - :param f: US frequency (Hz) + :param Fdrive: US frequency (Hz) :param gammas: pair of relative capacitance oscillation ranges ''' self.pneuron = pneuron self.ga = ga - self.f = f + self.Fdrive = Fdrive self.gammas = gammas self.passive = passive self.computeLookups() def copy(self): - return self.__class__(self.pneuron, self.ga, self.f, self.gammas, passive=self.passive) + return self.__class__(self.pneuron, self.ga, self.Fdrive, self.gammas, passive=self.passive) @property - def strGammas(self): - return f"({', '.join([f'{x:.2f}' for x in self.gammas])})" + def gammalist(self): + return [f'{x:.2f}' for x in self.gammas] + + @property + def gammastr(self): + return f"({', '.join(self.gammalist)})" + + @property + def fstr(self): + return f'{si_format(self.Fdrive)}Hz' def __repr__(self): params = [ f'ga = {self.ga:.2e} S/m2', - f'f = {si_format(self.f)}Hz', - f'gamma = {self.strGammas}' + f'f = {self.fstr}', + f'gamma = {self.gammastr}' ] dynamics = 'passive ' if self.passive else '' mech = f'{dynamics}{self.pneuron.name} dynamics' return f'{self.__class__.__name__}({mech}, {", ".join(params)})' @property def pneuron(self): return self._pneuron @pneuron.setter def pneuron(self, value): self._pneuron = value.copy() self.states = self._pneuron.statesNames() if hasattr(self, 'lkps'): self.computeLookups() def isPassive(self): return self.pneuron.name.startswith('pas_') @property - def f(self): - return self._f + def Fdrive(self): + return self._Fdrive - @f.setter - def f(self, value): - self._f = value + @Fdrive.setter + def Fdrive(self, value): + self._Fdrive = value if hasattr(self, 'lkps'): self.computeLookups() @property def gammas(self): return self._gammas @gammas.setter def gammas(self, value): self._gammas = value if hasattr(self, 'lkps'): self.computeLookups() @property def passive(self): return self._passive @passive.setter def passive(self, value): assert isinstance(value, bool), 'passive must be boolean typed' self._passive = value if hasattr(self, 'lkps'): self.computeLookups() @property def ga(self): return self._ga @ga.setter def ga(self, value): if value != 0.: assert isWithin('ga', value, self.ga_bounds) self._ga = value @property def gPas(self): ''' Passive membrane conductance (S/m2). ''' return self.pneuron.gLeak @property def Cm0(self): ''' Resting capacitance (F/m2). ''' return self.pneuron.Cm0 @property def Vm0(self): ''' Resting membrane potential (mV). ''' return self.pneuron.Vm0 @property def Qm0(self): ''' Resting membrane charge density (C/m2). ''' return self.Vm0 * self.Cm0 * 1e-3 @property def Qref(self): ''' Reference charge linear space. ''' return np.arange(*self.pneuron.Qbounds, 1e-5) # C/cm2 @property def Cmeff(self): ''' Analytical solution for effective membrane capacitance (F/m2). ''' return self.Cm0 * np.sqrt(1 - np.array(self.gammas)**2 / 4) @property def Qminf(self): ''' Analytical solution for steady-state charge density (C/m2). ''' return self.Cmeff * self.pneuron.ELeak * 1e-3 def capct(self, gamma, t): ''' Time-varying sinusoidal capacitance (in F/m2) ''' - return self.Cm0 * (1 + 0.5 * gamma * np.sin(2 * np.pi * self.f * t)) + return self.Cm0 * (1 + 0.5 * gamma * np.sin(2 * np.pi * self.Fdrive * t)) def vCapct(self, t): ''' Vector of time-varying capacitance (in F/m2) ''' return np.array([self.capct(gamma, t) for gamma in self.gammas]) def getLookup(self, Cm): ''' Get a lookup object of effective variables for a given capacitance cycle vector. ''' Vmarray = np.array([Q / Cm for Q in self.Qref]) * 1e3 # mV refs = {'Q': self.Qref} # C/m2 tables = { k: np.array([np.mean(np.vectorize(v)(Vmvec)) for Vmvec in Vmarray]) for k, v in self.pneuron.effRates().items() } return EffectiveVariablesLookup(refs, tables) @property def tcycle(self): ''' Time vector over 1 acoustic cycle (s). ''' - return np.linspace(0, 1 / self.f, self.npc) + return np.linspace(0, 1 / self.Fdrive, self.npc) @property def dt_full(self): ''' Full time step (s). ''' - return 1 / (self.npc * self.f) + return 1 / (self.npc * self.Fdrive) @property def dt_sparse(self): ''' Sparse time step (s). ''' - return 1 / self.f + return 1 / self.Fdrive def computeLookups(self): ''' Compute benchmark lookups. ''' self.lkps = [] if not self.passive: self.lkps = [self.getLookup(Cm_cycle) for Cm_cycle in self.vCapct(self.tcycle)] def getCmeff(self, Cm_cycle): ''' Compute effective capacitance from capacitance profile over 1 cycle. ''' return 1 / np.mean(1 / Cm_cycle) # F/m2 def iax(self, Vm, Vmother): ''' Compute axial current flowing in the compartment from another compartment (in mA/m2). [iax] = S/m2 * mV = 1e-3 A/m2 = 1 mA/m2 ''' return self.ga * (Vmother - Vm) def vIax(self, Vm): ''' Compute array of axial currents in each compartment based on array of potentials. ''' return np.array([self.iax(*Vm), self.iax(*Vm[::-1])]) # mA/m2 def serialize(self, y): ''' Serialize a single state vector into a state-per-node matrix. ''' return np.reshape(y.copy(), (self.npernode * 2)) def deserialize(self, y): ''' Deserialize a state per node matrix into a single state vector. ''' return np.reshape(y.copy(), (2, self.npernode)) def derivatives(self, t, y, Cm, dstates_func): ''' Generic derivatives method. ''' # Deserialize states vector and initialize derivatives array y = self.deserialize(y) dydt = np.empty(y.shape) # Extract charge density and membrane potential vectors Qm = y[:, 0] # C/m2 Vm = y[:, 0] / Cm * 1e3 # mV # Extract states array states_array = y[:, 1:] # Compute membrane dynamics for each node for i, (qm, vm, states) in enumerate(zip(Qm, Vm, states_array)): # If passive, compute only leakage current if self.passive: im = self.pneuron.iLeak(vm) # mA/m2 # Otherwise, compute states derivatives and total membrane current if not self.passive: states_dict = dict(zip(self.states, states)) dydt[i, 1:] = dstates_func(i, qm, vm, states_dict) # s-1 im = self.pneuron.iNet(vm, states_dict) # mA/m2 dydt[i, 0] = -im # mA/m2 # Add axial currents to currents column dydt[:, 0] += self.vIax(Vm) # mA/m2 # Rescale currents column into charge derivative units dydt[:, 0] *= 1e-3 # C/m2.s # Return serialized derivatives vector return self.serialize(dydt) def dstatesFull(self, i, qm, vm, states): ''' Compute detailed states derivatives. ''' return self.pneuron.getDerStates(vm, states) def dfull(self, t, y): ''' Compute detailed derivatives vector. ''' return self.derivatives(t, y, self.vCapct(t), self.dstatesFull) def dstatesEff(self, i, qm, vm, states): ''' Compute effective states derivatives. ''' lkp0d = self.lkps[i].interpolate1D(qm) return np.array([self.pneuron.derEffStates()[k](lkp0d, states) for k in self.states]) def deff(self, t, y): ''' Compute effective derivatives vector. ''' return self.derivatives(t, y, self.Cmeff, self.dstatesEff) @property def y0node(self): ''' Get initial conditions vector (common to every node). ''' if self.passive: return [self.Qm0] else: return [self.Qm0, *[self.pneuron.steadyStates()[k](self.Vm0) for k in self.states]] @property def y0(self): ''' Get full initial conditions vector (duplicated ynode vector). ''' self.npernode = len(self.y0node) return self.y0node + self.y0node def integrate(self, dfunc, t): ''' Integrate over a time vector and return charge density arrays. ''' # Integrate system y = odeint(dfunc, self.y0, t, tfirst=True).T # Cast each solution variable as a time-per-node matrix sol = {'Qm': y[::self.npernode]} if not self.passive: for i, k in enumerate(self.states): sol[k] = y[i + 1::self.npernode] # Return recast solution dictionary return sol def orderedKeys(self, varkeys): ''' Get ordered list of solution keys. ''' mainkeys = ['Qm', 'Vm', 'Cm'] otherkeys = list(set(varkeys) - set(mainkeys)) return mainkeys + otherkeys def orderedSol(self, sol): ''' Re-order solution according to keys list. ''' return {k: sol[k] for k in self.orderedKeys(sol.keys())} def nsamples(self, tstop): ''' Compute the number of samples required to integrate over a given time interval. ''' return self.getNCycles(tstop) * self.npc @timer def simFull(self, tstop): ''' Simulate the full system until a specific stop time. ''' t = np.linspace(0, tstop, self.nsamples(tstop)) sol = self.integrate(self.dfull, t) sol['Cm'] = self.vCapct(t) sol['Vm'] = sol['Qm'] / sol['Cm'] * 1e3 return t, self.orderedSol(sol) @timer def simEff(self, tstop): ''' Simulate the effective system until a specific stop time. ''' t = np.linspace(0, tstop, self.getNCycles(tstop)) sol = self.integrate(self.deff, t) sol['Cm'] = np.array([np.ones(t.size) * Cmeff for Cmeff in self.Cmeff]) sol['Vm'] = sol['Qm'] / sol['Cm'] * 1e3 return t, self.orderedSol(sol) @property def methods(self): ''' Dictionary of simulation methods. ''' return {'full': self.simFull, 'effective': self.simEff} def getNCycles(self, duration): ''' Compute number of cycles from a duration. ''' - return int(np.ceil(duration * self.f)) + return int(np.ceil(duration * self.Fdrive)) def simulate(self, mtype, tstop): ''' Simulate the system with a specific method for a given duration. ''' # Cast tstop as a multiple of the acoustic period - tstop = self.getNCycles(tstop) / self.f # s + tstop = self.getNCycles(tstop) / self.Fdrive # s # Retrieve simulation method try: method = self.methods[mtype] except KeyError: raise ValueError(f'"{mtype}" is not a valid method type') # Run simulation and return output logger.debug(f'running {mtype} {tstop:.2f} ms simulation') output, tcomp = method(tstop) logger.debug(f'completed in {tcomp:.2f} s') return output def cycleAvg(self, y): ''' Cycle-average a solution vector according to the number of samples per cycle. ''' ypercycle = np.reshape(y, (int(y.shape[0] / self.npc), self.npc)) return np.mean(ypercycle, axis=1) def cycleAvgSol(self, t, sol): ''' Cycle-average a time vector and a solution dictionary. ''' solavg = {} # For each per-node-matrix in the solution for k, ymat in sol.items(): # Cycle-average each node vector of the matrix solavg[k] = np.array([self.cycleAvg(yvec) for yvec in ymat]) # Re-sample time vector at system periodicity - tavg = t[::self.npc] # + 0.5 / self.f + tavg = t[::self.npc] # + 0.5 / self.Fdrive # Return cycle-averaged time vector and solution dictionary return tavg, solavg def g2tau(self, g): ''' Convert conductance per unit membrane area (S/m2) to time constant (s). ''' return self.Cm0 / g # ms def tau2g(self, tau): ''' Convert time constant (s) to conductance per unit membrane area (S/m2). ''' return self.Cm0 / tau # s @property def taum(self): ''' Passive membrane time constant (s). ''' return self.pneuron.tau_pas @taum.setter def taum(self, value): ''' Update point-neuron leakage conductance to match time new membrane time constant. ''' if not self.isPassive(): raise ValueError('taum can only be set for passive neurons') self.pneuron = passiveNeuron( self.pneuron.Cm0, self.tau2g(value), # S/m2 self.pneuron.ELeak) @property def tauax(self): ''' Axial time constant (s). ''' return self.g2tau(self.ga) @tauax.setter def tauax(self, value): ''' Update axial conductance per unit area to match time new axial time constant. ''' self.ga = self.tau2g(value) # S/m2 @property def taumax(self): ''' Maximal time constant of the model (s). ''' return max(self.taum, self.tauax) def setTimeConstants(self, taum, tauax): ''' Update benchmark according to pair of time constants (in s). ''' self.taum = taum # s self.tauax = tauax # s def setDrive(self, f, gammas): ''' Update benchmark drive to a new frequency and amplitude. ''' - self.f = f + self.Fdrive = f self.gammas = gammas def getPassiveTstop(self, f): ''' Compute minimum simulation time for a passive model (s). ''' return max(5 * self.taumax, self.min_ncycles / f) @property def passive_tstop(self): - return self.getPassiveTstop(self.f) + return self.getPassiveTstop(self.Fdrive) def simAllMethods(self, tstop): ''' Simulate the model with both methods. ''' logger.info(f'{self}: {si_format(tstop)}s simulation') # Simulate with full and effective systems t, sol = {}, {} for method in self.methods.keys(): t[method], sol[method] = self.simulate(method, tstop) t, sol = self.postproSol(t, sol) return t, sol def computeGradient(self, sol): ''' compute the gradient of a solution array. ''' return {k: np.vstack((y, np.diff(y, axis=0))) for k, y in sol.items()} def addOnset(self, ymat, y0): return np.hstack((np.ones((2, 2)) * y0, ymat)) def postproSol(self, t, sol): ''' Post-process solution. ''' # Add cycle-average of full solution t['cycle-avg'], sol['cycle-avg'] = self.cycleAvgSol(t['full'], sol['full']) keys = list(sol.keys()) tonset = 0.05 * np.ptp(t['full']) # Add onset y0dict = {'Cm': self.Cm0, 'Qm': self.Qm0, 'Vm': self.Vm0} for k in keys: t[k] = np.hstack(([-tonset, 0], t[k])) sol[k] = {vk: self.addOnset(ymat, y0dict[vk]) for vk, ymat in sol[k].items()} # Add gradient across nodes for each variable for k in keys: sol[f'{k}-grad'] = self.computeGradient(sol[k]) return t, sol def plot(self, t, sol, Qonly=False, gradient=False): ''' Plot results of benchmark simulations of the model. ''' colors = ['C0', 'C1', 'darkgrey'] markers = ['-', '--', '-'] alphas = [0.5, 1., 1.] # Reduce solution dictionary if only Q needs to be plotted if Qonly: sol = {key: {'Qm': value['Qm']} for key, value in sol.items()} # Extract simulation duration tstop = t[list(t.keys())[0]][-1] # s # Gather keys of methods and variables to plot mkeys = list(sol.keys()) varkeys = list(sol[mkeys[0]].keys()) naxes = len(varkeys) # Get node labels lbls = self.nodelabels if gradient: lbls.append('gradient') # Create figure fig, axes = plt.subplots(naxes, 1, sharex=True, figsize=(10, min(3 * naxes, 10))) if naxes == 1: axes = [axes] axes[0].set_title(f'{self} - {si_format(tstop)}s simulation') axes[-1].set_xlabel(f'time ({self.varunits["t"]})') for ax, vk in zip(axes, varkeys): ax.set_ylabel(f'{vk} ({self.varunits.get(vk, "-")})') # Add horizontal lines for node-specific SONIC steady-states on charge density plot Qm_ax = axes[varkeys.index('Qm')] for Qm, c in zip(self.Qminf, colors): Qm_ax.axhline(Qm * self.varfactors['Qm'], c=c, linestyle=':') if gradient: Qm_ax.axhline(np.diff(self.Qminf) * self.varfactors['Qm'], c=colors[-1], linestyle=':') # For each solution type for m, alpha, (mkey, varsdict) in zip(markers, alphas, sol.items()): tplt = t[mkey] * self.varfactors['t'] # For each solution variable for ax, (vkey, v) in zip(axes, varsdict.items()): # For each node for y, c, lbl in zip(v, colors, lbls): # Plot node variable with appropriate color and marker ax.plot(tplt, y * self.varfactors[vkey], m, alpha=alpha, c=c, label=f'{lbl} - {mkey}') # Add legend fig.subplots_adjust(bottom=0.2) axes[-1].legend( bbox_to_anchor=(0., -0.7, 1., .1), loc='upper center', ncol=3, mode="expand", borderaxespad=0.) # Return figure return fig - def plotV(self, t, sol): + def plotQnorm(self, t, sol): ''' Plot results of benchmark simulations of the model. ''' colors = ['C0', 'C1'] markers = ['-', '--', '-'] alphas = [0.5, 1., 1.] V = {key: value['Qm'] / self.Cm0 for key, value in sol.items()} fig, ax = plt.subplots(figsize=(10, 3)) ax.set_title(f'{self} - {t[list(t.keys())[0]][-1]:.2f} ms simulation') ax.set_xlabel(f'time ({self.varunits["t"]})') ax.set_ylabel(f'Qm / Cm0 (mV)') ax.set_ylim(-100.0, 50.) for m, alpha, (key, varsdict) in zip(markers, alphas, sol.items()): for y, c, lbl in zip(V[key], colors, self.nodelabels): ax.plot(t[key] * self.varfactors['t'], y * 1e3, m, alpha=alpha, c=c, label=f'{lbl} - {key}') fig.subplots_adjust(bottom=0.2) ax.legend(bbox_to_anchor=(0., -0.7, 1., .1), loc='upper center', ncol=3, mode="expand", borderaxespad=0.) return fig def simplot(self, *args, **kwargs): ''' Run benchmark simulation and plot results. ''' return self.plot(*self.simAllMethods(*args, **kwargs)) @property def eval_funcs(self): ''' Different functions to evaluate the divergence between two solutions. ''' return { 'rmse': lambda y1, y2: rmse(y1, y2), # RMSE 'ss': lambda y1, y2: np.abs(y1[-1] - y2[-1]), # steady-state absolute difference 'amax': lambda y1, y2: np.max(np.abs(y1 - y2)) # max absolute difference } def divergencePerNode(self, t, sol, eval_mode='RMSE'): ''' Evaluate the divergence between the effective and full, cycle-averaged solutions at a specific point in time, computing per-node differences in charge density values divided by resting capacitance. ''' if eval_mode not in self.eval_funcs.keys(): raise ValueError(f'{eval_mode} evaluation mode is not supported') # Extract charge matrices from solution dictionary Qsol = {k: sol[k]['Qm'] for k in ['effective', 'cycle-avg']} # C/m2 # Normalize matrices by resting capacitance Qnorm = {k: v / self.Cm0 * 1e3 for k, v in Qsol.items()} # mV # Keep only the first two rows (3rd one, if any, is a gradient) Qnorm = {k: v[:2, :] for k, v in Qnorm.items()} # Discard the first 3 columns (artifical onset and first cycle artefact) Qnorm = {k: v[:, 3:] for k, v in Qnorm.items()} eval_func = self.eval_funcs[eval_mode] # Compute deviation across nodes saccording to evaluation mode div_per_node = [eval_func(*[v[i] for v in Qnorm.values()]) for i in range(2)] # Cast into dictionary and return div_per_node = dict(zip(self.nodelabels, div_per_node)) logger.debug(f'divergence per node: ', {k: f'{v:.2e} mV' for k, v in div_per_node.items()}) return div_per_node def divergence(self, *args, **kwargs): div_per_node = self.divergencePerNode(*args, **kwargs) # mV return max(list(div_per_node.values())) # mV diff --git a/PySONIC/multicomp/divmaps.py b/PySONIC/multicomp/divmaps.py new file mode 100644 index 0000000..ff206d6 --- /dev/null +++ b/PySONIC/multicomp/divmaps.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- +# @Author: Theo Lemaire +# @Email: theo.lemaire@epfl.ch +# @Date: 2020-06-29 18:11:24 +# @Last Modified by: Theo Lemaire +# @Last Modified time: 2020-09-24 20:34:59 + +import os +import numpy as np +import matplotlib.pyplot as plt + +from ..utils import logger, bounds, si_format +from ..plt import XYMap + + +class DivergenceMap(XYMap): + ''' Interface to a 2D map showing divergence of the SONIC output from a + cycle-averaged NICE output, for various combinations of parameters. + ''' + zunit = 'mV' + zfactor = 1e0 + + def __init__(self, root, benchmark, eval_mode, *args, tstop=None, **kwargs): + self.benchmark = benchmark.copy() + self.eval_mode = eval_mode + self.tstop = tstop + super().__init__(root, *args, **kwargs) + + @property + def zkey(self): + return self.eval_mode + + @property + def suffix(self): + return self.eval_mode + + @property + def tstop(self): + if self._tstop is None: + return self.benchmark.passive_tstop + return self._tstop + + @tstop.setter + def tstop(self, value): + self._tstop = value + + def descPair(self, x1, x2): + raise NotImplementedError + + def updateBenchmark(self, x): + raise NotImplementedError + + def logDiv(self, x, div): + ''' Log divergence for a particular inputs combination. ''' + logger.info(f'{self.descPair(*x)}: {self.eval_mode} = {div:.2e} mV') + + def compute(self, x): + self.updateBenchmark(x) + t, sol = self.benchmark.simAllMethods(self.tstop) + div = self.benchmark.divergence(t, sol, eval_mode=self.eval_mode) # mV + self.logDiv(x, div) + return div + + def onClick(self, event): + ''' Execute action when the user clicks on a cell in the 2D map. ''' + x = self.getOnClickXY(event) + + # Update bechmark object to slected configuration + self.updateBenchmark(x) + + # Get divergence output from log + ix, iy = [np.where(vec == val)[0][0] for vec, val in zip([self.xvec, self.yvec], x)] + div_log = self.getOutput()[iy, ix] # mV + + # Simulate model and re-compute divergence + t, sol = self.benchmark.simAllMethods(self.tstop) + div = self.benchmark.divergence(t, sol, eval_mode=self.eval_mode) # mV + + # Raise error if computed divergence does not match log reference + if not np.isclose(div_log, div): + err_str = 'computed {} ({:.2e} mV) does not match log reference ({:.2e} mV)' + raise ValueError(err_str.format(self.eval_mode, div, div_log)) + + # Log divergence + self.logDiv(x, div) + + # Show related plot + fig = self.benchmark.plot(t, sol) + fig.axes[0].set_title(self.descPair(*x)) + plt.show() + + def render(self, zscale='log', levels=None, zbounds=(1e-1, 1e1), + extend_under=True, extend_over=True, cmap='Spectral_r', figsize=(6, 4), fs=12, + **kwargs): + ''' Render and add specific contour levels. ''' + fig = super().render( + zscale=zscale, zbounds=zbounds, extend_under=extend_under, extend_over=extend_over, + cmap=cmap, figsize=figsize, fs=fs, **kwargs) + if levels is not None: + ax = fig.axes[0] + fmt = lambda x: f'{x:g}' # ' mV' + CS = ax.contour( + self.xvec, self.yvec, self.getOutput(), levels, colors='k') + ax.clabel(CS, fontsize=fs, fmt=fmt, inline_spacing=2) + return fig + + +class ModelDivergenceMap(DivergenceMap): + ''' Divergence map of a passive model for various combinations of + membrane time constants (taum) and axial time constant (tauax) + ''' + + xkey = 'tau_m' + xfactor = 1e0 + xunit = 's' + ykey = 'tau_ax' + yfactor = 1e0 + yunit = 's' + ga_default = 1e0 # mS/cm2 + + @property + def title(self): + return f'Model divmap (f = {self.benchmark.fstr}, gamma = {self.benchmark.gammastr})' + + def corecode(self): + gstr = '_'.join(self.benchmark.gammalist) + code = f'model_divmap_f{self.benchmark.fstr}_gamma{gstr}' + return code.replace(' ', '') + + def descPair(self, taum, tauax): + return f'taum = {si_format(taum, 2)}s, tauax = {si_format(tauax, 2)}s' + + def updateBenchmark(self, x): + self.benchmark.setTimeConstants(*x) + + def render(self, xscale='log', yscale='log', add_periodicity=True, insets=None, **kwargs): + ''' Render with insets and drive periodicty indicator. ''' + fig = super().render(xscale=xscale, yscale=yscale, **kwargs) + fig.canvas.set_window_title(self.corecode()) + ax = fig.axes[0] + axis_to_data = ax.transAxes + ax.transData.inverted() + data_to_axis = axis_to_data.inverted() + + # Indicate periodicity if required + if add_periodicity: + T_US = 1 / self.benchmark.Fdrive + xyTUS = data_to_axis.transform((T_US, T_US)) + for i, k in enumerate(['h', 'v']): + getattr(ax, f'ax{k[0]}line')(T_US, color='k', linestyle='-', linewidth=1) + xy = np.empty(2) + xy_offset = np.empty(2) + xy[i] = xyTUS[i] + xy[1 - i] = 0. + xy_offset[i] = 0. + xy_offset[1 - i] = 0.2 + ax.annotate( + 'TUS', xy=xy, xytext=xy - xy_offset, xycoords=ax.transAxes, fontsize=10, + arrowprops={'facecolor': 'black', 'arrowstyle': '-'}, **{f'{k}a': 'center'}) + + # Add potential insets + if insets is not None: + for k, (taum, tauax) in insets.items(): + xy = data_to_axis.transform((taum, tauax)) + ax.scatter(*xy, transform=ax.transAxes, facecolor='k', edgecolor='none', + linestyle='--', lw=1) + ax.annotate(k, xy=xy, xytext=np.array(xy) + np.array([0, 0.1]), + xycoords=ax.transAxes, fontsize=10, + arrowprops={'facecolor': 'black', 'arrowstyle': '-'}, ha='right') + + return fig + + +class OldDriveDivergenceMap(DivergenceMap): + ''' Divergence map of a specific (membrane model, axial coupling) pairfor various + combinations of drive frequencies and drive amplitudes. + ''' + + xkey = 'f_US' + xfactor = 1e0 + xunit = 'kHz' + ykey = 'gamma' + yfactor = 1e0 + yunit = '-' + + @property + def title(self): + return f'Drive divergence map - {self.benchmark.pneuron.name}, tauax = {self.benchmark.tauax:.2e} ms)' + + def corecode(self): + if self.benchmark.isPassive(): + neuron_desc = f'passive_taum_{self.benchmark.taum:.2e}ms' + else: + neuron_desc = self.benchmark.pneuron.name + if self.benchmark.passive: + neuron_desc = f'passive_{neuron_desc}' + code = f'drive_divmap_{neuron_desc}_tauax_{self.benchmark.tauax:.2e}ms' + if self._tstop is not None: + code = f'{code}_tstop{self.tstop:.2f}ms' + return code + + def descPair(self, f_US, A_Cm): + return f'f = {f_US:.2f} kHz, gamma = {A_Cm:.2f}' + + def updateBenchmark(self, x): + f, gamma = x + self.benchmark.setDrive(f, (gamma, 0.)) + + def threshold_filename(self, method): + fmin, fmax = bounds(self.xvec) + return f'{self.corecode()}_f{fmin:.0f}kHz_{fmax:.0f}kHz_{self.xvec.size}_gammathrs_{method}.txt' + + def threshold_filepath(self, *args, **kwargs): + return os.path.join(self.root, self.threshold_filename(*args, **kwargs)) + + def addThresholdCurves(self, ax): + ls = ['--', '-.'] + for j, method in enumerate(['effective', 'full']): + fpath = self.threshold_filepath(method) + if os.path.isfile(fpath): + gamma_thrs = np.loadtxt(fpath) + else: + gamma_thrs = np.empty(self.xvec.size) + for i, f in enumerate(self.xvec): + self.benchmark.f = f + gamma_thrs[i] = self.benchmark.titrate(self.tstop, method=method) + np.savetxt(fpath, gamma_thrs) + ylims = ax.get_ylim() + ax.plot(self.xvec * self.xfactor, gamma_thrs * self.yfactor, ls[j], color='k') + ax.set_ylim(ylims) + + def render(self, xscale='log', thresholds=False, **kwargs): + fig = super().render(xscale=xscale, **kwargs) + if thresholds: + self.addThresholdCurves(fig.axes[0]) + return fig diff --git a/PySONIC/multicomp/gamma_map.py b/PySONIC/multicomp/gamma_map.py new file mode 100644 index 0000000..ffc8ea3 --- /dev/null +++ b/PySONIC/multicomp/gamma_map.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# @Author: Theo Lemaire +# @Email: theo.lemaire@epfl.ch +# @Date: 2020-09-24 19:00:54 +# @Last Modified by: Theo Lemaire +# @Last Modified time: 2020-09-24 19:14:57 + +import os +import numpy as np +import matplotlib.pyplot as plt + +from ..core import AcousticDrive, Lookup +from ..utils import logger, si_format, bounds +from ..plt import XYMap +from ..constants import NPC_DENSE + + +class GammaMap(XYMap): + ''' Interface to a 2D map showing relative capacitance oscillation amplitude + resulting from BLS simulations at various frequencies and amplitude. + ''' + xkey = 'f' + xfactor = 1e0 + xunit = 'Hz' + ykey = 'A' + yfactor = 1e0 + yunit = 'Pa' + zkey = 'gamma' + zfactor = 1e0 + zunit = '-' + suffix = 'gamma' + + def __init__(self, root, bls, Qm, freqs, amps): + self.bls = bls.copy() + self.Qm = Qm + super().__init__(root, freqs, amps) + + @property + def title(self): + return f'Gamma map - {self.bls}' + + @property + def pdict(self): + return { + 'a': f'{self.bls.a * 1e9:.0f}nm', + 'Cm0': f'{self.bls.Cm0 * 1e2:.1f}uF_cm2', + 'Qm0': f'{self.bls.Qm0 * 1e5:.0f}nC_cm2', + 'Qm': f'{self.Qm * 1e5:.0f}nC_cm2', + } + + @property + def pcode(self): + return 'bls_' + '_'.join([f'{k}{v}' for k, v in self.pdict.items()]) + + def corecode(self): + return f'gamma_map_{self.pcode}' + + def compute(self, x): + f, A = x + data = self.bls.simCycles(AcousticDrive(f, A), self.Qm).tail(NPC_DENSE) + Cm = self.bls.v_capacitance(data['Z']) + gamma = np.ptp(Cm) / self.bls.Cm0 + logger.info(f'f = {si_format(f, 1)}Hz, A = {si_format(A)}Pa, gamma = {gamma:.2f}') + return gamma + + def onClick(self, event): + ''' Show capacitance profile when the user clicks on a cell in the 2D map. ''' + x = self.getOnClickXY(event) + f, A = x + + # Simulate mechanical model + data = self.bls.simCycles(AcousticDrive(f, A), self.Qm).tail(NPC_DENSE) + + # Retrieve time and relative capacitance profiles from last cycle + t = data['t'].values + rel_Cm = self.bls.v_capacitance(data['Z']) / self.bls.Cm0 + + # Create figure + fig, ax = plt.subplots() + ax.set_title(f'f = {si_format(f, 1)}Hz, A = {si_format(A)}Pa') + ax.set_xlabel('time (us)') + ax.set_ylabel('Cm / Cm0') + for sk in ['right', 'top']: + ax.spines[sk].set_visible(False) + + # Plot capacitance profile + ax.plot((t - t[0]) * 1e6, rel_Cm) + ax.axhline(1.0, c='k', linewidth=0.5) + + # Indicate relative oscillation range + ybounds = bounds(rel_Cm) + gamma = ybounds[1] - ybounds[0] + for y in ybounds: + ax.axhline(y, linestyle='--', c='k') + axis_to_data = ax.transAxes + ax.transData.inverted() + data_to_axis = axis_to_data.inverted() + ax_ybounds = [data_to_axis.transform((ax.get_ylim()[0], y))[1] for y in ybounds] + xarrow = 0.9 + ax.text(0.85, np.mean(ax_ybounds), f'gamma = {gamma:.2f}', transform=ax.transAxes, + rotation='vertical', va='center', ha='center', color='k', fontsize=10) + ax.annotate( + '', xy=(xarrow, ax_ybounds[0]), xytext=(xarrow, ax_ybounds[1]), + xycoords='axes fraction', textcoords='axes fraction', + arrowprops=dict(facecolor='k', edgecolor='k', arrowstyle='<|-|>')) + + # Show figure + plt.show() + + def render(self, xscale='log', yscale='log', figsize=(6, 4), fs=12, levels=None, **kwargs): + ''' Render and add specific contour levels. ''' + fig = super().render(xscale=xscale, yscale=yscale, figsize=figsize, fs=fs, **kwargs) + if levels is not None: + colors = ['k' if l > 0.5 else 'w' for l in levels] + ax = fig.axes[0] + CS = ax.contour( + self.xvec, self.yvec, self.getOutput(), levels, colors=colors) + ax.clabel(CS, fontsize=fs, fmt=lambda x: f'{x:g}', inline_spacing=2) + return fig + + def toPickle(self, root): + ''' Ouput map to a lookup file (adding amplitude-zero). ''' + lkp = Lookup( + {'f': self.xvec, 'A': np.hstack(([0.], self.yvec))}, + {'gamma': np.vstack([np.zeros(self.xvec.size), self.getOutput()]).T}) + + xcode = self.rangecode(lkp.refs['f'], self.xkey, self.xunit) + ycode = self.rangecode(lkp.refs['A'], self.ykey, self.yunit) + xycode = '_'.join([xcode, ycode]) + lkp.toPickle(os.path.join(root, f'gamma_lkp_{self.pcode}_{xycode}.lkp')) diff --git a/PySONIC/plt/actmap.py b/PySONIC/plt/actmap.py index 0789b70..9dbb867 100644 --- a/PySONIC/plt/actmap.py +++ b/PySONIC/plt/actmap.py @@ -1,558 +1,456 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2019-06-04 18:24:29 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-09-24 15:26:27 +# @Last Modified time: 2020-09-24 19:03:20 -import os import abc import csv from itertools import product import numpy as np import pandas as pd import matplotlib.pyplot as plt import copy -from ..core import NeuronalBilayerSonophore, PulsedProtocol, AcousticDrive, LogBatch, Batch, Lookup -from ..utils import logger, si_format, isIterable, bounds +from ..core import NeuronalBilayerSonophore, PulsedProtocol, AcousticDrive, LogBatch, Batch +from ..utils import logger, si_format, isIterable from .pltutils import cm2inch, setNormalizer from .timeseries import GroupedTimeSeries from ..postpro import detectSpikes -from ..constants import NPC_DENSE class XYMap(LogBatch): ''' Generic 2D map object interface. ''' offset_options = { 'lr': (1, -1), 'ur': (1, 1), 'll': (-1, -1), 'ul': (-1, 1) } def __init__(self, root, xvec, yvec): self.root = root self.xvec = xvec self.yvec = yvec super().__init__([list(pair) for pair in product(self.xvec, self.yvec)], root=root) def checkVector(self, name, value): if not isIterable(value): raise ValueError(f'{name} vector must be an iterable') if not isinstance(value, np.ndarray): value = np.asarray(value) if len(value.shape) > 1: raise ValueError(f'{name} vector must be one-dimensional') return value @property def in_key(self): return self.xkey @property def unit(self): return self.xunit @property def xvec(self): return self._xvec @xvec.setter def xvec(self, value): self._xvec = self.checkVector('x', value) @property def yvec(self): return self._yvec @yvec.setter def yvec(self, value): self._yvec = self.checkVector('x', value) @property @abc.abstractmethod def xkey(self): raise NotImplementedError @property @abc.abstractmethod def xfactor(self): raise NotImplementedError @property @abc.abstractmethod def xunit(self): raise NotImplementedError @property @abc.abstractmethod def ykey(self): raise NotImplementedError @property @abc.abstractmethod def yfactor(self): raise NotImplementedError @property @abc.abstractmethod def yunit(self): raise NotImplementedError @property @abc.abstractmethod def zkey(self): raise NotImplementedError @property @abc.abstractmethod def zunit(self): raise NotImplementedError @property @abc.abstractmethod def zfactor(self): raise NotImplementedError @property def out_keys(self): return [f'{self.zkey} ({self.zunit})'] @property def in_labels(self): return [f'{self.xkey} ({self.xunit})', f'{self.ykey} ({self.yunit})'] def getLogData(self): ''' Retrieve the batch log file data (inputs and outputs) as a dataframe. ''' return pd.read_csv(self.fpath, sep=self.delimiter).sort_values(self.in_labels) def getInput(self): ''' Retrieve the logged batch inputs as an array. ''' return self.getLogData()[self.in_labels].values def getOutput(self): return np.reshape(super().getOutput(), (self.xvec.size, self.yvec.size)).T def writeLabels(self): with open(self.fpath, 'w') as csvfile: writer = csv.writer(csvfile, delimiter=self.delimiter) writer.writerow([*self.in_labels, *self.out_keys]) def isEntry(self, comb): ''' Check if a given input is logged in the batch log file. ''' inputs = self.getInput() if len(inputs) == 0: return False imatches_x = np.where(np.isclose(inputs[:, 0], comb[0], rtol=self.rtol, atol=self.atol))[0] imatches_y = np.where(np.isclose(inputs[:, 1], comb[1], rtol=self.rtol, atol=self.atol))[0] imatches = list(set(imatches_x).intersection(imatches_y)) if len(imatches) == 0: return False return True @property def inputscode(self): ''' String describing the batch inputs. ''' xcode = self.rangecode(self.xvec, self.xkey, self.xunit) ycode = self.rangecode(self.yvec, self.ykey, self.yunit) return '_'.join([xcode, ycode]) @staticmethod def getScaleType(x): xmin, xmax, nx = x.min(), x.max(), x.size if np.all(np.isclose(x, np.logspace(np.log10(xmin), np.log10(xmax), nx))): return 'log' else: return 'lin' # elif np.all(np.isclose(x, np.linspace(xmin, xmax, nx))): # return 'lin' # else: # raise ValueError('Unknown distribution type') @property def xscale(self): return self.getScaleType(self.xvec) @property def yscale(self): return self.getScaleType(self.yvec) @staticmethod def computeMeshEdges(x, scale): ''' Compute the appropriate edges of a mesh that quads a linear or logarihtmic distribution. :param x: the input vector :param scale: the type of distribution ('lin' for linear, 'log' for logarihtmic) :return: the edges vector ''' if scale == 'log': x = np.log10(x) range_func = np.logspace else: range_func = np.linspace dx = x[1] - x[0] n = x.size + 1 return range_func(x[0] - dx / 2, x[-1] + dx / 2, n) @abc.abstractmethod def compute(self, x): ''' Compute the necessary output(s) for a given inputs combination. ''' raise NotImplementedError def run(self, **kwargs): super().run(**kwargs) self.getLogData().to_csv(self.filepath(), sep=self.delimiter, index=False) def getOnClickXY(self, event): ''' Get x and y values from from x and y click event coordinates. ''' x = self.xvec[np.searchsorted(self.xedges, event.xdata) - 1] y = self.yvec[np.searchsorted(self.yedges, event.ydata) - 1] return x, y def onClick(self, event): ''' Exexecute specific action when the user clicks on a cell in the 2D map. ''' pass @property @abc.abstractmethod def title(self): raise NotImplementedError def getZBounds(self): matrix = self.getOutput() zmin, zmax = np.nanmin(matrix), np.nanmax(matrix) logger.info( f'{self.zkey} range: {zmin:.0f} - {zmax:.0f} {self.zunit}') return zmin, zmax def checkZbounds(self, zbounds): zmin, zmax = self.getZBounds() if zmin < zbounds[0]: logger.warning( f'Minimal {self.zkey} ({zmin:.0f} {self.zunit}) is below defined lower bound ({zbounds[0]:.0f} {self.zunit})') if zmax > zbounds[1]: logger.warning( f'Maximal {self.zkey} ({zmax:.0f} {self.zunit}) is above defined upper bound ({zbounds[1]:.0f} {self.zunit})') def render(self, xscale='lin', yscale='lin', zscale='lin', zbounds=None, fs=8, cmap='viridis', interactive=False, figsize=None, insets=None, inset_offset=0.05, extend_under=False, extend_over=False): # Get figure size if figsize is None: figsize = cm2inch(12, 7) # Compute Z normalizer mymap = copy.copy(plt.get_cmap(cmap)) if not extend_under: mymap.set_under('silver') if not extend_over: mymap.set_over('silver') if zbounds is None: zbounds = self.getZBounds() else: self.checkZbounds(zbounds) norm, sm = setNormalizer(mymap, zbounds, zscale) nan_eq = zbounds[0] - 1 if zscale == 'lin' else 0.5 * zbounds[0] # Compute mesh edges self.xedges = self.computeMeshEdges(self.xvec, xscale) self.yedges = self.computeMeshEdges(self.yvec, yscale) # Create figure fig, ax = plt.subplots(figsize=figsize) fig.subplots_adjust(left=0.15, bottom=0.15, right=0.8, top=0.92) ax.set_title(self.title, fontsize=fs) ax.set_xlabel(f'{self.xkey} ({self.xunit})', fontsize=fs, labelpad=-0.5) ax.set_ylabel(f'{self.ykey} ({self.yunit})', fontsize=fs) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) if xscale == 'log': ax.set_xscale('log') if yscale == 'log': ax.set_yscale('log') # Retrieve data and replace NaNs with specific out-of-bounds value data = self.getOutput() data[np.isnan(data)] = nan_eq # Plot map with specific color code ax.pcolormesh(self.xedges, self.yedges, data, cmap=mymap, norm=norm) # Plot potential insets if insets is not None: x_data, y_data, *_ = zip(*insets) ax.scatter(x_data, y_data, s=80, facecolors='none', edgecolors='k', linestyle='--', lw=1) axis_to_data = ax.transAxes + ax.transData.inverted() data_to_axis = axis_to_data.inverted() for x, y, label, direction in insets: xyoffset = np.array(self.offset_options[direction]) * inset_offset # in axis coords xytext = axis_to_data.transform(np.array(data_to_axis.transform((x, y))) + xyoffset) ax.annotate(label, xy=(x, y), xytext=xytext, fontsize=fs, horizontalalignment='right', arrowprops={'facecolor': 'black', 'arrowstyle': '-'}) # Plot z-scale colorbar pos1 = ax.get_position() # get the map axis position cbarax = fig.add_axes([pos1.x1 + 0.02, pos1.y0, 0.03, pos1.height]) if not extend_under and not extend_over: extend = 'neither' elif extend_under and extend_over: extend = 'both' else: extend = 'max' if extend_over else 'min' fig.colorbar(sm, cax=cbarax, extend=extend) cbarax.set_ylabel(f'{self.zkey} ({self.zunit})', fontsize=fs) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) if interactive: fig.canvas.mpl_connect('button_press_event', lambda event: self.onClick(event)) return fig class ActivationMap(XYMap): xkey = 'Duty cycle' xfactor = 1e2 xunit = '%' ykey = 'Amplitude' yfactor = 1e-3 yunit = 'kPa' onclick_colors = None def __init__(self, root, pneuron, a, fs, f, tstim, PRF, amps, DCs): self.nbls = NeuronalBilayerSonophore(a, pneuron) self.drive = AcousticDrive(f, None) self.pp = PulsedProtocol(tstim, 0., PRF, .5) self.fs = fs super().__init__(root, DCs * self.xfactor, amps * self.yfactor) @property def sim_args(self): return [self.drive, self.pp, self.fs, 'sonic', None] @property def title(self): s = 'Activation map - {} neuron @ {}Hz, {}Hz PRF ({}m sonophore'.format( self.nbls.pneuron.name, *si_format([self.drive.f, self.pp.PRF, self.nbls.a])) if self.fs < 1: s = f'{s}, {self.fs * 1e2:.0f}% coverage' return f'{s})' def corecode(self): corecodes = self.nbls.filecodes(*self.sim_args) del corecodes['nature'] if 'DC' in corecodes: del corecodes['DC'] return '_'.join(filter(lambda x: x is not None, corecodes.values())) def compute(self, x): ''' Compute firing rate from simulation output ''' # Adapt drive and pulsed protocol self.pp.DC = x[0] / self.xfactor self.drive.A = x[1] / self.yfactor # Get model output, running simulation if needed data, _ = self.nbls.getOutput(*self.sim_args, outputdir=self.root) return self.xfunc(data) @abc.abstractmethod def xfunc(self, data): raise NotImplementedError def addThresholdCurve(self, ax, fs, mpi=False): queue = [[ self.drive, PulsedProtocol(self.pp.tstim, self.pp.toffset, self.pp.PRF, DC / self.xfactor), self.fs, 'sonic', None] for DC in self.xvec] batch = Batch(self.nbls.titrate, queue) Athrs = np.array(batch.run(mpi=mpi, loglevel=logger.level)) ax.plot(self.xvec, Athrs * self.yfactor, '-', color='#F26522', linewidth=3, label='threshold amplitudes') ax.legend(loc='lower center', frameon=False, fontsize=fs) @property @abc.abstractmethod def onclick_pltscheme(self): raise NotImplementedError def onClick(self, event): ''' Execute action when the user clicks on a cell in the 2D map. ''' DC, A = self.getOnClickXY(event) self.plotTimeseries(DC, A) plt.show() def plotTimeseries(self, DC, A, **kwargs): ''' Plot related timeseries for a given duty cycle and amplitude. ''' self.drive.A = A / self.yfactor self.pp.DC = DC / self.xfactor # Get model output, running simulation if needed data, meta = self.nbls.getOutput(*self.sim_args, outputdir=self.root) # Plot timeseries of appropriate variables timeseries = GroupedTimeSeries([(data, meta)], pltscheme=self.onclick_pltscheme) return timeseries.render(colors=self.onclick_colors, **kwargs)[0] def render(self, yscale='log', thresholds=False, mpi=False, **kwargs): fig = super().render(yscale=yscale, **kwargs) if thresholds: self.addThresholdCurve(fig.axes[0], fs=12, mpi=mpi) return fig class FiringRateMap(ActivationMap): zkey = 'Firing rate' zunit = 'Hz' zfactor = 1e0 suffix = 'FRmap' onclick_pltscheme = {'V_m\ |\ Q_/C_{m0}': ['Vm', 'Qm/Cm0']} onclick_colors = ['darkgrey', 'k'] def xfunc(self, data): ''' Detect spikes in data and compute firing rate. ''' ispikes, _ = detectSpikes(data) if ispikes.size > 1: t = data['t'].values sr = 1 / np.diff(t[ispikes]) return np.mean(sr) else: return np.nan def render(self, zscale='log', **kwargs): return super().render(zscale=zscale, **kwargs) class CalciumMap(ActivationMap): zkey = '[Ca2+]i' zunit = 'uM' zfactor = 1e6 suffix = 'Camap' onclick_pltscheme = {'Cai': ['Cai']} def xfunc(self, data): ''' Detect spikes in data and compute firing rate. ''' Cai = data['Cai'].values * self.zfactor # uM return np.mean(Cai) def render(self, zscale='log', **kwargs): return super().render(zscale=zscale, **kwargs) map_classes = { 'FR': FiringRateMap, 'Cai': CalciumMap } def getActivationMap(key, *args, **kwargs): if key not in map_classes: raise ValueError(f'{key} is not a valid map type') return map_classes[key](*args, **kwargs) - - -class GammaMap(XYMap): - ''' Interface to a 2D map showing relative capacitance oscillation amplitude - resulting from BLS simulations at various frequencies and amplitude. - ''' - xkey = 'f' - xfactor = 1e0 - xunit = 'Hz' - ykey = 'A' - yfactor = 1e0 - yunit = 'Pa' - zkey = 'gamma' - zfactor = 1e0 - zunit = '-' - suffix = 'gamma' - - def __init__(self, root, bls, Qm, freqs, amps): - self.bls = bls.copy() - self.Qm = Qm - super().__init__(root, freqs, amps) - - @property - def title(self): - return f'Gamma map - {self.bls}' - - @property - def pdict(self): - return { - 'a': f'{self.bls.a * 1e9:.0f}nm', - 'Cm0': f'{self.bls.Cm0 * 1e2:.1f}uF_cm2', - 'Qm0': f'{self.bls.Qm0 * 1e5:.0f}nC_cm2', - 'Qm': f'{self.Qm * 1e5:.0f}nC_cm2', - } - - @property - def pcode(self): - return 'bls_' + '_'.join([f'{k}{v}' for k, v in self.pdict.items()]) - - def corecode(self): - return f'gamma_map_{self.pcode}' - - def compute(self, x): - f, A = x - data = self.bls.simCycles(AcousticDrive(f, A), self.Qm).tail(NPC_DENSE) - Cm = self.bls.v_capacitance(data['Z']) - gamma = np.ptp(Cm) / self.bls.Cm0 - logger.info(f'f = {si_format(f, 1)}Hz, A = {si_format(A)}Pa, gamma = {gamma:.2f}') - return gamma - - def onClick(self, event): - ''' Execute action when the user clicks on a cell in the 2D map. ''' - x = self.getOnClickXY(event) - f, A = x - data = self.bls.simCycles(AcousticDrive(f, A), self.Qm).tail(NPC_DENSE) - t = data['t'].values - rel_Cm = self.bls.v_capacitance(data['Z']) / self.bls.Cm0 - gamma = np.ptp(rel_Cm) - fig, ax = plt.subplots() - ax.set_xlabel('time (us)') - ax.set_ylabel('Cm / Cm0') - for sk in ['right', 'top']: - ax.spines[sk].set_visible(False) - ax.plot((t - t[0]) * 1e6, rel_Cm) - ax.axhline(1.0, c='k', linewidth=0.5) - ybounds = bounds(rel_Cm) - for y in ybounds: - ax.axhline(y, linestyle='--', c='k') - - axis_to_data = ax.transAxes + ax.transData.inverted() - data_to_axis = axis_to_data.inverted() - ax_ybounds = [data_to_axis.transform((ax.get_ylim()[0], y))[1] for y in ybounds] - xarrow = 0.9 - ax.text(0.85, np.mean(ax_ybounds), f'gamma = {gamma:.2f}', transform=ax.transAxes, - rotation='vertical', va='center', ha='center', color='k') - ax.annotate( - '', xy=(xarrow, ax_ybounds[0]), xytext=(xarrow, ax_ybounds[1]), - xycoords='axes fraction', textcoords='axes fraction', - arrowprops=dict(facecolor='k', edgecolor='k', arrowstyle='<|-|>')) - plt.show() - - def render(self, xscale='log', yscale='log', figsize=(6, 4), fs=12, levels=None, **kwargs): - fig = super().render(xscale=xscale, yscale=yscale, figsize=figsize, fs=fs, **kwargs) - if levels is not None: - colors = ['k' if l > 0.5 else 'w' for l in levels] - ax = fig.axes[0] - CS = ax.contour( - self.xvec, self.yvec, self.getOutput(), levels, colors=colors) - ax.clabel(CS, fontsize=fs, fmt=lambda x: f'{x:g}', inline_spacing=2) - return fig - - def toPickle(self, root): - lkp = Lookup( - {'f': self.xvec, 'A': np.hstack(([0.], self.yvec))}, - {'gamma': np.vstack([np.zeros(self.xvec.size), self.getOutput()]).T}) - - xcode = self.rangecode(lkp.refs['f'], self.xkey, self.xunit) - ycode = self.rangecode(lkp.refs['A'], self.ykey, self.yunit) - xycode = '_'.join([xcode, ycode]) - lkp.toPickle(os.path.join(root, f'gamma_lkp_{self.pcode}_{xycode}.lkp')) diff --git a/PySONIC/utils.py b/PySONIC/utils.py index 1ccfbfe..1960946 100644 --- a/PySONIC/utils.py +++ b/PySONIC/utils.py @@ -1,1155 +1,1155 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2016-09-19 22:30:46 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-09-22 17:34:17 +# @Last Modified time: 2020-09-24 20:42:31 ''' Definition of generic utility functions used in other modules ''' import sys import itertools import csv from functools import wraps import operator import time from inspect import signature import os from shutil import get_terminal_size import lockfile import math import pickle import json from tqdm import tqdm import logging import tkinter as tk from tkinter import filedialog import base64 import datetime import numpy as np import pandas as pd from scipy.optimize import brentq from scipy.interpolate import interp1d from scipy import linalg import colorlog from pushbullet import Pushbullet # Package logger my_log_formatter = colorlog.ColoredFormatter( '%(log_color)s %(asctime)s %(message)s', datefmt='%d/%m/%Y %H:%M:%S:', reset=True, log_colors={ 'DEBUG': 'green', 'INFO': 'white', 'WARNING': 'yellow', 'ERROR': 'red', 'CRITICAL': 'red,bg_white', }, style='%') def setHandler(logger, handler): for h in logger.handlers: logger.removeHandler(h) logger.addHandler(handler) return logger def setLogger(name, formatter): handler = colorlog.StreamHandler() handler.setFormatter(formatter) handler.stream = sys.stdout logger = colorlog.getLogger(name) logger.addHandler(handler) return logger class TqdmHandler(logging.StreamHandler): def __init__(self, formatter): logging.StreamHandler.__init__(self) self.setFormatter(formatter) def emit(self, record): msg = self.format(record) tqdm.write(msg) logger = setLogger('PySONIC', my_log_formatter) LOOKUP_DIR = os.path.abspath(os.path.split(__file__)[0] + "/lookups/") def fillLine(text, char='-', totlength=None): ''' Surround a text with repetitions of a specific character in order to fill a line to a given total length. :param text: text to be surrounded :param char: surrounding character :param totlength: target number of characters in filled text line :return: filled text line ''' if totlength is None: totlength = get_terminal_size().columns - 1 ndashes = totlength - len(text) - 2 if ndashes < 2: return text else: nside = ndashes // 2 nleft, nright = nside, nside if ndashes % 2 == 1: nright += 1 return f'{char * nleft} {text} {char * nright}' # SI units prefixes si_prefixes = { 'y': 1e-24, # yocto 'z': 1e-21, # zepto 'a': 1e-18, # atto 'f': 1e-15, # femto 'p': 1e-12, # pico 'n': 1e-9, # nano 'u': 1e-6, # micro 'm': 1e-3, # mili '': 1e0, # None 'k': 1e3, # kilo 'M': 1e6, # mega 'G': 1e9, # giga 'T': 1e12, # tera 'P': 1e15, # peta 'E': 1e18, # exa 'Z': 1e21, # zetta 'Y': 1e24, # yotta } sorted_si_prefixes = sorted(si_prefixes.items(), key=operator.itemgetter(1)) def getSIpair(x, scale='lin'): ''' Get the correct SI factor and prefix for a floating point number. ''' if isIterable(x): # If iterable, get a representative number of the distribution x = np.asarray(x) x = x.prod()**(1.0 / x.size) if scale == 'log' else np.mean(x) if x == 0: return 1e0, '' else: vals = [tmp[1] for tmp in sorted_si_prefixes] ix = np.searchsorted(vals, np.abs(x)) - 1 if np.abs(x) == vals[ix + 1]: ix += 1 return vals[ix], sorted_si_prefixes[ix][0] def si_format(x, precision=0, space=' '): ''' Format a float according to the SI unit system, with the appropriate prefix letter. ''' if isinstance(x, float) or isinstance(x, int) or isinstance(x, np.float) or\ isinstance(x, np.int32) or isinstance(x, np.int64): factor, prefix = getSIpair(x) return f'{x / factor:.{precision}f}{space}{prefix}' elif isinstance(x, list) or isinstance(x, tuple): return [si_format(item, precision, space) for item in x] elif isinstance(x, np.ndarray) and x.ndim == 1: return [si_format(float(item), precision, space) for item in x] else: raise ValueError(f'cannot si_format {type(x)} objects') def pow10_format(number, precision=2): ''' Format a number in power of 10 notation. ''' sci_string = f'{number:.{precision}e}' value, exponent = sci_string.split("e") value, exponent = float(value), int(exponent) val_str = f'{value} * ' if value != 1. else '' return f'{val_str}10^{{{exponent}}}' def rmse(x1, x2, axis=None): ''' Compute the root mean square error between two 1D arrays ''' return np.sqrt(((x1 - x2) ** 2).mean(axis=axis)) def rsquared(x1, x2): ''' compute the R-squared coefficient between two 1D arrays ''' residuals = x1 - x2 ss_res = np.sum(residuals**2) ss_tot = np.sum((x1 - np.mean(x1))**2) return 1 - (ss_res / ss_tot) def Pressure2Intensity(p, rho=1075.0, c=1515.0): ''' Return the spatial peak, pulse average acoustic intensity (ISPPA) associated with the specified pressure amplitude. :param p: pressure amplitude (Pa) :param rho: medium density (kg/m3) :param c: speed of sound in medium (m/s) :return: spatial peak, pulse average acoustic intensity (W/m2) ''' return p**2 / (2 * rho * c) def Intensity2Pressure(I, rho=1075.0, c=1515.0): ''' Return the pressure amplitude associated with the specified spatial peak, pulse average acoustic intensity (ISPPA). :param I: spatial peak, pulse average acoustic intensity (W/m2) :param rho: medium density (kg/m3) :param c: speed of sound in medium (m/s) :return: pressure amplitude (Pa) ''' return np.sqrt(2 * rho * c * I) def convertPKL2JSON(): for pkl_filepath in OpenFilesDialog('pkl')[0]: logger.info(f'Processing {pkl_filepath} ...') json_filepath = f'{os.path.splitext(pkl_filepath)[0]}.json' with open(pkl_filepath, 'rb') as fpkl, open(json_filepath, 'w') as fjson: data = pickle.load(fpkl) json.dump(data, fjson, ensure_ascii=False, sort_keys=True, indent=4) logger.info('All done!') def OpenFilesDialog(filetype, dirname=''): ''' Open a FileOpenDialogBox to select one or multiple file. The default directory and file type are given. :param dirname: default directory :param filetype: default file type :return: tuple of full paths to the chosen filenames ''' root = tk.Tk() root.withdraw() filenames = filedialog.askopenfilenames( filetypes=[(filetype + " files", '.' + filetype)], initialdir=dirname ) if len(filenames) == 0: raise ValueError('no input file selected') par_dir = os.path.abspath(os.path.join(filenames[0], os.pardir)) return filenames, par_dir def selectDirDialog(title='Select directory'): ''' Open a dialog box to select a directory. :return: full path to selected directory ''' root = tk.Tk() root.withdraw() directory = filedialog.askdirectory(title=title) if directory == '': raise ValueError('no directory selected') return directory def SaveFileDialog(filename, dirname=None, ext=None): ''' Open a dialog box to save file. :param filename: filename :param dirname: initial directory :param ext: default extension :return: full path to the chosen filename ''' root = tk.Tk() root.withdraw() filename_out = filedialog.asksaveasfilename( defaultextension=ext, initialdir=dirname, initialfile=filename) if len(filename_out) == 0: raise ValueError('no output filepath selected') return filename_out def loadData(fpath, frequency=1): ''' Load dataframe and metadata dictionary from pickle file. ''' logger.info('Loading data from "%s"', os.path.basename(fpath)) with open(fpath, 'rb') as fh: frame = pickle.load(fh) df = frame['data'].iloc[::frequency] meta = frame['meta'] return df, meta def rescale(x, lb=None, ub=None, lb_new=0, ub_new=1): ''' Rescale a value to a specific interval by linear transformation. ''' if lb is None: lb = x.min() if ub is None: ub = x.max() xnorm = (x - lb) / (ub - lb) return xnorm * (ub_new - lb_new) + lb_new def expandRange(xmin, xmax, exp_factor=2): ''' Expand a range by a specific factor around its mid-point. ''' if xmin > xmax: raise ValueError('values must be provided in (min, max) order') xptp = xmax - xmin xmid = (xmin + xmax) / 2 xdev = xptp * exp_factor / 2 - return (xmid - xdev, xmin + xdev) + return (xmid - xdev, xmid + xdev) def isIterable(x): for t in [list, tuple, np.ndarray]: if isinstance(x, t): return True return False def isWithin(name, val, bounds, rel_tol=1e-9, raise_warning=True): ''' Check if a floating point number is within an interval. If the value falls outside the interval, an error is raised. If the value falls just outside the interval due to rounding errors, the associated interval bound is returned. :param val: float value :param bounds: interval bounds (float tuple) :return: original or corrected value ''' if isIterable(val): return np.array([isWithin(name, v, bounds, rel_tol, raise_warning) for v in val]) if val >= bounds[0] and val <= bounds[1]: return val elif val < bounds[0] and math.isclose(val, bounds[0], rel_tol=rel_tol): if raise_warning: logger.warning( 'Rounding %s value (%s) to interval lower bound (%s)', name, val, bounds[0]) return bounds[0] elif val > bounds[1] and math.isclose(val, bounds[1], rel_tol=rel_tol): if raise_warning: logger.warning( 'Rounding %s value (%s) to interval upper bound (%s)', name, val, bounds[1]) return bounds[1] else: raise ValueError(f'{name} value ({val}) out of [{bounds[0]}, {bounds[1]}] interval') def getDistribution(xmin, xmax, nx, scale='lin'): if scale == 'log': xmin, xmax = np.log10(xmin), np.log10(xmax) return {'lin': np.linspace, 'log': np.logspace}[scale](xmin, xmax, nx) def getDistFromList(xlist): if not isinstance(xlist, list): raise TypeError('Input must be a list') if len(xlist) != 4: raise ValueError('List must contain exactly 4 arguments ([type, min, max, n])') scale = xlist[0] if scale not in ('log', 'lin'): raise ValueError('Unknown distribution type (must be "lin" or "log")') xmin, xmax = [float(x) for x in xlist[1:-1]] if xmin >= xmax: raise ValueError('Specified minimum higher or equal than specified maximum') nx = int(xlist[-1]) if nx < 2: raise ValueError('Specified number must be at least 2') return getDistribution(xmin, xmax, nx, scale=scale) def getIndex(container, value): ''' Return the index of a float / string value in a list / array :param container: list / 1D-array of elements :param value: value to search for :return: index of value (if found) ''' if isinstance(value, float): container = np.array(container) imatches = np.where(np.isclose(container, value, rtol=1e-9, atol=1e-16))[0] if len(imatches) == 0: raise ValueError(f'{value} not found in {container}') return imatches[0] elif isinstance(value, str): return container.index(value) def funcSig(func, args, kwargs): args_repr = [repr(a) for a in args] kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()] return f'{func.__name__}({", ".join(args_repr + kwargs_repr)})' def debug(func): ''' Print the function signature and return value. ''' @wraps(func) def wrapper_debug(*args, **kwargs): print(f'Calling {funcSig(func, args, kwargs)}') value = func(*args, **kwargs) print(f"{func.__name__!r} returned {value!r}") return value return wrapper_debug def timer(func): ''' Monitor and return the runtime of the decorated function. ''' @wraps(func) def wrapper(*args, **kwargs): start_time = time.perf_counter() value = func(*args, **kwargs) end_time = time.perf_counter() run_time = end_time - start_time return value, run_time return wrapper def alignWithFuncDef(func, args, kwargs): ''' Align a set of provided positional and keyword arguments with the arguments signature in a specific function definition. :param func: function object :param args: list of provided positional arguments :param kwargs: dictionary of provided keyword arguments :return: 2-tuple with the modified arguments and ''' # Get positional and keyword arguments from function signature sig_params = {k: v for k, v in signature(func).parameters.items()} sig_args = list(filter(lambda x: x.default == x.empty, sig_params.values())) sig_kwargs = {k: v.default for k, v in sig_params.items() if v.default != v.empty} sig_nargs = len(sig_args) kwarg_keys = list(sig_kwargs.keys()) # Restrain provided positional arguments to those that are also positional in signature new_args = args[:sig_nargs] # Construct hybrid keyword arguments dictionary from: # - remaining positional arguments # - provided keyword arguments # - default keyword arguments new_kwargs = sig_kwargs for i, x in enumerate(args[sig_nargs:]): new_kwargs[kwarg_keys[i]] = x for k, v in kwargs.items(): new_kwargs[k] = v return new_args, new_kwargs def alignWithMethodDef(method, args, kwargs): args, kwargs = alignWithFuncDef(method, [None] + list(args), kwargs) return tuple(args[1:]), kwargs def logCache(fpath, delimiter='\t', out_type=float): ''' Add an extra IO memoization functionality to a function using file caching, to avoid repetitions of tedious computations with identical inputs. ''' def wrapper_with_args(func): @wraps(func) def wrapper(*args, **kwargs): # If function has history -> do not log if 'history' in kwargs: return func(*args, **kwargs) # Modify positional and keyword arguments to match function signature, if needed args, kwargs = alignWithFuncDef(func, args, kwargs) # Translate args and kwargs into string signature fsignature = funcSig(func, args, kwargs) # If entry present in log, return corresponding output if os.path.isfile(fpath): with open(fpath, 'r', newline='') as f: reader = csv.reader(f, delimiter=delimiter) for row in reader: if row[0] == fsignature: logger.debug(f'entry found in "{os.path.basename(fpath)}"') return out_type(row[1]) # Otherwise, compute output and log it into file before returning out = func(*args, **kwargs) lock = lockfile.FileLock(fpath) lock.acquire() with open(fpath, 'a', newline='') as csvfile: writer = csv.writer(csvfile, delimiter=delimiter) writer.writerow([fsignature, str(out)]) lock.release() return out return wrapper return wrapper_with_args def fileCache(root, fcode_func, ext='json'): def wrapper_with_args(func): @wraps(func) def wrapper(*args, **kwargs): # Get load and dump functions from file extension try: load_func = { 'json': json.load, 'pkl': pickle.load, 'csv': lambda f: np.loadtxt(f, delimiter=',') }[ext] dump_func = { 'json': json.dump, 'pkl': pickle.dump, 'csv': lambda x, f: np.savetxt(f, x, delimiter=',') }[ext] except KeyError: raise ValueError('Unknown file extension') # Get read and write mode (text or binary) from file extension mode = 'b' if ext == 'pkl' else '' # Get file path from root and function arguments, using fcode function if callable(fcode_func): fcode = fcode_func(*args) else: fcode = fcode_func fpath = os.path.join(os.path.abspath(root), f'{fcode}.{ext}') # If file exists, load output from it if os.path.isfile(fpath): logger.info(f'loading data from "{fpath}"') with open(fpath, 'r' + mode) as f: out = load_func(f) # Otherwise, execute function and create the file to dump the output else: logger.warning(f'reference data file not found: "{fpath}"') out = func(*args, **kwargs) logger.info(f'dumping data in "{fpath}"') lock = lockfile.FileLock(fpath) lock.acquire() with open(fpath, 'w' + mode) as f: dump_func(out, f) lock.release() return out return wrapper return wrapper_with_args def derivative(f, x, eps, method='central'): ''' Compute the difference formula for f'(x) with perturbation size eps. :param dfunc: derivatives function, taking an array of states and returning an array of derivatives :param x: states vector :param method: difference formula: 'forward', 'backward' or 'central' :param eps: perturbation vector (same size as states vector) :return: numerical approximation of the derivative around the fixed point ''' if isIterable(x): if not isIterable(eps) or len(eps) != len(x): raise ValueError('eps must be the same size as x') elif np.sum(eps != 0.) != 1: raise ValueError('eps must be zero-valued across all but one dimensions') eps_val = np.sum(eps) else: eps_val = eps if method == 'central': df = (f(x + eps) - f(x - eps)) / 2 elif method == 'forward': df = f(x + eps) - f(x) elif method == 'backward': df = f(x) - f(x - eps) else: raise ValueError("Method must be 'central', 'forward' or 'backward'.") return df / eps_val def jacobian(dfunc, x, rel_eps=None, abs_eps=None, method='central'): ''' Evaluate the Jacobian maatrix of a (time-invariant) system, given a states vector and derivatives function. :param dfunc: derivatives function, taking an array of n states and returning an array of n derivatives :param x: n-states vector :return: n-by-n square Jacobian matrix ''' if sum(e is not None for e in [abs_eps, rel_eps]) != 1: raise ValueError('one (and only one) of "rel_eps" or "abs_eps" parameters must be provided') # Determine vector size x = np.asarray(x) n = x.size # Initialize Jacobian matrix J = np.empty((n, n)) # Create epsilon vector if rel_eps is not None: mode = 'relative' eps_vec = rel_eps else: mode = 'absolute' eps_vec = abs_eps if not isIterable(eps_vec): eps_vec = np.array([eps_vec] * n) if mode == 'relative': eps = x * eps_vec else: eps = eps_vec # Perturb each state by epsilon on both sides, re-evaluate derivatives # and assemble Jacobian matrix ei = np.zeros(n) for i in range(n): ei[i] = 1 J[:, i] = derivative(dfunc, x, eps * ei, method=method) ei[i] = 0 return J def classifyFixedPoint(x, dfunc): ''' Characterize the stability of a fixed point by numerically evaluating its Jacobian matrix and evaluating the sign of the real part of its associated eigenvalues. :param x: n-states vector :param dfunc: derivatives function, taking an array of n states and returning an array of n derivatives ''' # Compute Jacobian numerically # print(f'x = {x}, dfunx(x) = {dfunc(x)}') eps_machine = np.sqrt(np.finfo(float).eps) J = jacobian(dfunc, x, rel_eps=eps_machine, method='forward') # Compute eigenvalues and eigenvectors eigvals, eigvecs = linalg.eig(J) logger.debug(f"eigenvalues = {[f'({x.real:.2e} + {x.imag:.2e}j)' for x in eigvals]}") # Determine fixed point stability based on eigenvalues is_neg_eigvals = eigvals.real < 0 if is_neg_eigvals.all(): # all real parts negative -> stable key = 'stable' elif is_neg_eigvals.any(): # both posivie and negative real parts -> saddle key = 'saddle' else: # all real parts positive -> unstable key = 'unstable' return eigvals, key def findModifiedEq(x0, dfunc, *args): ''' Find an equilibrium variable in a modified system by searching for its derivative root within an interval around its original equilibrium. :param x0: equilibrium value in original system. :param func: derivative function, taking the variable as first parameter. :param *args: remaining arguments needed for the derivative function. :return: variable equilibrium value in modified system. ''' is_iterable = [isIterable(arg) for arg in args] if any(is_iterable): if not all(is_iterable): raise ValueError('mix of iterables and non-iterables') lengths = [len(arg) for arg in args] if not all(n == lengths[0] for n in lengths): raise ValueError(f'inputs are not of the same size: {lengths}') n = lengths[0] res = [] for i in range(n): x = [arg[i] for arg in args] res.append(findModifiedEq(x0, dfunc, *x)) return np.array(res) else: return brentq(lambda x: dfunc(x, *args), x0 * 1e-4, x0 * 1e3, xtol=1e-16) def swapFirstLetterCase(s): if s[0].islower(): return s.capitalize() else: return s[0].lower() + s[1:] def getPow10(x, direction='up'): ''' Get the power of 10 that is closest to a number, in either direction("down" or "up"). ''' round_method = {'up': np.ceil, 'down': np.floor}[direction] return np.power(10, round_method(np.log10(x))) def rotAroundPoint2D(x, theta, p): ''' Rotate a 2D vector around a center point by a given angle. :param x: 2D coordinates vector :param theta: rotation angle (in rad) :param p: 2D center point coordinates :return: 2D rotated coordinates vector ''' n1, n2 = x.shape if n1 != 2: if n2 == 2: x = x.T else: raise ValueError('x should be a 2-by-n vector') # Rotation matrix R = np.array([ [np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)], ]) # Broadcast center point to input vector ptile = np.tile(p, (x.shape[1], 1)).T # Subtract, rotate and add return R.dot(x - ptile) + ptile def getKey(keyfile='pushbullet.key'): dir_path = os.path.dirname(os.path.realpath(__file__)) package_root = os.path.abspath(os.path.join(dir_path, os.pardir)) fpath = os.path.join(package_root, keyfile) if not os.path.isfile(fpath): raise FileNotFoundError('pushbullet API key file not found') with open(fpath) as f: encoded_key = f.readlines()[0] return base64.b64decode(str.encode(encoded_key)).decode() def sendPushNotification(msg): try: key = getKey() pb = Pushbullet(key) dt = datetime.datetime.now() s = dt.strftime('%Y-%m-%d %H:%M:%S') pb.push_note('Code Messenger', f'{s}\n{msg}') except FileNotFoundError: logger.error(f'Could not send push notification: "{msg}"') def alert(func): ''' Run a function, and send a push notification upon completion, or if an error is raised during its execution. ''' @wraps(func) def wrapper(*args, **kwargs): try: out = func(*args, **kwargs) sendPushNotification(f'completed "{func.__name__}" execution successfully') return out except BaseException as e: sendPushNotification(f'error during "{func.__name__}" execution: {e}') raise e return wrapper def sunflower(n, radius=1, alpha=1): ''' Generate a population of uniformly distributed 2D data points in a unit circle. :param n: number of data points :param alpha: coefficient determining evenness of the boundary :return: 2D matrix of Cartesian (x, y) positions ''' nbounds = np.round(alpha * np.sqrt(n)) # number of boundary points phi = (np.sqrt(5) + 1) / 2 # golden ratio k = np.arange(1, n + 1) # index vector theta = 2 * np.pi * k / phi**2 # angle vector r = np.sqrt((k - 1) / (n - nbounds - 1)) # radius vector r[r > 1] = 1 x = r * np.cos(theta) y = r * np.sin(theta) return radius * np.vstack((x, y)) def filecode(model, *args): ''' Generate file code given a specific combination of model input parameters. ''' # If meta dictionary was passed, generate inputs list from it if len(args) == 1 and isinstance(args[0], dict): meta = args[0].copy() if meta['simkey'] == 'ASTIM' and 'fs' not in meta: meta['fs'] = meta['model']['fs'] meta['method'] = meta['model']['method'] meta['qss_vars'] = None for k in ['simkey', 'model', 'tcomp', 'dt', 'atol']: if k in meta: del meta[k] args = list(meta.values()) # Otherwise, transform args tuple into list else: args = list(args) # If any argument is an iterable -> transform it to a continous string for i in range(len(args)): if isIterable(args[i]): args[i] = ''.join([str(x) for x in args[i]]) # Create file code by joining string-encoded inputs with underscores codes = model.filecodes(*args).values() return '_'.join([x for x in codes if x is not None]) def simAndSave(model, *args, **kwargs): ''' Simulate the model and save the results in a specific output directory. :param *args: list of arguments provided to the simulation function :param **kwargs: optional arguments dictionary :return: output filepath ''' # Extract output directory and overwrite boolean from keyword arguments. outputdir = kwargs.pop('outputdir', '.') overwrite = kwargs.pop('overwrite', True) # Set data and meta to None data, meta = None, None # Extract drive object from args drive, *other_args = args # If drive is searchable and not fully resolved if drive.is_searchable: if not drive.is_resolved: # Call simulate to perform titration out = model.simulate(*args) # If titration yields nothing -> no file produced -> return None if out is None: logger.warning('returning None') return None # Store data and meta data, meta = out # Update args list with resovled drive try: args = (meta['drive'], *other_args) except KeyError: args = (meta['source'], *other_args) # Check if a output file corresponding to sim inputs is found in the output directory # That check is performed prior to running the simulation, such that # it is not run if the file is present and overwrite is set ot false. fname = f'{model.filecode(*args)}.pkl' fpath = os.path.join(outputdir, fname) existing_file_msg = f'File "{fname}" already present in directory "{outputdir}"' existing_file = os.path.isfile(fpath) # If file exists and overwrite is set ot false -> return if existing_file and not overwrite: logger.warning(f'{existing_file_msg} -> preserving') return fpath # Run simulation if not already done (for titration cases) if data is None: data, meta = model.simulate(*args) # Raise warning if an existing file is overwritten if existing_file: logger.warning(f'{existing_file_msg} -> overwriting') # Save output file and return output filepath with open(fpath, 'wb') as fh: pickle.dump({'meta': meta, 'data': data}, fh) logger.debug('simulation data exported to "%s"', fpath) return fpath def moveItem(l, value, itarget): ''' Move a list item to a specific target index. :param l: list object :param value: value of the item to move :param itarget: target index :return: re-ordered list. ''' # Get absolute target index if itarget < 0: itarget += len(l) assert itarget < len(l), f'target index {itarget} exceeds list size ({len(l)})' # Get index corresponding to element and delete entry from list iref = l.index(value) new_l = l.copy() del new_l[iref] # Return re-organized list return new_l[:itarget] + [value] + new_l[itarget:] def gaussian(x, mu=0., sigma=1., A=1.): return A * np.exp(-((x - mu) / sigma)**2 / 2) def isPickable(obj): try: pickle.dumps(obj) except Exception: return False return True def resolveFuncArgs(func, *args, **kwargs): ''' Return a dictionary of positional and keyword arguments upon function call, adding defaults from simfunc signature if not provided at call time. ''' bound_args = signature(func).bind(*args, **kwargs) bound_args.apply_defaults() return dict(bound_args.arguments) def getMeta(model, simfunc, *args, **kwargs): ''' Construct an informative dictionary about the model and simulation parameters. ''' # Retrieve function call arguments args_dict = resolveFuncArgs(simfunc, model, *args, **kwargs) # Construct meta dictionary meta = {'simkey': model.simkey} for k, v in args_dict.items(): if k == 'self': meta['model'] = v.meta else: meta[k] = v return meta def bounds(arr): ''' Return the bounds or a numpy array / list. ''' return (np.nanmin(arr), np.nanmax(arr)) def addColumn(df, key, arr, preceding_key=None): ''' Add a new column to a dataframe, right after a specific column. ''' df[key] = arr if preceding_key is not None: cols = df.columns.tolist()[:-1] preceding_index = cols.index(preceding_key) df = df[cols[:preceding_index + 1] + [key] + cols[preceding_index + 1:]] return df def integerSuffix(n): return 'th' if 4 <= n % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(n % 10, 'th') def customStrftime(fmt, dt_obj): return dt_obj.strftime(fmt).replace('{S}', str(dt_obj.day) + integerSuffix(dt_obj.day)) def friendlyLogspace(xmin, xmax, bases=None): ''' Define a "friendly" logspace between two bounds. ''' if bases is None: bases = [1, 2, 5] bases = np.asarray(bases) bounds = np.array([xmin, xmax]) logbounds = np.log10(bounds) bounds_orders = np.floor(logbounds) orders = np.arange(bounds_orders[0], bounds_orders[1] + 1) factors = np.power(10., np.floor(orders)) seq = np.hstack([bases * f for f in factors]) if xmax > seq.max(): seq = np.hstack((seq, xmax)) seq = seq[np.logical_and(seq >= xmin, seq <= xmax)] if xmin not in seq: seq = np.hstack((xmin, seq)) if xmax not in seq: seq = np.hstack((seq, xmax)) return seq def differing(d1, d2, subdkey=None, diff=None): ''' Find differences in values across two dictionaries (recursively). :param d1: first dictionary :param d2: second dictionary :param subdkey: specific sub-dictionary attribute key for objects :param diff: existing diff list to append to :return: list of (key, value1, value2) tuples for each differing values ''' # Initilize diff list if diff is None: diff = [] # Check that the two dicts have the same structure if sorted(list(d1.keys())) != sorted(list(d2.keys())): raise ValueError('inconsistent inputs') # For each key - values triplet for k in d1.keys(): # If values are dicts themselves, loop recursively through them if isinstance(d1[k], dict): diff = differing(d1[k], d2[k], subdkey=subdkey, diff=diff) # If values are objects with a specific sub-dictionary attribute, # loop recursively through them elif hasattr(d1[k], subdkey): diff = differing(getattr(d1[k], subdkey), getattr(d2[k], subdkey), subdkey=subdkey, diff=diff) # Otherwise else: # If values differ, add the key - values triplet to the diff list if d1[k] != d2[k]: diff.append((k, d1[k], d2[k])) # Return the diff list return diff def extractCommonPrefix(labels): ''' Extract a common prefix and a list of suffixes from a list of labels. ''' prefix = os.path.commonprefix(labels) if len(prefix) == 0: return None return prefix, [s.split(prefix)[1] for s in labels] def cycleAvg(t, y, T): ''' Cycle-average a vector according to a given periodicity. :param t: time vector ;param y: signal vector :param T: periodicity :return: cycle-averaged signal vector ''' t -= t[0] n = int(np.ceil(t[-1] / T)) return np.array([ np.mean(y[np.where((t >= i * T) & (t < (i + 1) * T))[0]]) for i in range(n)]) class TimeSeries(pd.DataFrame): ''' Wrapper around pandas DataFrame to store timeseries data. ''' time_key = 't' stim_key = 'stimstate' def __init__(self, t, stim, dout): super().__init__(data={ self.time_key: t, self.stim_key: stim, **dout }) @property def time(self): return self[self.time_key].values @property def tbounds(self): return self.time.min(), self.time.max() @property def stim(self): return self[self.stim_key].values @property def inputs(self): return [self.time_key, self.stim_key] @property def outputs(self): return list(set(self.columns.values) - set(self.inputs)) def interpCol(self, t, k, kind): ''' Interpolate a column according to a new time vector. ''' kind = 'nearest' if k == self.stim_key else 'linear' self[k] = interp1d(self.time, self[k].values, kind=kind)(t) def interp1d(self, t): ''' Interpolate the entire dataframe according to a new time vector. ''' for k in self.outputs: self.interpCol(t, k, 'linear') self.interpCol(t, self.stim_key, 'nearest') self[self.time_key] = t def resample(self, dt): ''' Resample dataframe at regular time step. ''' tmin, tmax = self.tbounds n = int((tmax - tmin) / dt) + 1 self.interp1d(np.linspace(tmin, tmax, n)) def cycleAveraged(self, T): ''' Cycle-average a periodic solution. ''' t = np.arange(self.time[0], self.time[-1], T) stim = interp1d(self.time, self.stim, kind='nearest')(t) outputs = {k: cycleAvg(self.time, self[k].values, T) for k in self.outputs} return self.__class__(t, stim, outputs) def prepend(self, t0=0): ''' Repeat first row outputs for a preceding time. ''' if t0 > self.time.min(): raise ValueError('t0 greater than minimal time value') self.loc[-1] = self.iloc[0] # repeat first row self.index = self.index + 1 # shift index self.sort_index(inplace=True) self[self.time_key][0] = t0 self[self.stim_key][0] = 0 def bound(self, tbounds): ''' Restrict all columns of dataframe to indexes corresponding to time values within specific bounds. ''' tmin, tmax = tbounds return self[np.logical_and(self.time >= tmin, self.time <= tmax)].reset_index(drop=True) def checkAgainst(self, other): assert isinstance(other, self.__class__), 'classes do not match' assert all(self.keys() == other.keys()), 'differing keys' for k in self.inputs: assert all(self[k].values == other[k].values), f'{k} vectors do not match' def operate(self, other, op): ''' Generic arithmetic operator. ''' self.checkAgainst(other) return self.__class__( self.time, self.stim, {k: getattr(self[k].values, op)(other[k].values) for k in self.outputs} ) def __add__(self, other): ''' Addition operator. ''' return self.operate(other, '__add__') def __sub__(self, other): ''' Subtraction operator. ''' return self.operate(other, '__sub__') def __mul__(self, other): ''' Multiplication operator. ''' return self.operate(other, '__mul__') def __truediv__(self, other): ''' Division operator. ''' return self.operate(other, '__truediv__') def pairwise(iterable): ''' s -> (s0,s1), (s1,s2), (s2, s3), ... ''' a, b = itertools.tee(iterable) next(b, None) return list(zip(a, b)) def padleft(x): return np.pad(x, (1, 0), 'edge') def padright(x): return np.pad(x, (0, 1), 'edge') def timeThreshold(t, y, dy_thr): ''' Find time interval required to reach a given threshold in a non-monotonous signal. ''' y -= y[0] # remove initial offset ifirst = np.where(y > dy_thr)[0][0] return np.interp(dy_thr, y[:ifirst + 1], t[:ifirst + 1]) def flatten(din): ''' Flatten a two level dictionary ''' dout = {} for k, v in din.items(): for k2, v2 in v.items(): dout[f'{k} - {k2}'] = v2 return dout