Caffe2 - C++ API
A deep learning, cross platform ML framework
mkl_operator.h
1 #ifndef CAFFE2_UTILS_MKL_OPERATOR_H_
2 #define CAFFE2_UTILS_MKL_OPERATOR_H_
3 
4 #include "caffe2/core/operator.h"
5 #include "caffe2/mkl/utils/mkl_dnn_cppwrapper.h"
6 #include "caffe2/mkl/utils/mkl_memory.h"
7 #include "caffe2/proto/caffe2.pb.h"
8 
9 CAFFE2_DECLARE_bool(caffe2_mkl_memonger_in_use);
10 
11 namespace caffe2 {
12 
13 CAFFE_DECLARE_REGISTRY(
14  MKLOperatorRegistry,
15  OperatorBase,
16  const OperatorDef&,
17  Workspace*);
18 #define REGISTER_MKL_OPERATOR_CREATOR(key, ...) \
19  CAFFE_REGISTER_CREATOR(MKLOperatorRegistry, key, __VA_ARGS__)
20 #define REGISTER_MKL_OPERATOR(name, ...) \
21  CAFFE_REGISTER_CLASS(MKLOperatorRegistry, name, __VA_ARGS__)
22 #define REGISTER_MKL_OPERATOR_STR(str_name, ...) \
23  CAFFE_REGISTER_TYPED_CLASS(MKLOperatorRegistry, str_name, __VA_ARGS__)
24 
25 #define REGISTER_MKL_OPERATOR_WITH_ENGINE(name, engine, ...) \
26  CAFFE_REGISTER_CLASS(MKLOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
27 
28 namespace mkl {
29 // MKLOperator is the base scaffolding of the operators that uses MKLDNN. It
30 // provides a few operators that are useful to MKLDNN specific implementations.
31 template <typename T>
32 class MKLOperator : public OperatorBase {
33  public:
34  explicit MKLOperator(const OperatorDef& operator_def, Workspace* ws)
35  : OperatorBase(operator_def, ws),
36  context_(operator_def.device_option()) {}
37  virtual ~MKLOperator() {}
38 
39  inline const MKLMemory<T>& Input(int idx) {
40  return OperatorBase::template Input<MKLMemory<T>>(idx);
41  }
42  inline MKLMemory<T>* Output(int idx) {
43  return OperatorBase::template Output<MKLMemory<T>>(idx);
44  }
45 
46  // The run function of Operator switches to the device, and then carries out
47  // the actual computation with RunOnDevice(). You should implement RunOnDevice
48  // instead of Run().
49  bool Run(int /* unused */ /*stream_id*/) final {
50  // Since MKLDNN does not need to do SwithToDevice and
51  // FinishDeviceComputation,
52  // it is always just a re-route to RunOnDevice().
53  try {
54  return RunOnDevice();
55  } catch (EnforceNotMet& err) {
56  err.AppendMessage(getErrorMsg());
57  throw;
58  }
59  }
60 
61  // Waits for a previous event. Note that to properly wait and run
62  // asynchronously, WaitEvent, RunAsync and Record should all be executed
63  // on the same CPU thread.
64  void WaitEvent(const Event& ev, int /* unused */) final {
65  context_.WaitEvent(ev);
66  }
67 
68  void WaitEvents(const std::vector<const Event*>& events, int /* unused */)
69  final {
70  for (const auto& ev : events) {
71  context_.WaitEvent(*ev);
72  }
73  }
74 
75  void RecordEvent(const char* err_msg = nullptr) final {
76  if (event_) {
77  context_.Record(event_.get(), err_msg);
78  }
79  }
80 
81  virtual bool RunOnDevice() = 0;
82 
83  inline void ExecutePrimitive() {
84  MKLDNN_SAFE_CALL(mkl::dnnExecute<T>(primitive_, resources_));
85  }
86 
87  protected:
88  std::string getErrorMsg() {
89  if (has_debug_def()) {
90  return "Error from operator: " + ProtoDebugString(debug_def());
91  } else {
92  return "Error from operator: no op def";
93  }
94  }
95 
96  MKLContext context_;
97  // The primitive used in the operator.
98  PrimitiveWrapper<T> primitive_;
99  // Size cache for all the input sizes.
100  vector<vector<TIndex>> input_size_cache_;
101  // An internal MKLMemory buffer. This is usually handy when we have a
102  // single output from the operator. If your operator has multiple outputs
103  // then you should allocate your own buffer.
104  MKLMemory<T> buffer_;
105  // The resources vector that we will need to use;
106  void* resources_[dnnResourceNumber];
107 };
108 } // namespace mkl
109 
110 #define USE_MKLOPERATOR_FUNCTIONS(T) \
111  USE_OPERATOR_BASE_FUNCTIONS; \
112  /* using override */ using MKLOperator<T>::Input; \
113  /* using override */ using MKLOperator<T>::Output; \
114  /* using override */ using MKLOperator<T>::ExecutePrimitive; \
115  /* using override */ using MKLOperator<T>::primitive_; \
116  /* using override */ using MKLOperator<T>::input_size_cache_; \
117  /* using override */ using MKLOperator<T>::buffer_; \
118  /* using override */ using MKLOperator<T>::resources_
119 
120 #define USE_SIMPLE_MKL_CTOR_DTOR(name, T) \
121  name(const OperatorDef& operator_def, Workspace* ws) \
122  : MKLOperator<T>(operator_def, ws) {} \
123  virtual ~name() {}
124 
125 } // namespace caffe2
126 
127 #endif // CAFFE2_UTILS_MKL_OPERATOR_H_
The MKL Context, which is largely the same as the CPUContext.
Definition: mkl_context.h:20
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
A wrapper around an opaque MKL internal resource that has certain layouts and convertion primitives s...
Definition: mkl_memory.h:151