diff --git a/examples/base/solver.py b/examples/base/solver.py
index e12f25c..5532ec3 100644
--- a/examples/base/solver.py
+++ b/examples/base/solver.py
@@ -1,74 +1,75 @@
import numpy as np
from rrompy.hfengines.linear_problem import \
HelmholtzSquareBubbleProblemEngine as HSBPE
from rrompy.hfengines.linear_problem import \
HelmholtzSquareTransmissionProblemEngine as HSTPE
from rrompy.hfengines.linear_problem import \
HelmholtzBoxScatteringProblemEngine as HBSPE
from rrompy.hfengines.linear_problem import \
HelmholtzCavityScatteringProblemEngine as HCSPE
-testNo = 4
+testNo = 1
verb = 0
if testNo == 1:
solver = HSBPE(kappa = 12 ** .5, theta = np.pi / 3, n = 20,
verbosity = verb)
mu = 12.**.5
+ solver.setSolver("BICG", {"tol" : 1e-15})
uh = solver.solve(mu)
solver.plotmesh()
print(solver.norm(uh))
solver.plot(uh)
solver.plot(solver.residual(uh, mu), 'res')
###########
elif testNo in [2, -2]:
solver = HSTPE(nT = 1, nB = 2, theta = np.pi * 20 / 180, kappa = 4.,
n = 50, verbosity = verb)
mu = 4.
uref = solver.liftDirichletData(mu)
if testNo > 0:
uh = solver.solve(mu)
utot = uh - uref
else:
utot = solver.solve(mu, homogeneized = True)
uh = utot + uref
print(solver.norm(uh))
print(solver.norm(uref))
solver.plot(uh)
solver.plot(uref, name = 'u_Dir')
solver.plot(utot, name = 'u_tot')
solver.plot(solver.residual(uh, mu), 'res')
solver.plot(solver.residual(utot, mu, homogeneized = True), 'res_tot')
###########
elif testNo in [3, -3]:
solver = HBSPE(R = 5, kappa = 12**.5, theta = - np.pi * 60 / 180, n = 30,
verbosity = verb)
mu = 12**.5
uref = solver.liftDirichletData(mu)
if testNo > 0:
uh = solver.solve(mu)
utot = uh - uref
else:
utot = solver.solve(mu, homogeneized = True)
uh = utot + uref
solver.plotmesh()
print(solver.norm(uh))
print(solver.norm(utot))
solver.plot(uh)
solver.plot(utot, name = 'u_tot')
solver.plot(solver.residual(uh, mu), 'res')
solver.plot(solver.residual(utot, mu, homogeneized = True), 'res_tot')
###########
elif testNo == 4:
solver = HCSPE(kappa = 5, n = 30, verbosity = verb)
mu = 10
uh = solver.solve(mu)
solver.plotmesh()
print(solver.norm(uh))
solver.plot(uh)
solver.plot(solver.residual(uh, mu), 'res')
diff --git a/rrompy/hfengines/base/problem_engine_base.py b/rrompy/hfengines/base/problem_engine_base.py
index 16c897d..c0c4af4 100644
--- a/rrompy/hfengines/base/problem_engine_base.py
+++ b/rrompy/hfengines/base/problem_engine_base.py
@@ -1,499 +1,504 @@
# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see .
#
from abc import abstractmethod
from os import path, mkdir
import fenics as fen
import numpy as np
from scipy.sparse import csr_matrix
import scipy.sparse as scsp
-import scipy.sparse.linalg as scspla
from matplotlib import pyplot as plt
from rrompy.utilities.base.types import (Np1D, Np2D, ScOp, strLst, FenFunc,
- Tuple, List)
+ Tuple, List, DictAny)
from rrompy.utilities.base import purgeList, getNewFilename, verbosityDepth
+from rrompy.solver import setupSolver
from .boundary_conditions import BoundaryConditions
from rrompy.utilities.exception_manager import RROMPyException
__all__ = ['ProblemEngineBase']
class ProblemEngineBase:
"""
Generic solver for parametric problems.
Attributes:
verbosity: Verbosity level.
BCManager: Boundary condition manager.
V: Real FE space.
u: Generic trial functions for variational form evaluation.
v: Generic test functions for variational form evaluation.
As: Scipy sparse array representation (in CSC format) of As.
bs: Numpy array representation of bs.
energyNormMatrix: Scipy sparse matrix representing inner product over
V.
bsmu: Mu value of last bs evaluation.
liftDirichletDatamu: Mu value of last Dirichlet datum evaluation.
liftedDirichletDatum: Dofs of Dirichlet datum lifting.
mu0BC: Mu value of last Dirichlet datum lifting.
degree_threshold: Threshold for ufl expression interpolation degree.
"""
nAs, nbs = 1, 1
rescalingExp = 1.
functional = lambda self, u: 0.
def __init__(self, degree_threshold : int = np.inf, verbosity : int = 10,
timestamp : bool = True):
self.BCManager = BoundaryConditions("Dirichlet")
self.V = fen.FunctionSpace(fen.UnitSquareMesh(10, 10), "P", 1)
self.verbosity = verbosity
self.timestamp = timestamp
self.resetAs()
self.resetbs()
self.bsmu = np.nan
self.liftDirichletDatamu = np.nan
self.mu0BC = np.nan
self.degree_threshold = degree_threshold
+ self.setSolver("SPSOLVE", {"use_umfpack" : False})
def name(self) -> str:
return self.__class__.__name__
def __str__(self) -> str:
return self.name()
def __repr__(self) -> str:
return self.__str__() + " at " + hex(id(self))
def __dir_base__(self):
return [x for x in self.__dir__() if x[:2] != "__"]
@property
def V(self):
"""Value of V."""
return self._V
@V.setter
def V(self, V):
self.resetAs()
self.resetbs()
if not type(V).__name__ == 'FunctionSpace':
raise RROMPyException("V type not recognized.")
self._V = V
self.u = fen.TrialFunction(V)
self.v = fen.TestFunction(V)
def innerProduct(self, u:Np2D, v:Np2D, onlyDiag : bool = False) -> Np2D:
"""Hilbert space scalar product."""
if not hasattr(self, "energyNormMatrix"):
self.buildEnergyNormForm()
if onlyDiag:
return np.sum(self.energyNormMatrix.dot(u) * v.conj(), axis = 0)
return v.T.conj().dot(self.energyNormMatrix.dot(u))
def buildEnergyNormForm(self): # L2
"""
Build sparse matrix (in CSR format) representative of scalar product.
"""
if self.verbosity >= 20:
verbosityDepth("INIT", "Assembling energy matrix.",
timestamp = self.timestamp)
normMatFen = fen.assemble(fen.dot(self.u, self.v) * fen.dx)
normMat = fen.as_backend_type(normMatFen).mat()
self.energyNormMatrix = csr_matrix(normMat.getValuesCSR()[::-1],
shape = normMat.size)
if self.verbosity >= 20:
verbosityDepth("DEL", "Done assembling energy matrix.",
timestamp = self.timestamp)
def norm(self, u:Np2D) -> Np1D:
return np.abs(self.innerProduct(u, u, onlyDiag = True)) ** .5
def checkAInBounds(self, der : int = 0):
"""Check if derivative index is oob for operator of linear system."""
if der < 0 or der >= self.nAs:
d = self.V.dim()
return scsp.csr_matrix((np.zeros(0), np.zeros(0), np.zeros(d + 1)),
shape = (d, d), dtype = np.complex)
def checkbInBounds(self, der : int = 0, homogeneized : bool = False):
"""Check if derivative index is oob for RHS of linear system."""
if der < 0 or der >= max(self.nbs, self.nAs * homogeneized):
return np.zeros(self.V.dim(), dtype = np.complex)
def setDirichletDatum(self, mu:complex):
"""Set Dirichlet datum if parametric."""
if hasattr(self, "liftedDirichletDatum"):
self.liftDirichletDatamu = mu
def liftDirichletData(self, mu:complex) -> Np1D:
"""Lift Dirichlet datum."""
self.setDirichletDatum(mu)
if not np.isclose(self.liftDirichletDatamu, mu):
try:
liftRe = fen.interpolate(self.DirichletDatum[0], self.V)
except:
liftRe = fen.project(self.DirichletDatum[0], self.V)
try:
liftIm = fen.interpolate(self.DirichletDatum[1], self.V)
except:
liftIm = fen.project(self.DirichletDatum[1], self.V)
self.liftedDirichletDatum = (np.array(liftRe.vector())
+ 1.j * np.array(liftIm.vector()))
return self.liftedDirichletDatum
def resetAs(self):
"""Reset (derivatives of) operator of linear system."""
self.As = [None] * self.nAs
def resetbs(self):
"""Reset (derivatives of) RHS of linear system."""
self.bs = {True: [None] * max(self.nbs, self.nAs),
False: [None] * self.nbs}
def reduceQuadratureDegree(self, fun:FenFunc, name:str):
"""Check whether to reduce compiler parameters to degree threshold."""
if not np.isinf(self.degree_threshold):
from ufl.algorithms.estimate_degrees import (
estimate_total_polynomial_degree as ETPD)
try:
deg = ETPD(fun)
except:
return False
if deg > self.degree_threshold:
if self.verbosity >= 15:
verbosityDepth("MAIN", ("Reducing quadrature degree from "
"{} to {} for {}.").format(
deg,
self.degree_threshold,
name),
timestamp = self.timestamp)
return True
return False
def iterReduceQuadratureDegree(self, funsNames:List[Tuple[FenFunc, str]]):
"""
Iterate reduceQuadratureDegree over list and define reduce compiler
parameters.
"""
if funsNames is not None:
for fun, name in funsNames:
if self.reduceQuadratureDegree(fun, name):
return {"quadrature_degree" : self.degree_threshold}
return {}
@abstractmethod
def A(self, mu:complex, der : int = 0) -> ScOp:
"""Assemble (derivative of) operator of linear system."""
Anull = self.checkAInBounds(der)
if Anull is not None: return Anull
if self.As[der] is None:
self.As[der] = 0.
return self.As[der]
@abstractmethod
def b(self, mu:complex, der : int = 0,
homogeneized : bool = False) -> Np1D:
"""Assemble (derivative of) RHS of linear system."""
bnull = self.checkbInBounds(der, homogeneized)
if bnull is not None: return bnull
if self.bs[homogeneized][der] is None:
self.bs[homogeneized][der] = 0.
return self.bs[homogeneized][der]
def affineLinearSystemA(self, mu : complex = 0.) -> List[Np2D]:
"""
Assemble affine blocks of operator of linear system (just linear
blocks).
"""
As = [None] * self.nAs
for j in range(self.nAs):
As[j] = self.A(mu, j)
return As
def affineWeightsA(self, mu : complex = 0.) -> callable:
"""
Assemble affine blocks of operator of linear system (just affine
weights). Stored as strings for the sake of pickling.
"""
lambdasA = ["np.ones_like(mu)"]
mu0Eff = np.power(mu, self.rescalingExp)
for j in range(1, self.nAs):
lambdasA += ["np.power(np.power(mu, {1}) - {2}, {0})".format(j,
self.rescalingExp,
mu0Eff)]
return lambdasA
def affineBlocksA(self, mu : complex = 0.) -> Tuple[List[Np2D], callable]:
"""Assemble affine blocks of operator of linear system."""
return self.affineLinearSystemA(mu), self.affineWeightsA(mu)
def setnbsEff(self, homogeneized : bool = False):
"""Compute effective number of b terms."""
self.nbsEff = max(homogeneized * self.nAs, self.nbs)
def affineLinearSystemb(self, mu : complex = 0.,
homogeneized : bool = False) -> List[Np1D]:
"""
Assemble affine blocks of RHS of linear system (just linear blocks).
"""
self.setnbsEff(homogeneized)
bs = [None] * self.nbsEff
for j in range(self.nbsEff):
bs[j] = self.b(mu, j, homogeneized)
return bs
def affineWeightsb(self, mu : complex = 0., homogeneized : bool = False)\
-> callable:
"""
Assemble affine blocks of RHS of linear system (just affine weights).
Stored as strings for the sake of pickling.
"""
self.setnbsEff(homogeneized)
lambdasb = ["np.ones_like(mu)"]
mu0Eff = np.power(mu, self.rescalingExp)
for j in range(1, self.nbsEff):
lambdasb += ["np.power(np.power(mu, {1}) - {2}, {0})".format(j,
self.rescalingExp,
mu0Eff)]
return lambdasb
def affineBlocksb(self, mu : complex = 0., homogeneized : bool = False)\
-> Tuple[List[Np1D], callable]:
"""Assemble affine blocks of RHS of linear system."""
return (self.affineLinearSystemb(mu, homogeneized),
self.affineWeightsb(mu, homogeneized))
+ def setSolver(self, solverType:str, solverArgs : DictAny = {}):
+ """Choose solver type and parameters."""
+ self._solver, self._solverArgs = setupSolver(solverType, solverArgs)
+
def solve(self, mu:complex, RHS : Np1D = None,
homogeneized : bool = False) -> Np1D:
"""
Find solution of linear system.
Args:
mu: parameter value.
RHS: RHS of linear system. If None, defaults to that of parametric
system. Defaults to None.
"""
A = self.A(mu)
if RHS is None: RHS = self.b(mu, 0, homogeneized)
- return scspla.spsolve(A, RHS)
+ return self._solver(A, RHS, self._solverArgs)
def residual(self, u:Np1D, mu:complex,
homogeneized : bool = False) -> Np1D:
"""
Find residual of linear system for given approximate solution.
Args:
u: numpy complex array with function dofs. If None, set to 0.
mu: parameter value.
"""
A = self.A(mu)
RHS = self.b(mu, 0, homogeneized)
if u is None: return RHS
return RHS - A.dot(u)
def plot(self, u:Np1D, name : str = "u", save : str = None,
what : strLst = 'all', saveFormat : str = "eps",
saveDPI : int = 100, **figspecs):
"""
Do some nice plots of the complex-valued function with given dofs.
Args:
u: numpy complex array with function dofs.
name(optional): Name to be shown as title of the plots. Defaults to
'u'.
save(optional): Where to save plot(s). Defaults to None, i.e. no
saving.
what(optional): Which plots to do. If list, can contain 'ABS',
'PHASE', 'REAL', 'IMAG'. If str, same plus wildcard 'ALL'.
Defaults to 'ALL'.
saveFormat(optional): Format for saved plot(s). Defaults to "eps".
saveDPI(optional): DPI for saved plot(s). Defaults to 100.
figspecs(optional key args): Optional arguments for matplotlib
figure creation.
"""
if isinstance(what, (str,)):
if what.upper() == 'ALL':
what = ['ABS', 'PHASE', 'REAL', 'IMAG']
else:
what = [what]
what = purgeList(what, ['ABS', 'PHASE', 'REAL', 'IMAG'],
listname = self.name() + ".what", baselevel = 1)
if len(what) == 0: return
if 'figsize' not in figspecs.keys():
figspecs['figsize'] = (13. * len(what) / 4, 3)
subplotcode = 100 + len(what) * 10
plt.figure(**figspecs)
plt.jet()
if 'ABS' in what:
uAb = fen.Function(self.V)
uAb.vector().set_local(np.abs(u))
subplotcode = subplotcode + 1
plt.subplot(subplotcode)
p = fen.plot(uAb, title = "|{0}|".format(name))
plt.colorbar(p)
if 'PHASE' in what:
uPh = fen.Function(self.V)
uPh.vector().set_local(np.angle(u))
subplotcode = subplotcode + 1
plt.subplot(subplotcode)
p = fen.plot(uPh, title = "phase({0})".format(name))
plt.colorbar(p)
if 'REAL' in what:
uRe = fen.Function(self.V)
uRe.vector().set_local(np.real(u))
subplotcode = subplotcode + 1
plt.subplot(subplotcode)
p = fen.plot(uRe, title = "Re({0})".format(name))
plt.colorbar(p)
if 'IMAG' in what:
uIm = fen.Function(self.V)
uIm.vector().set_local(np.imag(u))
subplotcode = subplotcode + 1
plt.subplot(subplotcode)
p = fen.plot(uIm, title = "Im({0})".format(name))
plt.colorbar(p)
if save is not None:
save = save.strip()
plt.savefig(getNewFilename("{}_fig_".format(save), saveFormat),
format = saveFormat, dpi = saveDPI)
plt.show()
plt.close()
def plotmesh(self, name : str = "Mesh", save : str = None,
saveFormat : str = "eps", saveDPI : int = 100, **figspecs):
"""
Do a nice plot of the mesh.
Args:
u: numpy complex array with function dofs.
name(optional): Name to be shown as title of the plots. Defaults to
'u'.
save(optional): Where to save plot(s). Defaults to None, i.e. no
saving.
saveFormat(optional): Format for saved plot(s). Defaults to "eps".
saveDPI(optional): DPI for saved plot(s). Defaults to 100.
figspecs(optional key args): Optional arguments for matplotlib
figure creation.
"""
plt.figure(**figspecs)
fen.plot(self.V.mesh())
if save is not None:
save = save.strip()
plt.savefig(getNewFilename("{}_msh_".format(save), saveFormat),
format = saveFormat, dpi = saveDPI)
plt.show()
plt.close()
def outParaview(self, u:Np1D, name : str = "u", filename : str = "out",
time : float = 0., what : strLst = 'all',
forceNewFile : bool = True, folder : bool = False,
filePW = None):
"""
Output complex-valued function with given dofs to ParaView file.
Args:
u: numpy complex array with function dofs.
name(optional): Base name to be used for data output.
filename(optional): Name of output file.
time(optional): Timestamp.
what(optional): Which plots to do. If list, can contain 'MESH',
'ABS', 'PHASE', 'REAL', 'IMAG'. If str, same plus wildcard
'ALL'. Defaults to 'ALL'.
forceNewFile(optional): Whether to create new output file.
folder(optional): Whether to create an additional folder layer.
filePW(optional): Fenics File entity (for time series).
"""
if isinstance(what, (str,)):
if what.upper() == 'ALL':
what = ['MESH', 'ABS', 'PHASE', 'REAL', 'IMAG']
else:
what = [what]
what = purgeList(what, ['MESH', 'ABS', 'PHASE', 'REAL', 'IMAG'],
listname = self.name() + ".what", baselevel = 1)
if len(what) == 0: return
if filePW is None:
if folder:
if not path.exists(filename + "/"):
mkdir(filename)
idxpath = filename.rfind("/")
filename += "/" + filename[idxpath + 1 :]
if forceNewFile:
filePW = fen.File(getNewFilename(filename, "pvd"))
else:
filePW = fen.File("{}.pvd".format(filename))
if what == ['MESH']:
filePW << (self.V.mesh(), time)
if 'ABS' in what:
uAb = fen.Function(self.V, name = "{}_ABS".format(name))
uAb.vector().set_local(np.abs(u))
filePW << (uAb, time)
if 'PHASE' in what:
uPh = fen.Function(self.V, name = "{}_PHASE".format(name))
uPh.vector().set_local(np.angle(u))
filePW << (uPh, time)
if 'REAL' in what:
uRe = fen.Function(self.V, name = "{}_REAL".format(name))
uRe.vector().set_local(np.real(u))
filePW << (uRe, time)
if 'IMAG' in what:
uIm = fen.Function(self.V, name = "{}_IMAG".format(name))
uIm.vector().set_local(np.imag(u))
filePW << (uIm, time)
return filePW
def outParaviewTimeDomain(self, u:Np1D, omega:float,
timeFinal : float = None,
periodResolution : int = 20, name : str = "u",
filename : str = "out",
forceNewFile : bool = True,
folder : bool = False):
"""
Output complex-valued function with given dofs to ParaView file,
converted to time domain.
Args:
u: numpy complex array with function dofs.
omega: frequency.
timeFinal(optional): final time of simulation.
periodResolution(optional): number of time steps per period.
name(optional): Base name to be used for data output.
filename(optional): Name of output file.
forceNewFile(optional): Whether to create new output file.
folder(optional): Whether to create an additional folder layer.
"""
if folder:
if not path.exists(filename + "/"):
mkdir(filename)
idxpath = filename.rfind("/")
filename += "/" + filename[idxpath + 1 :]
if forceNewFile:
filePW = fen.File(getNewFilename(filename, "pvd"))
else:
filePW = fen.File("{}.pvd".format(filename))
t = 0.
dt = 2. * np.pi / omega / periodResolution
if timeFinal is None: timeFinal = 2. * np.pi / omega - dt
for j in range(int(timeFinal / dt) + 1):
ut = fen.Function(self.V, name = name)
ut.vector().set_local(np.real(u) * np.cos(omega * t)
+ np.imag(u) * np.sin(omega * t))
filePW << (ut, t)
t += dt
return filePW
diff --git a/rrompy/solver/__init__.py b/rrompy/solver/__init__.py
new file mode 100644
index 0000000..ff83f0c
--- /dev/null
+++ b/rrompy/solver/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (C) 2018 by the RROMPy authors
+#
+# This file is part of RROMPy.
+#
+# RROMPy is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# RROMPy is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with RROMPy. If not, see .
+#
+
+from .linear_solver import RROMPyLinearSolvers, setupSolver
+
+__all__ = [
+ 'RROMPyLinearSolvers',
+ 'setupSolver'
+ ]
+
+
+
+
diff --git a/rrompy/solver/linear_solver.py b/rrompy/solver/linear_solver.py
new file mode 100644
index 0000000..cf6f35f
--- /dev/null
+++ b/rrompy/solver/linear_solver.py
@@ -0,0 +1,65 @@
+# Copyright (C) 2018 by the RROMPy authors
+#
+# This file is part of RROMPy.
+#
+# RROMPy is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# RROMPy is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with RROMPy. If not, see .
+#
+
+import scipy.sparse.linalg as scspla
+from rrompy.utilities.base import purgeDict
+from rrompy.utilities.base.types import Tuple, DictAny
+from rrompy.utilities.exception_manager import RROMPyException
+
+__all__ = ['RROMPyLinearSolvers', 'setupSolver']
+
+RROMPyLinearSolvers = {
+ "SPSOLVE" : (lambda A, b, kwargs: scspla.spsolve(A, b, **kwargs),
+ ["permc_spec", "use_umfpack"]),
+ "BICG" : (lambda A, b, kwargs: scspla.bicg(A, b, **kwargs)[0],
+ ["x0", "tol", "maxiter", "M", "callback", "atol"]),
+ "BICGSTAB" : (lambda A, b, kwargs: scspla.bicgstab(A, b, **kwargs)[0],
+ ["x0", "tol", "maxiter", "M", "callback", "atol"]),
+ "CG" : (lambda A, b, kwargs: scspla.cg(A, b, **kwargs)[0],
+ ["x0", "tol", "maxiter", "M", "callback", "atol"]),
+ "CGS" : (lambda A, b, kwargs: scspla.cgs(A, b, **kwargs)[0],
+ ["x0", "tol", "maxiter", "M", "callback", "atol"]),
+ "GMRES" : (lambda A, b, kwargs: scspla.gmres(A, b, **kwargs)[0],
+ ["x0", "tol", "restart", "maxiter", "M", "callback",
+ "restrt", "atol"]),
+ "LGMRES" : (lambda A, b, kwargs: scspla.lgmres(A, b, **kwargs)[0],
+ ["x0", "tol", "maxiter", "M", "callback", "inner_m",
+ "outer_k", "outer_v", "store_outer_Av",
+ "prepend_outer_v", "atol"]),
+ "MINRES" : (lambda A, b, kwargs: scspla.minres(A, b, **kwargs)[0],
+ ["x0", "shift", "tol", "maxiter", "M", "callback",
+ "show", "check"]),
+ "QMR" : (lambda A, b, kwargs: scspla.qmr(A, b, **kwargs)[0],
+ ["x0", "tol", "maxiter", "M1", "M2", "callback",
+ "atol"]),
+ "GCROTMK" : (lambda A, b, kwargs: scspla.gcrotmk(A, b, **kwargs)[0],
+ ["x0", "tol", "maxiter", "M", "callback", "m", "k", "CU",
+ "discard_C", "truncate", "atol"])
+ }
+
+def setupSolver(solverType:str, solverArgs : DictAny = {})\
+ -> Tuple[callable, DictAny]:
+ solverType = solverType.upper()
+ if solverType not in RROMPyLinearSolvers.keys():
+ raise RROMPyException(("Solver type not recognized. Check allowed "
+ "values in RROMPyLinearSolvers.keys()."))
+ solver, solverArgsList = RROMPyLinearSolvers[solverType]
+ solverArgs = purgeDict(solverArgs, solverArgsList,
+ dictname = "{}.kwargs".format(solverType),
+ baselevel = 1)
+ return solver, solverArgs