diff --git a/PySONIC/core/pneuron.py b/PySONIC/core/pneuron.py index 9d43a4f..7240863 100644 --- a/PySONIC/core/pneuron.py +++ b/PySONIC/core/pneuron.py @@ -1,624 +1,627 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-08-03 11:53:04 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-03-15 02:08:30 +# @Last Modified time: 2019-03-18 16:07:44 import os import time import pickle import abc import inspect import re import numpy as np from scipy.integrate import odeint import pandas as pd from ..postpro import findPeaks from ..constants import * from ..utils import si_format, logger, ESTIM_filecode from ..batches import xlslog class PointNeuron(metaclass=abc.ABCMeta): ''' Abstract class defining the common API (i.e. mandatory attributes and methods) of all subclasses implementing the channels mechanisms of specific point neurons. The mandatory attributes are: - **name**: a string defining the name of the mechanism. - **Cm0**: a float defining the membrane resting capacitance (in F/m2) - **Vm0**: a float defining the membrane resting potential (in mV) - **states**: a list of strings defining the names of the different state probabilities governing the channels behaviour (i.e. the differential HH variables). - **rates**: a list of strings defining the names of the different coefficients to be used in effective simulations. The mandatory methods are: - **iNet**: compute the net ionic current density (in mA/m2) across the membrane, given a specific membrane potential (in mV) and channel states. - **steadyStates**: compute the channels steady-state values for a specific membrane potential value (in mV). - **derStates**: compute the derivatives of channel states, given a specific membrane potential (in mV) and channel states. This method must return a list of derivatives ordered identically as in the steadyStates output. - **getEffRates**: get the effective rate constants of ion channels to be used in effective simulations. This method must return an array of effective rates ordered identically as in the rates attribute. - **derStatesEff**: compute the effective derivatives of channel states, based on 1-dimensional linear interpolators of "effective" coefficients. This method must return a list of derivatives ordered identically as in the steadyStates output. - **steadyStates**: compute the steady-state values of all internal states for a given membrane potential. ''' tscale = 'ms' # relevant temporal scale of the model defvar = 'V' # default plot variable def __repr__(self): return self.__class__.__name__ def pprint(self): return '{} neuron'.format(self.__class__.__name__) @property @abc.abstractmethod def name(self): return 'Should never reach here' @property @abc.abstractmethod def Cm0(self): return 'Should never reach here' @property @abc.abstractmethod def Vm0(self): return 'Should never reach here' @abc.abstractmethod def currents(self, Vm, states): ''' Compute all ionic currents per unit area. :param Vm: membrane potential (mV) :states: state probabilities of the ion channels :return: dictionary of ionic currents per unit area (mA/m2) ''' def iNet(self, Vm, states): ''' Net membrane current :param Vm: membrane potential (mV) :states: state probabilities of the ion channels :return: current per unit area (mA/m2) ''' return sum(self.currents(Vm, states).values()) def currentToConcentrationRate(self, z_ion, depth): ''' Compute the conversion factor from a specific ionic current (in mA/m2) into a variation rate of submembrane ion concentration (in M/s). :param: z_ion: ion valence :param depth: submembrane depth (m) :return: time derivative of submembrane ion concentration (M/s) ''' return 1e-6 / (z_ion * depth * FARADAY) def nernst(self, z_ion, Cion_in, Cion_out, T): ''' Nernst potential of a specific ion given its intra and extracellular concentrations. :param z_ion: ion valence :param Cion_in: intracellular ion concentration :param Cion_out: extracellular ion concentration :param T: temperature (K) :return: ion Nernst potential (mV) ''' return (Rg * T) / (z_ion * FARADAY) * np.log(Cion_out / Cion_in) * 1e3 def vtrap(self, x, y): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x / y) - 1) def efun(self, x): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x) - 1) def ghkDrive(self, Vm, Z_ion, Cion_in, Cion_out, T): ''' Use the Goldman-Hodgkin-Katz equation to compute the electrochemical driving force of a specific ion species for a given membrane potential. :param Vm: membrane potential (mV) :param Cin: intracellular ion concentration (M) :param Cout: extracellular ion concentration (M) :param T: temperature (K) :return: electrochemical driving force of a single ion particle (mC.m-3) ''' x = Z_ion * FARADAY * Vm / (Rg * T) * 1e-3 # [-] eCin = Cion_in * self.efun(-x) # M eCout = Cion_out * self.efun(x) # M return FARADAY * (eCin - eCout) * 1e6 # mC/m3 def getDesc(self): return inspect.getdoc(self).splitlines()[0] def getCurrentsNames(self): return list(self.currents(np.nan, [np.nan] * len(self.states)).keys()) def getPltScheme(self): - pltscheme = {'Q_m': ['Qm'], 'V_m': ['Vm']} + pltscheme = { + 'Q_m': ['Qm'], + 'V_m': ['Vm'] + } + pltscheme['I'] = self.getCurrentsNames() + ['iNet'] for cname in self.getCurrentsNames(): if 'Leak' not in cname: - key = 'I_{{{}}}\ kin.'.format(cname[1:]) + key = 'i_{{{}}}\ kin.'.format(cname[1:]) cargs = inspect.getargspec(getattr(self, cname))[0][1:] pltscheme[key] = [var for var in cargs if var not in ['Vm', 'Cai']] - pltscheme['I'] = self.getCurrentsNames() + ['iNet'] return pltscheme def getPltVars(self): ''' Return a dictionary with information about all plot variables related to the neuron. ''' pltvars = { 'Qm': { - 'desc': 'charge density', + 'desc': 'membrane charge density', 'label': 'Q_m', 'unit': 'nC/cm^2', 'factor': 1e5, 'bounds': (-100, 50) }, 'Vm': { 'desc': 'membrane potential', 'label': 'V_m', 'unit': 'mV', 'y0': self.Vm0 }, 'ELeak': { 'constant': 'obj.ELeak', 'desc': 'non-specific leakage current resting potential', 'label': 'V_{leak}', 'unit': 'mV', 'ls': '--', 'color': 'k' } } for cname in self.getCurrentsNames(): cfunc = getattr(self, cname) cargs = inspect.getargspec(cfunc)[0][1:] pltvars[cname] = { 'desc': inspect.getdoc(cfunc).splitlines()[0], 'label': 'I_{{{}}}'.format(cname[1:]), 'unit': 'A/m^2', 'factor': 1e-3, 'func': '{}({})'.format(cname, ', '.join(['df["{}"]'.format(a) for a in cargs])) } for var in cargs: if var not in ['Vm', 'Cai']: vfunc = getattr(self, 'der{}{}'.format(var[0].upper(), var[1:])) desc = cname + re.sub('^Evolution of', '', inspect.getdoc(vfunc).splitlines()[0]) pltvars[var] = { 'desc': desc, 'label': var, 'bounds': (-0.1, 1.1) } pltvars['iNet'] = { 'desc': inspect.getdoc(getattr(self, 'iNet')).splitlines()[0], 'label': 'I_{net}', 'unit': 'A/m^2', 'factor': 1e-3, 'func': 'iNet(df["Vm"], df[obj.states].values.T)', 'ls': '--', - 'color': 'k' + 'color': 'black' } for x in self.getGates(): for rate in ['alpha', 'beta']: pltvars['{}{}'.format(rate, x)] = { 'label': '\\{}_{{{}}}'.format(rate, x), 'unit': 'ms^{-1}', 'factor': 1e-3 } return pltvars @abc.abstractmethod def steadyStates(self, Vm): ''' Compute the channels steady-state values for a specific membrane potential value. :param Vm: membrane potential (mV) :return: array of steady-states ''' @abc.abstractmethod def derStates(self, Vm, states): ''' Compute the derivatives of channel states. :param Vm: membrane potential (mV) :states: state probabilities of the ion channels :return: current per unit area (mA/m2) ''' @abc.abstractmethod def getEffRates(self, Vm): ''' Get the effective rate constants of ion channels, averaged along an acoustic cycle, for future use in effective simulations. :param Vm: array of membrane potential values for an acoustic cycle (mV) :return: an array of rate average constants (s-1) ''' @abc.abstractmethod def derStatesEff(self, Qm, states, interp_data): ''' Compute the effective derivatives of channel states, based on 1-dimensional linear interpolation of "effective" coefficients that summarize the system's behaviour over an acoustic cycle. :param Qm: membrane charge density (C/m2) :states: state probabilities of the ion channels :param interp_data: dictionary of 1D vectors of "effective" coefficients over the charge domain, for specific frequency and amplitude values. ''' def Qbounds(self): ''' Determine bounds of membrane charge physiological range for a given neuron. ''' return np.array([np.round(self.Vm0 - 25.0), 50.0]) * self.Cm0 * 1e-3 # C/m2 def getGates(self): ''' Retrieve the names of the neuron's states that match an ion channel gating. ''' gates = [] for x in self.states: if 'alpha{}'.format(x.lower()) in self.rates: gates.append(x) return gates def getRates(self, Vm): ''' Compute the ion channels rate constants for a given membrane potential. :param Vm: membrane potential (mV) :return: a dictionary of rate constants and their values at the given potential. ''' rates = {} for x in self.getGates(): x = x.lower() alpha_str, beta_str = ['{}{}'.format(s, x.lower()) for s in ['alpha', 'beta']] inf_str, tau_str = ['{}inf'.format(x.lower()), 'tau{}'.format(x.lower())] if hasattr(self, 'alpha{}'.format(x)): alphax = getattr(self, alpha_str)(Vm) betax = getattr(self, beta_str)(Vm) elif hasattr(self, '{}inf'.format(x)): xinf = getattr(self, inf_str)(Vm) taux = getattr(self, tau_str)(Vm) alphax = xinf / taux betax = 1 / taux - alphax rates[alpha_str] = alphax rates[beta_str] = betax return rates def Vderivatives(self, y, t, Iinj): ''' Compute the derivatives of a V-cast HH system for a specific value of injected current. :param y: vector of HH system variables at time t :param t: time value (s, unused) :param Iinj: injected current (mA/m2) :return: vector of HH system derivatives at time t ''' Vm, *states = y Iionic = self.iNet(Vm, states) # mA/m2 dVmdt = (- Iionic + Iinj) / self.Cm0 # mV/s dstates = self.derStates(Vm, states) return [dVmdt, *dstates] def Qderivatives(self, y, t, Cm=None): ''' Compute the derivatives of the n-ODE HH system variables, based on a value of membrane capacitance. :param y: vector of HH system variables at time t :param t: specific instant in time (s) :param Cm: membrane capacitance (F/m2) :return: vector of HH system derivatives at time t ''' if Cm is None: Cm = self.Cm0 Qm, *states = y Vm = Qm / Cm * 1e3 # mV dQm = - self.iNet(Vm, states) * 1e-3 # A/m2 dstates = self.derStates(Vm, states) return [dQm, *dstates] def checkInputs(self, Astim, tstim, toffset, PRF, DC): ''' Check validity of electrical stimulation parameters. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) ''' # Check validity of stimulation parameters if not all(isinstance(param, float) for param in [Astim, tstim, toffset, DC]): raise TypeError('Invalid stimulation parameters (must be float typed)') if tstim <= 0: raise ValueError('Invalid stimulus duration: {} ms (must be strictly positive)' .format(tstim * 1e3)) if toffset < 0: raise ValueError('Invalid stimulus offset: {} ms (must be positive or null)' .format(toffset * 1e3)) if DC <= 0.0 or DC > 1.0: raise ValueError('Invalid duty cycle: {} (must be within ]0; 1])'.format(DC)) if DC < 1.0: if not isinstance(PRF, float): raise TypeError('Invalid PRF value (must be float typed)') if PRF is None: raise AttributeError('Missing PRF value (must be provided when DC < 1)') if PRF < 1 / tstim: raise ValueError('Invalid PRF: {} Hz (PR interval exceeds stimulus duration)' .format(PRF)) def simulate(self, Astim, tstim, toffset, PRF=None, DC=1.0): ''' Compute solutions of a neuron's HH system for a specific set of electrical stimulation parameters, using a classic integration scheme. :param Astim: pulse amplitude (mA/m2) :param tstim: pulse duration (s) :param toffset: offset duration (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :return: 3-tuple with the time profile and solution matrix and a state vector ''' # Check validity of stimulation parameters self.checkInputs(Astim, tstim, toffset, PRF, DC) # Determine system time step dt = DT_ESTIM # if CW stimulus: divide integration during stimulus into single interval if DC == 1.0: PRF = 1 / tstim # Compute vector sizes npulses = int(np.round(PRF * tstim)) Tpulse_on = DC / PRF Tpulse_off = (1 - DC) / PRF # For high-PRF pulsed protocols: adapt time step to ensure minimal # number of samples during TON or TOFF dt_warning_msg = 'high-PRF protocol: lowering time step to %.2e s to properly integrate %s' for key, Tpulse in {'TON': Tpulse_on, 'TOFF': Tpulse_off}.items(): if Tpulse > 0 and Tpulse / dt < MIN_SAMPLES_PER_PULSE_INT: dt = Tpulse / MIN_SAMPLES_PER_PULSE_INT logger.warning(dt_warning_msg, dt, key) n_pulse_on = int(np.round(Tpulse_on / dt)) n_pulse_off = int(np.round(Tpulse_off / dt)) # Compute offset size n_off = int(np.round(toffset / dt)) # Set initial conditions y0 = [self.Vm0, *self.steadyStates(self.Vm0)] nvar = len(y0) # Initialize global arrays t = np.array([0.]) states = np.array([1]) y = np.array([y0]).T # Initialize pulse time and states vectors t_pulse0 = np.linspace(0, Tpulse_on + Tpulse_off, n_pulse_on + n_pulse_off) states_pulse = np.concatenate((np.ones(n_pulse_on), np.zeros(n_pulse_off))) # Loop through all pulse (ON and OFF) intervals for i in range(npulses): # Construct and initialize arrays t_pulse = t_pulse0 + t[-1] y_pulse = np.empty((nvar, n_pulse_on + n_pulse_off)) # Integrate ON system y_pulse[:, :n_pulse_on] = odeint( self.Vderivatives, y[:, -1], t_pulse[:n_pulse_on], args=(Astim,)).T # Integrate OFF system if n_pulse_off > 0: y_pulse[:, n_pulse_on:] = odeint( self.Vderivatives, y_pulse[:, n_pulse_on - 1], t_pulse[n_pulse_on:], args=(0.0,)).T # Append pulse arrays to global arrays states = np.concatenate([states, states_pulse[1:]]) t = np.concatenate([t, t_pulse[1:]]) y = np.concatenate([y, y_pulse[:, 1:]], axis=1) # Integrate offset interval if n_off > 0: t_off = np.linspace(0, toffset, n_off) + t[-1] states_off = np.zeros(n_off) y_off = odeint(self.Vderivatives, y[:, -1], t_off, args=(0.0, )).T # Concatenate offset arrays to global arrays states = np.concatenate([states, states_off[1:]]) t = np.concatenate([t, t_off[1:]]) y = np.concatenate([y, y_off[:, 1:]], axis=1) # Return output variables return (t, y, states) def titrate(self, tstim, toffset, PRF=None, DC=1.0, Arange=(0., 2 * TITRATION_ESTIM_A_MAX)): ''' Use a dichotomic recursive search to determine the threshold amplitude needed to obtain neural excitation for a given duration, PRF and duty cycle. :param tstim: duration of US stimulation (s) :param toffset: duration of the offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: pulse duty cycle (-) :param Arange: search interval for Astim, iteratively refined :return: 5-tuple with the determined threshold, time profile, solution matrix, state vector and response latency ''' Astim = (Arange[0] + Arange[1]) / 2 # Run simulation and detect spikes t0 = time.time() (t, y, states) = self.simulate(Astim, tstim, toffset, PRF, DC) tcomp = time.time() - t0 dt = t[1] - t[0] ipeaks, *_ = findPeaks(y[0, :], SPIKE_MIN_VAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_VPROM) nspikes = ipeaks.size latency = t[ipeaks[0]] if nspikes > 0 else None logger.debug('A = %sA/m2 ---> %s spike%s detected', si_format(Astim * 1e-3, 2, space=' '), nspikes, "s" if nspikes > 1 else "") # If accurate threshold is found, return simulation results if (Arange[1] - Arange[0]) <= TITRATION_ESTIM_DA_MAX and nspikes == 1: return (Astim, t, y, states, latency, tcomp) # Otherwise, refine titration interval and iterate recursively else: if nspikes == 0: # if Astim too close to max then stop if (TITRATION_ESTIM_A_MAX - Astim) <= TITRATION_ESTIM_DA_MAX: return (np.nan, t, y, states, latency, tcomp) Arange = (Astim, Arange[1]) else: Arange = (Arange[0], Astim) return self.titrate(tstim, toffset, PRF, DC, Arange=Arange) def runAndSave(self, outdir, tstim, toffset, PRF=None, DC=1.0, Astim=None): ''' Run a simulation of the point-neuron Hodgkin-Huxley system with specific parameters, and save the results in a PKL file. :param outdir: full path to output directory :param tstim: stimulus duration (s) :param toffset: stimulus offset (s) :param PRF: pulse repetition frequency (Hz) :param DC: stimulus duty cycle (-) :param Astim: stimulus amplitude (mA/m2) ''' # Get date and time info date_str = time.strftime("%Y.%m.%d") daytime_str = time.strftime("%H:%M:%S") if Astim is not None: logger.info('%s: simulation @ A = %sA/m2, t = %ss (%ss offset)%s', self, si_format(Astim * 1e-3, 2, space=' '), *si_format([tstim, toffset], 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format(si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) # Run simulation tstart = time.time() t, y, states = self.simulate(Astim, tstim, toffset, PRF, DC) Vm, *channels = y tcomp = time.time() - tstart # Detect spikes on Vm signal dt = t[1] - t[0] ipeaks, *_ = findPeaks(Vm, SPIKE_MIN_VAMP, int(np.ceil(SPIKE_MIN_DT / dt)), SPIKE_MIN_VPROM) nspikes = ipeaks.size lat = t[ipeaks[0]] if nspikes > 0 else 'N/A' outstr = '{} spike{} detected'.format(nspikes, 's' if nspikes > 1 else '') else: logger.info('%s: titration @ t = %ss%s', self, si_format(tstim, 1, space=' '), (', PRF = {}Hz, DC = {:.2f}%'.format(si_format(PRF, 2, space=' '), DC * 1e2) if DC < 1.0 else '')) # Run titration Astim, t, y, states, lat, tcomp = self.titrate(tstim, toffset, PRF, DC) Vm, *channels = y nspikes = 1 if Astim is np.nan: outstr = 'no spikes detected within titration interval' nspikes = 0 else: nspikes = 1 outstr = 'Athr = {}A/m2'.format(si_format(Astim * 1e-3, 2, space=' ')) logger.debug('completed in %s, %s', si_format(tcomp, 1), outstr) sr = np.mean(1 / np.diff(t[ipeaks])) if nspikes > 1 else None # Store dataframe and metadata df = pd.DataFrame({ 't': t, 'states': states, 'Vm': Vm, 'Qm': Vm * self.Cm0 * 1e-3 }) for j in range(len(self.states)): df[self.states[j]] = channels[j] meta = { 'neuron': self.name, 'Astim': Astim, 'tstim': tstim, 'toffset': toffset, 'PRF': PRF, 'DC': DC, 'tcomp': tcomp } # Export into to PKL file simcode = ESTIM_filecode(self.name, Astim, tstim, PRF, DC) outpath = '{}/{}.pkl'.format(outdir, simcode) with open(outpath, 'wb') as fh: pickle.dump({'meta': meta, 'data': df}, fh) logger.debug('simulation data exported to "%s"', outpath) # Export key metrics to log file logpath = os.path.join(outdir, 'log_ESTIM.xlsx') logentry = { 'Date': date_str, 'Time': daytime_str, 'Neuron Type': self.name, 'Astim (mA/m2)': Astim, 'Tstim (ms)': tstim * 1e3, 'PRF (kHz)': PRF * 1e-3 if DC < 1 else 'N/A', 'Duty factor': DC, '# samples': t.size, 'Comp. time (s)': round(tcomp, 2), '# spikes': nspikes, 'Latency (ms)': lat * 1e3 if isinstance(lat, float) else 'N/A', 'Spike rate (sp/ms)': sr * 1e-3 if isinstance(sr, float) else 'N/A' } if xlslog(logpath, logentry) == 1: logger.debug('log exported to "%s"', logpath) else: logger.error('log export to "%s" aborted', self.logpath) return outpath def findRheobaseAmps(self, DCs, Vthr, curr='net'): ''' Find the rheobase amplitudes (i.e. threshold amplitudes of infinite duration that would result in excitation) of a specific neuron for various stimulation duty cycles. :param DCs: duty cycles vector (-) :param Vthr: threshold membrane potential above which the neuron necessarily fires (mV) :return: rheobase amplitudes vector (mA/m2) ''' # Compute the pulse average net (or leakage) current along the amplitude space if curr == 'net': iNet = self.iNet(Vthr, self.steadyStates(Vthr)) elif curr == 'leak': iNet = self.iLeak(Vthr) # Compute rheobase amplitudes return iNet / np.array(DCs) diff --git a/PySONIC/plt/batch.py b/PySONIC/plt/batch.py index 6bbbd7b..a224005 100644 --- a/PySONIC/plt/batch.py +++ b/PySONIC/plt/batch.py @@ -1,149 +1,148 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2018-09-25 16:19:19 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-03-15 01:16:31 +# @Last Modified time: 2019-03-18 14:18:02 import numpy as np import matplotlib.pyplot as plt from ..utils import * from .pltutils import * def plotBatch(filepaths, pltscheme=None, plt_save=False, directory=None, ask_before_save=True, fig_ext='png', tag='fig', fs=10, lw=2, title=True, show_patches=True, frequency=1): ''' Plot a figure with profiles of several specific NICE output variables, for several NICE simulations. :param filepaths: list of full paths to output data files to be compared :param pltscheme: dict of lists of variables names to extract and plot together :param plt_save: boolean stating whether to save the created figures :param directory: directory where to save figures :param ask_before_save: boolean stating whether to show the created figures :param fig_ext: file extension for the saved figures :param tag: suffix added to the end of the figures name :param fs: labels font size :param lw: curves line width :param title: boolean stating whether to display a general title on the figures :param show_patches: boolean indicating whether to indicate periods of stimulation with colored rectangular patches :param frequency: downsampling frequency for time series :return: list of figure handles ''' figs = [] # Loop through data files for filepath in filepaths: # Retrieve file code and sim type from file name pkl_filename = os.path.basename(filepath) filecode = pkl_filename[0:-4] sim_type = getSimType(pkl_filename) # Load data and extract variables df, meta = loadData(filepath, frequency) t = df['t'].values states = df['states'].values # Determine stimulus patch from states _, tpatch_on, tpatch_off = getStimPulses(t, states) # Initialize appropriate object obj = getObject(sim_type, meta) # Retrieve plot variables tvar, pltvars = getTimePltVar(obj.tscale), obj.getPltVars() # Check plot scheme if provided, otherwise generate it if pltscheme: for key in list(sum(list(pltscheme.values()), [])): if key not in pltvars: raise KeyError('Unknown plot variable: "{}"'.format(key)) else: pltscheme = obj.getPltScheme() # Preset and rescale time vector if tvar['onset'] > 0.0: tonset = np.array([-tvar['onset'], -t[0] - t[1]]) t = np.hstack((tonset, t)) - states = np.hstack((states, np.zeros(2))) t *= tvar['factor'] # Create figure naxes = len(pltscheme) if naxes == 1: fig, ax = plt.subplots(figsize=(11, 4)) axes = [ax] else: fig, axes = plt.subplots(naxes, 1, figsize=(11, min(3 * naxes, 9))) # Loop through each subgraph for ax, (grouplabel, keys) in zip(axes, pltscheme.items()): # Extract variables to plot nvars = len(keys) ax_pltvars = [pltvars[k] for k in keys] if nvars == 1: ax_pltvars[0]['color'] = 'k' ax_pltvars[0]['ls'] = '-' # Set y-axis unit and bounds ax.set_ylabel('$\\rm {}\ ({})$'.format(grouplabel, ax_pltvars[0].get('unit', '')), fontsize=fs) if 'bounds' in ax_pltvars[0]: ax_min = min([ap['bounds'][0] for ap in ax_pltvars]) ax_max = max([ap['bounds'][1] for ap in ax_pltvars]) ax.set_ylim(ax_min, ax_max) # Plot time series icolor = 0 for pltvar, name in zip(ax_pltvars, pltscheme[grouplabel]): var = extractPltVar(obj, pltvar, df, meta, t.size, name) ax.plot(t, var, pltvar.get('ls', '-'), c=pltvar.get('color', 'C{}'.format(icolor)), lw=lw, label='$\\rm {}$'.format(pltvar['label'])) if 'color' not in pltvar: icolor += 1 # Add legend if nvars > 1 or 'gate' in ax_pltvars[0]['desc']: ax.legend(fontsize=fs, loc=7, ncol=nvars // 4 + 1, frameon=False) # Post-process figure for ax in axes: for item in ['top', 'right']: ax.spines[item].set_visible(False) ax.locator_params(axis='y', nbins=2) for item in ax.get_yticklabels(): item.set_fontsize(fs) for ax in axes[:-1]: ax.set_xticklabels([]) for item in axes[-1].get_xticklabels(): item.set_fontsize(fs) axes[-1].set_xlabel('$\\rm {}\ ({})$'.format(tvar['label'], tvar['unit']), fontsize=fs) if show_patches == 1: for ax in axes: plotStimPatches(ax, tpatch_on, tpatch_off, tvar['factor']) if title: axes[0].set_title(figtitle(meta), fontsize=fs) fig.tight_layout() # Save figure if needed (automatic or checked) if plt_save: if directory is None: directory = os.path.split(filepath)[0] if ask_before_save: plt_filename = SaveFileDialog( '{}_{}.{}'.format(filecode, tag, fig_ext), dirname=directory, ext=fig_ext) else: plt_filename = '{}/{}_{}.{}'.format(directory, filecode, tag, fig_ext) if plt_filename: plt.savefig(plt_filename) logger.info('Saving figure as "{}"'.format(plt_filename)) plt.close() figs.append(fig) return figs diff --git a/PySONIC/plt/pltutils.py b/PySONIC/plt/pltutils.py index 3f86e06..c399b70 100644 --- a/PySONIC/plt/pltutils.py +++ b/PySONIC/plt/pltutils.py @@ -1,118 +1,119 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2017-08-21 14:33:36 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-03-15 01:32:28 +# @Last Modified time: 2019-03-18 14:13:18 ''' Useful functions to generate plots. ''' import re import numpy as np import matplotlib from ..core import BilayerSonophore, NeuronalBilayerSonophore from ..neurons import getNeuronsDict # Matplotlib parameters matplotlib.rcParams['pdf.fonttype'] = 42 matplotlib.rcParams['ps.fonttype'] = 42 matplotlib.rcParams['font.family'] = 'arial' rgxp = re.compile('(ESTIM|ASTIM)_([A-Za-z]*)_(.*).pkl') rgxp_mech = re.compile('(MECH)_(.*).pkl') def cm2inch(*tupl): inch = 2.54 if isinstance(tupl[0], tuple): return tuple(i / inch for i in tupl[0]) else: return tuple(i / inch for i in tupl) def getTimePltVar(tscale): ''' Return time plot variable for a given temporal scale. ''' return { 'desc': 'time', 'label': 'time', 'unit': tscale, 'factor': {'ms': 1e3, 'us': 1e6}[tscale], 'onset': {'ms': 1e-3, 'us': 1e-6}[tscale] } def getSimType(fname): ''' Get sim type from filename. ''' for exp in [rgxp, rgxp_mech]: mo = exp.fullmatch(fname) if mo: sim_type = mo.group(1) if sim_type not in ('MECH', 'ASTIM', 'ESTIM'): raise ValueError('Invalid simulation type: {}'.format(sim_type)) return sim_type raise ValueError('Error: "{}" file does not match regexp pattern'.format(fname)) def getObject(sim_type, meta): if sim_type == 'MECH': obj = BilayerSonophore(meta['a'], meta['Cm0'], meta['Qm0']) else: obj = getNeuronsDict()[meta['neuron']]() if sim_type == 'ASTIM': obj = NeuronalBilayerSonophore(meta['a'], obj, meta['Fdrive']) return obj def getStimPulses(t, states): ''' Determine the onset and offset times of pulses from a stimulation vector. :param t: time vector (s). :param states: a vector of stimulation state (ON/OFF) at each instant in time. :return: 3-tuple with number of patches, timing of STIM-ON an STIM-OFF instants. ''' # Compute states derivatives and identify bounds indexes of pulses dstates = np.diff(states) ipulse_on = np.insert(np.where(dstates > 0.0)[0] + 1, 0, 0) ipulse_off = np.where(dstates < 0.0)[0] + 1 if ipulse_off.size < ipulse_on.size: ioff = t.size - 1 if ipulse_off.size == 0: ipulse_off = np.array([ioff]) else: ipulse_off = np.insert(ipulse_off, ipulse_off.size - 1, ioff) # Get time instants for pulses ON and OFF npulses = ipulse_on.size tpulse_on = t[ipulse_on] tpulse_off = t[ipulse_off] # return 3-tuple with #pulses, pulse ON and pulse OFF instants return npulses, tpulse_on, tpulse_off def plotStimPatches(ax, tpatch_on, tpatch_off, tfactor): for j in range(tpatch_on.size): ax.axvspan(tpatch_on[j] * tfactor, tpatch_off[j] * tfactor, edgecolor='none', facecolor='#8A8A8A', alpha=0.2) def extractPltVar(obj, pltvar, df, meta, nsamples, name): if 'func' in pltvar: s = 'obj.{}'.format(pltvar['func']) try: var = eval(s) except AttributeError: var = eval(s.replace('obj', 'obj.neuron')) elif 'key' in pltvar: var = df[pltvar['key']] elif 'constant' in pltvar: var = eval(pltvar['constant']) * np.ones(nsamples) else: var = df[name] + var = var.values.copy() if var.size == nsamples - 2: var = np.hstack((np.array([pltvar.get('y0', var[0])] * 2), var)) var *= pltvar.get('factor', 1) return var diff --git a/PySONIC/utils.py b/PySONIC/utils.py index b3c940f..ff9f0de 100644 --- a/PySONIC/utils.py +++ b/PySONIC/utils.py @@ -1,507 +1,525 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Date: 2016-09-19 22:30:46 # @Email: theo.lemaire@epfl.ch # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-03-15 00:29:15 +# @Last Modified time: 2019-03-18 16:50:08 ''' Definition of generic utility functions used in other modules ''' import operator import os import math import pickle import tkinter as tk from tkinter import filedialog import numpy as np import colorlog from scipy.interpolate import interp1d # Package logger def setLogger(): log_formatter = colorlog.ColoredFormatter( '%(log_color)s %(asctime)s %(message)s', datefmt='%d/%m/%Y %H:%M:%S:', reset=True, log_colors={ 'DEBUG': 'green', 'INFO': 'white', 'WARNING': 'yellow', 'ERROR': 'red', 'CRITICAL': 'red,bg_white', }, style='%' ) log_handler = colorlog.StreamHandler() log_handler.setFormatter(log_formatter) color_logger = colorlog.getLogger('PySONIC') color_logger.addHandler(log_handler) return color_logger logger = setLogger() # File naming conventions def ESTIM_filecode(neuron, Astim, tstim, PRF, DC): return 'ESTIM_{}_{}_{:.1f}mA_per_m2_{:.0f}ms{}'.format( neuron, 'CW' if DC == 1 else 'PW', Astim, tstim * 1e3, '_PRF{:.2f}Hz_DC{:.2f}%'.format(PRF, DC * 1e2) if DC < 1. else '') def ASTIM_filecode(neuron, a, Fdrive, Adrive, tstim, PRF, DC, method): return 'ASTIM_{}_{}_{:.0f}nm_{:.0f}kHz_{:.2f}kPa_{:.0f}ms_{}{}'.format( neuron, 'CW' if DC == 1 else 'PW', a * 1e9, Fdrive * 1e-3, Adrive * 1e-3, tstim * 1e3, 'PRF{:.2f}Hz_DC{:.2f}%_'.format(PRF, DC * 1e2) if DC < 1. else '', method) def MECH_filecode(a, Fdrive, Adrive, Qm): return 'MECH_{:.0f}nm_{:.0f}kHz_{:.1f}kPa_{:.1f}nCcm2'.format( a * 1e9, Fdrive * 1e-3, Adrive * 1e-3, Qm * 1e5) # Figure naming conventions def figtitle(meta): ''' Return appropriate title based on simulation metadata. ''' if 'Cm0' in meta: return '{:.0f}nm radius BLS structure: MECH-STIM {:.0f}kHz, {:.2f}kPa, {:.1f}nC/cm2'.format( meta['a'] * 1e9, meta['Fdrive'] * 1e-3, meta['Adrive'] * 1e-3, meta['Qm'] * 1e5) else: if meta['DC'] < 1: wavetype = 'PW' suffix = ', {:.2f}Hz PRF, {:.0f}% DC'.format(meta['PRF'], meta['DC'] * 1e2) else: wavetype = 'CW' suffix = '' if 'Astim' in meta: return '{} neuron: {} E-STIM {:.2f}mA/m2, {:.0f}ms{}'.format( meta['neuron'], wavetype, meta['Astim'], meta['tstim'] * 1e3, suffix) else: return '{} neuron ({:.1f}nm): {} A-STIM {:.0f}kHz {:.2f}kPa, {:.0f}ms{} - {} model'.format( meta['neuron'], meta['a'] * 1e9, wavetype, meta['Fdrive'] * 1e-3, meta['Adrive'] * 1e-3, meta['tstim'] * 1e3, suffix, meta['method']) # SI units prefixes si_prefixes = { 'y': 1e-24, # yocto 'z': 1e-21, # zepto 'a': 1e-18, # atto 'f': 1e-15, # femto 'p': 1e-12, # pico 'n': 1e-9, # nano 'u': 1e-6, # micro 'm': 1e-3, # mili '': 1e0, # None 'k': 1e3, # kilo 'M': 1e6, # mega 'G': 1e9, # giga 'T': 1e12, # tera 'P': 1e15, # peta 'E': 1e18, # exa 'Z': 1e21, # zetta 'Y': 1e24, # yotta } def loadData(fpath, frequency=1): ''' Load dataframe and metadata dictionary from pickle file. ''' logger.info('Loading data from "%s"', os.path.basename(fpath)) with open(fpath, 'rb') as fh: frame = pickle.load(fh) df = frame['data'].iloc[::frequency] meta = frame['meta'] return df, meta def si_format(x, precision=0, space=' '): ''' Format a float according to the SI unit system, with the appropriate prefix letter. ''' if isinstance(x, float) or isinstance(x, int) or isinstance(x, np.float) or\ isinstance(x, np.int32) or isinstance(x, np.int64): if x == 0: factor = 1e0 prefix = '' else: sorted_si_prefixes = sorted(si_prefixes.items(), key=operator.itemgetter(1)) vals = [tmp[1] for tmp in sorted_si_prefixes] # vals = list(si_prefixes.values()) ix = np.searchsorted(vals, np.abs(x)) - 1 if np.abs(x) == vals[ix + 1]: ix += 1 factor = vals[ix] prefix = sorted_si_prefixes[ix][0] # prefix = list(si_prefixes.keys())[ix] return '{{:.{}f}}{}{}'.format(precision, space, prefix).format(x / factor) elif isinstance(x, list) or isinstance(x, tuple): return [si_format(item, precision, space) for item in x] elif isinstance(x, np.ndarray) and x.ndim == 1: return [si_format(float(item), precision, space) for item in x] else: print(type(x)) def pow10_format(number, precision=2): ''' Format a number in power of 10 notation. ''' ret_string = '{0:.{1:d}e}'.format(number, precision) a, b = ret_string.split("e") a = float(a) b = int(b) return '{}10^{{{}}}'.format('{} * '.format(a) if a != 1. else '', b) def rmse(x1, x2): ''' Compute the root mean square error between two 1D arrays ''' return np.sqrt(((x1 - x2) ** 2).mean()) def rsquared(x1, x2): ''' compute the R-squared coefficient between two 1D arrays ''' residuals = x1 - x2 ss_res = np.sum(residuals**2) ss_tot = np.sum((x1 - np.mean(x1))**2) return 1 - (ss_res / ss_tot) def Pressure2Intensity(p, rho=1075.0, c=1515.0): ''' Return the spatial peak, pulse average acoustic intensity (ISPPA) associated with the specified pressure amplitude. :param p: pressure amplitude (Pa) :param rho: medium density (kg/m3) :param c: speed of sound in medium (m/s) :return: spatial peak, pulse average acoustic intensity (W/m2) ''' return p**2 / (2 * rho * c) def Intensity2Pressure(I, rho=1075.0, c=1515.0): ''' Return the pressure amplitude associated with the specified spatial peak, pulse average acoustic intensity (ISPPA). :param I: spatial peak, pulse average acoustic intensity (W/m2) :param rho: medium density (kg/m3) :param c: speed of sound in medium (m/s) :return: pressure amplitude (Pa) ''' return np.sqrt(2 * rho * c * I) def OpenFilesDialog(filetype, dirname=''): ''' Open a FileOpenDialogBox to select one or multiple file. The default directory and file type are given. :param dirname: default directory :param filetype: default file type :return: tuple of full paths to the chosen filenames ''' root = tk.Tk() root.withdraw() filenames = filedialog.askopenfilenames(filetypes=[(filetype + " files", '.' + filetype)], initialdir=dirname) if filenames: par_dir = os.path.abspath(os.path.join(filenames[0], os.pardir)) else: par_dir = None return (filenames, par_dir) def selectDirDialog(): ''' Open a dialog box to select a directory. :return: full path to selected directory ''' root = tk.Tk() root.withdraw() return filedialog.askdirectory() def SaveFileDialog(filename, dirname=None, ext=None): ''' Open a dialog box to save file. :param filename: filename :param dirname: initial directory :param ext: default extension :return: full path to the chosen filename ''' root = tk.Tk() root.withdraw() filename_out = filedialog.asksaveasfilename( defaultextension=ext, initialdir=dirname, initialfile=filename) return filename_out def downsample(t_dense, y, nsparse): ''' Decimate periodic signals to a specified number of samples.''' if(y.ndim) > 1: nsignals = y.shape[0] else: nsignals = 1 y = np.array([y]) # determine time step and period of input signal T = t_dense[-1] - t_dense[0] dt_dense = t_dense[1] - t_dense[0] # resample time vector linearly t_ds = np.linspace(t_dense[0], t_dense[-1], nsparse) # create MAV window nmav = int(0.03 * T / dt_dense) if nmav % 2 == 0: nmav += 1 mav = np.ones(nmav) / nmav # determine signals padding npad = int((nmav - 1) / 2) # determine indexes of sampling on convolved signals ids = np.round(np.linspace(0, t_dense.size - 1, nsparse)).astype(int) y_ds = np.empty((nsignals, nsparse)) # loop through signals for i in range(nsignals): # pad, convolve and resample pad_left = y[i, -(npad + 2):-2] pad_right = y[i, 1:npad + 1] y_ext = np.concatenate((pad_left, y[i, :], pad_right), axis=0) y_mav = np.convolve(y_ext, mav, mode='valid') y_ds[i, :] = y_mav[ids] if nsignals == 1: y_ds = y_ds[0, :] return (t_ds, y_ds) def rescale(x, lb=None, ub=None, lb_new=0, ub_new=1): ''' Rescale a value to a specific interval by linear transformation. ''' if lb is None: lb = x.min() if ub is None: ub = x.max() xnorm = (x - lb) / (ub - lb) return xnorm * (ub_new - lb_new) + lb_new def getNeuronLookupsFile(mechname, a=None, Fdrive=None, Adrive=None, fs=False): fpath = os.path.join( os.path.split(__file__)[0], 'neurons', '{}_lookups'.format(mechname) ) if a is not None: fpath += '_{:.0f}nm'.format(a * 1e9) if Fdrive is not None: fpath += '_{:.0f}kHz'.format(Fdrive * 1e-3) if Adrive is not None: fpath += '_{:.0f}kPa'.format(Adrive * 1e-3) if fs is True: fpath += '_fs' return '{}.pkl'.format(fpath) def getLookups4D(mechname): ''' Retrieve 4D lookup tables and reference vectors for a given membrane mechanism. :param mechname: name of membrane density mechanism :return: 4-tuple with 1D numpy arrays of reference input vectors (charge density and one other variable), a dictionary of associated 2D lookup numpy arrays, and a dictionnary with information about the other variable. ''' # Check lookup file existence lookup_path = getNeuronLookupsFile(mechname) if not os.path.isfile(lookup_path): raise FileNotFoundError('Missing lookup file: "{}"'.format(lookup_path)) # Load lookups dictionary logger.debug('Loading lookup table') with open(lookup_path, 'rb') as fh: df = pickle.load(fh) inputs = df['input'] lookups4D = df['lookup'] # Retrieve 1D inputs from lookups dictionary aref = inputs['a'] Fref = inputs['f'] Aref = inputs['A'] Qref = inputs['Q'] return aref, Fref, Aref, Qref, lookups4D def getLookupsOff(mechname): ''' Retrieve appropriate US-OFF lookup tables and reference vectors for a given membrane mechanism. :param mechname: name of membrane density mechanism :return: 2-tuple with 1D numpy array of reference charge density and dictionary of associated 1D lookup numpy arrays. ''' # Get 4D lookups and input vectors aref, Fref, Aref, Qref, lookups4D = getLookups4D(mechname) # Perform 2D projection in appropriate dimensions logger.debug('Interpolating lookups at A = 0') lookups_off = {key: y4D[0, 0, 0, :] for key, y4D in lookups4D.items()} return Qref, lookups_off def getLookups2D(mechname, a=None, Fdrive=None, Adrive=None): ''' Retrieve appropriate 2D lookup tables and reference vectors for a given membrane mechanism, projected at a specific combination of sonophore radius, US frequency and/or acoustic pressure amplitude. :param mechname: name of membrane density mechanism :param a: sonophore radius (m) :param Fdrive: US frequency (Hz) :param Adrive: Acoustic peak pressure ampplitude (Hz) :return: 4-tuple with 1D numpy arrays of reference input vectors (charge density and one other variable), a dictionary of associated 2D lookup numpy arrays, and a dictionnary with information about the other variable. ''' # Get 4D lookups and input vectors aref, Fref, Aref, Qref, lookups4D = getLookups4D(mechname) # Check that inputs are within lookup range if a is not None: a = isWithin('radius', a, (aref.min(), aref.max())) if Fdrive is not None: Fdrive = isWithin('frequency', Fdrive, (Fref.min(), Fref.max())) if Adrive is not None: Adrive = isWithin('amplitude', Adrive, (Aref.min(), Aref.max())) # Determine projection dimensions based on inputs var_a = {'name': 'a', 'label': 'sonophore radius', 'val': a, 'unit': 'm', 'factor': 1e9, 'ref': aref, 'axis': 0} var_Fdrive = {'name': 'f', 'label': 'frequency', 'val': Fdrive, 'unit': 'Hz', 'factor': 1e-3, 'ref': Fref, 'axis': 1} var_Adrive = {'name': 'A', 'label': 'amplitude', 'val': Adrive, 'unit': 'Pa', 'factor': 1e-3, 'ref': Aref, 'axis': 2} if not isinstance(Adrive, float): var1 = var_a var2 = var_Fdrive var3 = var_Adrive elif not isinstance(Fdrive, float): var1 = var_a var2 = var_Adrive var3 = var_Fdrive elif not isinstance(a, float): var1 = var_Fdrive var2 = var_Adrive var3 = var_a # Perform 2D projection in appropriate dimensions logger.debug('Interpolating lookups at (%s = %s%s, %s = %s%s)', var1['name'], si_format(var1['val'], space=' '), var1['unit'], var2['name'], si_format(var2['val'], space=' '), var2['unit']) lookups3D = {key: interp1d(var1['ref'], y4D, axis=var1['axis'])(var1['val']) for key, y4D in lookups4D.items()} if var2['axis'] > var1['axis']: var2['axis'] -= 1 lookups2D = {key: interp1d(var2['ref'], y3D, axis=var2['axis'])(var2['val']) for key, y3D in lookups3D.items()} if var3['val'] is not None: logger.debug('Interpolating lookups at %d new %s values between %s%s and %s%s', len(var3['val']), var3['name'], si_format(min(var3['val']), space=' '), var3['unit'], si_format(max(var3['val']), space=' '), var3['unit']) lookups2D = {key: interp1d(var3['ref'], y2D, axis=0)(var3['val']) for key, y2D in lookups2D.items()} var3['ref'] = np.array(var3['val']) return var3['ref'], Qref, lookups2D, var3 def getLookups2Dfs(mechname, a, Fdrive, fs): # Check lookup file existence lookup_path = getNeuronLookupsFile(mechname, a=a, Fdrive=Fdrive, fs=True) if not os.path.isfile(lookup_path): raise FileNotFoundError('Missing lookup file: "{}"'.format(lookup_path)) # Load lookups dictionary logger.debug('Loading lookup table') with open(lookup_path, 'rb') as fh: df = pickle.load(fh) inputs = df['input'] lookups3D = df['lookup'] # Retrieve 1D inputs from lookups dictionary fsref = inputs['fs'] Aref = inputs['A'] Qref = inputs['Q'] # Check that fs is within lookup range fs = isWithin('coverage', fs, (fsref.min(), fsref.max())) # Perform projection at fs logger.debug('Interpolating lookups at fs = %s%%', fs * 1e2) lookups2D = {key: interp1d(fsref, y3D, axis=2)(fs) for key, y3D in lookups3D.items()} return Aref, Qref, lookups2D def isWithin(name, val, bounds, rel_tol=1e-9): ''' Check if a floating point number is within an interval. If the value falls outside the interval, an error is raised. If the value falls just outside the interval due to rounding errors, the associated interval bound is returned. :param val: float value :param bounds: interval bounds (float tuple) :return: original or corrected value ''' if isinstance(val, list) or isinstance(val, np.ndarray) or isinstance(val, tuple): return [isWithin(name, v, bounds, rel_tol) for v in val] if val >= bounds[0] and val <= bounds[1]: return val elif val < bounds[0] and math.isclose(val, bounds[0], rel_tol=rel_tol): logger.warning('Rounding %s value (%s) to interval lower bound (%s)', name, val, bounds[0]) return bounds[0] elif val > bounds[1] and math.isclose(val, bounds[1], rel_tol=rel_tol): logger.warning('Rounding %s value (%s) to interval upper bound (%s)', name, val, bounds[1]) return bounds[1] else: raise ValueError('{} value ({}) out of [{}, {}] interval'.format( name, val, bounds[0], bounds[1])) def getLookupsCompTime(mechname): # Check lookup file existence lookup_path = getNeuronLookupsFile(mechname) if not os.path.isfile(lookup_path): raise FileNotFoundError('Missing lookup file: "{}"'.format(lookup_path)) # Load lookups dictionary logger.debug('Loading comp times') with open(lookup_path, 'rb') as fh: df = pickle.load(fh) tcomps4D = df['tcomp'] return np.sum(tcomps4D) def getLowIntensitiesSTN(): ''' Return an array of acoustic intensities (W/m2) used to study the STN neuron in Tarnaud, T., Joseph, W., Martens, L., and Tanghe, E. (2018). Computational Modeling of Ultrasonic Subthalamic Nucleus Stimulation. IEEE Trans Biomed Eng. ''' return np.hstack(( np.arange(10, 101, 10), np.arange(101, 131, 1), np.array([140]) )) # W/m2 + + + +def getIndex(container, value): + ''' Return the index of a float / string value in a list / array + + :param container: list / 1D-array of elements + :param value: value to search for + :return: index of value (if found) + ''' + if isinstance(value, float): + container = np.array(container) + imatches = np.where(np.isclose(container, value, rtol=1e-9, atol=1e-16))[0] + if len(imatches) == 0: + raise ValueError('{} not found in {}'.format(value, container)) + return imatches[0] + elif isinstance(value, str): + return container.index(value)