Caffe2 - C++ API
A deep learning, cross platform ML framework
tt_contraction_op.h
1 #ifndef CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_
2 #define CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context, class Engine = DefaultEngine>
11 class TTContractionOp final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  TTContractionOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  K_(OperatorBase::GetSingleArgument<TIndex>("K", 0)),
17  M_(OperatorBase::GetSingleArgument<TIndex>("M", 0)),
18  N_(OperatorBase::GetSingleArgument<TIndex>("N", 0)) {
19  CAFFE_ENFORCE(OperatorBase::HasArgument("K"), "Argument `K` is missing.");
20  CAFFE_ENFORCE(OperatorBase::HasArgument("M"), "Argument `M` is missing.");
21  CAFFE_ENFORCE(OperatorBase::HasArgument("N"), "Argument `N` is missing.");
22  }
23 
24  bool RunOnDevice() override {
25  const auto& A = Input(0);
26  const auto& B = Input(1);
27  auto* C = Output(0);
28 
29  CAFFE_ENFORCE(A.ndim() == 2, A.ndim());
30 
31  TIndex A_size = A.size_from_dim(0);
32  TIndex B_size = B.size_from_dim(0);
33 
34  CAFFE_ENFORCE(
35  K_ * M_ == A_size,
36  "Argument `K` and `M` do not agree with the size of A.");
37 
38  CAFFE_ENFORCE(
39  B_size % (K_ * N_) == 0,
40  "Argument `K` and `N` do not agree with the size of B.");
41 
42  TIndex D_ = B_size / (K_ * N_);
43 
44  TIndex C_size = D_ * M_ * N_;
45  C->Resize(vector<TIndex>{C_size});
46 
47  TIndex B_stride = K_ * N_;
48  TIndex C_stride = M_ * N_;
49 
50  const T* A_data = A.template data<T>();
51  const T* B_data = B.template data<T>();
52  T* C_data = C->template mutable_data<T>();
53 
54  for (TIndex B_index = 0; B_index < B_size; B_index += B_stride) {
55  math::Gemm<T, Context, Engine>(
56  CblasTrans,
57  CblasNoTrans,
58  M_, N_, K_, 1,
59  A_data,
60  B_data + B_index,
61  0,
62  C_data,
63  &context_);
64  C_data += C_stride;
65  }
66 
67  return true;
68  }
69 
70  protected:
71  TIndex K_;
72  TIndex M_;
73  TIndex N_;
74 };
75 
76 template <typename T, class Context, class Engine = DefaultEngine>
77 class TTContractionGradientOp final : public Operator<Context> {
78  public:
79  USE_OPERATOR_CONTEXT_FUNCTIONS;
80  TTContractionGradientOp(const OperatorDef& operator_def, Workspace* ws)
81  : Operator<Context>(operator_def, ws),
82  K_(OperatorBase::GetSingleArgument<TIndex>("K", 0)),
83  M_(OperatorBase::GetSingleArgument<TIndex>("M", 0)),
84  N_(OperatorBase::GetSingleArgument<TIndex>("N", 0)) {}
85 
86  bool RunOnDevice() override {
87  const auto& G = Input(0);
88  const auto& A = Input(1);
89  const auto& B = Input(2);
90  auto* dA = Output(0);
91  auto* dB = Output(1);
92 
93  TIndex G_size = G.size_from_dim(0);
94  TIndex D_ = G_size / (M_ * N_);
95 
96  TIndex dB_size = D_ * K_ * N_;
97 
98  dA->Resize(A.dims());
99  dB->Resize(B.dims());
100 
101  TIndex B_stride = K_ * N_;
102  TIndex G_stride = M_ * N_;
103 
104  const T* G_data = G.template data<T>();
105  const T* A_data = A.template data<T>();
106  const T* B_data = B.template data<T>();
107 
108  T* dA_data = dA->template mutable_data<T>();
109  T* dB_data = dB->template mutable_data<T>();
110 
111  const T* G_ptr = G_data;
112  for (TIndex B_index = 0; B_index < dB_size; B_index += B_stride) {
113  math::Gemm<T, Context, Engine>(
114  CblasNoTrans,
115  CblasTrans,
116  K_, M_, N_, 1,
117  B_data + B_index,
118  G_ptr,
119  B_index == 0 ? 0 : 1,
120  dA_data,
121  &context_);
122  G_ptr += G_stride;
123  }
124 
125  G_ptr = G_data;
126  for (TIndex B_index = 0; B_index < dB_size; B_index += B_stride) {
127  math::Gemm<T, Context, Engine>(
128  CblasNoTrans,
129  CblasNoTrans,
130  K_, N_, M_, 1,
131  A_data,
132  G_ptr,
133  0,
134  dB_data + B_index,
135  &context_);
136  G_ptr += G_stride;
137  }
138 
139  return true;
140  }
141 
142  protected:
143  TIndex K_;
144  TIndex M_;
145  TIndex N_;
146 };
147 
148 } // namespace caffe2
149 
150 #endif // CAFFE2_OPERATORS_TT_CONTRACTION_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37