Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F62491776
parameter_sweeper.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Mon, May 13, 14:14
Size
13 KB
Mime Type
text/x-python
Expires
Wed, May 15, 14:14 (2 d)
Engine
blob
Format
Raw Data
Handle
17653517
Attached To
R6746 RationalROMPy
parameter_sweeper.py
View Options
# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
from
copy
import
copy
import
itertools
import
csv
import
warnings
import
numpy
as
np
from
matplotlib
import
pyplot
as
plt
from
rrompy.utilities.base.types
import
Np1D
,
N2FSExpr
,
DictAny
,
List
,
ROMEng
from
rrompy.utilities.base
import
purgeList
,
getNewFilename
__all__
=
[
'ParameterSweeper'
]
def
C2R2csv
(
x
):
x
=
np
.
ravel
(
x
)
y
=
np
.
concatenate
((
np
.
real
(
x
),
np
.
imag
(
x
)))
z
=
np
.
ravel
(
np
.
reshape
(
y
,
[
2
,
np
.
size
(
x
)])
.
T
)
return
np
.
array2string
(
z
,
separator
=
'_'
,
suppress_small
=
False
,
max_line_width
=
np
.
inf
,
sign
=
'+'
,
formatter
=
{
'all'
:
lambda
x
:
"{:.15E}"
.
format
(
x
)}
)[
1
:
-
1
]
class
ParameterSweeper
:
"""
ROM approximant parameter sweeper.
Args:
ROMEngine(optional): Generic approximant class. Defaults to None.
mutars(optional): Array of parameter values to sweep. Defaults to empty
array.
params(optional): List of parameter settings (each as a dict) to
explore. Defaults to single empty set.
mostExpensive(optional): String containing label of most expensive
step, to be executed fewer times. Allowed options are 'HF' and
'Approx'. Defaults to 'HF'.
normType(optional): Target norm identifier. Must be recognizable by
HSEngine norm command. Defaults to None.
Attributes:
ROMEngine: Generic approximant class.
mutars: Array of parameter values to sweep.
params: List of parameter settings (each as a dict) to explore.
mostExpensive: String containing label of most expensive step, to be
executed fewer times.
"""
allowedOutputsStandard
=
[
"HFNorm"
,
"AppNorm"
,
"ErrNorm"
]
allowedOutputs
=
allowedOutputsStandard
+
[
"HFFunc"
,
"AppFunc"
,
"ErrFunc"
]
allowedOutputsFull
=
allowedOutputs
+
[
"poles"
]
def
__init__
(
self
,
ROMEngine
:
ROMEng
=
None
,
mutars
:
Np1D
=
np
.
array
([]),
params
:
List
[
DictAny
]
=
[{}],
mostExpensive
:
str
=
"HF"
,
normType
:
N2FSExpr
=
None
):
self
.
ROMEngine
=
ROMEngine
self
.
mutars
=
mutars
self
.
params
=
params
self
.
mostExpensive
=
mostExpensive
self
.
normType
=
normType
def
name
(
self
)
->
str
:
return
self
.
__class__
.
__name__
def
__str__
(
self
)
->
str
:
return
self
.
name
()
@property
def
mostExpensive
(
self
):
"""Value of mostExpensive."""
return
self
.
_mostExpensive
@mostExpensive.setter
def
mostExpensive
(
self
,
mostExpensive
:
str
):
mostExpensive
=
mostExpensive
.
upper
()
if
mostExpensive
not
in
[
"HF"
,
"APPROX"
]:
warnings
.
warn
((
"Value of mostExpensive not recognized. Overriding "
"to 'APPROX'."
),
stacklevel
=
2
)
mostExpensive
=
"APPROX"
self
.
_mostExpensive
=
mostExpensive
def
checkValues
(
self
)
->
bool
:
"""Check if sweep can be performed."""
if
self
.
ROMEngine
is
None
:
warnings
.
warn
(
"ROMEngine is missing. Aborting."
,
stacklevel
=
2
)
return
False
if
len
(
self
.
mutars
)
==
0
:
warnings
.
warn
(
"Empty target parameter vector. Aborting."
,
stacklevel
=
2
)
return
False
if
len
(
self
.
params
)
==
0
:
warnings
.
warn
(
"Empty method parameters vector. Aborting."
,
stacklevel
=
2
)
return
False
return
True
def
sweep
(
self
,
filename
:
str
=
"out.dat"
,
outputs
:
List
[
str
]
=
[],
verbose
:
int
=
1
):
if
not
self
.
checkValues
():
return
try
:
if
outputs
.
upper
()
==
"ALL"
:
outputs
=
self
.
allowedOutputsFull
except
:
if
len
(
outputs
)
==
0
:
outputs
=
self
.
allowedOutputsStandard
outputs
=
purgeList
(
outputs
,
self
.
allowedOutputsFull
,
listname
=
self
.
name
()
+
".outputs"
,
baselevel
=
1
)
poles
=
(
"poles"
in
outputs
)
if
len
(
outputs
)
==
0
:
warnings
.
warn
(
"Empty outputs. Aborting."
,
stacklevel
=
2
)
return
outParList
=
self
.
ROMEngine
.
parameterList
Nparams
=
len
(
self
.
params
)
if
poles
:
polesCheckList
=
[]
allowedParams
=
self
.
ROMEngine
.
parameterList
dotPos
=
filename
.
rfind
(
'.'
)
if
dotPos
in
[
-
1
,
len
(
filename
)
-
1
]:
filename
=
getNewFilename
(
filename
[:
dotPos
])
else
:
filename
=
getNewFilename
(
filename
[:
dotPos
],
filename
[
dotPos
+
1
:])
append_write
=
"w"
initial_row
=
(
outParList
+
[
"muRe"
,
"muIm"
]
+
[
x
for
x
in
self
.
allowedOutputs
if
x
in
outputs
]
+
[
"type"
]
+
[
"poles"
]
*
poles
)
with
open
(
filename
,
append_write
,
buffering
=
1
)
as
fout
:
writer
=
csv
.
writer
(
fout
,
delimiter
=
","
)
writer
.
writerow
(
initial_row
)
if
self
.
mostExpensive
==
"HF"
:
outerSet
=
self
.
mutars
innerSet
=
self
.
params
elif
self
.
mostExpensive
==
"APPROX"
:
outerSet
=
self
.
params
innerSet
=
self
.
mutars
for
outerIdx
,
outerPar
in
enumerate
(
outerSet
):
if
self
.
mostExpensive
==
"HF"
:
i
,
mutar
=
outerIdx
,
outerPar
elif
self
.
mostExpensive
==
"APPROX"
:
j
,
par
=
outerIdx
,
outerPar
self
.
ROMEngine
.
approxParameters
=
{
k
:
par
[
k
]
for
k
in
\
par
.
keys
()
&
allowedParams
}
self
.
ROMEngine
.
setupApprox
()
for
innerIdx
,
innerPar
in
enumerate
(
innerSet
):
if
self
.
mostExpensive
==
"APPROX"
:
i
,
mutar
=
innerIdx
,
innerPar
elif
self
.
mostExpensive
==
"HF"
:
j
,
par
=
innerIdx
,
innerPar
self
.
ROMEngine
.
approxParameters
=
{
k
:
par
[
k
]
for
k
in
\
par
.
keys
()
&
allowedParams
}
self
.
ROMEngine
.
setupApprox
()
if
verbose
>=
1
:
print
(
"Set {}/{}
\t
mu_{} = {:.10f}"
.
format
(
j
+
1
,
Nparams
,
i
,
mutar
))
outData
=
[]
if
"HFNorm"
in
outputs
:
val
=
self
.
ROMEngine
.
HFNorm
(
mutar
,
self
.
normType
)
if
isinstance
(
val
,
(
list
,)):
val
=
val
[
0
]
outData
=
outData
+
[
val
]
if
"AppNorm"
in
outputs
:
val
=
self
.
ROMEngine
.
approxNorm
(
mutar
,
self
.
normType
)
if
isinstance
(
val
,
(
list
,)):
val
=
val
[
0
]
outData
=
outData
+
[
val
]
if
"ErrNorm"
in
outputs
:
val
=
self
.
ROMEngine
.
approxError
(
mutar
,
self
.
normType
)
if
isinstance
(
val
,
(
list
,)):
val
=
val
[
0
]
outData
=
outData
+
[
val
]
if
"HFFunc"
in
outputs
:
outData
=
outData
+
[
self
.
ROMEngine
.
HFEngine
.
functional
(
self
.
ROMEngine
.
getHF
(
mutar
))]
if
"AppFunc"
in
outputs
:
outData
=
outData
+
[
self
.
ROMEngine
.
HFEngine
.
functional
(
self
.
ROMEngine
.
getApp
(
mutar
))]
if
"ErrFunc"
in
outputs
:
outData
=
outData
+
[
self
.
ROMEngine
.
HFEngine
.
functional
(
self
.
ROMEngine
.
getApp
(
mutar
))
-
self
.
ROMEngine
.
HFEngine
.
functional
(
self
.
ROMEngine
.
getHF
(
mutar
))]
writeData
=
[]
for
parn
in
outParList
:
writeData
=
(
writeData
+
[
self
.
ROMEngine
.
approxParameters
[
parn
]])
writeData
=
(
writeData
+
[
mutar
.
real
,
mutar
.
imag
]
+
outData
+
[
self
.
ROMEngine
.
name
()])
if
poles
:
if
j
not
in
polesCheckList
:
polesCheckList
+=
[
j
]
writeData
=
writeData
+
[
C2R2csv
(
self
.
ROMEngine
.
getPoles
())]
else
:
writeData
=
writeData
+
[
""
]
writer
.
writerow
(
str
(
x
)
for
x
in
writeData
)
if
verbose
>=
1
:
if
self
.
mostExpensive
==
"APPROX"
:
print
(
"Set {}/{}
\t
done"
.
format
(
j
+
1
,
Nparams
))
elif
self
.
mostExpensive
==
"HF"
:
print
(
"Point mu_{} = {:.10f}
\t
done"
.
format
(
i
,
mutar
))
self
.
filename
=
filename
return
self
.
filename
def
read
(
self
,
filename
:
str
,
restrictions
:
DictAny
=
{},
outputs
:
List
[
str
]
=
[])
->
DictAny
:
"""
Execute a query on a custom format CSV.
Args:
filename: CSV filename.
restrictions(optional): Parameter configurations to output.
Defaults to empty dictionary, i.e. output all.
outputs(optional): Values to output. Defaults to empty list, i.e.
no output.
Returns:
Dictionary of desired results, with a key for each entry of
outputs, and a numpy 1D array as corresponding value.
"""
with
open
(
filename
,
'r'
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
','
)
header
=
next
(
reader
)
restrIndices
,
outputIndices
,
outputData
=
{},
{},
{}
for
key
in
restrictions
.
keys
():
try
:
restrIndices
[
key
]
=
header
.
index
(
key
)
if
not
isinstance
(
restrictions
[
key
],
list
):
restrictions
[
key
]
=
[
restrictions
[
key
]]
restrictions
[
key
]
=
copy
(
restrictions
[
key
])
except
:
warnings
.
warn
(
"Ignoring key {} from restrictions"
\
.
format
(
key
),
stacklevel
=
2
)
for
key
in
outputs
:
try
:
outputIndices
[
key
]
=
header
.
index
(
key
)
outputData
[
key
]
=
np
.
array
([])
except
:
warnings
.
warn
(
"Ignoring key {} from outputs"
.
format
(
key
),
stacklevel
=
2
)
for
row
in
reader
:
restrTrue
=
True
for
key
in
restrictions
.
keys
():
if
row
[
restrIndices
[
key
]]
==
restrictions
[
key
]:
continue
try
:
if
np
.
any
(
np
.
isclose
(
float
(
row
[
restrIndices
[
key
]]),
[
float
(
x
)
for
x
in
restrictions
[
key
]])):
continue
except
:
pass
restrTrue
=
False
if
restrTrue
:
for
key
in
outputIndices
.
keys
():
try
:
val
=
row
[
outputIndices
[
key
]]
val
=
float
(
val
)
finally
:
outputData
[
key
]
=
np
.
append
(
outputData
[
key
],
val
)
return
outputData
def
plot
(
self
,
filename
:
str
,
xs
:
List
[
str
],
ys
:
List
[
str
],
zs
:
List
[
str
],
onePlot
:
bool
=
False
):
"""
Perform plots from data in filename.
Args:
filename: CSV filename.
xs: Values to put on x axes.
ys: Values to put on y axes.
zs: Meta-values for constraints.
onePlot: Whether to create a single figure per x. Defaults to
False.
"""
zsVals
=
self
.
read
(
filename
,
outputs
=
zs
)
zs
=
list
(
zsVals
.
keys
())
zss
=
None
for
key
in
zs
:
vals
=
np
.
unique
(
zsVals
[
key
])
if
zss
is
None
:
zss
=
copy
(
vals
)
else
:
zss
=
list
(
itertools
.
product
(
zss
,
vals
))
lzs
=
len
(
zs
)
for
z
in
zss
:
if
lzs
<=
1
:
constr
=
{
zs
[
0
]
:
z
}
else
:
constr
=
{
zs
[
j
]
:
z
[
j
]
for
j
in
range
(
len
(
zs
))}
data
=
self
.
read
(
filename
,
restrictions
=
constr
,
outputs
=
xs
+
ys
)
if
onePlot
:
for
x
in
xs
:
xVals
=
data
[
x
]
plt
.
figure
()
for
y
in
ys
:
yVals
=
data
[
y
]
label
=
'{} vs {} for {}'
.
format
(
x
,
y
,
constr
)
plt
.
semilogy
(
xVals
,
yVals
,
label
=
label
)
plt
.
legend
()
plt
.
grid
()
plt
.
show
()
plt
.
close
()
else
:
for
x
,
y
in
itertools
.
product
(
xs
,
ys
):
xVals
,
yVals
=
data
[
x
],
data
[
y
]
label
=
'{} vs {} for {}'
.
format
(
x
,
y
,
constr
)
plt
.
figure
()
plt
.
semilogy
(
xVals
,
yVals
,
label
=
label
)
plt
.
legend
()
plt
.
grid
()
plt
.
show
()
plt
.
close
()
Event Timeline
Log In to Comment