Caffe2 - C++ API
A deep learning, cross platform ML framework
GLSigmoid.cc
1 
2 #include "../core/GLFilter.h"
3 #include "../core/GLImage.h"
4 #include "../core/ImageAllocator.h"
5 
6 #include "caffe2/core/operator.h"
7 #include "caffe2/core/timer.h"
8 #include <iostream>
9 #include <vector>
10 
11 typedef enum { Sigmoid, Tanh } OpType;
12 
13 class GLSigmoid : public GLFilter {
14  public:
15  binding* inputData;
16  binding* outputSize;
17 
18  GLSigmoid(OpType opType)
19  : GLFilter("GLSigmoid",
20  vertex_shader,
21  fragment_shader,
22  {BINDING(outputSize), BINDING(inputData)},
23  {/* no uniform blocks */},
24  {/* no attributes */},
25  {{"SIGMOID", caffe2::to_string(opType == Sigmoid)},
26  {"TANH", caffe2::to_string(opType == Tanh)}}) {}
27 
28  template <typename T>
29  void sigmoid(const GLImageVector<T>& input_images, const GLImageVector<T>& output_images);
30 
31  static const char* fragment_shader;
32 };
33 
34 // MARK: GLSL
35 
36 const char* GLSigmoid::fragment_shader = R"GLSL(#version 300 es
37 #define SIGMOID $(SIGMOID)
38 #define TANH $(TANH)
39 
40 precision mediump float;
41 precision mediump int;
42 
43 in highp vec2 v_texCoord;
44 
45 uniform ivec2 outputSize;
46 
47 TEXTURE_INPUT(inputData);
48 TEXTURE_OUTPUT(0, outputData);
49 
50 void main() {
51  ivec2 texelCoord = ivec2(v_texCoord * vec2(outputSize));
52  vec4 value = TEXTURE_LOAD(inputData, ivec2(texelCoord));
53 #if SIGMOID
54  value = vec4(1.0) / (vec4(1.0) + exp(-value));
55  outputData = TEXTURE_STORE(value);
56 #elif TANH
57  value = tanh(value);
58  outputData = TEXTURE_STORE(value);
59 #endif
60 }
61 
62 )GLSL";
63 
64 template <typename T>
65 void GLSigmoid::sigmoid(const GLImageVector<T>& input_images,
66  const GLImageVector<T>& output_images) {
67  for (int i = 0; i < input_images.size(); i++) {
68  auto input_image = input_images[i];
69  auto output_image = output_images[i];
70  int input_slices = input_image->slices;
71  int output_slices = output_image->slices;
72 
73  for (int is = 0; is < input_slices; is++) {
74  run(std::vector<texture_attachment>({{input_image->textures[is], inputData}}),
75  {output_image->textures.begin() + is, output_image->textures.begin() + is + 1},
76  [&]() { glUniform2i(outputSize->location, output_image->width, output_image->height); },
77  output_image->width,
78  output_image->height);
79  }
80  }
81 }
82 
83 namespace caffe2 {
84 template <typename T, OpType opType>
85 class OpenGLSigmoidOp final : public Operator<CPUContext>, ImageAllocator<T> {
86  public:
87  OpenGLSigmoidOp(const OperatorDef& operator_def, Workspace* ws)
88  : Operator<CPUContext>(operator_def, ws) {}
89 
90  bool RunOnDevice() override {
91  const GLImageVector<T>& input = Inputs()[0]->template Get<GLImageVector<T>>();
92  const int num_images = input.size();
93  const int input_channels = input.channels();
94  const int input_width = input.width();
95  const int input_height = input.height();
96 
97  const int output_channels = input_channels;
98  const int output_width = input_width;
99  const int output_height = input_height;
100 
101  int is_last = OperatorBase::GetSingleArgument<int>("is_last", 0);
102 
104  num_images, output_width, output_height, output_channels, is_last);
105 
106  if (!_sigmoid) {
107  _sigmoid.reset(new GLSigmoid(opType));
108  }
109 
110  _sigmoid->sigmoid(input, *output);
111 
112  Outputs()[0]->Reset(output);
113 
114  return true;
115  }
116 
117  private:
118  std::unique_ptr<GLSigmoid> _sigmoid;
119 };
120 
121 REGISTER_CPU_OPERATOR(OpenGLSigmoid, OpenGLSigmoidOp<float16_t, Sigmoid>);
122 OPERATOR_SCHEMA(OpenGLSigmoid)
123  .NumInputs(1)
124  .NumOutputs(1)
125  .AllowInplace({{0, 0}})
126  .IdenticalTypeAndShape();
127 
128 REGISTER_CPU_OPERATOR(OpenGLTanh, OpenGLSigmoidOp<float16_t, Tanh>);
129 OPERATOR_SCHEMA(OpenGLTanh)
130  .NumInputs(1)
131  .NumOutputs(1)
132  .AllowInplace({{0, 0}})
133  .IdenticalTypeAndShape();
134 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...