Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F85169393
static_types.hh
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Fri, Sep 27, 06:31
Size
24 KB
Mime Type
text/x-c++
Expires
Sun, Sep 29, 06:31 (2 d)
Engine
blob
Format
Raw Data
Handle
21106199
Attached To
rTAMAAS tamaas
static_types.hh
View Options
/**
* @file
* @section LICENSE
*
* Copyright (©) 2016-19 EPFL (École Polytechnique Fédérale de Lausanne),
* Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
/* -------------------------------------------------------------------------- */
#ifndef __STATIC_TYPES_HH__
#define __STATIC_TYPES_HH__
/* -------------------------------------------------------------------------- */
#include "tamaas.hh"
#include <thrust/sort.h>
#include <type_traits>
namespace
tamaas
{
/* -------------------------------------------------------------------------- */
namespace
detail
{
template
<
UInt
acc
,
UInt
n
,
UInt
...
ns
>
struct
product_tail_rec
:
product_tail_rec
<
acc
*
n
,
ns
...
>
{};
template
<
UInt
acc
,
UInt
n
>
struct
product_tail_rec
<
acc
,
n
>
:
std
::
integral_constant
<
UInt
,
acc
*
n
>
{};
template
<
UInt
N
,
UInt
n
,
UInt
...
ns
>
struct
get_rec
:
get_rec
<
N
-
1
,
ns
...
>
{};
template
<
UInt
n
,
UInt
...
ns
>
struct
get_rec
<
0
,
n
,
ns
...
>
:
std
::
integral_constant
<
UInt
,
n
>
{};
}
// namespace detail
template
<
UInt
...
ns
>
struct
product
:
detail
::
product_tail_rec
<
1
,
ns
...
>
{};
template
<
UInt
N
,
UInt
...
ns
>
struct
get
:
detail
::
get_rec
<
N
,
ns
...
>
{};
template
<
typename
T
>
struct
is_arithmetic
:
std
::
is_arithmetic
<
T
>
{};
template
<
typename
T
>
struct
is_arithmetic
<
thrust
::
complex
<
T
>>
:
std
::
true_type
{};
template
<
UInt
dim
>
struct
voigt_size
;
template
<>
struct
voigt_size
<
3
>
:
std
::
integral_constant
<
UInt
,
6
>
{};
template
<>
struct
voigt_size
<
2
>
:
std
::
integral_constant
<
UInt
,
3
>
{};
template
<>
struct
voigt_size
<
1
>
:
std
::
integral_constant
<
UInt
,
1
>
{};
/* -------------------------------------------------------------------------- */
/**
* @brief Static Array
*
* This class is meant to be a small and fast object for intermediate
* calculations, possibly on wrapped memory belonging to a grid. Support type
* show be either a pointer or a C array. It should not contain any virtual
* method.
*/
template
<
typename
DataType
,
typename
SupportType
,
UInt
_size
>
class
StaticArray
{
static_assert
(
std
::
is_array
<
SupportType
>::
value
||
std
::
is_pointer
<
SupportType
>::
value
,
"the support type of StaticArray should be either a pointer or "
"a C-array"
);
using
T
=
DataType
;
using
T_bare
=
typename
std
::
remove_cv_t
<
T
>
;
public
:
using
value_type
=
T
;
static
constexpr
UInt
size
=
_size
;
public
:
/// Access operator
__device__
__host__
T
&
operator
()(
UInt
i
)
{
// TAMAAS_ASSERT(i < n, "Access out of bounds");
return
_mem
[
i
];
}
/// Access operator
__device__
__host__
const
T
&
operator
()(
UInt
i
)
const
{
// TAMAAS_ASSERT(i < n, "Access out of bounds");
return
_mem
[
i
];
}
/// Scalar product
template
<
typename
DT
,
typename
ST
>
__device__
__host__
T_bare
dot
(
const
StaticArray
<
DT
,
ST
,
size
>&
o
)
const
{
decltype
(
T_bare
(
0
)
*
DT
(
0
))
res
=
0
;
for
(
UInt
i
=
0
;
i
<
size
;
++
i
)
res
+=
(
*
this
)(
i
)
*
o
(
i
);
return
res
;
}
/// L2 norm squared
__device__
__host__
T_bare
l2squared
()
const
{
return
this
->
dot
(
*
this
);
}
/// L2 norm
__device__
__host__
T_bare
l2norm
()
const
{
return
std
::
sqrt
(
l2squared
());
}
/// Sum of all elements
__device__
__host__
T_bare
sum
()
const
{
T_bare
res
=
0
;
for
(
UInt
i
=
0
;
i
<
size
;
++
i
)
res
+=
(
*
this
)(
i
);
return
res
;
}
#define VECTOR_OP(op) \
template <typename DT, typename ST> \
__device__ __host__ void operator op(const StaticArray<DT, ST, size>& o) { \
for (UInt i = 0; i < size; ++i) \
(*this)(i) op o(i); \
}
VECTOR_OP
(
+=
)
VECTOR_OP
(
-=
)
VECTOR_OP
(
*=
)
VECTOR_OP
(
/=
)
#undef VECTOR_OP
#define SCALAR_OP(op) \
template <typename T1> \
__device__ __host__ std::enable_if_t<is_arithmetic<T1>::value, StaticArray&> \
operator op(const T1& x) { \
for (UInt i = 0; i < size; ++i) \
(*this)(i) op x; \
return *this; \
}
SCALAR_OP
(
+=
)
SCALAR_OP
(
-=
)
SCALAR_OP
(
*=
)
SCALAR_OP
(
/=
)
SCALAR_OP
(
=
)
#undef SCALAR_OP
/// Overriding the implicit copy operator
__device__
__host__
StaticArray
&
operator
=
(
const
StaticArray
&
o
)
{
return
this
->
copy
(
o
);
}
template
<
typename
DT
,
typename
ST
>
__device__
__host__
void
operator
=
(
const
StaticArray
<
DT
,
ST
,
size
>&
o
)
{
this
->
copy
(
o
);
}
template
<
typename
DT
,
typename
ST
>
__device__
__host__
StaticArray
&
copy
(
const
StaticArray
<
DT
,
ST
,
size
>&
o
)
{
for
(
UInt
i
=
0
;
i
<
size
;
++
i
)
(
*
this
)(
i
)
=
o
(
i
);
return
*
this
;
}
T
*
begin
()
{
return
_mem
;
}
const
T
*
begin
()
const
{
return
_mem
;
}
T
*
end
()
{
return
_mem
+
size
;
}
const
T
*
end
()
const
{
return
_mem
+
size
;
}
private
:
template
<
typename
U
>
using
valid_size_t
=
std
::
enable_if_t
<
(
size
>
0
),
U
>
;
public
:
valid_size_t
<
T
&>
front
()
{
return
*
_mem
;
}
valid_size_t
<
const
T
&>
front
()
const
{
return
*
_mem
;
}
valid_size_t
<
T
&>
back
()
{
return
_mem
[
size
-
1
];
}
valid_size_t
<
const
T
&>
back
()
const
{
return
_mem
[
size
-
1
];
}
protected
:
SupportType
_mem
;
};
/**
* @brief Static Tensor
*
* This class implements a multi-dimensional tensor behavior.
*/
template
<
typename
DataType
,
typename
SupportType
=
DataType
*
,
UInt
...
dims
>
class
StaticTensor
:
public
StaticArray
<
DataType
,
SupportType
,
product
<
dims
...
>::
value
>
{
using
parent
=
StaticArray
<
DataType
,
SupportType
,
product
<
dims
...
>::
value
>
;
using
T
=
DataType
;
public
:
static
constexpr
UInt
dim
=
sizeof
...(
dims
);
using
parent
::
operator
=
;
private
:
template
<
typename
...
Idx
>
__device__
__host__
static
UInt
unpackOffset
(
UInt
offset
,
UInt
index
,
Idx
...
rest
)
{
constexpr
UInt
size
=
sizeof
...(
rest
);
offset
+=
index
;
offset
*=
get
<
dim
-
size
,
dims
...
>::
value
;
return
unpackOffset
(
offset
,
rest
...);
}
template
<
typename
...
Idx
>
__device__
__host__
static
UInt
unpackOffset
(
UInt
offset
,
UInt
index
)
{
return
offset
+
index
;
}
public
:
template
<
typename
...
Idx
>
__device__
__host__
const
T
&
operator
()(
Idx
...
idx
)
const
{
return
parent
::
operator
()(
unpackOffset
(
0
,
idx
...));
}
template
<
typename
...
Idx
>
__device__
__host__
T
&
operator
()(
Idx
...
idx
)
{
return
parent
::
operator
()(
unpackOffset
(
0
,
idx
...));
}
};
/* -------------------------------------------------------------------------- */
/* Common Static Types */
/* -------------------------------------------------------------------------- */
// Forward declaration
template
<
typename
DataType
,
typename
SupportType
,
UInt
n
>
class
StaticVector
;
template
<
typename
DataType
,
typename
SupportType
,
UInt
n
>
class
StaticSymMatrix
;
/* -------------------------------------------------------------------------- */
template
<
typename
DataType
,
typename
SupportType
,
UInt
n
,
UInt
m
>
class
StaticMatrix
:
public
StaticTensor
<
DataType
,
SupportType
,
n
,
m
>
{
using
T
=
DataType
;
using
T_bare
=
typename
std
::
remove_cv_t
<
T
>
;
public
:
using
StaticTensor
<
DataType
,
SupportType
,
n
,
m
>::
operator
=
;
// /// Initialize from a symmetric matrix
template
<
typename
DT
,
typename
ST
>
__device__
__host__
std
::
enable_if_t
<
n
==
m
>
fromSymmetric
(
const
StaticSymMatrix
<
DT
,
ST
,
n
>&
o
);
/// Outer product of two vectors
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
>
__device__
__host__
void
outer
(
const
StaticVector
<
DT1
,
ST1
,
n
>&
a
,
const
StaticVector
<
DT2
,
ST2
,
m
>&
b
);
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
,
UInt
l
>
__device__
__host__
void
mul
(
const
StaticMatrix
<
DT1
,
ST1
,
n
,
l
>&
a
,
const
StaticMatrix
<
DT2
,
ST2
,
l
,
m
>&
b
)
{
(
*
this
)
=
T
(
0
);
for
(
UInt
i
=
0
;
i
<
n
;
++
i
)
for
(
UInt
j
=
0
;
j
<
m
;
++
j
)
for
(
UInt
k
=
0
;
k
<
l
;
++
k
)
(
*
this
)(
i
,
j
)
+=
a
(
i
,
k
)
*
b
(
k
,
j
);
}
__device__
__host__
std
::
enable_if_t
<
n
==
m
,
T_bare
>
trace
()
const
{
T_bare
res
{
0
};
for
(
UInt
i
=
0
;
i
<
n
;
++
i
)
res
+=
(
*
this
)(
i
,
i
);
return
res
;
}
template
<
typename
DT1
,
typename
ST1
>
__device__
__host__
std
::
enable_if_t
<
n
==
m
>
deviatoric
(
const
StaticMatrix
<
DT1
,
ST1
,
n
,
m
>&
mat
,
Real
factor
=
n
)
{
auto
norm_trace
=
mat
.
trace
()
/
factor
;
for
(
UInt
i
=
0
;
i
<
n
;
++
i
)
for
(
UInt
j
=
0
;
j
<
m
;
++
j
)
(
*
this
)(
i
,
j
)
=
mat
(
i
,
j
)
-
(
i
==
j
)
*
norm_trace
;
}
};
/* -------------------------------------------------------------------------- */
/// Vector class with size determined at compile-time
template
<
typename
DataType
,
typename
SupportType
,
UInt
n
>
class
StaticVector
:
public
StaticTensor
<
DataType
,
SupportType
,
n
>
{
using
T
=
std
::
remove_cv_t
<
DataType
>
;
public
:
using
StaticTensor
<
DataType
,
SupportType
,
n
>::
operator
=
;
/// Matrix-vector product
template
<
bool
transpose
,
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
,
UInt
m
>
__device__
__host__
void
mul
(
const
StaticMatrix
<
DT1
,
ST1
,
n
,
m
>&
mat
,
const
StaticVector
<
DT2
,
ST2
,
m
>&
vec
)
{
*
this
=
T
(
0
);
for
(
UInt
i
=
0
;
i
<
n
;
++
i
)
for
(
UInt
j
=
0
;
j
<
m
;
++
j
)
(
*
this
)(
i
)
+=
((
transpose
)
?
mat
(
j
,
i
)
:
mat
(
i
,
j
))
*
vec
(
j
);
// can be optimized
}
};
/* -------------------------------------------------------------------------- */
/// Symmetric matrix in Voigt notation
template
<
typename
DataType
,
typename
SupportType
,
UInt
n
>
class
StaticSymMatrix
:
public
StaticVector
<
DataType
,
SupportType
,
voigt_size
<
n
>::
value
>
{
using
parent
=
StaticVector
<
DataType
,
SupportType
,
voigt_size
<
n
>::
value
>
;
using
T
=
std
::
remove_cv_t
<
DataType
>
;
private
:
template
<
typename
DT
,
typename
ST
,
typename
BinOp
>
__device__
__host__
void
sym_binary
(
const
StaticMatrix
<
DT
,
ST
,
n
,
n
>&
m
,
BinOp
&&
op
)
{
for
(
UInt
i
=
0
;
i
<
n
;
++
i
)
op
((
*
this
)(
i
),
m
(
i
,
i
));
const
auto
a
=
0.5
*
std
::
sqrt
(
2
);
for
(
UInt
j
=
n
-
1
,
b
=
n
;
j
>
0
;
--
j
)
for
(
int
i
=
j
-
1
;
i
>=
0
;
--
i
)
op
((
*
this
)(
b
++
),
a
*
(
m
(
i
,
j
)
+
m
(
j
,
i
)));
}
public
:
/// Copy values from matrix and symmetrize
template
<
typename
DT
,
typename
ST
>
__device__
__host__
void
symmetrize
(
const
StaticMatrix
<
DT
,
ST
,
n
,
n
>&
m
)
{
sym_binary
(
m
,
[](
auto
&&
v
,
auto
&&
w
)
{
v
=
w
;
});
}
/// Add values from symmetrized matrix
template
<
typename
DT
,
typename
ST
>
__device__
__host__
void
operator
+=
(
const
StaticMatrix
<
DT
,
ST
,
n
,
n
>&
m
)
{
sym_binary
(
m
,
[](
auto
&&
v
,
auto
&&
w
)
{
v
+=
w
;
});
}
__device__
__host__
auto
trace
()
const
{
std
::
remove_cv_t
<
DataType
>
res
=
0
;
for
(
UInt
i
=
0
;
i
<
n
;
++
i
)
res
+=
(
*
this
)(
i
);
return
res
;
}
template
<
typename
DT
,
typename
ST
>
__device__
__host__
void
deviatoric
(
const
StaticSymMatrix
<
DT
,
ST
,
n
>&
m
,
Real
factor
=
n
)
{
auto
tr
=
m
.
trace
()
/
factor
;
for
(
UInt
i
=
0
;
i
<
n
;
++
i
)
(
*
this
)(
i
)
=
m
(
i
)
-
tr
;
for
(
UInt
i
=
n
;
i
<
voigt_size
<
n
>::
value
;
++
i
)
(
*
this
)(
i
)
=
m
(
i
);
}
using
parent
::
operator
+=
;
using
parent
::
operator
=
;
};
/* -------------------------------------------------------------------------- */
// Implementation of constructor from symmetric matrix
template
<
typename
DataType
,
typename
SupportType
,
UInt
n
,
UInt
m
>
template
<
typename
DT
,
typename
ST
>
__device__
__host__
std
::
enable_if_t
<
n
==
m
>
StaticMatrix
<
DataType
,
SupportType
,
n
,
m
>::
fromSymmetric
(
const
StaticSymMatrix
<
DT
,
ST
,
n
>&
o
)
{
for
(
UInt
i
=
0
;
i
<
n
;
++
i
)
(
*
this
)(
i
,
i
)
=
o
(
i
);
// We use Mendel notation for the vector representation
const
auto
a
=
1.
/
std
::
sqrt
(
2
);
for
(
UInt
j
=
n
-
1
,
b
=
n
;
j
>
0
;
--
j
)
for
(
int
i
=
j
-
1
;
i
>=
0
;
--
i
)
(
*
this
)(
i
,
j
)
=
(
*
this
)(
j
,
i
)
=
a
*
o
(
b
++
);
}
// Implementation of outer product
template
<
typename
DataType
,
typename
SupportType
,
UInt
n
,
UInt
m
>
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
>
__device__
__host__
void
StaticMatrix
<
DataType
,
SupportType
,
n
,
m
>::
outer
(
const
StaticVector
<
DT1
,
ST1
,
n
>&
a
,
const
StaticVector
<
DT2
,
ST2
,
m
>&
b
)
{
for
(
UInt
i
=
0
;
i
<
n
;
++
i
)
for
(
UInt
j
=
0
;
j
<
m
;
++
j
)
(
*
this
)(
i
,
j
)
=
a
(
i
)
*
b
(
j
);
}
/* -------------------------------------------------------------------------- */
/* On the stack static types */
/* -------------------------------------------------------------------------- */
template
<
template
<
typename
,
typename
,
UInt
...
>
class
StaticParent
,
UInt
...
dims
>
struct
static_size_helper
:
product
<
dims
...
>
{};
template
<
UInt
n
>
struct
static_size_helper
<
StaticSymMatrix
,
n
>
:
voigt_size
<
n
>
{};
template
<
template
<
typename
,
typename
,
UInt
...
>
class
StaticParent
,
typename
T
,
UInt
...
dims
>
class
Tensor
:
public
StaticParent
<
T
,
T
[
static_size_helper
<
StaticParent
,
dims
...
>::
value
],
dims
...
>
{
static
constexpr
UInt
size
=
static_size_helper
<
StaticParent
,
dims
...
>::
value
;
using
parent
=
StaticParent
<
T
,
T
[
size
],
dims
...
>
;
public
:
using
parent
::
operator
=
;
using
parent
::
copy
;
/// Default constructor
__device__
__host__
Tensor
()
=
default
;
/// Construct with default value
__device__
__host__
Tensor
(
T
val
)
{
*
this
=
val
;
}
/// Construct from array
__device__
__host__
Tensor
(
const
std
::
array
<
T
,
size
>&
arr
)
{
// we use size to ensure static loop unrolling
for
(
UInt
i
=
0
;
i
<
size
;
++
i
)
this
->
_mem
[
i
]
=
arr
[
i
];
}
/// Copy from array
__device__
__host__
Tensor
&
operator
=
(
const
std
::
array
<
T
,
size
>&
arr
)
{
// we use size to ensure static loop unrolling
for
(
UInt
i
=
0
;
i
<
size
;
++
i
)
(
*
this
)(
i
)
=
arr
[
i
];
}
/// Construct by copy from static tensor
template
<
typename
DT
,
typename
ST
>
__device__
__host__
Tensor
(
const
StaticParent
<
DT
,
ST
,
dims
...
>&
o
)
{
this
->
copy
(
o
);
}
};
template
<
typename
T
,
UInt
n
,
UInt
m
>
using
Matrix
=
Tensor
<
StaticMatrix
,
T
,
n
,
m
>
;
template
<
typename
T
,
UInt
n
>
using
SymMatrix
=
Tensor
<
StaticSymMatrix
,
T
,
n
>
;
template
<
typename
T
,
UInt
n
>
using
Vector
=
Tensor
<
StaticVector
,
T
,
n
>
;
/* -------------------------------------------------------------------------- */
/* Proxy Static Types */
/* -------------------------------------------------------------------------- */
/// Proxy type for tensor
template
<
template
<
typename
,
typename
,
UInt
...
>
class
StaticParent
,
typename
T
,
UInt
...
dims
>
class
TensorProxy
:
public
StaticParent
<
T
,
T
*
,
dims
...
>
{
using
parent
=
StaticParent
<
T
,
T
*
,
dims
...
>
;
public
:
/// Explicit construction from data location
__device__
__host__
explicit
TensorProxy
(
T
*
spot
)
{
this
->
_mem
=
spot
;
}
/// Explicit construction from lvalue-reference
__device__
__host__
explicit
TensorProxy
(
T
&
spot
)
:
TensorProxy
(
&
spot
)
{}
/// Construction from static tensor
template
<
typename
DataType
,
typename
SupportType
>
__device__
__host__
TensorProxy
(
StaticParent
<
DataType
,
SupportType
,
dims
...
>&
o
)
:
TensorProxy
(
o
.
begin
())
{}
using
parent
::
operator
=
;
public
:
using
stack_type
=
Tensor
<
StaticParent
,
T
,
dims
...
>
;
};
template
<
typename
T
,
UInt
n
,
UInt
m
>
using
MatrixProxy
=
TensorProxy
<
StaticMatrix
,
T
,
n
,
m
>
;
template
<
typename
T
,
UInt
n
>
using
SymMatrixProxy
=
TensorProxy
<
StaticSymMatrix
,
T
,
n
>
;
template
<
typename
T
,
UInt
n
>
using
VectorProxy
=
TensorProxy
<
StaticVector
,
T
,
n
>
;
/* -------------------------------------------------------------------------- */
/* -------------------------------------------------------------------------- */
/* Arithmetic operators creating temporaries */
/* -------------------------------------------------------------------------- */
/* -------------------------------------------------------------------------- */
/* -------------------------------------------------------------------------- */
/* Simple operators */
/* -------------------------------------------------------------------------- */
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
,
UInt
dim
>
__device__
__host__
Vector
<
decltype
(
DT1
(
0
)
+
DT2
(
0
)),
dim
>
operator
+
(
const
StaticVector
<
DT1
,
ST1
,
dim
>&
a
,
const
StaticVector
<
DT2
,
ST2
,
dim
>&
b
)
{
Vector
<
decltype
(
DT1
(
0
)
+
DT2
(
0
)),
dim
>
res
(
a
);
res
+=
b
;
return
res
;
}
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
,
UInt
dim
>
__device__
__host__
Vector
<
decltype
(
DT1
(
0
)
-
DT2
(
0
)),
dim
>
operator
-
(
const
StaticVector
<
DT1
,
ST1
,
dim
>&
a
,
const
StaticVector
<
DT2
,
ST2
,
dim
>&
b
)
{
Vector
<
decltype
(
DT1
(
0
)
-
DT2
(
0
)),
dim
>
res
(
a
);
res
-=
b
;
return
res
;
}
template
<
typename
DT1
,
typename
ST1
,
UInt
dim
>
__device__
__host__
Vector
<
decltype
(
DT1
(
0
)),
dim
>
operator
-
(
const
StaticVector
<
DT1
,
ST1
,
dim
>&
a
)
{
Vector
<
decltype
(
DT1
(
0
)),
dim
>
res
(
a
);
res
*=
-
1
;
return
res
;
}
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
,
UInt
n
,
UInt
m
>
__device__
__host__
Matrix
<
decltype
(
DT1
(
0
)
+
DT2
(
0
)),
n
,
m
>
operator
+
(
const
StaticMatrix
<
DT1
,
ST1
,
n
,
m
>&
a
,
const
StaticMatrix
<
DT2
,
ST2
,
n
,
m
>&
b
)
{
Matrix
<
decltype
(
DT1
(
0
)
+
DT2
(
0
)),
n
,
m
>
res
(
a
);
res
+=
b
;
return
res
;
}
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
,
UInt
n
,
UInt
m
>
__device__
__host__
Matrix
<
decltype
(
DT1
(
0
)
-
DT2
(
0
)),
n
,
m
>
operator
-
(
const
StaticMatrix
<
DT1
,
ST1
,
n
,
m
>&
a
,
const
StaticMatrix
<
DT2
,
ST2
,
n
,
m
>&
b
)
{
Matrix
<
decltype
(
DT1
(
0
)
-
DT2
(
0
)),
n
,
m
>
res
(
a
);
res
-=
b
;
return
res
;
}
template
<
typename
DT1
,
typename
ST1
,
UInt
n
,
UInt
m
>
__device__
__host__
Matrix
<
decltype
(
DT1
(
0
)),
n
,
m
>
operator
-
(
const
StaticMatrix
<
DT1
,
ST1
,
n
,
m
>&
a
)
{
Matrix
<
decltype
(
DT1
(
0
)),
n
,
m
>
res
(
a
);
res
*=
-
1
;
return
res
;
}
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
,
UInt
n
>
__device__
__host__
SymMatrix
<
decltype
(
DT1
(
0
)
+
DT2
(
0
)),
n
>
operator
+
(
const
StaticSymMatrix
<
DT1
,
ST1
,
n
>&
a
,
const
StaticSymMatrix
<
DT2
,
ST2
,
n
>&
b
)
{
SymMatrix
<
decltype
(
DT1
(
0
)
+
DT2
(
0
)),
n
>
res
(
a
);
res
+=
b
;
return
res
;
}
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
,
UInt
n
>
__device__
__host__
SymMatrix
<
decltype
(
DT1
(
0
)
-
DT2
(
0
)),
n
>
operator
-
(
const
StaticSymMatrix
<
DT1
,
ST1
,
n
>&
a
,
const
StaticSymMatrix
<
DT2
,
ST2
,
n
>&
b
)
{
SymMatrix
<
decltype
(
DT1
(
0
)
-
DT2
(
0
)),
n
>
res
(
a
);
res
-=
b
;
return
res
;
}
template
<
typename
DT1
,
typename
ST1
,
UInt
dim
>
__device__
__host__
SymMatrix
<
decltype
(
DT1
(
0
)),
dim
>
operator
-
(
const
StaticSymMatrix
<
DT1
,
ST1
,
dim
>&
a
)
{
SymMatrix
<
decltype
(
DT1
(
0
)),
dim
>
res
(
a
);
res
*=
-
1
;
return
res
;
}
template
<
typename
DT
,
typename
ST
,
typename
T
,
UInt
n
,
typename
=
std
::
enable_if_t
<
is_arithmetic
<
T
>::
value
>>
Vector
<
decltype
(
DT
(
0
)
*
T
(
0
)),
n
>
operator
*
(
const
StaticVector
<
DT
,
ST
,
n
>&
a
,
const
T
&
b
)
{
Vector
<
decltype
(
DT
(
0
)
*
T
(
0
)),
n
>
res
{
a
};
res
*=
b
;
return
res
;
}
// symmetry
template
<
typename
DT
,
typename
ST
,
typename
T
,
UInt
n
,
typename
=
std
::
enable_if_t
<
is_arithmetic
<
T
>::
value
>>
Vector
<
decltype
(
DT
(
0
)
*
T
(
0
)),
n
>
operator
*
(
const
T
&
b
,
const
StaticVector
<
DT
,
ST
,
n
>&
a
)
{
return
a
*
b
;
}
template
<
typename
DT
,
typename
ST
,
typename
T
,
UInt
n
,
UInt
m
,
typename
=
std
::
enable_if_t
<
is_arithmetic
<
T
>::
value
>>
Matrix
<
decltype
(
DT
(
0
)
*
T
(
0
)),
n
,
m
>
operator
*
(
const
StaticMatrix
<
DT
,
ST
,
n
,
m
>&
a
,
const
T
&
b
)
{
Matrix
<
decltype
(
DT
(
0
)
*
T
(
0
)),
n
,
m
>
res
{
a
};
res
*=
b
;
return
res
;
}
// symmetry
template
<
typename
DT
,
typename
ST
,
typename
T
,
UInt
n
,
UInt
m
,
typename
=
std
::
enable_if_t
<
is_arithmetic
<
T
>::
value
,
void
>>
Matrix
<
decltype
(
DT
(
0
)
*
T
(
0
)),
n
,
m
>
operator
*
(
const
T
&
b
,
const
StaticMatrix
<
DT
,
ST
,
n
,
m
>&
a
)
{
return
a
*
b
;
}
template
<
typename
DT
,
typename
ST
,
typename
T
,
UInt
n
,
typename
=
std
::
enable_if_t
<
is_arithmetic
<
T
>::
value
>>
SymMatrix
<
decltype
(
DT
(
0
)
*
T
(
0
)),
n
>
operator
*
(
const
StaticSymMatrix
<
DT
,
ST
,
n
>&
a
,
const
T
&
b
)
{
SymMatrix
<
decltype
(
DT
(
0
)
*
T
(
0
)),
n
>
res
{
a
};
res
*=
b
;
return
res
;
}
// symmetry
template
<
typename
DT
,
typename
ST
,
typename
T
,
UInt
n
,
typename
=
std
::
enable_if_t
<
is_arithmetic
<
T
>::
value
>>
SymMatrix
<
decltype
(
DT
(
0
)
*
T
(
0
)),
n
>
operator
*
(
const
T
&
b
,
const
StaticSymMatrix
<
DT
,
ST
,
n
>&
a
)
{
return
a
*
b
;
}
/* -------------------------------------------------------------------------- */
/* Linear algebra operators */
/* -------------------------------------------------------------------------- */
/// Matrix-vector multiplication
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
,
UInt
n
,
UInt
m
>
__device__
__host__
Vector
<
decltype
(
DT1
(
0
)
*
DT2
(
0
)),
n
>
operator
*
(
const
StaticMatrix
<
DT1
,
ST1
,
n
,
m
>&
a
,
const
StaticVector
<
DT2
,
ST2
,
m
>&
b
)
{
Vector
<
decltype
(
DT1
(
0
)
*
DT2
(
0
)),
n
>
res
;
res
.
template
mul
<
false
>
(
a
,
b
);
return
res
;
}
/// Matrix-matrix multiplication
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
,
UInt
n
,
UInt
m
,
UInt
l
>
__device__
__host__
Matrix
<
decltype
(
DT1
(
0
)
*
DT2
(
0
)),
n
,
m
>
operator
*
(
const
StaticMatrix
<
DT1
,
ST1
,
n
,
l
>&
a
,
const
StaticMatrix
<
DT2
,
ST2
,
l
,
m
>&
b
)
{
Matrix
<
decltype
(
DT1
(
0
)
*
DT2
(
0
)),
n
,
m
>
res
;
res
.
mul
(
a
,
b
);
return
res
;
}
template
<
typename
DT1
,
typename
ST1
,
typename
DT2
,
typename
ST2
,
UInt
n
,
UInt
m
>
__device__
__host__
Matrix
<
decltype
(
DT1
(
0
)
*
DT2
(
0
)),
n
,
m
>
outer
(
const
StaticVector
<
DT1
,
ST1
,
n
>&
a
,
const
StaticVector
<
DT2
,
ST2
,
m
>&
b
)
{
Matrix
<
decltype
(
DT1
(
0
)
*
DT2
(
0
)),
n
,
m
>
res
;
res
.
outer
(
a
,
b
);
return
res
;
}
/* -------------------------------------------------------------------------- */
/* Dense/Sparse */
/* -------------------------------------------------------------------------- */
template
<
typename
DT
,
typename
ST
,
UInt
n
>
__device__
__host__
Matrix
<
std
::
remove_cv_t
<
DT
>
,
n
,
n
>
dense
(
const
StaticSymMatrix
<
DT
,
ST
,
n
>&
m
)
{
Matrix
<
std
::
remove_cv_t
<
DT
>
,
n
,
n
>
res
;
res
.
fromSymmetric
(
m
);
return
res
;
}
template
<
typename
DT
,
typename
ST
,
UInt
n
>
__device__
__host__
auto
dense
(
const
StaticVector
<
DT
,
ST
,
n
>&
v
)
{
return
v
;
}
template
<
typename
DT
,
typename
ST
,
UInt
n
>
__device__
__host__
SymMatrix
<
std
::
remove_cv_t
<
DT
>
,
n
>
symmetrize
(
const
StaticMatrix
<
DT
,
ST
,
n
,
n
>&
m
)
{
SymMatrix
<
std
::
remove_cv_t
<
DT
>
,
n
>
res
;
res
.
symmetrize
(
m
);
return
res
;
}
/* -------------------------------------------------------------------------- */
template
<
typename
DT
,
typename
ST
>
__device__
__host__
Vector
<
std
::
remove_cv_t
<
DT
>
,
3
>
invariants
(
const
StaticSymMatrix
<
DT
,
ST
,
3
>&
m
)
{
return
{{
// I1 = tr(A)
m
.
trace
(),
// I2 = 1/2 * (tr(A)^2 - tr(A^2))
m
(
0
)
*
m
(
1
)
+
m
(
1
)
*
m
(
2
)
+
m
(
0
)
*
m
(
2
)
-
m
(
3
)
*
m
(
3
)
*
0.5
-
m
(
4
)
*
m
(
4
)
*
0.5
-
m
(
5
)
*
m
(
5
)
*
0.5
,
// I3 = det(A)
m
(
0
)
*
m
(
1
)
*
m
(
2
)
+
m
(
5
)
*
m
(
3
)
*
m
(
4
)
/
std
::
sqrt
(
2
)
-
m
(
4
)
*
m
(
4
)
*
m
(
1
)
*
0.5
-
m
(
3
)
*
m
(
3
)
*
m
(
0
)
*
0.5
-
m
(
5
)
*
m
(
5
)
*
m
(
2
)
*
0.5
}};
}
template
<
typename
DT
,
typename
ST
>
__device__
__host__
Vector
<
std
::
remove_cv_t
<
DT
>
,
3
>
eigenvalues
(
const
StaticSymMatrix
<
DT
,
ST
,
3
>&
m
)
{
constexpr
UInt
n
=
3
;
Vector
<
std
::
remove_cv_t
<
DT
>
,
n
>
eigenv
;
auto
inv
=
invariants
(
m
);
Real
a
=
1
,
b
=
-
inv
(
0
),
c
=
inv
(
1
),
d
=
-
inv
(
2
);
auto
p
=
(
3
*
a
*
c
-
b
*
b
)
/
(
3
*
a
*
a
);
auto
q
=
(
2
*
b
*
b
*
b
-
9
*
a
*
b
*
c
+
27
*
a
*
a
*
d
)
/
(
27
*
a
*
a
*
a
);
for
(
UInt
k
=
0
;
k
<
n
;
++
k
)
eigenv
(
k
)
=
2.
*
std
::
sqrt
(
-
p
/
3.
)
*
std
::
cos
(
1.
/
3.
*
std
::
acos
(
3.
*
q
/
(
2.
*
p
)
*
std
::
sqrt
(
-
3.
/
p
))
-
2.
*
M_PI
*
k
/
3.
)
-
b
/
(
3.
*
a
);
thrust
::
sort
(
eigenv
.
begin
(),
eigenv
.
end
());
return
eigenv
;
}
/* -------------------------------------------------------------------------- */
/* Type traits */
/* -------------------------------------------------------------------------- */
template
<
class
Type
>
struct
is_proxy
:
std
::
false_type
{};
template
<
template
<
typename
,
typename
,
UInt
...
>
class
StaticParent
,
typename
T
,
UInt
...
dims
>
struct
is_proxy
<
TensorProxy
<
StaticParent
,
T
,
dims
...
>>
:
std
::
true_type
{};
template
<
typename
T
,
UInt
n
,
UInt
m
>
struct
is_proxy
<
MatrixProxy
<
T
,
n
,
m
>>
:
std
::
true_type
{};
template
<
typename
T
,
UInt
n
>
struct
is_proxy
<
SymMatrixProxy
<
T
,
n
>>
:
std
::
true_type
{};
template
<
typename
T
,
UInt
n
>
struct
is_proxy
<
VectorProxy
<
T
,
n
>>
:
std
::
true_type
{};
}
// namespace tamaas
#endif
// __STATIC_TYPES_HH__
Event Timeline
Log In to Comment