diff --git a/PySONIC/core/pneuron.py b/PySONIC/core/pneuron.py index 3f08fc8..813e77a 100644 --- a/PySONIC/core/pneuron.py +++ b/PySONIC/core/pneuron.py @@ -1,596 +1,597 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-08-03 11:53:04 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-03-02 13:27:17 +# @Last Modified time: 2021-05-27 20:34:55 import abc import inspect import numpy as np from .protocols import * from .model import Model from .lookups import EffectiveVariablesLookup from .solvers import EventDrivenSolver from .drives import Drive, ElectricDrive from ..postpro import detectSpikes, computeFRProfile from ..constants import * from ..utils import * class PointNeuron(Model): ''' Generic point-neuron model interface. ''' tscale = 'ms' # relevant temporal scale of the model simkey = 'ESTIM' # keyword used to characterize simulations made with this model celsius = 36.0 # Temperature (Celsius) T = celsius + CELSIUS_2_KELVIN def __repr__(self): return self.__class__.__name__ def copy(self): return self.__class__() def __eq__(self, other): if not isinstance(other, PointNeuron): return False return self.name == other.name @property @classmethod @abc.abstractmethod def name(cls): ''' Neuron name. ''' raise NotImplementedError @property @classmethod @abc.abstractmethod def Cm0(cls): ''' Neuron's resting capacitance (F/m2). ''' raise NotImplementedError @property @classmethod @abc.abstractmethod def Vm0(cls): ''' Neuron's resting membrane potential(mV). ''' raise NotImplementedError @property def Qm0(self): return self.Cm0 * self.Vm0 * 1e-3 # C/m2 @property def tau_pas(self): ''' Passive membrane time constant (s). ''' return self.Cm0 / self.gLeak @property def meta(self): return {'neuron': self.name} @staticmethod def inputs(): return ElectricDrive.inputs() @classmethod def filecodes(cls, drive, pp): return { 'simkey': cls.simkey, 'neuron': cls.name, 'nature': pp.nature, **drive.filecodes, **pp.filecodes } @classmethod def normalizedQm(cls, Qm): ''' Compute membrane charge density normalized by resting capacitance. :param Qm: membrane charge density (Q/m2) :return: normalized charge density (mV) ''' return Qm / cls.Cm0 * 1e3 @classmethod def getPltVars(cls, wl='df["', wr='"]'): pltvars = { 'Qm': { 'desc': 'membrane charge density', 'label': 'Q_m', 'unit': 'nC/cm^2', 'factor': 1e5, 'bounds': ((cls.Vm0 - 20.0) * cls.Cm0 * 1e2, 60) }, 'Qm/Cm0': { 'desc': 'membrane charge density over resting capacitance', 'label': 'Q_m / C_{m0}', 'unit': 'mV', 'bounds': (-150, 70), - 'func': f"normalizedQm({wl}Qm{wr})" + 'func': f"normalizedQm({wl}Qm{wr})", + 'factor': 1e3 / cls.Cm0 }, 'Vm': { 'desc': 'membrane potential', 'label': 'V_m', 'unit': 'mV', 'bounds': (-150, 70) }, 'ELeak': { 'constant': 'obj.ELeak', 'desc': 'non-specific leakage current resting potential', 'label': 'V_{leak}', 'unit': 'mV', 'ls': '--', 'color': 'k' } } for cname in cls.getCurrentsNames(): cfunc = getattr(cls, cname) cargs = inspect.getargspec(cfunc)[0][1:] pltvars[cname] = { 'desc': inspect.getdoc(cfunc).splitlines()[0], 'label': f'I_{{{cname[1:]}}}', 'unit': 'A/m^2', 'factor': 1e-3, 'func': f"{cname}({', '.join([f'{wl}{a}{wr}' for a in cargs])})" } for var in cargs: if var != 'Vm': pltvars[var] = { 'desc': cls.states[var], 'label': var, 'bounds': (-0.1, 1.1) } pltvars['iNet'] = { 'desc': inspect.getdoc(getattr(cls, 'iNet')).splitlines()[0], 'label': 'I_{net}', 'unit': 'A/m^2', 'factor': 1e-3, 'func': f'iNet({wl}Vm{wr}, {wl[:-1]}{cls.statesNames()}{wr[1:]})', 'ls': '--', 'color': 'black' } pltvars['dQdt'] = { 'desc': inspect.getdoc(getattr(cls, 'dQdt')).splitlines()[0], 'label': 'dQ_m/dt', 'unit': 'A/m^2', 'factor': 1e-3, 'func': f'dQdt({wl}t{wr}, {wl}Qm{wr})', 'ls': '--', 'color': 'black' } pltvars['iax'] = { 'desc': inspect.getdoc(getattr(cls, 'iax')).splitlines()[0], 'label': 'i_{ax}', 'unit': 'A/m^2', 'factor': 1e-3, # 'func': f'iax({wl}t{wr}, {wl}Qm{wr}, {wl}Vm{wr}, {wl[:-1]}{cls.statesNames()}{wr[1:]})', 'ls': '--', 'color': 'black', # 'bounds': (-1e2, 1e2) } pltvars['iCap'] = { 'desc': inspect.getdoc(getattr(cls, 'iCap')).splitlines()[0], 'label': 'I_{cap}', 'unit': 'A/m^2', 'factor': 1e-3, 'func': f'iCap({wl}t{wr}, {wl}Vm{wr})' } for rate in cls.rates: if 'alpha' in rate: prefix, suffix = 'alpha', rate[5:] else: prefix, suffix = 'beta', rate[4:] pltvars[rate] = { 'label': '\\{}_{{{}}}'.format(prefix, suffix), 'unit': 'ms^{-1}', 'factor': 1e-3 } pltvars['FR'] = { 'desc': 'riring rate', 'label': 'FR', 'unit': 'Hz', 'factor': 1e0, # 'bounds': (0, 1e3), 'func': f'firingRateProfile({wl[:-2]})' } return pltvars @classmethod def iCap(cls, t, Vm): ''' Capacitive current. ''' dVdt = np.insert(np.diff(Vm) / np.diff(t), 0, 0.) return cls.Cm0 * dVdt @property def pltScheme(self): pltscheme = { 'Q_m': ['Qm'], 'V_m': ['Vm'] } pltscheme['I'] = self.getCurrentsNames() + ['iNet'] for cname in self.getCurrentsNames(): if 'Leak' not in cname: key = f'i_{{{cname[1:]}}}\ kin.' cargs = inspect.getargspec(getattr(self, cname))[0][1:] pltscheme[key] = [var for var in cargs if var not in ['Vm', 'Cai']] return pltscheme @classmethod def statesNames(cls): ''' Return a list of names of all state variables of the model. ''' return list(cls.states.keys()) @classmethod @abc.abstractmethod def derStates(cls): ''' Dictionary of states derivatives functions ''' raise NotImplementedError @classmethod def getDerStates(cls, Vm, states): ''' Compute states derivatives array given a membrane potential and states dictionary ''' return np.array([cls.derStates()[k](Vm, states) for k in cls.statesNames()]) @classmethod @abc.abstractmethod def steadyStates(cls): ''' Return a dictionary of steady-states functions ''' raise NotImplementedError @classmethod def getSteadyStates(cls, Vm): ''' Compute array of steady-states for a given membrane potential ''' return np.array([cls.steadyStates()[k](Vm) for k in cls.statesNames()]) @classmethod def getDerEffStates(cls, lkp, states): ''' Compute effective states derivatives array given lookups and states dictionaries. ''' return np.array([cls.derEffStates()[k](lkp, states) for k in cls.statesNames()]) @classmethod def getEffRates(cls, Vm): ''' Compute array of effective rate constants for a given membrane potential vector. ''' return {k: np.mean(np.vectorize(v)(Vm)) for k, v in cls.effRates().items()} def getLookup(self): ''' Get lookup of membrane potential rate constants interpolated along the neuron's charge physiological range. ''' logger.debug(f'generating {self} baseline lookup') Qmin, Qmax = expandRange(*self.Qbounds, exp_factor=10.) Qref = np.arange(Qmin, Qmax, 1e-5) # C/m2 Vref = Qref / self.Cm0 * 1e3 # mV tables = {k: np.vectorize(v)(Vref) for k, v in self.effRates().items()} return EffectiveVariablesLookup({'Q': Qref}, {'V': Vref, **tables}) @classmethod @abc.abstractmethod def currents(cls): ''' Dictionary of ionic currents functions (returning current densities in mA/m2) ''' @classmethod def iNet(cls, Vm, states): ''' net membrane current :param Vm: membrane potential (mV) :param states: states of ion channels gating and related variables :return: current per unit area (mA/m2) ''' return sum([cfunc(Vm, states) for cfunc in cls.currents().values()]) @classmethod def dQdt(cls, t, Qm, pad='right'): ''' membrane charge density variation rate :param t: time vector (s) :param Qm: membrane charge density vector (C/m2) :return: variation rate vector (mA/m2) ''' dQdt = np.diff(Qm) / np.diff(t) * 1e3 # mA/m2 return {'left': padleft, 'right': padright}[pad](dQdt) @classmethod def iax(cls, t, Qm, Vm, states): ''' axial current density (computed as sum of charge variation and net membrane ionic current) :param t: time vector (s) :param Qm: membrane charge density vector (C/m2) :param Vm: membrane potential (mV) :param states: states of ion channels gating and related variables :return: axial current density (mA/m2) ''' return cls.iNet(Vm, states) + cls.dQdt(t, Qm) @classmethod def titrationFunc(cls, *args, **kwargs): ''' Default titration function. ''' return cls.isExcited(*args, **kwargs) @staticmethod def currentToConcentrationRate(z_ion, depth): ''' Compute the conversion factor from a specific ionic current (in mA/m2) into a variation rate of submembrane ion concentration (in M/s). :param: z_ion: ion valence :param depth: submembrane depth (m) :return: conversion factor (Mmol.m-1.C-1) ''' return 1e-6 / (z_ion * depth * FARADAY) @staticmethod def nernst(z_ion, Cion_in, Cion_out, T): ''' Nernst potential of a specific ion given its intra and extracellular concentrations. :param z_ion: ion valence :param Cion_in: intracellular ion concentration :param Cion_out: extracellular ion concentration :param T: temperature (K) :return: ion Nernst potential (mV) ''' return (Rg * T) / (z_ion * FARADAY) * np.log(Cion_out / Cion_in) * 1e3 @staticmethod def vtrap(x, y): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x / y) - 1) @staticmethod def efun(x): ''' Generic function used to compute rate constants. ''' return x / (np.exp(x) - 1) @classmethod def ghkDrive(cls, Vm, Z_ion, Cion_in, Cion_out, T): ''' Use the Goldman-Hodgkin-Katz equation to compute the electrochemical driving force of a specific ion species for a given membrane potential. :param Vm: membrane potential (mV) :param Cin: intracellular ion concentration (M) :param Cout: extracellular ion concentration (M) :param T: temperature (K) :return: electrochemical driving force of a single ion particle (mC.m-3) ''' x = Z_ion * FARADAY * Vm / (Rg * T) * 1e-3 # [-] eCin = Cion_in * cls.efun(-x) # M eCout = Cion_out * cls.efun(x) # M return FARADAY * (eCin - eCout) * 1e6 # mC/m3 @classmethod def xBG(cls, Vref, Vm): ''' Compute dimensionless Borg-Graham ratio for a given voltage. :param Vref: reference voltage membrane (mV) :param Vm: membrane potential (mV) :return: dimensionless ratio ''' return (Vm - Vref) * FARADAY / (Rg * cls.T) * 1e-3 # [-] @classmethod def alphaBG(cls, alpha0, zeta, gamma, Vref, Vm): ''' Compute the activation rate constant for a given voltage and temperature, using a Borg-Graham formalism. :param alpha0: pre-exponential multiplying factor :param zeta: effective valence of the gating particle :param gamma: normalized position of the transition state within the membrane :param Vref: membrane voltage at which alpha = alpha0 (mV) :param Vm: membrane potential (mV) :return: rate constant (in alpha0 units) ''' return alpha0 * np.exp(-zeta * gamma * cls.xBG(Vref, Vm)) @classmethod def betaBG(cls, beta0, zeta, gamma, Vref, Vm): ''' Compute the inactivation rate constant for a given voltage and temperature, using a Borg-Graham formalism. :param beta0: pre-exponential multiplying factor :param zeta: effective valence of the gating particle :param gamma: normalized position of the transition state within the membrane :param Vref: membrane voltage at which beta = beta0 (mV) :param Vm: membrane potential (mV) :return: rate constant (in beta0 units) ''' return beta0 * np.exp(zeta * (1 - gamma) * cls.xBG(Vref, Vm)) @classmethod def getCurrentsNames(cls): return list(cls.currents().keys()) @staticmethod def firingRateProfile(*args, **kwargs): return computeFRProfile(*args, **kwargs) @property def Qbounds(self): ''' Determine bounds of membrane charge physiological range for a given neuron. ''' return np.array([np.round(self.Vm0 - 25.0), 50.0]) * self.Cm0 * 1e-3 # C/m2 @classmethod def isVoltageGated(cls, state): ''' Determine whether a given state is purely voltage-gated or not.''' return f'alpha{state.lower()}' in cls.rates @classmethod @Model.checkOutputDir def simQueue(cls, amps, durations, offsets, PRFs, DCs, **kwargs): ''' Create a serialized 2D array of all parameter combinations for a series of individual parameter sweeps. :param amps: list (or 1D-array) of acoustic amplitudes :param durations: list (or 1D-array) of stimulus durations :param offsets: list (or 1D-array) of stimulus offsets (paired with durations array) :param PRFs: list (or 1D-array) of pulse-repetition frequencies :param DCs: list (or 1D-array) of duty cycle values :return: list of parameters (list) for each simulation ''' if amps is None: amps = [None] drives = ElectricDrive.createQueue(amps) protocols = PulsedProtocol.createQueue(durations, offsets, PRFs, DCs) queue = [] for drive in drives: for pp in protocols: queue.append([drive, pp]) return queue @classmethod @Model.checkOutputDir def simQueueBurst(cls, amps, durations, PRFs, DCs, BRFs, nbursts, **kwargs): if amps is None: amps = [None] drives = ElectricDrive.createQueue(amps) protocols = BurstProtocol.createQueue(durations, PRFs, DCs, BRFs, nbursts) queue = [] for drive in drives: for pp in protocols: queue.append([drive, pp]) return queue @staticmethod def checkInputs(drive, pp): ''' Check validity of electrical stimulation parameters. :param drive: electric drive object :param pp: pulse protocol object ''' if not isinstance(drive, Drive): raise TypeError(f'Invalid "drive" parameter (must be an "Drive" object)') if not isinstance(pp, TimeProtocol): raise TypeError('Invalid time protocol (must be "TimeProtocol" instance)') def chooseTimeStep(self): ''' Determine integration time step based on intrinsic temporal properties. ''' return DT_EFFECTIVE @classmethod def derivatives(cls, t, y, Cm=None, drive=None): ''' Compute system derivatives for a given membrane capacitance and injected current. :param t: specific instant in time (s) :param y: vector of HH system variables at time t :param Cm: membrane capacitance (F/m2) :param Iinj: injected current (mA/m2) :return: vector of system derivatives at time t ''' if Cm is None: Cm = cls.Cm0 Qm, *states = y Vm = Qm / Cm * 1e3 # mV states_dict = dict(zip(cls.statesNames(), states)) dQmdt = - cls.iNet(Vm, states_dict) # mA/m2 if drive is not None: dQmdt += drive.compute(t) dQmdt *= 1e-3 # A/m2 # dQmdt = (Iinj - cls.iNet(Vm, states_dict)) * 1e-3 # A/m2 return [dQmdt, *cls.getDerStates(Vm, states_dict)] @Model.logNSpikes @Model.checkTitrate @Model.addMeta @Model.logDesc @Model.checkSimParams def simulate(self, drive, pp): ''' Simulate a specific neuron model for a set of simulation parameters, and return output data in a dataframe. :param drive: electric drive object :param pp: pulse protocol object :return: output DataFrame ''' # Set initial conditions y0 = { 'Qm': self.Qm0, **{k: self.steadyStates()[k](self.Vm0) for k in self.statesNames()} } # Initialize solver and compute solution solver = EventDrivenSolver( lambda x: setattr(solver.drive, 'xvar', drive.xvar * x), # eventfunc y0.keys(), # variables lambda t, y: self.derivatives(t, y, drive=solver.drive), # dfunc event_params={'drive': drive.copy().updatedX(0.)}, # event parameters dt=self.chooseTimeStep()) # time step data = solver(y0, pp.stimEvents(), pp.tstop) # Add Vm timeries to solution data = addColumn(data, 'Vm', data['Qm'].values / self.Cm0 * 1e3, preceding_key='Qm') # Return solution dataframe return data def desc(self, meta): return f'{self}: simulation @ {meta["drive"].desc}, {meta["pp"].desc}' @staticmethod def getNSpikes(data): ''' Compute number of spikes in charge profile of simulation output. :param data: dataframe containing output time series :return: number of detected spikes ''' return detectSpikes(data)[0].size @staticmethod def getStabilizationValue(data): ''' Determine stabilization value from the charge profile of a simulation output. :param data: dataframe containing output time series :return: charge stabilization value (or np.nan if no stabilization detected) ''' # Extract charge signal posterior to observation window t, Qm = [data[key].values for key in ['t', 'Qm']] if t.max() <= TMIN_STABILIZATION: raise ValueError('solution length is too short to assess stabilization') Qm = Qm[t > TMIN_STABILIZATION] # Compute variation range Qm_range = np.ptp(Qm) logger.debug('%.2f nC/cm2 variation range over the last %.0f ms, Qmf = %.2f nC/cm2', Qm_range * 1e5, TMIN_STABILIZATION * 1e3, Qm[-1] * 1e5) # Return final value only if stabilization is detected if np.ptp(Qm) < QSS_Q_DIV_THR: return Qm[-1] else: return np.nan @classmethod def isExcited(cls, data): ''' Determine if neuron is excited from simulation output. :param data: dataframe containing output time series :return: boolean stating whether neuron is excited or not ''' return cls.getNSpikes(data) > 0 @classmethod def isSilenced(cls, data): ''' Determine if neuron is silenced from simulation output. :param data: dataframe containing output time series :return: boolean stating whether neuron is silenced or not ''' return not np.isnan(cls.getStabilizationValue(data)) def getArange(self, drive): return drive.xvar_range diff --git a/PySONIC/multicomp/benchmarks.py b/PySONIC/multicomp/benchmarks.py index 15a1a08..e3d0911 100644 --- a/PySONIC/multicomp/benchmarks.py +++ b/PySONIC/multicomp/benchmarks.py @@ -1,347 +1,353 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2021-05-14 19:42:00 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-05-26 18:28:44 +# @Last Modified time: 2021-05-27 20:38:11 import os import numpy as np import matplotlib.pyplot as plt from ..core import NeuronalBilayerSonophore, PulsedProtocol, Batch from ..core.drives import AcousticDrive, AcousticDriveArray from ..utils import si_format, rmse, rescale from ..neurons import passiveNeuron from ..postpro import gamma from ..plt import harmonizeAxesLimits, hideSpines, hideTicks, addYscale, addXscale from .coupled_nbls import CoupledSonophores class Benchmark: def __init__(self, a, nnodes, outdir=None): self.a = a self.nnodes = nnodes self.outdir = outdir if not os.path.isdir(self.outdir): os.mkdir(self.outdir) def pdict(self): return { 'a': f'{self.a * 1e9:.0f} nm', 'nnodes': f'{self.nnodes} nodes', } def pstr(self): l = [] for k, v in self.pdict().items(): if k == 'nnodes': l.append(v) else: l.append(f'{k} = {v}') return ', '.join(l) def __repr__(self): return f'{self.__class__.__name__}({self.pstr()})' def code(self): s = self.__repr__() for k in ['/', '(', ',']: s = s.replace(k, '_') for k in ['=', ' ', ')']: s = s.replace(k, '') return s def runSims(self, model, drives, tstim, covs): ''' Run full and sonic simulations for a specific combination drives, pulsed protocol and coverage fractions, harmonize outputs and compute normalized charge density profiles. ''' Fdrive = drives[0].f assert all(x.f == Fdrive for x in drives), 'frequencies do not match' assert len(covs) == model.nnodes, 'coverages do not match model dimensions' assert len(drives) == model.nnodes, 'drives do not match model dimensions' # If not provided, compute stimulus duration from model passive properties min_ncycles = 10 ntaumax_conv = 5 if tstim is None: tstim = max(ntaumax_conv * model.taumax, min_ncycles / Fdrive) # Recast stimulus duration as finite multiple of acoustic period tstim = int(np.ceil(tstim * Fdrive)) / Fdrive # s # Pulsed protocol pp = PulsedProtocol(tstim, 0) # Simulate/Load with full and sonic methods data, meta = {}, {} for method in ['full', 'sonic']: data[method], meta[method] = model.simAndSave( drives, pp, covs, method, outdir=self.outdir, overwrite=False, minimize_output=True) # Cycle-average full solution and interpolate sonic solution along same time vector data['cycleavg'] = data['full'].cycleAveraged(1 / Fdrive) data['sonic'] = data['sonic'].interpolate(data['cycleavg'].time) # Compute normalized charge density profiles and add them to dataset for simkey, simdata in data.items(): for nodekey, nodedata in simdata.items(): nodedata['Qnorm'] = nodedata['Qm'] / model.refpneuron.Cm0 * 1e3 # mV # Return dataset return data, meta def getQnorms(self, data, k, cut_bounds=True): ''' Get node-specific list of cycle-averaged and sonic normalized charge vectors (with, by default, discarding of bounding artefact elements). ''' Qnorms = np.array([data[simkey][k]['Qnorm'].values for simkey in ['cycleavg', 'sonic']]) if cut_bounds: Qnorms = Qnorms[:, 1:-1] return Qnorms def computeGamma(self, data, *args): ''' Evaluate per-node gamma on charge density profiles. ''' gamma_dict = {} resolution = list(data['cycleavg'].values())[0].dt for k in data['cycleavg'].keys(): # Get normalized charge vectors (discarding 1st and last indexes) and compute gamma gamma_dict[k] = gamma(*self.getQnorms(data, k), *args, resolution) # Pad gamma with nan on each side to ensure size consistency with time vector gamma_dict[k] = np.pad(gamma_dict[k], 1, mode='constant', constant_values=(np.nan,)) return gamma_dict def computeRMSE(self, data): ''' Evaluate per-node RMSE on charge density profiles. ''' return {k: rmse(*self.getQnorms(data, k)) for k in data['cycleavg'].keys()} def computeSteadyStateDeviation(self, data): ''' Evaluate per-node steady-state absolute deviation on charge density profiles. ''' return {k: np.abs(np.squeeze(np.diff(self.getQnorms(data, k), axis=0)))[-1] for k in data['cycleavg'].keys()} def computeMeanNormDifference(self, data): ''' Evaluate per-node mean absolute difference on [0, 1] normalized charge profiles. ''' d = {} for k in data['cycleavg'].keys(): y = self.getQnorms(data, k) # Rescale signals linearly between 0 and 1 ynorms = np.array([rescale(yy) for yy in y]) # Compute absolute difference signal ydiff = np.squeeze(np.abs(np.diff(ynorms, axis=0))) # Compute mean absolute difference d[k] = np.mean(ydiff) * 1e2 # % return d def computeDivergence(self, data, eval_mode, *args): ''' Compute divergence according to given eval_mode. ''' if eval_mode == 'rmse': div_dict = self.computeRMSE(data) elif eval_mode == 'gamma': div_dict = {k: np.nanmax(v) for k, v in self.computeGamma(data, *args).items()} elif eval_mode == 'ss': div_dict = self.computeSteadyStateDeviation(data) elif eval_mode == 'normdiff': div_dict = self.computeMeanNormDifference(data) return max(div_dict.values()) def plotQnorm(self, ax, data): ''' Plot normalized charge density signals on an axis. ''' markers = {'full': '-', 'cycleavg': '--', 'sonic': '-'} alphas = {'full': 0.5, 'cycleavg': 1., 'sonic': 1.} + # tplt = TimeSeriesPlot.getTimePltVar('ms') + # yplt = self.model.refpneuron.getPltVars()['Qm/Cm0'] + # mode = 'details' for simkey, simdata in data.items(): for i, (nodekey, nodedata) in enumerate(simdata.items()): + c = f'C{i}' y = nodedata['Qnorm'].values y[-1] = y[-2] - ax.plot(nodedata.time * 1e3, y, markers[simkey], c=f'C{i}', + ax.plot(nodedata.time * 1e3, y, markers[simkey], c=c, alpha=alphas[simkey], label=f'{simkey} - {nodekey}', clip_on=False) + # if simkey == 'cycleavg': + # TimeSeriesPlot.materializeSpikes(ax, nodedata, tplt, yplt, c, mode) def plotMeanNormDifference(self, ax, data): for i, (k, nodedata) in enumerate(data['cycleavg'].items()): t = nodedata.time[1:-1] y = self.getQnorms(data, k) c = f'C{i}' ynorms = np.array([rescale(yy) for yy in y]) for yn, marker in zip(ynorms, ['--', '-']): ax.plot(t * 1e3, yn, marker, c=c, clip_on=False) ax.fill_between(t * 1e3, *ynorms, alpha=0.5) ydiff = np.squeeze(np.abs(np.diff(ynorms, axis=0))) ymean = np.mean(ydiff) ax.text(0.5, 0.3 * (i + 1), f'{ymean * 1e2:.2f}%', c=c, transform=ax.transAxes) def plotGamma(self, ax, data, *gamma_args): gamma_dict = self.computeGamma(data, *gamma_args) tplt = list(data['cycleavg'].values())[0].time * 1e3 for i, (nodekey, nodegamma) in enumerate(gamma_dict.items()): ax.plot(tplt, nodegamma, c=f'C{i}', label=nodekey, clip_on=False) ax.axhline(1, linestyle='--', c='k') def plotSignalsOver2DSpace(self, gridxkey, gridxvec, gridxunit, gridykey, gridyvec, gridyunit, results, pltfunc, yunit='', title=None, fs=10, flipud=True, fliplr=False): ''' Plot signals over 2D space. ''' # Create grid-like figure fig, axes = plt.subplots(gridxvec.size, gridyvec.size) # Re-arrange axes and labels if flipud/fliplr option is set supylabel_args = {} supxlabel_args = {'y': 1.0, 'va': 'top'} if flipud: axes = axes[::-1] supxlabel_args = {} if fliplr: axes = axes[:, ::-1] supylabel_args = {'x': 1.0, 'ha': 'right'} # Add title and general axes labels if title is not None: fig.suptitle(title, fontsize=fs + 2) fig.supxlabel(gridxkey, fontsize=fs + 2, **supxlabel_args) fig.supylabel(gridykey, fontsize=fs + 2, **supylabel_args) # Loop through the axes and plot results, while storing time ranges i = 0 tranges = [] for i, axrow in enumerate(axes): for j, ax in enumerate(axrow): hideSpines(ax) hideTicks(ax) ax.margins(0) if results[i, j] is not None: pltfunc(ax, results[i, j]) tranges.append(np.ptp(ax.get_xlim())) if len(np.unique(tranges)) > 1: # If more than one time range, add common x-scale to all axes tmin = min(tranges) for axrow in axes[::-1]: for ax in axrow: trans = (ax.transData + ax.transAxes.inverted()) xpoints = [trans.transform([x, 0])[0] for x in [0, tmin]] ax.plot(xpoints, [-0.05] * 2, c='k', lw=2, transform=ax.transAxes, clip_on=False) else: # Otherwise, add x-scale only to axis opposite to origin side = 'top' if flipud else 'bottom' addXscale(axes[-1, -1], 0, 0.05, unit='ms', fmt='.0f', fs=fs, side=side) # Harmonize y-limits across all axes, and add y-scale to axis opposite to origin harmonizeAxesLimits(axes, dim='y') side = 'left' if fliplr else 'right' if yunit is not None: addYscale(axes[-1, -1], 0.05, 0, unit=yunit, fmt='.0f', fs=fs, side=side) # Set labels for xvec and yvec values along the two figure grid dimensions for ax, x in zip(axes[0, :], gridxvec): ax.set_xlabel(f'{si_format(x)}{gridxunit}', labelpad=15, fontsize=fs + 2) if not flipud: ax.xaxis.set_label_position('top') for ax, y in zip(axes[:, 0], gridyvec): if fliplr: ax.yaxis.set_label_position('right') ax.set_ylabel(f'{si_format(y)}{gridyunit}', labelpad=15, fontsize=fs + 2) # Return figure object return fig class PassiveBenchmark(Benchmark): def __init__(self, a, nnodes, Cm0, ELeak, **kwargs): super().__init__(a, nnodes, **kwargs) self.Cm0 = Cm0 self.ELeak = ELeak def pdict(self): return { **super().pdict(), 'Cm0': f'{self.Cm0 * 1e2:.1f} uF/cm2', 'ELeak': f'{self.ELeak} mV', } def getModelAndRunSims(self, drives, covs, taum, tauax): ''' Create passive model for a combination of time constants. ''' gLeak = self.Cm0 / taum ga = self.Cm0 / tauax pneuron = passiveNeuron(self.Cm0, gLeak, self.ELeak) model = CoupledSonophores([ NeuronalBilayerSonophore(self.a, pneuron) for i in range(self.nnodes)], ga) return self.runSims(model, drives, None, covs) def runSimsOverTauSpace(self, drives, covs, taum_range, tauax_range, mpi=False): ''' Run simulations over 2D time constant space. ''' queue = [[drives, covs] + x for x in Batch.createQueue(taum_range, tauax_range)] batch = Batch(self.getModelAndRunSims, queue) # batch.printQueue(queue) output = batch.run(mpi=mpi) results = [x[0] for x in output] # removing meta return np.reshape(results, (taum_range.size, tauax_range.size)).T def plotSignalsOverTauSpace(self, taum_range, tauax_range, results, pltfunc=None, fs=10): if pltfunc is None: pltfunc = 'plotQnorm' yunit = {'plotQnorm': 'mV', 'plotMeanNormDifference': None}[pltfunc] title = pltfunc[4:] pltfunc = getattr(self, pltfunc) return self.plotSignalsOver2DSpace( 'taum', taum_range, 's', 'tauax', tauax_range, 's', results, pltfunc, title=title, yunit=yunit) class FiberBenchmark(Benchmark): def __init__(self, a, nnodes, pneuron, ga, **kwargs): super().__init__(a, nnodes, **kwargs) self.model = CoupledSonophores([ NeuronalBilayerSonophore(self.a, pneuron) for i in range(self.nnodes)], ga) def pdict(self): return { **super().pdict(), 'ga': self.model.gastr, 'pneuron': self.model.refpneuron, } def getModelAndRunSims(self, Fdrive, tstim, covs, A1, A2): ''' Create passive model for a combination of time constants. ''' drives = AcousticDriveArray([AcousticDrive(Fdrive, A1), AcousticDrive(Fdrive, A2)]) return self.runSims(self.model, drives, tstim, covs) def runSimsOverAmplitudeSpace(self, Fdrive, tstim, covs, A_range, mpi=False): ''' Run simulations over 2D time constant space. ''' # Generate 2D amplitudes meshgrid A_combs = np.meshgrid(A_range, A_range) # Set elements below main diagonal to NaN tril_idxs = np.tril_indices(A_range.size, -1) for x in A_combs: x[tril_idxs] = np.nan # Flatten the meshgrid and assemble into list of tuples A_combs = list(zip(*[x.flatten().tolist() for x in A_combs])) # Remove NaN elements A_combs = list(filter(lambda x: not any(np.isnan(xx) for xx in x), A_combs)) # Assemble queue queue = [[Fdrive, tstim, covs] + list(x) for x in A_combs] batch = Batch(self.getModelAndRunSims, queue) output = batch.run(mpi=mpi) results = [x[0] for x in output] # removing meta # Re-organize results into upper-triangle matrix new_results = np.empty((A_range.size, A_range.size), dtype=object) triu_idxs = np.triu_indices(A_range.size, 0) for *idx, res in zip(*triu_idxs, results): new_results[idx[0], idx[1]] = res return new_results def plotSignalsOverAmplitudeSpace(self, A_range, results, *gamma_args, fs=10): plt_gamma = len(gamma_args) > 0 title = 'gamma' if plt_gamma else 'Qm/Cm0' if plt_gamma: pltfunc = lambda *args: self.plotGamma(*args, *gamma_args) yunit = '' else: pltfunc = self.plotQnorm yunit = 'mV' return self.plotSignalsOver2DSpace( 'A1', A_range, 'Pa', 'A2', A_range, 'Pa', results, pltfunc, title=title, yunit=yunit) diff --git a/PySONIC/plt/timeseries.py b/PySONIC/plt/timeseries.py index 7a746aa..9b8e366 100644 --- a/PySONIC/plt/timeseries.py +++ b/PySONIC/plt/timeseries.py @@ -1,510 +1,511 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2018-09-25 16:18:45 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-05-23 12:01:55 +# @Last Modified time: 2021-05-27 20:35:42 import numpy as np import matplotlib.pyplot as plt from ..postpro import detectSpikes, convertPeaksProperties from ..utils import * from .pltutils import * class TimeSeriesPlot(GenericPlot): ''' Generic interface to build a plot displaying temporal profiles of model simulations. ''' @classmethod def setTimeLabel(cls, ax, tplt, fs): return super().setXLabel(ax, tplt, fs) @classmethod def setYLabel(cls, ax, yplt, fs, grouplabel=None): if grouplabel is not None: yplt['label'] = grouplabel return super().setYLabel(ax, yplt, fs) def checkInputs(self, *args, **kwargs): raise NotImplementedError @staticmethod def getStimStates(df): try: stimstate = df['stimstate'] except KeyError: stimstate = df['states'] return stimstate.values @classmethod def getStimPulses(cls, t, states): ''' Determine the onset and offset times of pulses from a stimulation vector. :param t: time vector (s). :param states: a vector of stimulation state (ON/OFF) at each instant in time. :return: list of 3-tuples start time, end time and value of each pulse. ''' # Compute states derivatives and identify transition indexes dstates = np.diff(states) itransitions = np.where(np.abs(dstates) > 0)[0] + 1 if states[0] != 0.: itransitions = np.hstack(([0], itransitions)) if states[-1] != 0: itransitions = np.hstack((itransitions, [t.size - 1])) pulses = list(zip(t[itransitions[:-1]], t[itransitions[1:]], states[itransitions[:-1]])) return list(filter(lambda x: x[2] != 0, pulses)) def addLegend(self, fig, ax, handles, labels, fs, color=None, ls=None): lh = ax.legend(handles, labels, loc=1, fontsize=fs, frameon=False) if color is not None: for l in lh.get_lines(): l.set_color(color) if ls: for l in lh.get_lines(): l.set_linestyle(ls) @classmethod def materializeSpikes(cls, ax, data, tplt, yplt, color, mode, add_to_legend=False): ispikes, properties = detectSpikes(data) t = data['t'].values Qm = data['Qm'].values + yfactor = yplt.get('factor', 1.0) if ispikes is not None: yoffset = 5 - ax.plot(t[ispikes] * tplt['factor'], Qm[ispikes] * yplt['factor'] + yoffset, + ax.plot(t[ispikes] * tplt['factor'], Qm[ispikes] * yfactor + yoffset, 'v', color=color, label='spikes' if add_to_legend else None) if mode == 'details': ileft = properties['left_bases'] iright = properties['right_bases'] properties = convertPeaksProperties(t, properties) - ax.plot(t[ileft] * tplt['factor'], Qm[ileft] * yplt['factor'] - 5, + ax.plot(t[ileft] * tplt['factor'], Qm[ileft] * yfactor - 5, '<', color=color, label='left-bases' if add_to_legend else None) - ax.plot(t[iright] * tplt['factor'], Qm[iright] * yplt['factor'] - 10, + ax.plot(t[iright] * tplt['factor'], Qm[iright] * yfactor - 10, '>', color=color, label='right-bases' if add_to_legend else None) ax.vlines( x=t[ispikes] * tplt['factor'], - ymin=(Qm[ispikes] - properties['prominences']) * yplt['factor'], - ymax=Qm[ispikes] * yplt['factor'], + ymin=(Qm[ispikes] - properties['prominences']) * yfactor, + ymax=Qm[ispikes] * yfactor, color=color, linestyles='dashed', label='prominences' if add_to_legend else '') ax.hlines( - y=properties['width_heights'] * yplt['factor'], + y=properties['width_heights'] * yfactor, xmin=properties['left_ips'] * tplt['factor'], xmax=properties['right_ips'] * tplt['factor'], color=color, linestyles='dotted', label='half-widths' if add_to_legend else '') return add_to_legend @staticmethod def prepareTime(t, tplt): if tplt['onset'] > 0.0: tonset = t.min() - 0.05 * np.ptp(t) t = np.insert(t, 0, tonset) return t * tplt['factor'] @staticmethod def getPatchesColors(x): if np.all([xx == x[0] for xx in x]): return ['#8A8A8A'] * len(x) else: xabsmax = np.abs(x).max() _, sm = setNormalizer(plt.get_cmap('RdGy'), (-xabsmax, xabsmax), 'lin') return [sm.to_rgba(xx) for xx in x] @classmethod def addPatches(cls, ax, pulses, tplt, color=None): tstart, tend, x = zip(*pulses) if color is None: colors = cls.getPatchesColors(x) else: colors = [color] * len(x) for i in range(len(pulses)): ax.axvspan(tstart[i] * tplt['factor'], tend[i] * tplt['factor'], edgecolor='none', facecolor=colors[i], alpha=0.2) @staticmethod def plotInset(inset_ax, inset, t, y, tplt, yplt, line, color, lw): inset_ax.plot(t, y, linewidth=lw, linestyle=line, color=color) return inset_ax @classmethod def addInsetPatches(cls, ax, inset_ax, inset, pulses, tplt, color): tstart, tend, x = [np.array([z]) for z in zip(*pulses)] tfactor = tplt['factor'] ybottom, ytop = ax.get_ylim() cond_start = np.logical_and(tstart > (inset['xlims'][0] / tfactor), tstart < (inset['xlims'][1] / tfactor)) cond_end = np.logical_and(tend > (inset['xlims'][0] / tfactor), tend < (inset['xlims'][1] / tfactor)) cond_glob = np.logical_and(tstart < (inset['xlims'][0] / tfactor), tend > (inset['xlims'][1] / tfactor)) cond_onoff = np.logical_or(cond_start, cond_end) cond = np.logical_or(cond_onoff, cond_glob) tstart, tend, x = [z[cond] for z in [tstart, tend, x]] colors = cls.getPatchesColors(x) npatches_inset = tstart.size for i in range(npatches_inset): inset_ax.add_patch(Rectangle( (tstart[i] * tfactor, ybottom), (tend[i] - tstart[i]) * tfactor, ytop - ybottom, color=colors[i], alpha=0.1)) class CompTimeSeries(ComparativePlot, TimeSeriesPlot): ''' Interface to build a comparative plot displaying profiles of a specific output variable across different model simulations. ''' def __init__(self, outputs, varname): ''' Constructor. :param outputs: list / generator of simulator outputs to be compared. :param varname: name of variable to extract and compare ''' ComparativePlot.__init__(self, outputs, varname) def checkPatches(self, patches): self.greypatch = False if patches == 'none': self.patchfunc = lambda _: False elif patches == 'all': self.patchfunc = lambda _: True elif patches == 'one': self.patchfunc = lambda j: True if j == 0 else False self.greypatch = True elif isinstance(patches, list): if not all(isinstance(p, bool) for p in patches): raise TypeError('Invalid patch sequence: all list items must be boolean typed') self.patchfunc = lambda j: patches[j] if len(patches) > j else False else: raise ValueError( 'Invalid patches: must be either "none", all", "one", or a boolean list') def checkInputs(self, labels, patches): self.checkLabels(labels) self.checkPatches(patches) @staticmethod def createBackBone(figsize): fig, ax = plt.subplots(figsize=figsize) ax.set_zorder(0) return fig, ax @classmethod def postProcess(cls, ax, tplt, yplt, fs, meta, prettify): cls.removeSpines(ax) if 'bounds' in yplt: ymin, ymax = ax.get_ylim() ax.set_ylim(min(ymin, yplt['bounds'][0]), max(ymax, yplt['bounds'][1])) elif 'strictbounds' in yplt: ax.set_ylim(*yplt['strictbounds']) cls.setTimeLabel(ax, tplt, fs) cls.setYLabel(ax, yplt, fs) if prettify: cls.prettify(ax, xticks=(0, meta['tstim'] * tplt['factor'])) cls.setTickLabelsFontSize(ax, fs) def render(self, figsize=(11, 4), fs=10, lw=2, labels=None, colors=None, lines=None, patches='one', inset=None, frequency=1, spikes='none', cmap=None, cscale='lin', trange=None, prettify=False): ''' Render plot. :param figsize: figure size (x, y) :param fs: labels fontsize :param lw: linewidth :param labels: list of labels to use in the legend :param colors: list of colors to use for each curve :param lines: list of linestyles :param patches: string indicating whether/how to mark stimulation periods with rectangular patches :param inset: string indicating whether/how to mark an inset zooming on a particular region of the graph :param frequency: frequency at which to plot samples :param spikes: string indicating how to show spikes ("none", "marks" or "details") :param cmap: color map to use for colobar-based comparison (if not None) :param cscale: color scale to use for colobar-based comparison :param trange: optional lower and upper bounds to time axis :return: figure handle ''' self.checkInputs(labels, patches) fcodes = [] fig, ax = self.createBackBone(figsize) if inset is not None: inset_ax = self.addInset(fig, ax, inset) # Loop through data files handles, comp_values, full_labels = [], [], [] tmin, tmax = np.inf, -np.inf for j, output in enumerate(self.outputs): color = f'C{j}' if colors is None else colors[j] line = '-' if lines is None else lines[j] patch = self.patchfunc(j) # Load data try: data, meta = self.getData(output, frequency, trange) except ValueError: continue if 'tcomp' in meta: meta.pop('tcomp') # Extract model model = self.getModel(meta) fcodes.append(model.filecode(meta)) # Add label to list full_labels.append(self.figtitle(model, meta)) # Check consistency of sim types and check differing inputs comp_values = self.checkConsistency(meta, comp_values) # Extract time and stim pulses t = data['t'].values stimstate = self.getStimStates(data) pulses = self.getStimPulses(t, stimstate) tplt = self.getTimePltVar(model.tscale) t = self.prepareTime(t, tplt) # Extract y-variable pltvars = model.getPltVars() if self.varname not in pltvars: pltvars_str = ', '.join([f'"{p}"' for p in pltvars.keys()]) raise KeyError( f'Unknown plot variable: "{self.varname}". Candidates are: {pltvars_str}') yplt = pltvars[self.varname] y = extractPltVar(model, yplt, data, meta, t.size, self.varname) # Plot time series handles.append(ax.plot(t, y, linewidth=lw, linestyle=line, color=color)[0]) # Optional: add spikes if self.varname == 'Qm' and spikes != 'none': self.materializeSpikes(ax, data, tplt, yplt, color, spikes) # Plot optional inset if inset is not None: inset_ax = self.plotInset( inset_ax, inset, t, y, tplt, yplt, lines[j], color, lw) # Add optional STIM-ON patches if patch: ybottom, ytop = ax.get_ylim() patchcolor = None if self.greypatch else handles[j].get_color() self.addPatches(ax, pulses, tplt, color=patchcolor) if inset is not None: self.addInsetPatches(ax, inset_ax, inset, pulses, tplt, patchcolor) tmin, tmax = min(tmin, t.min()), max(tmax, t.max()) # Get common label and add it as title common_label = self.getCommonLabel(full_labels.copy(), seps=':@,()') self.wraptitle(ax, common_label, fs=fs) # Get comp info if any if self.comp_ref_key is not None: self.comp_info = model.inputs().get(self.comp_ref_key, None) # Post-process figure self.postProcess(ax, tplt, yplt, fs, meta, prettify) ax.set_xlim(tmin, tmax) fig.tight_layout() # Materialize inset if any if inset is not None: self.materializeInset(ax, inset_ax, inset) # Add labels or colorbar legend if cmap is not None: if not self.is_unique_comp: raise ValueError('Colormap mode unavailable for multiple differing parameters') if self.comp_info is None: raise ValueError('Colormap mode unavailable for qualitative comparisons') self.addCmap( fig, cmap, handles, comp_values, self.comp_info, fs, prettify, zscale=cscale) else: comp_values, comp_labels = self.getCompLabels(comp_values) labels = self.chooseLabels(labels, comp_labels, full_labels) self.addLegend(fig, ax, handles, labels, fs) # Add window title based on common pattern common_fcode = self.getCommonLabel(fcodes.copy()) fig.canvas.manager.set_window_title(common_fcode) return fig class GroupedTimeSeries(TimeSeriesPlot): ''' Interface to build a plot displaying profiles of several output variables arranged into specific schemes. ''' def __init__(self, outputs, pltscheme=None): ''' Constructor. :param outputs: list / generator of simulation outputs. :param varname: name of variable to extract and compare ''' super().__init__(outputs) self.pltscheme = pltscheme @staticmethod def createBackBone(pltscheme): naxes = len(pltscheme) if naxes == 1: fig, ax = plt.subplots(figsize=(11, 4)) axes = [ax] else: fig, axes = plt.subplots(naxes, 1, figsize=(11, min(3 * naxes, 9))) return fig, axes @staticmethod def shareX(axes): for ax in axes[:-1]: ax.get_shared_x_axes().join(ax, axes[-1]) ax.set_xticklabels([]) @classmethod def postProcess(cls, axes, tplt, fs, meta, prettify): for ax in axes: cls.removeSpines(ax) if prettify: cls.prettify(ax, xticks=(0, meta['pp'].tstim * tplt['factor']), yfmt=None) cls.setTickLabelsFontSize(ax, fs) cls.shareX(axes) cls.setTimeLabel(axes[-1], tplt, fs) def render(self, fs=10, lw=2, labels=None, colors=None, lines=None, patches='one', save=False, outputdir=None, fig_ext='png', frequency=1, spikes='none', trange=None, prettify=False): ''' Render plot. :param fs: labels fontsize :param lw: linewidth :param labels: list of labels to use in the legend :param colors: list of colors to use for each curve :param lines: list of linestyles :param patches: boolean indicating whether to mark stimulation periods with rectangular patches :param save: boolean indicating whether or not to save the figure(s) :param outputdir: path to output directory in which to save figure(s) :param fig_ext: string indcating figure extension ("png", "pdf", ...) :param frequency: frequency at which to plot samples :param spikes: string indicating how to show spikes ("none", "marks" or "details") :param trange: optional lower and upper bounds to time axis :return: figure handle(s) ''' if colors is None: colors = plt.get_cmap('tab10').colors figs = [] for output in self.outputs: # Load data and extract model try: data, meta = self.getData(output, frequency, trange) except ValueError: continue model = self.getModel(meta) # Extract time and stim pulses t = data['t'].values stimstate = self.getStimStates(data) pulses = self.getStimPulses(t, stimstate) tplt = self.getTimePltVar(model.tscale) t = self.prepareTime(t, tplt) # Check plot scheme if provided, otherwise generate it pltvars = model.getPltVars() if self.pltscheme is not None: for key in list(sum(list(self.pltscheme.values()), [])): if key not in pltvars: raise KeyError(f'Unknown plot variable: "{key}"') pltscheme = self.pltscheme else: pltscheme = model.pltScheme # Create figure fig, axes = self.createBackBone(pltscheme) # Loop through each subgraph for ax, (grouplabel, keys) in zip(axes, pltscheme.items()): ax_legend_spikes = False # Extract variables to plot nvars = len(keys) ax_pltvars = [pltvars[k] for k in keys] if nvars == 1: ax_pltvars[0]['color'] = 'k' ax_pltvars[0]['ls'] = '-' # Plot time series icolor = 0 for yplt, name in zip(ax_pltvars, pltscheme[grouplabel]): color = yplt.get('color', colors[icolor]) y = extractPltVar(model, yplt, data, meta, t.size, name) ax.plot(t, y, yplt.get('ls', '-'), c=color, lw=lw, label='$\\rm {}$'.format(yplt["label"])) if 'color' not in yplt: icolor += 1 # Optional: add spikes if name == 'Qm' and spikes != 'none': ax_legend_spikes = self.materializeSpikes( ax, data, tplt, yplt, color, spikes, add_to_legend=True) # Set y-axis unit and bounds self.setYLabel(ax, ax_pltvars[0].copy(), fs, grouplabel=grouplabel) if 'bounds' in ax_pltvars[0]: ymin, ymax = ax.get_ylim() ax_min = min(ymin, *[ap['bounds'][0] for ap in ax_pltvars]) ax_max = max(ymax, *[ap['bounds'][1] for ap in ax_pltvars]) ax.set_ylim(ax_min, ax_max) # Add legend if nvars > 1 or 'gate' in ax_pltvars[0]['desc'] or ax_legend_spikes: ax.legend(fontsize=fs, loc=7, ncol=nvars // 4 + 1, frameon=False) # Set x-limits and add optional patches for ax in axes: ax.set_xlim(t.min(), t.max()) if patches != 'none': self.addPatches(ax, pulses, tplt) # Post-process figure self.postProcess(axes, tplt, fs, meta, prettify) self.wraptitle(axes[0], self.figtitle(model, meta), fs=fs) fig.tight_layout() fig.canvas.manager.set_window_title(model.filecode(meta)) # Save figure if needed (automatic or checked) if save: filecode = model.filecode(meta) if outputdir is None: raise ValueError('output directory not specified') plt_filename = f'{outputdir}/{filecode}.{fig_ext}' plt.savefig(plt_filename) logger.info(f'Saving figure as "{plt_filename}"') plt.close() figs.append(fig) return figs if __name__ == '__main__': # example of use filepaths = OpenFilesDialog('pkl')[0] comp_plot = CompTimeSeries(filepaths, 'Qm') fig = comp_plot.render( lines=['-', '--'], labels=['60 kPa', '80 kPa'], patches='one', colors=['r', 'g'], xticks=[0, 100], yticks=[-80, +50], inset={'xcoords': [5, 40], 'ycoords': [-35, 45], 'xlims': [57.5, 60.5], 'ylims': [10, 35]} ) scheme_plot = GroupedTimeSeries(filepaths) figs = scheme_plot.render() plt.show()