diff --git a/PySONIC/core/batches.py b/PySONIC/core/batches.py index cef1ce2..d8e908b 100644 --- a/PySONIC/core/batches.py +++ b/PySONIC/core/batches.py @@ -1,364 +1,364 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-08-22 14:33:04 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-03-26 19:36:52 +# @Last Modified time: 2021-05-19 16:55:33 ''' Utility functions used in simulations ''' import os import abc import csv import logging import numpy as np import pandas as pd import multiprocess as mp from ..utils import logger, si_format, isIterable class Consumer(mp.Process): ''' Generic consumer process, taking tasks from a queue and outputing results in another queue. ''' def __init__(self, queue_in, queue_out): mp.Process.__init__(self) self.queue_in = queue_in self.queue_out = queue_out logger.debug('Starting %s', self.name) def run(self): while True: nextTask = self.queue_in.get() if nextTask is None: logger.debug('Exiting %s', self.name) self.queue_in.task_done() break answer = nextTask() self.queue_in.task_done() self.queue_out.put(answer) return class Worker: ''' Generic worker class calling a specific function with a given set of parameters. ''' def __init__(self, wid, func, args, kwargs, loglevel): ''' Worker constructor. :param wid: worker ID :param func: function object :param args: list of method arguments :param kwargs: dictionary of optional method arguments :param loglevel: logging level ''' self.id = wid self.func = func self.args = args self.kwargs = kwargs self.loglevel = loglevel def __call__(self): ''' Caller to the function with specific parameters. ''' logger.setLevel(self.loglevel) return self.id, self.func(*self.args, **self.kwargs) class Batch: ''' Generic interface to run batches of function calls. ''' def __init__(self, func, queue): ''' Batch constructor. :param func: function object :param queue: list of list of function parameters ''' self.func = func self.queue = queue def __call__(self, *args, **kwargs): ''' Call the internal run method. ''' return self.run(*args, **kwargs) def getNConsumers(self): ''' Determine number of consumers based on queue length and number of available CPUs. ''' return min(mp.cpu_count(), len(self.queue)) def start(self): ''' Create tasks and results queues, and start consumers. ''' mp.freeze_support() self.tasks = mp.JoinableQueue() self.results = mp.Queue() self.consumers = [Consumer(self.tasks, self.results) for i in range(self.getNConsumers())] for c in self.consumers: c.start() @staticmethod def resolve(params): if isinstance(params, list): args = params kwargs = {} elif isinstance(params, tuple): args, kwargs = params return args, kwargs def assign(self, loglevel): ''' Assign tasks to workers. ''' for i, params in enumerate(self.queue): args, kwargs = self.resolve(params) worker = Worker(i, self.func, args, kwargs, loglevel) self.tasks.put(worker, block=False) def join(self): ''' Put all tasks to None and join the queue. ''' for i in range(len(self.consumers)): self.tasks.put(None, block=False) self.tasks.join() def get(self): ''' Extract and re-order results. ''' outputs, idxs = [], [] for i in range(len(self.queue)): wid, out = self.results.get() outputs.append(out) idxs.append(wid) return [x for _, x in sorted(zip(idxs, outputs))] def stop(self): ''' Close tasks and results queues. ''' self.tasks.close() self.results.close() def run(self, mpi=False, loglevel=logging.INFO): ''' Run batch with or without multiprocessing. ''' if mpi: self.start() self.assign(loglevel) self.join() outputs = self.get() self.stop() else: outputs = [] for params in self.queue: args, kwargs = self.resolve(params) outputs.append(self.func(*args, **kwargs)) return outputs @staticmethod def createQueue(*dims): ''' Create a serialized 2D array of all parameter combinations for a series of individual parameter sweeps. :param dims: list of lists (or 1D arrays) of input parameters :return: list of parameters (list) for each simulation ''' ndims = len(dims) dims_in = [dims[1], dims[0]] inds_out = [1, 0] if ndims > 2: dims_in += dims[2:] inds_out += list(range(2, ndims)) queue = np.stack(np.meshgrid(*dims_in), -1).reshape(-1, ndims) queue = queue[:, inds_out] return queue.tolist() @staticmethod def printQueue(queue, nmax=20): if len(queue) <= nmax: for x in queue: print(x) else: for x in queue[:nmax // 2]: print(x) print(f'... {len(queue) - nmax} more entries ...') for x in queue[-nmax // 2:]: print(x) class LogBatch(metaclass=abc.ABCMeta): ''' Generic interface to a simulation batch in with real-time input:output caching in a specific log file. ''' delimiter = '\t' # csv delimiter rtol = 1e-9 atol = 1e-16 def __init__(self, inputs, root='.'): ''' Construtor. :param inputs: array of batch inputs :param root: root for IO operations ''' self.inputs = inputs self.root = root self.fpath = self.filepath() @property def root(self): return self._root @root.setter def root(self, value): if not os.path.isdir(value): raise ValueError(f'{value} is not a valid directory') self._root = value @property @abc.abstractmethod def in_key(self): ''' Input key. ''' raise NotImplementedError @property @abc.abstractmethod def out_keys(self): ''' Output keys. ''' raise NotImplementedError @property @abc.abstractmethod def suffix(self): ''' filename suffix ''' raise NotImplementedError @property @abc.abstractmethod def unit(self): ''' Input unit. ''' raise NotImplementedError @property def in_label(self): ''' Input label. ''' return f'{self.in_key} ({self.unit})' def rangecode(self, x, label, unit): ''' String describing a batch input range. ''' bounds_str = si_format([x.min(), x.max()], 1, space="") return '{0}{1}{3}-{2}{3}_{4}'.format(label.replace(' ', '_'), *bounds_str, unit, x.size) @property def inputscode(self): ''' String describing the batch inputs. ''' return self.rangecode(self.inputs, self.in_key, self.unit) @abc.abstractmethod def corecode(self): ''' String describing the batch core components. ''' raise NotImplementedError def filecode(self): ''' String fully describing the batch. ''' return f'{self.corecode()}_{self.inputscode}_{self.suffix}_results' def filename(self): ''' Batch associated filename. ''' return f'{self.filecode()}.csv' def filepath(self): ''' Batch associated filepath. ''' return os.path.join(self.root, self.filename()) def createLogFile(self): ''' Create batch log file if it does not exist. ''' if not os.path.isfile(self.fpath): logger.debug(f'creating batch log file: "{self.fpath}"') self.writeLabels() else: logger.debug(f'existing batch log file: "{self.fpath}"') def writeLabels(self): ''' Write the column labels of the batch log file. ''' with open(self.fpath, 'w') as csvfile: writer = csv.writer(csvfile, delimiter=self.delimiter) writer.writerow([self.in_label, *self.out_keys]) def writeEntry(self, entry): ''' Write a new input(s):ouput(s) entry in the batch log file. ''' with open(self.fpath, 'a', newline='') as csvfile: writer = csv.writer(csvfile, delimiter=self.delimiter) writer.writerow(entry) def getLogData(self): ''' Retrieve the batch log file data (inputs and outputs) as a dataframe. ''' return pd.read_csv(self.fpath, sep=self.delimiter).sort_values(self.in_label) def getInput(self): ''' Retrieve the logged batch inputs as an array. ''' return self.getLogData()[self.in_label].values def getSerializedOutput(self): ''' Retrieve the logged batch outputs as an array (if 1 key) or dataframe (if several). ''' if len(self.out_keys) == 1: return self.getLogData()[self.out_keys[0]].values else: return pd.DataFrame({k: self.getLogData()[k].values for k in self.out_keys}) def getOutput(self): return self.getSerializedOutput() def getEntryIndex(self, entry): ''' Get the index corresponding to a given entry. ''' inputs = self.getInput() if len(inputs) == 0: raise ValueError(f'no entries in batch') close = np.isclose(inputs, entry, rtol=self.rtol, atol=self.atol) imatches = np.where(close)[0] if len(imatches) == 0: raise ValueError(f'{entry} entry not found in batch log') elif len(imatches) > 1: raise ValueError(f'duplicate {entry} entry found in batch log') return imatches[0] def getEntryOutput(self, entry): imatch = self.getEntryIndex(entry) return self.getSerializedOutput()[imatch] def isEntry(self, value): ''' Check if a given input is logged in the batch log file. ''' try: self.getEntryIndex(value) return True except ValueError: return False @abc.abstractmethod def compute(self, x): ''' Compute the necessary output(s) for a given input. ''' raise NotImplementedError def computeAndLog(self, x): ''' Compute output(s) and log new entry only if input is not already in the log file. ''' if not self.isEntry(x): logger.debug(f'entry not found: "{x}"') out = self.compute(x) if not isIterable(x): x = [x] if not isIterable(out): out = [out] entry = [*x, *out] if not self.mpi: self.writeEntry(entry) return entry else: logger.debug(f'existing entry: "{x}"') return None def run(self, mpi=False): ''' Run the batch and return the output(s). ''' self.createLogFile() if len(self.getLogData()) < len(self.inputs): batch = Batch(self.computeAndLog, [[x] for x in self.inputs]) self.mpi = mpi outputs = batch.run(mpi=mpi, loglevel=logger.level) outputs = filter(lambda x: x is not None, outputs) if mpi: for out in outputs: self.writeEntry(out) self.mpi = False else: logger.debug('all entries already present') return self.getOutput() diff --git a/PySONIC/neurons/pas.py b/PySONIC/neurons/pas.py index c14ce3d..4c1d784 100644 --- a/PySONIC/neurons/pas.py +++ b/PySONIC/neurons/pas.py @@ -1,99 +1,99 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2020-07-07 16:56:34 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-05-19 12:16:51 +# @Last Modified time: 2021-05-19 16:32:39 import re from ..core import PointNeuron, addSonicFeatures float_pattern = r'([+-]?\d+\.?\d*)' pattern = re.compile( r'pas_Cm0_{0}uF_cm2_gLeak_{0}S_m2_ELeak_{0}mV'.format(float_pattern)) def passiveNeuron(*args): if len(args) == 1: Cm0, gLeak, ELeak = [float(x) for x in re.findall(pattern, args[0])[0]] Cm0 *= 1e-2 else: Cm0, gLeak, ELeak = args @addSonicFeatures class PassiveNeuron(PointNeuron): ''' Generic point-neuron model with only a passive current. ''' states = {} def __new__(cls, Cm0, gLeak, ELeak): ''' Initialization. :param Cm0: membrane capacitance (F/m2) :param gLeak: leakage conductance (S/m2) :param ELeak: leakage revwersal potential (mV) ''' cls.Cm0 = Cm0 cls.gLeak = gLeak cls.ELeak = ELeak return super(PassiveNeuron, cls).__new__(cls) def copy(self): return self.__class__(self.Cm0, self.gLeak, self.ELeak) def pdict(self): return { 'Cm0': f'{self.Cm0 * 1e2:.1f} uF/cm2', 'gLeak': f'{self.gLeak:.1f} S/m2', 'ELeak': f'{self.ELeak:.1f} mV' } def __repr__(self): params_str = ', '.join([f'{k} = {v}' for k, v in self.pdict().items()]) return f'{self.__class__.__name__}({params_str})' def code(self, pdict): pdict = {k: v.replace(' ', '').replace('/', '_') for k, v in pdict.items()} s = '_'.join([f'{k}_{v}' for k, v in pdict.items()]) return f'pas_{s}' @property def name(self): return self.code(self.pdict()) @property def lookup_name(self): pdict = self.pdict() del pdict['gLeak'] return self.code(pdict) @property def Cm0(self): return self._Cm0 @Cm0.setter def Cm0(self, value): self._Cm0 = value @property def Vm0(self): return self.ELeak @classmethod def derStates(cls): return {} @classmethod def steadyStates(cls): return {} @classmethod def iLeak(cls, Vm): ''' non-specific leakage current ''' return cls.gLeak * (Vm - cls.ELeak) # mA/m2 @classmethod def currents(cls): return {'iLeak': lambda Vm, _: cls.iLeak(Vm)} return PassiveNeuron(Cm0, gLeak, ELeak) diff --git a/PySONIC/plt/pltutils.py b/PySONIC/plt/pltutils.py index b7f374c..1d97ee5 100644 --- a/PySONIC/plt/pltutils.py +++ b/PySONIC/plt/pltutils.py @@ -1,491 +1,526 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2017-08-21 14:33:36 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2020-08-24 12:47:51 +# @Last Modified time: 2021-05-19 16:53:36 ''' Useful functions to generate plots. ''' import re import numpy as np import pandas as pd import matplotlib from matplotlib.patches import Polygon, Rectangle from matplotlib import cm, colors import matplotlib.pyplot as plt from ..core import getModel from ..utils import * # Matplotlib parameters matplotlib.rcParams['pdf.fonttype'] = 42 matplotlib.rcParams['ps.fonttype'] = 42 matplotlib.rcParams['font.family'] = 'arial' def getSymmetricCmap(cmap_key): cmap = plt.get_cmap(cmap_key) cl = np.vstack((cmap.colors, cmap.reversed().colors)) return colors.LinearSegmentedColormap.from_list(f'sym_{cmap_key}', cl) for k in ['viridis', 'plasma', 'inferno', 'magma', 'cividis']: for cmap_key in [k, f'{k}_r']: sym_cmap = getSymmetricCmap(cmap_key) plt.register_cmap(name=sym_cmap.name, cmap=sym_cmap) def cm2inch(*tupl): inch = 2.54 if isinstance(tupl[0], tuple): return tuple(i / inch for i in tupl[0]) else: return tuple(i / inch for i in tupl) def extractPltVar(model, pltvar, df, meta=None, nsamples=0, name=''): if 'func' in pltvar: s = pltvar['func'] if not s.startswith('meta'): s = f'model.{s}' try: var = eval(s) except AttributeError as err: if hasattr(model, 'pneuron'): var = eval(s.replace('model', 'model.pneuron')) else: raise err elif 'key' in pltvar: var = df[pltvar['key']] elif 'constant' in pltvar: var = eval(pltvar['constant']) * np.ones(nsamples) else: var = df[name] if isinstance(var, pd.Series): var = var.values var = var.copy() if var.size == nsamples - 1: var = np.insert(var, 0, var[0]) var *= pltvar.get('factor', 1) return var def setGrid(n, ncolmax=3): ''' Determine number of rows and columns in figure grid, based on number of variables to plot. ''' if n <= ncolmax: return (1, n) else: return ((n - 1) // ncolmax + 1, ncolmax) def setNormalizer(cmap, bounds, scale='lin'): norm = { 'lin': colors.Normalize, 'log': colors.LogNorm, 'symlog': colors.SymLogNorm }[scale](*bounds) sm = cm.ScalarMappable(norm=norm, cmap=cmap) sm._A = [] return norm, sm class GenericPlot: def __init__(self, outputs): ''' Constructor. :param outputs: list / generator of simulation outputs ''' try: iter(outputs) except TypeError: outputs = [outputs] self.outputs = outputs def __call__(self, *args, **kwargs): return self.render(*args, **kwargs) def figtitle(self, model, meta): return model.desc(meta) @staticmethod def wraptitle(ax, title, maxwidth=120, sep=':', fs=10, y=1.0): if len(title) > maxwidth: title = '\n'.join(title.split(sep)) y = 0.94 h = ax.set_title(title, fontsize=fs) h.set_y(y) @staticmethod def getData(entry, frequency=1, trange=None): if entry is None: raise ValueError('non-existing data') if isinstance(entry, str): data, meta = loadData(entry, frequency) else: data, meta = entry data = data.iloc[::frequency] if trange is not None: tmin, tmax = trange data = data.loc[(data['t'] >= tmin) & (data['t'] <= tmax)] return data, meta def render(self, *args, **kwargs): raise NotImplementedError @staticmethod def getSimType(fname): ''' Get sim type from filename. ''' mo = re.search('(^[A-Z]*)_(.*).pkl', fname) if not mo: raise ValueError(f'Could not find sim-key in filename: "{fname}"') return mo.group(1) @staticmethod def getModel(*args, **kwargs): return getModel(*args, **kwargs) @staticmethod def getTimePltVar(tscale): ''' Return time plot variable for a given temporal scale. ''' return { 'desc': 'time', 'label': 'time', 'unit': tscale, 'factor': {'ms': 1e3, 'us': 1e6}[tscale], 'onset': {'ms': 1e-3, 'us': 1e-6}[tscale] } @staticmethod def createBackBone(*args, **kwargs): raise NotImplementedError @staticmethod def prettify(ax, xticks=None, yticks=None, xfmt='{:.0f}', yfmt='{:+.0f}'): try: ticks = ax.get_ticks() ticks = (min(ticks), max(ticks)) ax.set_ticks(ticks) ax.set_ticklabels([xfmt.format(x) for x in ticks]) except AttributeError: if xticks is None: xticks = ax.get_xticks() xticks = (min(xticks), max(xticks)) if yticks is None: yticks = ax.get_yticks() yticks = (min(yticks), max(yticks)) ax.set_xticks(xticks) ax.set_yticks(yticks) if xfmt is not None: ax.set_xticklabels([xfmt.format(x) for x in xticks]) if yfmt is not None: ax.set_yticklabels([yfmt.format(y) for y in yticks]) @staticmethod def addInset(fig, ax, inset): ''' Create inset axis. ''' inset_ax = fig.add_axes(ax.get_position()) inset_ax.set_zorder(1) inset_ax.set_xlim(inset['xlims'][0], inset['xlims'][1]) inset_ax.set_ylim(inset['ylims'][0], inset['ylims'][1]) inset_ax.set_xticks([]) inset_ax.set_yticks([]) inset_ax.add_patch(Rectangle((inset['xlims'][0], inset['ylims'][0]), inset['xlims'][1] - inset['xlims'][0], inset['ylims'][1] - inset['ylims'][0], color='w')) return inset_ax @staticmethod def materializeInset(ax, inset_ax, inset): ''' Materialize inset with zoom boox. ''' # Re-position inset axis axpos = ax.get_position() left, right, = rescale(inset['xcoords'], ax.get_xlim()[0], ax.get_xlim()[1], axpos.x0, axpos.x0 + axpos.width) bottom, top, = rescale(inset['ycoords'], ax.get_ylim()[0], ax.get_ylim()[1], axpos.y0, axpos.y0 + axpos.height) inset_ax.set_position([left, bottom, right - left, top - bottom]) for i in inset_ax.spines.values(): i.set_linewidth(2) # Materialize inset target region with contour frame ax.plot(inset['xlims'], [inset['ylims'][0]] * 2, linestyle='-', color='k') ax.plot(inset['xlims'], [inset['ylims'][1]] * 2, linestyle='-', color='k') ax.plot([inset['xlims'][0]] * 2, inset['ylims'], linestyle='-', color='k') ax.plot([inset['xlims'][1]] * 2, inset['ylims'], linestyle='-', color='k') # Link target and inset with dashed lines if possible if inset['xcoords'][1] < inset['xlims'][0]: ax.plot([inset['xcoords'][1], inset['xlims'][0]], [inset['ycoords'][0], inset['ylims'][0]], linestyle='--', color='k') ax.plot([inset['xcoords'][1], inset['xlims'][0]], [inset['ycoords'][1], inset['ylims'][1]], linestyle='--', color='k') elif inset['xcoords'][0] > inset['xlims'][1]: ax.plot([inset['xcoords'][0], inset['xlims'][1]], [inset['ycoords'][0], inset['ylims'][0]], linestyle='--', color='k') ax.plot([inset['xcoords'][0], inset['xlims'][1]], [inset['ycoords'][1], inset['ylims'][1]], linestyle='--', color='k') else: logger.warning('Inset x-coordinates intersect with those of target region') def postProcess(self, *args, **kwargs): raise NotImplementedError @staticmethod def removeSpines(ax): for item in ['top', 'right']: ax.spines[item].set_visible(False) @staticmethod def setXTicks(ax, xticks=None): if xticks is not None: ax.set_xticks(xticks) @staticmethod def setYTicks(ax, yticks=None): if yticks is not None: ax.set_yticks(yticks) @staticmethod def setTickLabelsFontSize(ax, fs): for tick in ax.xaxis.get_major_ticks() + ax.yaxis.get_major_ticks(): tick.label.set_fontsize(fs) @staticmethod def setXLabel(ax, xplt, fs): ax.set_xlabel('$\\rm {}\ ({})$'.format(xplt["label"], xplt["unit"]), fontsize=fs) @staticmethod def setYLabel(ax, yplt, fs): ax.set_ylabel('$\\rm {}\ ({})$'.format(yplt["label"], yplt.get("unit", "")), fontsize=fs) @classmethod def addCmap(cls, fig, cmap, handles, comp_values, comp_info, fs, prettify, zscale='lin'): if all(isinstance(x, str) for x in comp_values): # If list of strings, assume that index suffixes can be extracted prefix, suffixes = extractCommonPrefix(comp_values) comp_values = [int(s) for s in suffixes] desc_str = f'{prefix}\ index' else: # Rescale comparison values and adjust unit comp_values = np.asarray(comp_values) * comp_info.get('factor', 1.) comp_factor, comp_prefix = getSIpair(comp_values, scale=zscale) comp_values /= comp_factor comp_info['unit'] = comp_prefix + comp_info['unit'] desc_str = comp_info["desc"].replace(" ", "\ ") if len(comp_info['unit']) > 0: desc_str = f"{desc_str}\ ({comp_info['unit']})" nvalues = len(comp_values) # Create colormap and normalizer try: mymap = plt.get_cmap(cmap) except ValueError: mymap = plt.get_cmap(swapFirstLetterCase(cmap)) norm, sm = setNormalizer(mymap, (min(comp_values), max(comp_values)), zscale) # Extract and adjust line colors zcolors = sm.to_rgba(comp_values) for lh, c in zip(handles, zcolors): if isIterable(lh): for item in lh: item.set_color(c) else: lh.set_color(c) # Add colorbar fig.subplots_adjust(left=0.1, right=0.8, bottom=0.15, top=0.95, hspace=0.5) cbarax = fig.add_axes([0.85, 0.15, 0.03, 0.8]) cbar_kwargs = {} if all(isinstance(x, int) for x in comp_values): dx = np.diff(comp_values) if all(x == dx[0] for x in dx): dx = dx[0] ticks = comp_values bounds = np.hstack((ticks, [max(ticks) + dx])) - dx / 2 if nvalues > 10: ticks = [ticks[0], ticks[-1]] cbar_kwargs.update({'ticks': ticks, 'boundaries': bounds, 'format': '%1i'}) cbarax.tick_params(axis='both', which='both', length=0) cbar = fig.colorbar(sm, cax=cbarax, **cbar_kwargs) fig.sm = sm # add scalar mappable as figure attribute in case of future need cbarax.set_ylabel(f'$\\rm {desc_str}$', fontsize=fs) if prettify: cls.prettify(cbar) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) class ComparativePlot(GenericPlot): def __init__(self, outputs, varname): ''' Constructor. :param outputs: list /generator of simulation outputs to be compared. :param varname: name of variable to extract and compare. ''' super().__init__(outputs) self.varname = varname self.comp_ref_key = None self.meta_ref = None self.comp_info = None self.is_unique_comp = False def checkLabels(self, labels): if labels is not None: if not isIterable(labels): raise TypeError('Invalid labels: must be an iterable') if not all(isinstance(x, str) for x in labels): raise TypeError('Invalid labels: must be string typed') def checkSimType(self, meta): ''' Check consistency of sim types across files. ''' if meta['simkey'] != self.meta_ref['simkey']: raise ValueError('Invalid comparison: different simulation types') def checkCompValues(self, meta, comp_values): ''' Check consistency of differing values across files. ''' # Get differing values across meta dictionaries diffs = differing(self.meta_ref, meta, subdkey='meta') # Check that only one value differs if len(diffs) > 1: logger.warning('More than one differing inputs') self.comp_ref_key = None return [] # Get the key and differing values zkey, refval, val = diffs[0] # If no comparison key yet, fill it up if self.comp_ref_key is None: self.comp_ref_key = zkey self.is_unique_comp = True comp_values += [refval, val] # Otherwise, check that comparison matches the existing one else: if zkey != self.comp_ref_key: logger.warning('inconsistent differing inputs') self.comp_ref_key = None return [] else: comp_values.append(val) return comp_values def checkConsistency(self, meta, comp_values): ''' Check consistency of sim types and check differing inputs. ''' if self.meta_ref is None: self.meta_ref = meta else: self.checkSimType(meta) comp_values = self.checkCompValues(meta, comp_values) if self.comp_ref_key is None: self.is_unique_comp = False return comp_values def getCompLabels(self, comp_values): if self.comp_info is not None: comp_values = np.array(comp_values) * self.comp_info.get('factor', 1) if 'unit' in self.comp_info: p = self.comp_info.get('precision', 0) comp_values = [f"{si_format(v, p)}{self.comp_info['unit']}".replace(' ', '\ ') for v in comp_values] comp_labels = ['$\\rm{} = {}$'.format(self.comp_info['label'], x) for x in comp_values] else: comp_labels = comp_values return comp_values, comp_labels def chooseLabels(self, labels, comp_labels, full_labels): if labels is not None: return labels else: if self.is_unique_comp: return comp_labels else: return full_labels @staticmethod def getCommonLabel(lbls, seps='_'): ''' Get a common label from a list of labels, by removing parts that differ across them. ''' # Split every label according to list of separator characters, and save splitters as well splt_lbls = [re.split(f'([{seps}])', x) for x in lbls] pieces = [x[::2] for x in splt_lbls] splitters = [x[1::2] for x in splt_lbls] ncomps = len(pieces[0]) # Assert that splitters are equivalent across all labels, and reduce them to a single array assert (x == x[0] for x in splitters), 'Inconsistent splitters' splitters = np.array(splitters[0]) # Transform pieces into 2D matrix, and evaluate equality of every piece across labels pieces = np.array(pieces).T all_identical = [np.all(x == x[0]) for x in pieces] if np.sum(all_identical) < ncomps - 1: logger.warning('More than one differing inputs') return '' # Discard differing pieces and remove associated splitters pieces = pieces[all_identical, 0] splitters = splitters[all_identical[:-1]] # Remove last splitter if the last pieces were discarded if splitters.size == pieces.size: splitters = splitters[:-1] # Join common pieces and associated splitters into a single label common_lbl = '' for p, s in zip(pieces, splitters): common_lbl += f'{p}{s}' common_lbl += pieces[-1] return common_lbl.replace('( ', '(') def addExcitationInset(ax, is_excited): ''' Add text inset on axis stating excitation status. ''' ax.text( 0.7, 0.7, f'{"" if is_excited else "not "}excited', transform=ax.transAxes, ha='center', va='center', size=30, bbox=dict( boxstyle='round', fc=(0.8, 1.0, 0.8) if is_excited else (1., 0.8, 0.8) )) def mirrorProp(org, new, prop): ''' Mirror an instance property onto another instance of the same class. ''' getattr(new, f'set_{prop}')(getattr(org, f'get_{prop}')()) def mirrorAxis(org_ax, new_ax, p=False): ''' Mirror content of original axis to a new axis. That includes: - position on the figure - spines properties - ticks, ticklabels, and labels - vertical spans ''' mirrorProp(org_ax, new_ax, 'position') for sk in ['bottom', 'left', 'right', 'top']: mirrorProp(org_ax.spines[sk], new_ax.spines[sk], 'visible') for prop in ['label', 'ticks', 'ticklabels']: for k in ['x', 'y']: mirrorProp(org_ax, new_ax, f'{k}{prop}') ax_children = org_ax.get_children() vspans = filter(lambda x: isinstance(x, Polygon), ax_children) for vs in vspans: props = vs.properties() xmin, xmax = [props['xy'][i][0] for i in [0, 2]] kwargs = {k: props[k] for k in ['alpha', 'edgecolor', 'facecolor']} if kwargs['edgecolor'] == (0.0, 0.0, 0.0, 0.0): kwargs['edgecolor'] = 'none' new_ax.axvspan(xmin, xmax, **kwargs) + + +def harmonizeAxesLimits(axes, dim='xy'): + ''' Harmonize x and/or y limits across an array of axes. ''' + axes = axes.flatten() + xlims, ylims = [np.inf, -np.inf], [np.inf, -np.inf] + for ax in axes: + xlims = [min(xlims[0], ax.get_xlim()[0]), max(xlims[1], ax.get_xlim()[1])] + ylims = [min(ylims[0], ax.get_ylim()[0]), max(ylims[1], ax.get_ylim()[1])] + for ax in axes: + if dim in ['xy', 'x']: + ax.set_xlim(*xlims) + if dim in ['xy', 'y']: + ax.set_ylim(*ylims) + + +def hideSpines(ax, spines='all'): + if isIterable(ax): + for item in ax: + hideSpines(item, spines=spines) + else: + if spines == 'all': + spines = ['top', 'bottom', 'right', 'left'] + for sk in spines: + ax.spines[sk].set_visible(False) + + +def hideTicks(ax, key='xy'): + if isIterable(ax): + for item in ax: + hideTicks(item, key=key) + if key in ['xy', 'x']: + ax.set_xticks([]) + if key in ['xy', 'y']: + ax.set_yticks([])