1 #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_ 2 #define CAFFE2_OPERATORS_MATMUL_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 12 template <
class Context,
class Engine = DefaultEngine>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 trans_a_(OperatorBase::GetSingleArgument<int>(
"trans_a", 0)),
19 trans_b_(OperatorBase::GetSingleArgument<int>(
"trans_b", 0)),
20 broadcast_(OperatorBase::GetSingleArgument<int>(
"broadcast", 0)),
21 use_scratch_(OperatorBase::GetSingleArgument<int>(
"use_scratch", 0)) {
23 scratch_ = std::make_shared<Tensor<Context>>();
29 bool RunOnDevice()
override {
34 bool DoRunWithType() {
35 const auto& A = Input(0);
36 const auto& B = Input(1);
39 auto ndims_A = A.ndim();
40 auto dims_A = A.dims();
41 auto ndims_B = B.ndim();
42 auto dims_B = B.dims();
44 auto noBroadcastErrorMsg = [](
size_t dim1,
size_t dim2) {
46 ss <<
"Inputs with dimensions A = ";
50 ss <<
" is not supported with broadcast=0. Did you forget to set the " 56 bool dimMismatch = ndims_A != ndims_B;
57 bool dimsLessThan1D = ndims_A < 2;
59 broadcast_ || (!dimMismatch && !dimsLessThan1D),
60 noBroadcastErrorMsg(ndims_A, ndims_B));
62 auto* data_A = A.template data<T>();
63 auto* data_B = B.template data<T>();
65 auto dimMismatchErrorString = [](
size_t dimnum1,
72 ss <<
"Expected dimension ";
74 ss <<
" of tensor A with value ";
76 ss <<
" to match dimension ";
78 ss <<
" of tensor B with value ";
87 if (ndims_A == 1 && ndims_B == 1) {
92 "Vector-vector product requires each of the vectors to " 95 math::Dot<T, Context>(
96 dims_A[0], data_A, data_B, Y->template mutable_data<T>(), &context_);
98 bool A_broadcasted =
false, B_broadcasted =
false;
100 dims_A.insert(dims_A.begin(), 1);
102 A_broadcasted =
true;
107 B_broadcasted =
true;
120 size_t num_inner_dims = std::min(ndims_A, ndims_B);
121 for (
size_t i = 2; i < num_inner_dims; ++i) {
122 auto first_r_itr = dims_A.rbegin();
123 auto second_r_itr = dims_B.rbegin();
127 dimMismatchErrorString(
135 size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
139 size_t M, N, K, K_dim;
141 M = dims_A[ndims_A - 1];
142 K = dims_A[ndims_A - 2];
145 M = dims_A[ndims_A - 2];
146 K = dims_A[ndims_A - 1];
150 N = dims_B[ndims_B - 2];
154 dimMismatchErrorString(
162 N = dims_B[ndims_B - 1];
166 dimMismatchErrorString(
178 std::vector<TIndex> new_dims;
179 if (ndims_A >= ndims_B) {
180 new_dims.assign(dims_A.begin(), dims_A.end() - 2);
182 new_dims.assign(dims_B.begin(), dims_B.end() - 2);
184 if (!A_broadcasted) {
185 new_dims.push_back(M);
187 new_dims.push_back(1);
189 if (!B_broadcasted) {
190 new_dims.push_back(N);
192 new_dims.push_back(1);
209 size_t num_sub_batches = 1;
210 if (ndims_A >= ndims_B) {
211 auto first_r_itr = dims_A.rbegin();
212 auto output_r_itr = new_dims.rbegin();
213 for (
size_t i = 0; i < num_inner_dims; ++i) {
214 A_stride *= *(first_r_itr + i);
215 Y_stride *= *(output_r_itr + i);
217 num_sub_batches *= *(first_r_itr + i);
223 auto second_r_itr = dims_B.rbegin();
224 auto output_r_itr = new_dims.rbegin();
225 for (
size_t i = 0; i < num_inner_dims; ++i) {
226 B_stride *= *(second_r_itr + i);
227 Y_stride *= *(output_r_itr + i);
229 num_sub_batches *= *(second_r_itr + i);
234 size_t num_outer_batches = 1;
235 for (
size_t i = 0; i < num_outer_dims; ++i) {
236 num_outer_batches *= new_dims[i];
242 new_dims.erase(new_dims.end() - 2);
243 }
else if (B_broadcasted) {
244 new_dims.erase(new_dims.end() - 1);
249 auto* Y_data = Y->template mutable_data<T>();
252 if (num_sub_batches == 0 || num_outer_batches == 0) {
257 for (
size_t p = 0; p < num_outer_batches; ++p) {
258 math::GemmBatched<T, Context, Engine>(
259 trans_a_ ? CblasTrans : CblasNoTrans,
260 trans_b_ ? CblasTrans : CblasNoTrans,
266 data_A + p * A_stride,
267 data_B + p * B_stride,
269 Y_data + p * Y_stride,
271 use_scratch_ ? scratch_.get() :
nullptr);
283 std::shared_ptr<Tensor<Context>> scratch_;
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...