1 #ifndef CAFFE2_CORE_EVENT_H_ 2 #define CAFFE2_CORE_EVENT_H_ 4 #include "caffe2/core/common.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/proto/caffe2.pb.h" 10 constexpr
int MaxDeviceTypes = DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
14 EVENT_INITIALIZED = 0,
25 typedef void (*EventCreateFunction)(
const DeviceOption& option, Event*);
30 typedef void (*EventRecordFunction)(Event*,
const void*,
const char*);
37 typedef void (*EventWaitFunction)(
const Event*,
void*);
41 typedef void (*EventFinishFunction)(
const Event*);
45 typedef EventStatus (*EventQueryFunction)(
const Event*);
46 typedef const std::string& (*EventErrorMessageFunction)(
const Event*);
47 typedef void (*EventSetFinishedFunction)(
const Event*,
const char*);
48 typedef void (*EventResetFunction)(Event*);
52 explicit Event(
const DeviceOption& option)
53 : event_(), type_(option.device_type()), option_(option) {
54 CAFFE_ENFORCE_LT(type_, MaxDeviceTypes);
55 CAFFE_ENFORCE(event_creator_[type_]);
56 event_creator_[type_](option,
this);
66 const char* err_msg =
nullptr) {
70 "You are trying to record with a wrong device type.");
71 CAFFE_ENFORCE(event_recorder_[recorder_type]);
72 event_recorder_[recorder_type](
this, context, err_msg);
75 void Wait(
int waiter_type,
void* context)
const {
76 CAFFE_ENFORCE(event_waiter_[waiter_type][type_]);
77 event_waiter_[waiter_type][type_](
this, context);
81 CAFFE_ENFORCE(event_finisher_[type_]);
82 event_finisher_[type_](
this);
85 EventStatus Query()
const {
86 CAFFE_ENFORCE(event_querier_[type_]);
87 return event_querier_[type_](
this);
90 const std::string& ErrorMessage()
const {
91 CAFFE_ENFORCE(event_err_msg_getter_[type_]);
92 return event_err_msg_getter_[type_](
this);
96 CAFFE_ENFORCE(event_resetter_[type_]);
97 event_resetter_[type_](
this);
100 const DeviceOption& GetDeviceOption()
const {
104 bool IsScheduled()
const {
105 return Query() == EventStatus::EVENT_SCHEDULED;
108 bool IsFinished()
const {
109 auto status = Query();
110 return status == EventStatus::EVENT_SUCCESS ||
111 status == EventStatus::EVENT_FAILED;
114 void SetFinished(
const char* err_msg =
nullptr) {
115 CAFFE_ENFORCE(event_finished_setter_[type_]);
116 return event_finished_setter_[type_](
this, err_msg);
128 bool CanSchedule(
const Event& child_event,
bool supports_async)
const {
129 return CanSchedule(type_, Query(), child_event.GetType(), supports_async);
132 static bool CanSchedule(
134 EventStatus parent_status,
136 bool child_supports_async) {
137 if (parent_status == EventStatus::EVENT_SUCCESS) {
140 if (parent_status == EventStatus::EVENT_SCHEDULED) {
141 return (parent_type == child_type) && child_supports_async;
146 int GetType()
const {
153 std::shared_ptr<void> event_;
157 DeviceOption option_;
159 CAFFE2_API
static EventCreateFunction event_creator_[MaxDeviceTypes];
160 CAFFE2_API
static EventRecordFunction event_recorder_[MaxDeviceTypes];
161 CAFFE2_API
static EventWaitFunction event_waiter_[MaxDeviceTypes]
163 CAFFE2_API
static EventFinishFunction event_finisher_[MaxDeviceTypes];
165 CAFFE2_API
static EventQueryFunction event_querier_[MaxDeviceTypes];
166 CAFFE2_API
static EventErrorMessageFunction
167 event_err_msg_getter_[MaxDeviceTypes];
168 CAFFE2_API
static EventSetFinishedFunction
169 event_finished_setter_[MaxDeviceTypes];
170 CAFFE2_API
static EventResetFunction event_resetter_[MaxDeviceTypes];
176 template <
int w,
int d>
194 static_assert(d < MaxDeviceTypes,
"");
195 Event::event_creator_[d] = f;
198 #define REGISTER_EVENT_CREATE_FUNCTION(d, f) \ 200 static EventCreateFunctionRegisterer<d> g_event_create_##d(f); \ 206 static_assert(d < MaxDeviceTypes,
"");
207 Event::event_recorder_[d] = f;
210 #define REGISTER_EVENT_RECORD_FUNCTION(d, f) \ 212 static EventRecordFunctionRegisterer<d> g_event_record_##d(f); \ 215 template <
int waiter_type,
int event_type>
218 static_assert(waiter_type < MaxDeviceTypes,
"");
219 static_assert(event_type < MaxDeviceTypes,
"");
220 Event::event_waiter_[waiter_type][event_type] = f;
223 #define REGISTER_EVENT_WAIT_FUNCTION(w, d, f) \ 225 static EventWaitFunctionRegisterer<w, d> g_event_wait_##w##_##d(f); \ 231 static_assert(d < MaxDeviceTypes,
"");
232 Event::event_querier_[d] = f;
235 #define REGISTER_EVENT_QUERY_FUNCTION(d, f) \ 237 static EventQueryFunctionRegisterer<d> g_event_query_##d(f); \ 243 static_assert(d < MaxDeviceTypes,
"");
244 Event::event_err_msg_getter_[d] = f;
247 #define REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(d, f) \ 249 static EventErrorMessageFunctionRegisterer<d> g_event_err_msg_##d(f); \ 255 static_assert(d < MaxDeviceTypes,
"");
256 Event::event_finished_setter_[d] = f;
259 #define REGISTER_EVENT_SET_FINISHED_FUNCTION(d, f) \ 261 static EventSetFinishedFunctionRegisterer<d> g_event_set_finished_##d(f); \ 267 static_assert(d < MaxDeviceTypes,
"");
268 Event::event_finisher_[d] = f;
271 #define REGISTER_EVENT_FINISH_FUNCTION(d, f) \ 273 static EventFinishFunctionRegisterer<d> g_event_finish_##d(f); \ 279 static_assert(d < MaxDeviceTypes,
"");
280 Event::event_resetter_[d] = f;
283 #define REGISTER_EVENT_RESET_FUNCTION(d, f) \ 285 static EventResetFunctionRegisterer<d> g_event_reset_##d(f); \ 290 #endif // CAFFE2_CORE_EVENT_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...