Page MenuHomec4science

parameter_sweeper.py
No OneTemporary

File Metadata

Created
Wed, May 14, 14:01

parameter_sweeper.py

#!/usr/bin/python
# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
import os
import csv
import warnings
import numpy as np
from rrompy.utilities.types import Np1D, DictAny, List, ROMEng, Tuple
from rrompy.utilities import purgeList
__all__ = ['ParameterSweeper']
class ParameterSweeper:
"""
ROM approximant parameter sweeper.
Args:
ROMEngine(optional): ROMApproximant class. Defaults to None.
mutars(optional): Array of parameter values to sweep. Defaults to empty
array.
params(optional): List of parameter settings (each as a dict) to
explore. Defaults to single empty set.
mostExpensive(optional): String containing label of most expensive
step, to be executed fewer times. Allowed options are 'HF' and
'Approx'. Defaults to 'HF'.
Attributes:
ROMEngine: ROMApproximant class.
mutars: Array of parameter values to sweep.
params: List of parameter settings (each as a dict) to explore.
mostExpensive: String containing label of most expensive step, to be
executed fewer times.
"""
allowedOutputs = ["HFNorm", "HFFunc", "AppNorm", "AppFunc",
"ErrNorm", "ErrFunc"]
def __init__(self, ROMEngine : ROMEng = None, mutars : Np1D = np.array([]),
params : List[DictAny] = [{}], mostExpensive : str = "HF"):
self.ROMEngine = ROMEngine
self.mutars = mutars
self.params = params
self.mostExpensive = mostExpensive
def name(self) -> str:
"""Approximant label."""
return self.__class__.__name__
@property
def mostExpensive(self):
"""Value of mostExpensive."""
return self._mostExpensive
@mostExpensive.setter
def mostExpensive(self, mostExpensive:str):
mostExpensive = mostExpensive.upper()
if mostExpensive not in ["HF", "APPROX"]:
warnings.warn(("Value of mostExpensive not recognized. Overriding "
"to 'HF'."), stacklevel = 2)
mostExpensive = "HF"
self._mostExpensive = mostExpensive
def checkValues(self) -> bool:
"""Check if sweep can be performed."""
if self.ROMEngine is None:
warnings.warn("ROMEngine is missing. Aborting.", stacklevel = 2)
return False
if len(self.mutars) == 0:
warnings.warn("Empty target parameter vector. Aborting.",
stacklevel = 2)
return False
if len(self.params) == 0:
warnings.warn("Empty method parameters vector. Aborting.",
stacklevel = 2)
return False
return True
def sweep(self, filename : str = "./out", outputs : List[str] = [],
verbose : int = 1):
if not self.checkValues(): return
try:
if outputs.upper() == "ALL":
outputs = self.allowedOutputs + ["poles"]
except:
if len(outputs) == 0:
outputs = self.allowedOutputs
outputs = purgeList(outputs, self.allowedOutputs + ["poles"],
listname = self.name() + ".outputs")
poles = ("poles" in outputs)
if len(outputs) == 0:
warnings.warn("Empty outputs. Aborting.", stacklevel = 2)
return
outParList = self.ROMEngine.parameterList
Nparams = len(self.params)
allowedParams = self.ROMEngine.parameterList
while os.path.exists(filename):
filename = filename + "{}".format(np.random.randint(10))
append_write = "w"
initial_row = (outParList + ["muRe", "muIm"]
+ [x for x in self.allowedOutputs if x in outputs]
+ ["type"] + ["poles"] * poles)
with open(filename, append_write, buffering = 1) as fout:
writer = csv.writer(fout, delimiter=",")
writer.writerow(initial_row)
if self.mostExpensive == "HF":
outerSet = self.mutars
innerSet = self.params
elif self.mostExpensive == "APPROX":
outerSet = self.params
innerSet = self.mutars
for outerIdx, outerPar in enumerate(outerSet):
if self.mostExpensive == "HF":
i, mutar = outerIdx, outerPar
elif self.mostExpensive == "APPROX":
j, par = outerIdx, outerPar
self.ROMEngine.approxParameters = {k: par[k] for k in\
par.keys() & allowedParams}
self.ROMEngine.setupApprox()
for innerIdx, innerPar in enumerate(innerSet):
if self.mostExpensive == "APPROX":
i, mutar = innerIdx, innerPar
elif self.mostExpensive == "HF":
j, par = innerIdx, innerPar
self.ROMEngine.approxParameters = {k: par[k] for k in\
par.keys() & allowedParams}
self.ROMEngine.setupApprox()
if verbose >= 1:
print("Set {}/{}\tmu_{} = {:.10f}{}"\
.format(j+1, Nparams, i, mutar, " " * 15),
end="\r")
outData = []
if "HFNorm" in outputs:
val = self.ROMEngine.HFNorm(mutar)
if isinstance(val, (list,)): val = val[0]
outData = outData + [val]
if "HFFunc" in outputs:
outData = outData +[self.ROMEngine.HFEngine.functional(
self.ROMEngine.solveHF(mutar))]
if "AppNorm" in outputs:
val = self.ROMEngine.approxNorm(mutar)
if isinstance(val, (list,)): val = val[0]
outData = outData + [val]
if "AppFunc" in outputs:
outData = outData +[self.ROMEngine.HFEngine.functional(
self.ROMEngine.evalApprox(mutar))]
if "ErrNorm" in outputs:
val = self.ROMEngine.approxError(mutar)
if isinstance(val, (list,)): val = val[0]
outData = outData + [val]
if "ErrFunc" in outputs:
outData = outData +[self.ROMEngine.HFEngine.functional(
self.ROMEngine.evalApprox(mutar)
- self.ROMEngine.solveHF(mutar))]
writeData = []
for parn in outParList:
writeData = (writeData
+ [self.ROMEngine.approxParameters[parn]])
writeData = (writeData + [mutar.real, mutar.imag]
+ outData + [self.ROMEngine.name()])
if poles:
writeData = writeData + list(self.ROMEngine.getPoles())
writer.writerow(str(x) for x in writeData)
if verbose >= 1:
if self.mostExpensive == "APPROX":
print("Set {}/{}\tdone{}".format(j+1, Nparams, " "*25))
elif self.mostExpensive == "HF":
print("Point mu_{} = {:.10f}\tdone{}".format(i, mutar,
" " * 25))
self.filename = filename
return self.filename
def read(self, filename:str, restrictions : DictAny = {},
outputs : List[str] = []) -> DictAny:
"""
Execute a query on a custom format CSV.
Args:
filename: CSV filename.
restrictions(optional): Parameter configurations to output.
Defaults to empty dictionary, i.e. output all.
outputs(optional): Values to output. Defaults to empty list, i.e.
no output.
Returns:
Dictionary of desired results, with a key for each entry of
outputs, and a numpy 1D array as corresponding value.
"""
def pairMuFromCSV(filename:str, muTarget:complex) -> Tuple[str, str]:
"""
Find complex point in CSV closer to a prescribed value.
Args:
filename: CSV filename.
muTarget: Target complex value.
Returns:
Strings containing real and imaginary part of desired value, in
the same format as in the CSV file.
"""
mutarsF = np.array([], dtype = complex)
muRetarsF = np.array([], dtype = complex)
muImtarsF = np.array([], dtype = complex)
with open(filename, 'r') as f:
reader = csv.reader(f, delimiter=',')
header = next(reader)
muReindex = header.index('muRe')
muImindex = header.index('muIm')
for row in reader:
try:
if row[muReindex] not in [" ", ""]:
muRetarsF = np.append(muRetarsF, row[muReindex])
muImtarsF = np.append(muImtarsF, row[muImindex])
mutarsF = np.append(mutarsF, float(row[muReindex])
+ 1.j * float(row[muImindex]))
except:
pass
optimalIndex = np.argmin(np.abs(mutarsF - muTarget))
return [muRetarsF[optimalIndex], muImtarsF[optimalIndex]]
with open(filename, 'r') as f:
reader = csv.reader(f, delimiter=',')
header = next(reader)
restrIndices, outputIndices, outputData = {}, {}, {}
for key in restrictions.keys():
try:
restrIndices[key] = header.index(key)
if not isinstance(restrictions[key], list):
restrictions[key] = [restrictions[key]]
restrictions[key] = [str(x) for x in restrictions[key]]
except:
warnings.warn("Ignoring key {} from restrictions"\
.format(key), stacklevel = 2)
if 'muRe' in restrIndices.keys() or 'muIm' in restrIndices.keys():
if 'muRe' not in restrIndices.keys():
restrIndices['muRe'] = header.index('muRe')
restrictions['muRe'] = [0.] * len(restrictions['muIm'])
elif 'muIm' not in restrIndices.keys():
restrIndices['muIm'] = header.index('muIm')
restrictions['muIm'] = [0.] * len(restrictions['muRe'])
elif len(restrictions['muRe']) != len(restrictions['muIm']):
raise Exception(("The lists of values for muRe and muIm "
"must have the same length."))
for i in range(len(restrictions['muRe'])):
mu = (1.0 * float(restrictions['muRe'][i])
+ 1.j * float(restrictions['muIm'][i]))
restrictions['muRe'][i], restrictions['muIm'][i] =\
pairMuFromCSV(filename, mu)
for key in outputs:
try:
outputIndices[key] = header.index(key)
outputData[key] = np.array([])
except:
warnings.warn("Ignoring key {} from outputs".format(key),
stacklevel = 2)
for row in reader:
if all([row[restrIndices[key]] in restrictions[key]\
for key in restrictions.keys()]):
for key in outputIndices.keys():
try:
val = row[outputIndices[key]]
val = int(val)
except:
try:
val = float(val)
except:
val = np.nan
finally:
outputData[key] = np.append(outputData[key], val)
return outputData

Event Timeline