1 #ifndef CAFFE2_OPERATORS_DISTANCE_OP_H_ 2 #define CAFFE2_OPERATORS_DISTANCE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
typename T,
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 bool RunOnDevice()
override;
23 template <
typename T,
class Context>
28 USE_OPERATOR_CONTEXT_FUNCTIONS;
30 bool RunOnDevice()
override {
33 auto& dDistance = Input(2);
36 int N = X.ndim() > 0 ? X.dim32(0) : 1;
37 int D = N > 0 ? X.size() / N : 0;
38 CAFFE_ENFORCE(X.ndim() == Y.ndim());
39 for (
int i = 0; i < X.ndim(); ++i) {
40 CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i));
42 CAFFE_ENFORCE(dDistance.ndim() == 1);
43 CAFFE_ENFORCE(dDistance.dim32(0) == N);
46 math::Sub<T, Context>(
50 dX->template mutable_data<T>(),
52 for (
int i = 0; i < N; ++i) {
53 math::Scale<T, Context>(
55 dDistance.template data<T>() + i,
56 dX->template data<T>() + i * D,
57 dX->template mutable_data<T>() + i * D,
61 math::Scale<T, Context>(
64 dX->template data<T>(),
65 dY->template mutable_data<T>(),
74 template <
typename T,
class Context>
79 USE_OPERATOR_CONTEXT_FUNCTIONS;
81 bool RunOnDevice()
override;
87 template <
typename T,
class Context>
92 USE_OPERATOR_CONTEXT_FUNCTIONS;
94 bool RunOnDevice()
override;
100 template <
typename T,
class Context>
105 USE_OPERATOR_CONTEXT_FUNCTIONS;
107 bool RunOnDevice()
override;
110 INPUT_TAGS(X_IN, Y_IN);
111 OUTPUT_TAGS(DOT_OUT);
114 template <
typename T,
class Context>
119 USE_OPERATOR_CONTEXT_FUNCTIONS;
121 bool RunOnDevice()
override;
124 INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN);
125 OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
128 template <
typename T,
class Context>
133 pad_value_(OperatorBase::GetSingleArgument<float>(
"pad_value", 0.0)),
134 replicate_(OperatorBase::GetSingleArgument<bool>(
"replicate",
false)) {}
135 USE_OPERATOR_CONTEXT_FUNCTIONS;
137 bool RunOnDevice()
override;
142 INPUT_TAGS(X_IN, Y_IN);
143 OUTPUT_TAGS(DOT_OUT);
146 template <
typename T,
class Context>
151 USE_OPERATOR_CONTEXT_FUNCTIONS;
153 bool RunOnDevice()
override;
156 INPUT_TAGS(X_IN, Y_IN);
157 OUTPUT_TAGS(COS_OUT);
163 template <
typename T,
class Context>
168 USE_OPERATOR_CONTEXT_FUNCTIONS;
170 bool RunOnDevice()
override;
173 INPUT_TAGS(X_IN, Y_IN, DER_COS_IN);
174 OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
180 template <
typename T,
class Context>
185 pad_value_(OperatorBase::GetSingleArgument<float>(
"pad_value", 0.0)),
186 replicate_(OperatorBase::GetSingleArgument<bool>(
"replicate",
false)) {}
187 USE_OPERATOR_CONTEXT_FUNCTIONS;
189 bool RunOnDevice()
override {
190 auto& X = Input(X_IN);
191 auto& Y = Input(Y_IN);
192 auto& dDot = Input(DER_DOT_IN);
193 auto* dX = Output(DER_X_OUT);
194 auto* dY = Output(DER_Y_OUT);
195 int N, D, DX, DY, restD;
197 N = X.ndim() > 0 ? X.dim32(0) : 1;
205 CAFFE_ENFORCE(!replicate_ || DX % DY == 0 || DY % DX == 0);
206 D = std::min(DX, DY);
207 restD = std::max(DX, DY) - D;
208 CAFFE_ENFORCE_EQ(X.ndim(), Y.ndim());
209 CAFFE_ENFORCE_EQ(X.dim32(0), Y.dim32(0));
210 CAFFE_ENFORCE_EQ(dDot.ndim(), 1);
211 CAFFE_ENFORCE_EQ(dDot.dim32(0), N);
215 const auto* X_data = X.template data<T>();
216 const auto* Y_data = Y.template data<T>();
217 const auto* dDot_data = dDot.template data<T>();
218 auto* dX_data = dX->template mutable_data<T>();
219 auto* dY_data = dY->template mutable_data<T>();
220 for (
int i = 0; i < N; ++i) {
221 auto offsetX = i * DX;
222 auto offsetY = i * DY;
225 const T *L_data, *S_data;
226 T *dL_data, *dS_data;
229 L_data = X_data + offsetX;
230 S_data = Y_data + offsetY;
231 dL_data = dX_data + offsetX;
232 dS_data = dY_data + offsetY;
236 L_data = Y_data + offsetY;
237 S_data = X_data + offsetX;
238 dL_data = dY_data + offsetY;
239 dS_data = dX_data + offsetX;
245 std::vector<T> tmp_data(DS);
246 math::Set<T, Context>(DS, 0.0, dS_data, &context_);
247 for (
int j = 0; j < DL / DS; j++) {
248 math::Scale<T, Context>(
249 DS, dDot_data[i], S_data, dL_data + j * DS, &context_);
250 math::Scale<T, Context>(
251 DS, dDot_data[i], L_data + j * DS, tmp_data.data(), &context_);
252 math::Axpy<T, Context>(DS, 1.0, tmp_data.data(), dS_data, &context_);
255 math::Scale<T, Context>(
256 D, dDot_data[i], X_data + offsetX, dY_data + offsetY, &context_);
257 math::Scale<T, Context>(
258 D, dDot_data[i], Y_data + offsetY, dX_data + offsetX, &context_);
261 if (!replicate_ && DX != DY) {
264 rest_data = dX_data + offsetX + D;
266 rest_data = dY_data + offsetY + D;
268 auto pad_gradient = dDot_data[i] * pad_value_;
269 math::Set<T, Context>(restD, pad_gradient, rest_data, &context_);
279 INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN);
280 OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
285 #endif // CAFFE2_OPERATORS_DISTANCE_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...