Page MenuHomec4science

static_types.hh
No OneTemporary

File Metadata

Created
Thu, Jun 27, 13:34

static_types.hh

/**
* @file
*
* @author Lucas Frérot <lucas.frerot@epfl.ch>
*
* @section LICENSE
*
* Copyright (©) 2017 EPFL (Ecole Polytechnique Fédérale de
* Lausanne) Laboratory (LSMS - Laboratoire de Simulation en Mécanique des
* Solides)
*
* Tamaas is free software: you can redistribute it and/or modify it under the
* terms of the GNU Lesser General Public License as published by the Free
* Software Foundation, either version 3 of the License, or (at your option) any
* later version.
*
* Tamaas is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
* A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
* details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with Tamaas. If not, see <http://www.gnu.org/licenses/>.
*
*/
/* -------------------------------------------------------------------------- */
#ifndef __STATIC_TYPES_HH__
#define __STATIC_TYPES_HH__
/* -------------------------------------------------------------------------- */
#include "tamaas.hh"
#include <type_traits>
__BEGIN_TAMAAS__
template <typename T, UInt n, UInt m>
class StaticMatrix;
/// Vector class with size determined at compile-time
template <typename T, UInt n>
class StaticVector {
public:
/// Implicit wrap constructor
__device__ __host__ StaticVector(T& spot) : _mem(&spot) {}
public:
/// Access operator
__device__ __host__ T& operator()(UInt i) {
// TAMAAS_ASSERT(i < n, "Access out of bounds");
return _mem[i];
}
/// Access operator
__device__ __host__ const T& operator()(UInt i) const {
// TAMAAS_ASSERT(i < n, "Access out of bounds");
return _mem[i];
}
/// L2 norm squared
__device__ __host__ T l2squared() const {
typename std::remove_const<T>::type res = 0;
for (UInt i = 0; i < n; ++i)
res += _mem[i] * _mem[i];
return res;
}
/// L2 norm
__device__ __host__ T l2norm() const { return std::sqrt(l2squared()); }
/// Matrix-vector product
template <typename T1, typename T2, UInt m>
__device__ __host__ void mul(const StaticMatrix<T1, n, m>& mat,
const StaticVector<T2, m>& vec) {
*this = 0;
for (UInt i = 0; i < n; ++i)
for (UInt j = 0; j < m; ++j)
(*this)(i) += mat(i, j) * vec(j);
}
#define VECTOR_OP(op) \
template <typename T1> \
__device__ __host__ void operator op(const StaticVector<T1, n>& o) { \
for (UInt i = 0; i < n; ++i) \
(*this)(i) op o(i); \
}
VECTOR_OP(+=)
VECTOR_OP(-=)
VECTOR_OP(*=)
VECTOR_OP(/=)
#undef VECTOR_OP
#define SCALAR_OP(op) \
__device__ __host__ StaticVector& operator op(const T& x) { \
for (UInt i = 0; i < n; ++i) \
(*this)(i) op x; \
return *this; \
}
SCALAR_OP(+=)
SCALAR_OP(-=)
SCALAR_OP(*=)
SCALAR_OP(/=)
SCALAR_OP(=)
#undef SCALAR_OP
protected:
T* _mem;
};
/// Matrix class with size determined at compile time
template <typename T, UInt n, UInt m>
class StaticMatrix : public StaticVector<T, n * m> {
public:
using StaticVector<T, n * m>::StaticVector;
template <UInt p>
void mul(const StaticMatrix<T, n, p>& mat,
const StaticVector<T, p>& vec) = delete;
public:
/// Access operator
__device__ __host__ T& operator()(UInt i, UInt j) {
UInt index = i * m + j;
// TAMAAS_ASSERT(index < n, "Access out of bounds");
return this->_mem[index];
}
/// Access operator
__device__ __host__ const T& operator()(UInt i, UInt j) const {
UInt index = i * m + j;
// TAMAAS_ASSERT(index < n, "Access out of bounds");
return this->_mem[index];
}
#define VECTOR_OP(op) \
template <typename T1> \
__device__ __host__ void operator op(const StaticMatrix<T1, n, m>& o) { \
for (UInt i = 0; i < n; ++i) \
for (UInt j = 0; j < m; ++j) \
(*this)(i, j) op o(i, j); \
}
VECTOR_OP(+=)
VECTOR_OP(-=)
VECTOR_OP(*=)
VECTOR_OP(/=)
#undef VECTOR_OP
#define SCALAR_OP(op) \
__device__ __host__ StaticMatrix& operator op(const T& x) { \
for (UInt i = 0; i < n; ++i) \
for (UInt j = 0; j < m; ++j) \
(*this)(i, j) op x; \
return *this; \
}
SCALAR_OP(+=)
SCALAR_OP(-=)
SCALAR_OP(*=)
SCALAR_OP(/=)
SCALAR_OP(=)
#undef SCALAR_OP
};
__END_TAMAAS__
#endif // __STATIC_TYPES_HH__

Event Timeline