Caffe2 - C++ API
A deep learning, cross platform ML framework
tt_pad_op.h
1 #ifndef CAFFE2_OPERATORS_TT_PAD_OP_H_
2 #define CAFFE2_OPERATORS_TT_PAD_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context, class Engine = DefaultEngine>
11 class TTPadOp final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  TTPadOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  scale_(OperatorBase::GetSingleArgument<TIndex>("scale", 0)) {
17  CAFFE_ENFORCE(
18  OperatorBase::HasArgument("scale"), "Argument `scale` is missing.");
19  }
20 
21  bool RunOnDevice() override {
22  const auto& X = Input(0);
23  auto* X_pad = Output(0);
24  CAFFE_ENFORCE(&X == X_pad);
25 
26  CAFFE_ENFORCE(X.ndim() == 2, X.ndim());
27 
28  auto X_dim0 = X.dim(0);
29  auto X_dim1 = X.dim(1);
30 
31  auto* X_orig_dim0 = Output(1);
32  X_orig_dim0->Resize(1);
33  *X_orig_dim0->template mutable_data<TIndex>() = X_dim0;
34 
35  if (X_dim0 % scale_ != 0) {
36  TIndex padded_dim0 = (X_dim0 / scale_ + 1) * scale_;
37  auto dim0_diff = padded_dim0 - X_dim0;
38  // set growthPct to the upper bound percentage: (100 * scale_ / X_dim0)
39  X_pad->template Extend(dim0_diff, 100 * scale_ / X_dim0, &context_);
40 
41  auto* X_pad_data = X_pad->template mutable_data<T>();
42  TIndex X_size = X_dim0 * X_dim1;
43  memset(X_pad_data + X_size, 0, dim0_diff * X_dim1 * sizeof(T));
44  }
45 
46  return true;
47  }
48 
49  protected:
50  TIndex scale_;
51 };
52 
53 template <typename T, class Context, class Engine = DefaultEngine>
54 class TTPadGradientOp final : public Operator<Context> {
55  public:
56  USE_OPERATOR_CONTEXT_FUNCTIONS;
57  TTPadGradientOp(const OperatorDef& operator_def, Workspace* ws)
58  : Operator<Context>(operator_def, ws) {}
59 
60  bool RunOnDevice() override {
61  const auto& G = Input(0);
62  auto* output = Output(0);
63  CAFFE_ENFORCE(&G == output);
64 
65  auto old_dim0 = *Input(1).template data<TIndex>();
66  auto new_dim0 = G.dim(0);
67  auto dim1 = G.dim(1);
68 
69  if (old_dim0 < new_dim0) {
70  output->Shrink(old_dim0);
71  }
72 
73  return true;
74  }
75 };
76 
77 } // namespace caffe2
78 
79 #endif // CAFFE2_OPERATORS_TT_PAD_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37