Caffe2 - C++ API
A deep learning, cross platform ML framework
tt_linear_op.h
1 #ifndef CAFFE2_OPERATORS_TT_LINEAR_OP_H_
2 #define CAFFE2_OPERATORS_TT_LINEAR_OP_H_
3 
4 #ifdef CAFFE2_USE_MKL
5 #include <mkl.h>
6 #endif // CAFFE2_USE_MKL
7 
8 #include "Eigen/Core"
9 #include "Eigen/Dense"
10 #include "caffe2/core/context.h"
11 #include "caffe2/core/operator.h"
12 #include "caffe2/utils/math.h"
13 
14 namespace caffe2 {
15 
16 template <typename T, class Context, class Engine = DefaultEngine>
17 class TTLinearOp final : public Operator<Context> {
18  public:
19  USE_OPERATOR_CONTEXT_FUNCTIONS;
20  TTLinearOp(const OperatorDef& operator_def, Workspace* ws)
21  : Operator<Context>(operator_def, ws),
22  inp_sizes_(OperatorBase::GetRepeatedArgument<int>("inp_sizes")),
23  out_sizes_(OperatorBase::GetRepeatedArgument<int>("out_sizes")),
24  tt_ranks_(OperatorBase::GetRepeatedArgument<int>("tt_ranks")),
25  Y_temp_(unique_ptr<Blob>(new Blob())) {}
26  ~TTLinearOp() {}
27 
28  bool RunOnDevice() override {
29  const auto& X = Input(0); // Input array
30  const auto& b = Input(1); // Bias array
31  const auto& cores = Input(2); // 1D array containing the TT-cores
32  auto* Y = Output(0);
33 
34  CAFFE_ENFORCE(X.ndim() > 1, "Number of dimensions in X: ", X.ndim());
35  CAFFE_ENFORCE(b.ndim() == 1, "Number of dimensions in b: ", b.ndim());
36  CAFFE_ENFORCE(
37  inp_sizes_.size() == out_sizes_.size(),
38  "inp_sizes has size: ",
39  inp_sizes_.size(),
40  ", out_sizes has size: ",
41  out_sizes_.size());
42  CAFFE_ENFORCE(
43  cores.ndim() == 1, "Number of dimensions in cores: ", cores.ndim());
44  // batch size
45  const int batch_size = X.ndim() > 1 ? X.dim32(0) : 1;
46 
47  // dimension d of tensors
48  const int d = inp_sizes_.size();
49 
50  // Keep track of index of current core in multiplication
51  int cores_idx = 0;
52 
53  // Temporary buffer to facilitate multiplication of TT-cores with input
54  auto Y_buf = Y_temp_->GetMutable<Tensor<Context>>();
55  Y_buf->ResizeLike(X);
56  Y_buf->CopyFrom(X);
57 
58  // The overall forward pass involves multiplication with each core, where
59  // each core has sizes dictated by inp_sizes_ and out_sizes_. Each core thus
60  // has size inp_sizes_[i] * tt_ranks_[i] * tt_ranks_[i + 1] * out_sizes_[i].
61  for (int i = (d - 1); i >= 0; --i) {
62  int curr_rows = inp_sizes_[i] * tt_ranks_[i + 1];
63  int curr_cols = tt_ranks_[i] * out_sizes_[i];
64 
65  // TODO Replace by Reshape(), once wrappers are written
66  Y_buf->Resize(Y_buf->size() / curr_rows, curr_rows);
67  Y->Resize(Y_buf->size() / curr_rows, curr_cols);
68 
69  // Defensive checks
70  CAFFE_ENFORCE(Y_buf->size() % curr_rows == 0, Y_buf->size(), curr_rows);
71  CAFFE_ENFORCE(
72  cores_idx + curr_rows * curr_cols <= cores.size(),
73  cores_idx + curr_rows * curr_cols,
74  cores.size());
75 
76  // Multiply ith core with the intermediate output
77  math::Gemm<float, Context, Engine>(
78  CblasNoTrans,
79  CblasNoTrans,
80  Y_buf->size() / curr_rows,
81  curr_cols,
82  curr_rows,
83  1,
84  Y_buf->template data<float>(),
85  cores.template data<float>() + cores_idx,
86  0,
87  Y->template mutable_data<float>(),
88  &context_);
89 
90  CAFFE_ENFORCE(Y->size() % out_sizes_[i] == 0, Y->size(), out_sizes_[i]);
91 
92  // TODO Add GPU support by writing a generic wrapper.
93  auto Y_mat = EigenMatrixMap<float>(
94  Y->template mutable_data<float>(),
95  Y->size() / out_sizes_[i],
96  out_sizes_[i]);
97  Y_mat = ConstEigenMatrixMap<float>(
98  Y->template data<float>(),
99  out_sizes_[i],
100  Y->size() / out_sizes_[i])
101  .transpose()
102  .eval();
103 
104  // Resize operation
105  Y_buf->Resize(Y->dim32(0), Y->dim32(1));
106  context_.template Copy<float, CPUContext, CPUContext>(
107  Y->size(),
108  Y->template data<float>(),
109  Y_buf->template mutable_data<float>());
110 
111  cores_idx += curr_rows * curr_cols;
112  }
113 
114  // TODO Add GPU support by writing a generic wrapper.
115  auto Y_mat = EigenMatrixMap<float>(
116  Y->template mutable_data<float>(), batch_size, Y->size() / batch_size);
117  Y_mat = ConstEigenMatrixMap<float>(
118  Y->template data<float>(), Y->size() / batch_size, batch_size)
119  .transpose()
120  .eval();
121  // TODO Replace by Reshape(), once wrappers are written
122  Y->Resize(batch_size, Y->size() / batch_size);
123 
124  // Check that output size of Y is the element-wise product of out_sizes
125  int prod_out_sizes = 1;
126  for (int i = 0; i < out_sizes_.size(); i++) {
127  prod_out_sizes *= out_sizes_[i];
128  }
129  CAFFE_ENFORCE(
130  Y->dim32(1) == prod_out_sizes,
131  "Output dimension of Y: ",
132  Y->dim32(1),
133  ", product of out_sizes: ",
134  prod_out_sizes);
135 
136  // Add bias term
137  if (bias_multiplier_.size() != batch_size) {
138  // If the helper bias multiplier is not M, reshape and fill it with one.
139  bias_multiplier_.Resize(batch_size);
140  math::Set<T, Context>(
141  batch_size,
142  static_cast<T>(1),
143  bias_multiplier_.template mutable_data<T>(),
144  &context_);
145  }
146  math::Gemm<T, Context, Engine>(
147  CblasNoTrans,
148  CblasNoTrans,
149  Y->dim32(0),
150  Y->dim32(1),
151  1,
152  1,
153  bias_multiplier_.template data<T>(),
154  b.template data<T>(),
155  1,
156  Y->template mutable_data<T>(),
157  &context_);
158  return true;
159  }
160 
161  protected:
162  Tensor<Context> bias_multiplier_;
163  std::vector<int> inp_sizes_;
164  std::vector<int> out_sizes_;
165  std::vector<int> tt_ranks_;
166  std::unique_ptr<Blob> Y_temp_;
167 };
168 
169 // TODO: Complete after verifying utility of TT-layer's forward pass.
170 template <typename T, class Context, class Engine = DefaultEngine>
171 class TTLinearGradientOp : public Operator<Context> {
172  public:
173  USE_OPERATOR_CONTEXT_FUNCTIONS;
174  TTLinearGradientOp(const OperatorDef& operator_def, Workspace* ws)
175  : Operator<Context>(operator_def, ws) {}
176  ~TTLinearGradientOp() {}
177 
178  bool RunOnDevice() override {
179  return false;
180  }
181 
182  protected:
183  Tensor<Context> bias_multiplier_;
184 };
185 
186 } // namespace caffe2
187 
188 #endif // CAFFE2_OPERATORS_TT_LINEAR_OP_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void ResizeLike(const Tensor< OtherContext > &src_tensor)
Resize the tensor like the source tensor.
Definition: tensor.h:315