diff --git a/PySONIC/core/model.py b/PySONIC/core/model.py index 9dacc88..a8cf192 100644 --- a/PySONIC/core/model.py +++ b/PySONIC/core/model.py @@ -1,242 +1,242 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-08-03 11:53:04 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-08-14 17:37:41 +# @Last Modified time: 2019-08-16 19:14:50 import os from functools import wraps from inspect import signature, getdoc import pickle import abc import inspect import numpy as np from .batches import Batch -from ..utils import logger, loadData, timer, si_format, plural +from ..utils import logger, loadData, timer, si_format, plural, debug class Model(metaclass=abc.ABCMeta): ''' Generic model interface. ''' @property @abc.abstractmethod def tscale(self): ''' Relevant temporal scale of the model. ''' raise NotImplementedError @property @abc.abstractmethod def simkey(self): ''' Keyword used to characterize simulations made with the model. ''' raise NotImplementedError @property @abc.abstractmethod def __repr__(self): ''' String representation. ''' raise NotImplementedError def params(self): ''' Return a dictionary of all model parameters (class and instance attributes) ''' def toAvoid(p): return (p.startswith('__') and p.endswith('__')) or p.startswith('_abc_') class_attrs = inspect.getmembers(self.__class__, lambda a: not(inspect.isroutine(a))) inst_attrs = inspect.getmembers(self, lambda a: not(inspect.isroutine(a))) class_attrs = [a for a in class_attrs if not toAvoid(a[0])] inst_attrs = [a for a in inst_attrs if not toAvoid(a[0]) and a not in class_attrs] params_dict = {a[0]: a[1] for a in class_attrs + inst_attrs} return params_dict @classmethod def description(cls): return getdoc(cls).split('\n', 1)[0].strip() @staticmethod @abc.abstractmethod def inputs(): ''' Return an informative dictionary on input variables used to simulate the model. ''' raise NotImplementedError @property @abc.abstractmethod def filecodes(self, *args): ''' Return a dictionary of string-encoded inputs used for file naming. ''' raise NotImplementedError def filecode(self, *args): ''' Generate file code given a specific combination of model input parameters. ''' # If meta dictionary was passed, generate inputs list from it if len(args) == 1 and isinstance(args[0], dict): meta = args[0] meta.pop('tcomp', None) nparams = len(signature(self.meta).parameters) args = list(meta.values())[-nparams:] # Create file code by joining string-encoded inputs with underscores codes = self.filecodes(*args).values() return '_'.join([x for x in codes if x is not None]) @classmethod @abc.abstractmethod def getPltVars(self, *args, **kwargs): ''' Return a dictionary with information about all plot variables related to the model. ''' raise NotImplementedError @classmethod @abc.abstractmethod def getPltScheme(self): ''' Return a dictionary model plot variables grouped by context. ''' raise NotImplementedError @staticmethod def checkOutputDir(queue, outputdir): ''' Check if an outputdir is provided in input arguments, and if so, add it as the first element of each item in the returned queue. ''' if outputdir is not None: for item in queue: item.insert(0, outputdir) else: if len(queue) > 5: logger.warning('Running more than 5 simulations without file saving') return queue @classmethod def simQueue(cls, *args, outputdir=None): return cls.checkOutputDir(Batch.createQueue(*args), outputdir) @staticmethod @abc.abstractmethod def checkInputs(self, *args): ''' Check the validity of simulation input parameters. ''' raise NotImplementedError @property @abc.abstractmethod def derivatives(self, *args, **kwargs): ''' Compute ODE derivatives for a specific set of ODE states and external parameters. ''' raise NotImplementedError @property @abc.abstractmethod def simulate(self, *args, **kwargs): ''' Simulate the model's differential system for specific input parameters and return output data in a dataframe. ''' raise NotImplementedError @classmethod @abc.abstractmethod def meta(self, *args): ''' Return an informative dictionary about model and simulation parameters. ''' raise NotImplementedError @staticmethod def addMeta(simfunc): ''' Add an informative dictionary about model and simulation parameters to simulation output ''' @wraps(simfunc) def wrapper(self, *args, **kwargs): data, tcomp = timer(simfunc)(self, *args, **kwargs) logger.debug('completed in %ss', si_format(tcomp, 1)) # Add keyword arguments from simfunc signature if not provided - bound_args = inspect.signature(simfunc).bind(self, *args, **kwargs) + bound_args = signature(simfunc).bind(self, *args, **kwargs) bound_args.apply_defaults() target_args = dict(bound_args.arguments) # Try to retrieve meta information try: - meta_params = [target_args[k] for k in inspect.signature(self.meta).parameters.keys()] + meta_params = [target_args[k] for k in signature(self.meta).parameters.keys()] meta = self.meta(*meta_params) except KeyError: meta = {} # Add computation time to it meta['tcomp'] = tcomp # Return data with meta dict return data, meta return wrapper @staticmethod def logNSpikes(simfunc): ''' Log number of detected spikes on charge profile of simulation output. ''' @wraps(simfunc) def wrapper(self, *args, **kwargs): out = simfunc(self, *args, **kwargs) if out is None: return None data, meta = out nspikes = self.getNSpikes(data) logger.debug('{} spike{} detected'.format(nspikes, plural(nspikes))) return data, meta return wrapper @staticmethod def checkTitrate(argname): ''' If no None provided in the list of input parameters, perform a titration to find the threshold parameter and add it to the list. ''' def wrapper_with_args(simfunc): @wraps(simfunc) def wrapper(self, *args, **kwargs): # Get argument index from function signature - func_args = list(inspect.signature(simfunc).parameters.keys())[1:] + func_args = list(signature(simfunc).parameters.keys())[1:] iarg = func_args.index(argname) # If argument is None if args[iarg] is None: # Generate new args list without argument args = list(args) new_args = args.copy() del new_args[iarg] # Perform titration to find threshold argument value xthr = self.titrate(*new_args) if np.isnan(xthr): logger.error(f'Could not find threshold {argname}') return None # Re-insert it into arguments list args[iarg] = xthr # Execute simulation function return simfunc(self, *args, **kwargs) return wrapper return wrapper_with_args def simAndSave(self, outdir, *args): ''' Simulate the model and save the results in a specific output directory. ''' out = self.simulate(*args) if out is None: return None data, meta = out if None in args: args = list(args) iNone = next(i for i, arg in enumerate(args) if arg is None) args[iNone] = meta['Adrive'] fpath = '{}/{}.pkl'.format(outdir, self.filecode(*args)) with open(fpath, 'wb') as fh: pickle.dump({'meta': meta, 'data': data}, fh) logger.debug('simulation data exported to "%s"', fpath) return fpath def getOutput(self, outdir, *args): ''' Get simulation output data for a specific parameters combination, by looking for an output file into a specific directory. If a corresponding output file is not found in the specified directory, the model is first run and results are saved in the output file. ''' fpath = '{}/{}.pkl'.format(outdir, self.filecode(*args)) if not os.path.isfile(fpath): self.simAndSave(outdir, *args) return loadData(fpath) diff --git a/PySONIC/plt/pltutils.py b/PySONIC/plt/pltutils.py index fe9d8d7..da6d270 100644 --- a/PySONIC/plt/pltutils.py +++ b/PySONIC/plt/pltutils.py @@ -1,400 +1,409 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-08-21 14:33:36 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-08-15 11:50:16 +# @Last Modified time: 2019-08-16 20:08:57 ''' Useful functions to generate plots. ''' import re import numpy as np import pandas as pd import matplotlib from matplotlib.patches import Rectangle from matplotlib import cm, colors import matplotlib.pyplot as plt +from ..core import getModel from ..utils import logger, isIterable, loadData, rescale, swapFirstLetterCase from ..postpro import findPeaks from ..constants import SPIKE_MIN_DT, SPIKE_MIN_QAMP, SPIKE_MIN_QPROM # Matplotlib parameters matplotlib.rcParams['pdf.fonttype'] = 42 matplotlib.rcParams['ps.fonttype'] = 42 matplotlib.rcParams['font.family'] = 'arial' def figtitle(meta): ''' Return appropriate title based on simulation metadata. ''' if 'Cm0' in meta: return '{:.0f}nm radius BLS structure: MECH-STIM {:.0f}kHz, {:.2f}kPa, {:.1f}nC/cm2'.format( meta['a'] * 1e9, meta['Fdrive'] * 1e-3, meta['Adrive'] * 1e-3, meta['Qm'] * 1e5) else: if 'DC' in meta: if meta['DC'] < 1: wavetype = 'PW' suffix = ', {:.2f}Hz PRF, {:.0f}% DC'.format(meta['PRF'], meta['DC'] * 1e2) else: wavetype = 'CW' suffix = '' if 'Astim' in meta: return '{} neuron: {} E-STIM {:.2f}mA/m2, {:.0f}ms{}'.format( meta['neuron'], wavetype, meta['Astim'], meta['tstim'] * 1e3, suffix) else: return '{} neuron ({:.1f}nm): {} A-STIM {:.0f}kHz {:.2f}kPa, {:.0f}ms{} - {} model'.format( meta['neuron'], meta['a'] * 1e9, wavetype, meta['Fdrive'] * 1e-3, meta['Adrive'] * 1e-3, meta['tstim'] * 1e3, suffix, meta['method']) else: return '{} neuron: V-CLAMP {:.1f}-{:.1f}mV, {:.0f}ms'.format( meta['neuron'], meta['Vhold'], meta['Vstep'], meta['tstim'] * 1e3) def cm2inch(*tupl): inch = 2.54 if isinstance(tupl[0], tuple): return tuple(i / inch for i in tupl[0]) else: return tuple(i / inch for i in tupl) def extractPltVar(model, pltvar, df, meta=None, nsamples=0, name=''): if 'func' in pltvar: s = 'model.{}'.format(pltvar['func']) try: var = eval(s) except AttributeError: var = eval(s.replace('model', 'model.pneuron')) elif 'key' in pltvar: var = df[pltvar['key']] elif 'constant' in pltvar: var = eval(pltvar['constant']) * np.ones(nsamples) else: var = df[name] if isinstance(var, pd.Series): var = var.values var = var.copy() if var.size == nsamples - 1: var = np.insert(var, 0, var[0]) var *= pltvar.get('factor', 1) return var def setGrid(n, ncolmax=3): ''' Determine number of rows and columns in figure grid, based on number of variables to plot. ''' if n <= ncolmax: return (1, n) else: return ((n - 1) // ncolmax + 1, ncolmax) def setNormalizer(cmap, bounds, scale='lin'): norm = { 'lin': colors.Normalize, 'log': colors.LogNorm }[scale](*bounds) sm = cm.ScalarMappable(norm=norm, cmap=cmap) sm._A = [] return norm, sm class GenericPlot: def __init__(self, filepaths): ''' Constructor. :param filepaths: list of full paths to output data files to be compared ''' if not isIterable(filepaths): filepaths = [filepaths] self.filepaths = filepaths def __call__(self, *args, **kwargs): return self.render(*args, **kwargs) @staticmethod def getData(entry, frequency=1, trange=None): if entry is None: raise ValueError('non-existing data') if isinstance(entry, str): data, meta = loadData(entry, frequency) else: data, meta = entry data = data.iloc[::frequency] if trange is not None: tmin, tmax = trange data = data.loc[(data['t'] >= tmin) & (data['t'] <= tmax)] return data, meta def render(self, *args, **kwargs): return NotImplementedError @staticmethod def getSimType(fname): ''' Get sim type from filename. ''' mo = re.search('(^[A-Z]*)_(.*).pkl', fname) if not mo: raise ValueError('Could not find sim-key in filename: "{}"'.format(fname)) return mo.group(1) + @staticmethod + def getModel(*args, **kwargs): + return getModel(*args, **kwargs) + + @staticmethod + def figtitle(*args, **kwargs): + return figtitle(*args, **kwargs) + @staticmethod def getTimePltVar(tscale): ''' Return time plot variable for a given temporal scale. ''' return { 'desc': 'time', 'label': 'time', 'unit': tscale, 'factor': {'ms': 1e3, 'us': 1e6}[tscale], 'onset': {'ms': 1e-3, 'us': 1e-6}[tscale] } @staticmethod def createBackBone(*args, **kwargs): return NotImplementedError @staticmethod def prettify(ax, xticks=None, yticks=None, xfmt='{:.0f}', yfmt='{:+.0f}'): try: ticks = ax.get_ticks() ticks = (min(ticks), max(ticks)) ax.set_ticks(ticks) ax.set_ticklabels([xfmt.format(x) for x in ticks]) except AttributeError: if xticks is None: xticks = ax.get_xticks() xticks = (min(xticks), max(xticks)) if yticks is None: yticks = ax.get_yticks() yticks = (min(yticks), max(yticks)) ax.set_xticks(xticks) ax.set_yticks(yticks) if xfmt is not None: ax.set_xticklabels([xfmt.format(x) for x in xticks]) if yfmt is not None: ax.set_yticklabels([yfmt.format(y) for y in yticks]) @staticmethod def addInset(fig, ax, inset): ''' Create inset axis. ''' inset_ax = fig.add_axes(ax.get_position()) inset_ax.set_zorder(1) inset_ax.set_xlim(inset['xlims'][0], inset['xlims'][1]) inset_ax.set_ylim(inset['ylims'][0], inset['ylims'][1]) inset_ax.set_xticks([]) inset_ax.set_yticks([]) inset_ax.add_patch(Rectangle((inset['xlims'][0], inset['ylims'][0]), inset['xlims'][1] - inset['xlims'][0], inset['ylims'][1] - inset['ylims'][0], color='w')) return inset_ax @staticmethod def materializeInset(ax, inset_ax, inset): ''' Materialize inset with zoom boox. ''' # Re-position inset axis axpos = ax.get_position() left, right, = rescale(inset['xcoords'], ax.get_xlim()[0], ax.get_xlim()[1], axpos.x0, axpos.x0 + axpos.width) bottom, top, = rescale(inset['ycoords'], ax.get_ylim()[0], ax.get_ylim()[1], axpos.y0, axpos.y0 + axpos.height) inset_ax.set_position([left, bottom, right - left, top - bottom]) for i in inset_ax.spines.values(): i.set_linewidth(2) # Materialize inset target region with contour frame ax.plot(inset['xlims'], [inset['ylims'][0]] * 2, linestyle='-', color='k') ax.plot(inset['xlims'], [inset['ylims'][1]] * 2, linestyle='-', color='k') ax.plot([inset['xlims'][0]] * 2, inset['ylims'], linestyle='-', color='k') ax.plot([inset['xlims'][1]] * 2, inset['ylims'], linestyle='-', color='k') # Link target and inset with dashed lines if possible if inset['xcoords'][1] < inset['xlims'][0]: ax.plot([inset['xcoords'][1], inset['xlims'][0]], [inset['ycoords'][0], inset['ylims'][0]], linestyle='--', color='k') ax.plot([inset['xcoords'][1], inset['xlims'][0]], [inset['ycoords'][1], inset['ylims'][1]], linestyle='--', color='k') elif inset['xcoords'][0] > inset['xlims'][1]: ax.plot([inset['xcoords'][0], inset['xlims'][1]], [inset['ycoords'][0], inset['ylims'][0]], linestyle='--', color='k') ax.plot([inset['xcoords'][0], inset['xlims'][1]], [inset['ycoords'][1], inset['ylims'][1]], linestyle='--', color='k') else: logger.warning('Inset x-coordinates intersect with those of target region') def postProcess(self, *args, **kwargs): return NotImplementedError @staticmethod def removeSpines(ax): for item in ['top', 'right']: ax.spines[item].set_visible(False) @staticmethod def setXTicks(ax, xticks=None): if xticks is not None: ax.set_xticks(xticks) @staticmethod def setYTicks(ax, yticks=None): if yticks is not None: ax.set_yticks(yticks) @staticmethod def setTickLabelsFontSize(ax, fs): for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks(): tick.label.set_fontsize(fs) @staticmethod def setXLabel(ax, xplt, fs): ax.set_xlabel('$\\rm {}\ ({})$'.format(xplt['label'], xplt['unit']), fontsize=fs) @staticmethod def setYLabel(ax, yplt, fs): ax.set_ylabel('$\\rm {}\ ({})$'.format(yplt['label'], yplt.get('unit', '')), fontsize=fs) @classmethod def addCmap(cls, fig, cmap, handles, comp_values, comp_info, fs, prettify, zscale='lin'): # Create colormap and normalizer try: mymap = plt.get_cmap(cmap) except ValueError: mymap = plt.get_cmap(swapFirstLetterCase(cmap)) norm, sm = setNormalizer(mymap, (comp_values.min(), comp_values.max()), zscale) # Adjust line colors for lh, z in zip(handles, comp_values): if isIterable(lh): for item in lh: item.set_color(sm.to_rgba(z)) else: lh.set_color(sm.to_rgba(z)) # Add colorbar fig.subplots_adjust(left=0.1, right=0.8, bottom=0.15, top=0.95, hspace=0.5) cbarax = fig.add_axes([0.85, 0.15, 0.03, 0.8]) cbar = fig.colorbar(sm, cax=cbarax, orientation='vertical') cbarax.set_ylabel('$\\rm {}\ ({})$'.format( comp_info['desc'].replace(' ', '\ '), comp_info['unit']), fontsize=fs) if prettify: cls.prettify(cbar) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) @staticmethod def getSpikes(data, key='Qm', mph=SPIKE_MIN_QAMP, mpp=SPIKE_MIN_QPROM, mpt=SPIKE_MIN_DT): if key not in data: raise ValueError('charge profile not avilable in dataframe') t, y = [data[k].values for k in['t', key]] dt = t[2] - t[1] ipeaks, proms, widths, ihalfmaxbounds, ibounds = findPeaks( data[key].values, mph=mph, mpd=int(np.ceil(mpt / dt)), mpp=mpp) if ipeaks is None: return [None] * 4 widths *= dt indexes = np.arange(t.size) thalfmaxbounds = np.array([ np.interp(ihalfmaxbounds[:, i], indexes, t, left=np.nan, right=np.nan) for i in range(2) ]).T tbounds = np.array([ np.interp(ibounds[:, i], indexes, t, left=np.nan, right=np.nan) for i in range(2) ]).T return np.array(t[ipeaks]), np.array(y[ipeaks]), np.array(proms), thalfmaxbounds, tbounds class ComparativePlot(GenericPlot): def __init__(self, filepaths, varname): ''' Constructor. :param filepaths: list of full paths to output data files to be compared :param varname: name of variable to extract and compare ''' super().__init__(filepaths) self.varname = varname self.comp_ref_key = None self.meta_ref = None self.comp_info = None self.is_unique_comp = False def checkColors(self, colors): if colors is None: colors = ['C{}'.format(j) for j in range(len(self.filepaths))] return colors def checkLines(self, lines): if lines is None: lines = ['-'] * len(self.filepaths) return lines def checkLabels(self, labels): if labels is not None: if len(labels) != len(self.filepaths): raise ValueError( 'Invalid labels ({}): not matching number of compared files ({})'.format( len(labels), len(self.filepaths))) if not all(isinstance(x, str) for x in labels): raise TypeError('Invalid labels: must be string typed') def checkSimType(self, meta): ''' Check consistency of sim types across files. ''' if meta['simkey'] != self.meta_ref['simkey']: raise ValueError('Invalid comparison: different simulation types') def checkCompValues(self, meta, comp_values): ''' Check consistency of differing values across files. ''' differing = {k: meta[k] != self.meta_ref[k] for k in meta.keys()} if sum(differing.values()) > 1: logger.warning('More than one differing inputs') self.comp_ref_key = None return [] zkey = (list(differing.keys())[list(differing.values()).index(True)]) if self.comp_ref_key is None: self.comp_ref_key = zkey self.is_unique_comp = True comp_values.append(self.meta_ref[self.comp_ref_key]) comp_values.append(meta[self.comp_ref_key]) else: if zkey != self.comp_ref_key: logger.warning('inconsitent differing inputs') self.comp_ref_key = None return [] else: comp_values.append(meta[self.comp_ref_key]) return comp_values def checkConsistency(self, meta, comp_values): ''' Check consistency of sim types and check differing inputs. ''' if self.meta_ref is None: self.meta_ref = meta else: self.checkSimType(meta) comp_values = self.checkCompValues(meta, comp_values) if self.comp_ref_key is None: self.is_unique_comp = False return comp_values def getCompLabels(self, comp_values): if self.comp_info is not None: comp_values = np.array(comp_values) * self.comp_info.get('factor', 1) comp_labels = [ '$\\rm{} = {}\ {}$'.format(self.comp_info['label'], x, self.comp_info['unit']) for x in comp_values] else: comp_labels = comp_values return comp_values, comp_labels def chooseLabels(self, labels, comp_labels, full_labels): if labels is not None: return labels else: if self.is_unique_comp: return comp_labels else: return full_labels diff --git a/PySONIC/plt/timeseries.py b/PySONIC/plt/timeseries.py index bc78b0a..0e6c4fb 100644 --- a/PySONIC/plt/timeseries.py +++ b/PySONIC/plt/timeseries.py @@ -1,469 +1,468 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2018-09-25 16:18:45 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2019-08-14 17:34:35 +# @Last Modified time: 2019-08-16 20:09:11 import numpy as np import matplotlib.pyplot as plt -from ..core import getModel from ..utils import * from .pltutils import * class TimeSeriesPlot(GenericPlot): ''' Generic interface to build a plot displaying temporal profiles of model simulations. ''' @classmethod def setTimeLabel(cls, ax, tplt, fs): return super().setXLabel(ax, tplt, fs) @classmethod def setYLabel(cls, ax, yplt, fs, grouplabel=None): if grouplabel is not None: yplt['label'] = grouplabel return super().setYLabel(ax, yplt, fs) def checkInputs(self, *args, **kwargs): return NotImplementedError @staticmethod def getStimStates(df): try: stimstate = df['stimstate'] except KeyError: stimstate = df['states'] return stimstate.values @classmethod def getStimPulses(cls, t, states): ''' Determine the onset and offset times of pulses from a stimulation vector. :param t: time vector (s). :param states: a vector of stimulation state (ON/OFF) at each instant in time. :return: 3-tuple with number of patches, timing of STIM-ON an STIM-OFF instants. ''' # Compute states derivatives and identify bounds indexes of pulses dstates = np.diff(states) ipulse_on = np.where(dstates > 0.0)[0] + 1 ipulse_off = np.where(dstates < 0.0)[0] + 1 if ipulse_off.size < ipulse_on.size: ioff = t.size - 1 if ipulse_off.size == 0: ipulse_off = np.array([ioff]) else: ipulse_off = np.insert(ipulse_off, ipulse_off.size - 1, ioff) # Get time instants for pulses ON and OFF tpulse_on = t[ipulse_on] tpulse_off = t[ipulse_off] return tpulse_on, tpulse_off @staticmethod def addLegend(ax, handles, labels, fs, color=None, ls=None): lh = ax.legend(handles, labels, loc=1, fontsize=fs, frameon=False) if color is not None: for l in lh.get_lines(): l.set_color(color) if ls: for l in lh.get_lines(): l.set_linestyle(ls) @classmethod def materializeSpikes(cls, ax, data, tplt, yplt, color, mode, add_to_legend=False): tspikes, Qspikes, Qprominences, thalfmaxbounds, _ = cls.getSpikes(data) if tspikes is not None: ax.scatter(tspikes * tplt['factor'], Qspikes * yplt['factor'] + 10, color=color, label='spikes' if add_to_legend else None, marker='v') if mode == 'details': Qbottoms = Qspikes - Qprominences Qmiddles = Qspikes - 0.5 * Qprominences for i in range(len(tspikes)): ax.plot(np.array([tspikes[i]] * 2) * tplt['factor'], np.array([Qspikes[i], Qbottoms[i]]) * yplt['factor'], '--', color=color, label='prominences' if i == 0 and add_to_legend else '') ax.plot(thalfmaxbounds[i] * tplt['factor'], np.array([Qmiddles[i]] * 2) * yplt['factor'], '-.', color=color, label='widths' if i == 0 and add_to_legend else '') return add_to_legend @staticmethod def prepareTime(t, tplt): if tplt['onset'] > 0.0: t = np.insert(t, 0, -tplt['onset']) return t * tplt['factor'] @staticmethod def addPatches(ax, tpatch_on, tpatch_off, tplt, color='#8A8A8A'): for i in range(tpatch_on.size): ax.axvspan(tpatch_on[i] * tplt['factor'], tpatch_off[i] * tplt['factor'], edgecolor='none', facecolor=color, alpha=0.2) @staticmethod def plotInset(inset_ax, t, y, tplt, yplt, line, color, lw): inset_window = np.logical_and(t > (inset['xlims'][0] / tplt['factor']), t < (inset['xlims'][1] / tplt['factor'])) inset_ax.plot(t[inset_window] * tplt['factor'], y[inset_window] * yplt['factor'], linewidth=lw, linestyle=line, color=color) return inset_ax @staticmethod def addInsetPatches(ax, inset_ax, inset, tpatch_on, tpatch_off, tplt, color): ybottom, ytop = ax.get_ylim() cond_on = np.logical_and(tpatch_on > (inset['xlims'][0] / tfactor), tpatch_on < (inset['xlims'][1] / tfactor)) cond_off = np.logical_and(tpatch_off > (inset['xlims'][0] / tfactor), tpatch_off < (inset['xlims'][1] / tfactor)) cond_glob = np.logical_and(tpatch_on < (inset['xlims'][0] / tfactor), tpatch_off > (inset['xlims'][1] / tfactor)) cond_onoff = np.logical_or(cond_on, cond_off) cond = np.logical_or(cond_onoff, cond_glob) npatches_inset = np.sum(cond) for i in range(npatches_inset): inset_ax.add_patch(Rectangle((tpatch_on[cond][i] * tfactor, ybottom), (tpatch_off[cond][i] - tpatch_on[cond][i]) * tfactor, ytop - ybottom, color=color, alpha=0.1)) class CompTimeSeries(ComparativePlot, TimeSeriesPlot): ''' Interface to build a comparative plot displaying profiles of a specific output variable across different model simulations. ''' def __init__(self, filepaths, varname): ''' Constructor. :param filepaths: list of full paths to output data files to be compared :param varname: name of variable to extract and compare ''' ComparativePlot.__init__(self, filepaths, varname) def checkPatches(self, patches): greypatch = False if patches == 'none': patches = [False] * len(self.filepaths) elif patches == 'all': patches = [True] * len(self.filepaths) elif patches == 'one': patches = [True] + [False] * (len(self.filepaths) - 1) greypatch = True elif isinstance(patches, list): if len(patches) != len(self.filepaths): raise ValueError( 'Invalid patches ({}): not matching number of compared files ({})'.format( len(patches), len(self.filepaths))) if not all(isinstance(p, bool) for p in patches): raise TypeError('Invalid patch sequence: all list items must be boolean typed') else: raise ValueError( 'Invalid patches: must be either "none", all", "one", or a boolean list') return patches, greypatch def checkInputs(self, lines, labels, colors, patches): self.checkLabels(labels) lines = self.checkLines(lines) colors = self.checkColors(colors) patches, greypatch = self.checkPatches(patches) return lines, labels, colors, patches, greypatch @staticmethod def createBackBone(figsize): fig, ax = plt.subplots(figsize=figsize) ax.set_zorder(0) return fig, ax @classmethod def postProcess(cls, ax, tplt, yplt, fs, meta, prettify): cls.removeSpines(ax) if 'bounds' in yplt: ax.set_ylim(*yplt['bounds']) cls.setTimeLabel(ax, tplt, fs) cls.setYLabel(ax, yplt, fs) if prettify: cls.prettify(ax, xticks=(0, meta['tstim'] * tplt['factor'])) cls.setTickLabelsFontSize(ax, fs) def render(self, figsize=(11, 4), fs=10, lw=2, labels=None, colors=None, lines=None, patches='one', inset=None, frequency=1, spikes='none', cmap=None, cscale='lin', trange=None, prettify=False): ''' Render plot. :param figsize: figure size (x, y) :param fs: labels fontsize :param lw: linewidth :param labels: list of labels to use in the legend :param colors: list of colors to use for each curve :param lines: list of linestyles :param patches: string indicating whether/how to mark stimulation periods with rectangular patches :param inset: string indicating whether/how to mark an inset zooming on a particular region of the graph :param frequency: frequency at which to plot samples :param spikes: string indicating how to show spikes ("none", "marks" or "details") :param cmap: color map to use for colobar-based comparison (if not None) :param cscale: color scale to use for colobar-based comparison :param trange: optional lower and upper bounds to time axis :return: figure handle ''' lines, labels, colors, patches, greypatch = self.checkInputs( lines, labels, colors, patches) fig, ax = self.createBackBone(figsize) if inset is not None: inset_ax = self.addInset(fig, ax, inset) # Loop through data files handles, comp_values, full_labels = [], [], [] tmin, tmax = np.inf, -np.inf for j, filepath in enumerate(self.filepaths): # Load data try: data, meta = self.getData(filepath, frequency, trange) except ValueError as err: continue if 'tcomp' in meta: meta.pop('tcomp') - full_labels.append(figtitle(meta)) + full_labels.append(self.figtitle(meta)) # Extract model - model = getModel(meta) + model = self.getModel(meta) # Check consistency of sim types and check differing inputs comp_values = self.checkConsistency(meta, comp_values) # Extract time and stim pulses t = data['t'].values stimstate = self.getStimStates(data) tpatch_on, tpatch_off = self.getStimPulses(t, stimstate) tplt = self.getTimePltVar(model.tscale) t = self.prepareTime(t, tplt) # Extract y-variable pltvars = model.getPltVars() if self.varname not in pltvars: raise KeyError( 'Unknown plot variable: "{}". Possible plot variables are: {}'.format( self.varname, ', '.join(['"{}"'.format(p) for p in pltvars.keys()]))) yplt = pltvars[self.varname] y = extractPltVar(model, yplt, data, meta, t.size, self.varname) # Plot time series handles.append(ax.plot(t, y, linewidth=lw, linestyle=lines[j], color=colors[j])[0]) # Optional: add spikes if self.varname == 'Qm' and spikes != 'none': self.materializeSpikes(ax, data, tplt, yplt, colors[j], spikes) # Plot optional inset if inset is not None: inset_ax = self.plotInset(inset_ax, t, y, tplt, yplt, lines[j], colors[j], lw) # Add optional STIM-ON patches if patches[j]: ybottom, ytop = ax.get_ylim() color = '#8A8A8A' if greypatch else handles[j].get_color() self.addPatches(ax, tpatch_on, tpatch_off, tplt, color) if inset is not None: self.addInsetPatches(ax, inset_ax, inset, tpatch_on, tpatch_off, tplt, color) tmin, tmax = min(tmin, t.min()), max(tmax, t.max()) # Determine labels if self.comp_ref_key is not None: self.comp_info = model.inputs().get(self.comp_ref_key, None) comp_values, comp_labels = self.getCompLabels(comp_values) labels = self.chooseLabels(labels, comp_labels, full_labels) # Post-process figure self.postProcess(ax, tplt, yplt, fs, meta, prettify) ax.set_xlim(tmin, tmax) fig.tight_layout() if inset is not None: self.materializeInset(ax, inset_ax, inset) # Add labels or colorbar legend if cmap is not None: if not self.is_unique_comp: raise ValueError('Colormap mode unavailable for multiple differing parameters') if self.comp_info is None: raise ValueError('Colormap mode unavailable for qualitative comparisons') self.addCmap( fig, cmap, handles, comp_values, self.comp_info, fs, prettify, zscale=cscale) else: self.addLegend(ax, handles, labels, fs) return fig class GroupedTimeSeries(TimeSeriesPlot): ''' Interface to build a plot displaying profiles of several output variables arranged into specific schemes. ''' def __init__(self, filepaths, pltscheme=None): ''' Constructor. :param filepaths: list of full paths to output data files to be compared :param varname: name of variable to extract and compare ''' super().__init__(filepaths) self.pltscheme = pltscheme @staticmethod def createBackBone(pltscheme): naxes = len(pltscheme) if naxes == 1: fig, ax = plt.subplots(figsize=(11, 4)) axes = [ax] else: fig, axes = plt.subplots(naxes, 1, figsize=(11, min(3 * naxes, 9))) return fig, axes @classmethod def postProcess(cls, axes, tplt, fs, meta, prettify): for ax in axes: cls.removeSpines(ax) # if prettify: # cls.prettify(ax, xticks=(0, meta['tstim'] * tplt['factor']), yfmt=None) cls.setTickLabelsFontSize(ax, fs) for ax in axes[:-1]: ax.set_xticklabels([]) cls.setTimeLabel(axes[-1], tplt, fs) def render(self, fs=10, lw=2, labels=None, colors=None, lines=None, patches='one', save=False, outputdir=None, fig_ext='png', frequency=1, spikes='none', trange=None, prettify=False): ''' Render plot. :param fs: labels fontsize :param lw: linewidth :param labels: list of labels to use in the legend :param colors: list of colors to use for each curve :param lines: list of linestyles :param patches: boolean indicating whether to mark stimulation periods with rectangular patches :param save: boolean indicating whether or not to save the figure(s) :param outputdir: path to output directory in which to save figure(s) :param fig_ext: string indcating figure extension ("png", "pdf", ...) :param frequency: frequency at which to plot samples :param spikes: string indicating how to show spikes ("none", "marks" or "details") :param trange: optional lower and upper bounds to time axis :return: figure handle(s) ''' figs = [] for filepath in self.filepaths: # Load data and extract model try: data, meta = self.getData(filepath, frequency, trange) except ValueError as err: continue - model = getModel(meta) + model = self.getModel(meta) # Extract time and stim pulses t = data['t'].values stimstate = self.getStimStates(data) tpatch_on, tpatch_off = self.getStimPulses(t, stimstate) tplt = self.getTimePltVar(model.tscale) t = self.prepareTime(t, tplt) # Check plot scheme if provided, otherwise generate it pltvars = model.getPltVars() if self.pltscheme is not None: for key in list(sum(list(self.pltscheme.values()), [])): if key not in pltvars: raise KeyError('Unknown plot variable: "{}"'.format(key)) pltscheme = self.pltscheme else: pltscheme = model.getPltScheme() # Create figure fig, axes = self.createBackBone(pltscheme) # Loop through each subgraph for ax, (grouplabel, keys) in zip(axes, pltscheme.items()): ax_legend_spikes = False # Extract variables to plot nvars = len(keys) ax_pltvars = [pltvars[k] for k in keys] if nvars == 1: ax_pltvars[0]['color'] = 'k' ax_pltvars[0]['ls'] = '-' # Set y-axis unit and bounds self.setYLabel(ax, ax_pltvars[0].copy(), fs, grouplabel=grouplabel) if 'bounds' in ax_pltvars[0]: ax_min = min([ap['bounds'][0] for ap in ax_pltvars]) ax_max = max([ap['bounds'][1] for ap in ax_pltvars]) ax.set_ylim(ax_min, ax_max) # Plot time series icolor = 0 for yplt, name in zip(ax_pltvars, pltscheme[grouplabel]): color = yplt.get('color', 'C{}'.format(icolor)) y = extractPltVar(model, yplt, data, meta, t.size, name) ax.plot(t, y, yplt.get('ls', '-'), c=color, lw=lw, label='$\\rm {}$'.format(yplt['label'])) if 'color' not in yplt: icolor += 1 # Optional: add spikes if name == 'Qm' and spikes != 'none': ax_legend_spikes = self.materializeSpikes( ax, data, tplt, yplt, color, spikes, add_to_legend=True) # Add legend if nvars > 1 or 'gate' in ax_pltvars[0]['desc'] or ax_legend_spikes: ax.legend(fontsize=fs, loc=7, ncol=nvars // 4 + 1, frameon=False) # Set x-limits and add optional patches for ax in axes: ax.set_xlim(t.min(), t.max()) if patches != 'none': self.addPatches(ax, tpatch_on, tpatch_off, tplt) # Post-process figure self.postProcess(axes, tplt, fs, meta, prettify) - axes[0].set_title(figtitle(meta), fontsize=fs) + axes[0].set_title(self.figtitle(meta), fontsize=fs) fig.tight_layout() fig.canvas.set_window_title(model.filecode(meta)) # Save figure if needed (automatic or checked) if save: filecode = model.filecode(meta) if outputdir is None: outputdir = os.path.split(filepath)[0] plt_filename = '{}/{}.{}'.format(outputdir, filecode, fig_ext) plt.savefig(plt_filename) logger.info('Saving figure as "{}"'.format(plt_filename)) plt.close() figs.append(fig) return figs if __name__ == '__main__': # example of use filepaths = OpenFilesDialog('pkl')[0] comp_plot = CompTimeSeries(filepaths, 'Qm') fig = comp_plot.render( lines=['-', '--'], labels=['60 kPa', '80 kPa'], patches='one', colors=['r', 'g'], xticks=[0, 100], yticks=[-80, +50], inset={'xcoords': [5, 40], 'ycoords': [-35, 45], 'xlims': [57.5, 60.5], 'ylims': [10, 35]} ) scheme_plot = GroupedTimeSeries(filepaths) figs = scheme_plot.render() plt.show()