Page MenuHomec4science

nearest_neighbor_interpolator.py
No OneTemporary

File Metadata

Created
Sat, May 11, 12:50

nearest_neighbor_interpolator.py

# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
import numpy as np
from copy import deepcopy as copy
from rrompy.utilities.base.types import (List, ListAny, DictAny, Np1D, Np2D,
paramList)
from rrompy.utilities.poly_fitting.interpolator import GenericInterpolator
from .val import polyval
from rrompy.utilities.numerical import dot
from rrompy.utilities.exception_manager import RROMPyAssert
from rrompy.parameter import checkParameterList
__all__ = ['NearestNeighborInterpolator']
class NearestNeighborInterpolator(GenericInterpolator):
def __init__(self, other = None):
if other is None: return
self.support = other.support
self.coeffsLocal = other.coeffsLocal
self.directionalWeights = other.directionalWeights
self.nNeighbors = other.nNeighbors
self.npar = other.npar
@property
def shape(self):
sh = self.coeffsLocal.shape[1 :] if self.coeffsLocal.ndim > 1 else 1
return sh
def __call__(self, mu:paramList, der : List[int] = None,
scl : Np1D = None):
if der is not None and np.sum(der) > 0:
return np.zeros(self.coeffsLocal.shape[1 :] + (len(mu),))
return polyval(mu, self.coeffsLocal, self.support,
self.directionalWeights, self.nNeighbors)
def __copy__(self):
return NearestNeighborInterpolator(self)
def __deepcopy__(self, memo):
other = NearestNeighborInterpolator()
(other.support, other.coeffsLocal, other.directionalWeights,
other.nNeighbors, other.npar) = copy((self.support, self.coeffsLocal,
self.directionalWeights,
self.nNeighbors, self.npar),
memo)
return other
def postmultiplyTensorize(self, A:Np2D):
RROMPyAssert(A.shape[0], self.shape[-1], "Shape of output")
self.coeffsLocal = dot(self.coeffsLocal, A)
def pad(self, nleft : List[int] = None, nright : List[int] = None):
if nleft is None: nleft = [0] * len(self.shape)
if nright is None: nright = [0] * len(self.shape)
if not hasattr(nleft, "__len__"): nleft = [nleft]
if not hasattr(nright, "__len__"): nright = [nright]
RROMPyAssert(len(self.shape), len(nleft), "Shape of output")
RROMPyAssert(len(self.shape), len(nright), "Shape of output")
padwidth = [(0, 0)] + [(l, r) for l, r in zip(nleft, nright)]
self.coeffsLocal = np.pad(self.coeffsLocal, padwidth, "constant",
constant_values = (0., 0.))
def setupByInterpolation(self, support:paramList, values:ListAny,
directionalWeights : Np1D = None,
nNeighbors : int = 1, verbose : bool = True,
vanderCoeffs : DictAny = {}):
support = checkParameterList(support)[0]
self.support = copy(support)
if "reorder" in vanderCoeffs.keys():
self.support = self.support[vanderCoeffs["reorder"]]
self.npar = support.shape[1]
if directionalWeights is None:
directionalWeights = np.ones(self.npar)
self.directionalWeights = directionalWeights
self.coeffsLocal = values
RROMPyAssert(len(support), len(values), "Number of support points")
return True, None

Event Timeline