diff --git a/PySONIC/plt/xymap.py b/PySONIC/plt/xymap.py index 78f6620..fea85ac 100644 --- a/PySONIC/plt/xymap.py +++ b/PySONIC/plt/xymap.py @@ -1,336 +1,337 @@ # -*- coding: utf-8 -*- # @Author: Theo Lemaire # @Email: theo.lemaire@epfl.ch # @Date: 2019-06-04 18:24:29 # @Last Modified by: Theo Lemaire -# @Last Modified time: 2021-05-28 12:12:10 +# @Last Modified time: 2021-05-30 17:54:01 import abc import csv from itertools import product import numpy as np import pandas as pd import matplotlib.pyplot as plt import copy from ..core import LogBatch from ..utils import logger, isIterable, rangecode from .pltutils import cm2inch, setNormalizer class XYMap(LogBatch): ''' Generic 2D map object interface. ''' offset_options = { 'lr': (1, -1), 'ur': (1, 1), 'll': (-1, -1), 'ul': (-1, 1) } def __init__(self, root, xvec, yvec): self.root = root self.xvec = xvec self.yvec = yvec super().__init__([list(pair) for pair in product(self.xvec, self.yvec)], root=root) def checkVector(self, name, value): if not isIterable(value): raise ValueError(f'{name} vector must be an iterable') if not isinstance(value, np.ndarray): value = np.asarray(value) if len(value.shape) > 1: raise ValueError(f'{name} vector must be one-dimensional') return value @property def in_key(self): return self.xkey @property def unit(self): return self.xunit @property def xvec(self): return self._xvec @xvec.setter def xvec(self, value): self._xvec = self.checkVector('x', value) @property def yvec(self): return self._yvec @yvec.setter def yvec(self, value): self._yvec = self.checkVector('x', value) @property @abc.abstractmethod def xkey(self): raise NotImplementedError @property @abc.abstractmethod def xfactor(self): raise NotImplementedError @property @abc.abstractmethod def xunit(self): raise NotImplementedError @property @abc.abstractmethod def ykey(self): raise NotImplementedError @property @abc.abstractmethod def yfactor(self): raise NotImplementedError @property @abc.abstractmethod def yunit(self): raise NotImplementedError @property @abc.abstractmethod def zkey(self): raise NotImplementedError @property @abc.abstractmethod def zunit(self): raise NotImplementedError @property @abc.abstractmethod def zfactor(self): raise NotImplementedError @property def out_keys(self): return [f'{self.zkey} ({self.zunit})'] @property def in_labels(self): return [f'{self.xkey} ({self.xunit})', f'{self.ykey} ({self.yunit})'] def getLogData(self): ''' Retrieve the batch log file data (inputs and outputs) as a dataframe. ''' return pd.read_csv(self.fpath, sep=self.delimiter).sort_values(self.in_labels) def getInput(self): ''' Retrieve the logged batch inputs as an array. ''' return self.getLogData()[self.in_labels].values def getOutput(self): return np.reshape(super().getOutput(), (self.xvec.size, self.yvec.size)).T def writeLabels(self): with open(self.fpath, 'w') as csvfile: writer = csv.writer(csvfile, delimiter=self.delimiter) writer.writerow([*self.in_labels, *self.out_keys]) def isEntry(self, comb): ''' Check if a given input is logged in the batch log file. ''' inputs = self.getInput() if len(inputs) == 0: return False imatches_x = np.where(np.isclose(inputs[:, 0], comb[0], rtol=self.rtol, atol=self.atol))[0] imatches_y = np.where(np.isclose(inputs[:, 1], comb[1], rtol=self.rtol, atol=self.atol))[0] imatches = list(set(imatches_x).intersection(imatches_y)) if len(imatches) == 0: return False return True @property def inputscode(self): ''' String describing the batch inputs. ''' xcode = rangecode(self.xvec, self.xkey, self.xunit) ycode = rangecode(self.yvec, self.ykey, self.yunit) return '_'.join([xcode, ycode]) @staticmethod def getScaleType(x): xmin, xmax, nx = x.min(), x.max(), x.size if np.all(np.isclose(x, np.logspace(np.log10(xmin), np.log10(xmax), nx))): return 'log' else: return 'lin' # elif np.all(np.isclose(x, np.linspace(xmin, xmax, nx))): # return 'lin' # else: # raise ValueError('Unknown distribution type') @property def xscale(self): return self.getScaleType(self.xvec) @property def yscale(self): return self.getScaleType(self.yvec) @staticmethod def computeMeshEdges(x, scale): ''' Compute the appropriate edges of a mesh that quads a linear or logarihtmic distribution. :param x: the input vector :param scale: the type of distribution ('lin' for linear, 'log' for logarihtmic) :return: the edges vector ''' if scale == 'log': x = np.log10(x) range_func = np.logspace else: range_func = np.linspace dx = x[1] - x[0] n = x.size + 1 return range_func(x[0] - dx / 2, x[-1] + dx / 2, n) @abc.abstractmethod def compute(self, x): ''' Compute the necessary output(s) for a given inputs combination. ''' raise NotImplementedError def run(self, **kwargs): super().run(**kwargs) self.getLogData().to_csv(self.filepath(), sep=self.delimiter, index=False) def getOnClickXY(self, event): ''' Get x and y values from from x and y click event coordinates. ''' x = self.xvec[np.searchsorted(self.xedges, event.xdata) - 1] y = self.yvec[np.searchsorted(self.yedges, event.ydata) - 1] return x, y def onClick(self, event): ''' Exexecute specific action when the user clicks on a cell in the 2D map. ''' pass @property @abc.abstractmethod def title(self): raise NotImplementedError def getZBounds(self): matrix = self.getOutput() zmin, zmax = np.nanmin(matrix), np.nanmax(matrix) logger.info( f'{self.zkey} range: {zmin:.0f} - {zmax:.0f} {self.zunit}') return zmin, zmax def checkZbounds(self, zbounds): zmin, zmax = self.getZBounds() if zmin < zbounds[0]: logger.warning( f'Minimal {self.zkey} ({zmin:.0f} {self.zunit}) is below defined lower bound ({zbounds[0]:.0f} {self.zunit})') if zmax > zbounds[1]: logger.warning( f'Maximal {self.zkey} ({zmax:.0f} {self.zunit}) is above defined upper bound ({zbounds[1]:.0f} {self.zunit})') - def addInsets(self, ax, insets, fs, minimal=False): - xyoffset = np.array([0, 0.1]) - axis_to_data = ax.transAxes + ax.transData.inverted() - data_to_axis = axis_to_data.inverted() + @staticmethod + def addInsets(ax, insets, fs, minimal=False): + ax.update_datalim(list(insets.values())) + xyoffset = np.array([0, 0.05]) + data_to_axis = ax.transData + ax.transAxes.inverted() for k, xydata in insets.items(): ax.scatter(*xydata, s=20, facecolor='k', edgecolor='none') if not minimal: xyaxis = np.array(data_to_axis.transform(xydata)) ax.annotate( k, xy=xyaxis, xytext=xyaxis + xyoffset, xycoords=ax.transAxes, fontsize=fs, arrowprops={'facecolor': 'black', 'arrowstyle': '-'}, ha='right') def render(self, xscale='lin', yscale='lin', zscale='lin', zbounds=None, fs=8, cmap='viridis', interactive=False, figsize=None, insets=None, inset_offset=0.05, extend_under=False, extend_over=False, ax=None, cbarax=None, title=None, minimal=False, levels=None, flip=False): # Compute z-bounds if zbounds is None: extend_under = False extend_over = False zbounds = self.getZBounds() else: self.checkZbounds(zbounds) # Compute Z normalizer mymap = copy.copy(plt.get_cmap(cmap)) # mymap.set_bad('silver') if not extend_under: mymap.set_under('silver') if not extend_over: mymap.set_over('silver') norm, sm = setNormalizer(mymap, zbounds, zscale) # Compute mesh edges self.xedges = self.computeMeshEdges(self.xvec, xscale) self.yedges = self.computeMeshEdges(self.yvec, yscale) # Create figure if required if ax is None: if figsize is None: figsize = cm2inch(12, 7) fig, ax = plt.subplots(figsize=figsize) fig.subplots_adjust(left=0.15, bottom=0.15, right=0.8, top=0.92) else: fig = ax.get_figure() # Set axis properties if title is None: title = self.title ax.set_title(title, fontsize=fs) if minimal: ax.set_xticks([]) ax.set_yticks([]) ax.tick_params(axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False) else: ax.set_xlabel(f'{self.xkey} ({self.xunit})', fontsize=fs, labelpad=-0.5) ax.set_ylabel(f'{self.ykey} ({self.yunit})', fontsize=fs) for item in ax.get_xticklabels() + ax.get_yticklabels(): item.set_fontsize(fs) if xscale == 'log': ax.set_xscale('log') if yscale == 'log': ax.set_yscale('log') # Retrieve data data = self.getOutput() if flip: data = data.T # Plot map with specific color code ax.pcolormesh(self.xedges, self.yedges, data, cmap=mymap, norm=norm) # Add contour levels if needed if levels is not None: CS = ax.contour( self.xvec, self.yvec, data, levels, colors='k') ax.clabel(CS, fontsize=fs, fmt=lambda x: f'{x:g}', inline_spacing=2) # Add potential insets if insets is not None: self.addInsets(ax, insets, fs, minimal=minimal) # Plot z-scale colorbar if cbarax is None: pos1 = ax.get_position() # get the map axis position cbarax = fig.add_axes([pos1.x1 + 0.02, pos1.y0, 0.03, pos1.height]) if not extend_under and not extend_over: extend = 'neither' elif extend_under and extend_over: extend = 'both' else: extend = 'max' if extend_over else 'min' self.cbar = plt.colorbar(sm, cax=cbarax, extend=extend) cbarax.set_ylabel(f'{self.zkey} ({self.zunit})', fontsize=fs) for item in cbarax.get_yticklabels(): item.set_fontsize(fs) if interactive: fig.canvas.mpl_connect('button_press_event', lambda event: self.onClick(event)) return fig