diff --git a/site_scons/detect.py b/site_scons/detect.py
index 6d756be..8c5b2cf 100644
--- a/site_scons/detect.py
+++ b/site_scons/detect.py
@@ -1,260 +1,262 @@
# -*- coding: utf-8 -*-
#
# Copyright (©) 2016-2023 EPFL (École Polytechnique Fédérale de Lausanne),
# Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
# Copyright (©) 2020-2023 Lucas Frérot
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
from __future__ import print_function
from SCons.Script import Configure, Dir
from SCons.Errors import StopError
# ------------------------------------------------------------------------------
def _get_path(env, ext, module_var):
path = ""
if module_var in env and env[module_var] != "":
root = Dir(env[module_var])
if not root.exists() or not root.isdir():
raise RuntimeError("{} is set to a non-existing path '{}'".format(
module_var, root))
path = Dir(ext, root)
if not path.exists() or not path.isdir():
raise RuntimeError("{} does not contain '{}' directory".format(
module_var, ext))
return [path]
# ------------------------------------------------------------------------------
def FindFFTW(env, components=None, precision='double', module_var='FFTW_ROOT'):
"""Find FFTW3 and set relevant environment variables"""
if not env.get('should_configure', True):
return
if components is None:
components = []
fftw_vars = {}
fftw_vars['CPPPATH'] = _get_path(env, 'include', module_var)
fftw_vars['LIBPATH'] = _get_path(env, 'lib', module_var)
try:
fftw_vars['LIBPATH'] += _get_path(env, 'lib64', module_var)
except RuntimeError:
pass
fftw_vars['RPATH'] = fftw_vars['LIBPATH']
if 'threads' in components:
fftw_vars['LIBS'] = ['pthread']
# Setting up FFTW
wishes = ['main'] + components
fftw_name = "fftw3{}"
# Components
lib_names = {'main': fftw_name,
'threads': fftw_name + '_threads',
'omp': fftw_name + '_omp',
'mpi': fftw_name + '_mpi'}
# Checking list of wishes
try:
libs = [lib_names[i].format("") for i in wishes]
except KeyError:
raise StopError(
'Incompatible FFTW wishlist {0}'.format(wishes))
# Add long precision libraries
if precision == 'long double':
libs += [lib_names[i].format("l") for i in wishes]
+ elif precision == 'float':
+ libs += [lib_names[i].format("f") for i in wishes]
conf_env = env.Clone(**fftw_vars)
conf = Configure(conf_env)
for lib in libs:
inc_names = ['fftw3.h']
if 'mpi' in lib:
inc_names.append('fftw3-mpi.h')
if not conf.CheckLibWithHeader(lib, inc_names, 'c++'):
raise StopError(
('Failed to find library {0} or '
'headers {1}. Check the build options "fftw_threads"'
' and "FFTW_ROOT".').format(lib, str(inc_names)))
conf_env = conf.Finish()
fftw_vars['LIBS'] = fftw_vars.get('LIBS', []) + libs
# Update modified variables
env.AppendUnique(**fftw_vars)
# ------------------------------------------------------------------------------
def FindBoost(env, headers=['boost/version.hpp'], module_var='BOOST_ROOT'):
"""Find Boost and set relevant environment variables"""
if not env.get('should_configure', True):
return
boost_vars = {}
boost_vars['CPPPATH'] = _get_path(env, 'include', module_var)
conf_env = env.Clone(**boost_vars)
conf = Configure(conf_env)
if not conf.CheckCXXHeader(headers):
raise StopError(
'Failed to find Boost headers {}'.format(headers))
conf_env = conf.Finish()
# Update modified variables
env.AppendUnique(**boost_vars)
# ------------------------------------------------------------------------------
def FindThrust(env, backend='omp', module_var='THRUST_ROOT'):
"""Find Thrust and set relevant environment variables"""
if not env.get('should_configure', True):
return
if backend not in ('cpp', 'omp', 'cuda', 'tbb'):
raise StopError(
'Unknown thrust backend "{}"'.format(backend))
thrust_vars = {}
try:
thrust_vars['CPPPATH'] = _get_path(env, 'include', module_var)
except RuntimeError:
thrust_vars['CPPPATH'] = _get_path(env, '', module_var)
if "clang++" in env['CXX']:
thrust_vars['CXXFLAGS'] = ["-Wno-unused-local-typedef"]
thrust_vars['CPPDEFINES'] = [
"THRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_{}".format(backend.upper())
]
if backend == 'cuda':
thrust_vars['CPPDEFINES'].append("TAMAAS_USE_CUDA")
conf_env = env.Clone(**thrust_vars)
conf = Configure(conf_env)
if not conf.CheckCXXHeader('thrust/version.h'):
raise StopError(
'Failed to find Thrust')
conf_env = conf.Finish()
# Update modified variables
env.AppendUnique(**thrust_vars)
# ------------------------------------------------------------------------------
def FindPybind11(env, module_var='PYBIND11_ROOT'):
"""Detech Pybind11 and set appropriate build variables"""
if not env.get('should_configure', True) \
or env.get("PYBIND11_FOUND", False):
return
pybind11_vars = {}
clone = env.Clone(CPPPATH=[])
clone.ParseConfig('${py_exec}-config --includes')
clone.AppendUnique(CPPPATH=[_get_path(env, 'include', module_var)])
pybind11_vars['CPPPATH'] = clone['CPPPATH']
conf = Configure(clone)
if not conf.CheckCXXHeader('pybind11/pybind11.h'):
raise StopError(
"Failed to find pybind11 header.\n"
"Install pybind11 headers or "
"set PYBIND11_ROOT='#third-party/pybind11' "
"to use vendored dependency."
)
conf.Finish()
pybind11_vars['PYBIND11_FOUND'] = True
# Update variables
env.AppendUnique(**pybind11_vars)
# ------------------------------------------------------------------------------
def FindCuda(env, module_var="CUDA_ROOT"):
"""Detect cuda on clusters"""
if not env.get('should_configure', True):
return
if 'CUDA_ROOT' in env:
env['CUDA_TOOLKIT_PATH'] = _get_path(env, '', module_var)[0]
else:
env['CUDA_TOOLKIT_PATH'] = '/opt/cuda'
env['CUDA_COMPONENTS'] = ['cufft']
env['CUDA_ARCH_FLAG'] = '-arch=sm_60'
colors = env['COLOR_DICT']
if not env['verbose']:
env['SHCXX'] = 'nvcc'
env['NVCCCOMSTR'] = env['SHCXXCOMSTR']
env['SHLINKCOMSTR'] = \
u'{0}[Linking (cuda)] {1}$TARGET'.format(colors['purple'],
colors['end'])
env.AppendUnique(CXXFLAGS=[
"-expt-extended-lambda", # experimental lambda support
"-expt-relaxed-constexpr", # experimental constexpr
])
if env['build_type'] == 'debug':
env.AppendUnique(CXXFLAGS="-G")
env.Tool('nvcc')
# ------------------------------------------------------------------------------
def FindGTest(env):
"""A tool to configure GoogleTest"""
if not env.get('should_configure', True):
return
conf = Configure(env)
env['has_gtest'] = conf.CheckCXXHeader('gtest/gtest.h')
env = conf.Finish()
# ------------------------------------------------------------------------------
def FindExpolit(env):
"""A tool to configure Expolit"""
if not env.get('should_configure', True):
return
expolit_vars = {"CPPPATH": "#/third-party/expolit/include"}
conf_env = env.Clone()
conf_env.AppendUnique(**expolit_vars)
conf = Configure(conf_env)
if not conf.CheckCXXHeader('expolit/expolit'):
raise StopError(
'Failed to find Expolit header\n' +
"Run 'git submodule update --init " +
"third-party/expolit'")
conf_env = conf.Finish()
env.AppendUnique(**expolit_vars)
diff --git a/src/core/fftw/interface_impl.hh b/src/core/fftw/interface_impl.hh
index 5498225..5df3c67 100644
--- a/src/core/fftw/interface_impl.hh
+++ b/src/core/fftw/interface_impl.hh
@@ -1,206 +1,293 @@
/*
* SPDX-License-Indentifier: AGPL-3.0-or-later
*
* Copyright (©) 2016-2023 EPFL (École Polytechnique Fédérale de Lausanne),
* Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
* Copyright (©) 2020-2023 Lucas Frérot
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*
*/
/* -------------------------------------------------------------------------- */
#ifndef FFTW_INTERFACE_IMPL_HH
#define FFTW_INTERFACE_IMPL_HH
/* -------------------------------------------------------------------------- */
#include
#include
#include
#include
#include
namespace fftw_impl {
-template
-struct helper;
-
-template <>
-struct helper {
- using complex = fftw_complex;
- using plan = fftw_plan;
-
- static auto alloc_real(std::size_t size) { return fftw_alloc_real(size); }
- static auto alloc_complex(std::size_t size) {
- return fftw_alloc_complex(size);
- }
-};
-template <>
-struct helper {
- using complex = fftwl_complex;
- using plan = fftwl_plan;
-
- static auto alloc_real(std::size_t size) { return fftwl_alloc_real(size); }
- static auto alloc_complex(std::size_t size) {
- return fftwl_alloc_complex(size);
- }
-};
+/* -------------------------------------------------------------------------- */
+/* Precision independent functions */
+/* -------------------------------------------------------------------------- */
+/// Free memory
template
inline auto free(T* ptr) {
fftw_free(ptr);
}
-inline auto destroy(fftw_plan plan) { fftw_destroy_plan(plan); }
-inline auto destroy(fftwl_plan plan) { fftwl_destroy_plan(plan); }
+/// Init FFTW with threads
inline auto init_threads() { return fftw_init_threads(); }
+
+/// Set number of threads
inline auto plan_with_nthreads(int nthreads) {
return fftw_plan_with_nthreads(nthreads);
}
-inline auto cleanup_threads() { return fftw_cleanup_threads(); }
-
-/// Holder type for fftw plans
-template
-struct plan {
- typename helper::plan _plan;
-
- /// Create from plan
- explicit plan(typename helper::plan _plan = nullptr) : _plan(_plan) {}
- /// Move constructor to avoid accidental plan destruction
- plan(plan&& o) noexcept : _plan(std::exchange(o._plan, nullptr)) {}
- /// Move operator
- plan& operator=(plan&& o) noexcept {
- _plan = std::exchange(o._plan, nullptr);
- return *this;
- }
- /// Destroy plan
- ~plan() noexcept {
- if (_plan)
- fftw_impl::destroy(_plan);
- }
-
- /// For seamless use with fftw api
- operator typename helper::plan() const { return _plan; }
-};
-
-/// RAII helper for fftw_free
-template
-struct ptr {
- T* _ptr;
- ~ptr() noexcept {
- if (_ptr)
- fftw_impl::free(_ptr);
- }
-
- operator T*() { return _ptr; }
-};
+/// Cleanup threads
+inline auto cleanup_threads() { return fftw_cleanup_threads(); }
+/* -------------------------------------------------------------------------- */
+/* double precision API */
/* -------------------------------------------------------------------------- */
inline auto plan_many_forward(int rank, const int* n, int howmany, double* in,
const int* inembed, int istride, int idist,
fftw_complex* out, const int* onembed,
int ostride, int odist, unsigned flags) {
return fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
}
inline auto plan_many_backward(int rank, const int* n, int howmany,
fftw_complex* in, const int* inembed,
int istride, int idist, double* out,
const int* onembed, int ostride, int odist,
unsigned flags) {
return fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
}
inline auto plan_1d_forward(int n, double* in, fftw_complex* out,
unsigned flags) {
return fftw_plan_dft_r2c_1d(n, in, out, flags);
}
inline auto plan_1d_backward(int n, fftw_complex* in, double* out,
unsigned flags) {
return fftw_plan_dft_c2r_1d(n, in, out, flags);
}
inline auto plan_2d_forward(int n0, int n1, double* in, fftw_complex* out,
unsigned flags) {
return fftw_plan_dft_r2c_2d(n0, n1, in, out, flags);
}
inline auto plan_2d_backward(int n0, int n1, fftw_complex* out, double* in,
unsigned flags) {
return fftw_plan_dft_c2r_2d(n0, n1, out, in, flags);
}
inline auto execute(fftw_plan plan) { fftw_execute(plan); }
inline auto execute(fftw_plan plan, double* in, fftw_complex* out) {
fftw_execute_dft_r2c(plan, in, out);
}
inline auto execute(fftw_plan plan, fftw_complex* in, double* out) {
fftw_execute_dft_c2r(plan, in, out);
}
+inline auto destroy(fftw_plan plan) { fftw_destroy_plan(plan); }
+
+/* -------------------------------------------------------------------------- */
+/* long double precision API */
/* -------------------------------------------------------------------------- */
inline auto plan_many_forward(int rank, const int* n, int howmany,
long double* in, const int* inembed, int istride,
int idist, fftwl_complex* out, const int* onembed,
int ostride, int odist, unsigned flags) {
return fftwl_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
}
inline auto plan_many_backward(int rank, const int* n, int howmany,
fftwl_complex* in, const int* inembed,
int istride, int idist, long double* out,
const int* onembed, int ostride, int odist,
unsigned flags) {
return fftwl_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
}
inline auto plan_1d_forward(int n, long double* in, fftwl_complex* out,
unsigned flags) {
return fftwl_plan_dft_r2c_1d(n, in, out, flags);
}
inline auto plan_1d_backward(int n, fftwl_complex* in, long double* out,
unsigned flags) {
return fftwl_plan_dft_c2r_1d(n, in, out, flags);
}
inline auto plan_2d_forward(int n0, int n1, long double* in, fftwl_complex* out,
unsigned flags) {
return fftwl_plan_dft_r2c_2d(n0, n1, in, out, flags);
}
inline auto plan_2d_backward(int n0, int n1, fftwl_complex* out,
long double* in, unsigned flags) {
return fftwl_plan_dft_c2r_2d(n0, n1, out, in, flags);
}
inline auto execute(fftwl_plan plan) { fftwl_execute(plan); }
inline auto execute(fftwl_plan plan, long double* in, fftwl_complex* out) {
fftwl_execute_dft_r2c(plan, in, out);
}
inline auto execute(fftwl_plan plan, fftwl_complex* in, long double* out) {
fftwl_execute_dft_c2r(plan, in, out);
}
+inline auto destroy(fftwl_plan plan) { fftwl_destroy_plan(plan); }
+
+/* -------------------------------------------------------------------------- */
+/* single precision API */
+/* -------------------------------------------------------------------------- */
+
+inline auto plan_many_forward(int rank, const int* n, int howmany, float* in,
+ const int* inembed, int istride, int idist,
+ fftwf_complex* out, const int* onembed,
+ int ostride, int odist, unsigned flags) {
+ return fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
+ out, onembed, ostride, odist, flags);
+}
+
+inline auto plan_many_backward(int rank, const int* n, int howmany,
+ fftwf_complex* in, const int* inembed,
+ int istride, int idist, float* out,
+ const int* onembed, int ostride, int odist,
+ unsigned flags) {
+ return fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
+ out, onembed, ostride, odist, flags);
+}
+
+inline auto plan_1d_forward(int n, float* in, fftwf_complex* out,
+ unsigned flags) {
+ return fftwf_plan_dft_r2c_1d(n, in, out, flags);
+}
+
+inline auto plan_1d_backward(int n, fftwf_complex* in, float* out,
+ unsigned flags) {
+ return fftwf_plan_dft_c2r_1d(n, in, out, flags);
+}
+
+inline auto plan_2d_forward(int n0, int n1, float* in, fftwf_complex* out,
+ unsigned flags) {
+ return fftwf_plan_dft_r2c_2d(n0, n1, in, out, flags);
+}
+
+inline auto plan_2d_backward(int n0, int n1, fftwf_complex* out, float* in,
+ unsigned flags) {
+ return fftwf_plan_dft_c2r_2d(n0, n1, out, in, flags);
+}
+
+inline auto execute(fftwf_plan plan) { fftwf_execute(plan); }
+inline auto execute(fftwf_plan plan, float* in, fftwf_complex* out) {
+ fftwf_execute_dft_r2c(plan, in, out);
+}
+
+inline auto execute(fftwf_plan plan, fftwf_complex* in, float* out) {
+ fftwf_execute_dft_c2r(plan, in, out);
+}
+
+inline auto destroy(fftwf_plan plan) { fftwf_destroy_plan(plan); }
+
+/* -------------------------------------------------------------------------- */
+/* Helper objects */
+/* -------------------------------------------------------------------------- */
+
+/// Allocation helper for different float types
+template
+struct helper;
+
+template <>
+struct helper {
+ using complex = fftw_complex;
+ using plan = fftw_plan;
+
+ static auto alloc_real(std::size_t size) { return fftw_alloc_real(size); }
+ static auto alloc_complex(std::size_t size) {
+ return fftw_alloc_complex(size);
+ }
+};
+
+template <>
+struct helper {
+ using complex = fftwl_complex;
+ using plan = fftwl_plan;
+
+ static auto alloc_real(std::size_t size) { return fftwl_alloc_real(size); }
+ static auto alloc_complex(std::size_t size) {
+ return fftwl_alloc_complex(size);
+ }
+};
+
+template <>
+struct helper {
+ using complex = fftwf_complex;
+ using plan = fftwf_plan;
+
+ static auto alloc_real(std::size_t size) { return fftwf_alloc_real(size); }
+ static auto alloc_complex(std::size_t size) {
+ return fftwf_alloc_complex(size);
+ }
+};
+
+/* -------------------------------------------------------------------------- */
+
+/// Holder type for fftw plans
+template
+struct plan {
+ typename helper::plan _plan;
+
+ /// Create from plan
+ explicit plan(typename helper::plan _plan = nullptr) : _plan(_plan) {}
+ /// Move constructor to avoid accidental plan destruction
+ plan(plan&& o) noexcept : _plan(std::exchange(o._plan, nullptr)) {}
+ /// Move operator
+ plan& operator=(plan&& o) noexcept {
+ _plan = std::exchange(o._plan, nullptr);
+ return *this;
+ }
+ /// Destroy plan
+ ~plan() noexcept {
+ if (_plan)
+ fftw_impl::destroy(_plan);
+ }
+
+ /// For seamless use with fftw api
+ operator typename helper::plan() const { return _plan; }
+};
+
+/// RAII helper for fftw_free
+template
+struct ptr {
+ T* _ptr;
+
+ ~ptr() noexcept {
+ if (_ptr)
+ fftw_impl::free(_ptr);
+ }
+
+ operator T*() { return _ptr; }
+};
+
} // namespace fftw_impl
#endif // FFTW_INTERFACE
diff --git a/src/core/mpi_interface.cpp b/src/core/mpi_interface.cpp
index f602fb1..a0464c6 100644
--- a/src/core/mpi_interface.cpp
+++ b/src/core/mpi_interface.cpp
@@ -1,66 +1,68 @@
/*
* SPDX-License-Indentifier: AGPL-3.0-or-later
*
* Copyright (©) 2016-2023 EPFL (École Polytechnique Fédérale de Lausanne),
* Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
* Copyright (©) 2020-2023 Lucas Frérot
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*
*/
/* -------------------------------------------------------------------------- */
#include "mpi_interface.hh"
/* -------------------------------------------------------------------------- */
namespace tamaas {
namespace mpi_dummy {
comm comm::world;
}
#ifdef TAMAAS_USE_MPI
namespace mpi_impl {
comm sequential::old_comm{MPI_COMM_NULL};
comm& comm::world() {
static comm _world{MPI_COMM_WORLD};
return _world;
}
// Define type traits for MPI data types
#define TYPE(t, mpi_t) \
const MPI_Datatype type_trait::value { mpi_t }
+TYPE(float, MPI_FLOAT);
TYPE(double, MPI_DOUBLE);
TYPE(int, MPI_INT);
TYPE(unsigned int, MPI_UNSIGNED);
TYPE(long double, MPI_LONG_DOUBLE);
TYPE(long, MPI_LONG);
TYPE(unsigned long, MPI_UNSIGNED_LONG);
+TYPE(::thrust::complex, MPI_CXX_FLOAT_COMPLEX);
TYPE(::thrust::complex, MPI_CXX_DOUBLE_COMPLEX);
TYPE(::thrust::complex, MPI_CXX_LONG_DOUBLE_COMPLEX);
TYPE(bool, MPI_CXX_BOOL);
#undef TYPE
// Define type traits for MPI operations
#define OPERATION(op, mpi_op) \
const MPI_Op operation_trait::value { mpi_op }
OPERATION(plus, MPI_SUM);
OPERATION(min, MPI_MIN);
OPERATION(max, MPI_MAX);
OPERATION(times, MPI_PROD);
#undef OPERATION
} // namespace mpi_impl
#endif
} // namespace tamaas
diff --git a/src/core/mpi_interface.hh b/src/core/mpi_interface.hh
index e379f90..ab0b3fa 100644
--- a/src/core/mpi_interface.hh
+++ b/src/core/mpi_interface.hh
@@ -1,293 +1,295 @@
/*
* SPDX-License-Indentifier: AGPL-3.0-or-later
*
* Copyright (©) 2016-2023 EPFL (École Polytechnique Fédérale de Lausanne),
* Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
* Copyright (©) 2020-2023 Lucas Frérot
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*
*/
/* -------------------------------------------------------------------------- */
#ifndef MPI_INTERFACE_HH
#define MPI_INTERFACE_HH
/* -------------------------------------------------------------------------- */
#include "static_types.hh"
#include "tamaas.hh"
#include
#include
#include
#ifdef TAMAAS_USE_MPI
#include
#endif
/* -------------------------------------------------------------------------- */
namespace tamaas {
/* -------------------------------------------------------------------------- */
/// Contains mock mpi functions
namespace mpi_dummy {
struct comm {
static comm world;
};
struct sequential {
static void enter(){};
static void exit(){};
};
struct sequential_guard {
sequential_guard() { sequential::enter(); }
~sequential_guard() { sequential::exit(); }
};
enum class thread : int { single, funneled, serialized, multiple };
inline bool initialized() { return false; }
inline bool finalized() { return false; }
inline int init(int* /*unused*/, char*** /*unused*/) { return 0; }
inline int init_thread(int* /*unused*/, char*** /*unused*/, thread /*unused*/,
thread* provided) {
*provided = thread::funneled;
return 0;
}
inline int finalize() { return 0; }
[[noreturn]] inline void abort(int) noexcept { std::abort(); }
// Defining error symbols used in code
#ifndef TAMAAS_USE_MPI
#define MPI_ERR_TOPOLOGY 0
#endif
inline int rank(comm /*unused*/ = comm::world) { return 0; }
inline int size(comm /*unused*/ = comm::world) { return 1; }
template
inline decltype(auto) reduce(T&& val, comm /*unused*/ = comm::world) {
return std::forward(val);
}
template
inline decltype(auto) allreduce(T&& val, comm /*unused*/ = comm::world) {
return std::forward(val);
}
template
inline decltype(auto) gather(const T* send, T* recv, int count,
comm /*unused*/ = comm::world) {
if (send == recv)
return;
thrust::copy_n(send, count, recv);
}
template
inline decltype(auto) scatter(const T* send, T* recv, int count,
comm /*unused*/ = comm::world) {
if (send == recv)
return;
thrust::copy_n(send, count, recv);
}
template
inline decltype(auto) scatterv(const T* send,
const std::vector& /*unused*/,
const std::vector& /*unused*/, T* recv,
int recvcount, comm /*unused*/ = comm::world) {
scatter(send, recv, recvcount);
}
template
inline decltype(auto) bcast(T* /*unused*/, int /*unused*/,
comm /*unused*/ = comm::world) {}
} // namespace mpi_dummy
/* -------------------------------------------------------------------------- */
#ifdef TAMAAS_USE_MPI
/// Contains real mpi functions
namespace mpi_impl {
/// MPI_Comm wrapper
struct comm {
MPI_Comm _comm;
operator MPI_Comm() const { return _comm; }
static comm& world();
};
struct sequential {
static void enter() {
sequential::old_comm._comm = comm::world()._comm;
comm::world()._comm = MPI_COMM_SELF;
}
static void exit() { comm::world()._comm = sequential::old_comm._comm; }
static comm old_comm;
};
struct sequential_guard {
sequential_guard() { sequential::enter(); }
~sequential_guard() { sequential::exit(); }
};
/// MPI Thread level
enum class thread : int {
single = MPI_THREAD_SINGLE,
funneled = MPI_THREAD_FUNNELED,
serialized = MPI_THREAD_SERIALIZED,
multiple = MPI_THREAD_MULTIPLE
};
template
struct type_trait;
#define TYPE(t, mpi_t) \
template <> \
struct type_trait { \
static const MPI_Datatype value; \
}
+TYPE(float, MPI_FLOAT);
TYPE(double, MPI_DOUBLE);
TYPE(int, MPI_INT);
TYPE(unsigned int, MPI_UNSIGNED);
TYPE(long double, MPI_LONG_DOUBLE);
TYPE(long, MPI_LONG);
TYPE(unsigned long, MPI_UNSIGNED_LONG);
+TYPE(::thrust::complex, MPI_CXX_FLOAT_COMPLEX);
TYPE(::thrust::complex, MPI_CXX_DOUBLE_COMPLEX);
TYPE(::thrust::complex, MPI_CXX_LONG_DOUBLE_COMPLEX);
TYPE(bool, MPI_CXX_BOOL);
#undef TYPE
template
struct operation_trait;
#define OPERATION(op, mpi_op) \
template <> \
struct operation_trait { \
static const MPI_Op value; \
}
OPERATION(plus, MPI_SUM);
OPERATION(min, MPI_MIN);
OPERATION(max, MPI_MAX);
OPERATION(times, MPI_PROD);
#undef OPERATION
inline bool initialized() {
int has_init;
MPI_Initialized(&has_init);
return has_init != 0;
}
inline bool finalized() {
int has_final;
MPI_Finalized(&has_final);
return has_final != 0;
}
inline int init(int* argc, char*** argv) { return MPI_Init(argc, argv); }
inline int init_thread(int* argc, char*** argv, thread required,
thread* provided) {
return MPI_Init_thread(argc, argv, static_cast(required),
reinterpret_cast(provided));
}
inline int finalize() { return MPI_Finalize(); }
inline void abort(int errcode) noexcept { MPI_Abort(comm::world(), errcode); }
inline int rank(comm communicator = comm::world()) {
int rank;
MPI_Comm_rank(communicator, &rank);
return rank;
}
inline int size(comm communicator = comm::world()) {
int size;
MPI_Comm_size(communicator, &size);
return size;
}
template
inline decltype(auto) reduce(T val, comm communicator = comm::world()) {
MPI_Reduce(&val, &val, 1, type_trait::value, operation_trait::value, 0,
communicator);
return val;
}
template ::value or
std::is_same::value>>
inline decltype(auto) allreduce(T val, comm communicator = comm::world()) {
MPI_Allreduce(&val, &val, 1, type_trait::value, operation_trait::value,
communicator);
return val;
}
template
inline decltype(auto) allreduce(const StaticVector& v,
comm communicator = comm::world()) {
Vector res;
MPI_Allreduce(v.begin(), res.begin(), n, type_trait::value,
operation_trait::value, communicator);
return res;
}
template
inline decltype(auto) allreduce(const StaticMatrix& v,
comm communicator = comm::world()) {
Matrix res;
MPI_Allreduce(v.begin(), res.begin(), n * m, type_trait::value,
operation_trait::value, communicator);
return res;
}
template
inline decltype(auto) gather(const T* send, T* recv, int count,
comm communicator = comm::world()) {
MPI_Gather(send, count, type_trait::value, recv, count,
type_trait::value, 0, communicator);
}
template
inline decltype(auto) scatter(const T* send, T* recv, int count,
comm communicator = comm::world()) {
MPI_Scatter(send, count, type_trait::value, recv, count,
type_trait::value, 0, communicator);
}
template
inline decltype(auto)
scatterv(const T* send, const std::vector& sendcounts,
const std::vector& displs, T* recv, int recvcount,
comm communicator = comm::world()) {
MPI_Scatterv(send, sendcounts.data(), displs.data(), type_trait::value,
recv, recvcount, type_trait::value, 0, communicator);
}
template
inline decltype(auto) bcast(T* buffer, int count,
comm communicator = comm::world()) {
MPI_Bcast(buffer, count, type_trait::value, 0, communicator);
}
} // namespace mpi_impl
namespace mpi = mpi_impl;
#else
namespace mpi = mpi_dummy;
#endif // TAMAAS_USE_MPI
} // namespace tamaas
/* -------------------------------------------------------------------------- */
#endif // MPI_INTERFACE_HH
diff --git a/src/core/static_types.hh b/src/core/static_types.hh
index 01edd56..70d01c0 100644
--- a/src/core/static_types.hh
+++ b/src/core/static_types.hh
@@ -1,798 +1,798 @@
/*
* SPDX-License-Indentifier: AGPL-3.0-or-later
*
* Copyright (©) 2016-2023 EPFL (École Polytechnique Fédérale de Lausanne),
* Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides)
* Copyright (©) 2020-2023 Lucas Frérot
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published
* by the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*
*/
/* -------------------------------------------------------------------------- */
#ifndef STATIC_TYPES_HH
#define STATIC_TYPES_HH
/* -------------------------------------------------------------------------- */
#include "tamaas.hh"
#include
#include
#include
namespace tamaas {
/* -------------------------------------------------------------------------- */
namespace detail {
/// \cond HIDDEN_SYMBOLS
template
struct product_tail_rec : product_tail_rec {};
/// \endcond
template
struct product_tail_rec : std::integral_constant {};
/// \cond HIDDEN_SYMBOLS
template
struct get_rec : get_rec {};
/// \endcond
template
struct get_rec<0, n, ns...> : std::integral_constant {};
} // namespace detail
template
struct product : detail::product_tail_rec<1, ns...> {};
template
struct get : detail::get_rec {};
template
struct is_arithmetic : std::is_arithmetic {};
template
struct is_arithmetic> : std::true_type {};
template
struct voigt_size;
template <>
struct voigt_size<3> : std::integral_constant {};
template <>
struct voigt_size<2> : std::integral_constant {};
template <>
struct voigt_size<1> : std::integral_constant {};
/* -------------------------------------------------------------------------- */
/**
* @brief Static Array
*
* This class is meant to be a small and fast object for intermediate
* calculations, possibly on wrapped memory belonging to a grid. Support type
* show be either a pointer or a C array. It should not contain any virtual
* method.
*/
template
class StaticArray {
static_assert(std::is_array::value or
std::is_pointer::value,
"the support type of StaticArray should be either a pointer or "
"a C-array");
using T = DataType;
using T_bare = typename std::remove_cv_t;
public:
using value_type = T;
static constexpr UInt size = _size;
public:
StaticArray() = default;
~StaticArray() = default;
StaticArray(const StaticArray&) = delete;
StaticArray(StaticArray&&) = delete;
StaticArray& operator=(StaticArray&&) = delete;
public:
/// Access operator
__device__ __host__ auto operator()(UInt i) -> T& { return _mem[i]; }
/// Access operator
__device__ __host__ auto operator()(UInt i) const -> const T& {
return _mem[i];
}
/// Scalar product
template
__device__ __host__ auto dot(const StaticArray& o) const
-> T_bare {
decltype(T_bare(0) * DT(0)) res = 0;
for (UInt i = 0; i < size; ++i)
res += (*this)(i)*o(i);
return res;
}
/// L2 norm squared
__device__ __host__ T_bare l2squared() const { return this->dot(*this); }
/// L2 norm
__device__ __host__ T_bare l2norm() const { return std::sqrt(l2squared()); }
/// Sum of all elements
__device__ __host__ T_bare sum() const {
T_bare res = 0;
for (UInt i = 0; i < size; ++i)
res += (*this)(i);
return res;
}
#define VECTOR_OP(op) \
template \
__device__ __host__ void operator op(const StaticArray& o) { \
for (UInt i = 0; i < size; ++i) \
(*this)(i) op o(i); \
}
VECTOR_OP(+=)
VECTOR_OP(-=)
VECTOR_OP(*=)
VECTOR_OP(/=)
#undef VECTOR_OP
#define SCALAR_OP(op) \
template \
__device__ __host__ std::enable_if_t::value, StaticArray&> \
operator op(const T1& x) { \
for (UInt i = 0; i < size; ++i) \
(*this)(i) op x; \
return *this; \
}
SCALAR_OP(+=)
SCALAR_OP(-=)
SCALAR_OP(*=)
SCALAR_OP(/=)
SCALAR_OP(=)
#undef SCALAR_OP
/// Overriding the implicit copy operator
__device__ __host__ StaticArray& operator=(const StaticArray& o) {
copy(o);
return *this;
}
template
__device__ __host__ void operator=(const StaticArray& o) {
copy(o);
}
template
__device__ __host__ StaticArray& copy(const StaticArray& o) {
for (UInt i = 0; i < size; ++i)
(*this)(i) = o(i);
return *this;
}
T* begin() { return _mem; }
const T* begin() const { return _mem; }
T* end() { return _mem + size; }
const T* end() const { return _mem + size; }
private:
template
using valid_size_t = std::enable_if_t<(size > 0), U>;
public:
__host__ __device__ valid_size_t front() { return *_mem; }
__host__ __device__ valid_size_t front() const { return *_mem; }
__host__ __device__ valid_size_t back() { return _mem[size - 1]; }
__host__ __device__ valid_size_t back() const {
return _mem[size - 1];
}
protected:
SupportType _mem;
};
/**
* @brief Static Tensor
*
* This class implements a multi-dimensional tensor behavior.
*/
template
class StaticTensor
: public StaticArray::value> {
using parent = StaticArray::value>;
using T = DataType;
public:
static constexpr UInt dim = sizeof...(dims);
using parent::operator=;
private:
template
__device__ __host__ static UInt unpackOffset(UInt offset, UInt index,
Idx... rest) {
constexpr UInt size = sizeof...(rest);
offset += index;
offset *= get::value;
return unpackOffset(offset, rest...);
}
template
__device__ __host__ static UInt unpackOffset(UInt offset, UInt index) {
return offset + index;
}
public:
template
__device__ __host__ const T& operator()(Idx... idx) const {
return parent::operator()(unpackOffset(0, idx...));
}
template
__device__ __host__ T& operator()(Idx... idx) {
return parent::operator()(unpackOffset(0, idx...));
}
};
/* -------------------------------------------------------------------------- */
/* Common Static Types */
/* -------------------------------------------------------------------------- */
// Forward declaration
template
class StaticVector;
template
class StaticSymMatrix;
/* -------------------------------------------------------------------------- */
template
class StaticMatrix : public StaticTensor {
using T = DataType;
using T_bare = typename std::remove_cv_t;
public:
using StaticTensor::operator=;
// /// Initialize from a symmetric matrix
template
__device__ __host__ std::enable_if_t
fromSymmetric(const StaticSymMatrix& o);
/// Outer product of two vectors
template
__device__ __host__ void outer(const StaticVector& a,
const StaticVector& b);
template
__device__ __host__ void mul(const StaticMatrix& a,
const StaticMatrix& b) {
(*this) = T(0);
for (UInt i = 0; i < n; ++i)
for (UInt j = 0; j < m; ++j)
for (UInt k = 0; k < l; ++k)
(*this)(i, j) += a(i, k) * b(k, j);
}
__device__ __host__ std::enable_if_t trace() const {
T_bare res{0};
for (UInt i = 0; i < n; ++i)
res += (*this)(i, i);
return res;
}
template
__device__ __host__ std::enable_if_t
deviatoric(const StaticMatrix& mat, Real factor = n) {
auto norm_trace = mat.trace() / factor;
for (UInt i = 0; i < n; ++i)
for (UInt j = 0; j < m; ++j)
(*this)(i, j) = mat(i, j) - (i == j) * norm_trace;
}
};
/* -------------------------------------------------------------------------- */
/// Vector class with size determined at compile-time
template
class StaticVector : public StaticTensor {
using T = std::remove_cv_t;
public:
using StaticTensor::operator=;
/// Matrix-vector product
template
__device__ __host__ void mul(const StaticMatrix& mat,
const StaticVector& vec) {
*this = T(0);
for (UInt i = 0; i < n; ++i)
for (UInt j = 0; j < m; ++j)
(*this)(i) +=
((transpose) ? mat(j, i) : mat(i, j)) * vec(j); // can be optimized
}
};
/* -------------------------------------------------------------------------- */
/// Symmetric matrix in Voigt notation
template
class StaticSymMatrix
: public StaticVector::value> {
using parent = StaticVector::value>;
using T = std::remove_cv_t;
private:
template
__device__ __host__ void sym_binary(const StaticMatrix& m,
BinOp&& op) {
for (UInt i = 0; i < n; ++i)
op((*this)(i), m(i, i));
const auto a = 0.5 * std::sqrt(2);
for (UInt j = n - 1, b = n; j > 0; --j)
for (Int i = j - 1; i >= 0; --i)
op((*this)(b++), a * (m(i, j) + m(j, i)));
}
public:
/// Copy values from matrix and symmetrize
template
__device__ __host__ void symmetrize(const StaticMatrix& m) {
sym_binary(m, [](auto&& v, auto&& w) { v = w; });
}
/// Add values from symmetrized matrix
template
__device__ __host__ void operator+=(const StaticMatrix& m) {
sym_binary(m, [](auto&& v, auto&& w) { v += w; });
}
__device__ __host__ auto trace() const {
std::remove_cv_t res = 0;
for (UInt i = 0; i < n; ++i)
res += (*this)(i);
return res;
}
template
__device__ __host__ void deviatoric(const StaticSymMatrix& m,
Real factor = n) {
auto tr = m.trace() / factor;
for (UInt i = 0; i < n; ++i)
(*this)(i) = m(i) - tr;
for (UInt i = n; i < voigt_size::value; ++i)
(*this)(i) = m(i);
}
using parent::operator+=;
using parent::operator=;
};
/* -------------------------------------------------------------------------- */
// Implementation of constructor from symmetric matrix
template
template
__device__ __host__ std::enable_if_t
StaticMatrix::fromSymmetric(
const StaticSymMatrix& o) {
for (UInt i = 0; i < n; ++i)
(*this)(i, i) = o(i);
// We use Mendel notation for the vector representation
const auto a = 1. / std::sqrt(2);
for (UInt j = n - 1, b = n; j > 0; --j)
for (Int i = j - 1; i >= 0; --i)
(*this)(i, j) = (*this)(j, i) = a * o(b++);
}
// Implementation of outer product
template
template
__device__ __host__ void StaticMatrix::outer(
const StaticVector& a, const StaticVector& b) {
for (UInt i = 0; i < n; ++i)
for (UInt j = 0; j < m; ++j)
(*this)(i, j) = a(i) * b(j);
}
/* -------------------------------------------------------------------------- */
/* On the stack static types */
/* -------------------------------------------------------------------------- */
template class StaticParent,
UInt... dims>
struct static_size_helper : product {};
template
struct static_size_helper : voigt_size {};
template class StaticParent, typename T,
UInt... dims>
class Tensor
: public StaticParent<
T, T[static_size_helper::value], dims...> {
static constexpr UInt size = static_size_helper::value;
using parent = StaticParent;
public:
using parent::operator=;
using parent::copy;
/// Default constructor
Tensor() = default;
/// Construct with default value
__device__ __host__ Tensor(T val) { *this = val; }
/// Construct from array
__device__ __host__ Tensor(const std::array& arr) {
// we use size to ensure static loop unrolling
for (UInt i = 0; i < size; ++i)
this->_mem[i] = arr[i];
}
/// Copy from array
__device__ __host__ Tensor& operator=(const std::array& arr) {
// we use size to ensure static loop unrolling
for (UInt i = 0; i < size; ++i)
(*this)(i) = arr[i];
}
/// Construct by copy from static tensor
template
__device__ __host__ Tensor(const StaticParent& o) {
copy(o);
}
__device__ __host__ Tensor(const Tensor& o) : parent() { copy(o); }
__device__ __host__ Tensor& operator=(const Tensor& o) {
copy(o);
return *this;
}
__device__ __host__ Tensor(Tensor&& o) noexcept { copy(o); }
};
template
using Matrix = Tensor;
template
using SymMatrix = Tensor;
template
using Vector = Tensor;
/* -------------------------------------------------------------------------- */
/* Proxy Static Types */
/* -------------------------------------------------------------------------- */
/// Proxy type for tensor
template class StaticParent, typename T,
UInt... dims>
class TensorProxy : public StaticParent {
using parent = StaticParent;
public:
/// Explicit construction from data location
__device__ __host__ explicit TensorProxy(T* spot) { this->_mem = spot; }
/// Explicit construction from lvalue-reference
__device__ __host__ explicit TensorProxy(T& spot) : TensorProxy(&spot) {}
/// Construction from static tensor
template
__device__ __host__
TensorProxy(StaticParent& o)
: TensorProxy(o.begin()) {}
using parent::operator=;
__device__ __host__ TensorProxy(const TensorProxy& o) { this->_mem = o._mem; }
__device__ __host__ TensorProxy& operator=(const TensorProxy& o) {
this->copy(o);
return *this;
}
__device__ __host__ TensorProxy(TensorProxy&& o) noexcept
: TensorProxy(exchange(o._mem, nullptr)) {}
public:
using stack_type = Tensor;
};
template
using MatrixProxy = TensorProxy;
template
using SymMatrixProxy = TensorProxy;
template
using VectorProxy = TensorProxy;
/* -------------------------------------------------------------------------- */
/* -------------------------------------------------------------------------- */
/* Arithmetic operators creating temporaries */
/* -------------------------------------------------------------------------- */
/* -------------------------------------------------------------------------- */
/* -------------------------------------------------------------------------- */
/* Simple operators */
/* -------------------------------------------------------------------------- */
template
__device__ __host__ Vector
operator+(const StaticVector& a,
const StaticVector& b) {
Vector res(a);
res += b;
return res;
}
template
__device__ __host__ Vector
operator-(const StaticVector& a,
const StaticVector& b) {
Vector res(a);
res -= b;
return res;
}
template
__device__ __host__ Vector
operator-(const StaticVector& a) {
Vector res(a);
res *= -1;
return res;
}
template
__device__ __host__ Matrix
operator+(const StaticMatrix& a,
const StaticMatrix& b) {
Matrix res(a);
res += b;
return res;
}
template
__device__ __host__ Matrix
operator-(const StaticMatrix& a,
const StaticMatrix& b) {
Matrix res(a);
res -= b;
return res;
}
template
__device__ __host__ Matrix
operator-(const StaticMatrix& a) {
Matrix res(a);
res *= -1;
return res;
}
template
__device__ __host__ SymMatrix
operator+(const StaticSymMatrix& a,
const StaticSymMatrix& b) {
SymMatrix res(a);
res += b;
return res;
}
template
__device__ __host__ SymMatrix
operator-(const StaticSymMatrix& a,
const StaticSymMatrix& b) {
SymMatrix res(a);
res -= b;
return res;
}
template
__device__ __host__ SymMatrix
operator-(const StaticSymMatrix& a) {
SymMatrix res(a);
res *= -1;
return res;
}
template