Page MenuHomec4science

marginalize_poly_list.py
No OneTemporary

File Metadata

Created
Thu, May 2, 03:25

marginalize_poly_list.py

# Copyright (C) 2018-2020 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
import numpy as np
from scipy.sparse import csr
from rrompy.utilities.base.types import Np1D, Np2D, ListAny
from rrompy.utilities.base import freepar as fp
from .hash_derivative import (hashDerivativeToIdx as hashD,
hashIdxToDerivative as hashI)
from rrompy.parameter import checkParameter
__all__ = ['marginalizePolyList']
def marginalizePolyList(objs:ListAny, marginalVals : Np1D = [fp],
zeroObj : Np2D = 0.,
recompress : bool = True) -> ListAny:
"""Marginalize out variable in list of polynomials."""
res = []
freeLocations = []
fixedLocations = []
muFixed = []
if not hasattr(marginalVals, "__len__"): marginalVals = [marginalVals]
for i, m in enumerate(marginalVals):
if m == fp:
freeLocations += [i]
else:
fixedLocations += [i]
muFixed += [m]
muFixed = checkParameter(muFixed, len(fixedLocations), return_data = True)
if zeroObj == "auto":
if isinstance(objs[0], np.ndarray):
zeroObj = np.zeros_like(objs[0])
elif isinstance(objs[0], csr.csr_matrix):
d = objs[0].shape[0]
zeroObj = csr.csr_matrix(([], [], np.zeros(d + 1)),
shape = objs[0].shape,
dtype = objs[0].dtype)
else:
zeroObj = 0.
for j, obj in enumerate(objs):
derjBase = hashI(j, len(marginalVals))
jNew = hashD([derjBase[i] for i in freeLocations])
derjFixed = [derjBase[i] for i in fixedLocations]
obj = np.prod(muFixed ** derjFixed) * obj
if jNew >= len(res):
for _ in range(len(res), jNew):
res += [zeroObj]
res += [obj]
else:
res[jNew] = res[jNew] + obj
if recompress:
for re in res[::-1]:
try:
if isinstance(re, np.ndarray):
iszero = np.allclose(re, zeroObj,
atol = 2 * np.finfo(re.dtype).eps)
elif isinstance(re, csr.csr_matrix):
iszero = re.nnz == 0
else:
break
if not iszero: break
except: break
res.pop()
return res

Event Timeline