Caffe2 - C++ API
A deep learning, cross platform ML framework
distance_op.h
1 #ifndef CAFFE2_OPERATORS_DISTANCE_OP_H_
2 #define CAFFE2_OPERATORS_DISTANCE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class SquaredL2DistanceOp : public Operator<Context> {
12  public:
13  SquaredL2DistanceOp(const OperatorDef& def, Workspace* ws)
14  : Operator<Context>(def, ws) {}
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16 
17  bool RunOnDevice() override;
18 
19  protected:
20  // Input: X, Y; Output: Distance
21 };
22 
23 template <typename T, class Context>
24 class SquaredL2DistanceGradientOp final : public Operator<Context> {
25  public:
26  SquaredL2DistanceGradientOp(const OperatorDef& def, Workspace* ws)
27  : Operator<Context>(def, ws) {}
28  USE_OPERATOR_CONTEXT_FUNCTIONS;
29 
30  bool RunOnDevice() override {
31  auto& X = Input(0);
32  auto& Y = Input(1);
33  auto& dDistance = Input(2);
34  auto* dX = Output(0);
35  auto* dY = Output(1);
36  int N = X.ndim() > 0 ? X.dim32(0) : 1;
37  int D = N > 0 ? X.size() / N : 0;
38  CAFFE_ENFORCE(X.ndim() == Y.ndim());
39  for (int i = 0; i < X.ndim(); ++i) {
40  CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i));
41  }
42  CAFFE_ENFORCE(dDistance.ndim() == 1);
43  CAFFE_ENFORCE(dDistance.dim32(0) == N);
44  dX->ResizeLike(X);
45  dY->ResizeLike(Y);
46  math::Sub<T, Context>(
47  X.size(),
48  X.template data<T>(),
49  Y.template data<T>(),
50  dX->template mutable_data<T>(),
51  &context_);
52  for (int i = 0; i < N; ++i) {
53  math::Scale<T, Context>(
54  D,
55  dDistance.template data<T>() + i,
56  dX->template data<T>() + i * D,
57  dX->template mutable_data<T>() + i * D,
58  &context_);
59  }
60  // The gradient of the other side is basically the negative.
61  math::Scale<T, Context>(
62  X.size(),
63  -1,
64  dX->template data<T>(),
65  dY->template mutable_data<T>(),
66  &context_);
67  return true;
68  }
69 
70  protected:
71  // Input: X, Y, dDistance; Output: dX, dY
72 };
73 
74 template <typename T, class Context>
75 class L1DistanceOp : public Operator<Context> {
76  public:
77  L1DistanceOp(const OperatorDef& def, Workspace* ws)
78  : Operator<Context>(def, ws) {}
79  USE_OPERATOR_CONTEXT_FUNCTIONS;
80 
81  bool RunOnDevice() override;
82 
83  protected:
84  // Input: X, Y; Output: Distance
85 };
86 
87 template <typename T, class Context>
88 class L1DistanceGradientOp : public Operator<Context> {
89  public:
90  L1DistanceGradientOp(const OperatorDef& def, Workspace* ws)
91  : Operator<Context>(def, ws) {}
92  USE_OPERATOR_CONTEXT_FUNCTIONS;
93 
94  bool RunOnDevice() override;
95 
96  protected:
97  // Input: X, Y, dDistance; Output: dX, dY
98 };
99 
100 template <typename T, class Context>
101 class DotProductOp : public Operator<Context> {
102  public:
103  DotProductOp(const OperatorDef& def, Workspace* ws)
104  : Operator<Context>(def, ws) {}
105  USE_OPERATOR_CONTEXT_FUNCTIONS;
106 
107  bool RunOnDevice() override;
108 
109  protected:
110  INPUT_TAGS(X_IN, Y_IN);
111  OUTPUT_TAGS(DOT_OUT);
112 };
113 
114 template <typename T, class Context>
115 class DotProductGradientOp final : public Operator<Context> {
116  public:
117  DotProductGradientOp(const OperatorDef& def, Workspace* ws)
118  : Operator<Context>(def, ws) {}
119  USE_OPERATOR_CONTEXT_FUNCTIONS;
120 
121  bool RunOnDevice() override;
122 
123  protected:
124  INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN);
125  OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
126 };
127 
128 template <typename T, class Context>
129 class DotProductWithPaddingOp : public Operator<Context> {
130  public:
131  DotProductWithPaddingOp(const OperatorDef& def, Workspace* ws)
132  : Operator<Context>(def, ws),
133  pad_value_(OperatorBase::GetSingleArgument<float>("pad_value", 0.0)),
134  replicate_(OperatorBase::GetSingleArgument<bool>("replicate", false)) {}
135  USE_OPERATOR_CONTEXT_FUNCTIONS;
136 
137  bool RunOnDevice() override;
138 
139  protected:
140  float pad_value_;
141  bool replicate_;
142  INPUT_TAGS(X_IN, Y_IN);
143  OUTPUT_TAGS(DOT_OUT);
144 };
145 
146 template <typename T, class Context>
147 class CosineSimilarityOp : public Operator<Context> {
148  public:
149  CosineSimilarityOp(const OperatorDef& def, Workspace* ws)
150  : Operator<Context>(def, ws) {}
151  USE_OPERATOR_CONTEXT_FUNCTIONS;
152 
153  bool RunOnDevice() override;
154 
155  protected:
156  INPUT_TAGS(X_IN, Y_IN);
157  OUTPUT_TAGS(COS_OUT);
158 
159  private:
160  Tensor<Context> aux_;
161 };
162 
163 template <typename T, class Context>
164 class CosineSimilarityGradientOp final : public Operator<Context> {
165  public:
166  CosineSimilarityGradientOp(const OperatorDef& def, Workspace* ws)
167  : Operator<Context>(def, ws) {}
168  USE_OPERATOR_CONTEXT_FUNCTIONS;
169 
170  bool RunOnDevice() override;
171 
172  protected:
173  INPUT_TAGS(X_IN, Y_IN, DER_COS_IN);
174  OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
175 
176  private:
177  Tensor<Context> aux_;
178 };
179 
180 template <typename T, class Context>
181 class DotProductWithPaddingGradientOp final : public Operator<Context> {
182  public:
183  DotProductWithPaddingGradientOp(const OperatorDef& def, Workspace* ws)
184  : Operator<Context>(def, ws),
185  pad_value_(OperatorBase::GetSingleArgument<float>("pad_value", 0.0)),
186  replicate_(OperatorBase::GetSingleArgument<bool>("replicate", false)) {}
187  USE_OPERATOR_CONTEXT_FUNCTIONS;
188 
189  bool RunOnDevice() override {
190  auto& X = Input(X_IN);
191  auto& Y = Input(Y_IN);
192  auto& dDot = Input(DER_DOT_IN);
193  auto* dX = Output(DER_X_OUT);
194  auto* dY = Output(DER_Y_OUT);
195  int N, D, DX, DY, restD;
196  if (X.size() > 0) {
197  N = X.ndim() > 0 ? X.dim32(0) : 1;
198  DX = X.size() / N;
199  DY = Y.size() / N;
200  } else {
201  N = 0;
202  DX = 0;
203  DY = 0;
204  }
205  CAFFE_ENFORCE(!replicate_ || DX % DY == 0 || DY % DX == 0);
206  D = std::min(DX, DY);
207  restD = std::max(DX, DY) - D;
208  CAFFE_ENFORCE_EQ(X.ndim(), Y.ndim());
209  CAFFE_ENFORCE_EQ(X.dim32(0), Y.dim32(0));
210  CAFFE_ENFORCE_EQ(dDot.ndim(), 1);
211  CAFFE_ENFORCE_EQ(dDot.dim32(0), N);
212  dX->ResizeLike(X);
213  dY->ResizeLike(Y);
214 
215  const auto* X_data = X.template data<T>();
216  const auto* Y_data = Y.template data<T>();
217  const auto* dDot_data = dDot.template data<T>();
218  auto* dX_data = dX->template mutable_data<T>();
219  auto* dY_data = dY->template mutable_data<T>();
220  for (int i = 0; i < N; ++i) { // TODO: multithreading
221  auto offsetX = i * DX;
222  auto offsetY = i * DY;
223  if (replicate_) {
224  // L_ for longer vector and S_ for shorter vector
225  const T *L_data, *S_data;
226  T *dL_data, *dS_data;
227  int DL, DS;
228  if (DX > DY) {
229  L_data = X_data + offsetX;
230  S_data = Y_data + offsetY;
231  dL_data = dX_data + offsetX;
232  dS_data = dY_data + offsetY;
233  DL = DX;
234  DS = DY;
235  } else {
236  L_data = Y_data + offsetY;
237  S_data = X_data + offsetX;
238  dL_data = dY_data + offsetY;
239  dS_data = dX_data + offsetX;
240  DL = DY;
241  DS = DX;
242  }
243 
244  // TODO: get rid of temp memory use
245  std::vector<T> tmp_data(DS);
246  math::Set<T, Context>(DS, 0.0, dS_data, &context_);
247  for (int j = 0; j < DL / DS; j++) {
248  math::Scale<T, Context>(
249  DS, dDot_data[i], S_data, dL_data + j * DS, &context_);
250  math::Scale<T, Context>(
251  DS, dDot_data[i], L_data + j * DS, tmp_data.data(), &context_);
252  math::Axpy<T, Context>(DS, 1.0, tmp_data.data(), dS_data, &context_);
253  }
254  } else {
255  math::Scale<T, Context>(
256  D, dDot_data[i], X_data + offsetX, dY_data + offsetY, &context_);
257  math::Scale<T, Context>(
258  D, dDot_data[i], Y_data + offsetY, dX_data + offsetX, &context_);
259  }
260 
261  if (!replicate_ && DX != DY) {
262  T* rest_data;
263  if (DX > DY) {
264  rest_data = dX_data + offsetX + D;
265  } else {
266  rest_data = dY_data + offsetY + D;
267  }
268  auto pad_gradient = dDot_data[i] * pad_value_;
269  math::Set<T, Context>(restD, pad_gradient, rest_data, &context_);
270  }
271  }
272 
273  return true;
274  }
275 
276  protected:
277  float pad_value_;
278  bool replicate_;
279  INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN);
280  OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
281 };
282 
283 } // namespace caffe2
284 
285 #endif // CAFFE2_OPERATORS_DISTANCE_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...