Page MenuHomec4science

nearest_neighbor_interpolator.py
No OneTemporary

File Metadata

Created
Wed, May 1, 11:51

nearest_neighbor_interpolator.py

# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
from copy import deepcopy as copy
import numpy as np
from collections.abc import Iterable
from rrompy.utilities.base.types import List, ListAny, Np1D, Np2D, paramList
from rrompy.utilities.poly_fitting.interpolator import GenericInterpolator
from .val import polyval
from rrompy.utilities.numerical import dot
from rrompy.utilities.exception_manager import RROMPyAssert
from rrompy.parameter import checkParameterList
__all__ = ['NearestNeighborInterpolator']
class NearestNeighborInterpolator(GenericInterpolator):
"""Function class with setup by nearest neighbor interpolation."""
def __init__(self, other = None):
if other is None: return
self.support = other.support
self.coeffsLocal = other.coeffsLocal
self.nNeighbors = other.nNeighbors
self.directionalWeights = other.directionalWeights
self.npar = other.npar
@property
def shape(self):
sh = self.coeffsLocal.shape[1 :] if self.coeffsLocal.ndim > 1 else 1
return sh
def __call__(self, mu:paramList, der : List[int] = None,
scl : Np1D = None):
if der is not None and np.sum(der) > 0:
return np.zeros(self.coeffsLocal.shape[1 :] + (len(mu),))
return polyval(mu, self.coeffsLocal, self.support,
self.nNeighbors, self.directionalWeights)
def __copy__(self):
return NearestNeighborInterpolator(self)
def __deepcopy__(self, memo):
other = NearestNeighborInterpolator()
(other.support, other.coeffsLocal, other.nNeighbors,
other.directionalWeights, other.npar) = copy((self.support,
self.coeffsLocal, self.nNeighbors,
self.directionalWeights,
self.npar), memo)
return other
def postmultiplyTensorize(self, A:Np2D):
RROMPyAssert(A.shape[0], self.shape[-1], "Shape of output")
self.coeffsLocal = dot(self.coeffsLocal, A)
def pad(self, nleft : List[int] = None, nright : List[int] = None):
if nleft is None: nleft = [0] * len(self.shape)
if nright is None: nright = [0] * len(self.shape)
if not isinstance(nleft, Iterable): nleft = [nleft]
if not isinstance(nright, Iterable): nright = [nright]
RROMPyAssert(len(self.shape), len(nleft), "Shape of output")
RROMPyAssert(len(self.shape), len(nright), "Shape of output")
padwidth = [(0, 0)] + [(l, r) for l, r in zip(nleft, nright)]
self.coeffsLocal = np.pad(self.coeffsLocal, padwidth, "constant",
constant_values = (0., 0.))
def setupByInterpolation(self, support:paramList, values:ListAny,
nNeighbors : int = 1,
directionalWeights : Np1D = None):
support = checkParameterList(support)
RROMPyAssert(len(support), len(values), "Number of support values")
self.support = copy(support)
self.npar = support.shape[1]
self.coeffsLocal = values
self.nNeighbors = max(1, nNeighbors)
if directionalWeights is None: directionalWeights = [1.] * self.npar
self.directionalWeights = np.array(directionalWeights)
RROMPyAssert(len(support), len(values), "Number of support points")
return True, None

Event Timeline