1 #include "caffe2/core/net.h" 2 #include "caffe2/core/net_simple.h" 5 #include <unordered_map> 6 #include <unordered_set> 8 #include "caffe2/core/operator.h" 9 #include "caffe2/core/timer.h" 10 #include "caffe2/proto/caffe2.pb.h" 11 #include "caffe2/utils/proto_utils.h" 15 CAFFE_DEFINE_REGISTRY(
18 const std::shared_ptr<const NetDef>&,
22 const std::shared_ptr<const NetDef>& def,
25 def->external_input().begin(),
26 def->external_input().end()),
28 def->external_output().begin(),
29 def->external_output().end()),
33 for (
const OperatorDef& op : def->op()) {
34 if (op.has_device_option()) {
36 !op.device_option().has_node_name(),
37 "node_name must be empty for all operators at execution time.");
42 std::set<string> known_blobs(
43 external_input_.begin(), external_input_.end());
44 std::set<string> remaining_output(
45 external_output_.begin(), external_output_.end());
46 for (
const auto& blob : known_blobs) {
47 remaining_output.erase(blob);
49 for (
const OperatorDef& op : def->op()) {
50 for (
const string& in : op.input()) {
51 if (!known_blobs.count(in)) {
52 if (external_input_.size()) {
56 ": Source for input ",
58 " is unknown for net ",
61 ProtoDebugString(op));
65 VLOG(1) <<
"op " << op.type() <<
": input " << in <<
" is unknown.";
69 for (
const string& out : op.output()) {
70 known_blobs.insert(out);
71 remaining_output.erase(out);
76 remaining_output.size() == 0,
77 "Some of the blobs are declared as output but never produced by the " 80 ", the first one is ",
81 *remaining_output.begin());
84 bool NetBase::RunAsync() {
85 for (
auto& op : GetOperators()) {
92 std::vector<NetObserverCreator>* GetNetObserverCreators() {
93 static std::vector<NetObserverCreator> creators;
98 void AddGlobalNetObserverCreator(NetObserverCreator creator) {
99 GetNetObserverCreators()->push_back(creator);
100 VLOG(1) <<
"Have set a custom GlobalNetObserverCreator";
104 std::shared_ptr<NetDef> tmp_net_def(
new NetDef(net_def));
109 const std::shared_ptr<const NetDef>& net_def,
113 unique_ptr<NetBase> net;
114 if (!net_def->has_type()) {
115 net = std::unique_ptr<NetBase>(
new SimpleNet(net_def, ws));
117 net = NetRegistry()->Create(net_def->type(), net_def, ws);
119 VLOG(1) <<
"Adding a global observer to a net";
121 auto* observer_creators = GetNetObserverCreators();
122 for (
auto& creator : *observer_creators) {
123 net->AttachObserver(creator(net.get()));
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.