Page MenuHomec4science

xymap.py
No OneTemporary

File Metadata

Created
Wed, May 1, 01:15

xymap.py

# -*- coding: utf-8 -*-
# @Author: Theo Lemaire
# @Email: theo.lemaire@epfl.ch
# @Date: 2019-06-04 18:24:29
# @Last Modified by: Theo Lemaire
# @Last Modified time: 2021-05-31 11:54:30
import abc
import csv
from itertools import product
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import copy
from ..core import LogBatch
from ..utils import logger, isIterable, rangecode
from .pltutils import cm2inch, setNormalizer
class XYMap(LogBatch):
''' Generic 2D map object interface. '''
offset_options = {
'lr': (1, -1),
'ur': (1, 1),
'll': (-1, -1),
'ul': (-1, 1)
}
def __init__(self, root, xvec, yvec):
self.root = root
self.xvec = xvec
self.yvec = yvec
super().__init__([list(pair) for pair in product(self.xvec, self.yvec)], root=root)
def checkVector(self, name, value):
if not isIterable(value):
raise ValueError(f'{name} vector must be an iterable')
if not isinstance(value, np.ndarray):
value = np.asarray(value)
if len(value.shape) > 1:
raise ValueError(f'{name} vector must be one-dimensional')
return value
@property
def in_key(self):
return self.xkey
@property
def unit(self):
return self.xunit
@property
def xvec(self):
return self._xvec
@xvec.setter
def xvec(self, value):
self._xvec = self.checkVector('x', value)
@property
def yvec(self):
return self._yvec
@yvec.setter
def yvec(self, value):
self._yvec = self.checkVector('x', value)
@property
@abc.abstractmethod
def xkey(self):
raise NotImplementedError
@property
@abc.abstractmethod
def xfactor(self):
raise NotImplementedError
@property
@abc.abstractmethod
def xunit(self):
raise NotImplementedError
@property
@abc.abstractmethod
def ykey(self):
raise NotImplementedError
@property
@abc.abstractmethod
def yfactor(self):
raise NotImplementedError
@property
@abc.abstractmethod
def yunit(self):
raise NotImplementedError
@property
@abc.abstractmethod
def zkey(self):
raise NotImplementedError
@property
@abc.abstractmethod
def zunit(self):
raise NotImplementedError
@property
@abc.abstractmethod
def zfactor(self):
raise NotImplementedError
@property
def out_keys(self):
return [f'{self.zkey} ({self.zunit})']
@property
def in_labels(self):
return [f'{self.xkey} ({self.xunit})', f'{self.ykey} ({self.yunit})']
def getLogData(self):
''' Retrieve the batch log file data (inputs and outputs) as a dataframe. '''
return pd.read_csv(self.fpath, sep=self.delimiter).sort_values(self.in_labels)
def getInput(self):
''' Retrieve the logged batch inputs as an array. '''
return self.getLogData()[self.in_labels].values
def getOutput(self):
return np.reshape(super().getOutput(), (self.xvec.size, self.yvec.size)).T
def writeLabels(self):
with open(self.fpath, 'w') as csvfile:
writer = csv.writer(csvfile, delimiter=self.delimiter)
writer.writerow([*self.in_labels, *self.out_keys])
def isEntry(self, comb):
''' Check if a given input is logged in the batch log file. '''
inputs = self.getInput()
if len(inputs) == 0:
return False
imatches_x = np.where(np.isclose(inputs[:, 0], comb[0], rtol=self.rtol, atol=self.atol))[0]
imatches_y = np.where(np.isclose(inputs[:, 1], comb[1], rtol=self.rtol, atol=self.atol))[0]
imatches = list(set(imatches_x).intersection(imatches_y))
if len(imatches) == 0:
return False
return True
@property
def inputscode(self):
''' String describing the batch inputs. '''
xcode = rangecode(self.xvec, self.xkey, self.xunit)
ycode = rangecode(self.yvec, self.ykey, self.yunit)
return '_'.join([xcode, ycode])
@staticmethod
def getScaleType(x):
xmin, xmax, nx = x.min(), x.max(), x.size
if np.all(np.isclose(x, np.logspace(np.log10(xmin), np.log10(xmax), nx))):
return 'log'
else:
return 'lin'
# elif np.all(np.isclose(x, np.linspace(xmin, xmax, nx))):
# return 'lin'
# else:
# raise ValueError('Unknown distribution type')
@property
def xscale(self):
return self.getScaleType(self.xvec)
@property
def yscale(self):
return self.getScaleType(self.yvec)
@staticmethod
def computeMeshEdges(x, scale):
''' Compute the appropriate edges of a mesh that quads a linear or logarihtmic distribution.
:param x: the input vector
:param scale: the type of distribution ('lin' for linear, 'log' for logarihtmic)
:return: the edges vector
'''
if scale == 'log':
x = np.log10(x)
range_func = np.logspace
else:
range_func = np.linspace
dx = x[1] - x[0]
n = x.size + 1
return range_func(x[0] - dx / 2, x[-1] + dx / 2, n)
@abc.abstractmethod
def compute(self, x):
''' Compute the necessary output(s) for a given inputs combination. '''
raise NotImplementedError
def run(self, **kwargs):
super().run(**kwargs)
self.getLogData().to_csv(self.filepath(), sep=self.delimiter, index=False)
def getOnClickXY(self, event):
''' Get x and y values from from x and y click event coordinates. '''
x = self.xvec[np.searchsorted(self.xedges, event.xdata) - 1]
y = self.yvec[np.searchsorted(self.yedges, event.ydata) - 1]
return x, y
def onClick(self, event):
''' Exexecute specific action when the user clicks on a cell in the 2D map. '''
pass
@property
@abc.abstractmethod
def title(self):
raise NotImplementedError
def getZBounds(self):
matrix = self.getOutput()
zmin, zmax = np.nanmin(matrix), np.nanmax(matrix)
logger.info(
f'{self.zkey} range: {zmin:.0f} - {zmax:.0f} {self.zunit}')
return zmin, zmax
def checkZbounds(self, zbounds):
zmin, zmax = self.getZBounds()
if zmin < zbounds[0]:
logger.warning(
f'Minimal {self.zkey} ({zmin:.0f} {self.zunit}) is below defined lower bound ({zbounds[0]:.0f} {self.zunit})')
if zmax > zbounds[1]:
logger.warning(
f'Maximal {self.zkey} ({zmax:.0f} {self.zunit}) is above defined upper bound ({zbounds[1]:.0f} {self.zunit})')
@staticmethod
def addInsets(ax, insets, fs, minimal=False):
ax.update_datalim(list(insets.values()))
xyoffset = np.array([0, 0.05])
data_to_axis = ax.transData + ax.transAxes.inverted()
for k, xydata in insets.items():
ax.scatter(*xydata, s=20, facecolor='k', edgecolor='none')
if not minimal:
xyaxis = np.array(data_to_axis.transform(xydata))
ax.annotate(
k, xy=xyaxis, xytext=xyaxis + xyoffset, xycoords=ax.transAxes,
fontsize=fs, arrowprops={'facecolor': 'black', 'arrowstyle': '-'},
ha='right')
def render(self, xscale='lin', yscale='lin', zscale='lin', zbounds=None, fs=8, cmap='viridis',
interactive=False, figsize=None, insets=None, inset_offset=0.05,
extend_under=False, extend_over=False, ax=None, cbarax=None, cbarlabel='vertical',
title=None, minimal=False, levels=None, flip=False):
# Compute z-bounds
if zbounds is None:
extend_under = False
extend_over = False
zbounds = self.getZBounds()
else:
self.checkZbounds(zbounds)
# Compute Z normalizer
mymap = copy.copy(plt.get_cmap(cmap))
# mymap.set_bad('silver')
if not extend_under:
mymap.set_under('silver')
if not extend_over:
mymap.set_over('silver')
norm, sm = setNormalizer(mymap, zbounds, zscale)
# Compute mesh edges
self.xedges = self.computeMeshEdges(self.xvec, xscale)
self.yedges = self.computeMeshEdges(self.yvec, yscale)
# Create figure if required
if ax is None:
if figsize is None:
figsize = cm2inch(12, 7)
fig, ax = plt.subplots(figsize=figsize)
fig.subplots_adjust(left=0.15, bottom=0.15, right=0.8, top=0.92)
else:
fig = ax.get_figure()
# Set axis properties
if title is None:
title = self.title
ax.set_title(title, fontsize=fs)
if minimal:
ax.set_xticks([])
ax.set_yticks([])
ax.tick_params(axis='both', which='both', bottom=False, left=False,
labelbottom=False, labelleft=False)
else:
ax.set_xlabel(f'{self.xkey} ({self.xunit})', fontsize=fs, labelpad=-0.5)
ax.set_ylabel(f'{self.ykey} ({self.yunit})', fontsize=fs)
for item in ax.get_xticklabels() + ax.get_yticklabels():
item.set_fontsize(fs)
if xscale == 'log':
ax.set_xscale('log')
if yscale == 'log':
ax.set_yscale('log')
# Retrieve data
data = self.getOutput()
if flip:
data = data.T
# Plot map with specific color code
ax.pcolormesh(self.xedges, self.yedges, data, cmap=mymap, norm=norm)
# Add contour levels if needed
if levels is not None:
CS = ax.contour(
self.xvec, self.yvec, data, levels, colors='k')
ax.clabel(CS, fontsize=fs, fmt=lambda x: f'{x:g}', inline_spacing=2)
# Add potential insets
if insets is not None:
self.addInsets(ax, insets, fs, minimal=minimal)
# Plot z-scale colorbar
if cbarax is None:
pos1 = ax.get_position() # get the map axis position
cbarax = fig.add_axes([pos1.x1 + 0.02, pos1.y0, 0.03, pos1.height])
if not extend_under and not extend_over:
extend = 'neither'
elif extend_under and extend_over:
extend = 'both'
else:
extend = 'max' if extend_over else 'min'
self.cbar = plt.colorbar(sm, cax=cbarax, extend=extend)
if cbarlabel == 'vertical':
cbarax.set_ylabel(f'{self.zkey} ({self.zunit})', fontsize=fs)
else:
cbarax.set_title(f'{self.zkey} ({self.zunit})', fontsize=fs)
for item in cbarax.get_yticklabels():
item.set_fontsize(fs)
if interactive:
fig.canvas.mpl_connect('button_press_event', lambda event: self.onClick(event))
return fig

Event Timeline