Page MenuHomec4science

sampling_engine_pivoted.py
No OneTemporary

File Metadata

Created
Wed, Jun 5, 03:17

sampling_engine_pivoted.py

# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
from copy import deepcopy as copy
import numpy as np
from rrompy.sampling.base.sampling_engine_base_pivoted import (
SamplingEngineBasePivoted)
from rrompy.hfengines.base import MarginalProxyEngine
from rrompy.utilities.base.types import Np1D, paramVal, paramList, sampList
from rrompy.utilities.base import verbosityManager as vbMng, freepar as fp
from rrompy.utilities.exception_manager import RROMPyException
from rrompy.utilities.numerical import nextDerivativeIndices, dot
from rrompy.parameter import checkParameter, checkParameterList
__all__ = ['SamplingEnginePivoted']
class SamplingEnginePivoted(SamplingEngineBasePivoted):
"""HERE"""
def preprocesssamples(self, idxs:Np1D, j:int) -> sampList:
if self.samples[j] is None or len(self.samples[j]) == 0: return
return self.samples[j](idxs)
def postprocessu(self, u:sampList, j:int, overwrite : bool = False):
self.setsample(u, j, overwrite)
def postprocessuBulk(self, j:int):
pass
def _getSampleConcurrence(self, mu:paramVal, j:int,
previous:Np1D) -> sampList:
if not (self.sample_state or self.HFEngine.isCEye):
raise RROMPyException(("Derivatives of solution with non-scalar "
"C not computable."))
if not self.HFEngine._isStateShiftZero:
raise RROMPyException(("Derivatives of solution with non-zero "
"solution shift not computable."))
if len(previous) >= len(self._derIdxs[j]):
self._derIdxs[j] += nextDerivativeIndices(
self._derIdxs[j], self.nPivot,
len(previous) + 1 - len(self._derIdxs[j]))
derIdx = self._derIdxs[j][len(previous)]
mu = checkParameter(mu, self.nPivot)
samplesOld = self.preprocesssamples(previous, j)
RHS = self.HFEngineMarginalized.b(mu, derIdx)
for j, derP in enumerate(self._derIdxs[j][: len(previous)]):
diffP = [x - y for (x, y) in zip(derIdx, derP)]
if np.all([x >= 0 for x in diffP]):
RHS -= dot(self.HFEngineMarginalized.A(mu, diffP),
samplesOld[j])
return self.solveLS(mu, RHS = RHS)
def nextSample(self, mu:paramVal, j:int, overwrite : bool = False,
postprocess : bool = True) -> Np1D:
mu = checkParameter(mu, self.nPivot)
muidxs = self.mus[j].findall(mu[0])
if len(muidxs) > 0:
u = self._getSampleConcurrence(mu, j, np.sort(muidxs))
else:
u = self.solveLS(mu)
if postprocess:
self.postprocessu(u, j, overwrite = overwrite)
else:
self.setsample(u, j, overwrite)
if overwrite:
self.mus[j][self.nsamples[j]] = mu[0]
else:
self.mus[j].append(mu)
self.nsamples[j] += 1
return self.samples[j][self.nsamples[j] - 1]
def iterSample(self, mus:paramList, musM:paramList) -> sampList:
mus = checkParameterList(mus, self.nPivot)[0]
musM = checkParameterList(musM, self.nMarginal)[0]
vbMng(self, "INIT", "Starting sampling iterations.", 5)
n = len(mus)
m = len(musM)
if n <= 0:
raise RROMPyException("Number of samples must be positive.")
if m <= 0:
raise RROMPyException(("Number of marginal samples must be "
"positive."))
repeatedSamples = len(mus.unique()) != n
for j in range(m):
muMEff = [fp] * self.HFEngine.npar
for k, x in enumerate(self.directionMarginal):
muMEff[x] = musM(j, k)
self.HFEngineMarginalized = MarginalProxyEngine(self.HFEngine,
list(muMEff))
if repeatedSamples:
for k in range(n):
vbMng(self, "MAIN",
"Computing sample {} / {} for marginal {} / {}."\
.format(k + 1, n, j, m), 10)
self.nextSample(mus[k], j, overwrite = (k > 0),
postprocess = False)
if n > 1 and k == 0:
self.preallocateSamples(self.samples[j][0], mus[0],
n, j)
else:
self.setsample(self.solveLS(mus), j, overwrite = False)
self.mus[j] = copy(mus)
self.nsamples[j] = n
if len(self.musMarginal) > j:
self.musMarginal[j] = copy(musM[j])
else:
self.musMarginal.append(musM[j])
self.postprocessuBulk(j)
vbMng(self, "DEL", "Finished sampling iterations.", 5)
return self.samples[j]

Event Timeline