Caffe2 - C++ API
A deep learning, cross platform ML framework
GLPredictor.cc
1 
2 #include "GLPredictor.h"
3 #include "GLContext.h"
4 #include "rewrite_net.h"
5 #include <vector>
6 
7 namespace caffe2 {
8 
9 template <class T>
10 void shareInputGLImage(Workspace* ws, const std::string& name, GLImageVector<T>* input) {
11  auto* blob = ws->GetBlob(name);
12  CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
13  blob->ShareExternal<GLImageVector<T>>(input);
14 }
15 
16 template <class T>
17 const GLImageVector<T>* extractOutputGLImage(Workspace* ws, const std::string& name) {
18  auto* blob = ws->GetBlob(name);
19  CAFFE_ENFORCE(blob, "Blob: ", name, " does not exist");
20  return &blob->template Get<GLImageVector<T>>();
21 }
22 
23 const NetDef create_gl_run_net(const NetDef& init_net,
24  const NetDef& run_net,
25  bool use_texture_input) {
26  NetDef gl_run_net;
27  if (!tryConvertToOpenGL(init_net, run_net, &gl_run_net, use_texture_input)) {
28  CAFFE_THROW("Failed to convert model to OpenGL");
29  }
30  return gl_run_net;
31 }
32 
33 GLPredictor::GLPredictor(const NetDef& init_net,
34  const NetDef& run_net,
35  bool use_texture_input,
36  Workspace* parent)
37  : Predictor(init_net, create_gl_run_net(init_net, run_net, use_texture_input), parent) {}
38 
39 GLPredictor::~GLPredictor() {}
40 
41 template <class T>
42 bool GLPredictor::run(std::vector<GLImageVector<T>*>& inputs,
43  std::vector<const GLImageVector<T>*>* outputs) {
44  const NetDef& run_net_ = Predictor::def();
45  CAFFE_ENFORCE(inputs.size() <= run_net_.external_input_size());
46  for (auto i = 0; i < inputs.size(); ++i) {
47  shareInputGLImage<T>(Predictor::ws(), run_net_.external_input(i), inputs[i]);
48  }
49 
50  if (!Predictor::ws()->RunNet(run_net_.name())) {
51  return false;
52  }
53 
54  for (auto i = 0; i < run_net_.external_output_size(); ++i) {
55  outputs->push_back(extractOutputGLImage<T>(Predictor::ws(), run_net_.external_output(i)));
56  }
57 
58  return true;
59 }
60 
61 template bool GLPredictor::run(std::vector<GLImageVector<uint8_t>*>& inputs,
62  std::vector<const GLImageVector<uint8_t>*>* outputs);
63 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...