1 #include "caffe2/core/predictor.h" 3 #include <unordered_set> 9 void enforceIsTensor(Workspace* ws,
const std::string& name) {
10 auto blob = ws->GetBlob(name);
11 CAFFE_ENFORCE(blob,
"Blob does not exist: ", name);
13 blob->template IsType<TensorCPU>(),
"Blob is not a CPU Tensor: ", name);
16 void shareInputTensor(
18 const std::string& name,
20 enforceIsTensor(ws, name);
21 auto* blob = ws->GetBlob(name);
22 CAFFE_ENFORCE(blob,
"Blob: ", name,
" does not exist");
23 auto* tensor = blob->template GetMutable<TensorCPU>();
24 tensor->ResizeLike(*input);
25 tensor->ShareData(*input);
28 TensorCPU* extractOutputTensor(Workspace* ws,
const std::string& name) {
29 enforceIsTensor(ws, name);
30 auto* blob = ws->GetBlob(name);
31 CAFFE_ENFORCE(blob,
"Blob: ", name,
" does not exist");
32 return blob->template GetMutable<TensorCPU>();
35 const NetDef& getNet(
const MetaNetDef& def,
const std::string& name) {
36 for (
const auto& n : def.nets()) {
37 if (n.key() == name) {
41 CAFFE_THROW(
"Net not found: ", name);
44 const ::google::protobuf::RepeatedPtrField<::std::string>& getBlobs(
45 const MetaNetDef& def,
46 const std::string& name) {
47 for (
const auto& b : def.blobs()) {
48 if (b.key() == name) {
52 CAFFE_THROW(
"Blob not found: ", name);
56 Predictor::Predictor(
const MetaNetDef& def, Workspace* parent)
60 PredictorConsts::default_instance().global_init_net_type()),
61 getNet(def, PredictorConsts::default_instance().predict_net_type()),
64 getBlobs(def, PredictorConsts::default_instance().inputs_blob_type());
65 for (
const auto& input : inputs) {
66 inputNames_.insert(input);
71 const NetDef& init_net,
72 const NetDef& run_net,
74 : run_net_(run_net), ws_(parent) {
75 CAFFE_ENFORCE(ws_.RunNetOnce(init_net));
78 const auto& initialized_vec = ws_.Blobs();
79 const std::unordered_set<std::string> initialized{initialized_vec.begin(),
80 initialized_vec.end()};
81 for (
const auto& name : run_net.external_input()) {
82 if (!initialized.count(name)) {
83 auto* blob = ws_.CreateBlob(name);
84 blob->template GetMutable<TensorCPU>();
87 CAFFE_ENFORCE(ws_.CreateNet(run_net));
90 Predictor::~Predictor() {}
92 bool Predictor::run(
const TensorVector& inputs, TensorVector* outputs) {
93 CAFFE_ENFORCE(inputs.size() <= run_net_.external_input_size());
94 for (
auto i = 0; i < inputs.size(); ++i) {
95 shareInputTensor(&ws_, run_net_.external_input(i), inputs[i]);
98 if (!ws_.RunNet(run_net_.name())) {
102 outputs->resize(run_net_.external_output_size());
103 for (
auto i = 0; i < outputs->size(); ++i) {
104 (*outputs)[i] = extractOutputTensor(&ws_, run_net_.external_output(i));
109 bool Predictor::run_map(
const TensorMap& inputs, TensorVector* outputs) {
110 if (!inputNames_.empty()) {
111 CAFFE_ENFORCE_EQ(inputs.size(), inputNames_.size());
113 for (
auto input : inputs) {
114 if (!inputNames_.empty()) {
115 CAFFE_ENFORCE_GT(inputNames_.count(input.first), 0);
117 shareInputTensor(&ws_, input.first, input.second);
120 if (!ws_.RunNet(run_net_.name())) {
124 outputs->resize(run_net_.external_output_size());
125 for (
auto i = 0; i < outputs->size(); ++i) {
126 (*outputs)[i] = extractOutputTensor(&ws_, run_net_.external_output(i));
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...