diff --git a/src/core/loops/loop_utils.hh b/src/core/loops/loop_utils.hh index 442b36e..a8027bf 100644 --- a/src/core/loops/loop_utils.hh +++ b/src/core/loops/loop_utils.hh @@ -1,97 +1,96 @@ /* * SPDX-License-Indentifier: AGPL-3.0-or-later * * Copyright (©) 2016-2022 EPFL (École Polytechnique Fédérale de Lausanne), * Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides) * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published * by the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . * */ /* -------------------------------------------------------------------------- */ #ifndef LOOP_UTILS_HH #define LOOP_UTILS_HH /* -------------------------------------------------------------------------- */ -#include "loop.hh" #include "tamaas.hh" #include #include #include /* -------------------------------------------------------------------------- */ namespace tamaas { namespace detail { template UInt loopSize(Grids&&... grids) { return (only_points) ? std::get<0>(std::forward_as_tuple(grids...)).getNbPoints() : std::get<0>(std::forward_as_tuple(grids...)).dataSize(); } template void areAllEqual(bool /*unused*/ /*unused*/, std::ptrdiff_t /*unused*/ /*unused*/) {} template void areAllEqual(bool result, std::ptrdiff_t prev, std::ptrdiff_t current) { if (!(result && prev == current)) TAMAAS_EXCEPTION("Cannot loop over ranges that do not have the same size!"); } template void areAllEqual(bool result, std::ptrdiff_t prev, std::ptrdiff_t current, Sizes... rest) { areAllEqual(result && prev == current, current, rest...); } } // namespace detail template void checkLoopSize(Ranges&&... ranges) { detail::areAllEqual(true, (ranges.end() - ranges.begin())...); } namespace detail { template struct reduction_helper; template struct reduction_helper : public thrust::plus { ReturnType init() const { return ReturnType(0); } }; template struct reduction_helper : public thrust::multiplies { ReturnType init() const { return ReturnType(1); } }; template struct reduction_helper : public thrust::minimum { ReturnType init() const { return std::numeric_limits::max(); } }; template struct reduction_helper : public thrust::maximum { ReturnType init() const { return std::numeric_limits::lowest(); } }; } // namespace detail } // namespace tamaas #endif // __LOOP_UTILS_HH diff --git a/src/core/mpi_interface.hh b/src/core/mpi_interface.hh index 2fee3a9..6b51092 100644 --- a/src/core/mpi_interface.hh +++ b/src/core/mpi_interface.hh @@ -1,279 +1,279 @@ /* * SPDX-License-Indentifier: AGPL-3.0-or-later * * Copyright (©) 2016-2022 EPFL (École Polytechnique Fédérale de Lausanne), * Laboratory (LSMS - Laboratoire de Simulation en Mécanique des Solides) * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published * by the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . * */ /* -------------------------------------------------------------------------- */ #ifndef MPI_INTERFACE_HH #define MPI_INTERFACE_HH /* -------------------------------------------------------------------------- */ #include "static_types.hh" #include "tamaas.hh" #include #include #ifdef TAMAAS_USE_MPI #include #endif /* -------------------------------------------------------------------------- */ namespace tamaas { /* -------------------------------------------------------------------------- */ /// Contains mock mpi functions namespace mpi_dummy { struct comm { static comm world; }; struct sequential { static void enter(){}; static void exit(){}; }; struct sequential_guard { sequential_guard() { sequential::enter(); } ~sequential_guard() { sequential::exit(); } }; enum class thread : int { single, funneled, serialized, multiple }; inline bool initialized() { return false; } inline bool finalized() { return false; } inline int init(int* /*unused*/, char*** /*unused*/) { return 0; } inline int init_thread(int* /*unused*/, char*** /*unused*/, thread /*unused*/, thread* provided) { *provided = thread::funneled; return 0; } inline int finalize() { return 0; } inline int rank(comm /*unused*/ = comm::world) { return 0; } inline int size(comm /*unused*/ = comm::world) { return 1; } template inline decltype(auto) reduce(T&& val, comm /*unused*/ = comm::world) { return std::forward(val); } template inline decltype(auto) allreduce(T&& val, comm /*unused*/ = comm::world) { return std::forward(val); } template inline decltype(auto) gather(const T* send, T* recv, int count, comm /*unused*/ = comm::world) { if (send == recv) return; thrust::copy_n(send, count, recv); } template inline decltype(auto) scatter(const T* send, T* recv, int count, comm /*unused*/ = comm::world) { if (send == recv) return; thrust::copy_n(send, count, recv); } template -inline decltype(auto) -scatterv(const T* send, const std::vector& /*unused*/, - const std::vector& /*unused*/, T* recv, int recvcount, - comm /*unused*/ = comm::world) { +inline decltype(auto) scatterv(const T* send, + const std::vector& /*unused*/, + const std::vector& /*unused*/, T* recv, + int recvcount, comm /*unused*/ = comm::world) { scatter(send, recv, recvcount); } template inline decltype(auto) bcast(T* /*unused*/, int /*unused*/, comm /*unused*/ = comm::world) {} } // namespace mpi_dummy /* -------------------------------------------------------------------------- */ #ifdef TAMAAS_USE_MPI /// Contains real mpi functions namespace mpi_impl { /// MPI_Comm wrapper struct comm { MPI_Comm _comm; - operator MPI_Comm() { return _comm; } + operator MPI_Comm() const { return _comm; } static comm& world(); }; struct sequential { static void enter() { comm::world()._comm = MPI_COMM_SELF; } static void exit() { comm::world()._comm = MPI_COMM_WORLD; } }; struct sequential_guard { sequential_guard() { sequential::enter(); } ~sequential_guard() { sequential::exit(); } }; /// MPI Thread level enum class thread : int { single = MPI_THREAD_SINGLE, funneled = MPI_THREAD_FUNNELED, serialized = MPI_THREAD_SERIALIZED, multiple = MPI_THREAD_MULTIPLE }; template struct type_trait; #define TYPE(t, mpi_t) \ template <> \ struct type_trait { \ static constexpr MPI_Datatype value = mpi_t; \ } TYPE(double, MPI_DOUBLE); TYPE(int, MPI_INT); TYPE(unsigned int, MPI_UNSIGNED); TYPE(long double, MPI_LONG_DOUBLE); TYPE(long, MPI_LONG); TYPE(unsigned long, MPI_UNSIGNED_LONG); TYPE(::thrust::complex, MPI_CXX_DOUBLE_COMPLEX); TYPE(::thrust::complex, MPI_CXX_LONG_DOUBLE_COMPLEX); TYPE(bool, MPI_CXX_BOOL); #undef TYPE template struct operation_trait; #define OPERATION(op, mpi_op) \ template <> \ struct operation_trait { \ static constexpr MPI_Op value = mpi_op; \ } OPERATION(plus, MPI_SUM); OPERATION(min, MPI_MIN); OPERATION(max, MPI_MAX); OPERATION(times, MPI_PROD); #undef OPERATION inline bool initialized() { int has_init; MPI_Initialized(&has_init); return has_init != 0; } inline bool finalized() { int has_final; MPI_Finalized(&has_final); return has_final != 0; } inline int init(int* argc, char*** argv) { return MPI_Init(argc, argv); } inline int init_thread(int* argc, char*** argv, thread required, thread* provided) { return MPI_Init_thread(argc, argv, static_cast(required), reinterpret_cast(provided)); } inline int finalize() { return MPI_Finalize(); } inline int rank(comm communicator = comm::world()) { int rank; MPI_Comm_rank(communicator, &rank); return rank; } inline int size(comm communicator = comm::world()) { int size; MPI_Comm_size(communicator, &size); return size; } template inline decltype(auto) reduce(T val, comm communicator = comm::world()) { MPI_Reduce(&val, &val, 1, type_trait::value, operation_trait::value, 0, communicator); return val; } template ::value or std::is_same::value>> inline decltype(auto) allreduce(T val, comm communicator = comm::world()) { MPI_Allreduce(&val, &val, 1, type_trait::value, operation_trait::value, communicator); return val; } template inline decltype(auto) allreduce(const StaticVector& v, comm communicator = comm::world()) { Vector res; MPI_Allreduce(v.begin(), res.begin(), n, type_trait
::value, operation_trait::value, communicator); return res; } template inline decltype(auto) allreduce(const StaticMatrix& v, comm communicator = comm::world()) { Matrix res; MPI_Allreduce(v.begin(), res.begin(), n * m, type_trait
::value, operation_trait::value, communicator); return res; } template inline decltype(auto) gather(const T* send, T* recv, int count, comm communicator = comm::world()) { MPI_Gather(send, count, type_trait::value, recv, count, type_trait::value, 0, communicator); } template inline decltype(auto) scatter(const T* send, T* recv, int count, comm communicator = comm::world()) { MPI_Scatter(send, count, type_trait::value, recv, count, type_trait::value, 0, communicator); } template inline decltype(auto) scatterv(const T* send, const std::vector& sendcounts, const std::vector& displs, T* recv, int recvcount, comm communicator = comm::world()) { MPI_Scatterv(send, sendcounts.data(), displs.data(), type_trait::value, recv, recvcount, type_trait::value, 0, communicator); } template inline decltype(auto) bcast(T* buffer, int count, comm communicator = comm::world()) { MPI_Bcast(buffer, count, type_trait::value, 0, communicator); } } // namespace mpi_impl namespace mpi = mpi_impl; #else namespace mpi = mpi_dummy; #endif // TAMAAS_USE_MPI } // namespace tamaas /* -------------------------------------------------------------------------- */ #endif // MPI_INTERFACE_HH