Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_matmul_op.h
1 #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_
2 #define CAFFE2_OPERATORS_MATMUL_OP_H_
3 
4 #include <sstream>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <class Context, class Engine = DefaultEngine>
13 class BatchMatMulOp final : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  BatchMatMulOp(const OperatorDef& operator_def, Workspace* ws)
17  : Operator<Context>(operator_def, ws),
18  trans_a_(OperatorBase::GetSingleArgument<int>("trans_a", 0)),
19  trans_b_(OperatorBase::GetSingleArgument<int>("trans_b", 0)),
20  broadcast_(OperatorBase::GetSingleArgument<int>("broadcast", 0)),
21  use_scratch_(OperatorBase::GetSingleArgument<int>("use_scratch", 0)) {
22  if (use_scratch_) {
23  scratch_ = std::make_shared<Tensor<Context>>();
24  }
25  }
26 
27  ~BatchMatMulOp() {}
28 
29  bool RunOnDevice() override {
30  return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
31  }
32 
33  template <typename T>
34  bool DoRunWithType() {
35  const auto& A = Input(0);
36  const auto& B = Input(1);
37  auto* Y = Output(0);
38 
39  auto ndims_A = A.ndim();
40  auto dims_A = A.dims();
41  auto ndims_B = B.ndim();
42  auto dims_B = B.dims();
43 
44  auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) {
45  std::stringstream ss;
46  ss << "Inputs with dimensions A = ";
47  ss << dim1;
48  ss << " and B = ";
49  ss << dim2;
50  ss << " is not supported with broadcast=0. Did you forget to set the "
51  "broadcast flag?";
52  return ss.str();
53  };
54 
55  // These should all be false if we're not broadcasting.
56  bool dimMismatch = ndims_A != ndims_B;
57  bool dimsLessThan1D = ndims_A < 2;
58  CAFFE_ENFORCE(
59  broadcast_ || (!dimMismatch && !dimsLessThan1D),
60  noBroadcastErrorMsg(ndims_A, ndims_B));
61 
62  auto* data_A = A.template data<T>();
63  auto* data_B = B.template data<T>();
64 
65  auto dimMismatchErrorString = [](size_t dimnum1,
66  size_t dim1,
67  size_t dimnum2,
68  size_t dim2,
69  bool trans_a,
70  bool trans_b) {
71  std::stringstream ss;
72  ss << "Expected dimension ";
73  ss << dimnum1;
74  ss << " of tensor A with value ";
75  ss << dim1;
76  ss << " to match dimension ";
77  ss << dimnum2;
78  ss << " of tensor B with value ";
79  ss << dim2;
80  ss << ". trans_a = ";
81  ss << trans_a;
82  ss << " trans_b = ";
83  ss << trans_b;
84  return ss.str();
85  };
86 
87  if (ndims_A == 1 && ndims_B == 1) {
88  // vector-vector
89  CAFFE_ENFORCE_EQ(
90  dims_A[0],
91  dims_B[0],
92  "Vector-vector product requires each of the vectors to "
93  "be the same size.");
94  Y->Resize(1);
95  math::Dot<T, Context>(
96  dims_A[0], data_A, data_B, Y->template mutable_data<T>(), &context_);
97  } else {
98  bool A_broadcasted = false, B_broadcasted = false;
99  if (ndims_A == 1) {
100  dims_A.insert(dims_A.begin(), 1);
101  ndims_A = 2;
102  A_broadcasted = true;
103  }
104  if (ndims_B == 1) {
105  dims_B.push_back(1);
106  ndims_B = 2;
107  B_broadcasted = true;
108  }
109  // matrix-matrix with batches
110  // [B1..., M, K] * [B2..., K, N] -> [B..., M, N]
111  // In the event that A or B are one-dimensional, the trailing or leading
112  // 1 is not added to the output tensor's size.
113 
114  // First step: partition the tensors into inner and outer blocks.
115  // Ignoring the last two dimensions of A and B, ensure that one of the
116  // tensors' dimensions is a suffix of the other. For example,
117  // [4, x, x] is a suffix of [2, 3, 4, x, x]. In this example, the
118  // dimensions of size 2 and 3 will be broadcasted, so we partition into
119  // 2*3=6 individual instances of batched GEMM with A and B \in [4, x, x].
120  size_t num_inner_dims = std::min(ndims_A, ndims_B);
121  for (size_t i = 2; i < num_inner_dims; ++i) {
122  auto first_r_itr = dims_A.rbegin();
123  auto second_r_itr = dims_B.rbegin();
124  CAFFE_ENFORCE_EQ(
125  *(first_r_itr + i),
126  *(second_r_itr + i),
127  dimMismatchErrorString(
128  ndims_A - i - 1,
129  *(first_r_itr + i),
130  ndims_B - i - 1,
131  *(second_r_itr + i),
132  trans_a_,
133  trans_b_));
134  }
135  size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
136 
137  // Standard M, N, and K parameters respecting GEMM API and transpose
138  // flags
139  size_t M, N, K, K_dim;
140  if (trans_a_) {
141  M = dims_A[ndims_A - 1];
142  K = dims_A[ndims_A - 2];
143  K_dim = ndims_A - 2;
144  } else {
145  M = dims_A[ndims_A - 2];
146  K = dims_A[ndims_A - 1];
147  K_dim = ndims_A - 1;
148  }
149  if (trans_b_) {
150  N = dims_B[ndims_B - 2];
151  CAFFE_ENFORCE_EQ(
152  K,
153  dims_B[ndims_B - 1],
154  dimMismatchErrorString(
155  K_dim,
156  K,
157  ndims_B - 1,
158  dims_B[ndims_B - 1],
159  trans_a_,
160  trans_b_));
161  } else {
162  N = dims_B[ndims_B - 1];
163  CAFFE_ENFORCE_EQ(
164  K,
165  dims_B[ndims_B - 2],
166  dimMismatchErrorString(
167  K_dim,
168  K,
169  ndims_B - 2,
170  dims_B[ndims_B - 2],
171  trans_a_,
172  trans_b_));
173  }
174 
175  // Calculate output tensor shapes [B..., (M), (N)]
176  // Batch dimensions will be broadcasted out to those of the longer tensor
177  // A or B. Either M or N are optional if A or B, respectively are 1-D.
178  std::vector<TIndex> new_dims;
179  if (ndims_A >= ndims_B) {
180  new_dims.assign(dims_A.begin(), dims_A.end() - 2);
181  } else {
182  new_dims.assign(dims_B.begin(), dims_B.end() - 2);
183  }
184  if (!A_broadcasted) {
185  new_dims.push_back(M);
186  } else {
187  new_dims.push_back(1);
188  }
189  if (!B_broadcasted) {
190  new_dims.push_back(N);
191  } else {
192  new_dims.push_back(1);
193  }
194 
195  // Calculate strides. Continuing our example above,
196  // [4, M, K] * [2, 3, 4, K, N] = [2, 3, 4, M, N]
197  // We calculate this as follows:
198  // 1) Treat the outer batch dimensions as flattened, i.e. view the B
199  // tensor here as [6, 4, K, N] and Y as [6, 4, M, N]. The same rea-
200  // soning is analogous for the case where # dims A >= # dims B.
201  // 2) Perform this operation:
202  // for i in range(6):
203  // Y[i, :, :, :] = BatchMatMul(A, B[i, :, :, :])
204  size_t A_stride = 1; // How far to increment A pointer each itr
205  size_t B_stride = 1; // How far to increment B pointer each itr
206  size_t Y_stride = 1; // How far to increment Y pointer each itr
207  // How many "inner batches" we have. That is, the product of sizes for
208  // the slices excluding M, K, and N, for their respective matrices.
209  size_t num_sub_batches = 1;
210  if (ndims_A >= ndims_B) {
211  auto first_r_itr = dims_A.rbegin();
212  auto output_r_itr = new_dims.rbegin();
213  for (size_t i = 0; i < num_inner_dims; ++i) {
214  A_stride *= *(first_r_itr + i);
215  Y_stride *= *(output_r_itr + i);
216  if (i >= 2) {
217  num_sub_batches *= *(first_r_itr + i);
218  }
219  }
220  B_stride = 0;
221  } else {
222  A_stride = 0;
223  auto second_r_itr = dims_B.rbegin();
224  auto output_r_itr = new_dims.rbegin();
225  for (size_t i = 0; i < num_inner_dims; ++i) {
226  B_stride *= *(second_r_itr + i);
227  Y_stride *= *(output_r_itr + i);
228  if (i >= 2) {
229  num_sub_batches *= *(second_r_itr + i);
230  }
231  }
232  }
233 
234  size_t num_outer_batches = 1;
235  for (size_t i = 0; i < num_outer_dims; ++i) {
236  num_outer_batches *= new_dims[i];
237  }
238 
239  // Mutually exclusive since otherwise we would've taken the vector-vector
240  // path above
241  if (A_broadcasted) {
242  new_dims.erase(new_dims.end() - 2);
243  } else if (B_broadcasted) {
244  new_dims.erase(new_dims.end() - 1);
245  }
246 
247  // Allocate output tensor
248  Y->Resize(new_dims);
249  auto* Y_data = Y->template mutable_data<T>();
250 
251  // Zero batch dimension indicates no elements
252  if (num_sub_batches == 0 || num_outer_batches == 0) {
253  return true;
254  }
255 
256  // TODO(T23893772): doing this in a loop is likely going to be slow on GPU
257  for (size_t p = 0; p < num_outer_batches; ++p) {
258  math::GemmBatched<T, Context, Engine>(
259  trans_a_ ? CblasTrans : CblasNoTrans,
260  trans_b_ ? CblasTrans : CblasNoTrans,
261  num_sub_batches,
262  M,
263  N,
264  K,
265  1.0f,
266  data_A + p * A_stride,
267  data_B + p * B_stride,
268  0.0f,
269  Y_data + p * Y_stride,
270  &context_,
271  use_scratch_ ? scratch_.get() : nullptr);
272  }
273  }
274  return true;
275  }
276 
277  protected:
278  bool trans_a_;
279  bool trans_b_;
280  bool broadcast_;
281 
282  bool use_scratch_;
283  std::shared_ptr<Tensor<Context>> scratch_;
284 };
285 
286 } // namespace caffe2
287 
288 #endif /* CAFFE2_OPERATORS_MATMUL_OP_H_ */
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...