Page MenuHomec4science

parameter_sweeper.py
No OneTemporary

File Metadata

Created
Mon, May 13, 14:14

parameter_sweeper.py

# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
from copy import copy
import itertools
import csv
import warnings
import numpy as np
from matplotlib import pyplot as plt
from rrompy.utilities.base.types import Np1D, N2FSExpr, DictAny, List, ROMEng
from rrompy.utilities.base import purgeList, getNewFilename
__all__ = ['ParameterSweeper']
def C2R2csv(x):
x = np.ravel(x)
y = np.concatenate((np.real(x), np.imag(x)))
z = np.ravel(np.reshape(y, [2, np.size(x)]).T)
return np.array2string(z, separator = '_', suppress_small = False,
max_line_width = np.inf, sign = '+',
formatter = {'all' : lambda x : "{:.15E}".format(x)}
)[1 : -1]
class ParameterSweeper:
"""
ROM approximant parameter sweeper.
Args:
ROMEngine(optional): Generic approximant class. Defaults to None.
mutars(optional): Array of parameter values to sweep. Defaults to empty
array.
params(optional): List of parameter settings (each as a dict) to
explore. Defaults to single empty set.
mostExpensive(optional): String containing label of most expensive
step, to be executed fewer times. Allowed options are 'HF' and
'Approx'. Defaults to 'HF'.
normType(optional): Target norm identifier. Must be recognizable by
HSEngine norm command. Defaults to None.
Attributes:
ROMEngine: Generic approximant class.
mutars: Array of parameter values to sweep.
params: List of parameter settings (each as a dict) to explore.
mostExpensive: String containing label of most expensive step, to be
executed fewer times.
"""
allowedOutputsStandard = ["HFNorm", "AppNorm", "ErrNorm"]
allowedOutputs = allowedOutputsStandard + ["HFFunc", "AppFunc", "ErrFunc"]
allowedOutputsFull = allowedOutputs + ["poles"]
def __init__(self, ROMEngine : ROMEng = None, mutars : Np1D = np.array([]),
params : List[DictAny] = [{}], mostExpensive : str = "HF",
normType : N2FSExpr = None):
self.ROMEngine = ROMEngine
self.mutars = mutars
self.params = params
self.mostExpensive = mostExpensive
self.normType = normType
def name(self) -> str:
return self.__class__.__name__
def __str__(self) -> str:
return self.name()
@property
def mostExpensive(self):
"""Value of mostExpensive."""
return self._mostExpensive
@mostExpensive.setter
def mostExpensive(self, mostExpensive:str):
mostExpensive = mostExpensive.upper()
if mostExpensive not in ["HF", "APPROX"]:
warnings.warn(("Value of mostExpensive not recognized. Overriding "
"to 'APPROX'."), stacklevel = 2)
mostExpensive = "APPROX"
self._mostExpensive = mostExpensive
def checkValues(self) -> bool:
"""Check if sweep can be performed."""
if self.ROMEngine is None:
warnings.warn("ROMEngine is missing. Aborting.", stacklevel = 2)
return False
if len(self.mutars) == 0:
warnings.warn("Empty target parameter vector. Aborting.",
stacklevel = 2)
return False
if len(self.params) == 0:
warnings.warn("Empty method parameters vector. Aborting.",
stacklevel = 2)
return False
return True
def sweep(self, filename : str = "out.dat", outputs : List[str] = [],
verbose : int = 1):
if not self.checkValues(): return
try:
if outputs.upper() == "ALL":
outputs = self.allowedOutputsFull
except:
if len(outputs) == 0:
outputs = self.allowedOutputsStandard
outputs = purgeList(outputs, self.allowedOutputsFull,
listname = self.name() + ".outputs",
baselevel = 1)
poles = ("poles" in outputs)
if len(outputs) == 0:
warnings.warn("Empty outputs. Aborting.", stacklevel = 2)
return
outParList = self.ROMEngine.parameterList
Nparams = len(self.params)
if poles: polesCheckList = []
allowedParams = self.ROMEngine.parameterList
dotPos = filename.rfind('.')
if dotPos in [-1, len(filename) - 1]:
filename = getNewFilename(filename[:dotPos])
else:
filename = getNewFilename(filename[:dotPos], filename[dotPos + 1:])
append_write = "w"
initial_row = (outParList + ["muRe", "muIm"]
+ [x for x in self.allowedOutputs if x in outputs]
+ ["type"] + ["poles"] * poles)
with open(filename, append_write, buffering = 1) as fout:
writer = csv.writer(fout, delimiter=",")
writer.writerow(initial_row)
if self.mostExpensive == "HF":
outerSet = self.mutars
innerSet = self.params
elif self.mostExpensive == "APPROX":
outerSet = self.params
innerSet = self.mutars
for outerIdx, outerPar in enumerate(outerSet):
if self.mostExpensive == "HF":
i, mutar = outerIdx, outerPar
elif self.mostExpensive == "APPROX":
j, par = outerIdx, outerPar
self.ROMEngine.approxParameters = {k: par[k] for k in\
par.keys() & allowedParams}
self.ROMEngine.setupApprox()
for innerIdx, innerPar in enumerate(innerSet):
if self.mostExpensive == "APPROX":
i, mutar = innerIdx, innerPar
elif self.mostExpensive == "HF":
j, par = innerIdx, innerPar
self.ROMEngine.approxParameters = {k: par[k] for k in\
par.keys() & allowedParams}
self.ROMEngine.setupApprox()
if verbose >= 1:
print("Set {}/{}\tmu_{} = {:.10f}".format(j+1, Nparams,
i, mutar))
outData = []
if "HFNorm" in outputs:
val = self.ROMEngine.HFNorm(mutar, self.normType)
if isinstance(val, (list,)): val = val[0]
outData = outData + [val]
if "AppNorm" in outputs:
val = self.ROMEngine.approxNorm(mutar, self.normType)
if isinstance(val, (list,)): val = val[0]
outData = outData + [val]
if "ErrNorm" in outputs:
val = self.ROMEngine.approxError(mutar, self.normType)
if isinstance(val, (list,)): val = val[0]
outData = outData + [val]
if "HFFunc" in outputs:
outData = outData +[self.ROMEngine.HFEngine.functional(
self.ROMEngine.getHF(mutar))]
if "AppFunc" in outputs:
outData = outData +[self.ROMEngine.HFEngine.functional(
self.ROMEngine.getApp(mutar))]
if "ErrFunc" in outputs:
outData = outData +[self.ROMEngine.HFEngine.functional(
self.ROMEngine.getApp(mutar))
- self.ROMEngine.HFEngine.functional(
self.ROMEngine.getHF(mutar))]
writeData = []
for parn in outParList:
writeData = (writeData
+ [self.ROMEngine.approxParameters[parn]])
writeData = (writeData + [mutar.real, mutar.imag]
+ outData + [self.ROMEngine.name()])
if poles:
if j not in polesCheckList:
polesCheckList += [j]
writeData = writeData + [C2R2csv(
self.ROMEngine.getPoles())]
else:
writeData = writeData + [""]
writer.writerow(str(x) for x in writeData)
if verbose >= 1:
if self.mostExpensive == "APPROX":
print("Set {}/{}\tdone".format(j+1, Nparams))
elif self.mostExpensive == "HF":
print("Point mu_{} = {:.10f}\tdone".format(i, mutar))
self.filename = filename
return self.filename
def read(self, filename:str, restrictions : DictAny = {},
outputs : List[str] = []) -> DictAny:
"""
Execute a query on a custom format CSV.
Args:
filename: CSV filename.
restrictions(optional): Parameter configurations to output.
Defaults to empty dictionary, i.e. output all.
outputs(optional): Values to output. Defaults to empty list, i.e.
no output.
Returns:
Dictionary of desired results, with a key for each entry of
outputs, and a numpy 1D array as corresponding value.
"""
with open(filename, 'r') as f:
reader = csv.reader(f, delimiter=',')
header = next(reader)
restrIndices, outputIndices, outputData = {}, {}, {}
for key in restrictions.keys():
try:
restrIndices[key] = header.index(key)
if not isinstance(restrictions[key], list):
restrictions[key] = [restrictions[key]]
restrictions[key] = copy(restrictions[key])
except:
warnings.warn("Ignoring key {} from restrictions"\
.format(key), stacklevel = 2)
for key in outputs:
try:
outputIndices[key] = header.index(key)
outputData[key] = np.array([])
except:
warnings.warn("Ignoring key {} from outputs".format(key),
stacklevel = 2)
for row in reader:
restrTrue = True
for key in restrictions.keys():
if row[restrIndices[key]] == restrictions[key]:
continue
try:
if np.any(np.isclose(float(row[restrIndices[key]]),
[float(x) for x in restrictions[key]])):
continue
except: pass
restrTrue = False
if restrTrue:
for key in outputIndices.keys():
try:
val = row[outputIndices[key]]
val = float(val)
finally:
outputData[key] = np.append(outputData[key], val)
return outputData
def plot(self, filename:str, xs:List[str], ys:List[str], zs:List[str],
onePlot : bool = False):
"""
Perform plots from data in filename.
Args:
filename: CSV filename.
xs: Values to put on x axes.
ys: Values to put on y axes.
zs: Meta-values for constraints.
onePlot: Whether to create a single figure per x. Defaults to
False.
"""
zsVals = self.read(filename, outputs = zs)
zs = list(zsVals.keys())
zss = None
for key in zs:
vals = np.unique(zsVals[key])
if zss is None:
zss = copy(vals)
else:
zss = list(itertools.product(zss, vals))
lzs = len(zs)
for z in zss:
if lzs <= 1:
constr = {zs[0] : z}
else:
constr = {zs[j] : z[j] for j in range(len(zs))}
data = self.read(filename, restrictions = constr, outputs = xs+ys)
if onePlot:
for x in xs:
xVals = data[x]
plt.figure()
for y in ys:
yVals = data[y]
label = '{} vs {} for {}'.format(x, y, constr)
plt.semilogy(xVals, yVals, label = label)
plt.legend()
plt.grid()
plt.show()
plt.close()
else:
for x, y in itertools.product(xs, ys):
xVals, yVals = data[x], data[y]
label = '{} vs {} for {}'.format(x, y, constr)
plt.figure()
plt.semilogy(xVals, yVals, label = label)
plt.legend()
plt.grid()
plt.show()
plt.close()

Event Timeline