Caffe2 - C++ API
A deep learning, cross platform ML framework
matmul_op.cc
1 #include "caffe2/operators/matmul_op.h"
2 
3 namespace caffe2 {
4 
5 REGISTER_CPU_OPERATOR(MatMul, MatMulOp<float, CPUContext>);
6 
7 OPERATOR_SCHEMA(MatMul)
8  .NumInputs(2, 3)
9  .NumOutputs(1)
10  .TensorInferenceFunction([](const OperatorDef& def,
11  const vector<TensorShape>& in) {
12  vector<TensorShape> out(1);
13  out[0].set_data_type(in[0].data_type());
14  ArgumentHelper arg_helper(def);
15  int axis_a = arg_helper.GetSingleArgument<int>("axis_a", 1);
16  int axis_b = arg_helper.GetSingleArgument<int>("axis_b", 1);
17  int trans_a = arg_helper.GetSingleArgument<bool>("trans_a", false);
18  int trans_b = arg_helper.GetSingleArgument<bool>("trans_b", false);
19  int canonical_axis_a = canonical_axis_index_(axis_a, in[0].dims().size());
20  int canonical_axis_b = canonical_axis_index_(axis_b, in[0].dims().size());
21 
22  int M = size_to_dim_(canonical_axis_a, GetDimsVector(in[0]));
23  int N = size_from_dim_(canonical_axis_b, GetDimsVector(in[1]));
24  if (trans_a) {
25  M = size_from_dim_(canonical_axis_a, GetDimsVector(in[0]));
26  }
27  if (trans_b) {
28  N = size_to_dim_(canonical_axis_b, GetDimsVector(in[1]));
29  }
30 
31  out[0].add_dims(M);
32  out[0].add_dims(N);
33 
34  return out;
35  })
36  .SetDoc(R"DOC(
37 Matrix multiplication Y = A * B, where A has size (M x K), B has size (K x N),
38 and Y will have a size (M x N).
39 )DOC")
40  .Input(0, "A", "2D matrix of size (M x K)")
41  .Input(1, "B", "2D matrix of size (K x N)")
42  .Output(0, "Y", "2D matrix of size (M x N)")
43  .Arg(
44  "axis_a",
45  "Exclusive axis that divides the first and second dimension \
46 of matrix A, default to 1")
47  .Arg(
48  "axis_b",
49  "Exclusive axis that divides the first and second dimension \
50 of matrix B, default to 1")
51  .Arg(
52  "trans_a",
53  "Pass 1 to transpose A before multiplication and after the \
54 dimension adjustment using axis_a")
55  .Arg(
56  "trans_b",
57  "Pass 1 to transpose B before multiplication and after the \
58 dimension adjustment using axis_b");
59 
60 class GetMatMulGradient : public GradientMakerBase {
61  using GradientMakerBase::GradientMakerBase;
62  vector<OperatorDef> GetGradientDefs() override {
63  CAFFE_ENFORCE_EQ(def_.input_size(), 2);
64 
65  bool axis_a = 1;
66  bool axis_b = 1;
67  bool trans_a = 0;
68  bool trans_b = 0;
69 
70  if (ArgumentHelper::HasArgument(Def(), "trans_a")) {
71  trans_a = GetArgument(Def(), "trans_a").i();
72  }
73  if (ArgumentHelper::HasArgument(Def(), "trans_b")) {
74  trans_b = GetArgument(Def(), "trans_b").i();
75  }
76  if (ArgumentHelper::HasArgument(Def(), "axis_a")) {
77  axis_a = GetArgument(Def(), "axis_a").i();
78  }
79  if (ArgumentHelper::HasArgument(Def(), "axis_b")) {
80  axis_b = GetArgument(Def(), "axis_b").i();
81  }
82 
83  if (trans_a) {
84  if (trans_b) {
85  // A'B':
86  // dA = B'G', dB = G'A'
87  return vector<OperatorDef>{
88  CreateOperatorDef(
89  "MatMul",
90  "",
91  vector<string>{I(1), GO(0), I(0)},
92  vector<string>{GI(0)},
93  vector<Argument>{MakeArgument<int>("trans_a", 1),
94  MakeArgument<int>("trans_b", 1),
95  MakeArgument<int>("axis_a", axis_b)}),
96  CreateOperatorDef(
97  "MatMul",
98  "",
99  vector<string>{GO(0), I(0), I(1)},
100  vector<string>{GI(1)},
101  vector<Argument>{MakeArgument<int>("trans_a", 1),
102  MakeArgument<int>("trans_b", 1),
103  MakeArgument<int>("axis_b", axis_a)})};
104  } else {
105  // A'B:
106  // dA = BG', dB = AG
107  return vector<OperatorDef>{
108  CreateOperatorDef(
109  "MatMul",
110  "",
111  vector<string>{I(1), GO(0), I(0)},
112  vector<string>{GI(0)},
113  vector<Argument>{MakeArgument<int>("trans_b", 1),
114  MakeArgument<int>("axis_a", axis_b)}),
115  CreateOperatorDef(
116  "MatMul",
117  "",
118  vector<string>{I(0), GO(0), I(1)},
119  vector<string>{GI(1)},
120  vector<Argument>{MakeArgument<int>("axis_a", axis_a)})};
121  }
122  } else {
123  if (trans_b) {
124  // AB':
125  // dA = GB, dB = G'A
126  return vector<OperatorDef>{
127  CreateOperatorDef(
128  "MatMul",
129  "",
130  vector<string>{GO(0), I(1), I(0)},
131  vector<string>{GI(0)},
132  vector<Argument>{MakeArgument<int>("axis_b", axis_b)}),
133  CreateOperatorDef(
134  "MatMul",
135  "",
136  vector<string>{GO(0), I(0), I(1)},
137  vector<string>{GI(1)},
138  vector<Argument>{MakeArgument<int>("trans_a", 1),
139  MakeArgument<int>("axis_b", axis_a)})};
140  } else {
141  // AB:
142  // dA = GB', dB = A'G
143  return vector<OperatorDef>{
144  CreateOperatorDef(
145  "MatMul",
146  "",
147  vector<string>{GO(0), I(1), I(0)},
148  vector<string>{GI(0)},
149  vector<Argument>{MakeArgument<int>("trans_b", 1),
150  MakeArgument<int>("axis_b", axis_b)}),
151  CreateOperatorDef(
152  "MatMul",
153  "",
154  vector<string>{I(0), GO(0), I(1)},
155  vector<string>{GI(1)},
156  vector<Argument>{MakeArgument<int>("trans_a", 1),
157  MakeArgument<int>("axis_a", axis_a)})};
158  }
159  }
160  }
161 
162  bool CopyArguments() const override {
163  return false;
164  }
165 };
166 
167 REGISTER_GRADIENT(MatMul, GetMatMulGradient);
168 
169 } // namespace caffe2
TIndex size_from_dim_(int k, const vector< TIndex > &dims)
Return product of all dimensions starting from K.
Definition: tensor.h:40
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...