1 #include "caffe2/operators/batch_matmul_op.h" 2 #include "caffe2/core/operator_schema.h" 6 REGISTER_CPU_OPERATOR(BatchMatMul, BatchMatMulOp<CPUContext>);
8 vector<TensorShape> TensorInferenceForBatchMatMul(
9 const OperatorDef& def,
10 const vector<TensorShape>& in) {
11 ArgumentHelper helper(def);
12 bool broadcast = helper.GetSingleArgument<
int>(
"broadcast", 0);
14 const auto ndim = in[0].dims_size();
15 CAFFE_ENFORCE_GE(ndim, 2);
18 if (helper.GetSingleArgument<
int>(
"trans_a", 0)) {
19 a_dim0 = in[0].dims(ndim - 1);
21 a_dim0 = in[0].dims(ndim - 2);
24 if (helper.GetSingleArgument<
int>(
"trans_b", 0)) {
25 b_dim1 = in[1].dims(ndim - 2);
27 b_dim1 = in[1].dims(ndim - 1);
30 auto output_dims = vector<TIndex>{in[0].dims().begin(), in[0].dims().end()};
31 output_dims[ndim - 2] = a_dim0;
32 output_dims[ndim - 1] = b_dim1;
34 return vector<TensorShape>{
35 CreateTensorShape(vector<TIndex>{output_dims}, in[0].data_type())};
37 auto ndims_A = in[0].dims_size();
38 auto ndims_B = in[1].dims_size();
39 std::vector<TIndex> dims_A(ndims_A), dims_B(ndims_B);
40 for (
int i = 0; i < ndims_A; ++i) {
41 dims_A[i] = in[0].dims(i);
43 for (
int i = 0; i < ndims_B; ++i) {
44 dims_B[i] = in[1].dims(i);
46 bool A_broadcasted =
false, B_broadcasted =
false;
48 dims_A.insert(dims_A.begin(), 1);
58 if (helper.GetSingleArgument<
int>(
"trans_a", 0)) {
59 M = dims_A[ndims_A - 1];
61 M = dims_A[ndims_A - 2];
63 if (helper.GetSingleArgument<
int>(
"trans_b", 0)) {
64 N = dims_B[ndims_B - 2];
66 N = dims_B[ndims_B - 1];
69 std::vector<TIndex> new_dims;
70 if (ndims_A >= ndims_B) {
71 new_dims.assign(dims_A.begin(), dims_A.end() - 2);
73 new_dims.assign(dims_B.begin(), dims_B.end() - 2);
76 new_dims.push_back(M);
79 new_dims.push_back(N);
81 if (A_broadcasted && B_broadcasted) {
82 new_dims.push_back(1);
84 return vector<TensorShape>{
85 CreateTensorShape(vector<TIndex>{new_dims}, in[0].data_type())};
89 OpSchema::Cost CostInferenceForBatchMatMul(
90 const OperatorDef& def,
91 const vector<TensorShape>& in) {
92 ArgumentHelper helper(def);
93 struct OpSchema::Cost c;
94 const TensorShape Y = TensorInferenceForBatchMatMul(def, in)[0];
96 auto ndims_A = in[0].dims_size();
98 for (
int i = 0; i < Y.dims_size(); i++) {
102 if (helper.GetSingleArgument<
int>(
"trans_a", 0)) {
103 K = in[0].dims(ndims_A - 2);
105 K = in[0].dims(ndims_A - 1);
107 c.flops = 2 * nElemY * K;
108 c.bytes_moved = nElemY *
sizeof(float);
113 OPERATOR_SCHEMA(BatchMatMul)
117 Batch Matrix multiplication Yi = Ai * Bi, where A has shape (dim0, dim1, ... M, K), 118 B has shape (dim0, dim1, ... K, N), Y has shape (dim0, dim1, ... M, N) and i ranges 119 from 0 to (dim0 * dim1 ...) - 1. rank(A) == rank(B) >= 2. In case of A and B being 120 two diemnsional, it behaves like normal matrix multiplication. 122 .Input(0, "A",
"tensor of shape (dim0, dim1 ... M, K)")
123 .Input(1,
"B",
"tensor of shpae (dim0, dim2 ... K, N)")
124 .Output(0,
"Y",
"tensor of shape (dim0, dim1 ... M, N)")
127 "Pass 1 to transpose the last two dimensions of A before " 128 "doing multiplication")
131 "Pass 1 to transpose the last two dimensions of B before " 132 "doing multiplication")
135 "Pass 1 to allow broadcasting of dimensions. Behavior is the same as numpy.matmul. Gradient is currently not supported when running in broadcast mode.")
136 .TensorInferenceFunction(TensorInferenceForBatchMatMul)
137 .CostInferenceFunction(
140 class GetBatchMatMulGradient :
public GradientMakerBase {
141 using GradientMakerBase::GradientMakerBase;
142 vector<OperatorDef> GetGradientDefs()
override {
143 CAFFE_ENFORCE_EQ(def_.input_size(), 2);
145 bool broadcast =
false;
146 if (ArgumentHelper::HasArgument(Def(),
"broadcast")) {
147 broadcast = GetArgument(Def(),
"broadcast").i();
151 "Gradient is currently not supported with " 152 "broadcast=1 for BatchMatMul.");
157 if (ArgumentHelper::HasArgument(Def(),
"trans_a")) {
158 trans_a = GetArgument(Def(),
"trans_a").i();
160 if (ArgumentHelper::HasArgument(Def(),
"trans_b")) {
161 trans_b = GetArgument(Def(),
"trans_b").i();
164 auto no_trans_arg = vector<Argument>();
165 auto trans_a_arg = vector<Argument>{MakeArgument<int>(
"trans_a", 1)};
166 auto trans_b_arg = vector<Argument>{MakeArgument<int>(
"trans_b", 1)};
167 auto trans_both_arg = vector<Argument>{MakeArgument<int>(
"trans_a", 1),
168 MakeArgument<int>(
"trans_b", 1)};
170 if (ArgumentHelper::HasArgument(Def(),
"use_scratch")) {
171 no_trans_arg.push_back(MakeArgument<int>(
"use_scratch", 1));
172 trans_a_arg.push_back(MakeArgument<int>(
"use_scratch", 1));
173 trans_b_arg.push_back(MakeArgument<int>(
"use_scratch", 1));
174 trans_both_arg.push_back(MakeArgument<int>(
"use_scratch", 1));
181 return vector<OperatorDef>{CreateOperatorDef(
184 vector<string>{I(1), GO(0)},
185 vector<string>{GI(0)},
190 vector<string>{GO(0), I(0)},
191 vector<string>{GI(1)},
196 return vector<OperatorDef>{CreateOperatorDef(
199 vector<string>{I(1), GO(0)},
200 vector<string>{GI(0)},
205 vector<string>{I(0), GO(0)},
206 vector<string>{GI(1)},
213 return vector<OperatorDef>{CreateOperatorDef(
216 vector<string>{GO(0), I(1)},
217 vector<string>{GI(0)},
222 vector<string>{GO(0), I(0)},
223 vector<string>{GI(1)},
228 return vector<OperatorDef>{CreateOperatorDef(
231 vector<string>{GO(0), I(1)},
232 vector<string>{GI(0)},
237 vector<string>{I(0), GO(0)},
238 vector<string>{GI(1)},
244 bool CopyArguments()
const override {
249 REGISTER_GRADIENT(BatchMatMul, GetBatchMatMulGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
Registers a function that takes in an OperatorDef and a series of input shapes and returns the total ...