1 #include "caffe2/operators/matmul_op.h" 5 REGISTER_CPU_OPERATOR(MatMul, MatMulOp<float, CPUContext>);
7 OPERATOR_SCHEMA(MatMul)
10 .TensorInferenceFunction([](
const OperatorDef& def,
11 const vector<TensorShape>& in) {
12 vector<TensorShape> out(1);
13 out[0].set_data_type(in[0].data_type());
14 ArgumentHelper arg_helper(def);
15 int axis_a = arg_helper.GetSingleArgument<
int>(
"axis_a", 1);
16 int axis_b = arg_helper.GetSingleArgument<
int>(
"axis_b", 1);
17 int trans_a = arg_helper.GetSingleArgument<
bool>(
"trans_a",
false);
18 int trans_b = arg_helper.GetSingleArgument<
bool>(
"trans_b",
false);
19 int canonical_axis_a = canonical_axis_index_(axis_a, in[0].dims().size());
20 int canonical_axis_b = canonical_axis_index_(axis_b, in[0].dims().size());
22 int M = size_to_dim_(canonical_axis_a, GetDimsVector(in[0]));
28 N = size_to_dim_(canonical_axis_b, GetDimsVector(in[1]));
37 Matrix multiplication Y = A * B, where A has size (M x K), B has size (K x N), 38 and Y will have a size (M x N). 40 .Input(0, "A",
"2D matrix of size (M x K)")
41 .Input(1,
"B",
"2D matrix of size (K x N)")
42 .Output(0,
"Y",
"2D matrix of size (M x N)")
45 "Exclusive axis that divides the first and second dimension \ 46 of matrix A, default to 1")
49 "Exclusive axis that divides the first and second dimension \ 50 of matrix B, default to 1")
53 "Pass 1 to transpose A before multiplication and after the \ 54 dimension adjustment using axis_a")
57 "Pass 1 to transpose B before multiplication and after the \ 58 dimension adjustment using axis_b");
61 using GradientMakerBase::GradientMakerBase;
62 vector<OperatorDef> GetGradientDefs()
override {
63 CAFFE_ENFORCE_EQ(def_.input_size(), 2);
70 if (ArgumentHelper::HasArgument(Def(),
"trans_a")) {
71 trans_a = GetArgument(Def(),
"trans_a").i();
73 if (ArgumentHelper::HasArgument(Def(),
"trans_b")) {
74 trans_b = GetArgument(Def(),
"trans_b").i();
76 if (ArgumentHelper::HasArgument(Def(),
"axis_a")) {
77 axis_a = GetArgument(Def(),
"axis_a").i();
79 if (ArgumentHelper::HasArgument(Def(),
"axis_b")) {
80 axis_b = GetArgument(Def(),
"axis_b").i();
87 return vector<OperatorDef>{
91 vector<string>{I(1), GO(0), I(0)},
92 vector<string>{GI(0)},
93 vector<Argument>{MakeArgument<int>(
"trans_a", 1),
94 MakeArgument<int>(
"trans_b", 1),
95 MakeArgument<int>(
"axis_a", axis_b)}),
99 vector<string>{GO(0), I(0), I(1)},
100 vector<string>{GI(1)},
101 vector<Argument>{MakeArgument<int>(
"trans_a", 1),
102 MakeArgument<int>(
"trans_b", 1),
103 MakeArgument<int>(
"axis_b", axis_a)})};
107 return vector<OperatorDef>{
111 vector<string>{I(1), GO(0), I(0)},
112 vector<string>{GI(0)},
113 vector<Argument>{MakeArgument<int>(
"trans_b", 1),
114 MakeArgument<int>(
"axis_a", axis_b)}),
118 vector<string>{I(0), GO(0), I(1)},
119 vector<string>{GI(1)},
120 vector<Argument>{MakeArgument<int>(
"axis_a", axis_a)})};
126 return vector<OperatorDef>{
130 vector<string>{GO(0), I(1), I(0)},
131 vector<string>{GI(0)},
132 vector<Argument>{MakeArgument<int>(
"axis_b", axis_b)}),
136 vector<string>{GO(0), I(0), I(1)},
137 vector<string>{GI(1)},
138 vector<Argument>{MakeArgument<int>(
"trans_a", 1),
139 MakeArgument<int>(
"axis_b", axis_a)})};
143 return vector<OperatorDef>{
147 vector<string>{GO(0), I(1), I(0)},
148 vector<string>{GI(0)},
149 vector<Argument>{MakeArgument<int>(
"trans_b", 1),
150 MakeArgument<int>(
"axis_b", axis_b)}),
154 vector<string>{I(0), GO(0), I(1)},
155 vector<string>{GI(1)},
156 vector<Argument>{MakeArgument<int>(
"trans_a", 1),
157 MakeArgument<int>(
"axis_a", axis_a)})};
162 bool CopyArguments()
const override {
TIndex size_from_dim_(int k, const vector< TIndex > &dims)
Return product of all dimensions starting from K.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...