diff --git a/PySONIC/neurons/sundt.py b/PySONIC/neurons/sundt.py index ae6ce80..f608d36 100644 --- a/PySONIC/neurons/sundt.py +++ b/PySONIC/neurons/sundt.py @@ -1,313 +1,311 @@ # -*- coding: utf-8 -*- # @Author: Mariia Popova # @Email: theo.lemaire@epfl.ch # @Date: 2019-10-03 15:58:38 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-11-06 10:22:31 +# @Last Modified time: 2019-11-06 19:27:27 import numpy as np from ..core import PointNeuron from ..constants import CELSIUS_2_KELVIN, FARADAY, Rg, Z_Ca from ..utils import findModifiedEq class Sundt(PointNeuron): ''' Unmyelinated C-fiber model. Reference: *Sundt D., Gamper N., Jaffe D. B., Spike propagation through the dorsal root ganglia in an unmyelinated sensory neuron: a modeling study. Journal of Neurophysiology (2015)* ''' # Neuron name name = 'sundt' # ------------------------------ Biophysical parameters ------------------------------ # Resting parameters Cm0 = 1e-2 # Membrane capacitance (F/m2) Vm0 = -60. # Membrane potential (mV) # Reversal potentials (mV) ENa = 55.0 # Sodium EK = -90.0 # Potassium # Maximal channel conductances (S/m2) gNabar = 400.0 # Sodium gKdbar = 400.0 # Delayed-rectifier Potassium - gMbar = 4.0 # Slow non-inactivating Potassium - gCaLbar = 30 # High-threshold Calcium (???) - gKCabar = 2.0 # Calcium dependent Potassium (???) + gMbar = 3.1 # Slow non-inactivating Potassium (from MOD file, paper studies 2-8 S/m2 range) gLeak = 1e0 # Non-specific leakage + # gCaLbar = 30 # High-threshold Calcium (from MOD file, but only in soma !!!) + # gKCabar = 2.0 # Calcium dependent Potassium (only in soma !!!) # Na+ current parameters deltaVm = 6.0 # Voltage offset to shift the rate constants (6 mV in Sundt 2015) # Kd current parameters (Borg Graham 1987 for the formalism, Migliore 1995 for the values) alphan0 = 0.03 # ms-1 alphal0 = 0.001 # ms-1 betan0 = alphan0 # ms-1 betal0 = alphal0 # ms-1 Vhalfn = -32 # Membrane voltage at which alphan = alphan0 and betan = betan0 (mV) Vhalfl = -61 # Membrane voltage at which alphal = alphal0 and betal = betal0 (mV) zn = 5 # Effective valence of the n-gating particle zl = -2 # Effective valence of the l-gating particle gamman = 0.4 # Normalized position of the n-transition state within the membrane gammal = 1 # Normalized position of the l-transition state within the membrane # iM parameters taupMax = 1.0 # Max. adaptation decay of slow non-inactivating Potassium current (s) # Ca2+ parameters - Cao = 2e-3 # Extracellular Calcium concentration (M) - Cai0 = 70e-9 # Intracellular Calcium concentration at rest (M) (Aradi 1999) - deff = 200e-9 # effective depth beneath membrane for intracellular [Ca2+] calculation (m) - taur_Cai = 20e-3 # decay time constant for intracellular Ca2+ dissolution (s) + # Cao = 2e-3 # Extracellular Calcium concentration (M) + # Cai0 = 70e-9 # Intracellular Calcium concentration at rest (M) (Aradi 1999) + # deff = 200e-9 # effective depth beneath membrane for intracellular [Ca2+] calculation (m) + # taur_Cai = 20e-3 # decay time constant for intracellular Ca2+ dissolution (s) # iKCa parameters - Ca_factor = 1e6 # conversion factor for q-gate Calcium sensitivity (expressed in uM) - Ca_power = 3 # power exponent for q-gate Calcium sensitivity (-) + # Ca_factor = 1e6 # conversion factor for q-gate Calcium sensitivity (expressed in uM) + # Ca_power = 3 # power exponent for q-gate Calcium sensitivity (-) # Additional parameters celsius = 35.0 # Temperature (Celsius) celsius_Traub = 30.0 # Temperature in Traub 1991 (Celsius) celsius_Yamada = 23.5 # Temperature in Yamada 1989 (Celsius) # ------------------------------ States names & descriptions ------------------------------ states = { 'm': 'iNa activation gate', 'h': 'iNa inactivation gate', 'n': 'iKd activation gate', 'l': 'iKd inactivation gate', 'p': 'iM gate', - 'c': 'iCaL gate', - 'q': 'iKCa Calcium dependent gate', - 'Cai': 'Calcium intracellular concentration (M)' + # 'c': 'iCaL gate', + # 'q': 'iKCa Calcium dependent gate', + # 'Cai': 'Calcium intracellular concentration (M)' } def __new__(cls): cls.q10_Traub = 3**((cls.celsius - cls.celsius_Traub) / 10) cls.q10_Yamada = 3**((cls.celsius - cls.celsius_Yamada) / 10) cls.T = cls.celsius + CELSIUS_2_KELVIN - cls.current_to_molar_rate_Ca = cls.currentToConcentrationRate(Z_Ca, cls.deff) + # cls.current_to_molar_rate_Ca = cls.currentToConcentrationRate(Z_Ca, cls.deff) cls.Vref = Rg * cls.T / FARADAY * 1e3 # reference voltagte for iKd rate constants (mV) - # Compute total current at resting potential, without iLeak + # Compute Eleak such that iLeak cancels out the net current at resting potential sstates = {k: cls.steadyStates()[k](cls.Vm0) for k in cls.statesNames()} i_dict = cls.currents() del i_dict['iLeak'] iNet = sum([cfunc(cls.Vm0, sstates) for cfunc in i_dict.values()]) # mA/m2 - - # Compute Eleak such that iLeak cancels out the net current at resting potential cls.ELeak = cls.Vm0 + iNet / cls.gLeak # mV # print(f'Eleak = {cls.ELeak:.2f} mV') return super(Sundt, cls).__new__(cls) - @classmethod - def getPltScheme(cls): - pltscheme = super().getPltScheme() - pltscheme['[Ca^{2+}]_i'] = ['Cai'] - return pltscheme + # @classmethod + # def getPltScheme(cls): + # pltscheme = super().getPltScheme() + # pltscheme['[Ca^{2+}]_i'] = ['Cai'] + # return pltscheme - @classmethod - def getPltVars(cls, wrapleft='df["', wrapright='"]'): - return {**super().getPltVars(wrapleft, wrapright), **{ - 'Cai': { - 'desc': 'sumbmembrane Ca2+ concentration', - 'label': '[Ca^{2+}]_i', - 'unit': 'uM', - 'factor': 1e6 - } - }} + # @classmethod + # def getPltVars(cls, wrapleft='df["', wrapright='"]'): + # return {**super().getPltVars(wrapleft, wrapright), **{ + # 'Cai': { + # 'desc': 'sumbmembrane Ca2+ concentration', + # 'label': '[Ca^{2+}]_i', + # 'unit': 'uM', + # 'factor': 1e6 + # } + # }} # ------------------------------ Gating states kinetics ------------------------------ # iNa kinetics: adapted from Traub 1991, with 2 notable changes: # - Q10 correction to account for temperature adaptation from 30 to 35 degrees # - 6 mV voltage offset in the activation and inactivation rates to shift iNa voltage dependence # approximately midway between values reported for Nav1.7 and Nav1.8 currents. @classmethod def alpham(cls, Vm): Vm += cls.deltaVm return cls.q10_Traub * 0.32 * cls.vtrap((13.1 - Vm), 4) * 1e3 # s-1 @classmethod def betam(cls, Vm): Vm += cls.deltaVm return cls.q10_Traub * 0.28 * cls.vtrap((Vm - 40.1), 5) * 1e3 # s-1 @classmethod def alphah(cls, Vm): Vm += cls.deltaVm return cls.q10_Traub * 0.128 * np.exp((17.0 - Vm) / 18) * 1e3 # s-1 @classmethod def betah(cls, Vm): Vm += cls.deltaVm return cls.q10_Traub * 4 / (1 + np.exp((40.0 - Vm) / 5)) * 1e3 # s-1 # iKd kinetics: using Migliore 1995 values, with Borg-Graham 1991 formalism @classmethod def alphan(cls, Vm): return cls.alphan0 * np.exp(cls.zn * cls.gamman * (Vm - cls.Vhalfn) / cls.Vref) * 1e3 # s-1 @classmethod def betan(cls, Vm): return cls.betan0 * np.exp(-cls.zn * (1 - cls.gamman) * (Vm - cls.Vhalfn) / cls.Vref) * 1e3 # s-1 @classmethod def alphal(cls, Vm): return cls.alphal0 * np.exp(cls.zl * cls.gammal * (Vm - cls.Vhalfl) / cls.Vref) * 1e3 # s-1 @classmethod def betal(cls, Vm): return cls.betal0 * np.exp(-cls.zl * (1 - cls.gammal) * (Vm - cls.Vhalfl) / cls.Vref) * 1e3 # s-1 # iM kinetics: taken from Yamada 1989, with notable changes: # - Q10 correction to account for temperature adaptation from 23.5 to 35 degrees # - not sure about tau_p formulation (3.3 factor multiplying first-only or both exponential terms ???) @staticmethod def pinf(Vm): return 1.0 / (1 + np.exp(-(Vm + 35) / 10)) @classmethod def taup(cls, Vm): tau = cls.taupMax / (3.3 * (np.exp((Vm + 35) / 20) + np.exp(-(Vm + 35) / 20))) # s return tau * cls.q10_Yamada # iCaL kinetics: from Migliore 1995 that itself refers to Jaffe 1994. @classmethod def alphac(cls, Vm): return 15.69 * cls.vtrap((81.5 - Vm), 10.) * 1e3 # s-1 @classmethod def betac(cls, Vm): return 0.29 * np.exp(-Vm / 10.86) * 1e3 # s-1 - # iKCa kinetics: from Aradi 1999, which uses equations from Yuen 1991 with a few modifications: - # - 12 mV (???) shift in activation curve - # - log10 instead of log for Ca2+ sensitivity - # - global dampening factor of 1.67 applied on both rates - # Sundt 2015 applies an extra modification: - # - higher Calcium sensitivity (third power of Ca concentration) - # Also, there is an error in the alphaq denominator in the paper: using -4 instead of -4.5 + # # iKCa kinetics: from Aradi 1999, which uses equations from Yuen 1991 with a few modifications: + # # - 12 mV (???) shift in activation curve + # # - log10 instead of log for Ca2+ sensitivity + # # - global dampening factor of 1.67 applied on both rates + # # Sundt 2015 applies an extra modification: + # # - higher Calcium sensitivity (third power of Ca concentration) + # # Also, there is an error in the alphaq denominator in the paper: using -4 instead of -4.5 - @classmethod - def alphaq(cls, Cai): - return 0.00246 / np.exp((12 * np.log10((Cai * cls.Ca_factor)**cls.Ca_power) + 28.48) / -4.5) * 1e3 # s-1 + # @classmethod + # def alphaq(cls, Cai): + # return 0.00246 / np.exp((12 * np.log10((Cai * cls.Ca_factor)**cls.Ca_power) + 28.48) / -4.5) * 1e3 # s-1 - @classmethod - def betaq(cls, Cai): - return 0.006 / np.exp((12 * np.log10((Cai * cls.Ca_factor)**cls.Ca_power) + 60.4) / 35) * 1e3 # s-1 + # @classmethod + # def betaq(cls, Cai): + # return 0.006 / np.exp((12 * np.log10((Cai * cls.Ca_factor)**cls.Ca_power) + 60.4) / 35) * 1e3 # s-1 # ------------------------------ States derivatives ------------------------------ - # Ca2+ dynamics: using accumulation-dissolution formalism as in Aradi, with - # a longer Ca2+ intracellular dissolution time constant (20 ms vs. 9 ms) - - @classmethod - def derCai(cls, c, Cai, Vm): - return -cls.current_to_molar_rate_Ca * cls.iCaL(c, Cai, Vm) - (Cai - cls.Cai0) / cls.taur_Cai # M/s + # @classmethod + # def derCai(cls, c, Cai, Vm): + # ''' Using accumulation-dissolution formalism as in Aradi, with + # a longer Ca2+ intracellular dissolution time constant (20 ms vs. 9 ms). + # ''' + # return -cls.current_to_molar_rate_Ca * cls.iCaL(c, Cai, Vm) - (Cai - cls.Cai0) / cls.taur_Cai # M/s @classmethod def derStates(cls): return { 'm': lambda Vm, x: cls.alpham(Vm) * (1 - x['m']) - cls.betam(Vm) * x['m'], 'h': lambda Vm, x: cls.alphah(Vm) * (1 - x['h']) - cls.betah(Vm) * x['h'], 'n': lambda Vm, x: cls.alphan(Vm) * (1 - x['n']) - cls.betan(Vm) * x['n'], 'l': lambda Vm, x: cls.alphal(Vm) * (1 - x['l']) - cls.betal(Vm) * x['l'], 'p': lambda Vm, x: (cls.pinf(Vm) - x['p']) / cls.taup(Vm), - 'c': lambda Vm, x: cls.alphac(Vm) * (1 - x['c']) - cls.betac(Vm) * x['c'], - 'q': lambda Vm, x: cls.alphaq(x['Cai']) * (1 - x['q']) - cls.betaq(x['Cai']) * x['q'], - 'Cai': lambda Vm, x: cls.derCai(x['c'], x['Cai'], Vm) + # 'c': lambda Vm, x: cls.alphac(Vm) * (1 - x['c']) - cls.betac(Vm) * x['c'], + # 'q': lambda Vm, x: cls.alphaq(x['Cai']) * (1 - x['q']) - cls.betaq(x['Cai']) * x['q'], + # 'Cai': lambda Vm, x: cls.derCai(x['c'], x['Cai'], Vm) } # ------------------------------ Steady states ------------------------------ - @classmethod - def qinf(cls, Cai): - return cls.alphaq(Cai) / (cls.alphaq(Cai) + cls.betaq(Cai)) + # @classmethod + # def qinf(cls, Cai): + # return cls.alphaq(Cai) / (cls.alphaq(Cai) + cls.betaq(Cai)) - @classmethod - def Caiinf(cls, c, Vm): - return findModifiedEq( - cls.Cai0, - lambda Cai, c, Vm: cls.derCai(c, Cai, Vm), - c, Vm - ) + # @classmethod + # def Caiinf(cls, c, Vm): + # return findModifiedEq( + # cls.Cai0, + # lambda Cai, c, Vm: cls.derCai(c, Cai, Vm), + # c, Vm + # ) @classmethod def steadyStates(cls): lambda_dict = { 'm': lambda Vm: cls.alpham(Vm) / (cls.alpham(Vm) + cls.betam(Vm)), 'h': lambda Vm: cls.alphah(Vm) / (cls.alphah(Vm) + cls.betah(Vm)), 'n': lambda Vm: cls.alphan(Vm) / (cls.alphan(Vm) + cls.betan(Vm)), 'l': lambda Vm: cls.alphal(Vm) / (cls.alphal(Vm) + cls.betal(Vm)), 'p': lambda Vm: cls.pinf(Vm), - 'c': lambda Vm: cls.alphac(Vm) / (cls.alphac(Vm) + cls.betac(Vm)), + # 'c': lambda Vm: cls.alphac(Vm) / (cls.alphac(Vm) + cls.betac(Vm)), } - lambda_dict['Cai'] = lambda Vm: cls.Caiinf(lambda_dict['c'](Vm), Vm) - lambda_dict['q'] = lambda Vm: cls.qinf(lambda_dict['Cai'](Vm)) + # lambda_dict['Cai'] = lambda Vm: cls.Caiinf(lambda_dict['c'](Vm), Vm) + # lambda_dict['q'] = lambda Vm: cls.qinf(lambda_dict['Cai'](Vm)) return lambda_dict # ------------------------------ Membrane currents ------------------------------ # Sodium current: inconsistency with 1991 ref: m2h vs. m3h @classmethod def iNa(cls, m, h, Vm): ''' Sodium current. Gating formalism from Migliore 1995, using 3rd power for m in order to reproduce thinner AP waveform (half-width of ca. 1 ms) ''' return cls.gNabar * m**3 * h * (Vm - cls.ENa) # mA/m2 @classmethod def iKd(cls, n, l, Vm): ''' delayed-rectifier Potassium current ''' return cls.gKdbar * n**3 * l * (Vm - cls.EK) # mA/m2 @classmethod def iM(cls, p, Vm): ''' slow non-inactivating Potassium current ''' return cls.gMbar * p * (Vm - cls.EK) # mA/m2 - @classmethod - def iCaL(cls, c, Cai, Vm): - ''' Calcium current ''' - ECa = cls.nernst(Z_Ca, Cai, cls.Cao, cls.T) # mV - return cls.gCaLbar * c**2 * (Vm - ECa) # mA/m2 + # @classmethod + # def iCaL(cls, c, Cai, Vm): + # ''' Calcium current ''' + # ECa = cls.nernst(Z_Ca, Cai, cls.Cao, cls.T) # mV + # return cls.gCaLbar * c**2 * (Vm - ECa) # mA/m2 - @classmethod - def iKCa(cls, q, Vm): - ''' Calcium-dependent Potassium current ''' - return cls.gKCabar * q**2 * (Vm - cls.EK) # mA/m2 + # @classmethod + # def iKCa(cls, q, Vm): + # ''' Calcium-dependent Potassium current ''' + # return cls.gKCabar * q**2 * (Vm - cls.EK) # mA/m2 @classmethod def iLeak(cls, Vm): ''' non-specific leakage current ''' return cls.gLeak * (Vm - cls.ELeak) # mA/m2 @classmethod def currents(cls): return { 'iNa': lambda Vm, x: cls.iNa(x['m'], x['h'], Vm), 'iKd': lambda Vm, x: cls.iKd(x['n'], x['l'], Vm), 'iM': lambda Vm, x: cls.iM(x['p'], Vm), - 'iCaL': lambda Vm, x: cls.iCaL(x['c'], x['Cai'], Vm), - 'iKCa': lambda Vm, x: cls.iKCa(x['q'], Vm), + # 'iCaL': lambda Vm, x: cls.iCaL(x['c'], x['Cai'], Vm), + # 'iKCa': lambda Vm, x: cls.iKCa(x['q'], Vm), 'iLeak': lambda Vm, _: cls.iLeak(Vm) } def chooseTimeStep(self): ''' neuron-specific time step for fast dynamics. ''' return super().chooseTimeStep() * 1e-2 \ No newline at end of file diff --git a/PySONIC/parsers.py b/PySONIC/parsers.py index 0379160..5a02f75 100644 --- a/PySONIC/parsers.py +++ b/PySONIC/parsers.py @@ -1,699 +1,714 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2019-06-04 18:24:29 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-11-06 15:11:14 +# @Last Modified time: 2019-11-06 15:59:45 import os import logging import pprint import numpy as np import matplotlib.pyplot as plt from argparse import ArgumentParser from .utils import Intensity2Pressure, selectDirDialog, OpenFilesDialog, isIterable from .neurons import getPointNeuron, CorticalRS from .plt import GroupedTimeSeries, CompTimeSeries DEFAULT_OUTPUT_FOLDER = os.path.abspath(os.path.split(__file__)[0] + '../../../../dump') class Parser(ArgumentParser): ''' Generic parser interface. ''' dist_str = '[scale min max n]' def __init__(self): super().__init__() self.pp = pprint.PrettyPrinter(indent=4) self.defaults = {} self.allowed = {} self.factors = {} self.to_parse = {} self.addPlot() self.addVerbose() def pprint(self, args): self.pp.pprint(args) def getDistribution(self, xmin, xmax, nx, scale='lin'): if scale == 'log': xmin, xmax = np.log10(xmin), np.log10(xmax) return {'lin': np.linspace, 'log': np.logspace}[scale](xmin, xmax, nx) def getDistFromList(self, xlist): if not isinstance(xlist, list): raise TypeError('Input must be a list') if len(xlist) != 4: raise ValueError('List must contain exactly 4 arguments ([type, min, max, n])') scale = xlist[0] if scale not in ('log', 'lin'): raise ValueError('Unknown distribution type (must be "lin" or "log")') xmin, xmax = [float(x) for x in xlist[1:-1]] if xmin >= xmax: raise ValueError('Specified minimum higher or equal than specified maximum') nx = int(xlist[-1]) if nx < 2: raise ValueError('Specified number must be at least 2') return self.getDistribution(xmin, xmax, nx, scale=scale) def addVerbose(self): self.add_argument( '-v', '--verbose', default=False, action='store_true', help='Increase verbosity') self.to_parse['loglevel'] = self.parseLogLevel def addPlot(self): self.add_argument( '-p', '--plot', type=str, nargs='+', help='Variables to plot') self.to_parse['pltscheme'] = self.parsePltScheme + def addPhase(self): + self.add_argument( + '--phase', default=False, action='store_true', help='Phase plot') + def addMPI(self): self.add_argument( '--mpi', default=False, action='store_true', help='Use multiprocessing') def addTest(self): self.add_argument( '--test', default=False, action='store_true', help='Run test configuration') def addSave(self): self.add_argument( '-s', '--save', default=False, action='store_true', help='Save output(s)') def addCheckForOutput(self): self.add_argument( '--checkout', default=False, action='store_true', help='Run only simulations for which there is no output file in the output directory') self.to_parse['overwrite'] = self.parseOverwrite def addOverwrite(self): self.add_argument( '--overwrite', default=False, action='store_true', help='Overwrite pre-existing simulation files with new output') def addFigureExtension(self): self.add_argument( '--figext', type=str, default='png', help='Figure type (extension)') def addCompare(self, desc='Comparative graph'): self.add_argument( '--compare', default=False, action='store_true', help=desc) def addSamplingRate(self): self.add_argument( '--sr', type=int, default=1, help='Sampling rate for plot') def addSpikes(self): self.add_argument( '--spikes', type=str, default='none', help='How to indicate spikes on charge profile ("none", "marks" or "details")') def addNColumns(self): self.add_argument( '--ncol', type=int, default=1, help='Number of columns in figure') def addNLevels(self): self.add_argument( '--nlevels', type=int, default=10, help='Number of levels') def addHideOutput(self): self.add_argument( '--hide', default=False, action='store_true', help='Hide output') def addTimeRange(self, default=None): self.add_argument( '--trange', type=float, nargs=2, default=default, help='Time lower and upper bounds (ms)') self.to_parse['trange'] = self.parseTimeRange def addPotentialBounds(self, default=None): self.add_argument( '--Vbounds', type=float, nargs=2, default=default, help='Membrane potential lower and upper bounds (mV)') def addFiringRateBounds(self, default): self.add_argument( '--FRbounds', type=float, nargs=2, default=default, help='Firing rate lower and upper bounds (Hz)') def addFiringRateScale(self, default='lin'): self.add_argument( '--FRscale', type=str, choices=('lin', 'log'), default=default, help='Firing rate scale for plot ("lin" or "log")') def addCmap(self, default=None): self.add_argument( '--cmap', type=str, default=default, help='Colormap name') def addCscale(self, default='lin'): self.add_argument( '--cscale', type=str, default=default, choices=('lin', 'log'), help='Color scale ("lin" or "log")') def addInputDir(self, dep_key=None): self.inputdir_dep_key = dep_key self.add_argument( '-i', '--inputdir', type=str, help='Input directory') self.to_parse['inputdir'] = self.parseInputDir def addOutputDir(self, dep_key=None): self.outputdir_dep_key = dep_key self.add_argument( '-o', '--outputdir', type=str, help='Output directory') self.to_parse['outputdir'] = self.parseOutputDir def addInputFiles(self, dep_key=None): self.inputfiles_dep_key = dep_key self.add_argument( '-i', '--inputfiles', type=str, help='Input files') self.to_parse['inputfiles'] = self.parseInputFiles def addPatches(self): self.add_argument( '--patches', type=str, default='one', help='Stimulus patching mode ("none", "one", all", or a boolean list)') self.to_parse['patches'] = self.parsePatches def addThresholdCurve(self): self.add_argument( '--threshold', default=False, action='store_true', help='Show threshold amplitudes') def addNeuron(self): self.add_argument( '-n', '--neuron', type=str, nargs='+', help='Neuron name (string)') self.to_parse['neuron'] = self.parseNeuron def parseNeuron(self, args): return [getPointNeuron(n) for n in args['neuron']] def addInteractive(self): self.add_argument( '--interactive', default=False, action='store_true', help='Make interactive') def addLabels(self): self.add_argument( '--labels', type=str, nargs='+', default=None, help='Labels') def addRelativeTimeBounds(self): self.add_argument( '--rel_tbounds', type=float, nargs='+', default=None, help='Relative time lower and upper bounds') def addPretty(self): self.add_argument( '--pretty', default=False, action='store_true', help='Make figure pretty') def addSubset(self, choices): self.add_argument( '--subset', type=str, nargs='+', default=['all'], choices=choices + ['all'], help='Run specific subset(s)') self.subset_choices = choices self.to_parse['subset'] = self.parseSubset def parseSubset(self, args): if args['subset'] == ['all']: return self.subset_choices else: return args['subset'] def parseTimeRange(self, args): if args['trange'] is None: return None return np.array(args['trange']) * 1e-3 def parsePatches(self, args): if args['patches'] not in ('none', 'one', 'all'): return eval(args['patches']) else: return args['patches'] def parseInputFiles(self, args): if self.inputfiles_dep_key is not None and not args[self.inputfiles_dep_key]: return None elif args['inputfiles'] is None: return OpenFilesDialog('pkl')[0] def parseDir(self, key, args, title, dep_key=None): if dep_key is not None and args[dep_key] is False: return None try: if args[key] is not None: return os.path.abspath(args[key]) else: return selectDirDialog(title=title) except ValueError: raise ValueError('No {} selected'.format(key)) def parseInputDir(self, args): return self.parseDir( 'inputdir', args, 'Select input directory', self.inputdir_dep_key) def parseOutputDir(self, args): if hasattr(self, 'outputdir') and self.outputdir is not None: return self.outputdir else: if args['outputdir'] is not None and args['outputdir'] == 'dump': return DEFAULT_OUTPUT_FOLDER else: return self.parseDir( 'outputdir', args, 'Select output directory', self.outputdir_dep_key) def parseLogLevel(self, args): return logging.DEBUG if args.pop('verbose') else logging.INFO def parsePltScheme(self, args): if args['plot'] is None or args['plot'] == ['all']: return None else: return {x: [x] for x in args['plot']} def parseOverwrite(self, args): check_for_output = args.pop('checkout') return not check_for_output def restrict(self, args, keys): if sum([args[x] is not None for x in keys]) > 1: raise ValueError( 'You must provide only one of the following arguments: {}'.format(', '.join(keys))) def parse2array(self, args, key, factor=1): return np.array(args[key]) * factor def parse(self): args = vars(super().parse_args()) for k, v in self.defaults.items(): if k in args and args[k] is None: args[k] = v if isIterable(v) else [v] for k, parse_method in self.to_parse.items(): args[k] = parse_method(args) return args @staticmethod def parsePlot(args, output): render_args = {} if 'spikes' in args: render_args['spikes'] = args['spikes'] if args['compare']: if args['plot'] == ['all']: logger.error('Specific variables must be specified for comparative plots') return for key in ['cmap', 'cscale']: if key in args: render_args[key] = args[key] for pltvar in args['plot']: comp_plot = CompTimeSeries(output, pltvar) comp_plot.render(**render_args) else: scheme_plot = GroupedTimeSeries(output, pltscheme=args['pltscheme']) scheme_plot.render(**render_args) + + # phase_plot = PhaseDiagram(output, args['plot'][0]) + # phase_plot.render( + # # trange=args['trange'], + # # rel_tbounds=args['rel_tbounds'], + # labels=args['labels'], + # prettify=args['pretty'], + # cmap=args['cmap'], + # cscale=args['cscale'] + # ) + plt.show() class TestParser(Parser): def __init__(self, valid_subsets): super().__init__() self.addProfiling() self.addSubset(valid_subsets) def addProfiling(self): self.add_argument( '--profile', default=False, action='store_true', help='Run with profiling') class FigureParser(Parser): def __init__(self, valid_subsets): super().__init__() self.addSubset(valid_subsets) self.addSave() self.addOutputDir(dep_key='save') class PlotParser(Parser): def __init__(self): super().__init__() self.addHideOutput() self.addInputFiles() self.addOutputDir(dep_key='save') self.addSave() self.addFigureExtension() self.addCmap() self.addPretty() self.addTimeRange() self.addCscale() self.addLabels() class TimeSeriesParser(PlotParser): def __init__(self): super().__init__() self.addSpikes() self.addSamplingRate() self.addCompare() self.addPatches() class SimParser(Parser): ''' Generic simulation parser interface. ''' def __init__(self, outputdir=None): super().__init__() self.outputdir = outputdir self.addMPI() self.addOutputDir(dep_key='save') self.addSave() self.addCheckForOutput() self.addCompare() self.addCmap() self.addCscale() def parse(self): args = super().parse() return args class MechSimParser(SimParser): ''' Parser to run mechanical simulations from the command line. ''' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.defaults.update({ 'radius': 32.0, # nm 'embedding': 0., # um 'Cm0': CorticalRS.Cm0 * 1e2, # uF/m2 'Qm0': CorticalRS.Qm0() * 1e5, # nC/m2 'freq': 500.0, # kHz 'amp': 100.0, # kPa 'charge': 0., # nC/cm2 'fs': 100. # % }) self.factors.update({ 'radius': 1e-9, 'embedding': 1e-6, 'Cm0': 1e-2, 'Qm0': 1e-5, 'freq': 1e3, 'amp': 1e3, 'charge': 1e-5, 'fs': 1e-2 }) self.addRadius() self.addEmbedding() self.addCm0() self.addQm0() self.addFdrive() self.addAdrive() self.addCharge() self.addFs() def addRadius(self): self.add_argument( '-a', '--radius', nargs='+', type=float, help='Sonophore radius (nm)') def addEmbedding(self): self.add_argument( '--embedding', nargs='+', type=float, help='Embedding depth (um)') def addCm0(self): self.add_argument( '--Cm0', type=float, nargs='+', help='Resting membrane capacitance (uF/cm2)') def addQm0(self): self.add_argument( '--Qm0', type=float, nargs='+', help='Resting membrane charge density (nC/cm2)') def addFdrive(self): self.add_argument( '-f', '--freq', nargs='+', type=float, help='US frequency (kHz)') def addAdrive(self): self.add_argument( '-A', '--amp', nargs='+', type=float, help='Acoustic pressure amplitude (kPa)') self.add_argument( '--Arange', type=str, nargs='+', help='Amplitude range {} (kPa)'.format(self.dist_str)) self.add_argument( '-I', '--intensity', nargs='+', type=float, help='Acoustic intensity (W/cm2)') self.add_argument( '--Irange', type=str, nargs='+', help='Intensity range {} (W/cm2)'.format(self.dist_str)) self.to_parse['amp'] = self.parseAmp def addAscale(self, default='lin'): self.add_argument( '--Ascale', type=str, choices=('lin', 'log'), default=default, help='Amplitude scale for plot ("lin" or "log")') def addCharge(self): self.add_argument( '-Q', '--charge', nargs='+', type=float, help='Membrane charge density (nC/cm2)') def addFs(self): self.add_argument( '--fs', nargs='+', type=float, help='Sonophore coverage fraction (%%)') self.add_argument( '--spanFs', default=False, action='store_true', help='Span Fs from 1 to 100%%') self.to_parse['fs'] = self.parseFs def parseAmp(self, args): params = ['Irange', 'Arange', 'intensity', 'amp'] self.restrict(args, params[:-1]) Irange, Arange, Int, Adrive = [args.pop(k) for k in params] if Irange is not None: amps = Intensity2Pressure(self.getDistFromList(Irange) * 1e4) # Pa elif Int is not None: amps = Intensity2Pressure(np.array(Int) * 1e4) # Pa elif Arange is not None: amps = self.getDistFromList(Arange) * self.factors['amp'] # Pa else: amps = np.array(Adrive) * self.factors['amp'] # Pa return amps def parseFs(self, args): if args.pop('spanFs', False): return np.arange(1, 101) * self.factors['fs'] # (-) else: return np.array(args['fs']) * self.factors['fs'] # (-) def parse(self): args = super().parse() for key in ['radius', 'embedding', 'Cm0', 'Qm0', 'freq', 'charge']: args[key] = self.parse2array(args, key, factor=self.factors[key]) return args @staticmethod def parseSimInputs(args): return [args[k] for k in ['freq', 'amp', 'charge']] class NeuronSimParser(SimParser): def __init__(self): super().__init__() self.defaults.update({ 'neuron': 'RS', 'tstim': 100.0, # ms 'toffset': 50. # ms }) self.factors.update({ 'tstim': 1e-3, 'toffset': 1e-3 }) self.addNeuron() self.addTstim() self.addToffset() def addTstim(self): self.add_argument( '-t', '--tstim', nargs='+', type=float, help='Stimulus duration (ms)') def addToffset(self): self.add_argument( '--toffset', nargs='+', type=float, help='Offset duration (ms)') class VClampParser(NeuronSimParser): def __init__(self): super().__init__() self.defaults.update({ 'vhold': -70.0, # mV 'vstep': 0.0 # mV }) self.factors.update({ 'vhold': 1., 'vstep': 1. }) self.addVhold() self.addVstep() def addVhold(self): self.add_argument( '--vhold', nargs='+', type=float, help='Held membrane potential (mV)') def addVstep(self): self.add_argument( '--vstep', nargs='+', type=float, help='Step membrane potential (mV)') self.add_argument( '--vsteprange', type=str, nargs='+', help='Step membrane potential range {} (mV)'.format(self.dist_str)) self.to_parse['vstep'] = self.parseVstep def parseVstep(self, args): vstep_range, vstep = [args.pop(k) for k in ['vsteprange', 'vstep']] if vstep_range is not None: vsteps = self.getDistFromList(vstep_range) # mV else: vsteps = np.array(vstep) # mV return vsteps def parse(self, args=None): if args is None: args = super().parse() for key in ['vhold', 'vstep', 'tstim', 'toffset']: args[key] = self.parse2array(args, key, factor=self.factors[key]) return args @staticmethod def parseSimInputs(args): return [args[k] for k in ['vhold', 'vstep', 'tstim', 'toffset']] class PWSimParser(NeuronSimParser): ''' Generic parser interface to run PW patterned simulations from the command line. ''' def __init__(self): super().__init__() self.defaults.update({ 'PRF': 100.0, # Hz 'DC': 100.0 # % }) self.factors.update({ 'PRF': 1., 'DC': 1e-2 }) self.allowed.update({ 'DC': range(101) }) self.addPRF() self.addDC() self.addTitrate() self.addSpikes() def addPRF(self): self.add_argument( '--PRF', nargs='+', type=float, help='PRF (Hz)') def addDC(self): self.add_argument( '--DC', nargs='+', type=float, help='Duty cycle (%%)') self.add_argument( '--spanDC', default=False, action='store_true', help='Span DC from 1 to 100%%') self.to_parse['DC'] = self.parseDC def addTitrate(self): self.add_argument( '--titrate', default=False, action='store_true', help='Perform titration') def parseAmp(self, args): return NotImplementedError def parseDC(self, args): if args.pop('spanDC'): return np.arange(1, 101) * self.factors['DC'] # (-) else: return np.array(args['DC']) * self.factors['DC'] # (-) def parse(self, args=None): if args is None: args = super().parse() for key in ['tstim', 'toffset', 'PRF']: args[key] = self.parse2array(args, key, factor=self.factors[key]) return args @staticmethod def parseSimInputs(args): return [args[k] for k in ['amp', 'tstim', 'toffset', 'PRF', 'DC']] class EStimParser(PWSimParser): ''' Parser to run E-STIM simulations from the command line. ''' def __init__(self): super().__init__() self.defaults.update({'amp': 10.0}) # mA/m2 self.factors.update({'amp': 1.}) self.addAstim() def addAstim(self): self.add_argument( '-A', '--amp', nargs='+', type=float, help='Amplitude of injected current density (mA/m2)') self.add_argument( '--Arange', type=str, nargs='+', help='Amplitude range {} (mA/m2)'.format(self.dist_str)) self.to_parse['amp'] = self.parseAmp def addVext(self): self.add_argument( '--Vext', nargs='+', type=float, help='Extracellular potential (mV)') def parseAmp(self, args): if args.pop('titrate'): return None Arange, Astim = [args.pop(k) for k in ['Arange', 'amp']] if Arange is not None: amps = self.getDistFromList(Arange) * self.factors['amp'] # mA/m2 else: amps = np.array(Astim) * self.factors['amp'] # mA/m2 return amps class AStimParser(PWSimParser, MechSimParser): ''' Parser to run A-STIM simulations from the command line. ''' def __init__(self): MechSimParser.__init__(self) PWSimParser.__init__(self) self.defaults.update({'method': 'sonic'}) self.allowed.update({'method': ['full', 'hybrid', 'sonic', 'qss']}) self.addMethod() self.addQSSVars() def addMethod(self): self.add_argument( '-m', '--method', nargs='+', type=str, help='Numerical integration method ({})'.format(', '.join(self.allowed['method']))) self.to_parse['method'] = self.parseMethod def parseMethod(self, args): for item in args['method']: if item not in self.allowed['method']: raise ValueError('Unknown method type: "{}"'.format(item)) return args['method'] def addQSSVars(self): self.add_argument( '--qss', nargs='+', type=str, help='QSS variables') def parseAmp(self, args): if args.pop('titrate'): return None return MechSimParser.parseAmp(self, args) def parse(self): args = PWSimParser.parse(self, args=MechSimParser.parse(self)) for k in ['Cm0', 'Qm0', 'embedding', 'charge']: del args[k] return args @staticmethod def parseSimInputs(args): return [args['freq']] + PWSimParser.parseSimInputs(args) + [args[k] for k in ['fs', 'method']] diff --git a/PySONIC/plt/__init__.py b/PySONIC/plt/__init__.py index 29cac42..b260c88 100644 --- a/PySONIC/plt/__init__.py +++ b/PySONIC/plt/__init__.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-06-06 13:36:00 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-08-14 11:07:38 +# @Last Modified time: 2019-11-06 16:07:08 from .pltutils import * from .timeseries import * from .actmap import * from .QSS import * -from .phasediagram import * +from .spikes import * +from .phaseplot import * from .effvars import * \ No newline at end of file diff --git a/PySONIC/plt/phasediagram.py b/PySONIC/plt/phaseplot.py similarity index 89% copy from PySONIC/plt/phasediagram.py copy to PySONIC/plt/phaseplot.py index 42f7fa0..44fbebf 100644 --- a/PySONIC/plt/phasediagram.py +++ b/PySONIC/plt/phaseplot.py @@ -1,191 +1,168 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2018-10-01 20:40:28 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-08-22 14:59:13 +# @Last Modified time: 2019-11-06 16:05:54 import numpy as np import matplotlib.pyplot as plt from ..core import getModel from ..utils import * from ..constants import * from .pltutils import * -from ..postpro import detectSpikes, convertPeaksProperties - - -class PhaseDiagram(ComparativePlot): - - phaseplotvars = { - 'Vm': { - 'label': 'V_m\ (mV)', - 'dlabel': 'dV/dt\ (V/s)', - 'factor': 1e0, - 'lim': (-80.0, 50.0), - 'dfactor': 1e-3, - 'dlim': (-300, 700), - 'thr_amp': SPIKE_MIN_VAMP, - 'thr_prom': SPIKE_MIN_VPROM - }, - 'Qm': { - 'label': 'Q_m\ (nC/cm^2)', - 'dlabel': 'I\ (A/m^2)', - 'factor': 1e5, - 'lim': (-80.0, 50.0), - 'dfactor': 1e0, - 'dlim': (-2, 5), - 'thr_amp': SPIKE_MIN_QAMP, - 'thr_prom': SPIKE_MIN_QPROM - } - } + + +class PhasePlot(GenericPlot): + ''' Generic interface to build a phase plot displaying the evolution of 2 variables resulting from model simulations. ''' @classmethod def createBackBone(cls, pltvar, tbounds, fs, prettify): # Create figure fig, axes = plt.subplots(1, 2, figsize=(8, 4)) # 1st axis: variable as function of time ax = axes[0] ax.set_xlabel('$\\rm time\ (ms)$', fontsize=fs) ax.set_ylabel('$\\rm {}$'.format(pltvar['label']), fontsize=fs) ax.set_xlim(tbounds) ax.set_ylim(pltvar['lim']) # 2nd axis: phase plot (derivative of variable vs variable) ax = axes[1] ax.set_xlabel('$\\rm {}$'.format(pltvar['label']), fontsize=fs) ax.set_ylabel('$\\rm {}$'.format(pltvar['dlabel']), fontsize=fs) ax.set_xlim(pltvar['lim']) ax.set_ylim(pltvar['dlim']) ax.plot([0, 0], [pltvar['dlim'][0], pltvar['dlim'][1]], '--', color='k', linewidth=1) ax.plot([pltvar['lim'][0], pltvar['lim'][1]], [0, 0], '--', color='k', linewidth=1) if prettify: cls.prettify(axes[0], xticks=tbounds, yticks=pltvar['lim']) cls.prettify(axes[1], xticks=pltvar['lim'], yticks=pltvar['dlim']) for ax in axes: cls.removeSpines(ax) cls.setTickLabelsFontSize(ax, fs) return fig, axes def checkInputs(self, labels): self.checkLabels(labels) @staticmethod def extractSpikesData(t, y, tbounds, rel_tbounds, tspikes): spikes_tvec, spikes_yvec, spikes_dydtvec = [], [], [] for j, (tspike, tbound) in enumerate(zip(tspikes, tbounds)): left_bound = max(tbound[0], rel_tbounds[0] + tspike) right_bound = min(tbound[1], rel_tbounds[1] + tspike) inds = np.where((t > left_bound) & (t < right_bound))[0] spikes_tvec.append(t[inds] - tspike) spikes_yvec.append(y[inds]) dinds = np.hstack(([inds[0] - 1], inds, [inds[-1] + 1])) dydt = np.diff(y[dinds]) / np.diff(t[dinds]) spikes_dydtvec.append((dydt[:-1] + dydt[1:]) / 2) # average of the two return spikes_tvec, spikes_yvec, spikes_dydtvec def addLegend(self, fig, axes, handles, labels, fs): fig.subplots_adjust(top=0.8) if len(self.filepaths) > 1: axes[0].legend(handles, labels, fontsize=fs, frameon=False, loc='upper center', bbox_to_anchor=(1.0, 1.35)) else: fig.suptitle(labels[0], fontsize=fs) def render(self, no_offset=False, no_first=False, labels=None, colors=None, fs=10, lw=2, trange=None, rel_tbounds=None, prettify=False, cmap=None, cscale='lin'): self.checkInputs(labels) if rel_tbounds is None: rel_tbounds = np.array((-1.5e-3, 1.5e-3)) # Check pltvar if self.varname not in self.phaseplotvars: raise KeyError( 'Unknown plot variable: "{}". Possible plot variables are: {}'.format( self.varname, ', '.join(['"{}"'.format(p) for p in self.phaseplotvars.keys()]))) pltvar = self.phaseplotvars[self.varname] fig, axes = self.createBackBone(pltvar, rel_tbounds * 1e3, fs, prettify) # Loop through data files comp_values, full_labels = [], [] handles0, handles1 = [], [] for i, filepath in enumerate(self.filepaths): # Load data data, meta = self.getData(filepath, trange=trange) meta.pop('tcomp') full_labels.append(figtitle(meta)) # Extract model model = getModel(meta) # Check consistency of sim types and check differing inputs comp_values = self.checkConsistency(meta, comp_values) # Extract time and y-variable t = data['t'].values y = data[self.varname].values # Detect spikes in signal ispikes, properties = detectSpikes( data, key=self.varname, mph=pltvar['thr_amp'], mpp=pltvar['thr_prom']) nspikes = ispikes.size tspikes = t[ispikes] yspikes = y[ispikes] properties = convertPeaksProperties(t, properties) tbounds = np.array(list(zip(properties['left_bases'], properties['right_bases']))) if nspikes == 0: logger.warning('No spikes detected') else: # Store spikes in dedicated lists spikes_tvec, spikes_yvec, spikes_dydtvec = self.extractSpikesData( t, y, tbounds, rel_tbounds, tspikes) # Plot spikes temporal profiles and phase-plane diagrams lh0, lh1 = [], [] for j in range(nspikes): if colors is None: color = 'C{}'.format(i if len(self.filepaths) > 1 else j % 10) else: color = colors[i] lh0.append(axes[0].plot( spikes_tvec[j] * 1e3, spikes_yvec[j] * pltvar['factor'], linewidth=lw, c=color)[0]) lh1.append(axes[1].plot( spikes_yvec[j] * pltvar['factor'], spikes_dydtvec[j] * pltvar['dfactor'], linewidth=lw, c=color)[0]) handles0.append(lh0) handles1.append(lh1) # Determine labels if self.comp_ref_key is not None: self.comp_info = model.inputs().get(self.comp_ref_key, None) comp_values, comp_labels = self.getCompLabels(comp_values) labels = self.chooseLabels(labels, comp_labels, full_labels) # Post-process figure fig.tight_layout() # Add labels or colorbar legend if cmap is not None: if not self.is_unique_comp: raise ValueError('Colormap mode unavailable for multiple differing parameters') if self.comp_info is None: raise ValueError('Colormap mode unavailable for qualitative comparisons') cmap_handles = [h0 + h1 for h0, h1 in zip(handles0, handles1)] self.addCmap( fig, cmap, cmap_handles, comp_values, self.comp_info, fs, prettify, zscale=cscale) else: leg_handles = [x[0] for x in handles0] self.addLegend(fig, axes, leg_handles, labels, fs) return fig diff --git a/PySONIC/plt/phasediagram.py b/PySONIC/plt/spikes.py similarity index 98% rename from PySONIC/plt/phasediagram.py rename to PySONIC/plt/spikes.py index 42f7fa0..0cd9f1d 100644 --- a/PySONIC/plt/phasediagram.py +++ b/PySONIC/plt/spikes.py @@ -1,191 +1,191 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2018-10-01 20:40:28 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-08-22 14:59:13 +# @Last Modified time: 2019-11-06 16:01:32 import numpy as np import matplotlib.pyplot as plt from ..core import getModel from ..utils import * from ..constants import * from .pltutils import * from ..postpro import detectSpikes, convertPeaksProperties -class PhaseDiagram(ComparativePlot): +class SpikesDiagram(ComparativePlot): phaseplotvars = { 'Vm': { 'label': 'V_m\ (mV)', 'dlabel': 'dV/dt\ (V/s)', 'factor': 1e0, 'lim': (-80.0, 50.0), 'dfactor': 1e-3, 'dlim': (-300, 700), 'thr_amp': SPIKE_MIN_VAMP, 'thr_prom': SPIKE_MIN_VPROM }, 'Qm': { 'label': 'Q_m\ (nC/cm^2)', 'dlabel': 'I\ (A/m^2)', 'factor': 1e5, 'lim': (-80.0, 50.0), 'dfactor': 1e0, 'dlim': (-2, 5), 'thr_amp': SPIKE_MIN_QAMP, 'thr_prom': SPIKE_MIN_QPROM } } @classmethod def createBackBone(cls, pltvar, tbounds, fs, prettify): # Create figure fig, axes = plt.subplots(1, 2, figsize=(8, 4)) # 1st axis: variable as function of time ax = axes[0] ax.set_xlabel('$\\rm time\ (ms)$', fontsize=fs) ax.set_ylabel('$\\rm {}$'.format(pltvar['label']), fontsize=fs) ax.set_xlim(tbounds) ax.set_ylim(pltvar['lim']) # 2nd axis: phase plot (derivative of variable vs variable) ax = axes[1] ax.set_xlabel('$\\rm {}$'.format(pltvar['label']), fontsize=fs) ax.set_ylabel('$\\rm {}$'.format(pltvar['dlabel']), fontsize=fs) ax.set_xlim(pltvar['lim']) ax.set_ylim(pltvar['dlim']) ax.plot([0, 0], [pltvar['dlim'][0], pltvar['dlim'][1]], '--', color='k', linewidth=1) ax.plot([pltvar['lim'][0], pltvar['lim'][1]], [0, 0], '--', color='k', linewidth=1) if prettify: cls.prettify(axes[0], xticks=tbounds, yticks=pltvar['lim']) cls.prettify(axes[1], xticks=pltvar['lim'], yticks=pltvar['dlim']) for ax in axes: cls.removeSpines(ax) cls.setTickLabelsFontSize(ax, fs) return fig, axes def checkInputs(self, labels): self.checkLabels(labels) @staticmethod def extractSpikesData(t, y, tbounds, rel_tbounds, tspikes): spikes_tvec, spikes_yvec, spikes_dydtvec = [], [], [] for j, (tspike, tbound) in enumerate(zip(tspikes, tbounds)): left_bound = max(tbound[0], rel_tbounds[0] + tspike) right_bound = min(tbound[1], rel_tbounds[1] + tspike) inds = np.where((t > left_bound) & (t < right_bound))[0] spikes_tvec.append(t[inds] - tspike) spikes_yvec.append(y[inds]) dinds = np.hstack(([inds[0] - 1], inds, [inds[-1] + 1])) dydt = np.diff(y[dinds]) / np.diff(t[dinds]) spikes_dydtvec.append((dydt[:-1] + dydt[1:]) / 2) # average of the two return spikes_tvec, spikes_yvec, spikes_dydtvec def addLegend(self, fig, axes, handles, labels, fs): fig.subplots_adjust(top=0.8) if len(self.filepaths) > 1: axes[0].legend(handles, labels, fontsize=fs, frameon=False, loc='upper center', bbox_to_anchor=(1.0, 1.35)) else: fig.suptitle(labels[0], fontsize=fs) def render(self, no_offset=False, no_first=False, labels=None, colors=None, fs=10, lw=2, trange=None, rel_tbounds=None, prettify=False, cmap=None, cscale='lin'): self.checkInputs(labels) if rel_tbounds is None: rel_tbounds = np.array((-1.5e-3, 1.5e-3)) # Check pltvar if self.varname not in self.phaseplotvars: raise KeyError( 'Unknown plot variable: "{}". Possible plot variables are: {}'.format( self.varname, ', '.join(['"{}"'.format(p) for p in self.phaseplotvars.keys()]))) pltvar = self.phaseplotvars[self.varname] fig, axes = self.createBackBone(pltvar, rel_tbounds * 1e3, fs, prettify) # Loop through data files comp_values, full_labels = [], [] handles0, handles1 = [], [] for i, filepath in enumerate(self.filepaths): # Load data data, meta = self.getData(filepath, trange=trange) meta.pop('tcomp') full_labels.append(figtitle(meta)) # Extract model model = getModel(meta) # Check consistency of sim types and check differing inputs comp_values = self.checkConsistency(meta, comp_values) # Extract time and y-variable t = data['t'].values y = data[self.varname].values # Detect spikes in signal ispikes, properties = detectSpikes( data, key=self.varname, mph=pltvar['thr_amp'], mpp=pltvar['thr_prom']) nspikes = ispikes.size tspikes = t[ispikes] yspikes = y[ispikes] properties = convertPeaksProperties(t, properties) tbounds = np.array(list(zip(properties['left_bases'], properties['right_bases']))) if nspikes == 0: logger.warning('No spikes detected') else: # Store spikes in dedicated lists spikes_tvec, spikes_yvec, spikes_dydtvec = self.extractSpikesData( t, y, tbounds, rel_tbounds, tspikes) # Plot spikes temporal profiles and phase-plane diagrams lh0, lh1 = [], [] for j in range(nspikes): if colors is None: color = 'C{}'.format(i if len(self.filepaths) > 1 else j % 10) else: color = colors[i] lh0.append(axes[0].plot( spikes_tvec[j] * 1e3, spikes_yvec[j] * pltvar['factor'], linewidth=lw, c=color)[0]) lh1.append(axes[1].plot( spikes_yvec[j] * pltvar['factor'], spikes_dydtvec[j] * pltvar['dfactor'], linewidth=lw, c=color)[0]) handles0.append(lh0) handles1.append(lh1) # Determine labels if self.comp_ref_key is not None: self.comp_info = model.inputs().get(self.comp_ref_key, None) comp_values, comp_labels = self.getCompLabels(comp_values) labels = self.chooseLabels(labels, comp_labels, full_labels) # Post-process figure fig.tight_layout() # Add labels or colorbar legend if cmap is not None: if not self.is_unique_comp: raise ValueError('Colormap mode unavailable for multiple differing parameters') if self.comp_info is None: raise ValueError('Colormap mode unavailable for qualitative comparisons') cmap_handles = [h0 + h1 for h0, h1 in zip(handles0, handles1)] self.addCmap( fig, cmap, cmap_handles, comp_values, self.comp_info, fs, prettify, zscale=cscale) else: leg_handles = [x[0] for x in handles0] self.addLegend(fig, axes, leg_handles, labels, fs) return fig diff --git a/scripts/plot_phase_diagram.py b/scripts/plot_spikes.py similarity index 73% rename from scripts/plot_phase_diagram.py rename to scripts/plot_spikes.py index 6072ded..222561c 100644 --- a/scripts/plot_phase_diagram.py +++ b/scripts/plot_spikes.py @@ -1,39 +1,39 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-02-13 12:41:26 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-06-17 19:59:21 +# @Last Modified time: 2019-11-06 16:03:16 -''' Plot phase plane diagram of specific simulation output variables. ''' +''' Plot spikes diagrams of specific simulation output variables. ''' import matplotlib.pyplot as plt from PySONIC.utils import logger -from PySONIC.plt import PhaseDiagram +from PySONIC.plt import SpikesDiagram from PySONIC.parsers import PlotParser def main(): parser = PlotParser() parser.addRelativeTimeBounds() args = parser.parse() logger.setLevel(args['loglevel']) - # Plot phase-plane diagram - phase_diag = PhaseDiagram(args['inputfiles'], args['plot'][0]) + # Plot spikes diagram + phase_diag = SpikesDiagram(args['inputfiles'], args['plot'][0]) phase_diag.render( trange=args['trange'], rel_tbounds=args['rel_tbounds'], labels=args['labels'], prettify=args['pretty'], cmap=args['cmap'], cscale=args['cscale'] ) if not args['hide']: plt.show() if __name__ == '__main__': main()