diff --git a/PySONIC/multicomp/benchmarks.py b/PySONIC/multicomp/benchmarks.py index b13e247..3c0c369 100644 --- a/PySONIC/multicomp/benchmarks.py +++ b/PySONIC/multicomp/benchmarks.py @@ -1,353 +1,355 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2021-05-14 19:42:00 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-05-28 16:48:43 +# @Last Modified time: 2021-05-31 12:22:17 import os import numpy as np import matplotlib.pyplot as plt from ..core import NeuronalBilayerSonophore, PulsedProtocol, Batch from ..core.drives import AcousticDrive, AcousticDriveArray from ..utils import si_format, rmse, rescale from ..neurons import passiveNeuron from ..postpro import gamma from ..plt import harmonizeAxesLimits, hideSpines, hideTicks, addYscale, addXscale from .coupled_nbls import CoupledSonophores class Benchmark: - def __init__(self, a, nnodes, outdir=None): + def __init__(self, a, nnodes, outdir=None, nodecolors=None): self.a = a self.nnodes = nnodes self.outdir = outdir if not os.path.isdir(self.outdir): os.mkdir(self.outdir) + if nodecolors is None: + nodecolors = plt.get_cmap('Dark2').colors + self.nodecolors = nodecolors def pdict(self): return { 'a': f'{self.a * 1e9:.0f} nm', 'nnodes': f'{self.nnodes} nodes', } def pstr(self): l = [] for k, v in self.pdict().items(): if k == 'nnodes': l.append(v) else: l.append(f'{k} = {v}') return ', '.join(l) def __repr__(self): return f'{self.__class__.__name__}({self.pstr()})' def code(self): s = self.__repr__() for k in ['/', '(', ',']: s = s.replace(k, '_') for k in ['=', ' ', ')']: s = s.replace(k, '') return s def runSims(self, model, drives, tstim, covs): ''' Run full and sonic simulations for a specific combination drives, pulsed protocol and coverage fractions, harmonize outputs and compute normalized charge density profiles. ''' Fdrive = drives[0].f assert all(x.f == Fdrive for x in drives), 'frequencies do not match' assert len(covs) == model.nnodes, 'coverages do not match model dimensions' assert len(drives) == model.nnodes, 'drives do not match model dimensions' # If not provided, compute stimulus duration from model passive properties min_ncycles = 10 ntaumax_conv = 5 if tstim is None: tstim = max(ntaumax_conv * model.taumax, min_ncycles / Fdrive) # Recast stimulus duration as finite multiple of acoustic period tstim = int(np.ceil(tstim * Fdrive)) / Fdrive # s # Pulsed protocol pp = PulsedProtocol(tstim, 0) # Simulate/Load with full and sonic methods data, meta = {}, {} for method in ['full', 'sonic']: data[method], meta[method] = model.simAndSave( drives, pp, covs, method, outdir=self.outdir, overwrite=False, minimize_output=True) # Cycle-average full solution and interpolate sonic solution along same time vector data['cycleavg'] = data['full'].cycleAveraged(1 / Fdrive) data['sonic'] = data['sonic'].interpolate(data['cycleavg'].time) # Compute normalized charge density profiles and add them to dataset for simkey, simdata in data.items(): for nodekey, nodedata in simdata.items(): nodedata['Qnorm'] = nodedata['Qm'] / model.refpneuron.Cm0 * 1e3 # mV # Return dataset return data, meta def getQnorms(self, data, k, cut_bounds=True): ''' Get node-specific list of cycle-averaged and sonic normalized charge vectors (with, by default, discarding of bounding artefact elements). ''' Qnorms = np.array([data[simkey][k]['Qnorm'].values for simkey in ['cycleavg', 'sonic']]) if cut_bounds: Qnorms = Qnorms[:, 1:-1] return Qnorms def computeGamma(self, data, *args): ''' Evaluate per-node gamma on charge density profiles. ''' gamma_dict = {} resolution = list(data['cycleavg'].values())[0].dt for k in data['cycleavg'].keys(): # Get normalized charge vectors (discarding 1st and last indexes) and compute gamma gamma_dict[k] = gamma(*self.getQnorms(data, k), *args, resolution) # Pad gamma with nan on each side to ensure size consistency with time vector gamma_dict[k] = np.pad(gamma_dict[k], 1, mode='constant', constant_values=(np.nan,)) return gamma_dict def computeRMSE(self, data): ''' Evaluate per-node RMSE on charge density profiles. ''' return {k: rmse(*self.getQnorms(data, k)) for k in data['cycleavg'].keys()} def computeSteadyStateDeviation(self, data): ''' Evaluate per-node steady-state absolute deviation on charge density profiles. ''' return {k: np.abs(np.squeeze(np.diff(self.getQnorms(data, k), axis=0)))[-1] for k in data['cycleavg'].keys()} def computeMeanNormDifference(self, data): ''' Evaluate per-node mean absolute difference on [0, 1] normalized charge profiles. ''' d = {} for k in data['cycleavg'].keys(): y = self.getQnorms(data, k) # Rescale signals linearly between 0 and 1 ynorms = np.array([rescale(yy) for yy in y]) # Compute absolute difference signal ydiff = np.squeeze(np.abs(np.diff(ynorms, axis=0))) # Compute mean absolute difference d[k] = np.mean(ydiff) * 1e2 # % return d def computeDivergence(self, data, eval_mode, *args): ''' Compute divergence according to given eval_mode. ''' if eval_mode == 'rmse': div_dict = self.computeRMSE(data) elif eval_mode == 'gamma': div_dict = {k: np.nanmax(v) for k, v in self.computeGamma(data, *args).items()} elif eval_mode == 'ss': div_dict = self.computeSteadyStateDeviation(data) elif eval_mode == 'normdiff': div_dict = self.computeMeanNormDifference(data) return max(div_dict.values()) def plotQnorm(self, ax, data): ''' Plot normalized charge density signals on an axis. ''' markers = {'full': '-', 'cycleavg': '--', 'sonic': '-'} alphas = {'full': 0.5, 'cycleavg': 1., 'sonic': 1.} # tplt = TimeSeriesPlot.getTimePltVar('ms') # yplt = self.model.refpneuron.getPltVars()['Qm/Cm0'] # mode = 'details' for simkey, simdata in data.items(): for i, (nodekey, nodedata) in enumerate(simdata.items()): - c = f'C{i}' y = nodedata['Qnorm'].values y[-1] = y[-2] - ax.plot(nodedata.time * 1e3, y, markers[simkey], c=c, + ax.plot(nodedata.time * 1e3, y, markers[simkey], c=self.nodecolors[i], alpha=alphas[simkey], label=f'{simkey} - {nodekey}', clip_on=False) # if simkey == 'cycleavg': # TimeSeriesPlot.materializeSpikes(ax, nodedata, tplt, yplt, c, mode) def plotMeanNormDifference(self, ax, data): for i, (k, nodedata) in enumerate(data['cycleavg'].items()): t = nodedata.time[1:-1] y = self.getQnorms(data, k) - c = f'C{i}' + c = self.nodecolors[i] ynorms = np.array([rescale(yy) for yy in y]) for yn, marker in zip(ynorms, ['--', '-']): ax.plot(t * 1e3, yn, marker, c=c, clip_on=False) ax.fill_between(t * 1e3, *ynorms, alpha=0.5) ydiff = np.squeeze(np.abs(np.diff(ynorms, axis=0))) ymean = np.mean(ydiff) ax.text(0.5, 0.3 * (i + 1), f'{ymean * 1e2:.2f}%', c=c, transform=ax.transAxes) def plotGamma(self, ax, data, *gamma_args): gamma_dict = self.computeGamma(data, *gamma_args) tplt = list(data['cycleavg'].values())[0].time * 1e3 for i, (nodekey, nodegamma) in enumerate(gamma_dict.items()): - ax.plot(tplt, nodegamma, c=f'C{i}', label=nodekey, clip_on=False) + ax.plot(tplt, nodegamma, c=self.nodecolors[i], label=nodekey, clip_on=False) ax.axhline(1, linestyle='--', c='k') def plotSignalsOver2DSpace(self, gridxkey, gridxvec, gridxunit, gridykey, gridyvec, gridyunit, results, pltfunc, yunit='', title=None, fs=10, flipud=True, fliplr=False): ''' Plot signals over 2D space. ''' # Create grid-like figure - fig, axes = plt.subplots(gridxvec.size, gridyvec.size) + fig, axes = plt.subplots(gridxvec.size, gridyvec.size, figsize=(9, 7)) # Re-arrange axes and labels if flipud/fliplr option is set supylabel_args = {} supxlabel_args = {'y': 1.0, 'va': 'top'} if flipud: axes = axes[::-1] supxlabel_args = {} if fliplr: axes = axes[:, ::-1] supylabel_args = {'x': 1.0, 'ha': 'right'} # Add title and general axes labels if title is not None: fig.suptitle(title, fontsize=fs + 2) fig.supxlabel(gridxkey, fontsize=fs + 2, **supxlabel_args) fig.supylabel(gridykey, fontsize=fs + 2, **supylabel_args) # Loop through the axes and plot results, while storing time ranges i = 0 tranges = [] for i, axrow in enumerate(axes): for j, ax in enumerate(axrow): hideSpines(ax) hideTicks(ax) ax.margins(0) if results[i, j] is not None: pltfunc(ax, results[i, j]) tranges.append(np.ptp(ax.get_xlim())) if len(np.unique(tranges)) > 1: # If more than one time range, add common x-scale to all axes tmin = min(tranges) for axrow in axes[::-1]: for ax in axrow: trans = (ax.transData + ax.transAxes.inverted()) xpoints = [trans.transform([x, 0])[0] for x in [0, tmin]] ax.plot(xpoints, [-0.05] * 2, c='k', lw=2, transform=ax.transAxes, clip_on=False) else: # Otherwise, add x-scale only to axis opposite to origin side = 'top' if flipud else 'bottom' addXscale(axes[-1, -1], 0, 0.05, unit='ms', fmt='.0f', fs=fs, side=side) # Harmonize y-limits across all axes, and add y-scale to axis opposite to origin harmonizeAxesLimits(axes, dim='y') side = 'left' if fliplr else 'right' if yunit is not None: addYscale(axes[-1, -1], 0.05, 0, unit=yunit, fmt='.0f', fs=fs, side=side) # Set labels for xvec and yvec values along the two figure grid dimensions for ax, x in zip(axes[0, :], gridxvec): ax.set_xlabel(f'{si_format(x)}{gridxunit}', labelpad=15, fontsize=fs + 2) if not flipud: ax.xaxis.set_label_position('top') for ax, y in zip(axes[:, 0], gridyvec): if fliplr: ax.yaxis.set_label_position('right') ax.set_ylabel(f'{si_format(y)}{gridyunit}', labelpad=15, fontsize=fs + 2) # Return figure object return fig class PassiveBenchmark(Benchmark): def __init__(self, a, nnodes, Cm0, ELeak, **kwargs): super().__init__(a, nnodes, **kwargs) self.Cm0 = Cm0 self.ELeak = ELeak def pdict(self): return { **super().pdict(), 'Cm0': f'{self.Cm0 * 1e2:.1f} uF/cm2', 'ELeak': f'{self.ELeak} mV', } def getModelAndRunSims(self, drives, covs, taum, tauax): ''' Create passive model for a combination of time constants. ''' gLeak = self.Cm0 / taum ga = self.Cm0 / tauax pneuron = passiveNeuron(self.Cm0, gLeak, self.ELeak) model = CoupledSonophores([ NeuronalBilayerSonophore(self.a, pneuron) for i in range(self.nnodes)], ga) return self.runSims(model, drives, None, covs) def runSimsOverTauSpace(self, drives, covs, taum_range, tauax_range, mpi=False): ''' Run simulations over 2D time constant space. ''' queue = [[drives, covs] + x for x in Batch.createQueue(taum_range, tauax_range)] batch = Batch(self.getModelAndRunSims, queue) # batch.printQueue(queue) output = batch.run(mpi=mpi) results = [x[0] for x in output] # removing meta return np.reshape(results, (taum_range.size, tauax_range.size)).T def plotSignalsOverTauSpace(self, taum_range, tauax_range, results, pltfunc=None, fs=10): if pltfunc is None: pltfunc = 'plotQnorm' yunit = {'plotQnorm': 'mV', 'plotMeanNormDifference': None}[pltfunc] title = pltfunc[4:] pltfunc = getattr(self, pltfunc) return self.plotSignalsOver2DSpace( 'taum', taum_range, 's', 'tauax', tauax_range, 's', results, pltfunc, title=title, yunit=yunit) class FiberBenchmark(Benchmark): def __init__(self, a, nnodes, pneuron, ga, **kwargs): super().__init__(a, nnodes, **kwargs) self.model = CoupledSonophores([ NeuronalBilayerSonophore(self.a, pneuron) for i in range(self.nnodes)], ga) def pdict(self): return { **super().pdict(), 'ga': self.model.gastr, 'pneuron': self.model.refpneuron, } def getModelAndRunSims(self, Fdrive, tstim, covs, A1, A2): ''' Create passive model for a combination of time constants. ''' drives = AcousticDriveArray([AcousticDrive(Fdrive, A1), AcousticDrive(Fdrive, A2)]) return self.runSims(self.model, drives, tstim, covs) def runSimsOverAmplitudeSpace(self, Fdrive, tstim, covs, A_range, mpi=False): ''' Run simulations over 2D time constant space. ''' # Generate 2D amplitudes meshgrid A_combs = np.meshgrid(A_range, A_range) # Set elements below main diagonal to NaN tril_idxs = np.tril_indices(A_range.size, -1) for x in A_combs: x[tril_idxs] = np.nan # Flatten the meshgrid and assemble into list of tuples A_combs = list(zip(*[x.flatten().tolist() for x in A_combs])) # Remove NaN elements A_combs = list(filter(lambda x: not any(np.isnan(xx) for xx in x), A_combs)) # Assemble queue queue = [[Fdrive, tstim, covs] + list(x) for x in A_combs] batch = Batch(self.getModelAndRunSims, queue) output = batch.run(mpi=mpi) results = [x[0] for x in output] # removing meta # Re-organize results into upper-triangle matrix new_results = np.empty((A_range.size, A_range.size), dtype=object) triu_idxs = np.triu_indices(A_range.size, 0) for *idx, res in zip(*triu_idxs, results): new_results[idx[0], idx[1]] = res return new_results def plotSignalsOverAmplitudeSpace(self, A_range, results, *gamma_args, fs=10): plt_gamma = len(gamma_args) > 0 title = 'gamma' if plt_gamma else 'Qm/Cm0' if plt_gamma: pltfunc = lambda *args: self.plotGamma(*args, *gamma_args) yunit = '' else: pltfunc = self.plotQnorm yunit = 'mV' return self.plotSignalsOver2DSpace( 'A1', A_range, 'Pa', 'A2', A_range, 'Pa', results, pltfunc, title=title, yunit=yunit) diff --git a/PySONIC/plt/divmaps.py b/PySONIC/plt/divmaps.py index 3ffc046..1d1900b 100644 --- a/PySONIC/plt/divmaps.py +++ b/PySONIC/plt/divmaps.py @@ -1,146 +1,155 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2020-06-29 18:11:24 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-05-28 12:28:30 +# @Last Modified time: 2021-05-31 11:22:04 import numpy as np from ..utils import logger, si_format from .xymap import XYMap class DivergenceMap(XYMap): ''' Interface to a 2D map showing divergence of the SONIC output from a cycle-averaged NICE output, for various combinations of parameters. ''' zfactor = 1e0 def __init__(self, benchmark, xvec, yvec, sim_args, eval_mode, eval_args, *args, **kwargs): self.benchmark = benchmark self.sim_args = sim_args self.eval_mode = eval_mode self.eval_args = eval_args super().__init__(self.benchmark.outdir, xvec, yvec, *args, **kwargs) @property def zkey(self): return self.eval_mode @property def zunit(self): if self.eval_mode in ('rmse', 'ss'): return 'mV' elif self.eval_mode == 'gamma': return '' elif self.eval_mode == 'normdiff': return '%' else: raise ValueError(f'unknown z-unit for {self.eval_mode} mode') @property def suffix(self): s = self.eval_mode if len(self.eval_args) > 0: s = f'{s}_{"_".join([str(x) for x in self.eval_args])}' return s def descPair(self, x1, x2): raise NotImplementedError def logDiv(self, x, div): ''' Log divergence for a particular inputs combination. ''' logger.info(f'{self.descPair(*x)}: {self.eval_mode} = {div:.2e} {self.zunit}') def compute(self, x): data, _ = self.benchmark.getModelAndRunSims(*self.sim_args, *x) div = self.benchmark.computeDivergence(data, self.eval_mode, *self.eval_args) self.logDiv(x, div) return div def onClick(self, event): pass def render(self, zscale='log', zbounds=(1e-1, 1e1), extend_under=True, extend_over=True, cmap='Spectral_r', figsize=(6, 4), fs=12, ax=None, **kwargs): ''' Render with default log scale, zbounds, cmap and cbar properties. ''' fig = super().render( zscale=zscale, zbounds=zbounds, extend_under=extend_under, extend_over=extend_over, cmap=cmap, figsize=figsize, fs=fs, ax=ax, **kwargs) if ax is None: fig.canvas.manager.set_window_title(f'{self.corecode()} - {self.eval_mode}') return fig class PassiveDivergenceMap(DivergenceMap): ''' Divergence map of a passive model for various combinations of membrane time constants (taum) and axial time constant (tauax) ''' xkey = 'taum' xfactor = 1e0 xunit = 's' ykey = 'tauax' yfactor = 1e0 yunit = 's' @property def title(self): return f'passive divmap - {self.eval_mode}' def corecode(self): return f'divmap_{self.benchmark.code()}' def descPair(self, taum, tauax): return f'taum = {si_format(taum, 2)}s, tauax = {si_format(tauax, 2)}s' @staticmethod - def addPeriodicityLines(ax, T, dims='xy', color='k'): + def addPeriodicityLines(ax, T, dims='xy', color='k', pattern='cross'): + xmin, ymin = 0, 0 + xmax, ymax = 1, 1 + if pattern in ['upper-square', 'lower-square']: + data_to_axis = ax.transData + ax.transAxes.inverted() + xc, yc = data_to_axis.transform((T, T)) + if pattern == 'upper-square': + xmin, ymin = xc, yc + else: + xmax, ymax = xc, yc if 'x' in dims: - ax.axvline(T, color=color, linestyle='--', linewidth=1.5) + ax.axvline(T, ymin=ymin, ymax=ymax, color=color, linestyle='--', linewidth=1.5) if 'y' in dims: - ax.axhline(T, color=color, linestyle='--', linewidth=1.5) + ax.axhline(T, xmin=xmin, xmax=xmax, color=color, linestyle='--', linewidth=1.5) def render(self, xscale='log', yscale='log', T=None, ax=None, **kwargs): ''' Render with drive periodicty indicator. ''' fig = super().render(xscale=xscale, yscale=yscale, ax=ax, **kwargs) if ax is None: ax = fig.axes[0] if T is not None: self.addPeriodicityLines(ax, T) return fig class FiberDivergenceMap(DivergenceMap): ''' Divergence map of a fiber model for various combinations of acoustic pressure amplitudes in both compartments ''' xkey = 'A1' xfactor = 1e0 xunit = 'Pa' ykey = 'A2' yfactor = 1e0 yunit = 'Pa' def __init__(self, benchmark, Avec, *args, **kwargs): super().__init__(benchmark, Avec, Avec, *args, **kwargs) @property def title(self): return f'fiber divmap - {self.eval_mode}' def corecode(self): return f'divmap_{self.benchmark.code()}' def descPair(self, *amps): return f"A = {', '.join(f'{si_format(A, 2)}Pa' for A in amps)}" def compute(self, x): if x[0] < x[1]: return np.nan return super().compute(x) def render(self, Ascale='log', **kwargs): super().render(xscale=Ascale, yscale=Ascale, **kwargs) diff --git a/PySONIC/plt/xymap.py b/PySONIC/plt/xymap.py index fea85ac..b90bbae 100644 --- a/PySONIC/plt/xymap.py +++ b/PySONIC/plt/xymap.py @@ -1,337 +1,340 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2019-06-04 18:24:29 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-05-30 17:54:01 +# @Last Modified time: 2021-05-31 11:54:30 import abc import csv from itertools import product import numpy as np import pandas as pd import matplotlib.pyplot as plt import copy from ..core import LogBatch from ..utils import logger, isIterable, rangecode from .pltutils import cm2inch, setNormalizer class XYMap(LogBatch): ''' Generic 2D map object interface. ''' offset_options = { 'lr': (1, -1), 'ur': (1, 1), 'll': (-1, -1), 'ul': (-1, 1) } def __init__(self, root, xvec, yvec): self.root = root self.xvec = xvec self.yvec = yvec super().__init__([list(pair) for pair in product(self.xvec, self.yvec)], root=root) def checkVector(self, name, value): if not isIterable(value): raise ValueError(f'{name} vector must be an iterable') if not isinstance(value, np.ndarray): value = np.asarray(value) if len(value.shape) > 1: raise ValueError(f'{name} vector must be one-dimensional') return value @property def in_key(self): return self.xkey @property def unit(self): return self.xunit @property def xvec(self): return self._xvec @xvec.setter def xvec(self, value): self._xvec = self.checkVector('x', value) @property def yvec(self): return self._yvec @yvec.setter def yvec(self, value): self._yvec = self.checkVector('x', value) @property @abc.abstractmethod def xkey(self): raise NotImplementedError @property @abc.abstractmethod def xfactor(self): raise NotImplementedError @property @abc.abstractmethod def xunit(self): raise NotImplementedError @property @abc.abstractmethod def ykey(self): raise NotImplementedError @property @abc.abstractmethod def yfactor(self): raise NotImplementedError @property @abc.abstractmethod def yunit(self): raise NotImplementedError @property @abc.abstractmethod def zkey(self): raise NotImplementedError @property @abc.abstractmethod def zunit(self): raise NotImplementedError @property @abc.abstractmethod def zfactor(self): raise NotImplementedError @property def out_keys(self): return [f'{self.zkey} ({self.zunit})'] @property def in_labels(self): return [f'{self.xkey} ({self.xunit})', f'{self.ykey} ({self.yunit})'] def getLogData(self): ''' Retrieve the batch log file data (inputs and outputs) as a dataframe. ''' return pd.read_csv(self.fpath, sep=self.delimiter).sort_values(self.in_labels) def getInput(self): ''' Retrieve the logged batch inputs as an array. ''' return self.getLogData()[self.in_labels].values def getOutput(self): return np.reshape(super().getOutput(), (self.xvec.size, self.yvec.size)).T def writeLabels(self): with open(self.fpath, 'w') as csvfile: writer = csv.writer(csvfile, delimiter=self.delimiter) writer.writerow([*self.in_labels, *self.out_keys]) def isEntry(self, comb): ''' Check if a given input is logged in the batch log file. ''' inputs = self.getInput() if len(inputs) == 0: return False imatches_x = np.where(np.isclose(inputs[:, 0], comb[0], rtol=self.rtol, atol=self.atol))[0] imatches_y = np.where(np.isclose(inputs[:, 1], comb[1], rtol=self.rtol, atol=self.atol))[0] imatches = list(set(imatches_x).intersection(imatches_y)) if len(imatches) == 0: return False return True @property def inputscode(self): ''' String describing the batch inputs. ''' xcode = rangecode(self.xvec, self.xkey, self.xunit) ycode = rangecode(self.yvec, self.ykey, self.yunit) return '_'.join([xcode, ycode]) @staticmethod def getScaleType(x): xmin, xmax, nx = x.min(), x.max(), x.size if np.all(np.isclose(x, np.logspace(np.log10(xmin), np.log10(xmax), nx))): return 'log' else: return 'lin' # elif np.all(np.isclose(x, np.linspace(xmin, xmax, nx))): # return 'lin' # else: # raise ValueError('Unknown distribution type') @property def xscale(self): return self.getScaleType(self.xvec) @property def yscale(self): return self.getScaleType(self.yvec) @staticmethod def computeMeshEdges(x, scale): ''' Compute the appropriate edges of a mesh that quads a linear or logarihtmic distribution. :param x: the input vector :param scale: the type of distribution ('lin' for linear, 'log' for logarihtmic) :return: the edges vector ''' if scale == 'log': x = np.log10(x) range_func = np.logspace else: range_func = np.linspace dx = x[1] - x[0] n = x.size + 1 return range_func(x[0] - dx / 2, x[-1] + dx / 2, n) @abc.abstractmethod def compute(self, x): ''' Compute the necessary output(s) for a given inputs combination. ''' raise NotImplementedError def run(self, **kwargs): super().run(**kwargs) self.getLogData().to_csv(self.filepath(), sep=self.delimiter, index=False) def getOnClickXY(self, event): ''' Get x and y values from from x and y click event coordinates. ''' x = self.xvec[np.searchsorted(self.xedges, event.xdata) - 1] y = self.yvec[np.searchsorted(self.yedges, event.ydata) - 1] return x, y def onClick(self, event): ''' Exexecute specific action when the user clicks on a cell in the 2D map. ''' pass @property @abc.abstractmethod def title(self): raise NotImplementedError def getZBounds(self): matrix = self.getOutput() zmin, zmax = np.nanmin(matrix), np.nanmax(matrix) logger.info( f'{self.zkey} range: {zmin:.0f} - {zmax:.0f} {self.zunit}') return zmin, zmax def checkZbounds(self, zbounds): zmin, zmax = self.getZBounds() if zmin < zbounds[0]: logger.warning( f'Minimal {self.zkey} ({zmin:.0f} {self.zunit}) is below defined lower bound ({zbounds[0]:.0f} {self.zunit})') if zmax > zbounds[1]: logger.warning( f'Maximal {self.zkey} ({zmax:.0f} {self.zunit}) is above defined upper bound ({zbounds[1]:.0f} {self.zunit})') @staticmethod def addInsets(ax, insets, fs, minimal=False): ax.update_datalim(list(insets.values())) xyoffset = np.array([0, 0.05]) data_to_axis = ax.transData + ax.transAxes.inverted() for k, xydata in insets.items(): ax.scatter(*xydata, s=20, facecolor='k', edgecolor='none') if not minimal: xyaxis = np.array(data_to_axis.transform(xydata)) ax.annotate( k, xy=xyaxis, xytext=xyaxis + xyoffset, xycoords=ax.transAxes, fontsize=fs, arrowprops={'facecolor': 'black', 'arrowstyle': '-'}, ha='right') def render(self, xscale='lin', yscale='lin', zscale='lin', zbounds=None, fs=8, cmap='viridis', interactive=False, figsize=None, insets=None, inset_offset=0.05, - extend_under=False, extend_over=False, ax=None, cbarax=None, title=None, - minimal=False, levels=None, flip=False): + extend_under=False, extend_over=False, ax=None, cbarax=None, cbarlabel='vertical', + title=None, minimal=False, levels=None, flip=False): # Compute z-bounds if zbounds is None: extend_under = False extend_over = False zbounds = self.getZBounds() else: self.checkZbounds(zbounds) # Compute Z normalizer mymap = copy.copy(plt.get_cmap(cmap)) # mymap.set_bad('silver') if not extend_under: mymap.set_under('silver') if not extend_over: mymap.set_over('silver') norm, sm = setNormalizer(mymap, zbounds, zscale) # Compute mesh edges self.xedges = self.computeMeshEdges(self.xvec, xscale) self.yedges = self.computeMeshEdges(self.yvec, yscale) # Create figure if required if ax is None: if figsize is None: figsize = cm2inch(12, 7) fig, ax = plt.subplots(figsize=figsize) fig.subplots_adjust(left=0.15, bottom=0.15, right=0.8, top=0.92) else: fig = ax.get_figure() # Set axis properties if title is None: title = self.title ax.set_title(title, fontsize=fs) if minimal: ax.set_xticks([]) ax.set_yticks([]) ax.tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False) else: ax.set_xlabel(f'{self.xkey} ({self.xunit})', fontsize=fs, labelpad=-0.5) ax.set_ylabel(f'{self.ykey} ({self.yunit})', fontsize=fs) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) if xscale == 'log': ax.set_xscale('log') if yscale == 'log': ax.set_yscale('log') # Retrieve data data = self.getOutput() if flip: data = data.T # Plot map with specific color code ax.pcolormesh(self.xedges, self.yedges, data, cmap=mymap, norm=norm) # Add contour levels if needed if levels is not None: CS = ax.contour( self.xvec, self.yvec, data, levels, colors='k') ax.clabel(CS, fontsize=fs, fmt=lambda x: f'{x:g}', inline_spacing=2) # Add potential insets if insets is not None: self.addInsets(ax, insets, fs, minimal=minimal) # Plot z-scale colorbar if cbarax is None: pos1 = ax.get_position() # get the map axis position cbarax = fig.add_axes([pos1.x1 + 0.02, pos1.y0, 0.03, pos1.height]) if not extend_under and not extend_over: extend = 'neither' elif extend_under and extend_over: extend = 'both' else: extend = 'max' if extend_over else 'min' self.cbar = plt.colorbar(sm, cax=cbarax, extend=extend) - cbarax.set_ylabel(f'{self.zkey} ({self.zunit})', fontsize=fs) + if cbarlabel == 'vertical': + cbarax.set_ylabel(f'{self.zkey} ({self.zunit})', fontsize=fs) + else: + cbarax.set_title(f'{self.zkey} ({self.zunit})', fontsize=fs) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) if interactive: fig.canvas.mpl_connect('button_press_event', lambda event: self.onClick(event)) return fig