1 #ifndef CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_ 2 #define CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
typename T,
class Context,
class Engine = DefaultEngine>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 K_(OperatorBase::GetSingleArgument<TIndex>(
"K", 0)),
17 M_(OperatorBase::GetSingleArgument<TIndex>(
"M", 0)),
18 N_(OperatorBase::GetSingleArgument<TIndex>(
"N", 0)) {
24 bool RunOnDevice()
override {
25 const auto& A = Input(0);
26 const auto& B = Input(1);
29 CAFFE_ENFORCE(A.ndim() == 2, A.ndim());
31 TIndex A_size = A.size_from_dim(0);
32 TIndex B_size = B.size_from_dim(0);
36 "Argument `K` and `M` do not agree with the size of A.");
39 B_size % (K_ * N_) == 0,
40 "Argument `K` and `N` do not agree with the size of B.");
42 TIndex D_ = B_size / (K_ * N_);
44 TIndex C_size = D_ * M_ * N_;
45 C->Resize(vector<TIndex>{C_size});
47 TIndex B_stride = K_ * N_;
48 TIndex C_stride = M_ * N_;
50 const T* A_data = A.template data<T>();
51 const T* B_data = B.template data<T>();
52 T* C_data = C->template mutable_data<T>();
54 for (TIndex B_index = 0; B_index < B_size; B_index += B_stride) {
55 math::Gemm<T, Context, Engine>(
76 template <
typename T,
class Context,
class Engine = DefaultEngine>
79 USE_OPERATOR_CONTEXT_FUNCTIONS;
82 K_(OperatorBase::GetSingleArgument<TIndex>(
"K", 0)),
83 M_(OperatorBase::GetSingleArgument<TIndex>(
"M", 0)),
84 N_(OperatorBase::GetSingleArgument<TIndex>(
"N", 0)) {}
86 bool RunOnDevice()
override {
87 const auto& G = Input(0);
88 const auto& A = Input(1);
89 const auto& B = Input(2);
93 TIndex G_size = G.size_from_dim(0);
94 TIndex D_ = G_size / (M_ * N_);
96 TIndex dB_size = D_ * K_ * N_;
101 TIndex B_stride = K_ * N_;
102 TIndex G_stride = M_ * N_;
104 const T* G_data = G.template data<T>();
105 const T* A_data = A.template data<T>();
106 const T* B_data = B.template data<T>();
108 T* dA_data = dA->template mutable_data<T>();
109 T* dB_data = dB->template mutable_data<T>();
111 const T* G_ptr = G_data;
112 for (TIndex B_index = 0; B_index < dB_size; B_index += B_stride) {
113 math::Gemm<T, Context, Engine>(
119 B_index == 0 ? 0 : 1,
126 for (TIndex B_index = 0; B_index < dB_size; B_index += B_stride) {
127 math::Gemm<T, Context, Engine>(
150 #endif // CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_ Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.