1 #ifndef CAFFE2_CORE_OPERATOR_H_ 2 #define CAFFE2_CORE_OPERATOR_H_ 12 #include "caffe2/core/blob.h" 13 #include "caffe2/core/common.h" 14 #include "caffe2/core/net.h" 15 #include "caffe2/core/observer.h" 16 #include "caffe2/core/operator_gradient.h" 17 #include "caffe2/core/operator_schema.h" 18 #include "caffe2/core/registry.h" 19 #include "caffe2/core/tensor.h" 20 #include "caffe2/core/types.h" 21 #include "caffe2/core/workspace.h" 22 #include "caffe2/proto/caffe2.pb.h" 23 #include "caffe2/utils/proto_utils.h" 28 typedef ObserverBase<OperatorBase> OperatorObserver;
38 CAFFE_ENFORCE(operator_def_,
"operator_def was null!");
39 return ArgumentHelper::HasArgument(*operator_def_, name);
45 inline T GetSingleArgument(
const string& name,
const T& default_value)
const {
46 CAFFE_ENFORCE(operator_def_,
"operator_def was null!");
47 return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
48 *operator_def_, name, default_value);
51 inline bool HasSingleArgumentOfType(
const string& name)
const {
52 CAFFE_ENFORCE(operator_def_,
"operator_def was null!");
53 return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
54 *operator_def_, name);
57 inline vector<T> GetRepeatedArgument(
59 const vector<T>& default_value = {})
const {
60 CAFFE_ENFORCE(operator_def_,
"operator_def was null!");
61 return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
62 *operator_def_, name, default_value);
67 inline const T& Input(
int idx) {
68 DCHECK_LT(idx, inputs_.size());
70 return inputs_.at(idx)->template Get<T>();
72 if (has_debug_def()) {
73 enf.AppendMessage(
".\nOffending Blob name: ");
74 enf.AppendMessage(debug_def().input(idx));
75 enf.AppendMessage(
".\n");
82 inline T* Output(
int idx) {
83 return outputs_.at(idx)->template GetMutable<T>();
87 inline T* Output(
int idx, T* allocated) {
88 outputs_.at(idx)->Reset(allocated);
92 inline const Blob& InputBlob(
int idx) {
93 return *inputs_.at(idx);
96 inline Blob* OutputBlob(
int idx) {
97 return outputs_.at(idx);
100 template <
typename T>
101 inline bool InputIsType(
int idx) {
102 return inputs_.at(idx)->template IsType<T>();
105 template <
typename T>
106 inline bool OutputIsType(
int idx) {
107 return outputs_.at(idx)->template IsType<T>();
110 inline int InputSize()
const {
111 return inputs_.size();
113 inline int OutputSize()
const {
114 return outputs_.size();
116 inline const vector<const Blob*>& Inputs()
const {
return inputs_; }
117 inline const vector<Blob*>& Outputs() {
return outputs_; }
118 vector<TensorShape> InputTensorShapes();
120 virtual void WaitEvent(
const Event& ev,
int stream_id = -1) {
124 inline void Wait(
const OperatorBase& other,
int stream_id = -1) {
125 WaitEvent(other.event(), stream_id);
128 virtual void WaitEvents(
129 const std::vector<const Event*>& events,
130 int stream_id = -1) {
131 for (
const auto& ev : events) {
136 virtual void Finish() {
142 virtual bool Run(
int = 0) {
143 CAFFE_NOT_IMPLEMENTED;
146 virtual bool HasAsyncPart()
const {
150 virtual bool SupportsAsyncScheduling()
const {
158 virtual bool RunAsync(
int stream_id = 0) {
159 return Run(stream_id);
163 if (!has_debug_def()) {
168 if (err->caller() !=
nullptr) {
169 for (
int i = 0; i < inputs_.size(); i++) {
170 if (inputs_[i]->GetRaw() == err->caller()) {
173 "\n** while accessing input: " + debug_def().input(i));
177 for (
int i = 0; i < outputs_.size(); i++) {
178 if (outputs_[i]->GetRaw() == err->caller()) {
180 err->AppendMessage(
"\n OR ");
183 "\n** while accessing output: " + debug_def().output(i));
190 inline const OperatorDef& debug_def()
const {
191 CAFFE_ENFORCE(has_debug_def(),
"operator_def was null!");
192 return *operator_def_;
195 inline void set_debug_def(
196 const std::shared_ptr<const OperatorDef>& operator_def) {
197 operator_def_ = operator_def;
200 inline bool has_debug_def()
const {
201 return operator_def_ !=
nullptr;
205 void RecordLastFailedOpNetPosition() {
206 if (net_position_ != kNoNetPositionSet) {
207 VLOG(1) <<
"Operator with id " << net_position_ <<
" failed";
208 operator_ws_->last_failed_op_net_position = net_position_;
210 VLOG(1) <<
"Failed operator doesn't have id set";
214 int net_position()
const {
215 return net_position_;
218 void set_net_position(
int idx) {
222 const DeviceOption& device_option()
const {
223 return device_option_;
226 const Event& event()
const {
227 CAFFE_ENFORCE(event_,
"Event is disabled");
232 CAFFE_ENFORCE(event_,
"Event is disabled");
242 void DisableEvent() {
246 bool IsEventDisabled()
const {
253 virtual bool IsStreamFree(
int )
const {
257 const std::string& type()
const {
258 CAFFE_ENFORCE(operator_def_.get() !=
nullptr);
259 return operator_def_->type();
262 void annotate_engine(
const std::string& engine) {
266 const std::string& engine()
const {
271 static constexpr
int kNoNetPositionSet = -1;
275 std::shared_ptr<const OperatorDef> operator_def_;
276 DeviceOption device_option_;
278 vector<const Blob*> inputs_;
279 vector<Blob*> outputs_;
281 int net_position_{kNoNetPositionSet};
284 virtual void RecordEvent(
const char* err_msg =
nullptr) {
285 CAFFE_NOT_IMPLEMENTED;
289 std::unique_ptr<Event> event_;
296 #define USE_SIMPLE_BASE_CTOR_DTOR(name) \ 297 name(const OperatorDef& operator_def, Workspace* ws) \ 298 : OperatorBase(operator_def, ws) {} \ 299 virtual ~name() noexcept {} 303 #define OP_SINGLE_ARG(type, name, variable, default) \ 304 variable(OperatorBase::GetSingleArgument<type>(name, (default))) 316 #define INPUT_TAGS(first_input, ...) \ 317 enum _InputTags { first_input = 0, __VA_ARGS__ } 318 #define OUTPUT_TAGS(first_input, ...) \ 319 enum _OutputTags { first_input = 0, __VA_ARGS__ } 324 template <
class Context>
328 :
OperatorBase(operator_def, ws), context_(operator_def.device_option()) {
331 context_.SwitchToDevice(0);
336 return OperatorBase::template Input<Tensor<Context>>(idx);
339 return OperatorBase::template Output<Tensor<Context>>(idx);
342 void WaitEvent(
const Event& ev,
int stream_id = -1)
final {
343 if (stream_id >= 0) {
344 context_.SwitchToDevice(stream_id);
346 context_.WaitEvent(ev);
349 void WaitEvents(
const std::vector<const Event*>& events,
int stream_id = -1)
351 if (stream_id >= 0) {
352 context_.SwitchToDevice(stream_id);
354 for (
const auto& ev : events) {
355 context_.WaitEvent(*ev);
364 bool Run(
int stream_id = 0)
final {
368 context_.SwitchToDevice(stream_id);
369 bool result = RunOnDevice();
371 this->RecordLastFailedOpNetPosition();
373 context_.FinishDeviceComputation();
379 if (has_debug_def()) {
381 "Error from operator: \n" + ProtoDebugString(debug_def()));
382 AddRelatedBlobInfo(&err);
384 this->RecordLastFailedOpNetPosition();
387 this->RecordLastFailedOpNetPosition();
392 bool RunAsync(
int stream_id = 0)
final {
394 context_.SwitchToDevice(stream_id);
395 auto result = RunOnDevice();
397 if (HasAsyncPart()) {
402 event().SetFinished();
405 event().SetFinished(getErrorMsg().c_str());
406 this->RecordLastFailedOpNetPosition();
410 if (has_debug_def()) {
412 "Error from operator: \n" + ProtoDebugString(debug_def()));
413 AddRelatedBlobInfo(&err);
415 event().SetFinished(err.what());
416 this->RecordLastFailedOpNetPosition();
418 }
catch (
const std::exception& err) {
419 event().SetFinished(err.what());
420 this->RecordLastFailedOpNetPosition();
423 event().SetFinished(getErrorMsg().c_str());
424 this->RecordLastFailedOpNetPosition();
429 bool IsStreamFree(
int stream_id)
const override {
430 return context_.IsStreamFree(device_option(), stream_id);
433 virtual bool RunOnDevice() = 0;
442 bool HasAsyncPart()
const override {
443 return context_.HasAsyncPartDefault();
459 bool SupportsAsyncScheduling()
const override {
460 return HasAsyncPart() && context_.SupportsAsyncScheduling();
463 const Context* getContext()
const {
468 void RecordEvent(
const char* err_msg =
nullptr)
final {
470 context_.Record(event_.get(), err_msg);
474 std::string getErrorMsg() {
475 if (has_debug_def()) {
476 return "Error from operator: " + ProtoDebugString(debug_def());
478 return "Error from operator: no op def";
485 #define USE_OPERATOR_BASE_FUNCTIONS \ 486 using OperatorBase::HasArgument; \ 487 using OperatorBase::GetSingleArgument; \ 488 using OperatorBase::HasSingleArgumentOfType; \ 489 using OperatorBase::GetRepeatedArgument; \ 490 using OperatorBase::InputIsType; \ 491 using OperatorBase::InputSize; \ 492 using OperatorBase::OutputSize 494 #define USE_OPERATOR_FUNCTIONS(context) \ 495 USE_OPERATOR_BASE_FUNCTIONS; \ 496 using Operator<context>::context_; \ 497 using Operator<context>::Input; \ 498 using Operator<context>::InputBlob; \ 499 using Operator<context>::Output; \ 500 using Operator<context>::OutputBlob 502 #define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context) 504 #define USE_SIMPLE_CTOR_DTOR(name) \ 505 name(const OperatorDef& operator_def, Workspace* ws) \ 506 : Operator<Context>(operator_def, ws) {} \ 507 virtual ~name() noexcept {} 539 #define USE_DISPATCH_HELPER \ 540 template <typename FirstArg, typename... ExtraArgs> \ 541 friend struct DispatchHelper 543 template <
int... Values>
546 template <
typename... Types>
556 template <
typename... Types>
559 template <
typename Sizes,
typename... ExtraArgs>
562 template <
int FirstVal,
int... Values,
typename... ExtraArgs>
564 template <
typename Op>
565 static bool call(Op* op,
int value) {
566 if (FirstVal == value) {
567 return op->template DoRunWithValue<ExtraArgs..., FirstVal>();
574 template <
typename... ExtraArgs>
576 template <
typename Op>
577 static bool call(Op* op, TIndex ) {
578 return op->template DoRunWithValue<ExtraArgs..., -1>();
582 #define CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER( \ 583 TensorTypes, DoRunWithType, DoRunWithOtherType) \ 584 template <typename FirstType, typename... Types, typename... ExtraArgs> \ 585 struct DispatchHelper<TensorTypes<FirstType, Types...>, ExtraArgs...> { \ 586 template <typename Op> \ 587 static bool call(Op* op, const TypeMeta& meta) { \ 589 !std::is_same<GenericTensorImplementation, FirstType>::value, \ 590 "GenericTensorImplementation must be the last in TensorTypes list"); \ 591 if (meta.Match<FirstType>()) { \ 592 return op->template DoRunWithType<ExtraArgs..., FirstType>(); \ 594 return DispatchHelper<TensorTypes<Types...>, ExtraArgs...>:: \ 595 template call<Op>(op, meta); \ 597 template <typename Op, typename Context> \ 598 static bool call(Op* op, const Tensor<Context>& tensor) { \ 599 return call<Op>(op, tensor.meta()); \ 601 template <typename Op> \ 602 static bool call(Op* op, const Blob& blob) { \ 603 return call<Op>(op, blob.meta()); \ 607 template <typename... ExtraArgs> \ 608 struct DispatchHelper<TensorTypes<>, ExtraArgs...> { \ 609 template <typename Op> \ 610 static bool call(Op* , const TypeMeta& meta) { \ 611 CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \ 613 template <typename Op, typename Context> \ 614 static bool call(Op* op, const Tensor<Context>& tensor) { \ 615 return call<Op>(op, tensor.meta()); \ 617 template <typename Op> \ 618 static bool call(Op* op, const Blob& blob) { \ 619 return call<Op>(op, blob.meta()); \ 623 template <typename... ExtraArgs> \ 624 struct DispatchHelper< \ 625 TensorTypes<GenericTensorImplementation>, \ 627 template <typename Op> \ 628 static bool call(Op* op, const TypeMeta&) { \ 629 return op->template DoRunWithOtherType<ExtraArgs...>(); \ 631 template <typename Op, typename Context> \ 632 static bool call(Op* op, const Tensor<Context>& tensor) { \ 633 return call<Op>(op, tensor.meta()); \ 635 template <typename Op> \ 636 static bool call(Op* op, const Blob& blob) { \ 637 return call<Op>(op, blob.meta()); \ 640 CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER(
644 CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER(
648 #undef CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER 657 std::unique_ptr<OperatorBase>,
663 std::unique_ptr<OperatorBase>,
666 std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry();
670 if (gDeviceTypeRegistry()->count(type)) {
671 std::cerr <<
"Device type " << type
672 <<
"registered twice. This should not happen. Did you have " 673 "duplicated numbers assigned to different devices?";
677 gDeviceTypeRegistry()->emplace(type, func());
681 #define CAFFE_REGISTER_DEVICE_TYPE(type, registry_function) \ 683 static DeviceTypeRegisterer CAFFE_ANONYMOUS_VARIABLE( \ 684 DeviceType)(type, ®istry_function); \ 694 CAFFE_DECLARE_REGISTRY(
699 #define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \ 700 CAFFE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__) 701 #define REGISTER_CPU_OPERATOR(name, ...) \ 702 extern void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ 703 static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CPU##name() { \ 704 CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ 706 CAFFE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__) 707 #define REGISTER_CPU_OPERATOR_STR(str_name, ...) \ 708 CAFFE_REGISTER_TYPED_CLASS(CPUOperatorRegistry, str_name, __VA_ARGS__) 710 #define REGISTER_CPU_OPERATOR_WITH_ENGINE(name, engine, ...) \ 711 CAFFE_REGISTER_CLASS(CPUOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) 713 CAFFE_DECLARE_REGISTRY(
714 CUDAOperatorRegistry,
718 #define REGISTER_CUDA_OPERATOR_CREATOR(key, ...) \ 719 CAFFE_REGISTER_CREATOR(CUDAOperatorRegistry, key, __VA_ARGS__) 720 #define REGISTER_CUDA_OPERATOR(name, ...) \ 721 extern void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ 722 static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CUDA##name() { \ 723 CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ 725 CAFFE_REGISTER_CLASS(CUDAOperatorRegistry, name, __VA_ARGS__) 726 #define REGISTER_CUDA_OPERATOR_STR(str_name, ...) \ 727 CAFFE_REGISTER_TYPED_CLASS(CUDAOperatorRegistry, str_name, __VA_ARGS__) 729 #define REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, engine, ...) \ 730 CAFFE_REGISTER_CLASS( \ 731 CUDAOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) 734 #define REGISTER_CUDNN_OPERATOR(name, ...) \ 735 REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__) 747 const int registered_ops = CPUOperatorRegistry()->Keys().size();
753 if (registered_ops == 0) {
755 "You might have made a build error: the Caffe2 library does not seem " 756 "to be linked with whole-static library option. To do so, use " 757 "-Wl,-force_load (clang) or -Wl,--whole-archive (gcc) to link the " 771 const char* what()
const noexcept
override {
782 #define OPERATOR_NEEDS_FEATURE(condition, ...) \ 783 if (!(condition)) { \ 784 throw UnsupportedOperatorFeature(::caffe2::MakeString(__VA_ARGS__)); \ 789 unique_ptr<OperatorBase> CreateOperator(
790 const OperatorDef& operator_def,
792 int net_position = OperatorBase::kNoNetPositionSet);
794 const std::string OpRegistryKey(
795 const std::string& op_type,
796 const std::string& engine =
"");
800 using EnginePrefType = std::vector<std::string>;
802 using PerOpEnginePrefType =
803 CaffeMap<int, CaffeMap<std::string, EnginePrefType>>;
805 using GlobalEnginePrefType = CaffeMap<int, EnginePrefType>;
806 void SetPerOpEnginePref(
const PerOpEnginePrefType& per_op_engine_pref);
807 void SetGlobalEnginePref(
const GlobalEnginePrefType& global_engine_pref);
809 const PerOpEnginePrefType& per_op_engine_pref,
810 const GlobalEnginePrefType& global_engine_pref);
811 void SetOpEnginePref(
812 const std::string& op_type,
813 const CaffeMap<int, EnginePrefType>& op_pref);
815 TensorShape GetTensorShapeOfBlob(
const Blob* b);
817 TensorShapes InferBlobShapesAndTypesFromWorkspace(
819 const vector<std::unique_ptr<NetDef>>& nets);
821 TensorShapes InferBlobShapesAndTypesFromMap(
822 const CaffeMap<std::string, std::vector<TIndex>>& blob_dimensions,
823 const vector<std::unique_ptr<NetDef>>& nets);
825 std::map<string, std::pair<DeviceOption, DeviceOption>> ValidateTensorDevices(
827 const OperatorDef& op_def);
830 std::set<std::string> GetRegisteredOperators();
834 #endif // CAFFE2_CORE_OPERATOR_H_
Blob is a general container that hosts a typed pointer.
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Inherit to make your class observable.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
A template class that allows one to register classes by keys.