Page MenuHomec4science

tensor_la.py
No OneTemporary

File Metadata

Created
Wed, May 1, 07:08

tensor_la.py

# Copyright (C) 2018 by the RROMPy authors
#
# This file is part of RROMPy.
#
# RROMPy is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# RROMPy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with RROMPy. If not, see <http://www.gnu.org/licenses/>.
#
import numpy as np
from numbers import Number
from rrompy.sampling.sample_list import sampleList
from rrompy.parameter.parameter_list import parameterList
__all__ = ['dot', 'solve']
def dot(u, v):
"""A * b."""
if isinstance(u, Number) or isinstance(v, Number): return u * v
if isinstance(u, (parameterList, sampleList)): u = u.data
if isinstance(v, (parameterList, sampleList)): v = v.data
if u.shape[-1] == v.shape[0]:
if isinstance(u, np.ndarray):
return np.tensordot(u, v, 1)
else:
return u.dot(v)
M = u.shape[-1]
N = v.shape[0] // M
rshape = u.shape[: -2] + (N * u.shape[-2],) + v.shape[1 :]
return u.dot(v.reshape(M, -1)).reshape(rshape)
def solve(A, b, solver, kwargs):
"""A \ b."""
if isinstance(A, Number): return b / A
if isinstance(A, (parameterList, sampleList)): A = A.data
if isinstance(b, (parameterList, sampleList)): b = b.data
if A.shape[-1] == b.shape[0]: return solver(A, b, kwargs)
M = A.shape[-1]
N = b.shape[0] // M
rshape = A.shape[: -2] + (N * A.shape[-2],) + b.shape[1 :]
return solver(A, b.reshape(M, -1), kwargs).reshape(rshape)

Event Timeline