Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_transpose_unpool_op_base.h
1 #ifndef CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_
2 #define CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/operators/conv_op_shared.h"
8 #include "caffe2/operators/conv_pool_op_base.h"
9 #include "caffe2/proto/caffe2_legacy.pb.h"
10 #include "caffe2/utils/math.h"
11 
12 CAFFE2_DECLARE_bool(caffe2_force_shared_col_buffer);
13 
14 namespace caffe2 {
15 
16 template <class Context>
17 class ConvTransposeUnpoolBase : public Operator<Context> {
18  public:
19  USE_OPERATOR_CONTEXT_FUNCTIONS;
20  ConvTransposeUnpoolBase(const OperatorDef& operator_def, Workspace* ws)
21  : Operator<Context>(operator_def, ws),
22  legacy_pad_(
23  static_cast<LegacyPadding>(OperatorBase::GetSingleArgument<int>(
24  "legacy_pad",
25  LegacyPadding::NOTSET))),
26  kernel_(OperatorBase::GetRepeatedArgument<int>("kernels")),
27  stride_(OperatorBase::GetRepeatedArgument<int>("strides")),
28  pads_(OperatorBase::GetRepeatedArgument<int>("pads")),
29  adj_(OperatorBase::GetRepeatedArgument<int>("adjs")),
30  order_(StringToStorageOrder(
31  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
32  shared_buffer_(
33  OperatorBase::GetSingleArgument<int>("shared_buffer", 0)),
34  ws_(ws) {
35  // For the padding, they should either be the legacy padding strategy
36  // (VALID or SAME), or an explicit, non-negative value.
37  if (legacy_pad_ == LegacyPadding::VALID ||
38  legacy_pad_ == LegacyPadding::SAME) {
39  CAFFE_ENFORCE(
41  "If you use legacy padding VALID or SAME, you should not specify "
42  "any specific padding values.");
43  }
44  // Get old arguments values.
45  if (OperatorBase::HasArgument("kernel")) {
46  kernel_.resize(2, OperatorBase::GetSingleArgument<int>("kernel", 0));
47  } else if (
48  OperatorBase::HasArgument("kernel_h") &&
49  OperatorBase::HasArgument("kernel_w")) {
50  kernel_.push_back(OperatorBase::GetSingleArgument<int>("kernel_h", 0));
51  kernel_.push_back(OperatorBase::GetSingleArgument<int>("kernel_w", 0));
52  }
53 
54  if (OperatorBase::HasArgument("stride")) {
55  stride_.resize(2, OperatorBase::GetSingleArgument<int>("stride", 0));
56  } else if (
57  OperatorBase::HasArgument("stride_h") &&
58  OperatorBase::HasArgument("stride_w")) {
59  stride_.push_back(OperatorBase::GetSingleArgument<int>("stride_h", 0));
60  stride_.push_back(OperatorBase::GetSingleArgument<int>("stride_w", 0));
61  }
62 
63  if (OperatorBase::HasArgument("adj")) {
64  adj_.resize(2, OperatorBase::GetSingleArgument<int>("adj", 0));
65  } else if (
66  OperatorBase::HasArgument("adj_h") &&
67  OperatorBase::HasArgument("adj_w")) {
68  adj_.push_back(OperatorBase::GetSingleArgument<int>("adj_h", 0));
69  adj_.push_back(OperatorBase::GetSingleArgument<int>("adj_w", 0));
70  }
71 
72  if (OperatorBase::HasArgument("pad")) {
73  CAFFE_ENFORCE(
74  legacy_pad_ != LegacyPadding::VALID &&
75  legacy_pad_ != LegacyPadding::SAME,
76  "If you use legacy padding VALID or SAME, you should not specify "
77  "any specific padding values.");
78  pads_.resize(4, OperatorBase::GetSingleArgument<int>("pad", 0));
79  } else if (
80  OperatorBase::HasArgument("pad_t") &&
81  OperatorBase::HasArgument("pad_l") &&
82  OperatorBase::HasArgument("pad_b") &&
83  OperatorBase::HasArgument("pad_r")) {
84  CAFFE_ENFORCE(
85  legacy_pad_ != LegacyPadding::VALID &&
86  legacy_pad_ != LegacyPadding::SAME,
87  "If you use legacy padding VALID or SAME, you should not specify "
88  "any specific padding values.");
89  pads_.push_back(OperatorBase::GetSingleArgument<int>("pad_t", 0));
90  pads_.push_back(OperatorBase::GetSingleArgument<int>("pad_l", 0));
91  pads_.push_back(OperatorBase::GetSingleArgument<int>("pad_b", 0));
92  pads_.push_back(OperatorBase::GetSingleArgument<int>("pad_r", 0));
93  }
94 
95  // Fill default values.
96  if (kernel_.size() == 0) {
97  kernel_.assign({0, 0});
98  }
99 
100  if (stride_.size() == 0) {
101  stride_.resize(kernel_.size(), 1);
102  }
103 
104  if (pads_.size() == 0) {
105  pads_.resize(kernel_.size() * 2, 0);
106  }
107 
108  if (adj_.size() == 0) {
109  adj_.resize(kernel_.size(), 0);
110  }
111 
112  CAFFE_ENFORCE_EQ(stride_.size(), kernel_.size());
113  CAFFE_ENFORCE_EQ(adj_.size(), kernel_.size());
114 
115  if (legacy_pad_ != LegacyPadding::VALID &&
116  legacy_pad_ != LegacyPadding::SAME) {
117  CAFFE_ENFORCE_EQ(pads_.size(), 2 * kernel_.size());
118  }
119 
120  for (int dim = 0; dim < kernel_.size(); ++dim) {
121  CAFFE_ENFORCE_GT(kernel_[dim], 0);
122  CAFFE_ENFORCE_GT(stride_[dim], 0);
123  CAFFE_ENFORCE_GE(adj_[dim], 0);
124  CAFFE_ENFORCE_LE(adj_[dim], stride_[dim]);
125  }
126 
127  // Create shared buffer mutex in the constructor
128  // to avoid race-condition in DAGNet.
129  if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
130  createSharedBuffer<Context>(ws_);
131  }
132  }
133  // Sets the output size. The output channel is manually specified.
134  void SetOutputSize(
135  const Tensor<Context>& input,
136  Tensor<Context>* output,
137  int output_channel) {
138  CAFFE_ENFORCE(4 == input.ndim());
139  CAFFE_ENFORCE(input.size() > 0);
140  int N = input.dim32(0);
141  bool channel_first = false; // initialized to suppress compiler warning.
142  int H = 0, W = 0; // initialized to suppress compiler warning.
143  int M = 0;
144  switch (order_) {
145  case StorageOrder::NHWC:
146  channel_first = false;
147  H = input.dim32(1);
148  W = input.dim32(2);
149  M = input.dim32(3);
150  break;
151  case StorageOrder::NCHW:
152  channel_first = true;
153  M = input.dim32(1);
154  H = input.dim32(2);
155  W = input.dim32(3);
156  break;
157  default:
158  LOG(FATAL) << "Unknown Storage order: " << order_;
159  }
160  int output_height = 0, output_width = 0;
161  ComputeSizeAndPad(
162  H,
163  stride_[0],
164  kernel_[0],
165  adj_[0],
166  &pads_[0],
167  &pads_[2],
168  &output_height);
169  ComputeSizeAndPad(
170  W,
171  stride_[1],
172  kernel_[1],
173  adj_[1],
174  &pads_[1],
175  &pads_[3],
176  &output_width);
177  if (channel_first) {
178  output->Resize(N, output_channel, output_height, output_width);
179  } else {
180  output->Resize(N, output_height, output_width, output_channel);
181  }
182  VLOG(2) << "In: N " << N << " M " << M << " H " << H << " W " << W;
183  VLOG(2) << "Out: output_channel " << output_channel << " H "
184  << output_height << " W " << output_width;
185  }
186 
187  bool RunOnDevice() override {
188  switch (order_) {
189  case StorageOrder::NHWC:
190  return RunOnDeviceWithOrderNHWC();
191  case StorageOrder::NCHW:
192  return RunOnDeviceWithOrderNCHW();
193  default:
194  LOG(FATAL) << "Unknown storage order: " << order_;
195  }
196  // To suppress old compiler warnings
197  return true;
198  }
199 
200  virtual bool RunOnDeviceWithOrderNCHW() {
201  CAFFE_THROW("Not implemented");
202  }
203 
204  virtual bool RunOnDeviceWithOrderNHWC() {
205  CAFFE_THROW("Not implemented");
206  }
207 
208  virtual ~ConvTransposeUnpoolBase() {}
209 
210  private:
211  LegacyPadding legacy_pad_;
212  int pad_;
213 
214  protected:
215  vector<int> kernel_;
216  vector<int> stride_;
217  vector<int> pads_;
218  vector<int> adj_;
219  StorageOrder order_;
220  bool shared_buffer_;
221  Workspace* ws_;
222 
223  // Accessors for 2D conv params.
224 
225  inline int pad_t() const {
226  return pads_[0];
227  }
228 
229  inline int pad_l() const {
230  return pads_[1];
231  }
232 
233  inline int pad_b() const {
234  return pads_[2];
235  }
236 
237  inline int pad_r() const {
238  return pads_[3];
239  }
240 
241  inline int kernel_h() const {
242  return kernel_[0];
243  }
244 
245  inline int kernel_w() const {
246  return kernel_[1];
247  }
248 
249  inline int stride_h() const {
250  return stride_[0];
251  }
252 
253  inline int stride_w() const {
254  return stride_[1];
255  }
256 
257  inline int adj_h() const {
258  return adj_[0];
259  }
260 
261  inline int adj_w() const {
262  return adj_[1];
263  }
264 
265  inline void ComputeSizeAndPad(
266  const int in_size,
267  const int stride,
268  const int kernel,
269  const int adj,
270  int* pad_head,
271  int* pad_tail,
272  int* out_size) {
273  switch (legacy_pad_) {
274  case LegacyPadding::NOTSET:
275  CAFFE_ENFORCE(*pad_head >= 0);
276  CAFFE_ENFORCE(*pad_tail >= 0);
277  *out_size =
278  (in_size - 1) * stride + kernel + adj - *pad_head - *pad_tail;
279  break;
280  // We handle cases of LegacyPadding::VALID and LegacyPadding::SAME
281  // the same way
282  case LegacyPadding::VALID:
283  case LegacyPadding::SAME:
284  *pad_head = 0;
285  *pad_tail = 0;
286  *out_size = (in_size - 1) * stride + kernel + adj;
287  break;
288  case LegacyPadding::CAFFE_LEGACY_POOLING:
289  LOG(FATAL) << "CAFFE_LEGACY_POOLING is no longer supported.";
290  break;
291  }
292  }
293 };
294 
295 #define USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context) \
296  USE_OPERATOR_FUNCTIONS(Context); \
297  using ConvTransposeUnpoolBase<Context>::kernel_; \
298  using ConvTransposeUnpoolBase<Context>::stride_; \
299  using ConvTransposeUnpoolBase<Context>::pads_; \
300  using ConvTransposeUnpoolBase<Context>::adj_; \
301  using ConvTransposeUnpoolBase<Context>::order_; \
302  using ConvTransposeUnpoolBase<Context>::shared_buffer_; \
303  using ConvTransposeUnpoolBase<Context>::ws_
304 
305 } // namespace caffe2
306 
307 #endif // CAFFE2_OPERATORS_CONV_TRANSPOSE_UNPOOL_OP_BASE_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
int dim32(const int i) const
Returns the i-th dimension of the tensor in int.
Definition: tensor.h:657
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:593
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:288
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37
int ndim() const
Returns the number of dimensions of the data.
Definition: tensor.h:589