Caffe2 - C++ API
A deep learning, cross platform ML framework
operator.h
1 #ifndef CAFFE2_CORE_OPERATOR_H_
2 #define CAFFE2_CORE_OPERATOR_H_
3 
4 #include <array>
5 #include <climits>
6 #include <cstddef>
7 #include <exception>
8 #include <set>
9 #include <typeinfo>
10 #include <vector>
11 
12 #include "caffe2/core/blob.h"
13 #include "caffe2/core/common.h"
14 #include "caffe2/core/net.h"
15 #include "caffe2/core/observer.h"
16 #include "caffe2/core/operator_gradient.h"
17 #include "caffe2/core/operator_schema.h"
18 #include "caffe2/core/registry.h"
19 #include "caffe2/core/tensor.h"
20 #include "caffe2/core/types.h"
21 #include "caffe2/core/workspace.h"
22 #include "caffe2/proto/caffe2.pb.h"
23 #include "caffe2/utils/proto_utils.h"
24 
25 namespace caffe2 {
26 
27 class OperatorBase;
28 typedef ObserverBase<OperatorBase> OperatorObserver;
29 
30 class OperatorBase : public Observable<OperatorBase> {
31  public:
32  explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
33  virtual ~OperatorBase() noexcept {}
34 
37  inline bool HasArgument(const string& name) const {
38  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
39  return ArgumentHelper::HasArgument(*operator_def_, name);
40  }
41 
42  // Functions that deal with arguments. Basically, this allows us to map an
43  // argument name to a specific type of argument that we are trying to access.
44  template <typename T>
45  inline T GetSingleArgument(const string& name, const T& default_value) const {
46  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
47  return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
48  *operator_def_, name, default_value);
49  }
50  template <typename T>
51  inline bool HasSingleArgumentOfType(const string& name) const {
52  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
53  return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
54  *operator_def_, name);
55  }
56  template <typename T>
57  inline vector<T> GetRepeatedArgument(
58  const string& name,
59  const vector<T>& default_value = {}) const {
60  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
61  return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
62  *operator_def_, name, default_value);
63  }
64 
65  // Get the inputs and outputs as specific types.
66  template <typename T>
67  inline const T& Input(int idx) {
68  DCHECK_LT(idx, inputs_.size());
69  try {
70  return inputs_.at(idx)->template Get<T>();
71  } catch (::caffe2::EnforceNotMet& enf) {
72  if (has_debug_def()) {
73  enf.AppendMessage(".\nOffending Blob name: ");
74  enf.AppendMessage(debug_def().input(idx));
75  enf.AppendMessage(".\n");
76  }
77  throw enf;
78  }
79  }
80 
81  template <typename T>
82  inline T* Output(int idx) {
83  return outputs_.at(idx)->template GetMutable<T>();
84  }
85 
86  template <typename T>
87  inline T* Output(int idx, T* allocated) {
88  outputs_.at(idx)->Reset(allocated);
89  return allocated;
90  }
91 
92  inline const Blob& InputBlob(int idx) {
93  return *inputs_.at(idx);
94  }
95 
96  inline Blob* OutputBlob(int idx) {
97  return outputs_.at(idx);
98  }
99 
100  template <typename T>
101  inline bool InputIsType(int idx) {
102  return inputs_.at(idx)->template IsType<T>();
103  }
104 
105  template <typename T>
106  inline bool OutputIsType(int idx) {
107  return outputs_.at(idx)->template IsType<T>();
108  }
109 
110  inline int InputSize() const {
111  return inputs_.size();
112  }
113  inline int OutputSize() const {
114  return outputs_.size();
115  }
116  inline const vector<const Blob*>& Inputs() const { return inputs_; }
117  inline const vector<Blob*>& Outputs() { return outputs_; }
118  vector<TensorShape> InputTensorShapes();
119 
120  virtual void WaitEvent(const Event& ev, int stream_id = -1) {
121  ev.Finish();
122  }
123 
124  inline void Wait(const OperatorBase& other, int stream_id = -1) {
125  WaitEvent(other.event(), stream_id);
126  }
127 
128  virtual void WaitEvents(
129  const std::vector<const Event*>& events,
130  int stream_id = -1) {
131  for (const auto& ev : events) {
132  ev->Finish();
133  }
134  }
135 
136  virtual void Finish() {
137  if (event_) {
138  event_->Finish();
139  }
140  }
141 
142  virtual bool Run(int /* unused */ /*stream_id*/ = 0) {
143  CAFFE_NOT_IMPLEMENTED;
144  }
145 
146  virtual bool HasAsyncPart() const {
147  return false;
148  }
149 
150  virtual bool SupportsAsyncScheduling() const {
151  return false;
152  }
153 
154  // RunAsync, if implemenented by the specific operators, will schedule the
155  // computation on the corresponding context and record the event in its
156  // event_ member object. If the specific operator does not support RunAsync,
157  // it will simply be synchronous as a fallback.
158  virtual bool RunAsync(int stream_id = 0) {
159  return Run(stream_id);
160  }
161 
162  virtual void AddRelatedBlobInfo(EnforceNotMet* err) {
163  if (!has_debug_def()) {
164  return;
165  }
166 
167  bool found_input;
168  if (err->caller() != nullptr) {
169  for (int i = 0; i < inputs_.size(); i++) {
170  if (inputs_[i]->GetRaw() == err->caller()) {
171  found_input = true;
172  err->AppendMessage(
173  "\n** while accessing input: " + debug_def().input(i));
174  break;
175  }
176  }
177  for (int i = 0; i < outputs_.size(); i++) {
178  if (outputs_[i]->GetRaw() == err->caller()) {
179  if (found_input) {
180  err->AppendMessage("\n OR ");
181  }
182  err->AppendMessage(
183  "\n** while accessing output: " + debug_def().output(i));
184  break;
185  }
186  }
187  }
188  }
189 
190  inline const OperatorDef& debug_def() const {
191  CAFFE_ENFORCE(has_debug_def(), "operator_def was null!");
192  return *operator_def_;
193  }
194 
195  inline void set_debug_def(
196  const std::shared_ptr<const OperatorDef>& operator_def) {
197  operator_def_ = operator_def;
198  }
199 
200  inline bool has_debug_def() const {
201  return operator_def_ != nullptr;
202  }
203 
204  public:
205  void RecordLastFailedOpNetPosition() {
206  if (net_position_ != kNoNetPositionSet) {
207  VLOG(1) << "Operator with id " << net_position_ << " failed";
208  operator_ws_->last_failed_op_net_position = net_position_;
209  } else {
210  VLOG(1) << "Failed operator doesn't have id set";
211  }
212  }
213 
214  int net_position() const {
215  return net_position_;
216  }
217 
218  void set_net_position(int idx) {
219  net_position_ = idx;
220  }
221 
222  const DeviceOption& device_option() const {
223  return device_option_;
224  }
225 
226  const Event& event() const {
227  CAFFE_ENFORCE(event_, "Event is disabled");
228  return *event_;
229  }
230 
231  Event& event() {
232  CAFFE_ENFORCE(event_, "Event is disabled");
233  return *event_;
234  }
235 
236  void ResetEvent() {
237  if (event_) {
238  event_->Reset();
239  }
240  }
241 
242  void DisableEvent() {
243  event_ = nullptr;
244  }
245 
246  bool IsEventDisabled() const {
247  return !event_;
248  }
249 
250  // Checks whether stream is ready to execute new computation,
251  // used in stream allocation optimization to skip stream that is currently
252  // busy. Depends on context and operator's device, returns true by default
253  virtual bool IsStreamFree(int /* unused */) const {
254  return true;
255  }
256 
257  const std::string& type() const {
258  CAFFE_ENFORCE(operator_def_.get() != nullptr);
259  return operator_def_->type();
260  }
261 
262  void annotate_engine(const std::string& engine) {
263  engine_ = engine;
264  }
265 
266  const std::string& engine() const {
267  return engine_;
268  }
269 
270  public:
271  static constexpr int kNoNetPositionSet = -1;
272 
273  private:
274  Workspace* operator_ws_;
275  std::shared_ptr<const OperatorDef> operator_def_;
276  DeviceOption device_option_;
277  std::string engine_;
278  vector<const Blob*> inputs_;
279  vector<Blob*> outputs_;
280 
281  int net_position_{kNoNetPositionSet};
282 
283  protected:
284  virtual void RecordEvent(const char* err_msg = nullptr) {
285  CAFFE_NOT_IMPLEMENTED;
286  }
287 
288  // An event used by asynchronous execution.
289  std::unique_ptr<Event> event_;
290 
291  DISABLE_COPY_AND_ASSIGN(OperatorBase);
292 };
293 
294 // If your operator does not need any specialized contructor or destructor,
295 // you can simply use this to save two lines of code.
296 #define USE_SIMPLE_BASE_CTOR_DTOR(name) \
297  name(const OperatorDef& operator_def, Workspace* ws) \
298  : OperatorBase(operator_def, ws) {} \
299  virtual ~name() noexcept {}
300 
301 // OP_SINGLE_ARG provides a shorter initialization choice for initialization of
302 // member variables for the class constructors.
303 #define OP_SINGLE_ARG(type, name, variable, default) \
304  variable(OperatorBase::GetSingleArgument<type>(name, (default)))
305 
306 // INPUT_TAGS and OUTPUT_TAGS are optional features to name the indices of the
307 // operator's inputs and outputs, in order to avoid confusion. For example, for
308 // a fully convolution layer that has input, weight and bias, you can define its
309 // input tags as:
310 // INPUT_TAGS(INPUT, WEIGHT, BIAS);
311 // And in the code, instead of doing
312 // auto& weight = Input(1);
313 // you can now do
314 // auto& weight = Input(WEIGHT);
315 // to make it more clear.
316 #define INPUT_TAGS(first_input, ...) \
317  enum _InputTags { first_input = 0, __VA_ARGS__ }
318 #define OUTPUT_TAGS(first_input, ...) \
319  enum _OutputTags { first_input = 0, __VA_ARGS__ }
320 
321 // Operator is the class that you usually want to derive, if your operator will
322 // run on different devices. You should then implement the RunOnDevice()
323 // function.
324 template <class Context>
325 class Operator : public OperatorBase {
326  public:
327  explicit Operator(const OperatorDef& operator_def, Workspace* ws)
328  : OperatorBase(operator_def, ws), context_(operator_def.device_option()) {
329  // In the constructor, we switch to the device so that the child class
330  // constructors will run on that device.
331  context_.SwitchToDevice(0);
332  }
333  ~Operator() noexcept override {}
334 
335  inline const Tensor<Context>& Input(int idx) {
336  return OperatorBase::template Input<Tensor<Context>>(idx);
337  }
338  inline Tensor<Context>* Output(int idx) {
339  return OperatorBase::template Output<Tensor<Context>>(idx);
340  }
341 
342  void WaitEvent(const Event& ev, int stream_id = -1) final {
343  if (stream_id >= 0) {
344  context_.SwitchToDevice(stream_id);
345  }
346  context_.WaitEvent(ev);
347  }
348 
349  void WaitEvents(const std::vector<const Event*>& events, int stream_id = -1)
350  final {
351  if (stream_id >= 0) {
352  context_.SwitchToDevice(stream_id);
353  }
354  for (const auto& ev : events) {
355  context_.WaitEvent(*ev);
356  }
357  }
358 
359  // The run function of Operator switches to the device, and then carries out
360  // the actual computation with RunOnDevice(). You should implement RunOnDevice
361  // instead of Run().
362  // Note: Run does not update operator's event and can be used only with
363  // non-async executors that do not rely on events
364  bool Run(int stream_id = 0) final {
365  try {
366  StartAllObservers();
367 
368  context_.SwitchToDevice(stream_id);
369  bool result = RunOnDevice();
370  if (!result) {
371  this->RecordLastFailedOpNetPosition();
372  }
373  context_.FinishDeviceComputation(); // throws on error
374 
375  StopAllObservers();
376 
377  return result;
378  } catch (EnforceNotMet& err) {
379  if (has_debug_def()) {
380  err.AppendMessage(
381  "Error from operator: \n" + ProtoDebugString(debug_def()));
382  AddRelatedBlobInfo(&err);
383  }
384  this->RecordLastFailedOpNetPosition();
385  throw;
386  } catch (...) {
387  this->RecordLastFailedOpNetPosition();
388  throw;
389  }
390  }
391 
392  bool RunAsync(int stream_id = 0) final {
393  try {
394  context_.SwitchToDevice(stream_id);
395  auto result = RunOnDevice();
396  if (result) {
397  if (HasAsyncPart()) {
398  RecordEvent();
399  } else {
400  // Manually set CPU operator's event status to finished,
401  // unless this is an async CPU operator
402  event().SetFinished();
403  }
404  } else {
405  event().SetFinished(getErrorMsg().c_str());
406  this->RecordLastFailedOpNetPosition();
407  }
408  return result;
409  } catch (EnforceNotMet& err) {
410  if (has_debug_def()) {
411  err.AppendMessage(
412  "Error from operator: \n" + ProtoDebugString(debug_def()));
413  AddRelatedBlobInfo(&err);
414  }
415  event().SetFinished(err.what());
416  this->RecordLastFailedOpNetPosition();
417  throw;
418  } catch (const std::exception& err) {
419  event().SetFinished(err.what());
420  this->RecordLastFailedOpNetPosition();
421  throw;
422  } catch (...) {
423  event().SetFinished(getErrorMsg().c_str());
424  this->RecordLastFailedOpNetPosition();
425  throw;
426  }
427  }
428 
429  bool IsStreamFree(int stream_id) const override {
430  return context_.IsStreamFree(device_option(), stream_id);
431  }
432 
433  virtual bool RunOnDevice() = 0;
434 
435  // Returns whether operator has async on device part.
436  // CUDA operators by default have async parts, CPU operators by default
437  // don't have async parts and are finished after RunOnDevice call.
438  // Events of operators that don't have async parts are automatically set
439  // to finished state by RunAsync.
440  // Defaulting to the value from context (true for CUDA, false for CPU).
441  // Override in case of async CPU operators
442  bool HasAsyncPart() const override {
443  return context_.HasAsyncPartDefault();
444  }
445 
446  // Returns whether operator's RunOnDevice schedules async on device part and
447  // can be run without waiting for parent operator's async part to be finished
448  // on the same device.
449  // Note: when true, RunOnDevice must not access the content of the input blobs
450  // as they might not be computed yet
451  // Note: when true, operator's device needs to support async scheduling:
452  // - supports concept of streams: async ops scheduled on the same stream are
453  // guaranteed to be executed in the same order they were scheduled
454  // - provides non-blocking cross device/cross stream synchronization
455  // primitives
456  //
457  // By default, assuming an op with an async part can be scheduled
458  // asynchronously if device supports async scheduling
459  bool SupportsAsyncScheduling() const override {
460  return HasAsyncPart() && context_.SupportsAsyncScheduling();
461  }
462 
463  const Context* getContext() const {
464  return &context_;
465  }
466 
467  protected:
468  void RecordEvent(const char* err_msg = nullptr) final {
469  if (event_) {
470  context_.Record(event_.get(), err_msg);
471  }
472  }
473 
474  std::string getErrorMsg() {
475  if (has_debug_def()) {
476  return "Error from operator: " + ProtoDebugString(debug_def());
477  } else {
478  return "Error from operator: no op def";
479  }
480  }
481 
482  Context context_;
483 };
484 
485 #define USE_OPERATOR_BASE_FUNCTIONS \
486  /* using override */ using OperatorBase::HasArgument; \
487  /* using override */ using OperatorBase::GetSingleArgument; \
488  /* using override */ using OperatorBase::HasSingleArgumentOfType; \
489  /* using override */ using OperatorBase::GetRepeatedArgument; \
490  /* using override */ using OperatorBase::InputIsType; \
491  /* using override */ using OperatorBase::InputSize; \
492  /* using override */ using OperatorBase::OutputSize
493 
494 #define USE_OPERATOR_FUNCTIONS(context) \
495  USE_OPERATOR_BASE_FUNCTIONS; \
496  /* using override */ using Operator<context>::context_; \
497  /* using override */ using Operator<context>::Input; \
498  /* using override */ using Operator<context>::InputBlob; \
499  /* using override */ using Operator<context>::Output; \
500  /* using override */ using Operator<context>::OutputBlob
501 
502 #define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context)
503 
504 #define USE_SIMPLE_CTOR_DTOR(name) \
505  name(const OperatorDef& operator_def, Workspace* ws) \
506  : Operator<Context>(operator_def, ws) {} \
507  virtual ~name() noexcept {}
508 
509 // Helpers to implement runtime op polymorphism. Often it's convenient to make
510 // an op work on different input types (e.g. i32 vs i64 indices) or special-case
511 // it for particular input size (e.g. ScatterWeightedSum for block size of 1
512 // doesn't need to call Eigen).
513 //
514 // DispatchHelper provides compile-time generation of nested "if" statements,
515 // e.g. `DispatchHelper<FixedValues<1, 4>>::call(this, block_size);`
516 // unrolls into:
517 // if (block_size == 1) {
518 // return DoRunWithValue<1>();
519 // } else if (block_size = 4) {
520 // return DoRunWithValue<4>();
521 // } else {
522 // return DoRunWithValue<-1>();
523 // }`
524 //
525 // DoRunWithValue implementation can use template arguments to do "if"
526 // statements
527 // or proxy to functions in math.h which often provide fixed size
528 // implementation.
529 //
530 // Similarly `TensorTypes<int32_t, int64_t>(this, Input(0))` provides branching
531 // based on type of the first input and calls DoRunWithType.
532 //
533 // Note, that the same instance of Op class is used as the method, not class is
534 // templated. We might consider adding static class-level polymorphism later.
535 //
536 // Convenient macro USE_DISPATCH_HELPER is provided for declaring friendship in
537 // case DoRunWithValue or DoRunWithType are declared non-public.
538 
539 #define USE_DISPATCH_HELPER \
540  template <typename FirstArg, typename... ExtraArgs> \
541  friend struct DispatchHelper
542 
543 template <int... Values>
544 struct FixedValues {};
545 
546 template <typename... Types>
547 struct TensorTypes {};
548 
549 // Special tag that can be listed in TensorTypes to denote that a special
550 // implementation in 'RunWithOtherType' needs to be called instead of failing
551 // Obviously this needs to be the last item in lists, e.g.
552 // TensorTypes<float, double, GenericTensorImplementation>
554 
555 // Same as TensorTypes but call DoRunWithType2
556 template <typename... Types>
557 struct TensorTypes2 {};
558 
559 template <typename Sizes, typename... ExtraArgs>
561 
562 template <int FirstVal, int... Values, typename... ExtraArgs>
563 struct DispatchHelper<FixedValues<FirstVal, Values...>, ExtraArgs...> {
564  template <typename Op>
565  static bool call(Op* op, int value) {
566  if (FirstVal == value) {
567  return op->template DoRunWithValue<ExtraArgs..., FirstVal>();
568  }
569  return DispatchHelper<FixedValues<Values...>, ExtraArgs...>::template call<
570  Op>(op, value);
571  }
572 };
573 
574 template <typename... ExtraArgs>
575 struct DispatchHelper<FixedValues<>, ExtraArgs...> {
576  template <typename Op>
577  static bool call(Op* op, TIndex /*size*/) {
578  return op->template DoRunWithValue<ExtraArgs..., -1>();
579  }
580 };
581 
582 #define CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER( \
583  TensorTypes, DoRunWithType, DoRunWithOtherType) \
584  template <typename FirstType, typename... Types, typename... ExtraArgs> \
585  struct DispatchHelper<TensorTypes<FirstType, Types...>, ExtraArgs...> { \
586  template <typename Op> \
587  static bool call(Op* op, const TypeMeta& meta) { \
588  static_assert( \
589  !std::is_same<GenericTensorImplementation, FirstType>::value, \
590  "GenericTensorImplementation must be the last in TensorTypes list"); \
591  if (meta.Match<FirstType>()) { \
592  return op->template DoRunWithType<ExtraArgs..., FirstType>(); \
593  } \
594  return DispatchHelper<TensorTypes<Types...>, ExtraArgs...>:: \
595  template call<Op>(op, meta); \
596  } \
597  template <typename Op, typename Context> \
598  static bool call(Op* op, const Tensor<Context>& tensor) { \
599  return call<Op>(op, tensor.meta()); \
600  } \
601  template <typename Op> \
602  static bool call(Op* op, const Blob& blob) { \
603  return call<Op>(op, blob.meta()); \
604  } \
605  }; \
606  \
607  template <typename... ExtraArgs> \
608  struct DispatchHelper<TensorTypes<>, ExtraArgs...> { \
609  template <typename Op> \
610  static bool call(Op* /* unused */, const TypeMeta& meta) { \
611  CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \
612  } \
613  template <typename Op, typename Context> \
614  static bool call(Op* op, const Tensor<Context>& tensor) { \
615  return call<Op>(op, tensor.meta()); \
616  } \
617  template <typename Op> \
618  static bool call(Op* op, const Blob& blob) { \
619  return call<Op>(op, blob.meta()); \
620  } \
621  }; \
622  \
623  template <typename... ExtraArgs> \
624  struct DispatchHelper< \
625  TensorTypes<GenericTensorImplementation>, \
626  ExtraArgs...> { \
627  template <typename Op> \
628  static bool call(Op* op, const TypeMeta&) { \
629  return op->template DoRunWithOtherType<ExtraArgs...>(); \
630  } \
631  template <typename Op, typename Context> \
632  static bool call(Op* op, const Tensor<Context>& tensor) { \
633  return call<Op>(op, tensor.meta()); \
634  } \
635  template <typename Op> \
636  static bool call(Op* op, const Blob& blob) { \
637  return call<Op>(op, blob.meta()); \
638  } \
639  };
640 CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER(
641  TensorTypes,
642  DoRunWithType,
643  DoRunWithOtherType)
644 CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER(
645  TensorTypes2,
646  DoRunWithType2,
647  DoRunWithOtherType2)
648 #undef CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER
649 
650 // The device type registry. This works in two phases:
651 // (1) gDeviceTypeRegistry() maps the device types values to the actual operator
652 // registry function.
653 // (2) Then, one can call the operator registry function to further create the
654 // operators.
655 typedef Registry<
656  std::string,
657  std::unique_ptr<OperatorBase>,
658  const OperatorDef&,
659  Workspace*>
660  OperatorRegistry;
661 typedef Registry<
662  std::string,
663  std::unique_ptr<OperatorBase>,
664  const OperatorDef&,
665  Workspace*>* (*RegistryFunction)();
666 std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry();
667 
669  explicit DeviceTypeRegisterer(int32_t type, RegistryFunction func) {
670  if (gDeviceTypeRegistry()->count(type)) {
671  std::cerr << "Device type " << type
672  << "registered twice. This should not happen. Did you have "
673  "duplicated numbers assigned to different devices?";
674  std::exit(1);
675  }
676  // Calling the registry function to get the actual registry pointer.
677  gDeviceTypeRegistry()->emplace(type, func());
678  }
679 };
680 
681 #define CAFFE_REGISTER_DEVICE_TYPE(type, registry_function) \
682  namespace { \
683  static DeviceTypeRegisterer CAFFE_ANONYMOUS_VARIABLE( \
684  DeviceType)(type, &registry_function); \
685  }
686 
687 // The operator registry. Since we are not expecting a great number of devices,
688 // we will simply have an if-then type command and allocate the actual
689 // generation to device-specific registerers.
690 // Note that although we have CUDA and CUDNN here, the registerers themselves do
691 // not depend on specific cuda or cudnn libraries. This means that we will be
692 // able to compile it even when there is no cuda available - we simply do not
693 // link any cuda or cudnn operators.
694 CAFFE_DECLARE_REGISTRY(
695  CPUOperatorRegistry,
696  OperatorBase,
697  const OperatorDef&,
698  Workspace*);
699 #define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \
700  CAFFE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__)
701 #define REGISTER_CPU_OPERATOR(name, ...) \
702  extern void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
703  static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CPU##name() { \
704  CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
705  } \
706  CAFFE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__)
707 #define REGISTER_CPU_OPERATOR_STR(str_name, ...) \
708  CAFFE_REGISTER_TYPED_CLASS(CPUOperatorRegistry, str_name, __VA_ARGS__)
709 
710 #define REGISTER_CPU_OPERATOR_WITH_ENGINE(name, engine, ...) \
711  CAFFE_REGISTER_CLASS(CPUOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
712 
713 CAFFE_DECLARE_REGISTRY(
714  CUDAOperatorRegistry,
715  OperatorBase,
716  const OperatorDef&,
717  Workspace*);
718 #define REGISTER_CUDA_OPERATOR_CREATOR(key, ...) \
719  CAFFE_REGISTER_CREATOR(CUDAOperatorRegistry, key, __VA_ARGS__)
720 #define REGISTER_CUDA_OPERATOR(name, ...) \
721  extern void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
722  static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CUDA##name() { \
723  CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
724  } \
725  CAFFE_REGISTER_CLASS(CUDAOperatorRegistry, name, __VA_ARGS__)
726 #define REGISTER_CUDA_OPERATOR_STR(str_name, ...) \
727  CAFFE_REGISTER_TYPED_CLASS(CUDAOperatorRegistry, str_name, __VA_ARGS__)
728 
729 #define REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, engine, ...) \
730  CAFFE_REGISTER_CLASS( \
731  CUDAOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
732 
733 // Macros for cudnn since we use it often
734 #define REGISTER_CUDNN_OPERATOR(name, ...) \
735  REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__)
736 
737 // StaticLinkingProtector is a helper class that ensures that the Caffe2
738 // library is linked correctly with whole archives (in the case of static
739 // linking). What happens is that when CreateOperator is called for the first
740 // time, it instantiates an OperatorLinkingProtector object to check if the
741 // operator registry is empty. If it is empty, this means that we are not
742 // properly linking the library.
743 //
744 // You should not need to use this class.
747  const int registered_ops = CPUOperatorRegistry()->Keys().size();
748  // Note: this is a check failure instead of an exception, because if
749  // the linking is wrong, Caffe2 won't be able to run properly anyway,
750  // so it's better to fail loud.
751  // If Caffe2 is properly linked with whole archive, there should be more
752  // than zero registered ops.
753  if (registered_ops == 0) {
754  LOG(FATAL) <<
755  "You might have made a build error: the Caffe2 library does not seem "
756  "to be linked with whole-static library option. To do so, use "
757  "-Wl,-force_load (clang) or -Wl,--whole-archive (gcc) to link the "
758  "Caffe2 library.";
759  }
760  }
761 };
762 
763 // An exception that can be thrown by an operator constructor that notifies
764 // that it does not support the given setting. This can be usually used for
765 // specific engines that only implement a subset of the features required by
766 // the original operator schema.
767 // TODO(jiayq): make more feature-complete exception message.
768 class UnsupportedOperatorFeature : public std::exception {
769  public:
770  UnsupportedOperatorFeature(const string& msg) : msg_(msg) {}
771  const char* what() const noexcept override {
772  return msg_.c_str();
773  }
774 
775  private:
776  string msg_;
777 };
778 
779 // A helper macro that should ONLY be used in the operator constructor to check
780 // if needed features are met. If not, throws the UnsupportedOperatorFeature
781 // exception with the given message.
782 #define OPERATOR_NEEDS_FEATURE(condition, ...) \
783  if (!(condition)) { \
784  throw UnsupportedOperatorFeature(::caffe2::MakeString(__VA_ARGS__)); \
785  }
786 
787 // Creates an operator with the given operator definition.
788 // Throws on error and never returns nullptr
789 unique_ptr<OperatorBase> CreateOperator(
790  const OperatorDef& operator_def,
791  Workspace* ws,
792  int net_position = OperatorBase::kNoNetPositionSet);
793 
794 const std::string OpRegistryKey(
795  const std::string& op_type,
796  const std::string& engine = "");
797 
798 // User can set the preferred engines as a list of engine names, in
799 // descending order of preference.
800 using EnginePrefType = std::vector<std::string>;
801 // {device_type -> {operator_name -> EnginePrefType}}
802 using PerOpEnginePrefType =
803  CaffeMap<int, CaffeMap<std::string, EnginePrefType>>;
804 // {device_type -> EnginePrefType}
805 using GlobalEnginePrefType = CaffeMap<int, EnginePrefType>;
806 void SetPerOpEnginePref(const PerOpEnginePrefType& per_op_engine_pref);
807 void SetGlobalEnginePref(const GlobalEnginePrefType& global_engine_pref);
808 void SetEnginePref(
809  const PerOpEnginePrefType& per_op_engine_pref,
810  const GlobalEnginePrefType& global_engine_pref);
811 void SetOpEnginePref(
812  const std::string& op_type,
813  const CaffeMap<int, EnginePrefType>& op_pref);
814 
815 TensorShape GetTensorShapeOfBlob(const Blob* b);
816 
817 TensorShapes InferBlobShapesAndTypesFromWorkspace(
818  Workspace* ws,
819  const vector<std::unique_ptr<NetDef>>& nets);
820 
821 TensorShapes InferBlobShapesAndTypesFromMap(
822  const CaffeMap<std::string, std::vector<TIndex>>& blob_dimensions,
823  const vector<std::unique_ptr<NetDef>>& nets);
824 
825 std::map<string, std::pair<DeviceOption, DeviceOption>> ValidateTensorDevices(
826  OperatorBase& op,
827  const OperatorDef& op_def);
828 
829 // Get a set of registered operator names
830 std::set<std::string> GetRegisteredOperators();
831 
832 } // namespace caffe2
833 
834 #endif // CAFFE2_CORE_OPERATOR_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
Inherit to make your class observable.
Definition: observer.h:39
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37
A template class that allows one to register classes by keys.
Definition: registry.h:41