Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_matmul_op.cc
1 #include "caffe2/operators/batch_matmul_op.h"
2 #include "caffe2/core/operator_schema.h"
3 
4 namespace caffe2 {
5 
6 REGISTER_CPU_OPERATOR(BatchMatMul, BatchMatMulOp<CPUContext>);
7 
8 vector<TensorShape> TensorInferenceForBatchMatMul(
9  const OperatorDef& def,
10  const vector<TensorShape>& in) {
11  ArgumentHelper helper(def);
12  bool broadcast = helper.GetSingleArgument<int>("broadcast", 0);
13  if (!broadcast) {
14  const auto ndim = in[0].dims_size();
15  CAFFE_ENFORCE_GE(ndim, 2);
16  int a_dim0;
17  int b_dim1;
18  if (helper.GetSingleArgument<int>("trans_a", 0)) {
19  a_dim0 = in[0].dims(ndim - 1);
20  } else {
21  a_dim0 = in[0].dims(ndim - 2);
22  }
23 
24  if (helper.GetSingleArgument<int>("trans_b", 0)) {
25  b_dim1 = in[1].dims(ndim - 2);
26  } else {
27  b_dim1 = in[1].dims(ndim - 1);
28  }
29 
30  auto output_dims = vector<TIndex>{in[0].dims().begin(), in[0].dims().end()};
31  output_dims[ndim - 2] = a_dim0;
32  output_dims[ndim - 1] = b_dim1;
33 
34  return vector<TensorShape>{
35  CreateTensorShape(vector<TIndex>{output_dims}, in[0].data_type())};
36  } else {
37  auto ndims_A = in[0].dims_size();
38  auto ndims_B = in[1].dims_size();
39  std::vector<TIndex> dims_A(ndims_A), dims_B(ndims_B);
40  for (int i = 0; i < ndims_A; ++i) {
41  dims_A[i] = in[0].dims(i);
42  }
43  for (int i = 0; i < ndims_B; ++i) {
44  dims_B[i] = in[1].dims(i);
45  }
46  bool A_broadcasted = false, B_broadcasted = false;
47  if (ndims_A == 1) {
48  dims_A.insert(dims_A.begin(), 1);
49  ndims_A = 2;
50  A_broadcasted = true;
51  }
52  if (ndims_B == 1) {
53  dims_B.push_back(1);
54  ndims_B = 2;
55  B_broadcasted = true;
56  }
57  size_t M, N;
58  if (helper.GetSingleArgument<int>("trans_a", 0)) {
59  M = dims_A[ndims_A - 1];
60  } else {
61  M = dims_A[ndims_A - 2];
62  }
63  if (helper.GetSingleArgument<int>("trans_b", 0)) {
64  N = dims_B[ndims_B - 2];
65  } else {
66  N = dims_B[ndims_B - 1];
67  }
68 
69  std::vector<TIndex> new_dims;
70  if (ndims_A >= ndims_B) {
71  new_dims.assign(dims_A.begin(), dims_A.end() - 2);
72  } else {
73  new_dims.assign(dims_B.begin(), dims_B.end() - 2);
74  }
75  if (!A_broadcasted) {
76  new_dims.push_back(M);
77  }
78  if (!B_broadcasted) {
79  new_dims.push_back(N);
80  }
81  if (A_broadcasted && B_broadcasted) {
82  new_dims.push_back(1);
83  }
84  return vector<TensorShape>{
85  CreateTensorShape(vector<TIndex>{new_dims}, in[0].data_type())};
86  }
87 }
88 
89 OpSchema::Cost CostInferenceForBatchMatMul(
90  const OperatorDef& def,
91  const vector<TensorShape>& in) {
92  ArgumentHelper helper(def);
93  struct OpSchema::Cost c;
94  const TensorShape Y = TensorInferenceForBatchMatMul(def, in)[0];
95 
96  auto ndims_A = in[0].dims_size();
97  long long nElemY = 1;
98  for (int i = 0; i < Y.dims_size(); i++) {
99  nElemY *= Y.dims(i);
100  }
101  size_t K;
102  if (helper.GetSingleArgument<int>("trans_a", 0)) {
103  K = in[0].dims(ndims_A - 2);
104  } else {
105  K = in[0].dims(ndims_A - 1);
106  }
107  c.flops = 2 * nElemY * K;
108  c.bytes_moved = nElemY * sizeof(float);
109  c.params_bytes = 0;
110  return c;
111 }
112 
113 OPERATOR_SCHEMA(BatchMatMul)
114  .NumInputs(2)
115  .NumOutputs(1)
116  .SetDoc(R"DOC(
117 Batch Matrix multiplication Yi = Ai * Bi, where A has shape (dim0, dim1, ... M, K),
118 B has shape (dim0, dim1, ... K, N), Y has shape (dim0, dim1, ... M, N) and i ranges
119 from 0 to (dim0 * dim1 ...) - 1. rank(A) == rank(B) >= 2. In case of A and B being
120 two diemnsional, it behaves like normal matrix multiplication.
121 )DOC")
122  .Input(0, "A", "tensor of shape (dim0, dim1 ... M, K)")
123  .Input(1, "B", "tensor of shpae (dim0, dim2 ... K, N)")
124  .Output(0, "Y", "tensor of shape (dim0, dim1 ... M, N)")
125  .Arg(
126  "trans_a",
127  "Pass 1 to transpose the last two dimensions of A before "
128  "doing multiplication")
129  .Arg(
130  "trans_b",
131  "Pass 1 to transpose the last two dimensions of B before "
132  "doing multiplication")
133  .Arg(
134  "broadcast",
135  "Pass 1 to allow broadcasting of dimensions. Behavior is the same as numpy.matmul. Gradient is currently not supported when running in broadcast mode.")
136  .TensorInferenceFunction(TensorInferenceForBatchMatMul)
137  .CostInferenceFunction(
138  OpSchema::CostInferenceFunctionType(CostInferenceForBatchMatMul));
139 
140 class GetBatchMatMulGradient : public GradientMakerBase {
141  using GradientMakerBase::GradientMakerBase;
142  vector<OperatorDef> GetGradientDefs() override {
143  CAFFE_ENFORCE_EQ(def_.input_size(), 2);
144 
145  bool broadcast = false;
146  if (ArgumentHelper::HasArgument(Def(), "broadcast")) {
147  broadcast = GetArgument(Def(), "broadcast").i();
148  }
149  CAFFE_ENFORCE(
150  !broadcast,
151  "Gradient is currently not supported with "
152  "broadcast=1 for BatchMatMul.");
153 
154  bool trans_a = 0;
155  bool trans_b = 0;
156 
157  if (ArgumentHelper::HasArgument(Def(), "trans_a")) {
158  trans_a = GetArgument(Def(), "trans_a").i();
159  }
160  if (ArgumentHelper::HasArgument(Def(), "trans_b")) {
161  trans_b = GetArgument(Def(), "trans_b").i();
162  }
163 
164  auto no_trans_arg = vector<Argument>();
165  auto trans_a_arg = vector<Argument>{MakeArgument<int>("trans_a", 1)};
166  auto trans_b_arg = vector<Argument>{MakeArgument<int>("trans_b", 1)};
167  auto trans_both_arg = vector<Argument>{MakeArgument<int>("trans_a", 1),
168  MakeArgument<int>("trans_b", 1)};
169 
170  if (ArgumentHelper::HasArgument(Def(), "use_scratch")) {
171  no_trans_arg.push_back(MakeArgument<int>("use_scratch", 1));
172  trans_a_arg.push_back(MakeArgument<int>("use_scratch", 1));
173  trans_b_arg.push_back(MakeArgument<int>("use_scratch", 1));
174  trans_both_arg.push_back(MakeArgument<int>("use_scratch", 1));
175  }
176 
177  if (trans_a) {
178  if (trans_b) {
179  // A'B':
180  // dA = B'G', dB = G'A'
181  return vector<OperatorDef>{CreateOperatorDef(
182  "BatchMatMul",
183  "",
184  vector<string>{I(1), GO(0)},
185  vector<string>{GI(0)},
186  trans_both_arg),
187  CreateOperatorDef(
188  "BatchMatMul",
189  "",
190  vector<string>{GO(0), I(0)},
191  vector<string>{GI(1)},
192  trans_both_arg)};
193  } else {
194  // A'B:
195  // dA = BG', dB = AG
196  return vector<OperatorDef>{CreateOperatorDef(
197  "BatchMatMul",
198  "",
199  vector<string>{I(1), GO(0)},
200  vector<string>{GI(0)},
201  trans_b_arg),
202  CreateOperatorDef(
203  "BatchMatMul",
204  "",
205  vector<string>{I(0), GO(0)},
206  vector<string>{GI(1)},
207  no_trans_arg)};
208  }
209  } else {
210  if (trans_b) {
211  // AB':
212  // dA = GB, dB = G'A
213  return vector<OperatorDef>{CreateOperatorDef(
214  "BatchMatMul",
215  "",
216  vector<string>{GO(0), I(1)},
217  vector<string>{GI(0)},
218  no_trans_arg),
219  CreateOperatorDef(
220  "BatchMatMul",
221  "",
222  vector<string>{GO(0), I(0)},
223  vector<string>{GI(1)},
224  trans_a_arg)};
225  } else {
226  // AB:
227  // dA = GB', dB = A'G
228  return vector<OperatorDef>{CreateOperatorDef(
229  "BatchMatMul",
230  "",
231  vector<string>{GO(0), I(1)},
232  vector<string>{GI(0)},
233  trans_b_arg),
234  CreateOperatorDef(
235  "BatchMatMul",
236  "",
237  vector<string>{I(0), GO(0)},
238  vector<string>{GI(1)},
239  trans_a_arg)};
240  }
241  }
242  }
243 
244  bool CopyArguments() const override {
245  return false;
246  }
247 };
248 
249 REGISTER_GRADIENT(BatchMatMul, GetBatchMatMulGradient);
250 
251 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
Registers a function that takes in an OperatorDef and a series of input shapes and returns the total ...