1 #ifndef CAFFE2_CORE_NET_H_ 2 #define CAFFE2_CORE_NET_H_ 9 #include <unordered_map> 12 #include "caffe2/core/blob.h" 13 #include "caffe2/core/common.h" 14 #include "caffe2/core/logging.h" 15 #include "caffe2/core/observer.h" 16 #include "caffe2/core/operator_schema.h" 17 #include "caffe2/core/registry.h" 18 #include "caffe2/core/tensor.h" 19 #include "caffe2/core/workspace.h" 20 #include "caffe2/proto/caffe2.pb.h" 21 #include "caffe2/utils/simple_queue.h" 26 typedef ObserverBase<NetBase> NetObserver;
27 typedef std::function<std::unique_ptr<NetObserver>(NetBase*)>
40 virtual bool SupportsAsync() = 0;
41 inline const vector<const Event*>& events()
const {
47 for (
const auto& event : events_) {
54 LOG(ERROR) <<
"Failed to execute async run";
58 for (
const Event* event : events_) {
59 if (event->Query() != EventStatus::EVENT_SUCCESS) {
60 CAFFE_THROW(event->ErrorMessage());
66 virtual bool RunAsync();
81 LOG(ERROR) <<
"Benchmark not implemented for this net type.";
82 return vector<float>();
85 inline const vector<string>& external_output()
const {
86 return external_output_;
89 inline const vector<string>& external_input()
const {
90 return external_input_;
98 virtual vector<OperatorBase*> GetOperators()
const = 0;
100 const string& Name()
const {
104 inline const NetDef& debug_def()
const {
105 CAFFE_ENFORCE(has_debug_def(),
"net_def was null!");
109 inline bool has_debug_def()
const {
110 return net_def_ !=
nullptr;
114 virtual bool DoRunAsync() {
115 CAFFE_THROW(
"Not implemented");
118 vector<string> external_input_;
119 vector<string> external_output_;
121 vector<const Event*> events_;
122 std::shared_ptr<const NetDef> net_def_;
123 DISABLE_COPY_AND_ASSIGN(
NetBase);
126 CAFFE_DECLARE_REGISTRY(
129 const std::shared_ptr<const NetDef>&,
131 #define REGISTER_NET_CREATOR(key, ...) \ 132 CAFFE_REGISTER_CREATOR(NetRegistry, key, __VA_ARGS__) 133 #define REGISTER_NET(name, ...) \ 134 CAFFE_REGISTER_CLASS(NetRegistry, name, __VA_ARGS__) 145 const std::shared_ptr<const NetDef>& net_def,
148 void AddGlobalNetObserverCreator(NetObserverCreator creator);
152 #endif // CAFFE2_CORE_NET_H_ Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Inherit to make your class observable.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
virtual vector< float > TEST_Benchmark(const int, const int, const bool)
Benchmarks a network.