Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_transpose_op.h
1 #ifndef CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_H_
2 #define CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/operators/conv_transpose_unpool_op_base.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class ConvTransposeOp final : public ConvTransposeUnpoolBase<Context> {
12  public:
13  USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context);
14  ConvTransposeOp(const OperatorDef& operator_def, Workspace* ws)
15  : ConvTransposeUnpoolBase<Context>(operator_def, ws) {}
16 
17  bool RunOnDeviceWithOrderNCHW() override;
18  bool RunOnDeviceWithOrderNHWC() override;
19 
20  private:
21  Tensor<Context> col_buffer_;
22  Tensor<Context> bias_multiplier_;
23  // Input: X, W, b
24  // Output: Y
25  INPUT_TAGS(INPUT, FILTER, BIAS);
26 };
27 
28 template <typename T, class Context>
29 class ConvTransposeGradientOp final : public ConvTransposeUnpoolBase<Context> {
30  public:
31  USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(Context);
32  ConvTransposeGradientOp(const OperatorDef& operator_def, Workspace* ws)
33  : ConvTransposeUnpoolBase<Context>(operator_def, ws),
34  no_bias_(OperatorBase::GetSingleArgument<bool>("no_bias", false)) {
35  CAFFE_ENFORCE(
36  !(no_bias_ && OutputSize() == 3),
37  "If bias is not present, you should not have 3 grad output.");
38  }
39 
40  bool RunOnDeviceWithOrderNCHW() override;
41  bool RunOnDeviceWithOrderNHWC() override;
42 
43  private:
44  Tensor<Context> col_buffer_;
45  Tensor<Context> bias_multiplier_;
46  const bool no_bias_;
47  // input: X, W, dY
48  // output: dW, optionally db and dX
49  INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
50  OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
51 };
52 
53 } // namespace caffe2
54 
55 #endif // CAFFE2_OPERATORS_CONV_TRANSPOSE_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...