Caffe2 - C++ API
A deep learning, cross platform ML framework
mpi_common.h
1 #ifndef CAFFE2_MPI_MPI_COMMON_H_
2 #define CAFFE2_MPI_MPI_COMMON_H_
3 
4 #include <mpi.h>
5 #include <mutex>
6 
7 #include "caffe2/core/logging.h"
8 
9 namespace caffe2 {
10 
11 inline void CheckInitializedMPI() {
12  int flag;
13  MPI_Initialized(&flag);
14  CAFFE_ENFORCE(flag, "MPI does not seem to have been initialized.");
15 }
16 
17 template <typename T> class MPIDataTypeWrapper;
18 
19 #define MPI_DATATYPE_WRAPPER(c_type, mpi_type) \
20  template<> class MPIDataTypeWrapper<c_type> { \
21  public: \
22  inline static MPI_Datatype type() { return mpi_type; } \
23  };
24 
25 MPI_DATATYPE_WRAPPER(char, MPI_CHAR)
26 MPI_DATATYPE_WRAPPER(float, MPI_FLOAT)
27 MPI_DATATYPE_WRAPPER(double, MPI_DOUBLE)
28 // Note(Yangqing): as necessary, add more specializations.
29 #undef MPI_DATATYPE_WRAPPER
30 
31 // For all Caffe MPI calls, we will wrap it inside an MPI mutex lock guard.
32 std::mutex& MPIMutex();
33 
34 #define MPI_CHECK(condition) \
35  do { \
36  std::lock_guard<std::mutex> guard(::caffe2::MPIMutex()); \
37  int error = (condition); \
38  CAFFE_ENFORCE( \
39  error == MPI_SUCCESS, \
40  "Caffe2 MPI Error at: ", \
41  __FILE__, \
42  ":", \
43  __LINE__, \
44  ": ", \
45  error); \
46  } while (0)
47 
52 MPI_Comm GlobalMPIComm();
53 
58 void SetGlobalMPIComm(MPI_Comm new_comm);
59 
63 int MPICommSize(MPI_Comm comm);
64 
68 int MPICommRank(MPI_Comm comm);
69 
74  public:
86  MPI_Comm src_comm = MPI_COMM_NULL,
87  int color = 0,
88  int rank = -1) {
89  if (src_comm == MPI_COMM_NULL) {
90  src_comm = GlobalMPIComm();
91  }
92  if (rank == -1) {
93  MPI_CHECK(MPI_Comm_rank(src_comm, &rank));
94  }
95  MPI_CHECK(MPI_Comm_split(src_comm, color, rank, &comm_));
96  MPI_CHECK(MPI_Comm_size(comm_, &size_));
97  MPI_CHECK(MPI_Comm_rank(comm_, &rank_));
98  }
99 
101  int ret;
102  MPI_CHECK(MPI_Finalized(&ret));
103  if (!ret) {
104  MPI_Comm_free(&comm_);
105  }
106  }
107 
111  inline MPI_Comm comm() const {
112  return comm_;
113  }
117  inline int size() const {
118  return size_;
119  }
123  inline int rank() const {
124  return rank_;
125  }
126 
127  private:
128  MPI_Comm comm_;
129  int size_;
130  int rank_;
131 };
132 
148 void MPISetupPeers(
149  const int replicas,
150  const string& role,
151  const string& job_path);
152 } // namespace caffe2
153 
154 #endif // CAFFE2_MPI_MPI_COMMON_H_
MPI_Comm comm() const
Returns the common world held by the wrapper.
Definition: mpi_common.h:111
MPI_Comm GlobalMPIComm()
Gets the global MPI communicator used by Caffe2.
Definition: mpi_common.cc:20
void SetGlobalMPIComm(MPI_Comm new_comm)
Sets the global MPI communicator.
Definition: mpi_common.cc:24
int rank() const
Returns the rank of this process in the world.
Definition: mpi_common.h:123
int MPICommRank(MPI_Comm comm)
A helper function to return the rank of the given communicator.
Definition: mpi_common.cc:37
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
int size() const
Returns the size of the world.
Definition: mpi_common.h:117
int MPICommSize(MPI_Comm comm)
A helper function to return the size of the given communicator.
Definition: mpi_common.cc:31
MPICommonWorldWrapper(MPI_Comm src_comm=MPI_COMM_NULL, int color=0, int rank=-1)
Creates a common world wrapper.
Definition: mpi_common.h:85
A simple wrapper over an MPI common world.
Definition: mpi_common.h:73
void MPISetupPeers(const int replicas, const string &role, const string &job_path)
A function used to perform peer setup so one does not need to use mpirun / mpiexec to run the binary...
Definition: mpi_common.cc:94