2 #include "GLPredictor.h" 4 #include "rewrite_net.h" 10 void shareInputGLImage(Workspace* ws,
const std::string& name,
GLImageVector<T>* input) {
11 auto* blob = ws->GetBlob(name);
12 CAFFE_ENFORCE(blob,
"Blob: ", name,
" does not exist");
17 const GLImageVector<T>* extractOutputGLImage(Workspace* ws,
const std::string& name) {
18 auto* blob = ws->GetBlob(name);
19 CAFFE_ENFORCE(blob,
"Blob: ", name,
" does not exist");
20 return &blob->template Get<GLImageVector<T>>();
23 const NetDef create_gl_run_net(
const NetDef& init_net,
24 const NetDef& run_net,
25 bool use_texture_input) {
27 if (!tryConvertToOpenGL(init_net, run_net, &gl_run_net, use_texture_input)) {
28 CAFFE_THROW(
"Failed to convert model to OpenGL");
33 GLPredictor::GLPredictor(
const NetDef& init_net,
34 const NetDef& run_net,
35 bool use_texture_input,
37 : Predictor(init_net, create_gl_run_net(init_net, run_net, use_texture_input), parent) {}
39 GLPredictor::~GLPredictor() {}
44 const NetDef& run_net_ = Predictor::def();
45 CAFFE_ENFORCE(inputs.size() <= run_net_.external_input_size());
46 for (
auto i = 0; i < inputs.size(); ++i) {
47 shareInputGLImage<T>(Predictor::ws(), run_net_.external_input(i), inputs[i]);
50 if (!Predictor::ws()->RunNet(run_net_.name())) {
54 for (
auto i = 0; i < run_net_.external_output_size(); ++i) {
55 outputs->push_back(extractOutputGLImage<T>(Predictor::ws(), run_net_.external_output(i)));
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...