Caffe2 - C++ API
A deep learning, cross platform ML framework
net.cc
1 #include "caffe2/core/net.h"
2 #include "caffe2/core/net_simple.h"
3 
4 #include <set>
5 #include <unordered_map>
6 #include <unordered_set>
7 
8 #include "caffe2/core/operator.h"
9 #include "caffe2/core/timer.h"
10 #include "caffe2/proto/caffe2.pb.h"
11 #include "caffe2/utils/proto_utils.h"
12 
13 namespace caffe2 {
14 
15 CAFFE_DEFINE_REGISTRY(
16  NetRegistry,
17  NetBase,
18  const std::shared_ptr<const NetDef>&,
19  Workspace*);
20 
21 NetBase::NetBase(
22  const std::shared_ptr<const NetDef>& def,
23  Workspace* /* unused */)
24  : external_input_(
25  def->external_input().begin(),
26  def->external_input().end()),
27  external_output_(
28  def->external_output().begin(),
29  def->external_output().end()),
30  name_(def->name()),
31  net_def_(def) {
32  // Check that node_name is empty for all ops
33  for (const OperatorDef& op : def->op()) {
34  if (op.has_device_option()) {
35  CAFFE_ENFORCE(
36  !op.device_option().has_node_name(),
37  "node_name must be empty for all operators at execution time.");
38  }
39  }
40 
41  // Go through the operators and make sure that blobs are correctly made.
42  std::set<string> known_blobs(
43  external_input_.begin(), external_input_.end());
44  std::set<string> remaining_output(
45  external_output_.begin(), external_output_.end());
46  for (const auto& blob : known_blobs) {
47  remaining_output.erase(blob);
48  }
49  for (const OperatorDef& op : def->op()) {
50  for (const string& in : op.input()) {
51  if (!known_blobs.count(in)) {
52  if (external_input_.size()) {
53  CAFFE_THROW(
54  "op ",
55  op.type(),
56  ": Source for input ",
57  in,
58  " is unknown for net ",
59  def->name(),
60  ", operator ",
61  ProtoDebugString(op));
62  } else {
63  // If we are not declaring input and output, we will simply VLOG it
64  // for debugging purposes.
65  VLOG(1) << "op " << op.type() << ": input " << in << " is unknown.";
66  }
67  }
68  }
69  for (const string& out : op.output()) {
70  known_blobs.insert(out);
71  remaining_output.erase(out);
72  }
73  }
74  // Finally, check if all declared outputs are being created.
75  CAFFE_ENFORCE(
76  remaining_output.size() == 0,
77  "Some of the blobs are declared as output but never produced by the "
78  "net ",
79  def->name(),
80  ", the first one is ",
81  *remaining_output.begin());
82 }
83 
84 bool NetBase::RunAsync() {
85  for (auto& op : GetOperators()) {
86  op->ResetEvent();
87  }
88  return DoRunAsync();
89 }
90 
91 namespace {
92 std::vector<NetObserverCreator>* GetNetObserverCreators() {
93  static std::vector<NetObserverCreator> creators;
94  return &creators;
95 }
96 } // namespace
97 
98 void AddGlobalNetObserverCreator(NetObserverCreator creator) {
99  GetNetObserverCreators()->push_back(creator);
100  VLOG(1) << "Have set a custom GlobalNetObserverCreator";
101 }
102 
103 unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws) {
104  std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
105  return CreateNet(tmp_net_def, ws);
106 }
107 
108 unique_ptr<NetBase> CreateNet(
109  const std::shared_ptr<const NetDef>& net_def,
110  Workspace* ws) {
111  // In default, we will return a simple network that just runs all operators
112  // sequentially.
113  unique_ptr<NetBase> net;
114  if (!net_def->has_type()) {
115  net = std::unique_ptr<NetBase>(new SimpleNet(net_def, ws));
116  } else {
117  net = NetRegistry()->Create(net_def->type(), net_def, ws);
118  }
119  VLOG(1) << "Adding a global observer to a net";
120  if (net) {
121  auto* observer_creators = GetNetObserverCreators();
122  for (auto& creator : *observer_creators) {
123  net->AttachObserver(creator(net.get()));
124  }
125  }
126  return net;
127 }
128 
129 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:103