diff --git a/PySONIC/multicomp/benchmarks.py b/PySONIC/multicomp/benchmarks.py index f50beb1..abe575c 100644 --- a/PySONIC/multicomp/benchmarks.py +++ b/PySONIC/multicomp/benchmarks.py @@ -1,426 +1,426 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2021-05-14 19:42:00 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-06-16 22:49:14 +# @Last Modified time: 2021-06-16 23:15:25 import os import numpy as np import matplotlib.pyplot as plt from ..core import NeuronalBilayerSonophore, PulsedProtocol, Batch from ..core.drives import AcousticDrive, AcousticDriveArray from ..utils import si_format, rmse, rescale, logger, bounds from ..neurons import passiveNeuron from ..postpro import gamma from ..plt import harmonizeAxesLimits, hideSpines, hideTicks, addYscale, addXscale from .coupled_nbls import CoupledSonophores class Benchmark: tsparse_bounds = (1, -2) def __init__(self, a, nnodes, outdir=None, nodecolors=None): self.a = a self.nnodes = nnodes self.outdir = outdir if not os.path.isdir(self.outdir): os.mkdir(self.outdir) if nodecolors is None: nodecolors = plt.get_cmap('Dark2').colors self.nodecolors = nodecolors def pdict(self): return { 'a': f'{self.a * 1e9:.0f} nm', 'nnodes': f'{self.nnodes} nodes', } def pstr(self): l = [] for k, v in self.pdict().items(): if k == 'nnodes': l.append(v) else: l.append(f'{k} = {v}') return ', '.join(l) def __repr__(self): return f'{self.__class__.__name__}({self.pstr()})' def code(self): s = self.__repr__() for k in ['/', '(', ',']: s = s.replace(k, '_') for k in ['=', ' ', ')']: s = s.replace(k, '') return s def runSims(self, model, drives, tstim, covs): ''' Run full and sonic simulations for a specific combination drives, pulsed protocol and coverage fractions, and harmonize outputs. ''' Fdrive = drives[0].f assert all(x.f == Fdrive for x in drives), 'frequencies do not match' assert len(covs) == model.nnodes, 'coverages do not match model dimensions' assert len(drives) == model.nnodes, 'drives do not match model dimensions' # If not provided, compute stimulus duration from model passive properties min_ncycles = 10 ntaumax_conv = 5 if tstim is None: tstim = max(ntaumax_conv * model.taumax, min_ncycles / Fdrive) # Recast stimulus duration as finite multiple of acoustic period tstim = int(np.ceil(tstim * Fdrive)) / Fdrive # s # Pulsed protocol pp = PulsedProtocol(tstim, 0) # Simulate/Load with full and sonic methods data, meta = {}, {} for method in ['full', 'sonic']: data[method], meta[method] = model.simAndSave( drives, pp, covs, method, outdir=self.outdir, overwrite=False, minimize_output=True) # Cycle-average full solution and interpolate sonic solution along same time vector data['cycleavg'] = data['full'].cycleAveraged(1 / Fdrive) data['sonic'] = data['sonic'].interpolate(data['cycleavg'].time) # # Compute normalized charge density profiles and add them to dataset # for simkey, simdata in data.items(): # for nodekey, nodedata in simdata.items(): # nodedata['Qnorm'] = nodedata['Qm'] / model.refpneuron.Cm0 * 1e3 # mV # Return dataset return data, meta def getTime(self, data): ''' Get time vector compatible with cycle-averaged and sonic charge density vectors (with, by default, discarding of bounding artefact elements). ''' return data['cycleavg'].time[self.tsparse_bounds[0]:self.tsparse_bounds[1]] def getCharges(self, data, k, cut_bounds=True): ''' Get node-specific list of cycle-averaged and sonic charge density vectors (with, by default, discarding of bounding artefact elements). ''' Qms = np.array([data[simkey][k]['Qm'].values for simkey in ['cycleavg', 'sonic']]) if cut_bounds: Qms = Qms[:, self.tsparse_bounds[0]:self.tsparse_bounds[1]] return Qms def computeRMSE(self, data): ''' Evaluate per-node RMSE on charge density profiles. ''' return {k: rmse(*self.getCharges(data, k)) for k in data['cycleavg'].keys()} def eval_funcs(self): return { 'rmse': (self.computeRMSE, 'nC/cm2') } def computeDivergence(self, data, eval_mode, *args): ''' Compute divergence according to given eval_mode, and return max value across nodes. ''' divs = list(self.eval_funcs()[eval_mode][0](data, *args).values()) if any(np.isnan(x) for x in divs): return np.nan return max(divs) def plotQm(self, ax, data): ''' Plot charge density signals on an axis. ''' markers = {'full': '-', 'cycleavg': '--', 'sonic': '-'} alphas = {'full': 0.5, 'cycleavg': 1., 'sonic': 1.} # tplt = TimeSeriesPlot.getTimePltVar('ms') # yplt = self.model.refpneuron.getPltVars()['Qm/Cm0'] # mode = 'details' for simkey, simdata in data.items(): for i, (nodekey, nodedata) in enumerate(simdata.items()): y = nodedata['Qm'].values y[-1] = y[-2] ax.plot(nodedata.time * 1e3, y * 1e5, markers[simkey], c=self.nodecolors[i], alpha=alphas[simkey], label=f'{simkey} - {nodekey}') # if simkey == 'cycleavg': # TimeSeriesPlot.materializeSpikes(ax, nodedata, tplt, yplt, c, mode) def plotSignalsOver2DSpace(self, gridxkey, gridxvec, gridxunit, gridykey, gridyvec, gridyunit, results, pltfunc, *args, yunit='', title=None, fs=10, flipud=True, fliplr=False): ''' Plot signals over 2D space. ''' # Create grid-like figure fig, axes = plt.subplots(gridxvec.size, gridyvec.size, figsize=(9, 7)) # Re-arrange axes and labels if flipud/fliplr option is set supylabel_args = {} supxlabel_args = {'y': 1.0, 'va': 'top'} if flipud: axes = axes[::-1] supxlabel_args = {} if fliplr: axes = axes[:, ::-1] supylabel_args = {'x': 1.0, 'ha': 'right'} # Add title and general axes labels if title is not None: fig.suptitle(title, fontsize=fs + 2) fig.supxlabel(gridxkey, fontsize=fs + 2, **supxlabel_args) fig.supylabel(gridykey, fontsize=fs + 2, **supylabel_args) # Loop through the axes and plot results, while storing time ranges i = 0 tranges = [] for i, axrow in enumerate(axes): for j, ax in enumerate(axrow): hideSpines(ax) hideTicks(ax) ax.margins(0) if results[i, j] is not None: pltfunc(ax, results[i, j], *args) tranges.append(np.ptp(ax.get_xlim())) if len(np.unique(tranges)) > 1: # If more than one time range, add common x-scale to all axes tmin = min(tranges) for axrow in axes[::-1]: for ax in axrow: trans = (ax.transData + ax.transAxes.inverted()) xpoints = [trans.transform([x, 0])[0] for x in [0, tmin]] ax.plot(xpoints, [-0.05] * 2, c='k', lw=2, transform=ax.transAxes) else: # Otherwise, add x-scale only to axis opposite to origin side = 'top' if flipud else 'bottom' addXscale(axes[-1, -1], 0, 0.05, unit='ms', fmt='.0f', fs=fs, side=side) # Harmonize y-limits across all axes, and add y-scale to axis opposite to origin harmonizeAxesLimits(axes, dim='y') side = 'left' if fliplr else 'right' if yunit is not None: addYscale(axes[-1, -1], 0.05, 0, unit=yunit, fmt='.0f', fs=fs, side=side) # Set labels for xvec and yvec values along the two figure grid dimensions for ax, x in zip(axes[0, :], gridxvec): ax.set_xlabel(f'{si_format(x)}{gridxunit}', labelpad=15, fontsize=fs + 2) if not flipud: ax.xaxis.set_label_position('top') for ax, y in zip(axes[:, 0], gridyvec): if fliplr: ax.yaxis.set_label_position('right') ax.set_ylabel(f'{si_format(y)}{gridyunit}', labelpad=15, fontsize=fs + 2) # Return figure object return fig class PassiveBenchmark(Benchmark): def __init__(self, a, nnodes, Cm0, ELeak, **kwargs): super().__init__(a, nnodes, **kwargs) self.Cm0 = Cm0 self.ELeak = ELeak def pdict(self): return { **super().pdict(), 'Cm0': f'{self.Cm0 * 1e2:.1f} uF/cm2', 'ELeak': f'{self.ELeak} mV', } def getModelAndRunSims(self, drives, covs, taum, tauax): ''' Create passive model for a combination of time constants. ''' gLeak = self.Cm0 / taum ga = self.Cm0 / tauax pneuron = passiveNeuron(self.Cm0, gLeak, self.ELeak) model = CoupledSonophores([ NeuronalBilayerSonophore(self.a, pneuron) for i in range(self.nnodes)], ga) return self.runSims(model, drives, None, covs) def runSimsOverTauSpace(self, drives, covs, taum_range, tauax_range, mpi=False): ''' Run simulations over 2D time constant space. ''' queue = [[drives, covs] + x for x in Batch.createQueue(taum_range, tauax_range)] batch = Batch(self.getModelAndRunSims, queue) # batch.printQueue(queue) output = batch.run(mpi=mpi) results = [x[0] for x in output] # removing meta return np.reshape(results, (taum_range.size, tauax_range.size)).T def computeSteadyStateDivergence(self, data): ''' Evaluate per-node steady-state absolute deviation on charge density profiles. ''' return {k: np.abs(np.squeeze(np.diff(self.getCharges(data, k), axis=0)))[-1] for k in data['cycleavg'].keys()} @staticmethod def computeAreaRatio(yref, yeval, dt): # Get reference steady-state yinf = yref[-1] # Compute absolute differential signals: between reference solution and its steady-state, # and between the two solutions signals = [np.ones_like(yref) * yinf, yeval] diffsignals = [np.abs(y - yref) for y in signals] # Compute related areas areas = [np.sum(y) * dt for y in diffsignals] # Return ratio of the two areas ratio = areas[1] / areas[0] - logger.debug([f'{x * 1e5:.2f}%.ms' for x in areas], f'ratio = {ratio * 1e2:.2f}%') + logger.debug( + f"{', '.join([f'{x * 1e5:.2f}%.ms' for x in areas])}, ratio = {ratio * 1e2:.2f}%") return ratio def isExponentialChargeBuildup(self, Qm): ''' Check if charge signal corresponds to an exponential build-up. ''' if np.ptp(Qm) < 1e-5: # C/m2 logger.debug('too narrow') return False Qmin, Qmax = bounds(Qm) - Qstart, Qend = Qm[0], Qm[-1] Qbounds_check = dict(atol=1e-7, rtol=1e-5) - if not np.isclose(Qstart, Qmin, **Qbounds_check): - logger.debug('not starting from bottom') - return False - if not np.isclose(Qend, Qmax, **Qbounds_check): + # if not np.isclose(Qm[0], Qmin, **Qbounds_check): + # logger.debug('not starting from bottom') + # return False + if not np.isclose(Qm[-1], Qmax, **Qbounds_check): logger.debug('not finishing on top') return False return True def computeTransientDivergence(self, data): ''' Evaluate per-node mean absolute difference on [0, 1] normalized charge profiles. ''' d = {} t = self.getTime(data) dt = t[1] - t[0] for k in data['cycleavg'].keys(): y = self.getCharges(data, k) # If cycle-avg charge profile corresponds to an exponential build-up if self.isExponentialChargeBuildup(y[0]): # Rescale signals linearly between 0 and 1 ynorms = np.array([rescale(yy) for yy in y]) # Compute ratio between the cycle-avg steady-state convergence area and the # difference area between cycle-avg and sonic solutions d[k] = self.computeAreaRatio(*ynorms, dt) * 1e2 else: d[k] = np.nan return d def eval_funcs(self): return { **super().eval_funcs(), 'ss': (self.computeSteadyStateDivergence, 'nC/cm2', 1e5), 'transient': (self.computeTransientDivergence, '%', 1e0) } def plotSignalsOverTauSpace(self, taum_range, tauax_range, results, pltfunc=None, fs=10): if pltfunc is None: pltfunc = 'plotQm' yunit = {'plotQm': 'nC/cm2', 'plotQnorm': None}[pltfunc] title = pltfunc[4:] pltfunc = getattr(self, pltfunc) return self.plotSignalsOver2DSpace( 'taum', taum_range, 's', 'tauax', tauax_range, 's', results, pltfunc, title=title, yunit=yunit) def plotQnorm(self, ax, data): t = self.getTime(data) for i, (k, nodedata) in enumerate(data['cycleavg'].items()): dt = t[1] - t[0] y = self.getCharges(data, k) c = self.nodecolors[i] ynorms = np.array([rescale(yy) for yy in y]) for yn, marker in zip(ynorms, ['--', '-']): ax.plot(t * 1e3, yn, marker, c=c) if self.isExponentialChargeBuildup(y[0]): ax.axhline(ynorms[0][-1], ls='--', color='k') ax.fill_between(t * 1e3, *ynorms, alpha=0.5, color=c) eps = self.computeAreaRatio(*ynorms, dt) else: eps = np.nan ax.text(0.5, 0.3 * (i + 1), f'{eps * 1e2:.2f}%', c=c, transform=ax.transAxes) class FiberBenchmark(Benchmark): def __init__(self, a, nnodes, pneuron, ga, **kwargs): super().__init__(a, nnodes, **kwargs) self.model = CoupledSonophores([ NeuronalBilayerSonophore(self.a, pneuron) for i in range(self.nnodes)], ga) def pdict(self): return { **super().pdict(), 'ga': self.model.gastr, 'pneuron': self.model.refpneuron, } def getModelAndRunSims(self, Fdrive, tstim, covs, A1, A2): ''' Create passive model for a combination of time constants. ''' drives = AcousticDriveArray([AcousticDrive(Fdrive, A1), AcousticDrive(Fdrive, A2)]) return self.runSims(self.model, drives, tstim, covs) def runSimsOverAmplitudeSpace(self, Fdrive, tstim, covs, A_range, mpi=False, subset=None): ''' Run simulations over 2D time constant space. ''' # Generate 2D amplitudes meshgrid A_combs = np.meshgrid(A_range, A_range) # Set elements below main diagonal to NaN tril_idxs = np.tril_indices(A_range.size, -1) for x in A_combs: x[tril_idxs] = np.nan # Flatten the meshgrid and assemble into list of tuples A_combs = list(zip(*[x.flatten().tolist() for x in A_combs])) # Remove NaN elements A_combs = list(filter(lambda x: not any(np.isnan(xx) for xx in x), A_combs)) # Assemble queue queue = [[Fdrive, tstim, covs] + list(x) for x in A_combs] # restrict queue if subset is specified if subset is not None: queue = queue[subset[0]:subset[1] + 1] batch = Batch(self.getModelAndRunSims, queue) output = batch.run(mpi=mpi) results = [x[0] for x in output] # removing meta # Re-organize results into upper-triangle matrix new_results = np.empty((A_range.size, A_range.size), dtype=object) triu_idxs = np.triu_indices(A_range.size, 0) for *idx, res in zip(*triu_idxs, results): new_results[idx[0], idx[1]] = res return new_results def computeGamma(self, data, *args): ''' Evaluate per-node gamma on charge density profiles. ''' gamma_dict = {} resolution = list(data['cycleavg'].values())[0].dt for k in data['cycleavg'].keys(): # Get charge vectors (discarding 1st and last indexes) and compute gamma gamma_dict[k] = gamma(*self.getCharges(data, k), *args, resolution) return gamma_dict def plotQm(self, ax, data, *gamma_args): super().plotQm(ax, data) gamma_dict = self.computeGamma(data, *gamma_args) tplt = self.getTime(data) * 1e3 data_to_axis = ax.transData + ax.transAxes.inverted() tplt = data_to_axis.transform(tplt) ones = np.ones_like(tplt) for i, (nodekey, nodegamma) in enumerate(gamma_dict.items()): ax.plot(tplt[nodegamma >= 1], ones[nodegamma >= 1] + 0.02 * i, c=self.nodecolors[i], label=nodekey, transform=ax.transAxes) def plotGamma(self, ax, data, *gamma_args): gamma_dict = self.computeGamma(data, *gamma_args) tplt = self.getTime(data) * 1e3 for i, (nodekey, nodegamma) in enumerate(gamma_dict.items()): ax.plot(tplt, nodegamma, c=self.nodecolors[i], label=nodekey) ax.axhline(1, linestyle='--', c='k') def plotSignalsOverAmplitudeSpace(self, A_range, results, *args, pltfunc=None, fs=10): if pltfunc is None: pltfunc = 'plotQm' yunit = {'plotQm': 'nC/cm2', 'plotGamma': ''}[pltfunc] title = pltfunc[4:] pltfunc = getattr(self, pltfunc) return self.plotSignalsOver2DSpace( 'A1', A_range, 'Pa', 'A2', A_range, 'Pa', results, pltfunc, *args, title=title, yunit=yunit) def computeGammaDivergence(self, data, *args): return {k: np.nanmax(v) for k, v in self.computeGamma(data, *args).items()} def eval_funcs(self): return { **super().eval_funcs(), 'gamma': (self.computeGammaDivergence, '', 1e0) } diff --git a/PySONIC/plt/divmaps.py b/PySONIC/plt/divmaps.py index 1f69b63..eb12a8b 100644 --- a/PySONIC/plt/divmaps.py +++ b/PySONIC/plt/divmaps.py @@ -1,186 +1,186 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2020-06-29 18:11:24 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-06-16 19:08:45 +# @Last Modified time: 2021-06-16 23:35:43 import numpy as np import matplotlib.pyplot as plt from ..utils import logger, si_format from .xymap import XYMap class DivergenceMap(XYMap): ''' Interface to a 2D map showing divergence of the SONIC output from a cycle-averaged NICE output, for various combinations of parameters. ''' def __init__(self, benchmark, xvec, yvec, sim_args, eval_mode, eval_args, *args, **kwargs): self.benchmark = benchmark self.sim_args = sim_args self.eval_mode = eval_mode self.eval_args = eval_args super().__init__(self.benchmark.outdir, xvec, yvec, *args, **kwargs) @property def eval_mode(self): return self._eval_mode @eval_mode.setter def eval_mode(self, value): if value not in self.benchmark.eval_funcs().keys(): raise ValueError(f'unknown evalation mode: {value}') self._eval_mode = value @property def zkey(self): return self.eval_mode @property def zunit(self): return self.benchmark.eval_funcs()[self.eval_mode][1] @property def zfactor(self): if self.eval_mode == 'ss': return 1e5 else: return 1e0 @property def suffix(self): s = self.eval_mode if len(self.eval_args) > 0: - s = f'{s}_{"_".join([str(x) for x in self.eval_args])}' + s = f'{s}_{"_".join([f"{x:.2e}" for x in self.eval_args])}' return s def descPair(self, x1, x2): raise NotImplementedError def logDiv(self, x, div): ''' Log divergence for a particular inputs combination. ''' logger.info(f'{self.descPair(*x)}: {self.eval_mode} = {div:.2e} {self.zunit}') def compute(self, x): data, _ = self.benchmark.getModelAndRunSims(*self.sim_args, *x) div = self.benchmark.computeDivergence(data, self.eval_mode, *self.eval_args) self.logDiv(x, div) return div def callbackPltFunc(self): raise NotImplementedError def onClick(self, event): x = self.getOnClickXY(event) data, _ = self.benchmark.getModelAndRunSims(*self.sim_args, *x) fig, ax = plt.subplots(figsize=(4, 4)) ax.set_xlabel('time (ms)') ylabel = 'Qm (nC/cm2)' if self.eval_mode == 'transient': ylabel = 'Qm-norm' ax.set_ylabel(ylabel) for sk in ['top', 'right']: ax.spines[sk].set_visible(False) ax.set_title(self.descPair(*x)) self.callbackPltFunc()(ax, data) plt.show() def render(self, zscale='log', zbounds=(1e-1, 1e1), extend_under=True, extend_over=True, cmap='Spectral_r', figsize=(6, 4), fs=12, ax=None, **kwargs): ''' Render with default log scale, zbounds, cmap and cbar properties. ''' fig = super().render( zscale=zscale, zbounds=zbounds, extend_under=extend_under, extend_over=extend_over, cmap=cmap, figsize=figsize, fs=fs, ax=ax, **kwargs) if ax is None: fig.canvas.manager.set_window_title(f'{self.corecode()} - {self.eval_mode}') return fig class PassiveDivergenceMap(DivergenceMap): ''' Divergence map of a passive model for various combinations of membrane time constants (taum) and axial time constant (tauax) ''' xkey = 'taum' xfactor = 1e0 xunit = 's' ykey = 'tauax' yfactor = 1e0 yunit = 's' @property def title(self): return f'passive divmap - {self.eval_mode}' def corecode(self): return f'divmap_{self.benchmark.code()}' def descPair(self, taum, tauax): return f'taum = {si_format(taum, 2)}s, tauax = {si_format(tauax, 2)}s' @staticmethod def addPeriodicityLines(ax, T, dims='xy', color='k', pattern='cross'): xmin, ymin = 0, 0 xmax, ymax = 1, 1 if pattern in ['upper-square', 'lower-square']: data_to_axis = ax.transData + ax.transAxes.inverted() xc, yc = data_to_axis.transform((T, T)) if pattern == 'upper-square': xmin, ymin = xc, yc else: xmax, ymax = xc, yc if 'x' in dims: ax.axvline(T, ymin=ymin, ymax=ymax, color=color, linestyle='--', linewidth=1.5) if 'y' in dims: ax.axhline(T, xmin=xmin, xmax=xmax, color=color, linestyle='--', linewidth=1.5) def render(self, xscale='log', yscale='log', T=None, ax=None, **kwargs): ''' Render with drive periodicty indicator. ''' fig = super().render(xscale=xscale, yscale=yscale, ax=ax, **kwargs) if ax is None: ax = fig.axes[0] if T is not None: self.addPeriodicityLines(ax, T) return fig def callbackPltFunc(self): return { 'ss': self.benchmark.plotQm, 'transient': self.benchmark.plotQnorm }[self.eval_mode] class FiberDivergenceMap(DivergenceMap): ''' Divergence map of a fiber model for various combinations of acoustic pressure amplitudes in both compartments ''' xkey = 'A1' xfactor = 1e0 xunit = 'Pa' ykey = 'A2' yfactor = 1e0 yunit = 'Pa' def __init__(self, benchmark, Avec, *args, **kwargs): super().__init__(benchmark, Avec, Avec, *args, **kwargs) @property def title(self): return f'fiber divmap - {self.eval_mode}' def corecode(self): return f'divmap_{self.benchmark.code()}' def descPair(self, *amps): return f"A = {', '.join(f'{si_format(A, 2)}Pa' for A in amps)}" def compute(self, x): if x[0] < x[1]: return np.nan return super().compute(x) def render(self, Ascale='log', **kwargs): super().render(xscale=Ascale, yscale=Ascale, **kwargs) diff --git a/PySONIC/plt/xymap.py b/PySONIC/plt/xymap.py index d5da496..c374857 100644 --- a/PySONIC/plt/xymap.py +++ b/PySONIC/plt/xymap.py @@ -1,346 +1,402 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2019-06-04 18:24:29 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-06-16 20:52:44 +# @Last Modified time: 2021-06-17 13:31:22 import abc import csv from itertools import product import numpy as np import pandas as pd import matplotlib.pyplot as plt import copy +from scipy.interpolate import RectBivariateSpline from ..core import LogBatch -from ..utils import logger, isIterable, rangecode +from ..utils import logger, isIterable, rangecode, bounds from .pltutils import cm2inch, setNormalizer class XYMap(LogBatch): ''' Generic 2D map object interface. ''' offset_options = { 'lr': (1, -1), 'ur': (1, 1), 'll': (-1, -1), 'ul': (-1, 1) } def __init__(self, root, xvec, yvec): self.root = root self.xvec = xvec self.yvec = yvec super().__init__([list(pair) for pair in product(self.xvec, self.yvec)], root=root) def checkVector(self, name, value): if not isIterable(value): raise ValueError(f'{name} vector must be an iterable') if not isinstance(value, np.ndarray): value = np.asarray(value) if len(value.shape) > 1: raise ValueError(f'{name} vector must be one-dimensional') return value @property def in_key(self): return self.xkey @property def unit(self): return self.xunit @property def xvec(self): return self._xvec @xvec.setter def xvec(self, value): self._xvec = self.checkVector('x', value) @property def yvec(self): return self._yvec @yvec.setter def yvec(self, value): self._yvec = self.checkVector('x', value) @property @abc.abstractmethod def xkey(self): raise NotImplementedError @property @abc.abstractmethod def xfactor(self): raise NotImplementedError @property @abc.abstractmethod def xunit(self): raise NotImplementedError @property @abc.abstractmethod def ykey(self): raise NotImplementedError @property @abc.abstractmethod def yfactor(self): raise NotImplementedError @property @abc.abstractmethod def yunit(self): raise NotImplementedError @property @abc.abstractmethod def zkey(self): raise NotImplementedError @property @abc.abstractmethod def zunit(self): raise NotImplementedError @property @abc.abstractmethod def zfactor(self): raise NotImplementedError @property def out_keys(self): return [f'{self.zkey} ({self.zunit})'] @property def in_labels(self): return [f'{self.xkey} ({self.xunit})', f'{self.ykey} ({self.yunit})'] def getLogData(self): ''' Retrieve the batch log file data (inputs and outputs) as a dataframe. ''' return pd.read_csv(self.fpath, sep=self.delimiter).sort_values(self.in_labels) def getInput(self): ''' Retrieve the logged batch inputs as an array. ''' return self.getLogData()[self.in_labels].values def getOutput(self): - return np.reshape(super().getOutput(), (self.xvec.size, self.yvec.size)).T + ''' Return map output, shaped as an nx-by-ny matrix. ''' + return np.reshape(super().getOutput(), (self.xvec.size, self.yvec.size)) def writeLabels(self): with open(self.fpath, 'w') as csvfile: writer = csv.writer(csvfile, delimiter=self.delimiter) writer.writerow([*self.in_labels, *self.out_keys]) def isEntry(self, comb): ''' Check if a given input is logged in the batch log file. ''' inputs = self.getInput() if len(inputs) == 0: return False imatches_x = np.where(np.isclose(inputs[:, 0], comb[0], rtol=self.rtol, atol=self.atol))[0] imatches_y = np.where(np.isclose(inputs[:, 1], comb[1], rtol=self.rtol, atol=self.atol))[0] imatches = list(set(imatches_x).intersection(imatches_y)) if len(imatches) == 0: return False return True @property def inputscode(self): ''' String describing the batch inputs. ''' xcode = rangecode(self.xvec, self.xkey, self.xunit) ycode = rangecode(self.yvec, self.ykey, self.yunit) return '_'.join([xcode, ycode]) @staticmethod def getScaleType(x): xmin, xmax, nx = x.min(), x.max(), x.size if np.all(np.isclose(x, np.logspace(np.log10(xmin), np.log10(xmax), nx))): return 'log' else: return 'lin' # elif np.all(np.isclose(x, np.linspace(xmin, xmax, nx))): # return 'lin' # else: # raise ValueError('Unknown distribution type') @property def xscale(self): return self.getScaleType(self.xvec) @property def yscale(self): return self.getScaleType(self.yvec) @staticmethod def computeMeshEdges(x, scale): ''' Compute the appropriate edges of a mesh that quads a linear or logarihtmic distribution. :param x: the input vector :param scale: the type of distribution ('lin' for linear, 'log' for logarihtmic) :return: the edges vector ''' if scale == 'log': x = np.log10(x) range_func = np.logspace else: range_func = np.linspace dx = x[1] - x[0] n = x.size + 1 return range_func(x[0] - dx / 2, x[-1] + dx / 2, n) @abc.abstractmethod def compute(self, x): ''' Compute the necessary output(s) for a given inputs combination. ''' raise NotImplementedError def run(self, **kwargs): super().run(**kwargs) self.getLogData().to_csv(self.filepath(), sep=self.delimiter, index=False) def getOnClickXY(self, event): ''' Get x and y values from from x and y click event coordinates. ''' x = self.xvec[np.searchsorted(self.xedges, event.xdata) - 1] y = self.yvec[np.searchsorted(self.yedges, event.ydata) - 1] return x, y def onClickWrapper(self, event): if event.inaxes == self.ax: return self.onClick(event) def onClick(self, event): ''' Exexecute specific action when the user clicks on a cell in the 2D map. ''' pass @property @abc.abstractmethod def title(self): raise NotImplementedError def getZBounds(self): matrix = self.getOutput() * self.zfactor zmin, zmax = np.nanmin(matrix), np.nanmax(matrix) logger.info( - f'{self.zkey} range: {zmin:.0f} - {zmax:.0f} {self.zunit}') + f'{self.zkey} range: {zmin:.2f} - {zmax:.2f} {self.zunit}') return zmin, zmax def checkZbounds(self, zbounds): zmin, zmax = self.getZBounds() if zmin < zbounds[0]: logger.warning( - f'Minimal {self.zkey} ({zmin:.0f} {self.zunit}) is below defined lower bound ({zbounds[0]:.0f} {self.zunit})') + f'Minimal {self.zkey} ({zmin:.2f} {self.zunit}) is below defined lower bound ({zbounds[0]:.2f} {self.zunit})') if zmax > zbounds[1]: logger.warning( - f'Maximal {self.zkey} ({zmax:.0f} {self.zunit}) is above defined upper bound ({zbounds[1]:.0f} {self.zunit})') + f'Maximal {self.zkey} ({zmax:.2f} {self.zunit}) is above defined upper bound ({zbounds[1]:.2f} {self.zunit})') @staticmethod def addInsets(ax, insets, fs, minimal=False): ax.update_datalim(list(insets.values())) xyoffset = np.array([0, 0.05]) data_to_axis = ax.transData + ax.transAxes.inverted() for k, xydata in insets.items(): ax.scatter(*xydata, s=20, facecolor='k', edgecolor='none') if not minimal: xyaxis = np.array(data_to_axis.transform(xydata)) ax.annotate( k, xy=xyaxis, xytext=xyaxis + xyoffset, xycoords=ax.transAxes, fontsize=fs, arrowprops={'facecolor': 'black', 'arrowstyle': '-'}, ha='right') + @staticmethod + def extrapolate(xref, yref, data, xscale, yscale, xextra=None, yextra=None): + if sum([x is None for x in [xextra, yextra]]) != 1: + raise ValueError('only 1 dimension can be extrapolated at a time') + if xscale == 'log': + xref = np.log10(xref) + if yscale == 'log': + yref = np.log10(yref) + valid_data = ~np.isnan(data) + validrows, validcols = [np.all(valid_data, axis=i) for i in [1, 0]] + ref_xyz = [xref, yref, data] + if xextra is not None: + k, stackaxis, vref, vextra, vscale, other = 'x', 0, xref, xextra, xscale, yref + reverse = False + ref_xyz = ref_xyz[0][validrows], ref_xyz[1], ref_xyz[-1][validrows, :] + if yextra is not None: + k, stackaxis, vref, vextra, vscale, other = 'y', 1, yref, yextra, yscale, xref + reverse = True + ref_xyz = ref_xyz[0], ref_xyz[1][validcols], ref_xyz[-1][:, validcols] + f = RectBivariateSpline(*ref_xyz) + vmin, vmax = bounds(vref) + if any(vmin < vv < vmax for vv in vextra): + raise ValueError(f'new {k} vector must sit entirely outside of reference {k}-range') + if vscale == 'log': + vextra = np.log10(vextra) + interp_xy = [vextra, other] + if reverse: + interp_xy = interp_xy[::-1] + interp_data = f(*interp_xy) + if vextra[0] > vref.max(): + v = (vref, vextra) + data = (data, interp_data) + else: + v = (vextra, vref) + data = (interp_data, data) + v = np.hstack(v) + data = np.concatenate(data, axis=stackaxis) + + if xextra is not None: + x, y = v, yref + if yextra is not None: + x, y = xref, v + + if xscale == 'log': + x = np.power(10., x) + if yscale == 'log': + y = np.power(10., y) + return x, y, data + def render(self, xscale='lin', yscale='lin', zscale='lin', zbounds=None, fs=8, cmap='viridis', interactive=False, figsize=None, insets=None, inset_offset=0.05, extend_under=False, extend_over=False, ax=None, cbarax=None, cbarlabel='vertical', - title=None, minimal=False, levels=None, flip=False, plt_cbar=True): + title=None, minimal=False, levels=None, flip=False, plt_cbar=True, + xextra=None, yextra=None): # Compute z-bounds if zbounds is None: extend_under = False extend_over = False zbounds = self.getZBounds() else: self.checkZbounds(zbounds) # Compute Z normalizer mymap = copy.copy(plt.get_cmap(cmap)) mymap.set_bad('silver') if not extend_under: mymap.set_under('silver') if not extend_over: mymap.set_over('silver') norm, sm = setNormalizer(mymap, zbounds, zscale) - # Compute mesh edges - self.xedges = self.computeMeshEdges(self.xvec, xscale) - self.yedges = self.computeMeshEdges(self.yvec, yscale) - # Create figure if required if ax is None: if figsize is None: figsize = cm2inch(12, 7) fig, ax = plt.subplots(figsize=figsize) fig.subplots_adjust(left=0.15, bottom=0.15, right=0.8, top=0.92) else: fig = ax.get_figure() # Set axis properties if title is None: title = self.title ax.set_title(title, fontsize=fs) if minimal: ax.set_xticks([]) ax.set_yticks([]) ax.tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False) else: ax.set_xlabel(f'{self.xkey} ({self.xunit})', fontsize=fs, labelpad=-0.5) ax.set_ylabel(f'{self.ykey} ({self.yunit})', fontsize=fs) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) if xscale == 'log': ax.set_xscale('log') if yscale == 'log': ax.set_yscale('log') - # Retrieve data - data = self.getOutput() * self.zfactor + # Retrieve data and extrapolate if needed + x, y, data = self.xvec, self.yvec, self.getOutput() * self.zfactor + if xextra is not None: + x, y, data = self.extrapolate(x, y, data, xscale, yscale, xextra=xextra) + elif yextra is not None: + x, y, data = self.extrapolate(x, y, data, xscale, yscale, yextra=yextra) + + # Flip data if required if flip: data = data.T - # Plot map with specific color code - ax.pcolormesh(self.xedges, self.yedges, data, cmap=mymap, norm=norm) + # Compute mesh edges and plot map with specific color code + self.xedges = self.computeMeshEdges(x, xscale) + self.yedges = self.computeMeshEdges(y, yscale) + ax.pcolormesh(self.xedges, self.yedges, data.T, cmap=mymap, norm=norm) # Add contour levels if needed if levels is not None: CS = ax.contour( - self.xvec, self.yvec, data, levels, colors='k') + x, y, data.T, levels, colors='k') ax.clabel(CS, fontsize=fs, fmt=lambda x: f'{x:g}', inline_spacing=2) # Add potential insets if insets is not None: self.addInsets(ax, insets, fs, minimal=minimal) # Plot z-scale colorbar if required if plt_cbar: if cbarax is None: pos1 = ax.get_position() # get the map axis position cbarax = fig.add_axes([pos1.x1 + 0.02, pos1.y0, 0.03, pos1.height]) if not extend_under and not extend_over: extend = 'neither' elif extend_under and extend_over: extend = 'both' else: extend = 'max' if extend_over else 'min' self.cbar = plt.colorbar(sm, cax=cbarax, extend=extend) if cbarlabel == 'vertical': cbarax.set_ylabel(f'{self.zkey} ({self.zunit})', fontsize=fs) else: cbarax.set_title(f'{self.zkey} ({self.zunit})', fontsize=fs) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) if interactive: self.ax = ax fig.canvas.mpl_connect('button_press_event', self.onClickWrapper) return fig