Caffe2 - C++ API
A deep learning, cross platform ML framework
rowmul_op.h
1 #ifndef CAFFE2_OPERATORS_ROW_MUL_H_
2 #define CAFFE2_OPERATORS_ROW_MUL_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 // A hacky version of Mul with broadcast
12 // RowMul([mat, w], [output])
13 template <typename T, class Context>
14 class RowMulOp : public Operator<Context> {
15  public:
16  USE_OPERATOR_CONTEXT_FUNCTIONS;
17  USE_SIMPLE_CTOR_DTOR(RowMulOp);
18 
19  bool RunOnDevice() override {
20  auto& mat = Input(0);
21  auto& w = Input(1);
22  auto* output = Output(0);
23 
24  output->ResizeLike(mat);
25  T* output_data = output->template mutable_data<T>();
26  const T* mat_data = mat.template data<T>();
27  const T* w_data = w.template data<T>();
28 
29  // Dimension checking
30  CAFFE_ENFORCE_EQ(
31  w.size(),
32  mat.dim32(0),
33  "Length of w should be equal to the first dim of mat");
34 
35  auto block_size = mat.size_from_dim(1);
36  for (int i = 0; i < w.size(); i++) {
37  size_t offset = i * block_size;
38  for (int j = 0; j < block_size; j++) {
39  output_data[offset + j] = mat_data[offset + j] * w_data[i];
40  }
41  }
42 
43  return true;
44  }
45 };
46 
47 // A hacky version
48 template <typename T, class Context>
49 class ReduceTailSumOp : public Operator<Context> {
50  public:
51  USE_OPERATOR_CONTEXT_FUNCTIONS;
52  USE_SIMPLE_CTOR_DTOR(ReduceTailSumOp);
53 
54  bool RunOnDevice() override {
55  auto& mat = Input(0);
56  auto* output = Output(0);
57 
58  int N = mat.dim32(0);
59  int block_size = mat.size_from_dim(1);
60 
61  output->Resize(N);
62  T* output_data = output->template mutable_data<T>();
63  const T* mat_data = mat.template data<T>();
64 
65  for (int i = 0; i < N; i++) {
66  output_data[i] = 0;
67  size_t offset = i * block_size;
68  for (int j = 0; j < block_size; j++) {
69  output_data[i] += mat_data[offset + j];
70  }
71  }
72  return true;
73  }
74 };
75 
76 } // namespace caffe2
77 
78 #endif // CAFFE2_OPERATORS_ROW_MUL_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...