Page MenuHomec4science

trained_model.py
No OneTemporary

File Metadata

Created
Wed, May 1, 13:49

trained_model.py

# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
from abc import abstractmethod
from rrompy.utilities.base.types import Np1D, paramList, sampList
from rrompy.parameter import checkParameterList
from rrompy.sampling import sampleList, emptySampleList
__all__ = ['TrainedModel']
class TrainedModel:
"""
ABSTRACT
ROM approximant evaluation.
Attributes:
Data: dictionary with all that can be pickled.
"""
def name(self) -> str:
return self.__class__.__name__
def __str__(self) -> str:
return self.name()
def __repr__(self) -> str:
return self.__str__() + " at " + hex(id(self))
@abstractmethod
def getApproxReduced(self, mu : paramList = []) -> sampList:
"""
Evaluate reduced representation of approximant at arbitrary parameter.
(ABSTRACT)
Args:
mu: Target parameter.
"""
pass
def getApprox(self, mu : paramList = []) -> sampList:
"""
Evaluate approximant at arbitrary parameter.
Args:
mu: Target parameter.
"""
mu = checkParameterList(mu, self.data.npar)[0]
if (not hasattr(self, "lastSolvedApprox")
or self.lastSolvedApprox != mu):
uApproxR = self.getApproxReduced(mu)
self.uApprox = emptySampleList()
self.uApprox.reset((self.data.projMat.shape[0], len(mu)),
self.data.projMat.dtype)
for i in range(len(mu)):
if isinstance(self.data.projMat, (list, sampleList,)):
self.uApprox[i] = uApproxR[i][0] * self.data.projMat[0]
for j in range(1, uApproxR.shape[0]):
self.uApprox[i] += (uApproxR[i][j]
* self.data.projMat[j])
else:
self.uApprox[i] = self.data.projMat.dot(uApproxR[i])
self.lastSolvedApprox = mu
return self.uApprox
@abstractmethod
def getPoles(self) -> Np1D:
"""
Obtain approximant poles.
Returns:
Numpy complex vector of poles.
"""
pass

Event Timeline