diff --git a/src/core/static_types.hh b/src/core/static_types.hh index 303093f..2aad759 100644 --- a/src/core/static_types.hh +++ b/src/core/static_types.hh @@ -1,328 +1,328 @@ /** * @file * * @author Lucas Frérot * * @section LICENSE * * Copyright (©) 2017 EPFL (Ecole Polytechnique Fédérale de * Lausanne) Laboratory (LSMS - Laboratoire de Simulation en Mécanique des * Solides) * * Tamaas is free software: you can redistribute it and/or modify it under the * terms of the GNU Lesser General Public License as published by the Free * Software Foundation, either version 3 of the License, or (at your option) any * later version. * * Tamaas is distributed in the hope that it will be useful, but WITHOUT ANY * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more * details. * * You should have received a copy of the GNU Lesser General Public License * along with Tamaas. If not, see . * */ /* -------------------------------------------------------------------------- */ #ifndef __STATIC_TYPES_HH__ #define __STATIC_TYPES_HH__ /* -------------------------------------------------------------------------- */ #include "tamaas.hh" __BEGIN_TAMAAS__ /* -------------------------------------------------------------------------- */ namespace detail { template struct product_tail_rec : product_tail_rec {}; template struct product_tail_rec : std::integral_constant {}; template struct get_rec : get_rec {}; template struct get_rec<0, n, ns...> : std::integral_constant {}; } template struct product : detail::product_tail_rec<1, ns...> {}; template struct get : detail::get_rec {}; template struct is_arithmetic : std::is_arithmetic {}; template struct is_arithmetic> : std::true_type {}; /** * @brief Static Array * * This class is meant to be a small and fast object for intermediate * calculations, possibly on wrapped memory belonging to a grid. Support type * show be either a pointer or a C array. It should not contain any virtual * method. */ template class StaticArray { static_assert(std::is_array::value || std::is_pointer::value, "the support type of StaticArray should be either a pointer or " "a C-array"); using T = DataType; public: /// Access operator __device__ __host__ T& operator()(UInt i) { // TAMAAS_ASSERT(i < n, "Access out of bounds"); return _mem[i]; } /// Access operator __device__ __host__ const T& operator()(UInt i) const { // TAMAAS_ASSERT(i < n, "Access out of bounds"); return _mem[i]; } /// L2 norm squared __device__ __host__ T l2squared() const { T res = 0; for (UInt i = 0; i < size; ++i) - res += _mem[i] * _mem[i]; + res += (*this)(i) * (*this)(i); return res; } /// L2 norm __device__ __host__ T l2norm() const { return std::sqrt(l2squared()); } #define VECTOR_OP(op) \ template \ __device__ __host__ void operator op(const StaticArray& o) { \ for (UInt i = 0; i < size; ++i) \ (*this)(i) op o(i); \ } VECTOR_OP(+=) VECTOR_OP(-=) VECTOR_OP(*=) VECTOR_OP(/=) #undef VECTOR_OP #define SCALAR_OP(op) \ template \ __device__ __host__ \ typename std::enable_if::value, StaticArray&>::type \ operator op(const T1& x) { \ for (UInt i = 0; i < size; ++i) \ (*this)(i) op x; \ return *this; \ } SCALAR_OP(+=) SCALAR_OP(-=) SCALAR_OP(*=) SCALAR_OP(/=) SCALAR_OP(=) #undef SCALAR_OP /// Overriding the implit copy operator StaticArray& operator=(const StaticArray& o) { return this->copy(o); } template StaticArray& copy(const StaticArray& o) { for (UInt i = 0; i < size; ++i) (*this)(i) = o(i); return *this; } protected: SupportType _mem; }; /** * @brief Static Tensor * * This class implements a multi-dimensional tensor behavior. */ template class StaticTensor : public StaticArray::value> { using parent = StaticArray::value>; using T = DataType; public: static constexpr UInt dim = sizeof...(dims); using parent::operator=; private: template static UInt unpackOffset(UInt offset, UInt index, Idx... rest) { constexpr UInt size = sizeof...(rest); offset += index; offset *= get::value; return unpackOffset(offset, rest...); } template static UInt unpackOffset(UInt offset, UInt index) { return offset + index; } public: template const T& operator()(Idx... idx) const { - return this->_mem[unpackOffset(0, idx...)]; + return parent::operator()(unpackOffset(0, idx...)); } template T& operator()(Idx... idx) { - return this->_mem[unpackOffset(0, idx...)]; + return parent::operator()(unpackOffset(0, idx...)); } }; /* -------------------------------------------------------------------------- */ /* Common Static Types */ /* -------------------------------------------------------------------------- */ template using StaticMatrix = StaticTensor; /// Vector class with size determined at compile-time template class StaticVector : public StaticTensor { using T = DataType; public: using StaticTensor::operator=; /// Matrix-vector product template __device__ __host__ void mul(const StaticMatrix& mat, const StaticVector& vec) { *this = T(0); for (UInt i = 0; i < n; ++i) for (UInt j = 0; j < m; ++j) (*this)(i) += mat(i, j) * vec(j); } }; /* -------------------------------------------------------------------------- */ /* Proxy Static Types */ /* -------------------------------------------------------------------------- */ /// Proxy type for tensor template