diff --git a/rrompy/hfengines/base/problem_engine_base.py b/rrompy/hfengines/base/problem_engine_base.py index 3c13438..5144951 100644 --- a/rrompy/hfengines/base/problem_engine_base.py +++ b/rrompy/hfengines/base/problem_engine_base.py @@ -1,362 +1,366 @@ # Copyright (C) 2018 by the RROMPy authors # # This file is part of RROMPy. # # RROMPy is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # RROMPy is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with RROMPy. If not, see . # from os import path, mkdir import fenics as fen import numpy as np from matplotlib import pyplot as plt from copy import deepcopy as copy from rrompy.utilities.base.types import (Np1D, strLst, FenFunc, Tuple, List, paramVal) from rrompy.utilities.base import (purgeList, getNewFilename, verbosityManager as vbMng) from rrompy.solver import Np2DLikeEye from rrompy.solver.fenics import L2NormMatrix, fenplot, interp_project from .boundary_conditions import BoundaryConditions from .matrix_engine_base import MatrixEngineBase from rrompy.utilities.exception_manager import RROMPyException __all__ = ['ProblemEngineBase'] class ProblemEngineBase(MatrixEngineBase): """ Generic solver for parametric problems. Attributes: verbosity: Verbosity level. BCManager: Boundary condition manager. V: Real FE space. u: Generic trial functions for variational form evaluation. v: Generic test functions for variational form evaluation. As: Scipy sparse array representation (in CSC format) of As. bs: Numpy array representation of bs. energyNormMatrix: Scipy sparse matrix representing inner product over V. energyNormDualMatrix: Scipy sparse matrix representing inner product over V'. dualityMatrix: Scipy sparse matrix representing duality V-V'. energyNormPartialDualMatrix: Scipy sparse matrix representing dual inner product between Riesz representers V-V. liftedDirichletDatum: Dofs of Dirichlet datum lifting. mu0BC: Mu value of last Dirichlet datum lifting. degree_threshold: Threshold for ufl expression interpolation degree. """ _dualityCompress = None def __init__(self, degree_threshold : int = np.inf, verbosity : int = 10, timestamp : bool = True): super().__init__(verbosity = verbosity, timestamp = timestamp) self.BCManager = BoundaryConditions("Dirichlet") self.V = fen.FunctionSpace(fen.UnitSquareMesh(10, 10), "P", 1) self.mu0BC = np.nan self.degree_threshold = degree_threshold self.npar = 0 @property def V(self): """Value of V.""" return self._V @V.setter def V(self, V): self.resetAs() self.resetbs() if not type(V).__name__ == 'FunctionSpace': raise RROMPyException("V type not recognized.") self._V = V self.u = fen.TrialFunction(V) self.v = fen.TestFunction(V) def spacedim(self): return self.V.dim() def buildEnergyNormForm(self): """ Build sparse matrix (in CSR format) representative of scalar product. """ vbMng(self, "INIT", "Assembling energy matrix.", 20) self.energyNormMatrix = L2NormMatrix(self.V) vbMng(self, "DEL", "Done assembling energy matrix.", 20) def buildDualityPairingForm(self): """Build sparse matrix (in CSR format) representative of duality.""" vbMng(self, "INIT", "Assembling duality matrix.", 20) self.dualityMatrix = Np2DLikeEye() vbMng(self, "DEL", "Done assembling duality matrix.", 20) def liftDirichletData(self, mu : paramVal = []) -> Np1D: """Lift Dirichlet datum.""" mu = self.checkParameter(mu) if mu != self.mu0BC: self.mu0BC = copy(mu) liftRe = interp_project(self.DirichletDatum[0], self.V) liftIm = interp_project(self.DirichletDatum[1], self.V) self.liftedDirichletDatum = (np.array(liftRe.vector()) + 1.j * np.array(liftIm.vector())) return self.liftedDirichletDatum def reduceQuadratureDegree(self, fun:FenFunc, name:str): """Check whether to reduce compiler parameters to degree threshold.""" if not np.isinf(self.degree_threshold): from ufl.algorithms.estimate_degrees import ( estimate_total_polynomial_degree as ETPD) try: deg = ETPD(fun) except: return False if deg > self.degree_threshold: vbMng(self, "MAIN", ("Reducing quadrature degree from {} to {} for " "{}.").format(deg, self.degree_threshold, name), 15) return True return False def iterReduceQuadratureDegree(self, funsNames:List[Tuple[FenFunc, str]]): """ Iterate reduceQuadratureDegree over list and define reduce compiler parameters. """ if funsNames is not None: for fun, name in funsNames: if self.reduceQuadratureDegree(fun, name): return {"quadrature_degree" : self.degree_threshold} return {} def plot(self, u:Np1D, warping : List[callable] = None, name : str = "u", save : str = None, what : strLst = 'all', saveFormat : str = "eps", saveDPI : int = 100, show : bool = True, fenplotArgs : dict = {}, **figspecs): """ Do some nice plots of the complex-valued function with given dofs. Args: u: numpy complex array with function dofs. warping(optional): Domain warping functions. name(optional): Name to be shown as title of the plots. Defaults to 'u'. save(optional): Where to save plot(s). Defaults to None, i.e. no saving. what(optional): Which plots to do. If list, can contain 'ABS', 'PHASE', 'REAL', 'IMAG'. If str, same plus wildcard 'ALL'. Defaults to 'ALL'. saveFormat(optional): Format for saved plot(s). Defaults to "eps". saveDPI(optional): DPI for saved plot(s). Defaults to 100. show(optional): Whether to show figure. Defaults to True. fenplotArgs(optional): Optional arguments for fenplot. figspecs(optional key args): Optional arguments for matplotlib figure creation. """ if isinstance(what, (str,)): if what.upper() == 'ALL': what = ['ABS', 'PHASE', 'REAL', 'IMAG'] else: what = [what] what = purgeList(what, ['ABS', 'PHASE', 'REAL', 'IMAG'], listname = self.name() + ".what", baselevel = 1) if len(what) == 0: return if 'figsize' not in figspecs.keys(): figspecs['figsize'] = (13. * len(what) / 4, 3) subplotcode = 100 + len(what) * 10 plt.figure(**figspecs) plt.jet() if 'ABS' in what: uAb = fen.Function(self.V) uAb.vector().set_local(np.abs(u)) subplotcode = subplotcode + 1 plt.subplot(subplotcode) p = fenplot(uAb, warping = warping, title = "|{0}|".format(name), **fenplotArgs) - plt.colorbar(p) + if self.V.mesh().geometric_dimension() > 1: + plt.colorbar(p) if 'PHASE' in what: uPh = fen.Function(self.V) uPh.vector().set_local(np.angle(u)) subplotcode = subplotcode + 1 plt.subplot(subplotcode) p = fenplot(uPh, warping = warping, title = "phase({0})".format(name), **fenplotArgs) - plt.colorbar(p) + if self.V.mesh().geometric_dimension() > 1: + plt.colorbar(p) if 'REAL' in what: uRe = fen.Function(self.V) uRe.vector().set_local(np.real(u)) subplotcode = subplotcode + 1 plt.subplot(subplotcode) p = fenplot(uRe, warping = warping, title = "Re({0})".format(name), **fenplotArgs) - plt.colorbar(p) + if self.V.mesh().geometric_dimension() > 1: + plt.colorbar(p) if 'IMAG' in what: uIm = fen.Function(self.V) uIm.vector().set_local(np.imag(u)) subplotcode = subplotcode + 1 plt.subplot(subplotcode) p = fenplot(uIm, warping = warping, title = "Im({0})".format(name), **fenplotArgs) - plt.colorbar(p) + if self.V.mesh().geometric_dimension() > 1: + plt.colorbar(p) if save is not None: save = save.strip() plt.savefig(getNewFilename("{}_fig_".format(save), saveFormat), format = saveFormat, dpi = saveDPI) if show: plt.show() plt.close() def plotmesh(self, warping : List[callable] = None, name : str = "Mesh", save : str = None, saveFormat : str = "eps", saveDPI : int = 100, show : bool = True, fenplotArgs : dict = {}, **figspecs): """ Do a nice plot of the mesh. Args: u: numpy complex array with function dofs. warping(optional): Domain warping functions. name(optional): Name to be shown as title of the plots. Defaults to 'u'. save(optional): Where to save plot(s). Defaults to None, i.e. no saving. saveFormat(optional): Format for saved plot(s). Defaults to "eps". saveDPI(optional): DPI for saved plot(s). Defaults to 100. show(optional): Whether to show figure. Defaults to True. fenplotArgs(optional): Optional arguments for fenplot. figspecs(optional key args): Optional arguments for matplotlib figure creation. """ plt.figure(**figspecs) fenplot(self.V.mesh(), warping = warping, **fenplotArgs) if save is not None: save = save.strip() plt.savefig(getNewFilename("{}_msh_".format(save), saveFormat), format = saveFormat, dpi = saveDPI) if show: plt.show() plt.close() def outParaview(self, u:Np1D, warping : List[callable] = None, name : str = "u", filename : str = "out", time : float = 0., what : strLst = 'all', forceNewFile : bool = True, folder : bool = False, filePW = None): """ Output complex-valued function with given dofs to ParaView file. Args: u: numpy complex array with function dofs. warping(optional): Domain warping functions. name(optional): Base name to be used for data output. filename(optional): Name of output file. time(optional): Timestamp. what(optional): Which plots to do. If list, can contain 'MESH', 'ABS', 'PHASE', 'REAL', 'IMAG'. If str, same plus wildcard 'ALL'. Defaults to 'ALL'. forceNewFile(optional): Whether to create new output file. folder(optional): Whether to create an additional folder layer. filePW(optional): Fenics File entity (for time series). """ if isinstance(what, (str,)): if what.upper() == 'ALL': what = ['MESH', 'ABS', 'PHASE', 'REAL', 'IMAG'] else: what = [what] what = purgeList(what, ['MESH', 'ABS', 'PHASE', 'REAL', 'IMAG'], listname = self.name() + ".what", baselevel = 1) if len(what) == 0: return if filePW is None: if folder: if not path.exists(filename + "/"): mkdir(filename) idxpath = filename.rfind("/") filename += "/" + filename[idxpath + 1 :] if forceNewFile: filePW = fen.File(getNewFilename(filename, "pvd")) else: filePW = fen.File("{}.pvd".format(filename)) if warping is not None: fen.ALE.move(self.V.mesh(), interp_project(warping[0], self.V.mesh())) if what == ['MESH']: filePW << (self.V.mesh(), time) if 'ABS' in what: uAb = fen.Function(self.V, name = "{}_ABS".format(name)) uAb.vector().set_local(np.abs(u)) filePW << (uAb, time) if 'PHASE' in what: uPh = fen.Function(self.V, name = "{}_PHASE".format(name)) uPh.vector().set_local(np.angle(u)) filePW << (uPh, time) if 'REAL' in what: uRe = fen.Function(self.V, name = "{}_REAL".format(name)) uRe.vector().set_local(np.real(u)) filePW << (uRe, time) if 'IMAG' in what: uIm = fen.Function(self.V, name = "{}_IMAG".format(name)) uIm.vector().set_local(np.imag(u)) filePW << (uIm, time) if warping is not None: fen.ALE.move(self.V.mesh(), interp_project(warping[1], self.V.mesh())) return filePW def outParaviewTimeDomain(self, u:Np1D, omega:float, warping : List[callable] = None, timeFinal : float = None, periodResolution : int = 20, name : str = "u", filename : str = "out", forceNewFile : bool = True, folder : bool = False): """ Output complex-valued function with given dofs to ParaView file, converted to time domain. Args: u: numpy complex array with function dofs. omega: frequency. warping(optional): Domain warping functions. timeFinal(optional): final time of simulation. periodResolution(optional): number of time steps per period. name(optional): Base name to be used for data output. filename(optional): Name of output file. forceNewFile(optional): Whether to create new output file. folder(optional): Whether to create an additional folder layer. """ if folder: if not path.exists(filename + "/"): mkdir(filename) idxpath = filename.rfind("/") filename += "/" + filename[idxpath + 1 :] if forceNewFile: filePW = fen.File(getNewFilename(filename, "pvd")) else: filePW = fen.File("{}.pvd".format(filename)) omega = np.abs(omega) t = 0. dt = 2. * np.pi / omega / periodResolution if timeFinal is None: timeFinal = 2. * np.pi / omega - dt if warping is not None: fen.ALE.move(self.V.mesh(), interp_project(warping[0], self.V.mesh())) for j in range(int(np.ceil(timeFinal / dt)) + 1): ut = fen.Function(self.V, name = name) ut.vector().set_local(np.real(u) * np.cos(omega * t) + np.imag(u) * np.sin(omega * t)) filePW << (ut, t) t += dt if warping is not None: fen.ALE.move(self.V.mesh(), interp_project(warping[1], self.V.mesh())) return filePW diff --git a/rrompy/hfengines/base/vector_problem_engine_base.py b/rrompy/hfengines/base/vector_problem_engine_base.py index 4f6f7fc..3cdbbbc 100644 --- a/rrompy/hfengines/base/vector_problem_engine_base.py +++ b/rrompy/hfengines/base/vector_problem_engine_base.py @@ -1,218 +1,224 @@ # Copyright (C) 2018 by the RROMPy authors # # This file is part of RROMPy. # # RROMPy is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # RROMPy is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with RROMPy. If not, see . # import fenics as fen import numpy as np from matplotlib import pyplot as plt from rrompy.utilities.base.types import Np1D, List, strLst from rrompy.utilities.base import purgeList, getNewFilename from rrompy.solver.fenics import fenplot from .problem_engine_base import ProblemEngineBase __all__ = ['VectorProblemEngineBase'] class VectorProblemEngineBase(ProblemEngineBase): """ Generic solver for parametric vector problems. Attributes: verbosity: Verbosity level. BCManager: Boundary condition manager. V: Real FE space. u: Generic trial functions for variational form evaluation. v: Generic test functions for variational form evaluation. As: Scipy sparse array representation (in CSC format) of As. bs: Numpy array representation of bs. energyNormMatrix: Scipy sparse matrix representing inner product over V. energyNormDualMatrix: Scipy sparse matrix representing inner product over V'. dualityMatrix: Scipy sparse matrix representing duality V-V'. energyNormPartialDualMatrix: Scipy sparse matrix representing dual inner product between Riesz representers V-V. liftedDirichletDatum: Dofs of Dirichlet datum lifting. mu0BC: Mu value of last Dirichlet datum lifting. degree_threshold: Threshold for ufl expression interpolation degree. """ def __init__(self, degree_threshold : int = np.inf, verbosity : int = 10, timestamp : bool = True): super().__init__(degree_threshold = degree_threshold, verbosity = verbosity, timestamp = timestamp) self.V = fen.VectorFunctionSpace(fen.UnitSquareMesh(10, 10), "P", 1) self.npar = 0 def plot(self, u:Np1D, warping : List[callable] = None, name : str = "u", save : str = None, what : strLst = 'all', saveFormat : str = "eps", saveDPI : int = 100, show : bool = True, fenplotArgs : dict = {}, **figspecs): """ Do some nice plots of the complex-valued function with given dofs. Args: u: numpy complex array with function dofs. warping(optional): Domain warping functions. name(optional): Name to be shown as title of the plots. Defaults to 'u'. save(optional): Where to save plot(s). Defaults to None, i.e. no saving. what(optional): Which plots to do. If list, can contain 'ABS', 'PHASE', 'REAL', 'IMAG'. If str, same plus wildcard 'ALL'. Defaults to 'ALL'. saveFormat(optional): Format for saved plot(s). Defaults to "eps". saveDPI(optional): DPI for saved plot(s). Defaults to. show(optional): Whether to show figure. Defaults to True. fenplotArgs(optional): Optional arguments for fenplot. figspecs(optional key args): Optional arguments for matplotlib figure creation. """ if isinstance(what, (str,)): if what.upper() == 'ALL': what = ['ABS', 'PHASE', 'REAL', 'IMAG'] else: what = [what] what = purgeList(what, ['ABS', 'PHASE', 'REAL', 'IMAG'], listname = self.name() + ".what", baselevel = 1) if 'figsize' not in figspecs.keys(): figspecs['figsize'] = (13. * max(len(what), 1) / 4, 3) if len(what) > 0: for j in range(self.V.num_sub_spaces()): subplotcode = 100 + len(what) * 10 II = self.V.sub(j).dofmap().dofs() Vj = self.V.sub(j).collapse() plt.figure(**figspecs) plt.jet() if 'ABS' in what: uAb = fen.Function(Vj) uAb.vector().set_local(np.abs(u[II])) subplotcode = subplotcode + 1 plt.subplot(subplotcode) p = fenplot(uAb, warping = warping, title = "|{}_comp{}|".format(name, j, **fenplotArgs)) - plt.colorbar(p) + if self.V.mesh().geometric_dimension() > 1: + plt.colorbar(p) if 'PHASE' in what: uPh = fen.Function(Vj) uPh.vector().set_local(np.angle(u[II])) subplotcode = subplotcode + 1 plt.subplot(subplotcode) p = fenplot(uPh, warping = warping, title = "phase({}_comp{})".format(name, j), **fenplotArgs) - plt.colorbar(p) + if self.V.mesh().geometric_dimension() > 1: + plt.colorbar(p) if 'REAL' in what: uRe = fen.Function(Vj) uRe.vector().set_local(np.real(u[II])) subplotcode = subplotcode + 1 plt.subplot(subplotcode) p = fenplot(uRe, warping = warping, title = "Re({}_comp{})".format(name, j), **fenplotArgs) - plt.colorbar(p) + if self.V.mesh().geometric_dimension() > 1: + plt.colorbar(p) if 'IMAG' in what: uIm = fen.Function(Vj) uIm.vector().set_local(np.imag(u[II])) subplotcode = subplotcode + 1 plt.subplot(subplotcode) p = fenplot(uIm, warping = warping, title = "Im({}_comp{})".format(name, j), **fenplotArgs) - plt.colorbar(p) + if self.V.mesh().geometric_dimension() > 1: + plt.colorbar(p) if save is not None: save = save.strip() plt.savefig(getNewFilename("{}_comp{}_fig_".format(save,j), saveFormat), format = saveFormat, dpi = saveDPI) if show: plt.show() plt.close() try: if len(what) > 1: figspecs['figsize'] = (2. / len(what) * figspecs['figsize'][0], figspecs['figsize'][1]) elif len(what) == 0: figspecs['figsize'] = (2. * figspecs['figsize'][0], figspecs['figsize'][1]) if len(what) == 0 or 'ABS' in what or 'REAL' in what: uVRe = fen.Function(self.V) uVRe.vector().set_local(np.real(u)) plt.figure(**figspecs) plt.jet() p = fenplot(uVRe, warping = warping, title = "{}_Re".format(name), mode = "displacement", **fenplotArgs) - plt.colorbar(p) + if self.V.mesh().geometric_dimension() > 1: + plt.colorbar(p) if save is not None: save = save.strip() plt.savefig(getNewFilename("{}_disp_Re_fig_".format(save), saveFormat), format = saveFormat, dpi = saveDPI) plt.show() plt.close() if 'ABS' in what or 'IMAG' in what: uVIm = fen.Function(self.V) uVIm.vector().set_local(np.imag(u)) plt.figure(**figspecs) plt.jet() p = fenplot(uVIm, warping = warping, title = "{}_Im".format(name), mode = "displacement", **fenplotArgs) - plt.colorbar(p) + if self.V.mesh().geometric_dimension() > 1: + plt.colorbar(p) if save is not None: save = save.strip() plt.savefig(getNewFilename("{}_disp_Im_fig_".format(save, j), saveFormat), format = saveFormat, dpi = saveDPI) if show: plt.show() plt.close() except: pass def plotmesh(self, warping : List[callable] = None, name : str = "Mesh", save : str = None, saveFormat : str = "eps", saveDPI : int = 100, show : bool = True, fenplotArgs : dict = {}, **figspecs): """ Do a nice plot of the mesh. Args: u: numpy complex array with function dofs. warping(optional): Domain warping functions. name(optional): Name to be shown as title of the plots. Defaults to 'u'. save(optional): Where to save plot(s). Defaults to None, i.e. no saving. saveFormat(optional): Format for saved plot(s). Defaults to "eps". saveDPI(optional): DPI for saved plot(s). Defaults to 100. show(optional): Whether to show figure. Defaults to True. fenplotArgs(optional): Optional arguments for fenplot. figspecs(optional key args): Optional arguments for matplotlib figure creation. """ plt.figure(**figspecs) fenplot(self.V.mesh(), warping = warping, **fenplotArgs) if save is not None: save = save.strip() plt.savefig(getNewFilename("{}_msh_".format(save), saveFormat), format = saveFormat, dpi = saveDPI) if show: plt.show() plt.close() diff --git a/rrompy/sampling/base/pod_engine.py b/rrompy/sampling/base/pod_engine.py index 82a8699..316a70e 100644 --- a/rrompy/sampling/base/pod_engine.py +++ b/rrompy/sampling/base/pod_engine.py @@ -1,124 +1,133 @@ # Copyright (C) 2018 by the RROMPy authors # # This file is part of RROMPy. # # RROMPy is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # RROMPy is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with RROMPy. If not, see . # import numpy as np from copy import deepcopy as copy from rrompy.utilities.base.types import Np1D, Np2D, Tuple, HFEng, sampList from rrompy.sampling import sampleList __all__ = ['PODEngine'] class PODEngine: """ POD engine for general matrix orthogonalization. """ def __init__(self, HFEngine:HFEng): self.HFEngine = HFEngine def name(self) -> str: return self.__class__.__name__ def __str__(self) -> str: return self.name() def __repr__(self) -> str: return self.__str__() + " at " + hex(id(self)) def GS(self, a:Np1D, Q:sampList, n : int = -1) -> Tuple[Np1D, Np1D, bool]: """ Compute 1 Gram-Schmidt step with given projector. Args: a: vector to be projected; Q: orthogonal projection matrix; n: number of columns of Q to be considered; Returns: Resulting normalized vector, coefficients of a wrt the updated basis, whether computation is ill-conditioned. """ if n == -1: n = Q.shape[1] r = np.zeros((n + 1,), dtype = Q.dtype) if n > 0: Q = Q[: n] for j in range(2): # twice is enough! nu = self.HFEngine.innerProduct(a, Q) a = a - Q.dot(nu) r[:-1] = r[:-1] + nu.flatten() r[-1] = self.HFEngine.norm(a) ill_cond = False if np.isclose(np.abs(r[-1]) / np.linalg.norm(r), 0.): ill_cond = True r[-1] = 1. a = a / r[-1] return a, r, ill_cond def generalizedQR(self, A:sampList, Q0 : sampList = None, - only_R : bool = False) -> Tuple[sampList, Np2D]: + only_R : bool = False, + genTrials : int = 10) -> Tuple[sampList, Np2D]: """ Compute generalized QR decomposition of a matrix through Householder method. Args: A: matrix to be decomposed; Q0(optional): initial orthogonal guess for Q; defaults to random; only_R(optional): whether to skip reconstruction of Q; defaults to False. + genTrials(optional): number of trials of generation of linearly + independent vector; defaults to 10. Returns: Resulting (orthogonal and )upper-triangular factor(s). """ Nh, N = A.shape B = copy(A) V = copy(A) R = np.zeros((N, N), dtype = A.dtype) if Q0 is None: Q = sampleList(np.zeros(A.shape, dtype = A.dtype) + np.random.randn(*(A.shape))) else: Q = copy(Q0) for k in range(N): - if k <= Nh: - if Q0 is None: - illC = True - while illC: - Q[k], _, illC = self.GS(np.random.randn(Nh), Q, k) - else: - Q[k] = np.zeros(Nh, dtype = Q.dtype) a = B[k] R[k, k] = self.HFEngine.norm(a) - alpha = self.HFEngine.innerProduct(a, Q[k]) - if np.isclose(np.abs(alpha), 0.): s = 1. - else: s = - alpha / np.abs(alpha) - Q[k] = s * Q[k] + if Q0 is None: + for _ in range(genTrials): + Q[k], _, illC = self.GS(np.random.randn(Nh), Q, k) + if not illC: break + else: + illC = False + if illC: + Q[k] = np.zeros(Nh, dtype = Q.dtype) + alpha = 0. + else: + alpha = self.HFEngine.innerProduct(a, Q[k]) + if np.isclose(np.abs(alpha), 0.): s = 1. + else: s = - alpha / np.abs(alpha) + Q[k] = s * Q[k] V[k], _, _ = self.GS(R[k, k] * Q[k] - a, Q, k) J = np.arange(k + 1, N) vtB = self.HFEngine.innerProduct(B[J], V[k]) B.data[:, J] -= 2 * np.outer(V[k], vtB) - R[k, J] = self.HFEngine.innerProduct(B[J], Q[k]) - B.data[:, J] -= np.outer(Q[k], R[k, J]) + if illC: + R[k, J] = 0. + else: + R[k, J] = self.HFEngine.innerProduct(B[J], Q[k]) + B.data[:, J] -= np.outer(Q[k], R[k, J]) if only_R: return R for k in range(N - 1, -1, -1): J = list(range(k, N)) vtQ = self.HFEngine.innerProduct(Q[J], V[k]) Q.data[:, J] -= 2 * np.outer(V[k], vtQ) return Q, R