1 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_DECOMPOSITION_H_ 2 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_DECOMPOSITION_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 17 template <
typename T,
class Context,
class Engine=DefaultEngine>
20 USE_OPERATOR_CONTEXT_FUNCTIONS;
25 bool RunOnDevice()
override {
26 const auto& X = Input(0);
27 const auto& U = Input(1);
28 const auto& V = Input(2);
29 const auto& b = Input(3);
34 CAFFE_ENFORCE_GE(X.ndim(), 1);
35 CAFFE_ENFORCE_GE(U.ndim(), 2);
36 CAFFE_ENFORCE_GE(V.ndim(), 2);
37 if (X.ndim() > 2 || U.ndim() > 2 || V.ndim() > 2) {
38 VLOG(1) <<
"Using legacy support for arbitrary input and weight " 41 CAFFE_ENFORCE_EQ(b.ndim(), 1);
43 int M = X.ndim() > 1 ? X.dim32(0) : 1;
48 int middle = U.dim32(0);
49 CAFFE_ENFORCE_EQ(K, V.dim32(0));
50 CAFFE_ENFORCE_EQ(N, b.dim32(0));
53 multi_buffer_.Resize(M, middle);
56 multi_buffer_.Resize(middle);
61 T* multi_buffer_data = multi_buffer_.template mutable_data<T>();
63 math::Gemm<T, Context, Engine>(
64 CblasNoTrans, CblasNoTrans, M, middle, K, 1, X.template data<T>(),
65 V.template data<T>(), 0, multi_buffer_data,
67 math::Gemm<T, Context, Engine>(
68 CblasNoTrans, CblasTrans, M, N, middle, 1, multi_buffer_data,
69 U.template data<T>(), 0, Y->template mutable_data<T>(),
72 if (bias_multiplier_.size() != M) {
74 bias_multiplier_.Resize(M);
75 math::Set<T, Context>(
76 M,
static_cast<T
>(1), bias_multiplier_.template mutable_data<T>(),
79 math::Gemm<T, Context, Engine>(
80 CblasNoTrans, CblasNoTrans, M, N, 1, 1,
81 bias_multiplier_.template data<T>(), b.template data<T>(), 1,
82 Y->template mutable_data<T>(), &context_);
91 template <
typename T,
class Context,
class Engine=DefaultEngine>
94 USE_OPERATOR_CONTEXT_FUNCTIONS;
99 bool RunOnDevice()
override {
100 const auto& X = Input(0);
101 const auto& U = Input(1);
102 const auto& V = Input(2);
103 const auto& dY = Input(3);
104 DCHECK_GE(X.ndim(), 1);
105 DCHECK_GE(U.ndim(), 2);
106 DCHECK_GE(V.ndim(), 2);
107 DCHECK_LE(dY.ndim(), 2);
109 int M = X.ndim() > 1 ? X.dim32(0) : 1;
111 int K = X.size() / M;
114 int middle = U.dim32(1);
115 DCHECK_EQ(K, V.dim32(0));
117 DCHECK_EQ(M, dY.dim32(0));
118 DCHECK_EQ(N, dY.dim32(1));
120 DCHECK_EQ(X.ndim(), 1);
121 DCHECK_EQ(N, dY.size());
123 auto* dU = Output(0);
124 auto* dV = Output(1);
125 auto* db = Output(2);
132 du_buffer_.Resize(N, middle);
133 T* du_buffer_data = du_buffer_.template mutable_data<T>();
134 math::Gemm<T, Context, Engine>(
135 CblasNoTrans, CblasNoTrans, M, middle, K, 1,
136 X.template data<T>(), V.template data<T>(),
139 math::Gemm<T, Context, Engine>(
140 CblasTrans, CblasNoTrans, N, middle, M, 1,
141 dY.template data<T>(), du_buffer_data,
142 0, dU->template mutable_data<T>(),
146 dv_buffer_.Resize(M, middle);
147 T* dv_buffer_data = dv_buffer_.template mutable_data<T>();
148 math::Gemm<T, Context, Engine>(
149 CblasNoTrans, CblasNoTrans, M, middle, N, 1,
150 dY.template data<T>(), U.template data<T>(),
153 math::Gemm<T, Context, Engine>(
154 CblasTrans, CblasNoTrans, K, middle, M, 1,
155 dY.template data<T>(), du_buffer_data,
156 0, dV->template mutable_data<T>(),
158 if (bias_multiplier_.size() != M) {
160 bias_multiplier_.Resize(M);
161 math::Set<T, Context>(
162 M,
static_cast<T
>(1),
163 bias_multiplier_.template mutable_data<T>(),
167 math::Gemv<T, Context>(
168 CblasTrans, M, N, 1, dY.template data<T>(),
169 bias_multiplier_.template data<T>(), 0,
170 db->template mutable_data<T>(),
173 if (OutputSize() == 4) {
174 auto* dX = Output(3);
176 dx_buffer_.Resize(M, middle);
177 T* dx_buffer_data = dx_buffer_.template mutable_data<T>();
178 math::Gemm<T, Context, Engine>(
179 CblasNoTrans, CblasNoTrans, M, middle, N, 1,
180 dY.template data<T>(), U.template data<T>(),
183 math::Gemm<T, Context, Engine>(
184 CblasNoTrans, CblasTrans, M, K, middle, 1,
185 dx_buffer_data, V.template data<T>(),
186 0, dX->template mutable_data<T>(),
202 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...