Page MenuHomec4science

parameter_sweeper.py
No OneTemporary

File Metadata

Created
Sat, May 11, 07:14

parameter_sweeper.py

# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
from copy import copy
import itertools
import csv
import numpy as np
from matplotlib import pyplot as plt
from rrompy.utilities.base.types import Np1D, DictAny, List, ROMEng
from rrompy.utilities.base import purgeList, getNewFilename, verbosityDepth
from rrompy.utilities.warning_manager import warn
__all__ = ['ParameterSweeper']
def C2R2csv(x):
x = np.ravel(x)
y = np.concatenate((np.real(x), np.imag(x)))
z = np.ravel(np.reshape(y, [2, np.size(x)]).T)
return np.array2string(z, separator = '_', suppress_small = False,
max_line_width = np.inf, sign = '+',
formatter = {'all' : lambda x : "{:.15E}".format(x)}
)[1 : -1]
class ParameterSweeper:
"""
ROM approximant parameter sweeper.
Args:
ROMEngine(optional): Generic approximant class. Defaults to None.
mutars(optional): Array of parameter values to sweep. Defaults to empty
array.
params(optional): List of parameter settings (each as a dict) to
explore. Defaults to single empty set.
mostExpensive(optional): String containing label of most expensive
step, to be executed fewer times. Allowed options are 'HF' and
'Approx'. Defaults to 'HF'.
Attributes:
ROMEngine: Generic approximant class.
mutars: Array of parameter values to sweep.
params: List of parameter settings (each as a dict) to explore.
mostExpensive: String containing label of most expensive step, to be
executed fewer times.
"""
allowedOutputsStandard = ["normHF", "normApp", "normRes", "normResRel",
"normErr", "normErrRel"]
allowedOutputs = allowedOutputsStandard + ["HFFunc", "AppFunc",
"ErrFunc", "ErrFuncRel"]
allowedOutputsFull = allowedOutputs + ["poles"]
def __init__(self, ROMEngine : ROMEng = None, mutars : Np1D = np.array([]),
params : List[DictAny] = [{}], mostExpensive : str = "HF"):
self.ROMEngine = ROMEngine
self.mutars = mutars
self.params = params
self.mostExpensive = mostExpensive
def name(self) -> str:
return self.__class__.__name__
def __str__(self) -> str:
return self.name()
def __repr__(self) -> str:
return self.__str__() + " at " + hex(id(self))
@property
def mostExpensive(self):
"""Value of mostExpensive."""
return self._mostExpensive
@mostExpensive.setter
def mostExpensive(self, mostExpensive:str):
mostExpensive = mostExpensive.upper()
if mostExpensive not in ["HF", "APPROX"]:
warn(("Value of mostExpensive not recognized. Overriding to "
"'APPROX'."))
mostExpensive = "APPROX"
self._mostExpensive = mostExpensive
def checkValues(self) -> bool:
"""Check if sweep can be performed."""
if self.ROMEngine is None:
raise Exception("ROMEngine is missing. Aborting.")
if len(self.mutars) == 0:
raise Exception("Empty target parameter vector. Aborting.")
if len(self.params) == 0:
raise Exception("Empty method parameters vector. Aborting.")
def sweep(self, filename : str = "out.dat", outputs : List[str] = [],
verbose : int = 10):
self.checkValues()
try:
if outputs.upper() == "ALL":
outputs = self.allowedOutputsFull
except:
if len(outputs) == 0:
outputs = self.allowedOutputsStandard
outputs = purgeList(outputs, self.allowedOutputsFull,
listname = self.name() + ".outputs",
baselevel = 1)
poles = ("poles" in outputs)
if len(outputs) == 0:
raise Exception("Empty outputs. Aborting.")
outParList = self.ROMEngine.parameterList
Nparams = len(self.params)
if poles: polesCheckList = []
allowedParams = self.ROMEngine.parameterList
dotPos = filename.rfind('.')
if dotPos in [-1, len(filename) - 1]:
filename = getNewFilename(filename[:dotPos])
else:
filename = getNewFilename(filename[:dotPos], filename[dotPos + 1:])
append_write = "w"
initial_row = (outParList + ["muRe", "muIm"]
+ [x for x in self.allowedOutputs if x in outputs]
+ ["type"] + ["poles"] * poles)
with open(filename, append_write, buffering = 1) as fout:
writer = csv.writer(fout, delimiter=",")
writer.writerow(initial_row)
if self.mostExpensive == "HF":
outerSet = self.mutars
innerSet = self.params
elif self.mostExpensive == "APPROX":
outerSet = self.params
innerSet = self.mutars
for outerIdx, outerPar in enumerate(outerSet):
if self.mostExpensive == "HF":
i, mutar = outerIdx, outerPar
elif self.mostExpensive == "APPROX":
j, par = outerIdx, outerPar
self.ROMEngine.approxParameters = {k: par[k] for k in\
par.keys() & allowedParams}
self.ROMEngine.setupApprox()
for innerIdx, innerPar in enumerate(innerSet):
if self.mostExpensive == "APPROX":
i, mutar = innerIdx, innerPar
elif self.mostExpensive == "HF":
j, par = innerIdx, innerPar
self.ROMEngine.approxParameters = {k: par[k] for k in\
par.keys() & allowedParams}
self.ROMEngine.setupApprox()
if verbose >= 5:
verbosityDepth("INIT", "Set {}/{}\tmu_{} = {:.10f}"\
.format(j + 1, Nparams, i, mutar))
outData = []
if "normHF" in outputs:
valNorm = self.ROMEngine.normHF(mutar)
outData = outData + [valNorm]
if "normApp" in outputs:
val = self.ROMEngine.normApp(mutar)
outData = outData + [val]
if "normRes" in outputs:
valNRes = self.ROMEngine.normRes(mutar)
outData = outData + [valNRes]
if "normResRel" in outputs:
if "normRes" not in outputs:
valNRes = self.ROMEngine.normRes(mutar)
val = self.ROMEngine.normRHS(mutar)
outData = outData + [valNRes / val]
if "normErr" in outputs:
valNErr = self.ROMEngine.normErr(mutar)
outData = outData + [valNErr]
if "normErrRel" in outputs:
if "normHF" not in outputs:
valNorm = self.ROMEngine.normHF(mutar)
if "normErr" not in outputs:
valNErr = self.ROMEngine.normErr(mutar)
outData = outData + [valNErr / valNorm]
if "HFFunc" in outputs:
valFunc = self.ROMEngine.HFEngine.functional(
self.ROMEngine.getHF(mutar))
outData = outData + [valFunc]
if "AppFunc" in outputs:
valFApp = self.ROMEngine.HFEngine.functional(
self.ROMEngine.getApp(mutar))
outData = outData + [valFApp]
if "ErrFunc" in outputs:
if "HFFunc" not in outputs:
valFunc = self.ROMEngine.HFEngine.functional(
self.ROMEngine.getHF(mutar))
if "AppFunc" not in outputs:
valFApp = self.ROMEngine.HFEngine.functional(
self.ROMEngine.getApp(mutar))
valFErr = np.abs(valFApp - valFunc)
outData = outData + [valFErr]
if "ErrFuncRel" in outputs:
if not ("HFFunc" in outputs or "ErrFunc" in outputs):
valFunc = self.ROMEngine.HFEngine.functional(
self.ROMEngine.getHF(mutar))
if not ("AppFunc" in outputs or "ErrFunc" in outputs):
valFApp = self.ROMEngine.HFEngine.functional(
self.ROMEngine.getApp(mutar))
val = np.nan
if not np.isclose(valFunc, 0.):
val = valFApp / valFunc
outData = outData + [val]
writeData = []
for parn in outParList:
writeData = (writeData
+ [self.ROMEngine.approxParameters[parn]])
writeData = (writeData + [mutar.real, mutar.imag]
+ outData + [self.ROMEngine.name()])
if poles:
if j not in polesCheckList:
polesCheckList += [j]
writeData = writeData + [C2R2csv(
self.ROMEngine.getPoles())]
else:
writeData = writeData + [""]
writer.writerow(str(x) for x in writeData)
if verbose >= 5:
verbosityDepth("DEL", "", end = "", inline = "")
if verbose >= 5:
if self.mostExpensive == "APPROX":
out = "Set {}/{}\tdone.\n".format(j + 1, Nparams)
elif self.mostExpensive == "HF":
out = "Point mu_{} = {:.10f}\tdone.\n".format(i, mutar)
verbosityDepth("INIT", out)
verbosityDepth("DEL", "", end = "", inline = "")
self.filename = filename
return self.filename
def read(self, filename:str, restrictions : DictAny = {},
outputs : List[str] = []) -> DictAny:
"""
Execute a query on a custom format CSV.
Args:
filename: CSV filename.
restrictions(optional): Parameter configurations to output.
Defaults to empty dictionary, i.e. output all.
outputs(optional): Values to output. Defaults to empty list, i.e.
no output.
Returns:
Dictionary of desired results, with a key for each entry of
outputs, and a numpy 1D array as corresponding value.
"""
with open(filename, 'r') as f:
reader = csv.reader(f, delimiter=',')
header = next(reader)
restrIndices, outputIndices, outputData = {}, {}, {}
for key in restrictions.keys():
try:
restrIndices[key] = header.index(key)
if not isinstance(restrictions[key], list):
restrictions[key] = [restrictions[key]]
restrictions[key] = copy(restrictions[key])
except:
warn("Ignoring key {} from restrictions.".format(key))
for key in outputs:
try:
outputIndices[key] = header.index(key)
outputData[key] = np.array([])
except:
warn("Ignoring key {} from outputs.".format(key))
for row in reader:
restrTrue = True
for key in restrictions.keys():
if row[restrIndices[key]] == restrictions[key]:
continue
try:
if np.any(np.isclose(float(row[restrIndices[key]]),
[float(x) for x in restrictions[key]])):
continue
except: pass
restrTrue = False
if restrTrue:
for key in outputIndices.keys():
try:
val = row[outputIndices[key]]
val = float(val)
finally:
outputData[key] = np.append(outputData[key], val)
return outputData
def plot(self, filename:str, xs:List[str], ys:List[str], zs:List[str],
onePlot : bool = False, save : str = None,
saveFormat : str = "eps", saveDPI : int = 100, **figspecs):
"""
Perform plots from data in filename.
Args:
filename: CSV filename.
xs: Values to put on x axes.
ys: Values to put on y axes.
zs: Meta-values for constraints.
onePlot(optional): Whether to create a single figure per x.
Defaults to False.
save(optional): Where to save plot(s). Defaults to None, i.e. no
saving.
saveFormat(optional): Format for saved plot(s). Defaults to "eps".
saveDPI(optional): DPI for saved plot(s). Defaults to 100.
figspecs(optional key args): Optional arguments for matplotlib
figure creation.
"""
if save is not None:
save = save.strip()
zsVals = self.read(filename, outputs = zs)
zs = list(zsVals.keys())
zss = None
for key in zs:
vals = np.unique(zsVals[key])
if zss is None:
zss = copy(vals)
else:
zss = list(itertools.product(zss, vals))
lzs = len(zs)
for z in zss:
if lzs <= 1:
constr = {zs[0] : z}
else:
constr = {zs[j] : z[j] for j in range(len(zs))}
data = self.read(filename, restrictions = constr, outputs = xs+ys)
if onePlot:
for x in xs:
xVals = data[x]
p = plt.figure(**figspecs)
logScale = False
for y in ys:
yVals = data[y]
label = '{} vs {} for {}'.format(y, x, constr)
if np.min(yVals) <= - np.finfo(float).eps:
plt.plot(xVals, yVals, label = label)
else:
plt.plot(xVals, yVals, label = label)
if np.log10(np.max(yVals) / np.min(yVals)) > 1.:
logScale = True
if logScale:
ax = p.get_axes()[0]
ax.set_yscale('log')
plt.legend()
plt.grid()
if save is not None:
prefix = "{}_{}_vs_{}_{}".format(save, ys, x, constr)
plt.savefig(getNewFilename(prefix, saveFormat),
format = saveFormat, dpi = saveDPI)
plt.show()
plt.close()
else:
for x, y in itertools.product(xs, ys):
xVals, yVals = data[x], data[y]
label = '{} vs {} for {}'.format(y, x, constr)
p = plt.figure(**figspecs)
if np.min(yVals) <= - np.finfo(float).eps:
plt.plot(xVals, yVals, label = label)
else:
plt.plot(xVals, yVals, label = label)
if np.log10(np.max(yVals) / np.min(yVals)) > 1.:
ax = p.get_axes()[0]
ax.set_yscale('log')
plt.legend()
plt.grid()
if save is not None:
prefix = "{}_{}_vs_{}_{}".format(save, y, x, constr)
plt.savefig(getNewFilename(prefix, saveFormat),
format = saveFormat, dpi = saveDPI)
plt.show()
plt.close()
def plotCompare(self, filenames:List[str], xs:List[str], ys:List[str],
zs:List[str], onePlot : bool = False, save : str = None,
ylims : dict = None, saveFormat : str = "eps",
saveDPI : int = 100, labels : List[str] = None,
**figspecs):
"""
Perform plots from data in filename1 and filename2.
Args:
filenames: CSV filenames.
xs: Values to put on x axes.
ys: Values to put on y axes.
zs: Meta-values for constraints.
onePlot(optional): Whether to create a single figure per x.
Defaults to False.
save(optional): Where to save plot(s). Defaults to None, i.e. no
saving.
clip(optional): Custom y axis limits. If None, automatic values are
kept. Defaults to None.
saveFormat(optional): Format for saved plot(s). Defaults to "eps".
saveDPI(optional): DPI for saved plot(s). Defaults to 100.
labels: Label for each dataset.
figspecs(optional key args): Optional arguments for matplotlib
figure creation.
"""
nfiles = len(filenames)
if save is not None:
save = save.strip()
if labels is None:
labels = ["{}".format(j + 1) for j in range(nfiles)]
zsVals = self.read(filenames[0], outputs = zs)
zs = list(zsVals.keys())
zss = None
for key in zs:
vals = np.unique(zsVals[key])
if zss is None:
zss = copy(vals)
else:
zss = list(itertools.product(zss, vals))
lzs = len(zs)
for z in zss:
if lzs <= 1:
constr = {zs[0] : z}
else:
constr = {zs[j] : z[j] for j in range(len(zs))}
data = [None] * nfiles
for j in range(nfiles):
data[j] = self.read(filenames[j], restrictions = constr,
outputs = xs + ys)
if onePlot:
for x in xs:
xVals = [None] * nfiles
for j in range(nfiles):
try:
xVals[j] = data[j][x]
except:
pass
p = plt.figure(**figspecs)
logScale = False
for y in ys:
for j in range(nfiles):
try:
yVals = data[j][y]
except:
pass
l = '{} vs {} for {}, {}'.format(y, x, constr,
labels[j])
if np.min(yVals) <= - np.finfo(float).eps:
plt.plot(xVals[j], yVals, label = l)
else:
plt.plot(xVals[j], yVals, label = l)
if np.log10(np.max(yVals)/np.min(yVals)) > 1.:
logScale = True
if logScale:
ax = p.get_axes()[0]
ax.set_yscale('log')
if ylims is not None:
plt.ylim(**ylims)
plt.legend()
plt.grid()
if save is not None:
prefix = "{}_{}_vs_{}_{}".format(save, ys, x, constr)
plt.savefig(getNewFilename(prefix, saveFormat),
format = saveFormat, dpi = saveDPI)
plt.show()
plt.close()
else:
for x, y in itertools.product(xs, ys):
p = plt.figure(**figspecs)
logScale = False
for j in range(nfiles):
xVals, yVals = data[j][x], data[j][y]
l = '{} vs {} for {}, {}'.format(y, x, constr,
labels[j])
if np.min(yVals) <= - np.finfo(float).eps:
plt.plot(xVals, yVals, label = l)
else:
plt.plot(xVals, yVals, label = l)
if np.log10(np.max(yVals)/np.min(yVals)) > 1.:
logScale = True
if logScale:
ax = p.get_axes()[0]
ax.set_yscale('log')
if ylims is not None:
plt.ylim(**ylims)
plt.legend()
plt.grid()
if save is not None:
prefix = "{}_{}_vs_{}_{}".format(save, y, x, constr)
plt.savefig(getNewFilename(prefix, saveFormat),
format = saveFormat, dpi = saveDPI)
plt.show()
plt.close()

Event Timeline