Page MenuHomec4science

sampling_engine_standard.py
No OneTemporary

File Metadata

Created
Wed, May 1, 04:41

sampling_engine_standard.py

# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
from copy import deepcopy as copy
import numpy as np
from rrompy.sampling.base.sampling_engine_base import SamplingEngineBase
from rrompy.utilities.base.types import Np1D, paramVal, paramList, sampList
from rrompy.utilities.base import verbosityManager as vbMng
from rrompy.utilities.exception_manager import RROMPyException
from rrompy.utilities.poly_fitting.polynomial import nextDerivativeIndices
from rrompy.parameter import checkParameter, checkParameterList
from rrompy.sampling import sampleList
__all__ = ['SamplingEngineStandard']
class SamplingEngineStandard(SamplingEngineBase):
"""HERE"""
def preprocesssamples(self, idxs:Np1D) -> sampList:
if self.samples is None or len(self.samples) == 0: return
return self.samples(idxs)
def postprocessu(self, u:sampList, overwrite : bool = False) -> Np1D:
return copy(u)
def postprocessuBulk(self, u:sampList) -> sampList:
return copy(u)
def lastSampleManagement(self):
pass
def _getSampleConcurrence(self, mu:paramVal, previous:Np1D,
homogeneized : bool = False) -> sampList:
if len(previous) >= len(self._derIdxs):
self._derIdxs += nextDerivativeIndices(self._derIdxs,
self.HFEngine.npar,
len(previous) + 1 - len(self._derIdxs))
derIdx = self._derIdxs[len(previous)]
mu = checkParameter(mu, self.HFEngine.npar)
samplesOld = self.preprocesssamples(previous)
RHS = self.HFEngine.b(mu, derIdx, homogeneized = homogeneized)
for j, derP in enumerate(self._derIdxs[: len(previous)]):
diffP = [x - y for (x, y) in zip(derIdx, derP)]
if np.all([x >= 0 for x in diffP]):
RHS -= self.HFEngine.A(mu, diffP).dot(samplesOld[j])
return self.solveLS(mu, RHS = RHS, homogeneized = homogeneized)
def nextSample(self, mu : paramVal = [], overwrite : bool = False,
homogeneized : bool = False,
lastSample : bool = True) -> Np1D:
mu = checkParameter(mu, self.HFEngine.npar)
ns = self.nsamples
muidxs = self.mus.findall(mu[0])
if len(muidxs) > 0:
u = self._getSampleConcurrence(mu, np.sort(muidxs), homogeneized)
else:
u = self.solveLS(mu, homogeneized = homogeneized)
u = self.postprocessu(u, overwrite = overwrite)
if overwrite:
self.samples[ns] = u
self.mus[ns] = mu[0]
else:
if ns == 0:
self.samples = sampleList(u)
else:
self.samples.append(u)
self.mus.append(mu)
self.nsamples += 1
if lastSample: self.lastSampleManagement()
return u
def iterSample(self, mus:paramList,
homogeneized : bool = False) -> sampList:
mus = checkParameterList(mus, self.HFEngine.npar)[0]
vbMng(self, "INIT", "Starting sampling iterations.", 5)
n = len(mus)
if n <= 0:
raise RROMPyException(("Number of samples must be positive."))
self.resetHistory()
if self.allowRepeatedSamples:
for j in range(n):
vbMng(self, "MAIN",
"Computing sample {} / {}.".format(j + 1, n), 7)
self.nextSample(mus[j], overwrite = (j > 0),
homogeneized = homogeneized,
lastSample = (n == j + 1))
if n > 1 and j == 0:
self.preallocateSamples(self.samples[0], mus[0], n)
else:
self.samples = self.postprocessuBulk(self.solveLS(mus,
homogeneized = homogeneized))
self.mus = copy(mus)
self.nsamples = n
vbMng(self, "DEL", "Finished sampling iterations.", 5)
return self.samples

Event Timeline