Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F110722721
comm_group.hh
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Sun, Apr 27, 19:39
Size
10 KB
Mime Type
text/x-c++
Expires
Tue, Apr 29, 19:39 (1 d, 23 h)
Engine
blob
Format
Raw Data
Handle
25796171
Attached To
rLIBMULTISCALE LibMultiScale
comm_group.hh
View Options
#ifndef __LIBMULTISCALE_COMM_GROUP_HH__
#define __LIBMULTISCALE_COMM_GROUP_HH__
/* -------------------------------------------------------------------------- */
#include "lm_common.hh"
#include "lm_object.hh"
#include <memory>
#include <mpi.h>
/* -------------------------------------------------------------------------- */
__BEGIN_LIBMULTISCALE__
struct MPIProc {
MPIProc() : sequence_number(0) {}
UInt getAbsoluteRank() { return absolute_mpi_rank; };
UInt id;
UInt mpi_rank;
UInt sequence_number;
UInt absolute_mpi_rank;
};
struct CommGroup : public LMObject {
CommGroup(const LMID &id, int color, UInt nb_procs) : LMObject(id) {
DUMP("creating comm_group: " << id, DBG_INFO);
is_current_proc_in_group = color != MPI_UNDEFINED;
processors.resize(nb_procs);
MPI_Comm_split(MPI_COMM_WORLD, color, lm_my_proc_id, &mpi_comm);
MPI_Comm_group(mpi_comm, &mpi_group);
MPI_Group world_group;
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
int group_ranks[nb_procs];
for (UInt i = 0; i < nb_procs; ++i)
group_ranks[i] = i;
int absolute_ranks[nb_procs];
MPI_Group_translate_ranks(mpi_group, nb_procs, group_ranks, world_group,
absolute_ranks);
for (UInt i = 0; i < nb_procs; ++i) {
processors[i].absolute_mpi_rank = absolute_ranks[i];
processors[i].mpi_rank = group_ranks[i];
}
DUMP("created comm_group: " << id, DBG_INFO);
}
bool operator==(CommGroup &grp) { return (grp.getID() == this->getID()); }
bool operator!=(CommGroup &grp) { return !(grp == *this); }
auto getMPICommGroup() { return mpi_comm; };
UInt getMyRank() const {
int rank_in_group;
MPI_Comm_rank(mpi_comm, &rank_in_group);
return rank_in_group;
}
UInt real2GroupRank(UInt real_rank);
const LMID &getID() const override {
auto &id = LMObject::getID();
if (id == "invalid")
LM_FATAL("Communication group is invalid");
return id;
};
template <typename BufType>
void allReduce(BufType &&contrib, UInt nb, const std::string &comment,
Operator op);
template <typename BufType>
void reduce(BufType &&contrib, UInt nb, const std::string &comment,
Operator op, UInt root_rank = 0);
template <typename BufType>
inline void send(BufType &&d, UInt nb, UInt to, const std::string &comment);
template <typename BufType>
inline void receive(BufType &&d, UInt nb, UInt from,
const std::string &comment);
template <typename BufType>
inline void gather(BufType &&d, UInt nb, UInt *nb_data_per_proc, UInt to,
const std::string &comment);
template <typename BufType>
inline void allGatherv(BufType &&d, UInt nb, int *rcnts, int *rcntscnts,
int *yop, const std::string &comment);
auto size() const { return processors.size(); };
auto begin() { return processors.begin(); };
auto end() { return processors.end(); };
auto &operator[](UInt i) { return processors[i]; }
inline bool isInGroup(UInt mpi_rank) const;
bool amIinGroup() { return is_current_proc_in_group; };
void synchronize() {
MPI_Barrier(mpi_comm);
};
void printself(std::ostream &os) const {
os << "Communication Group #" << this->getID();
os << ", mpi_ID: " << this->mpi_comm << ", size: " << this->size();
}
// compatibility routines
template <typename BufType>
inline void receive(BufType &&d, UInt from, const std::string &comment) {
receive(d, d.size(), from, comment);
};
template <typename BufType>
inline void send(BufType &&d, UInt to, const std::string &comment) {
send(d, d.size(), to, comment);
};
private:
//! the intra communicator of the group
MPI_Comm mpi_comm;
//! the mpi group
MPI_Group mpi_group;
//! vector of processors
std::vector<MPIProc> processors;
//! stores the color used at creation of the communicator
bool is_current_proc_in_group;
};
/* -------------------------------------------------------------------------- */
inline std::ostream &operator<<(std::ostream &os, const CommGroup &group) {
group.printself(os);
return os;
}
/* -------------------------------------------------------------------------- */
template <typename T> struct get_mpi_scalar_type {};
/* -------------------------------------------------------------------------- */
template <typename T> struct get_mpi_type {};
/* -------------------------------------------------------------------------- */
#define MPI_TYPE_MAP(__C_TYPE__, __MPI_TYPE__) \
template <> struct get_mpi_scalar_type<__C_TYPE__> { \
static auto value() { return __MPI_TYPE__; }; \
using scalar_type = __C_TYPE__; \
}; \
template <> \
struct get_mpi_type<__C_TYPE__> : public get_mpi_scalar_type<__C_TYPE__> {}
MPI_TYPE_MAP(Real, MPI_DOUBLE);
MPI_TYPE_MAP(UInt, MPI_INT);
#undef MPI_TYPE_MAP
/* -------------------------------------------------------------------------- */
template <template <typename Tscal> class T, typename Tscal>
struct get_mpi_type<T<Tscal>> : public get_mpi_scalar_type<Tscal> {};
/* -------------------------------------------------------------------------- */
template <class T>
struct get_mpi_type<std::vector<T>> : public get_mpi_scalar_type<T> {};
/* -------------------------------------------------------------------------- */
inline bool CommGroup::isInGroup(UInt mpi_rank) const {
if (this->getID() == "all")
return true;
if (this->getID() == "none")
return false;
DUMP("testing if " << mpi_rank << " is in " << *this, DBG_DETAIL);
LM_TOIMPLEMENT;
}
/* -------------------------------------------------------------------------- */
template <typename BufType>
inline void CommGroup::receive(BufType &&d, UInt nb, UInt from,
const std::string &comment) {
MPI_Status status;
auto &group = *this;
auto &proc = group[from];
using btype = std::remove_pointer_t<std::decay_t<BufType>>;
using scalar_type = typename get_mpi_type<btype>::scalar_type;
DUMP("probe reception of " << typeid(scalar_type).name() << " from "
<< proc.mpi_rank << " - seq number "
<< proc.sequence_number << " for " << comment,
DBG_INFO);
if (nb == UINT_MAX) {
MPI_Probe(proc.mpi_rank, proc.sequence_number, this->mpi_comm, &status);
int nb_tmp;
MPI_Get_count(&status, get_mpi_type<btype>::value(), &nb_tmp);
nb = nb_tmp;
}
DUMP("receiving " << nb << " " << typeid(scalar_type).name() << " from "
<< proc.mpi_rank << " - seq number " << proc.sequence_number
<< " for " << comment,
DBG_INFO);
MPI_Recv(&d[0], nb, get_mpi_type<btype>::value(), proc.mpi_rank,
proc.sequence_number, this->mpi_comm, &status);
DUMP("received " << nb << " " << typeid(scalar_type).name() << " from "
<< proc.mpi_rank << " - seq number " << proc.sequence_number
<< " for " << comment,
DBG_INFO);
++proc.sequence_number;
}
/* -------------------------------------------------------------------------- */
template <typename BufType>
inline void CommGroup::send(BufType &&d, UInt nb, UInt dest,
const std::string &comment) {
auto &group = *this;
auto &proc = group[dest];
using btype = std::remove_pointer_t<std::decay_t<BufType>>;
using scalar_type = typename get_mpi_type<btype>::scalar_type;
DUMP("sending " << nb << " " << typeid(scalar_type).name() << " to "
<< proc.mpi_rank << " - seq number " << proc.sequence_number
<< " for " << comment,
DBG_INFO);
MPI_Send(&d[0], nb, get_mpi_type<btype>::value(), proc.mpi_rank,
proc.sequence_number, this->mpi_comm);
DUMP("sent " << nb << " " << typeid(scalar_type).name() << " to "
<< proc.mpi_rank << " - seq number " << proc.sequence_number
<< " for " << comment,
DBG_INFO);
++proc.sequence_number;
}
/* -------------------------------------------------------------------------- */
inline auto getMPIOperator(Operator op) {
MPI_Op mpi_op;
switch (op) {
case OP_SUM:
mpi_op = MPI_SUM;
break;
case OP_MAX:
mpi_op = MPI_MAX;
break;
case OP_MIN:
mpi_op = MPI_MIN;
break;
default:
LM_FATAL("unknown operator " << op);
}
return mpi_op;
}
/* -------------------------------------------------------------------------- */
template <typename BufType>
void CommGroup::reduce(BufType &&contrib, UInt nb, const std::string &comment,
Operator op, UInt root_rank) {
if (!amIinGroup()) {
LM_FATAL(
"reduction can't be made if not a member of the group: " << comment);
}
using btype = std::remove_pointer_t<std::decay_t<BufType>>;
using scalar_type = typename get_mpi_type<btype>::scalar_type;
MPI_Op mpi_op = getMPIOperator(op);
scalar_type result[nb];
MPI_Reduce(&contrib[0], result, nb, get_mpi_type<btype>::value(), mpi_op, 0,
this->mpi_comm);
contrib = result;
};
/* -------------------------------------------------------------------------- */
template <typename BufType>
void CommGroup::allReduce(BufType &&contrib, UInt nb,
const std::string &comment, Operator op) {
if (!amIinGroup()) {
LM_FATAL(
"reduction can't be made if not a member of the group: " << comment);
}
using btype = std::remove_pointer_t<std::decay_t<BufType>>;
using scalar_type = typename get_mpi_type<btype>::scalar_type;
MPI_Op mpi_op = getMPIOperator(op);
std::vector<scalar_type> result(nb);
MPI_Allreduce(&contrib[0], &result[0], nb, get_mpi_type<btype>::value(),
mpi_op, this->mpi_comm);
std::copy(result.begin(), result.end(), contrib);
}
/* -------------------------------------------------------------------------- */
template <typename BufType>
inline void CommGroup::allGatherv(BufType &&d, UInt nb, int *rcnts,
int *rcntscnts, int *yop,
const std::string &comment) {
LM_TOIMPLEMENT;
}
/* -------------------------------------------------------------------------- */
template <typename BufType>
inline void CommGroup::gather(BufType &&d, UInt nb, UInt *nb_data_per_proc,
UInt to, const std::string &comment) {
LM_TOIMPLEMENT;
}
/* -------------------------------------------------------------------------- */
__END_LIBMULTISCALE__
#endif //__LIBMULTISCALE_COMM_GROUP_HH__
Event Timeline
Log In to Comment