Caffe2 - C++ API
A deep learning, cross platform ML framework
matmul_op.h
1 #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_
2 #define CAFFE2_OPERATORS_MATMUL_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context, class Engine = DefaultEngine>
11 class MatMulOp final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  MatMulOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  axis_a_(OperatorBase::GetSingleArgument<int>("axis_a", 1)),
17  axis_b_(OperatorBase::GetSingleArgument<int>("axis_b", 1)),
18  trans_a_(OperatorBase::GetSingleArgument<int>("trans_a", 0)),
19  trans_b_(OperatorBase::GetSingleArgument<int>("trans_b", 0)) {}
20  ~MatMulOp() {}
21 
22  bool RunOnDevice() override {
23  const auto& A = Input(0);
24  const auto& B = Input(1);
25  auto* Y = Output(0);
26 
27  const auto canonical_axis_a = A.canonical_axis_index(axis_a_);
28  const auto canonical_axis_b = B.canonical_axis_index(axis_b_);
29  int A_dim0 = A.size_to_dim(canonical_axis_a);
30  int A_dim1 = A.size_from_dim(canonical_axis_a);
31  int B_dim0 = B.size_to_dim(canonical_axis_b);
32  int B_dim1 = B.size_from_dim(canonical_axis_b);
33 
34  int a_dim0, a_dim1, b_dim0, b_dim1;
35 
36  if (trans_a_) {
37  a_dim0 = A_dim1;
38  a_dim1 = A_dim0;
39  } else {
40  a_dim0 = A_dim0;
41  a_dim1 = A_dim1;
42  }
43 
44  if (trans_b_) {
45  b_dim0 = B_dim1;
46  b_dim1 = B_dim0;
47  } else {
48  b_dim0 = B_dim0;
49  b_dim1 = B_dim1;
50  }
51 
52  auto dimErrorString = [&]() {
53  return MakeString(
54  "Dimension mismatch: ",
55  trans_a_ ? "trans(A): " : "A: ",
56  a_dim0,
57  " ",
58  a_dim1,
59  trans_b_ ? ", trans(B): " : ", B: ",
60  b_dim0,
61  " ",
62  b_dim1);
63  };
64  // Error checking
65  CAFFE_ENFORCE(a_dim1 == b_dim0, dimErrorString());
66 
67  Y_shape_cache_[0] = a_dim0;
68  Y_shape_cache_[1] = b_dim1;
69  Y->Resize(Y_shape_cache_);
70  CAFFE_ENFORCE(a_dim0 * b_dim1 == Y->size(), dimErrorString());
71  // Y = A * B
72  math::Gemm<T, Context, Engine>(
73  trans_a_ ? CblasTrans : CblasNoTrans,
74  trans_b_ ? CblasTrans : CblasNoTrans,
75  a_dim0,
76  b_dim1,
77  a_dim1,
78  1,
79  A.template data<T>(),
80  B.template data<T>(),
81  0,
82  Y->template mutable_data<T>(),
83  &context_);
84 
85  if (InputSize() == 3) {
86  // In gradient op, resize to input
87  Y->ResizeLike(Input(2));
88  }
89  return true;
90  }
91 
92  protected:
93  // A local vector to cache the output shape so we don't need to recreate
94  // a vector object every time we run Run().
95  vector<TIndex> Y_shape_cache_{0, 0};
96  int axis_a_{1};
97  int axis_b_{1};
98  bool trans_a_;
99  bool trans_b_;
100 };
101 
102 } // namespace caffe2
103 
104 #endif // CAFFE2_OPERATORS_MATMUL_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...