1 #ifndef CAFFE2_UTILS_MKL_OPERATOR_H_ 2 #define CAFFE2_UTILS_MKL_OPERATOR_H_ 4 #include "caffe2/core/operator.h" 5 #include "caffe2/mkl/utils/mkl_dnn_cppwrapper.h" 6 #include "caffe2/mkl/utils/mkl_memory.h" 7 #include "caffe2/proto/caffe2.pb.h" 9 CAFFE2_DECLARE_bool(caffe2_mkl_memonger_in_use);
13 CAFFE_DECLARE_REGISTRY(
18 #define REGISTER_MKL_OPERATOR_CREATOR(key, ...) \ 19 CAFFE_REGISTER_CREATOR(MKLOperatorRegistry, key, __VA_ARGS__) 20 #define REGISTER_MKL_OPERATOR(name, ...) \ 21 CAFFE_REGISTER_CLASS(MKLOperatorRegistry, name, __VA_ARGS__) 22 #define REGISTER_MKL_OPERATOR_STR(str_name, ...) \ 23 CAFFE_REGISTER_TYPED_CLASS(MKLOperatorRegistry, str_name, __VA_ARGS__) 25 #define REGISTER_MKL_OPERATOR_WITH_ENGINE(name, engine, ...) \ 26 CAFFE_REGISTER_CLASS(MKLOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) 36 context_(operator_def.device_option()) {}
40 return OperatorBase::template Input<MKLMemory<T>>(idx);
43 return OperatorBase::template Output<MKLMemory<T>>(idx);
49 bool Run(
int )
final {
56 err.AppendMessage(getErrorMsg());
64 void WaitEvent(
const Event& ev,
int )
final {
65 context_.WaitEvent(ev);
68 void WaitEvents(
const std::vector<const Event*>& events,
int )
70 for (
const auto& ev : events) {
71 context_.WaitEvent(*ev);
75 void RecordEvent(
const char* err_msg =
nullptr)
final {
77 context_.Record(event_.get(), err_msg);
81 virtual bool RunOnDevice() = 0;
83 inline void ExecutePrimitive() {
84 MKLDNN_SAFE_CALL(mkl::dnnExecute<T>(primitive_, resources_));
88 std::string getErrorMsg() {
89 if (has_debug_def()) {
90 return "Error from operator: " + ProtoDebugString(debug_def());
92 return "Error from operator: no op def";
100 vector<vector<TIndex>> input_size_cache_;
106 void* resources_[dnnResourceNumber];
110 #define USE_MKLOPERATOR_FUNCTIONS(T) \ 111 USE_OPERATOR_BASE_FUNCTIONS; \ 112 using MKLOperator<T>::Input; \ 113 using MKLOperator<T>::Output; \ 114 using MKLOperator<T>::ExecutePrimitive; \ 115 using MKLOperator<T>::primitive_; \ 116 using MKLOperator<T>::input_size_cache_; \ 117 using MKLOperator<T>::buffer_; \ 118 using MKLOperator<T>::resources_ 120 #define USE_SIMPLE_MKL_CTOR_DTOR(name, T) \ 121 name(const OperatorDef& operator_def, Workspace* ws) \ 122 : MKLOperator<T>(operator_def, ws) {} \ 127 #endif // CAFFE2_UTILS_MKL_OPERATOR_H_
The MKL Context, which is largely the same as the CPUContext.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
A wrapper around an opaque MKL internal resource that has certain layouts and convertion primitives s...