diff --git a/PySONIC/plt/pltutils.py b/PySONIC/plt/pltutils.py index 6618fdc..b74f77d 100644 --- a/PySONIC/plt/pltutils.py +++ b/PySONIC/plt/pltutils.py @@ -1,444 +1,423 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-08-21 14:33:36 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-04-29 12:29:01 +# @Last Modified time: 2020-05-01 22:36:58 ''' Useful functions to generate plots. ''' import re import numpy as np import pandas as pd import matplotlib from matplotlib.patches import Rectangle from matplotlib import cm, colors import matplotlib.pyplot as plt from ..core import getModel from ..utils import * # Matplotlib parameters matplotlib.rcParams['pdf.fonttype'] = 42 matplotlib.rcParams['ps.fonttype'] = 42 matplotlib.rcParams['font.family'] = 'arial' def cm2inch(*tupl): inch = 2.54 if isinstance(tupl[0], tuple): return tuple(i / inch for i in tupl[0]) else: return tuple(i / inch for i in tupl) def extractPltVar(model, pltvar, df, meta=None, nsamples=0, name=''): if 'func' in pltvar: s = pltvar['func'] if not s.startswith('meta'): s = f'model.{s}' try: var = eval(s) except AttributeError: var = eval(s.replace('model', 'model.pneuron')) elif 'key' in pltvar: var = df[pltvar['key']] elif 'constant' in pltvar: var = eval(pltvar['constant']) * np.ones(nsamples) else: var = df[name] if isinstance(var, pd.Series): var = var.values var = var.copy() if var.size == nsamples - 1: var = np.insert(var, 0, var[0]) var *= pltvar.get('factor', 1) return var def setGrid(n, ncolmax=3): ''' Determine number of rows and columns in figure grid, based on number of variables to plot. ''' if n <= ncolmax: return (1, n) else: return ((n - 1) // ncolmax + 1, ncolmax) def setNormalizer(cmap, bounds, scale='lin'): norm = { 'lin': colors.Normalize, 'log': colors.LogNorm }[scale](*bounds) sm = cm.ScalarMappable(norm=norm, cmap=cmap) sm._A = [] return norm, sm class GenericPlot: def __init__(self, outputs): ''' Constructor. :param outputs: list / generator of simulation outputs ''' try: iter(outputs) except TypeError: outputs = [outputs] self.outputs = outputs def __call__(self, *args, **kwargs): return self.render(*args, **kwargs) - @property - def nouts(self): - ''' Number of outputs. ''' - return sum(1 for _ in self.outputs) - def figtitle(self, model, meta): return model.desc(meta) @staticmethod def wraptitle(ax, title, maxwidth=120, sep=':', fs=10, y=1.0): if len(title) > maxwidth: title = '\n'.join(title.split(sep)) y = 0.94 h = ax.set_title(title, fontsize=fs) h.set_y(y) @staticmethod def getData(entry, frequency=1, trange=None): if entry is None: raise ValueError('non-existing data') if isinstance(entry, str): data, meta = loadData(entry, frequency) else: data, meta = entry data = data.iloc[::frequency] if trange is not None: tmin, tmax = trange data = data.loc[(data['t'] >= tmin) & (data['t'] <= tmax)] return data, meta def render(self, *args, **kwargs): raise NotImplementedError @staticmethod def getSimType(fname): ''' Get sim type from filename. ''' mo = re.search('(^[A-Z]*)_(.*).pkl', fname) if not mo: raise ValueError(f'Could not find sim-key in filename: "{fname}"') return mo.group(1) @staticmethod def getModel(*args, **kwargs): return getModel(*args, **kwargs) @staticmethod def getTimePltVar(tscale): ''' Return time plot variable for a given temporal scale. ''' return { 'desc': 'time', 'label': 'time', 'unit': tscale, 'factor': {'ms': 1e3, 'us': 1e6}[tscale], 'onset': {'ms': 1e-3, 'us': 1e-6}[tscale] } @staticmethod def createBackBone(*args, **kwargs): raise NotImplementedError @staticmethod def prettify(ax, xticks=None, yticks=None, xfmt='{:.0f}', yfmt='{:+.0f}'): try: ticks = ax.get_ticks() ticks = (min(ticks), max(ticks)) ax.set_ticks(ticks) ax.set_ticklabels([xfmt.format(x) for x in ticks]) except AttributeError: if xticks is None: xticks = ax.get_xticks() xticks = (min(xticks), max(xticks)) if yticks is None: yticks = ax.get_yticks() yticks = (min(yticks), max(yticks)) ax.set_xticks(xticks) ax.set_yticks(yticks) if xfmt is not None: ax.set_xticklabels([xfmt.format(x) for x in xticks]) if yfmt is not None: ax.set_yticklabels([yfmt.format(y) for y in yticks]) @staticmethod def addInset(fig, ax, inset): ''' Create inset axis. ''' inset_ax = fig.add_axes(ax.get_position()) inset_ax.set_zorder(1) inset_ax.set_xlim(inset['xlims'][0], inset['xlims'][1]) inset_ax.set_ylim(inset['ylims'][0], inset['ylims'][1]) inset_ax.set_xticks([]) inset_ax.set_yticks([]) inset_ax.add_patch(Rectangle((inset['xlims'][0], inset['ylims'][0]), inset['xlims'][1] - inset['xlims'][0], inset['ylims'][1] - inset['ylims'][0], color='w')) return inset_ax @staticmethod def materializeInset(ax, inset_ax, inset): ''' Materialize inset with zoom boox. ''' # Re-position inset axis axpos = ax.get_position() left, right, = rescale(inset['xcoords'], ax.get_xlim()[0], ax.get_xlim()[1], axpos.x0, axpos.x0 + axpos.width) bottom, top, = rescale(inset['ycoords'], ax.get_ylim()[0], ax.get_ylim()[1], axpos.y0, axpos.y0 + axpos.height) inset_ax.set_position([left, bottom, right - left, top - bottom]) for i in inset_ax.spines.values(): i.set_linewidth(2) # Materialize inset target region with contour frame ax.plot(inset['xlims'], [inset['ylims'][0]] * 2, linestyle='-', color='k') ax.plot(inset['xlims'], [inset['ylims'][1]] * 2, linestyle='-', color='k') ax.plot([inset['xlims'][0]] * 2, inset['ylims'], linestyle='-', color='k') ax.plot([inset['xlims'][1]] * 2, inset['ylims'], linestyle='-', color='k') # Link target and inset with dashed lines if possible if inset['xcoords'][1] < inset['xlims'][0]: ax.plot([inset['xcoords'][1], inset['xlims'][0]], [inset['ycoords'][0], inset['ylims'][0]], linestyle='--', color='k') ax.plot([inset['xcoords'][1], inset['xlims'][0]], [inset['ycoords'][1], inset['ylims'][1]], linestyle='--', color='k') elif inset['xcoords'][0] > inset['xlims'][1]: ax.plot([inset['xcoords'][0], inset['xlims'][1]], [inset['ycoords'][0], inset['ylims'][0]], linestyle='--', color='k') ax.plot([inset['xcoords'][0], inset['xlims'][1]], [inset['ycoords'][1], inset['ylims'][1]], linestyle='--', color='k') else: logger.warning('Inset x-coordinates intersect with those of target region') def postProcess(self, *args, **kwargs): raise NotImplementedError @staticmethod def removeSpines(ax): for item in ['top', 'right']: ax.spines[item].set_visible(False) @staticmethod def setXTicks(ax, xticks=None): if xticks is not None: ax.set_xticks(xticks) @staticmethod def setYTicks(ax, yticks=None): if yticks is not None: ax.set_yticks(yticks) @staticmethod def setTickLabelsFontSize(ax, fs): for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks(): tick.label.set_fontsize(fs) @staticmethod def setXLabel(ax, xplt, fs): ax.set_xlabel('$\\rm {}\ ({})$'.format(xplt["label"], xplt["unit"]), fontsize=fs) @staticmethod def setYLabel(ax, yplt, fs): ax.set_ylabel('$\\rm {}\ ({})$'.format(yplt["label"], yplt.get("unit", "")), fontsize=fs) @classmethod def addCmap(cls, fig, cmap, handles, comp_values, comp_info, fs, prettify, zscale='lin'): # Rescale comparison values and adjust unit comp_values = np.asarray(comp_values) * comp_info.get('factor', 1.) comp_factor, comp_prefix = getSIpair(comp_values, scale=zscale) comp_values /= comp_factor comp_info['unit'] = comp_prefix + comp_info['unit'] # Create colormap and normalizer try: mymap = plt.get_cmap(cmap) except ValueError: mymap = plt.get_cmap(swapFirstLetterCase(cmap)) norm, sm = setNormalizer(mymap, (comp_values.min(), comp_values.max()), zscale) # Adjust line colors for lh, z in zip(handles, comp_values): if isIterable(lh): for item in lh: item.set_color(sm.to_rgba(z)) else: lh.set_color(sm.to_rgba(z)) # Add colorbar fig.subplots_adjust(left=0.1, right=0.8, bottom=0.15, top=0.95, hspace=0.5) cbarax = fig.add_axes([0.85, 0.15, 0.03, 0.8]) cbar = fig.colorbar(sm, cax=cbarax, orientation='vertical') desc_str = comp_info["desc"].replace(" ", "\ ") cbarax.set_ylabel('$\\rm {}\ ({})$'.format(desc_str, comp_info["unit"]), fontsize=fs) if prettify: cls.prettify(cbar) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) class ComparativePlot(GenericPlot): def __init__(self, outputs, varname): ''' Constructor. :param outputs: list /generator of simulation outputs to be compared. :param varname: name of variable to extract and compare. ''' super().__init__(outputs) self.varname = varname self.comp_ref_key = None self.meta_ref = None self.comp_info = None self.is_unique_comp = False - @property - def ncomps(self): - return sum(1 for _ in self.outputs) - - def checkColors(self, colors): - if colors is None: - colors = [f'C{j}' for j in range(self.nouts)] - return colors - - def checkLines(self, lines): - if lines is None: - lines = ['-'] * self.nouts - return lines - def checkLabels(self, labels): if labels is not None: - nlabels, nfiles = len(labels), self.nouts - if nlabels != nfiles: - raise ValueError( - f'Invalid labels ({nlabels}): not matching number of compared files ({nfiles})') + if not isIterable(labels): + raise TypeError('Invalid labels: must be an iterable') if not all(isinstance(x, str) for x in labels): raise TypeError('Invalid labels: must be string typed') def checkSimType(self, meta): ''' Check consistency of sim types across files. ''' if meta['simkey'] != self.meta_ref['simkey']: raise ValueError('Invalid comparison: different simulation types') def checkCompValues(self, meta, comp_values): ''' Check consistency of differing values across files. ''' # Get differing values across meta dictionaries diffs = differing(self.meta_ref, meta, subdkey='meta') # Check that only one value differs if len(diffs) > 1: logger.warning('More than one differing inputs') self.comp_ref_key = None return [] # Get the key and differing values zkey, refval, val = diffs[0] # If no comparison key yet, fill it up if self.comp_ref_key is None: self.comp_ref_key = zkey self.is_unique_comp = True comp_values += [refval, val] # Otherwise, check that comparison matches the existing one else: if zkey != self.comp_ref_key: logger.warning('inconsistent differing inputs') self.comp_ref_key = None return [] else: comp_values.append(val) return comp_values def checkConsistency(self, meta, comp_values): ''' Check consistency of sim types and check differing inputs. ''' if self.meta_ref is None: self.meta_ref = meta else: self.checkSimType(meta) comp_values = self.checkCompValues(meta, comp_values) if self.comp_ref_key is None: self.is_unique_comp = False return comp_values def getCompLabels(self, comp_values): if self.comp_info is not None: comp_values = np.array(comp_values) * self.comp_info.get('factor', 1) if 'unit' in self.comp_info: p = self.comp_info.get('precision', 0) comp_values = [f"{si_format(v, p)}{self.comp_info['unit']}".replace(' ', '\ ') for v in comp_values] comp_labels = ['$\\rm{} = {}$'.format(self.comp_info['label'], x) for x in comp_values] else: comp_labels = comp_values return comp_values, comp_labels def chooseLabels(self, labels, comp_labels, full_labels): if labels is not None: return labels else: if self.is_unique_comp: return comp_labels else: return full_labels @staticmethod def getCommonLabel(lbls, seps='_'): ''' Get a common label from a list of labels, by removing parts that differ across them. ''' # Split every label according to list of separator characters, and save splitters as well splt_lbls = [re.split(f'([{seps}])', x) for x in lbls] pieces = [x[::2] for x in splt_lbls] splitters = [x[1::2] for x in splt_lbls] ncomps = len(pieces[0]) # Assert that splitters are equivalent across all labels, and reduce them to a single array assert (x == x[0] for x in splitters), 'Inconsistent splitters' splitters = np.array(splitters[0]) # Transform pieces into 2D matrix, and evaluate equality of every piece across labels pieces = np.array(pieces).T all_identical = [np.all(x == x[0]) for x in pieces] if np.sum(all_identical) < ncomps - 1: logger.warning('More than one differing inputs') return '' # Discard differing pieces and remove associated splitters pieces = pieces[all_identical, 0] splitters = splitters[all_identical[:-1]] # Remove last splitter if the last pieces were discarded if splitters.size == pieces.size: splitters = splitters[:-1] # Join common pieces and associated splitters into a single label common_lbl = '' for p, s in zip(pieces, splitters): common_lbl += f'{p}{s}' common_lbl += pieces[-1] return common_lbl.replace('( ', '(') def addExcitationInset(ax, is_excited): ''' Add text inset on axis stating excitation status. ''' ax.text( 0.7, 0.7, f'{"" if is_excited else "not "}excited', transform=ax.transAxes, ha='center', va='center', size=30, bbox=dict( boxstyle='round', fc=(0.8, 1.0, 0.8) if is_excited else (1., 0.8, 0.8) )) diff --git a/PySONIC/plt/timeseries.py b/PySONIC/plt/timeseries.py index c9082d1..449308d 100644 --- a/PySONIC/plt/timeseries.py +++ b/PySONIC/plt/timeseries.py @@ -1,511 +1,505 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2018-09-25 16:18:45 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-04-29 12:25:50 +# @Last Modified time: 2020-05-01 22:34:19 import numpy as np import matplotlib.pyplot as plt from ..postpro import detectSpikes, convertPeaksProperties from ..utils import * from .pltutils import * class TimeSeriesPlot(GenericPlot): ''' Generic interface to build a plot displaying temporal profiles of model simulations. ''' @classmethod def setTimeLabel(cls, ax, tplt, fs): return super().setXLabel(ax, tplt, fs) @classmethod def setYLabel(cls, ax, yplt, fs, grouplabel=None): if grouplabel is not None: yplt['label'] = grouplabel return super().setYLabel(ax, yplt, fs) def checkInputs(self, *args, **kwargs): raise NotImplementedError @staticmethod def getStimStates(df): try: stimstate = df['stimstate'] except KeyError: stimstate = df['states'] return stimstate.values @classmethod def getStimPulses(cls, t, states): ''' Determine the onset and offset times of pulses from a stimulation vector. :param t: time vector (s). :param states: a vector of stimulation state (ON/OFF) at each instant in time. :return: list of 3-tuples start time, end time and value of each pulse. ''' # Compute states derivatives and identify transition indexes dstates = np.diff(states) itransitions = np.where(np.abs(dstates) > 0)[0] + 1 if states[0] != 0.: itransitions = np.hstack(([0], itransitions)) if states[-1] != 0: itransitions = np.hstack((itransitions, [t.size - 1])) pulses = list(zip(t[itransitions[:-1]], t[itransitions[1:]], states[itransitions[:-1]])) return list(filter(lambda x: x[2] != 0, pulses)) def addLegend(self, fig, ax, handles, labels, fs, color=None, ls=None): lh = ax.legend(handles, labels, loc=1, fontsize=fs, frameon=False) if color is not None: for l in lh.get_lines(): l.set_color(color) if ls: for l in lh.get_lines(): l.set_linestyle(ls) @classmethod def materializeSpikes(cls, ax, data, tplt, yplt, color, mode, add_to_legend=False): ispikes, properties = detectSpikes(data) t = data['t'].values Qm = data['Qm'].values if ispikes is not None: yoffset = 5 ax.plot(t[ispikes] * tplt['factor'], Qm[ispikes] * yplt['factor'] + yoffset, 'v', color=color, label='spikes' if add_to_legend else None) if mode == 'details': ileft = properties['left_bases'] iright = properties['right_bases'] properties = convertPeaksProperties(t, properties) ax.plot(t[ileft] * tplt['factor'], Qm[ileft] * yplt['factor'] - 5, '<', color=color, label='left-bases' if add_to_legend else None) ax.plot(t[iright] * tplt['factor'], Qm[iright] * yplt['factor'] - 10, '>', color=color, label='right-bases' if add_to_legend else None) ax.vlines( x=t[ispikes] * tplt['factor'], ymin=(Qm[ispikes] - properties['prominences']) * yplt['factor'], ymax=Qm[ispikes] * yplt['factor'], color=color, linestyles='dashed', label='prominences' if add_to_legend else '') ax.hlines( y=properties['width_heights'] * yplt['factor'], xmin=properties['left_ips'] * tplt['factor'], xmax=properties['right_ips'] * tplt['factor'], color=color, linestyles='dotted', label='half-widths' if add_to_legend else '') return add_to_legend @staticmethod def prepareTime(t, tplt): if tplt['onset'] > 0.0: tonset = t.min() - 0.05 * np.ptp(t) t = np.insert(t, 0, tonset) return t * tplt['factor'] @staticmethod def getPatchesColors(x): if np.all([xx == x[0] for xx in x]): return ['#8A8A8A'] * len(x) else: xabsmax = np.abs(x).max() _, sm = setNormalizer(plt.get_cmap('RdGy'), (-xabsmax, xabsmax), 'lin') return [sm.to_rgba(xx) for xx in x] @classmethod def addPatches(cls, ax, pulses, tplt, color=None): tstart, tend, x = zip(*pulses) if color is None: colors = cls.getPatchesColors(x) else: colors = [color] * len(x) for i in range(len(pulses)): ax.axvspan(tstart[i] * tplt['factor'], tend[i] * tplt['factor'], edgecolor='none', facecolor=colors[i], alpha=0.2) @staticmethod def plotInset(inset_ax, inset, t, y, tplt, yplt, line, color, lw): inset_ax.plot(t, y, linewidth=lw, linestyle=line, color=color) return inset_ax @staticmethod def addInsetPatches(ax, inset_ax, inset, pulses, tplt, color): tstart, tend, x = [np.array([z]) for z in zip(*pulses)] # colors = cls.getPatchesColors(x) tfactor = tplt['factor'] ybottom, ytop = ax.get_ylim() cond_start = np.logical_and(tstart > (inset['xlims'][0] / tfactor), tstart < (inset['xlims'][1] / tfactor)) cond_end = np.logical_and(tend > (inset['xlims'][0] / tfactor), tend < (inset['xlims'][1] / tfactor)) cond_glob = np.logical_and(tstart < (inset['xlims'][0] / tfactor), tend > (inset['xlims'][1] / tfactor)) cond_onoff = np.logical_or(cond_start, cond_end) cond = np.logical_or(cond_onoff, cond_glob) tstart, tend, x = [z[cond] for z in [tstart, tend, x]] colors = cls.getPatchesColors(x) npatches_inset = tstart.size for i in range(npatches_inset): inset_ax.add_patch(Rectangle( (tstart[i] * tfactor, ybottom), (tend[i] - tstart[i]) * tfactor, ytop - ybottom, color=colors[i], alpha=0.1)) class CompTimeSeries(ComparativePlot, TimeSeriesPlot): ''' Interface to build a comparative plot displaying profiles of a specific output variable across different model simulations. ''' def __init__(self, outputs, varname): ''' Constructor. :param outputs: list / generator of simulator outputs to be compared. :param varname: name of variable to extract and compare ''' ComparativePlot.__init__(self, outputs, varname) def checkPatches(self, patches): - greypatch = False + self.greypatch = False if patches == 'none': - patches = [False] * self.nouts + self.patchfunc = lambda _: False elif patches == 'all': - patches = [True] * self.nouts + self.patchfunc = lambda _: True elif patches == 'one': - patches = [True] + [False] * (self.nouts - 1) - greypatch = True + self.patchfunc = lambda j: True if j == 0 else False + self.greypatch = True elif isinstance(patches, list): - npatches, nfiles = len(patches), self.nouts - if npatches != nfiles: - raise ValueError( - f'Invalid patches ({npatches}): not matching number of files ({nfiles})') if not all(isinstance(p, bool) for p in patches): raise TypeError('Invalid patch sequence: all list items must be boolean typed') + self.patchfunc = lambda j: patches[j] if len(patches) > j else False else: raise ValueError( 'Invalid patches: must be either "none", all", "one", or a boolean list') - return patches, greypatch - def checkInputs(self, lines, labels, colors, patches): + def checkInputs(self, labels, patches): self.checkLabels(labels) - lines = self.checkLines(lines) - colors = self.checkColors(colors) - patches, greypatch = self.checkPatches(patches) - return lines, labels, colors, patches, greypatch + self.checkPatches(patches) @staticmethod def createBackBone(figsize): fig, ax = plt.subplots(figsize=figsize) ax.set_zorder(0) return fig, ax @classmethod def postProcess(cls, ax, tplt, yplt, fs, meta, prettify): cls.removeSpines(ax) if 'bounds' in yplt: ymin, ymax = ax.get_ylim() ax.set_ylim(min(ymin, yplt['bounds'][0]), max(ymax, yplt['bounds'][1])) cls.setTimeLabel(ax, tplt, fs) cls.setYLabel(ax, yplt, fs) if prettify: cls.prettify(ax, xticks=(0, meta['tstim'] * tplt['factor'])) cls.setTickLabelsFontSize(ax, fs) def render(self, figsize=(11, 4), fs=10, lw=2, labels=None, colors=None, lines=None, patches='one', inset=None, frequency=1, spikes='none', cmap=None, cscale='lin', trange=None, prettify=False): ''' Render plot. :param figsize: figure size (x, y) :param fs: labels fontsize :param lw: linewidth :param labels: list of labels to use in the legend :param colors: list of colors to use for each curve :param lines: list of linestyles :param patches: string indicating whether/how to mark stimulation periods with rectangular patches :param inset: string indicating whether/how to mark an inset zooming on a particular region of the graph :param frequency: frequency at which to plot samples :param spikes: string indicating how to show spikes ("none", "marks" or "details") :param cmap: color map to use for colobar-based comparison (if not None) :param cscale: color scale to use for colobar-based comparison :param trange: optional lower and upper bounds to time axis :return: figure handle ''' - lines, labels, colors, patches, greypatch = self.checkInputs( - lines, labels, colors, patches) - + self.checkInputs(labels, patches) fcodes = [] fig, ax = self.createBackBone(figsize) if inset is not None: inset_ax = self.addInset(fig, ax, inset) # Loop through data files handles, comp_values, full_labels = [], [], [] tmin, tmax = np.inf, -np.inf for j, output in enumerate(self.outputs): + color = f'C{j}' if colors is None else colors[j] + line = '-' if lines is None else lines[j] + patch = self.patchfunc(j) # Load data try: data, meta = self.getData(output, frequency, trange) except ValueError: continue if 'tcomp' in meta: meta.pop('tcomp') # Extract model model = self.getModel(meta) fcodes.append(model.filecode(meta)) # Add label to list full_labels.append(self.figtitle(model, meta)) # Check consistency of sim types and check differing inputs comp_values = self.checkConsistency(meta, comp_values) # Extract time and stim pulses t = data['t'].values stimstate = self.getStimStates(data) pulses = self.getStimPulses(t, stimstate) tplt = self.getTimePltVar(model.tscale) t = self.prepareTime(t, tplt) # Extract y-variable pltvars = model.getPltVars() if self.varname not in pltvars: pltvars_str = ', '.join([f'"{p}"' for p in pltvars.keys()]) raise KeyError( f'Unknown plot variable: "{self.varname}". Candidates are: {pltvars_str}') yplt = pltvars[self.varname] y = extractPltVar(model, yplt, data, meta, t.size, self.varname) # Plot time series - handles.append(ax.plot(t, y, linewidth=lw, linestyle=lines[j], color=colors[j])[0]) + handles.append(ax.plot(t, y, linewidth=lw, linestyle=line, color=color)[0]) # Optional: add spikes if self.varname == 'Qm' and spikes != 'none': - self.materializeSpikes(ax, data, tplt, yplt, colors[j], spikes) + self.materializeSpikes(ax, data, tplt, yplt, color, spikes) # Plot optional inset if inset is not None: inset_ax = self.plotInset( - inset_ax, inset, t, y, tplt, yplt, lines[j], colors[j], lw) + inset_ax, inset, t, y, tplt, yplt, lines[j], color, lw) # Add optional STIM-ON patches - if patches[j]: + if patch: ybottom, ytop = ax.get_ylim() - color = None if greypatch else handles[j].get_color() - self.addPatches(ax, pulses, tplt, color=color) + patchcolor = None if self.greypatch else handles[j].get_color() + self.addPatches(ax, pulses, tplt, color=patchcolor) if inset is not None: - self.addInsetPatches(ax, inset_ax, inset, pulses, tplt, color) + self.addInsetPatches(ax, inset_ax, inset, pulses, tplt, patchcolor) tmin, tmax = min(tmin, t.min()), max(tmax, t.max()) # Get common label and add it as title common_label = self.getCommonLabel(full_labels.copy(), seps=':@,()') self.wraptitle(ax, common_label, fs=fs) # Get comp info if any if self.comp_ref_key is not None: self.comp_info = model.inputs().get(self.comp_ref_key, None) # Post-process figure self.postProcess(ax, tplt, yplt, fs, meta, prettify) ax.set_xlim(tmin, tmax) fig.tight_layout() # Materialize inset if any if inset is not None: self.materializeInset(ax, inset_ax, inset) # Add labels or colorbar legend if cmap is not None: if not self.is_unique_comp: raise ValueError('Colormap mode unavailable for multiple differing parameters') if self.comp_info is None: raise ValueError('Colormap mode unavailable for qualitative comparisons') self.addCmap( fig, cmap, handles, comp_values, self.comp_info, fs, prettify, zscale=cscale) else: comp_values, comp_labels = self.getCompLabels(comp_values) labels = self.chooseLabels(labels, comp_labels, full_labels) self.addLegend(fig, ax, handles, labels, fs) # Add window title based on common pattern common_fcode = self.getCommonLabel(fcodes.copy()) fig.canvas.set_window_title(common_fcode) return fig class GroupedTimeSeries(TimeSeriesPlot): ''' Interface to build a plot displaying profiles of several output variables arranged into specific schemes. ''' def __init__(self, outputs, pltscheme=None): ''' Constructor. :param outputs: list / generator of simulation outputs. :param varname: name of variable to extract and compare ''' super().__init__(outputs) self.pltscheme = pltscheme @staticmethod def createBackBone(pltscheme): naxes = len(pltscheme) if naxes == 1: fig, ax = plt.subplots(figsize=(11, 4)) axes = [ax] else: fig, axes = plt.subplots(naxes, 1, figsize=(11, min(3 * naxes, 9))) return fig, axes @classmethod def postProcess(cls, axes, tplt, fs, meta, prettify): for ax in axes: cls.removeSpines(ax) if prettify: cls.prettify(ax, xticks=(0, meta['pp'].tstim * tplt['factor']), yfmt=None) cls.setTickLabelsFontSize(ax, fs) for ax in axes[:-1]: ax.get_shared_x_axes().join(ax, axes[-1]) ax.set_xticklabels([]) cls.setTimeLabel(axes[-1], tplt, fs) def render(self, fs=10, lw=2, labels=None, colors=None, lines=None, patches='one', save=False, outputdir=None, fig_ext='png', frequency=1, spikes='none', trange=None, prettify=False): ''' Render plot. :param fs: labels fontsize :param lw: linewidth :param labels: list of labels to use in the legend :param colors: list of colors to use for each curve :param lines: list of linestyles :param patches: boolean indicating whether to mark stimulation periods with rectangular patches :param save: boolean indicating whether or not to save the figure(s) :param outputdir: path to output directory in which to save figure(s) :param fig_ext: string indcating figure extension ("png", "pdf", ...) :param frequency: frequency at which to plot samples :param spikes: string indicating how to show spikes ("none", "marks" or "details") :param trange: optional lower and upper bounds to time axis :return: figure handle(s) ''' if colors is None: colors = plt.get_cmap('tab10').colors figs = [] for output in self.outputs: # Load data and extract model try: data, meta = self.getData(output, frequency, trange) except ValueError: continue model = self.getModel(meta) # Extract time and stim pulses t = data['t'].values stimstate = self.getStimStates(data) pulses = self.getStimPulses(t, stimstate) tplt = self.getTimePltVar(model.tscale) t = self.prepareTime(t, tplt) # Check plot scheme if provided, otherwise generate it pltvars = model.getPltVars() if self.pltscheme is not None: for key in list(sum(list(self.pltscheme.values()), [])): if key not in pltvars: raise KeyError(f'Unknown plot variable: "{key}"') pltscheme = self.pltscheme else: pltscheme = model.pltScheme # Create figure fig, axes = self.createBackBone(pltscheme) # Loop through each subgraph for ax, (grouplabel, keys) in zip(axes, pltscheme.items()): ax_legend_spikes = False # Extract variables to plot nvars = len(keys) ax_pltvars = [pltvars[k] for k in keys] if nvars == 1: ax_pltvars[0]['color'] = 'k' ax_pltvars[0]['ls'] = '-' # Plot time series icolor = 0 for yplt, name in zip(ax_pltvars, pltscheme[grouplabel]): color = yplt.get('color', colors[icolor]) y = extractPltVar(model, yplt, data, meta, t.size, name) ax.plot(t, y, yplt.get('ls', '-'), c=color, lw=lw, label='$\\rm {}$'.format(yplt["label"])) if 'color' not in yplt: icolor += 1 # Optional: add spikes if name == 'Qm' and spikes != 'none': ax_legend_spikes = self.materializeSpikes( ax, data, tplt, yplt, color, spikes, add_to_legend=True) # Set y-axis unit and bounds self.setYLabel(ax, ax_pltvars[0].copy(), fs, grouplabel=grouplabel) if 'bounds' in ax_pltvars[0]: ymin, ymax = ax.get_ylim() ax_min = min(ymin, *[ap['bounds'][0] for ap in ax_pltvars]) ax_max = max(ymax, *[ap['bounds'][1] for ap in ax_pltvars]) ax.set_ylim(ax_min, ax_max) # Add legend if nvars > 1 or 'gate' in ax_pltvars[0]['desc'] or ax_legend_spikes: ax.legend(fontsize=fs, loc=7, ncol=nvars // 4 + 1, frameon=False) # Set x-limits and add optional patches for ax in axes: ax.set_xlim(t.min(), t.max()) if patches != 'none': self.addPatches(ax, pulses, tplt) # Post-process figure self.postProcess(axes, tplt, fs, meta, prettify) self.wraptitle(axes[0], self.figtitle(model, meta), fs=fs) fig.tight_layout() fig.canvas.set_window_title(model.filecode(meta)) # Save figure if needed (automatic or checked) if save: filecode = model.filecode(meta) if outputdir is None: raise ValueError('output directory not specified') plt_filename = f'{outputdir}/{filecode}.{fig_ext}' plt.savefig(plt_filename) logger.info(f'Saving figure as "{plt_filename}"') plt.close() figs.append(fig) return figs if __name__ == '__main__': # example of use filepaths = OpenFilesDialog('pkl')[0] comp_plot = CompTimeSeries(filepaths, 'Qm') fig = comp_plot.render( lines=['-', '--'], labels=['60 kPa', '80 kPa'], patches='one', colors=['r', 'g'], xticks=[0, 100], yticks=[-80, +50], inset={'xcoords': [5, 40], 'ycoords': [-35, 45], 'xlims': [57.5, 60.5], 'ylims': [10, 35]} ) scheme_plot = GroupedTimeSeries(filepaths) figs = scheme_plot.render() plt.show()