1 #ifndef CAFFE2_OPERATORS_TT_LINEAR_OP_H_ 2 #define CAFFE2_OPERATORS_TT_LINEAR_OP_H_ 6 #endif // CAFFE2_USE_MKL 10 #include "caffe2/core/context.h" 11 #include "caffe2/core/operator.h" 12 #include "caffe2/utils/math.h" 16 template <
typename T,
class Context,
class Engine = DefaultEngine>
19 USE_OPERATOR_CONTEXT_FUNCTIONS;
22 inp_sizes_(OperatorBase::GetRepeatedArgument<int>(
"inp_sizes")),
23 out_sizes_(OperatorBase::GetRepeatedArgument<int>(
"out_sizes")),
24 tt_ranks_(OperatorBase::GetRepeatedArgument<int>(
"tt_ranks")),
25 Y_temp_(unique_ptr<Blob>(
new Blob())) {}
28 bool RunOnDevice()
override {
29 const auto& X = Input(0);
30 const auto& b = Input(1);
31 const auto& cores = Input(2);
34 CAFFE_ENFORCE(X.ndim() > 1,
"Number of dimensions in X: ", X.ndim());
35 CAFFE_ENFORCE(b.ndim() == 1,
"Number of dimensions in b: ", b.ndim());
37 inp_sizes_.size() == out_sizes_.size(),
38 "inp_sizes has size: ",
40 ", out_sizes has size: ",
43 cores.ndim() == 1,
"Number of dimensions in cores: ", cores.ndim());
45 const int batch_size = X.ndim() > 1 ? X.dim32(0) : 1;
48 const int d = inp_sizes_.size();
61 for (
int i = (d - 1); i >= 0; --i) {
62 int curr_rows = inp_sizes_[i] * tt_ranks_[i + 1];
63 int curr_cols = tt_ranks_[i] * out_sizes_[i];
66 Y_buf->Resize(Y_buf->size() / curr_rows, curr_rows);
67 Y->Resize(Y_buf->size() / curr_rows, curr_cols);
70 CAFFE_ENFORCE(Y_buf->size() % curr_rows == 0, Y_buf->size(), curr_rows);
72 cores_idx + curr_rows * curr_cols <= cores.size(),
73 cores_idx + curr_rows * curr_cols,
77 math::Gemm<float, Context, Engine>(
80 Y_buf->size() / curr_rows,
84 Y_buf->template data<float>(),
85 cores.template data<float>() + cores_idx,
87 Y->template mutable_data<float>(),
90 CAFFE_ENFORCE(Y->size() % out_sizes_[i] == 0, Y->size(), out_sizes_[i]);
93 auto Y_mat = EigenMatrixMap<float>(
94 Y->template mutable_data<float>(),
95 Y->size() / out_sizes_[i],
97 Y_mat = ConstEigenMatrixMap<float>(
98 Y->template data<float>(),
100 Y->size() / out_sizes_[i])
105 Y_buf->Resize(Y->dim32(0), Y->dim32(1));
106 context_.template Copy<float, CPUContext, CPUContext>(
108 Y->template data<float>(),
109 Y_buf->template mutable_data<float>());
111 cores_idx += curr_rows * curr_cols;
115 auto Y_mat = EigenMatrixMap<float>(
116 Y->template mutable_data<float>(), batch_size, Y->size() / batch_size);
117 Y_mat = ConstEigenMatrixMap<float>(
118 Y->template data<float>(), Y->size() / batch_size, batch_size)
122 Y->Resize(batch_size, Y->size() / batch_size);
125 int prod_out_sizes = 1;
126 for (
int i = 0; i < out_sizes_.size(); i++) {
127 prod_out_sizes *= out_sizes_[i];
130 Y->dim32(1) == prod_out_sizes,
131 "Output dimension of Y: ",
133 ", product of out_sizes: ",
137 if (bias_multiplier_.size() != batch_size) {
139 bias_multiplier_.Resize(batch_size);
140 math::Set<T, Context>(
143 bias_multiplier_.template mutable_data<T>(),
146 math::Gemm<T, Context, Engine>(
153 bias_multiplier_.template data<T>(),
154 b.template data<T>(),
156 Y->template mutable_data<T>(),
163 std::vector<int> inp_sizes_;
164 std::vector<int> out_sizes_;
165 std::vector<int> tt_ranks_;
166 std::unique_ptr<Blob> Y_temp_;
170 template <
typename T,
class Context,
class Engine = DefaultEngine>
173 USE_OPERATOR_CONTEXT_FUNCTIONS;
178 bool RunOnDevice()
override {
188 #endif // CAFFE2_OPERATORS_TT_LINEAR_OP_H_ Blob is a general container that hosts a typed pointer.
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void ResizeLike(const Tensor< OtherContext > &src_tensor)
Resize the tensor like the source tensor.