Caffe2 - C++ API
A deep learning, cross platform ML framework
event.h
1 #ifndef CAFFE2_CORE_EVENT_H_
2 #define CAFFE2_CORE_EVENT_H_
3 
4 #include "caffe2/core/common.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/proto/caffe2.pb.h"
7 
8 namespace caffe2 {
9 
10 constexpr int MaxDeviceTypes = DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
11 class Event;
12 
13 enum EventStatus {
14  EVENT_INITIALIZED = 0,
15  EVENT_SCHEDULED = 1,
16  EVENT_SUCCESS = 2,
17  EVENT_FAILED = 3,
18 };
19 
20 // For the following functions, void* shall be interpreted as the corresponding
21 // context object corresponding to the device type associated with the
22 // functions.
23 
24 // Initializes event
25 typedef void (*EventCreateFunction)(const DeviceOption& option, Event*);
26 
27 // Called on event to signal that CPU part of operation is finished,
28 // Optionally accepts error message from CPU part.
29 // Should be called no more than once per event
30 typedef void (*EventRecordFunction)(Event*, const void*, const char*);
31 
32 // Waits and returns as soon as possible in order schedule next operation,
33 // e.g. for CUDA->CUDA waits only for CPU part of CUDA op,
34 // for CUDA->CPU waits till the CUDA op is fully completed.
35 // Prepares context to synchronize device part of operation.
36 // Can be called concurrently from multiple threads
37 typedef void (*EventWaitFunction)(const Event*, void*);
38 
39 // Waits till operation is fully finished,
40 // can be called concurrently from multiple threads
41 typedef void (*EventFinishFunction)(const Event*);
42 
43 // Queries current status of operation,
44 // can be called concurrently from multiple threads
45 typedef EventStatus (*EventQueryFunction)(const Event*);
46 typedef const std::string& (*EventErrorMessageFunction)(const Event*);
47 typedef void (*EventSetFinishedFunction)(const Event*, const char*);
48 typedef void (*EventResetFunction)(Event*);
49 
50 class Event {
51  public:
52  explicit Event(const DeviceOption& option)
53  : event_(), type_(option.device_type()), option_(option) {
54  CAFFE_ENFORCE_LT(type_, MaxDeviceTypes);
55  CAFFE_ENFORCE(event_creator_[type_]);
56  event_creator_[type_](option, this);
57  }
58 
59  // Nothing needs to be done in the destructor, as the event creator should
60  // set the proper destruction process for the unique_ptr.
61  ~Event() {}
62 
63  void Record(
64  int recorder_type,
65  const void* context,
66  const char* err_msg = nullptr) {
67  CAFFE_ENFORCE_EQ(
68  recorder_type,
69  type_,
70  "You are trying to record with a wrong device type.");
71  CAFFE_ENFORCE(event_recorder_[recorder_type]);
72  event_recorder_[recorder_type](this, context, err_msg);
73  }
74 
75  void Wait(int waiter_type, void* context) const {
76  CAFFE_ENFORCE(event_waiter_[waiter_type][type_]);
77  event_waiter_[waiter_type][type_](this, context);
78  }
79 
80  void Finish() const {
81  CAFFE_ENFORCE(event_finisher_[type_]);
82  event_finisher_[type_](this);
83  }
84 
85  EventStatus Query() const {
86  CAFFE_ENFORCE(event_querier_[type_]);
87  return event_querier_[type_](this);
88  }
89 
90  const std::string& ErrorMessage() const {
91  CAFFE_ENFORCE(event_err_msg_getter_[type_]);
92  return event_err_msg_getter_[type_](this);
93  }
94 
95  void Reset() {
96  CAFFE_ENFORCE(event_resetter_[type_]);
97  event_resetter_[type_](this);
98  }
99 
100  const DeviceOption& GetDeviceOption() const {
101  return option_;
102  }
103 
104  bool IsScheduled() const {
105  return Query() == EventStatus::EVENT_SCHEDULED;
106  }
107 
108  bool IsFinished() const {
109  auto status = Query();
110  return status == EventStatus::EVENT_SUCCESS ||
111  status == EventStatus::EVENT_FAILED;
112  }
113 
114  void SetFinished(const char* err_msg = nullptr) {
115  CAFFE_ENFORCE(event_finished_setter_[type_]);
116  return event_finished_setter_[type_](this, err_msg);
117  }
118 
119  // If parent op has succeeded, then we can run any child op;
120  // If parent op is in scheduled state, we need to check that:
121  // - child op supports async scheduling
122  // - there's a way to setup synchronization between async parent and
123  // child - both child and parent should use the same type of device,
124  // non-blocking synchronization between different device types is not
125  // supported
126  // If parent op is in another state (initialized or failed) then scheduling
127  // is not possible
128  bool CanSchedule(const Event& child_event, bool supports_async) const {
129  return CanSchedule(type_, Query(), child_event.GetType(), supports_async);
130  }
131 
132  static bool CanSchedule(
133  int parent_type,
134  EventStatus parent_status,
135  int child_type,
136  bool child_supports_async) {
137  if (parent_status == EventStatus::EVENT_SUCCESS) {
138  return true;
139  }
140  if (parent_status == EventStatus::EVENT_SCHEDULED) {
141  return (parent_type == child_type) && child_supports_async;
142  }
143  return false;
144  }
145 
146  int GetType() const {
147  return type_;
148  }
149 
150  // event_ is going to be accessed by the EventCreate/Record/Wait/Finish
151  // functions, but one should not use it outside the own Event functionalities.
152  // In the future we may move it to a private member.
153  std::shared_ptr<void> event_;
154 
155  private:
156  int type_;
157  DeviceOption option_;
158 
159  CAFFE2_API static EventCreateFunction event_creator_[MaxDeviceTypes];
160  CAFFE2_API static EventRecordFunction event_recorder_[MaxDeviceTypes];
161  CAFFE2_API static EventWaitFunction event_waiter_[MaxDeviceTypes]
162  [MaxDeviceTypes];
163  CAFFE2_API static EventFinishFunction event_finisher_[MaxDeviceTypes];
164 
165  CAFFE2_API static EventQueryFunction event_querier_[MaxDeviceTypes];
166  CAFFE2_API static EventErrorMessageFunction
167  event_err_msg_getter_[MaxDeviceTypes];
168  CAFFE2_API static EventSetFinishedFunction
169  event_finished_setter_[MaxDeviceTypes];
170  CAFFE2_API static EventResetFunction event_resetter_[MaxDeviceTypes];
171 
172  template <int d>
173  friend struct EventCreateFunctionRegisterer;
174  template <int d>
175  friend struct EventRecordFunctionRegisterer;
176  template <int w, int d>
177  friend struct EventWaitFunctionRegisterer;
178  template <int d>
179  friend struct EventFinishFunctionRegisterer;
180 
181  template <int d>
182  friend struct EventQueryFunctionRegisterer;
183  template <int d>
185  template <int d>
187  template <int d>
188  friend struct EventResetFunctionRegisterer;
189 };
190 
191 template <int d>
193  explicit EventCreateFunctionRegisterer(EventCreateFunction f) {
194  static_assert(d < MaxDeviceTypes, "");
195  Event::event_creator_[d] = f;
196  }
197 };
198 #define REGISTER_EVENT_CREATE_FUNCTION(d, f) \
199  namespace { \
200  static EventCreateFunctionRegisterer<d> g_event_create_##d(f); \
201  }
202 
203 template <int d>
205  explicit EventRecordFunctionRegisterer(EventRecordFunction f) {
206  static_assert(d < MaxDeviceTypes, "");
207  Event::event_recorder_[d] = f;
208  }
209 };
210 #define REGISTER_EVENT_RECORD_FUNCTION(d, f) \
211  namespace { \
212  static EventRecordFunctionRegisterer<d> g_event_record_##d(f); \
213  }
214 
215 template <int waiter_type, int event_type>
217  explicit EventWaitFunctionRegisterer(EventWaitFunction f) {
218  static_assert(waiter_type < MaxDeviceTypes, "");
219  static_assert(event_type < MaxDeviceTypes, "");
220  Event::event_waiter_[waiter_type][event_type] = f;
221  }
222 };
223 #define REGISTER_EVENT_WAIT_FUNCTION(w, d, f) \
224  namespace { \
225  static EventWaitFunctionRegisterer<w, d> g_event_wait_##w##_##d(f); \
226  }
227 
228 template <int d>
230  explicit EventQueryFunctionRegisterer(EventQueryFunction f) {
231  static_assert(d < MaxDeviceTypes, "");
232  Event::event_querier_[d] = f;
233  }
234 };
235 #define REGISTER_EVENT_QUERY_FUNCTION(d, f) \
236  namespace { \
237  static EventQueryFunctionRegisterer<d> g_event_query_##d(f); \
238  }
239 
240 template <int d>
242  explicit EventErrorMessageFunctionRegisterer(EventErrorMessageFunction f) {
243  static_assert(d < MaxDeviceTypes, "");
244  Event::event_err_msg_getter_[d] = f;
245  }
246 };
247 #define REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(d, f) \
248  namespace { \
249  static EventErrorMessageFunctionRegisterer<d> g_event_err_msg_##d(f); \
250  }
251 
252 template <int d>
254  explicit EventSetFinishedFunctionRegisterer(EventSetFinishedFunction f) {
255  static_assert(d < MaxDeviceTypes, "");
256  Event::event_finished_setter_[d] = f;
257  }
258 };
259 #define REGISTER_EVENT_SET_FINISHED_FUNCTION(d, f) \
260  namespace { \
261  static EventSetFinishedFunctionRegisterer<d> g_event_set_finished_##d(f); \
262  }
263 
264 template <int d>
266  explicit EventFinishFunctionRegisterer(EventFinishFunction f) {
267  static_assert(d < MaxDeviceTypes, "");
268  Event::event_finisher_[d] = f;
269  }
270 };
271 #define REGISTER_EVENT_FINISH_FUNCTION(d, f) \
272  namespace { \
273  static EventFinishFunctionRegisterer<d> g_event_finish_##d(f); \
274  }
275 
276 template <int d>
278  explicit EventResetFunctionRegisterer(EventResetFunction f) {
279  static_assert(d < MaxDeviceTypes, "");
280  Event::event_resetter_[d] = f;
281  }
282 };
283 #define REGISTER_EVENT_RESET_FUNCTION(d, f) \
284  namespace { \
285  static EventResetFunctionRegisterer<d> g_event_reset_##d(f); \
286  }
287 
288 } // namespace caffe2
289 
290 #endif // CAFFE2_CORE_EVENT_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...