Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F62157264
parameter_sweeper.py
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sat, May 11, 07:14
Size
22 KB
Mime Type
text/x-python
Expires
Mon, May 13, 07:14 (2 d)
Engine
blob
Format
Raw Data
Handle
17612140
Attached To
R6746 RationalROMPy
parameter_sweeper.py
View Options
# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
from
copy
import
copy
import
itertools
import
csv
import
numpy
as
np
from
matplotlib
import
pyplot
as
plt
from
rrompy.utilities.base.types
import
Np1D
,
DictAny
,
List
,
ROMEng
from
rrompy.utilities.base
import
purgeList
,
getNewFilename
,
verbosityDepth
from
rrompy.utilities.warning_manager
import
warn
__all__
=
[
'ParameterSweeper'
]
def
C2R2csv
(
x
):
x
=
np
.
ravel
(
x
)
y
=
np
.
concatenate
((
np
.
real
(
x
),
np
.
imag
(
x
)))
z
=
np
.
ravel
(
np
.
reshape
(
y
,
[
2
,
np
.
size
(
x
)])
.
T
)
return
np
.
array2string
(
z
,
separator
=
'_'
,
suppress_small
=
False
,
max_line_width
=
np
.
inf
,
sign
=
'+'
,
formatter
=
{
'all'
:
lambda
x
:
"{:.15E}"
.
format
(
x
)}
)[
1
:
-
1
]
class
ParameterSweeper
:
"""
ROM approximant parameter sweeper.
Args:
ROMEngine(optional): Generic approximant class. Defaults to None.
mutars(optional): Array of parameter values to sweep. Defaults to empty
array.
params(optional): List of parameter settings (each as a dict) to
explore. Defaults to single empty set.
mostExpensive(optional): String containing label of most expensive
step, to be executed fewer times. Allowed options are 'HF' and
'Approx'. Defaults to 'HF'.
Attributes:
ROMEngine: Generic approximant class.
mutars: Array of parameter values to sweep.
params: List of parameter settings (each as a dict) to explore.
mostExpensive: String containing label of most expensive step, to be
executed fewer times.
"""
allowedOutputsStandard
=
[
"normHF"
,
"normApp"
,
"normRes"
,
"normResRel"
,
"normErr"
,
"normErrRel"
]
allowedOutputs
=
allowedOutputsStandard
+
[
"HFFunc"
,
"AppFunc"
,
"ErrFunc"
,
"ErrFuncRel"
]
allowedOutputsFull
=
allowedOutputs
+
[
"poles"
]
def
__init__
(
self
,
ROMEngine
:
ROMEng
=
None
,
mutars
:
Np1D
=
np
.
array
([]),
params
:
List
[
DictAny
]
=
[{}],
mostExpensive
:
str
=
"HF"
):
self
.
ROMEngine
=
ROMEngine
self
.
mutars
=
mutars
self
.
params
=
params
self
.
mostExpensive
=
mostExpensive
def
name
(
self
)
->
str
:
return
self
.
__class__
.
__name__
def
__str__
(
self
)
->
str
:
return
self
.
name
()
def
__repr__
(
self
)
->
str
:
return
self
.
__str__
()
+
" at "
+
hex
(
id
(
self
))
@property
def
mostExpensive
(
self
):
"""Value of mostExpensive."""
return
self
.
_mostExpensive
@mostExpensive.setter
def
mostExpensive
(
self
,
mostExpensive
:
str
):
mostExpensive
=
mostExpensive
.
upper
()
if
mostExpensive
not
in
[
"HF"
,
"APPROX"
]:
warn
((
"Value of mostExpensive not recognized. Overriding to "
"'APPROX'."
))
mostExpensive
=
"APPROX"
self
.
_mostExpensive
=
mostExpensive
def
checkValues
(
self
)
->
bool
:
"""Check if sweep can be performed."""
if
self
.
ROMEngine
is
None
:
raise
Exception
(
"ROMEngine is missing. Aborting."
)
if
len
(
self
.
mutars
)
==
0
:
raise
Exception
(
"Empty target parameter vector. Aborting."
)
if
len
(
self
.
params
)
==
0
:
raise
Exception
(
"Empty method parameters vector. Aborting."
)
def
sweep
(
self
,
filename
:
str
=
"out.dat"
,
outputs
:
List
[
str
]
=
[],
verbose
:
int
=
10
):
self
.
checkValues
()
try
:
if
outputs
.
upper
()
==
"ALL"
:
outputs
=
self
.
allowedOutputsFull
except
:
if
len
(
outputs
)
==
0
:
outputs
=
self
.
allowedOutputsStandard
outputs
=
purgeList
(
outputs
,
self
.
allowedOutputsFull
,
listname
=
self
.
name
()
+
".outputs"
,
baselevel
=
1
)
poles
=
(
"poles"
in
outputs
)
if
len
(
outputs
)
==
0
:
raise
Exception
(
"Empty outputs. Aborting."
)
outParList
=
self
.
ROMEngine
.
parameterList
Nparams
=
len
(
self
.
params
)
if
poles
:
polesCheckList
=
[]
allowedParams
=
self
.
ROMEngine
.
parameterList
dotPos
=
filename
.
rfind
(
'.'
)
if
dotPos
in
[
-
1
,
len
(
filename
)
-
1
]:
filename
=
getNewFilename
(
filename
[:
dotPos
])
else
:
filename
=
getNewFilename
(
filename
[:
dotPos
],
filename
[
dotPos
+
1
:])
append_write
=
"w"
initial_row
=
(
outParList
+
[
"muRe"
,
"muIm"
]
+
[
x
for
x
in
self
.
allowedOutputs
if
x
in
outputs
]
+
[
"type"
]
+
[
"poles"
]
*
poles
)
with
open
(
filename
,
append_write
,
buffering
=
1
)
as
fout
:
writer
=
csv
.
writer
(
fout
,
delimiter
=
","
)
writer
.
writerow
(
initial_row
)
if
self
.
mostExpensive
==
"HF"
:
outerSet
=
self
.
mutars
innerSet
=
self
.
params
elif
self
.
mostExpensive
==
"APPROX"
:
outerSet
=
self
.
params
innerSet
=
self
.
mutars
for
outerIdx
,
outerPar
in
enumerate
(
outerSet
):
if
self
.
mostExpensive
==
"HF"
:
i
,
mutar
=
outerIdx
,
outerPar
elif
self
.
mostExpensive
==
"APPROX"
:
j
,
par
=
outerIdx
,
outerPar
self
.
ROMEngine
.
approxParameters
=
{
k
:
par
[
k
]
for
k
in
\
par
.
keys
()
&
allowedParams
}
self
.
ROMEngine
.
setupApprox
()
for
innerIdx
,
innerPar
in
enumerate
(
innerSet
):
if
self
.
mostExpensive
==
"APPROX"
:
i
,
mutar
=
innerIdx
,
innerPar
elif
self
.
mostExpensive
==
"HF"
:
j
,
par
=
innerIdx
,
innerPar
self
.
ROMEngine
.
approxParameters
=
{
k
:
par
[
k
]
for
k
in
\
par
.
keys
()
&
allowedParams
}
self
.
ROMEngine
.
setupApprox
()
if
verbose
>=
5
:
verbosityDepth
(
"INIT"
,
"Set {}/{}
\t
mu_{} = {:.10f}"
\
.
format
(
j
+
1
,
Nparams
,
i
,
mutar
))
outData
=
[]
if
"normHF"
in
outputs
:
valNorm
=
self
.
ROMEngine
.
normHF
(
mutar
)
outData
=
outData
+
[
valNorm
]
if
"normApp"
in
outputs
:
val
=
self
.
ROMEngine
.
normApp
(
mutar
)
outData
=
outData
+
[
val
]
if
"normRes"
in
outputs
:
valNRes
=
self
.
ROMEngine
.
normRes
(
mutar
)
outData
=
outData
+
[
valNRes
]
if
"normResRel"
in
outputs
:
if
"normRes"
not
in
outputs
:
valNRes
=
self
.
ROMEngine
.
normRes
(
mutar
)
val
=
self
.
ROMEngine
.
normRHS
(
mutar
)
outData
=
outData
+
[
valNRes
/
val
]
if
"normErr"
in
outputs
:
valNErr
=
self
.
ROMEngine
.
normErr
(
mutar
)
outData
=
outData
+
[
valNErr
]
if
"normErrRel"
in
outputs
:
if
"normHF"
not
in
outputs
:
valNorm
=
self
.
ROMEngine
.
normHF
(
mutar
)
if
"normErr"
not
in
outputs
:
valNErr
=
self
.
ROMEngine
.
normErr
(
mutar
)
outData
=
outData
+
[
valNErr
/
valNorm
]
if
"HFFunc"
in
outputs
:
valFunc
=
self
.
ROMEngine
.
HFEngine
.
functional
(
self
.
ROMEngine
.
getHF
(
mutar
))
outData
=
outData
+
[
valFunc
]
if
"AppFunc"
in
outputs
:
valFApp
=
self
.
ROMEngine
.
HFEngine
.
functional
(
self
.
ROMEngine
.
getApp
(
mutar
))
outData
=
outData
+
[
valFApp
]
if
"ErrFunc"
in
outputs
:
if
"HFFunc"
not
in
outputs
:
valFunc
=
self
.
ROMEngine
.
HFEngine
.
functional
(
self
.
ROMEngine
.
getHF
(
mutar
))
if
"AppFunc"
not
in
outputs
:
valFApp
=
self
.
ROMEngine
.
HFEngine
.
functional
(
self
.
ROMEngine
.
getApp
(
mutar
))
valFErr
=
np
.
abs
(
valFApp
-
valFunc
)
outData
=
outData
+
[
valFErr
]
if
"ErrFuncRel"
in
outputs
:
if
not
(
"HFFunc"
in
outputs
or
"ErrFunc"
in
outputs
):
valFunc
=
self
.
ROMEngine
.
HFEngine
.
functional
(
self
.
ROMEngine
.
getHF
(
mutar
))
if
not
(
"AppFunc"
in
outputs
or
"ErrFunc"
in
outputs
):
valFApp
=
self
.
ROMEngine
.
HFEngine
.
functional
(
self
.
ROMEngine
.
getApp
(
mutar
))
val
=
np
.
nan
if
not
np
.
isclose
(
valFunc
,
0.
):
val
=
valFApp
/
valFunc
outData
=
outData
+
[
val
]
writeData
=
[]
for
parn
in
outParList
:
writeData
=
(
writeData
+
[
self
.
ROMEngine
.
approxParameters
[
parn
]])
writeData
=
(
writeData
+
[
mutar
.
real
,
mutar
.
imag
]
+
outData
+
[
self
.
ROMEngine
.
name
()])
if
poles
:
if
j
not
in
polesCheckList
:
polesCheckList
+=
[
j
]
writeData
=
writeData
+
[
C2R2csv
(
self
.
ROMEngine
.
getPoles
())]
else
:
writeData
=
writeData
+
[
""
]
writer
.
writerow
(
str
(
x
)
for
x
in
writeData
)
if
verbose
>=
5
:
verbosityDepth
(
"DEL"
,
""
,
end
=
""
,
inline
=
""
)
if
verbose
>=
5
:
if
self
.
mostExpensive
==
"APPROX"
:
out
=
"Set {}/{}
\t
done.
\n
"
.
format
(
j
+
1
,
Nparams
)
elif
self
.
mostExpensive
==
"HF"
:
out
=
"Point mu_{} = {:.10f}
\t
done.
\n
"
.
format
(
i
,
mutar
)
verbosityDepth
(
"INIT"
,
out
)
verbosityDepth
(
"DEL"
,
""
,
end
=
""
,
inline
=
""
)
self
.
filename
=
filename
return
self
.
filename
def
read
(
self
,
filename
:
str
,
restrictions
:
DictAny
=
{},
outputs
:
List
[
str
]
=
[])
->
DictAny
:
"""
Execute a query on a custom format CSV.
Args:
filename: CSV filename.
restrictions(optional): Parameter configurations to output.
Defaults to empty dictionary, i.e. output all.
outputs(optional): Values to output. Defaults to empty list, i.e.
no output.
Returns:
Dictionary of desired results, with a key for each entry of
outputs, and a numpy 1D array as corresponding value.
"""
with
open
(
filename
,
'r'
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
','
)
header
=
next
(
reader
)
restrIndices
,
outputIndices
,
outputData
=
{},
{},
{}
for
key
in
restrictions
.
keys
():
try
:
restrIndices
[
key
]
=
header
.
index
(
key
)
if
not
isinstance
(
restrictions
[
key
],
list
):
restrictions
[
key
]
=
[
restrictions
[
key
]]
restrictions
[
key
]
=
copy
(
restrictions
[
key
])
except
:
warn
(
"Ignoring key {} from restrictions."
.
format
(
key
))
for
key
in
outputs
:
try
:
outputIndices
[
key
]
=
header
.
index
(
key
)
outputData
[
key
]
=
np
.
array
([])
except
:
warn
(
"Ignoring key {} from outputs."
.
format
(
key
))
for
row
in
reader
:
restrTrue
=
True
for
key
in
restrictions
.
keys
():
if
row
[
restrIndices
[
key
]]
==
restrictions
[
key
]:
continue
try
:
if
np
.
any
(
np
.
isclose
(
float
(
row
[
restrIndices
[
key
]]),
[
float
(
x
)
for
x
in
restrictions
[
key
]])):
continue
except
:
pass
restrTrue
=
False
if
restrTrue
:
for
key
in
outputIndices
.
keys
():
try
:
val
=
row
[
outputIndices
[
key
]]
val
=
float
(
val
)
finally
:
outputData
[
key
]
=
np
.
append
(
outputData
[
key
],
val
)
return
outputData
def
plot
(
self
,
filename
:
str
,
xs
:
List
[
str
],
ys
:
List
[
str
],
zs
:
List
[
str
],
onePlot
:
bool
=
False
,
save
:
str
=
None
,
saveFormat
:
str
=
"eps"
,
saveDPI
:
int
=
100
,
**
figspecs
):
"""
Perform plots from data in filename.
Args:
filename: CSV filename.
xs: Values to put on x axes.
ys: Values to put on y axes.
zs: Meta-values for constraints.
onePlot(optional): Whether to create a single figure per x.
Defaults to False.
save(optional): Where to save plot(s). Defaults to None, i.e. no
saving.
saveFormat(optional): Format for saved plot(s). Defaults to "eps".
saveDPI(optional): DPI for saved plot(s). Defaults to 100.
figspecs(optional key args): Optional arguments for matplotlib
figure creation.
"""
if
save
is
not
None
:
save
=
save
.
strip
()
zsVals
=
self
.
read
(
filename
,
outputs
=
zs
)
zs
=
list
(
zsVals
.
keys
())
zss
=
None
for
key
in
zs
:
vals
=
np
.
unique
(
zsVals
[
key
])
if
zss
is
None
:
zss
=
copy
(
vals
)
else
:
zss
=
list
(
itertools
.
product
(
zss
,
vals
))
lzs
=
len
(
zs
)
for
z
in
zss
:
if
lzs
<=
1
:
constr
=
{
zs
[
0
]
:
z
}
else
:
constr
=
{
zs
[
j
]
:
z
[
j
]
for
j
in
range
(
len
(
zs
))}
data
=
self
.
read
(
filename
,
restrictions
=
constr
,
outputs
=
xs
+
ys
)
if
onePlot
:
for
x
in
xs
:
xVals
=
data
[
x
]
p
=
plt
.
figure
(
**
figspecs
)
logScale
=
False
for
y
in
ys
:
yVals
=
data
[
y
]
label
=
'{} vs {} for {}'
.
format
(
y
,
x
,
constr
)
if
np
.
min
(
yVals
)
<=
-
np
.
finfo
(
float
)
.
eps
:
plt
.
plot
(
xVals
,
yVals
,
label
=
label
)
else
:
plt
.
plot
(
xVals
,
yVals
,
label
=
label
)
if
np
.
log10
(
np
.
max
(
yVals
)
/
np
.
min
(
yVals
))
>
1.
:
logScale
=
True
if
logScale
:
ax
=
p
.
get_axes
()[
0
]
ax
.
set_yscale
(
'log'
)
plt
.
legend
()
plt
.
grid
()
if
save
is
not
None
:
prefix
=
"{}_{}_vs_{}_{}"
.
format
(
save
,
ys
,
x
,
constr
)
plt
.
savefig
(
getNewFilename
(
prefix
,
saveFormat
),
format
=
saveFormat
,
dpi
=
saveDPI
)
plt
.
show
()
plt
.
close
()
else
:
for
x
,
y
in
itertools
.
product
(
xs
,
ys
):
xVals
,
yVals
=
data
[
x
],
data
[
y
]
label
=
'{} vs {} for {}'
.
format
(
y
,
x
,
constr
)
p
=
plt
.
figure
(
**
figspecs
)
if
np
.
min
(
yVals
)
<=
-
np
.
finfo
(
float
)
.
eps
:
plt
.
plot
(
xVals
,
yVals
,
label
=
label
)
else
:
plt
.
plot
(
xVals
,
yVals
,
label
=
label
)
if
np
.
log10
(
np
.
max
(
yVals
)
/
np
.
min
(
yVals
))
>
1.
:
ax
=
p
.
get_axes
()[
0
]
ax
.
set_yscale
(
'log'
)
plt
.
legend
()
plt
.
grid
()
if
save
is
not
None
:
prefix
=
"{}_{}_vs_{}_{}"
.
format
(
save
,
y
,
x
,
constr
)
plt
.
savefig
(
getNewFilename
(
prefix
,
saveFormat
),
format
=
saveFormat
,
dpi
=
saveDPI
)
plt
.
show
()
plt
.
close
()
def
plotCompare
(
self
,
filenames
:
List
[
str
],
xs
:
List
[
str
],
ys
:
List
[
str
],
zs
:
List
[
str
],
onePlot
:
bool
=
False
,
save
:
str
=
None
,
ylims
:
dict
=
None
,
saveFormat
:
str
=
"eps"
,
saveDPI
:
int
=
100
,
labels
:
List
[
str
]
=
None
,
**
figspecs
):
"""
Perform plots from data in filename1 and filename2.
Args:
filenames: CSV filenames.
xs: Values to put on x axes.
ys: Values to put on y axes.
zs: Meta-values for constraints.
onePlot(optional): Whether to create a single figure per x.
Defaults to False.
save(optional): Where to save plot(s). Defaults to None, i.e. no
saving.
clip(optional): Custom y axis limits. If None, automatic values are
kept. Defaults to None.
saveFormat(optional): Format for saved plot(s). Defaults to "eps".
saveDPI(optional): DPI for saved plot(s). Defaults to 100.
labels: Label for each dataset.
figspecs(optional key args): Optional arguments for matplotlib
figure creation.
"""
nfiles
=
len
(
filenames
)
if
save
is
not
None
:
save
=
save
.
strip
()
if
labels
is
None
:
labels
=
[
"{}"
.
format
(
j
+
1
)
for
j
in
range
(
nfiles
)]
zsVals
=
self
.
read
(
filenames
[
0
],
outputs
=
zs
)
zs
=
list
(
zsVals
.
keys
())
zss
=
None
for
key
in
zs
:
vals
=
np
.
unique
(
zsVals
[
key
])
if
zss
is
None
:
zss
=
copy
(
vals
)
else
:
zss
=
list
(
itertools
.
product
(
zss
,
vals
))
lzs
=
len
(
zs
)
for
z
in
zss
:
if
lzs
<=
1
:
constr
=
{
zs
[
0
]
:
z
}
else
:
constr
=
{
zs
[
j
]
:
z
[
j
]
for
j
in
range
(
len
(
zs
))}
data
=
[
None
]
*
nfiles
for
j
in
range
(
nfiles
):
data
[
j
]
=
self
.
read
(
filenames
[
j
],
restrictions
=
constr
,
outputs
=
xs
+
ys
)
if
onePlot
:
for
x
in
xs
:
xVals
=
[
None
]
*
nfiles
for
j
in
range
(
nfiles
):
try
:
xVals
[
j
]
=
data
[
j
][
x
]
except
:
pass
p
=
plt
.
figure
(
**
figspecs
)
logScale
=
False
for
y
in
ys
:
for
j
in
range
(
nfiles
):
try
:
yVals
=
data
[
j
][
y
]
except
:
pass
l
=
'{} vs {} for {}, {}'
.
format
(
y
,
x
,
constr
,
labels
[
j
])
if
np
.
min
(
yVals
)
<=
-
np
.
finfo
(
float
)
.
eps
:
plt
.
plot
(
xVals
[
j
],
yVals
,
label
=
l
)
else
:
plt
.
plot
(
xVals
[
j
],
yVals
,
label
=
l
)
if
np
.
log10
(
np
.
max
(
yVals
)
/
np
.
min
(
yVals
))
>
1.
:
logScale
=
True
if
logScale
:
ax
=
p
.
get_axes
()[
0
]
ax
.
set_yscale
(
'log'
)
if
ylims
is
not
None
:
plt
.
ylim
(
**
ylims
)
plt
.
legend
()
plt
.
grid
()
if
save
is
not
None
:
prefix
=
"{}_{}_vs_{}_{}"
.
format
(
save
,
ys
,
x
,
constr
)
plt
.
savefig
(
getNewFilename
(
prefix
,
saveFormat
),
format
=
saveFormat
,
dpi
=
saveDPI
)
plt
.
show
()
plt
.
close
()
else
:
for
x
,
y
in
itertools
.
product
(
xs
,
ys
):
p
=
plt
.
figure
(
**
figspecs
)
logScale
=
False
for
j
in
range
(
nfiles
):
xVals
,
yVals
=
data
[
j
][
x
],
data
[
j
][
y
]
l
=
'{} vs {} for {}, {}'
.
format
(
y
,
x
,
constr
,
labels
[
j
])
if
np
.
min
(
yVals
)
<=
-
np
.
finfo
(
float
)
.
eps
:
plt
.
plot
(
xVals
,
yVals
,
label
=
l
)
else
:
plt
.
plot
(
xVals
,
yVals
,
label
=
l
)
if
np
.
log10
(
np
.
max
(
yVals
)
/
np
.
min
(
yVals
))
>
1.
:
logScale
=
True
if
logScale
:
ax
=
p
.
get_axes
()[
0
]
ax
.
set_yscale
(
'log'
)
if
ylims
is
not
None
:
plt
.
ylim
(
**
ylims
)
plt
.
legend
()
plt
.
grid
()
if
save
is
not
None
:
prefix
=
"{}_{}_vs_{}_{}"
.
format
(
save
,
y
,
x
,
constr
)
plt
.
savefig
(
getNewFilename
(
prefix
,
saveFormat
),
format
=
saveFormat
,
dpi
=
saveDPI
)
plt
.
show
()
plt
.
close
()
Event Timeline
Log In to Comment