Page MenuHomec4science

comm_group_inline_impl.hh
No OneTemporary

File Metadata

Created
Sun, Jul 7, 10:05

comm_group_inline_impl.hh

#include "comm_group.hh"
__BEGIN_LIBMULTISCALE__
/* -------------------------------------------------------------------------- */
template <typename T, typename R = std::enable_if_t<
std::is_trivially_copy_constructible<T>::value>>
auto get_mpi_type() {}
/* -------------------------------------------------------------------------- */
#define MPI_TYPE_MAP(__C_TYPE__, __MPI_TYPE__) \
template <> \
inline auto get_mpi_type< \
__C_TYPE__, std::enable_if_t<std::is_trivially_copy_constructible< \
__C_TYPE__>::value>>() { \
return __MPI_TYPE__; \
}
/* -------------------------------------------------------------------------- */
MPI_TYPE_MAP(Real, MPI_DOUBLE)
MPI_TYPE_MAP(UInt, MPI_INT)
MPI_TYPE_MAP(char, MPI_CHAR)
MPI_TYPE_MAP(MPIProc, Communicator::mpi_type_processor)
#undef MPI_TYPE_MAP
/* -------------------------------------------------------------------------- */
template <typename BufType,
bool copyable = std::is_trivially_copy_constructible<
typename BufType::value_type>::value>
struct PackedBuffer {
using value_type = typename BufType::value_type;
PackedBuffer(BufType &data) : _data(data){};
auto mpi_type() { return get_mpi_type<value_type>(); };
auto *data() { return _data.data(); };
UInt size() { return _data.size(); }
void resize(UInt sz) { _data.resize(sz); };
void unpack(){};
BufType &_data;
};
/* -------------------------------------------------------------------------- */
template <typename T> class ContainerArray;
template <typename T> struct PackedBuffer<ContainerArray<T>, true> {
using BufType = ContainerArray<T>;
using value_type = typename BufType::value_type;
PackedBuffer(BufType &data) : _data(data){};
auto mpi_type() { return get_mpi_type<value_type>(); };
auto *data() { return _data.data(); };
UInt size() { return _data.size(); }
void resize(UInt sz) { _data.resize(sz / _data.cols(), _data.cols()); };
void unpack(){};
BufType &_data;
};
/* -------------------------------------------------------------------------- */
template <typename T> void encode_to_sstr(T *var, std::stringstream &sstr) {
char *ptr = reinterpret_cast<char *>(var);
for (unsigned long i = 0; i < sizeof(T); ++i)
sstr << ptr[i];
}
template <typename T> void decode_from_sstr(T *var, std::stringstream &sstr) {
char *ptr = reinterpret_cast<char *>(var);
for (unsigned long i = 0; i < sizeof(T); ++i)
sstr >> ptr[i];
}
template <typename BufType> struct PackedBuffer<BufType, false> {
PackedBuffer(BufType &data) : _data(data) {
std::stringstream sstr;
for (UInt i = 0; i < data.size(); ++i) {
auto &c = _data[i];
c.pack([&](auto &var) { encode_to_sstr(&var, sstr); });
}
auto str = sstr.str();
std::copy(str.begin(), str.end(), std::back_inserter(pack_data));
}
using value_type = char;
void unpack() {
std::string str;
std::copy(pack_data.begin(), pack_data.end(), std::back_inserter(str));
std::stringstream sstr(str);
UInt i = 0;
while (sstr.tellg() < (long)str.size()) {
auto &c = _data[i];
c.unpack([&](auto &var) { decode_from_sstr(&var, sstr); });
++i;
}
}
char *data() { return pack_data.data(); };
UInt size() { return pack_data.size(); }
void resize(UInt sz) { pack_data.resize(sz); }
BufType &_data;
std::vector<char> pack_data;
};
/* -------------------------------------------------------------------------- */
template <typename BufType> decltype(auto) make_pack(BufType &&data) {
return PackedBuffer<std::decay_t<BufType>>(data);
}
/* -------------------------------------------------------------------------- */
inline CommGroup::CommGroup(const LMID &id, int color, UInt nb_procs)
: LMObject(id) {
DUMP("creating comm_group: " << id, DBG_INFO);
is_current_proc_in_group = color != MPI_UNDEFINED;
MPI_Comm_split(MPI_COMM_WORLD, color, lm_my_proc_id, &mpi_comm);
if (mpi_comm != MPI_COMM_NULL) {
processors.resize(nb_procs);
MPI_Comm_group(mpi_comm, &mpi_group);
MPI_Group world_group;
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
std::vector<int> group_ranks(nb_procs);
for (UInt i = 0; i < nb_procs; ++i)
group_ranks[i] = i;
std::vector<int> absolute_ranks(nb_procs);
MPI_Group_translate_ranks(mpi_group, nb_procs, group_ranks.data(),
world_group, absolute_ranks.data());
for (UInt i = 0; i < nb_procs; ++i) {
processors[i].absolute_mpi_rank = absolute_ranks[i];
processors[i].mpi_rank = group_ranks[i];
}
}
if (id == "all" or id == "self")
return;
auto &all_group = Communicator::getCommunicator().getGroup("all");
std::vector<MPIProc> tmp_procs(all_group.size());
if (mpi_comm != MPI_COMM_NULL) {
auto rank = this->getMyRank();
auto &tmp [[gnu::unused]] = processors[rank];
DUMP("local proc: " << tmp.mpi_rank << " " << tmp.absolute_mpi_rank,
DBG_INFO);
all_group.allGather(&processors[rank], 1, tmp_procs.data(),
"gather proc information about group " + id);
} else {
MPIProc tmp;
tmp.mpi_rank = -1;
tmp.absolute_mpi_rank = -1;
DUMP("local proc: " << tmp.mpi_rank << " " << tmp.absolute_mpi_rank,
DBG_INFO);
all_group.allGather(&tmp, 1, tmp_procs.data(),
"gather proc information about group " + id);
}
if (mpi_comm == MPI_COMM_NULL) {
UInt cpt = 0;
for (auto &&p : tmp_procs) {
if (p.mpi_rank == UInt(-1))
continue;
DUMP("proc: " << cpt << " " << p.mpi_rank << " " << p.absolute_mpi_rank,
DBG_INFO);
processors.push_back(p);
++cpt;
}
}
DUMP("created comm_group: " << id, DBG_INFO);
}
/* -------------------------------------------------------------------------- */
#define CHECK_MEMBERSHIP_MPI_ROUTINE() \
if (!this->amIinGroup()) { \
LM_FATAL("MPI Routine cannot be called if not member of the group: " \
<< comment); \
}
/* -------------------------------------------------------------------------- */
template <typename T>
inline auto CommGroup::probe(UInt from, const std::string &comment) {
CHECK_MEMBERSHIP_MPI_ROUTINE();
MPI_Status status;
auto &group = *this;
auto &proc = group[from];
DUMP("probing receive " << typeid(T).name() << " from " << proc.mpi_rank
<< " - seq number " << proc.sequence_number << " for "
<< comment,
DBG_INFO);
MPI_Probe(proc.mpi_rank, proc.sequence_number, this->mpi_comm, &status);
int nb_tmp;
MPI_Get_count(&status, get_mpi_type<T>(), &nb_tmp);
return nb_tmp;
}
/* -------------------------------------------------------------------------- */
inline bool CommGroup::isInGroup(UInt mpi_rank [[gnu::unused]]) const {
if (this->getID() == "all")
return true;
if (this->getID() == "none")
return false;
DUMP("testing if " << mpi_rank << " is in " << *this, DBG_DETAIL);
LM_TOIMPLEMENT;
}
/* -------------------------------------------------------------------------- */
template <typename T>
inline void CommGroup::receive(T *d, UInt nb, UInt from,
const std::string &comment [[gnu::unused]]) {
CHECK_MEMBERSHIP_MPI_ROUTINE();
MPI_Status status;
auto &group = *this;
auto &proc = group[from];
using btype = std::remove_pointer_t<std::decay_t<T>>;
DUMP("receiving " << nb << " " << typeid(T).name() << " from "
<< proc.mpi_rank << " - seq number " << proc.sequence_number
<< " for " << comment,
DBG_INFO);
MPI_Recv(d, nb, get_mpi_type<btype>(), proc.mpi_rank, proc.sequence_number,
this->mpi_comm, &status);
DUMP("received " << nb << " " << typeid(T).name() << " from " << proc.mpi_rank
<< " - seq number " << proc.sequence_number << " for "
<< comment,
DBG_INFO);
++proc.sequence_number;
}
/* -------------------------------------------------------------------------- */
template <typename Vec>
inline void CommGroup::receive(Vec &&d, UInt from, const std::string &comment) {
auto unpack = make_pack(d);
auto nb = probe<typename decltype(unpack)::value_type>(from, comment);
unpack.resize(nb);
receive(unpack.data(), unpack.size(), from, comment);
}
/* -------------------------------------------------------------------------- */
template <typename T>
void CommGroup::send(T *d, UInt nb, UInt dest,
const std::string &comment [[gnu::unused]]) {
CHECK_MEMBERSHIP_MPI_ROUTINE();
auto &group = *this;
auto &proc = group[dest];
using btype = std::remove_pointer_t<std::decay_t<T>>;
DUMP("sending " << nb << " " << typeid(T).name() << " to " << proc.mpi_rank
<< " - seq number " << proc.sequence_number << " for "
<< comment,
DBG_INFO);
MPI_Send(d, nb, get_mpi_type<btype>(), proc.mpi_rank, proc.sequence_number,
this->mpi_comm);
DUMP("sent " << nb << " " << typeid(T).name() << " to " << proc.mpi_rank
<< " - seq number " << proc.sequence_number << " for "
<< comment,
DBG_INFO);
++proc.sequence_number;
}
/* -------------------------------------------------------------------------- */
template <typename Vec>
inline void CommGroup::send(Vec &&d, UInt to, const std::string &comment) {
auto pack = make_pack(d);
send(pack.data(), pack.size(), to, comment);
}
/* -------------------------------------------------------------------------- */
inline auto getMPIOperator(Operator op) {
MPI_Op mpi_op;
switch (op) {
case OP_SUM:
mpi_op = MPI_SUM;
break;
case OP_MAX:
mpi_op = MPI_MAX;
break;
case OP_MIN:
mpi_op = MPI_MIN;
break;
default:
LM_FATAL("unknown operator " << op);
}
return mpi_op;
}
/* -------------------------------------------------------------------------- */
template <typename T>
void CommGroup::reduce(T *contrib, UInt nb, const std::string &comment,
Operator op, UInt root_rank) {
CHECK_MEMBERSHIP_MPI_ROUTINE();
MPI_Op mpi_op = getMPIOperator(op);
std::vector<T> result(nb);
MPI_Reduce(&contrib[0], result.data(), nb, get_mpi_type<T>(), mpi_op,
root_rank, this->mpi_comm);
std::copy(result.begin(), result.end(), contrib);
}
/* -------------------------------------------------------------------------- */
template <typename T>
void CommGroup::allReduce(T *contrib, UInt nb, const std::string &comment,
Operator op) {
CHECK_MEMBERSHIP_MPI_ROUTINE();
MPI_Op mpi_op = getMPIOperator(op);
std::vector<T> result(nb);
MPI_Allreduce(&contrib[0], &result[0], nb, get_mpi_type<T>(), mpi_op,
this->mpi_comm);
std::copy(result.begin(), result.end(), contrib);
}
/* -------------------------------------------------------------------------- */
template <typename BufType>
inline void CommGroup::allGatherv(BufType &&send_buffer, BufType &&recv_buffer,
std::vector<UInt> &recv_counts,
const std::string &comment) {
allGatherv(send_buffer.data(), send_buffer.size(), recv_buffer.data(),
recv_counts, comment);
}
/* -------------------------------------------------------------------------- */
template <typename T>
inline void CommGroup::allGatherv(T *send_buffer [[gnu::unused]],
UInt send_nb [[gnu::unused]],
T *recv_buffer [[gnu::unused]],
std::vector<UInt> &recv_counts
[[gnu::unused]],
const std::string &comment [[gnu::unused]]) {
CHECK_MEMBERSHIP_MPI_ROUTINE();
LM_TOIMPLEMENT;
}
/* -------------------------------------------------------------------------- */
template <typename BufType>
inline void CommGroup::allGather(BufType &&send_buffer, BufType &&recv_buffer,
const std::string &comment) {
recv_buffer.resize(this->size());
allGatherv(send_buffer.data(), send_buffer.size(), recv_buffer.data(),
recv_buffer.size(), comment);
}
/* -------------------------------------------------------------------------- */
template <typename T>
inline void CommGroup::allGather(T *send_buffer, UInt nb, T *recv_buffer,
const std::string &comment) {
CHECK_MEMBERSHIP_MPI_ROUTINE();
DUMP("Allgather " << nb << " " << typeid(T).name() << " for " << comment,
DBG_INFO);
MPI_Allgather(send_buffer, nb, get_mpi_type<T>(), recv_buffer, nb,
get_mpi_type<T>(), this->mpi_comm);
DUMP("Allgather done for " << comment, DBG_INFO);
}
/* -------------------------------------------------------------------------- */
template <typename T>
inline void CommGroup::gather(T *sendbuf, UInt nb, T *recvbuf, UInt root,
const std::string &comment) {
CHECK_MEMBERSHIP_MPI_ROUTINE();
MPI_Gather(sendbuf, nb, get_mpi_type<T>(), recvbuf, nb, get_mpi_type<T>(),
root, this->mpi_comm);
}
/* -------------------------------------------------------------------------- */
template <typename BufType>
inline void CommGroup::gather(BufType &&sendbuf, BufType &&recvbuf, UInt root,
const std::string &comment) {
if (this->getMyRank() == root)
recvbuf.resize(sendbuf.size() * this->size());
auto pack = make_pack(sendbuf);
auto unpack = make_pack(recvbuf);
gather(pack.data(), pack.size(), unpack.data(), root, comment);
if (this->getMyRank() == root)
unpack.unpack();
}
/* -------------------------------------------------------------------------- */
inline void CommGroup::printself(std::ostream &os) const {
os << "Communication Group #" << this->getID();
os << ", mpi_ID: " << this->mpi_comm << ", size: " << this->size();
}
/* -------------------------------------------------------------------------- */
template <typename BufType>
inline void CommGroup::broadcast(BufType &&buf, UInt root,
const std::string &comment) {
auto pack = make_pack(buf);
broadcast(pack.data(), pack.size(), root, comment);
pack.unpack();
}
/* -------------------------------------------------------------------------- */
template <typename T>
inline void CommGroup::broadcast(T *buf, UInt nb, UInt root,
const std::string &comment) {
CHECK_MEMBERSHIP_MPI_ROUTINE();
DUMP("broadcast " << nb << " " << typeid(T).name() << " from " << root
<< " for " << comment,
DBG_INFO);
MPI_Bcast(buf, nb, get_mpi_type<T>(), root, this->mpi_comm);
DUMP("done broadcast " << nb << " " << typeid(T).name() << " from " << root
<< " for " << comment,
DBG_INFO);
}
/* -------------------------------------------------------------------------- */
template <typename T>
inline void CommGroup::scatter(T *sendbuf, T *recvbuf, UInt nb_recv, UInt root,
const std::string &comment) {
CHECK_MEMBERSHIP_MPI_ROUTINE();
DUMP("scatter " << nb_recv << " " << typeid(T).name() << " from " << root
<< " for " << comment,
DBG_INFO);
MPI_Scatter(sendbuf, nb_recv, get_mpi_type<T>(), recvbuf, nb_recv,
get_mpi_type<T>(), root, this->mpi_comm);
DUMP("done scatter " << nb_recv << " " << typeid(T).name() << " from " << root
<< " for " << comment,
DBG_INFO);
}
/* -------------------------------------------------------------------------- */
template <typename BufType>
inline void CommGroup::scatter(BufType &&sendbuf, BufType &&recvbuf, UInt root,
const std::string &comment) {
if (this->getMyRank() == root)
LM_ASSERT(sendbuf.size() == recvbuf.size() * this->size(),
"buffers not having a proper size");
scatter(sendbuf.data(), recvbuf.data(), recvbuf.size(), root, comment);
}
/* -------------------------------------------------------------------------- */
inline SelfGroup::SelfGroup() : CommGroup("self", 1, 1) {
auto &proc = processors[0];
proc.absolute_mpi_rank = lm_my_proc_id;
proc.mpi_rank = 0;
is_current_proc_in_group = true;
mpi_comm = MPI_COMM_SELF;
}
__END_LIBMULTISCALE__

Event Timeline