1 #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_ 2 #define CAFFE2_OPERATORS_MATMUL_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
typename T,
class Context,
class Engine = DefaultEngine>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 axis_a_(OperatorBase::GetSingleArgument<int>(
"axis_a", 1)),
17 axis_b_(OperatorBase::GetSingleArgument<int>(
"axis_b", 1)),
18 trans_a_(OperatorBase::GetSingleArgument<int>(
"trans_a", 0)),
19 trans_b_(OperatorBase::GetSingleArgument<int>(
"trans_b", 0)) {}
22 bool RunOnDevice()
override {
23 const auto& A = Input(0);
24 const auto& B = Input(1);
27 const auto canonical_axis_a = A.canonical_axis_index(axis_a_);
28 const auto canonical_axis_b = B.canonical_axis_index(axis_b_);
29 int A_dim0 = A.size_to_dim(canonical_axis_a);
30 int A_dim1 = A.size_from_dim(canonical_axis_a);
31 int B_dim0 = B.size_to_dim(canonical_axis_b);
32 int B_dim1 = B.size_from_dim(canonical_axis_b);
34 int a_dim0, a_dim1, b_dim0, b_dim1;
52 auto dimErrorString = [&]() {
54 "Dimension mismatch: ",
55 trans_a_ ?
"trans(A): " :
"A: ",
59 trans_b_ ?
", trans(B): " :
", B: ",
65 CAFFE_ENFORCE(a_dim1 == b_dim0, dimErrorString());
67 Y_shape_cache_[0] = a_dim0;
68 Y_shape_cache_[1] = b_dim1;
69 Y->Resize(Y_shape_cache_);
70 CAFFE_ENFORCE(a_dim0 * b_dim1 == Y->size(), dimErrorString());
72 math::Gemm<T, Context, Engine>(
73 trans_a_ ? CblasTrans : CblasNoTrans,
74 trans_b_ ? CblasTrans : CblasNoTrans,
82 Y->template mutable_data<T>(),
85 if (InputSize() == 3) {
87 Y->ResizeLike(Input(2));
95 vector<TIndex> Y_shape_cache_{0, 0};
104 #endif // CAFFE2_OPERATORS_MATMUL_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...