Caffe2 - C++ API
A deep learning, cross platform ML framework
net.h
1 #ifndef CAFFE2_CORE_NET_H_
2 #define CAFFE2_CORE_NET_H_
3 
4 #include <atomic>
5 #include <climits>
6 #include <cstddef>
7 #include <thread> // NOLINT
8 #include <typeinfo>
9 #include <unordered_map>
10 #include <vector>
11 
12 #include "caffe2/core/blob.h"
13 #include "caffe2/core/common.h"
14 #include "caffe2/core/logging.h"
15 #include "caffe2/core/observer.h"
16 #include "caffe2/core/operator_schema.h"
17 #include "caffe2/core/registry.h"
18 #include "caffe2/core/tensor.h"
19 #include "caffe2/core/workspace.h"
20 #include "caffe2/proto/caffe2.pb.h"
21 #include "caffe2/utils/simple_queue.h"
22 
23 namespace caffe2 {
24 
25 class NetBase;
26 typedef ObserverBase<NetBase> NetObserver;
27 typedef std::function<std::unique_ptr<NetObserver>(NetBase*)>
28  NetObserverCreator;
29 
30 class OperatorBase;
31 class Workspace;
32 
33 // Net is a thin struct that owns all the operators together with the operator
34 // contexts.
35 class NetBase : public Observable<NetBase> {
36  public:
37  NetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
38  virtual ~NetBase() noexcept {}
39 
40  virtual bool SupportsAsync() = 0;
41  inline const vector<const Event*>& events() const {
42  return events_;
43  }
44 
45  virtual void Wait() {
46  // by default just wait till all events are finished
47  for (const auto& event : events_) {
48  event->Finish();
49  }
50  }
51 
52  virtual bool Run() {
53  if (!RunAsync()) {
54  LOG(ERROR) << "Failed to execute async run";
55  return false;
56  }
57  Wait();
58  for (const Event* event : events_) {
59  if (event->Query() != EventStatus::EVENT_SUCCESS) {
60  CAFFE_THROW(event->ErrorMessage());
61  }
62  }
63  return true;
64  }
65 
66  virtual bool RunAsync();
67 
77  virtual vector<float> TEST_Benchmark(
78  const int /*warmup_runs*/,
79  const int /*main_runs*/,
80  const bool /*run_individual*/) {
81  LOG(ERROR) << "Benchmark not implemented for this net type.";
82  return vector<float>();
83  }
84 
85  inline const vector<string>& external_output() const {
86  return external_output_;
87  }
88 
89  inline const vector<string>& external_input() const {
90  return external_input_;
91  }
92 
93  /* Used to attach Observers to operators of a Net
94  *
95  * Returns pointers to objects owned with unique_ptrs.
96  * Use with caution.
97  */
98  virtual vector<OperatorBase*> GetOperators() const = 0;
99 
100  const string& Name() const {
101  return name_;
102  }
103 
104  inline const NetDef& debug_def() const {
105  CAFFE_ENFORCE(has_debug_def(), "net_def was null!");
106  return *net_def_;
107  }
108 
109  inline bool has_debug_def() const {
110  return net_def_ != nullptr;
111  }
112 
113  protected:
114  virtual bool DoRunAsync() {
115  CAFFE_THROW("Not implemented");
116  };
117 
118  vector<string> external_input_;
119  vector<string> external_output_;
120  string name_;
121  vector<const Event*> events_;
122  std::shared_ptr<const NetDef> net_def_;
123  DISABLE_COPY_AND_ASSIGN(NetBase);
124 };
125 
126 CAFFE_DECLARE_REGISTRY(
127  NetRegistry,
128  NetBase,
129  const std::shared_ptr<const NetDef>&,
130  Workspace*);
131 #define REGISTER_NET_CREATOR(key, ...) \
132  CAFFE_REGISTER_CREATOR(NetRegistry, key, __VA_ARGS__)
133 #define REGISTER_NET(name, ...) \
134  CAFFE_REGISTER_CLASS(NetRegistry, name, __VA_ARGS__)
135 
143 unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws);
144 unique_ptr<NetBase> CreateNet(
145  const std::shared_ptr<const NetDef>& net_def,
146  Workspace* ws);
147 
148 void AddGlobalNetObserverCreator(NetObserverCreator creator);
149 
150 } // namespace caffe2
151 
152 #endif // CAFFE2_CORE_NET_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
Inherit to make your class observable.
Definition: observer.h:39
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:103
virtual vector< float > TEST_Benchmark(const int, const int, const bool)
Benchmarks a network.
Definition: net.h:77