Caffe2 - C++ API
A deep learning, cross platform ML framework
predictor.cc
1 #include "caffe2/core/predictor.h"
2 
3 #include <unordered_set>
4 
5 namespace caffe2 {
6 
7 namespace {
8 
9 void enforceIsTensor(Workspace* ws, const std::string& name) {
10  auto blob = ws->GetBlob(name);
11  CAFFE_ENFORCE(blob, "Blob does not exist: ", name);
12  CAFFE_ENFORCE(
13  blob->template IsType<TensorCPU>(), "Blob is not a CPU Tensor: ", name);
14 }
15 
16 void shareInputTensor(
17  Workspace* ws,
18  const std::string& name,
19  TensorCPU* input) {
20  enforceIsTensor(ws, name);
21  auto* blob = ws->GetBlob(name);
22  CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
23  auto* tensor = blob->template GetMutable<TensorCPU>();
24  tensor->ResizeLike(*input);
25  tensor->ShareData(*input);
26 }
27 
28 TensorCPU* extractOutputTensor(Workspace* ws, const std::string& name) {
29  enforceIsTensor(ws, name);
30  auto* blob = ws->GetBlob(name);
31  CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
32  return blob->template GetMutable<TensorCPU>();
33 }
34 
35 const NetDef& getNet(const MetaNetDef& def, const std::string& name) {
36  for (const auto& n : def.nets()) {
37  if (n.key() == name) {
38  return n.value();
39  }
40  }
41  CAFFE_THROW("Net not found: ", name);
42 }
43 
44 const ::google::protobuf::RepeatedPtrField<::std::string>& getBlobs(
45  const MetaNetDef& def,
46  const std::string& name) {
47  for (const auto& b : def.blobs()) {
48  if (b.key() == name) {
49  return b.value();
50  }
51  }
52  CAFFE_THROW("Blob not found: ", name);
53 }
54 } // namespace
55 
56 Predictor::Predictor(const MetaNetDef& def, Workspace* parent)
57  : Predictor(
58  getNet(
59  def,
60  PredictorConsts::default_instance().global_init_net_type()),
61  getNet(def, PredictorConsts::default_instance().predict_net_type()),
62  parent) {
63  const auto& inputs =
64  getBlobs(def, PredictorConsts::default_instance().inputs_blob_type());
65  for (const auto& input : inputs) {
66  inputNames_.insert(input);
67  }
68 }
69 
70 Predictor::Predictor(
71  const NetDef& init_net,
72  const NetDef& run_net,
73  Workspace* parent)
74  : run_net_(run_net), ws_(parent) {
75  CAFFE_ENFORCE(ws_.RunNetOnce(init_net));
76 
77  // real model inputs can be fed later in run* functions
78  const auto& initialized_vec = ws_.Blobs();
79  const std::unordered_set<std::string> initialized{initialized_vec.begin(),
80  initialized_vec.end()};
81  for (const auto& name : run_net.external_input()) {
82  if (!initialized.count(name)) {
83  auto* blob = ws_.CreateBlob(name);
84  blob->template GetMutable<TensorCPU>();
85  }
86  }
87  CAFFE_ENFORCE(ws_.CreateNet(run_net));
88 }
89 
90 Predictor::~Predictor() {}
91 
92 bool Predictor::run(const TensorVector& inputs, TensorVector* outputs) {
93  CAFFE_ENFORCE(inputs.size() <= run_net_.external_input_size());
94  for (auto i = 0; i < inputs.size(); ++i) {
95  shareInputTensor(&ws_, run_net_.external_input(i), inputs[i]);
96  }
97 
98  if (!ws_.RunNet(run_net_.name())) {
99  return false;
100  }
101 
102  outputs->resize(run_net_.external_output_size());
103  for (auto i = 0; i < outputs->size(); ++i) {
104  (*outputs)[i] = extractOutputTensor(&ws_, run_net_.external_output(i));
105  }
106  return true;
107 }
108 
109 bool Predictor::run_map(const TensorMap& inputs, TensorVector* outputs) {
110  if (!inputNames_.empty()) {
111  CAFFE_ENFORCE_EQ(inputs.size(), inputNames_.size());
112  }
113  for (auto input : inputs) {
114  if (!inputNames_.empty()) {
115  CAFFE_ENFORCE_GT(inputNames_.count(input.first), 0);
116  }
117  shareInputTensor(&ws_, input.first, input.second);
118  }
119 
120  if (!ws_.RunNet(run_net_.name())) {
121  return false;
122  }
123 
124  outputs->resize(run_net_.external_output_size());
125  for (auto i = 0; i < outputs->size(); ++i) {
126  (*outputs)[i] = extractOutputTensor(&ws_, run_net_.external_output(i));
127  }
128  return true;
129 }
130 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...