Caffe2 - C++ API
A deep learning, cross platform ML framework
channel_shuffle_op.h
1 #pragma once
2 #include "caffe2/operators/conv_pool_op_base.h"
3 
4 namespace caffe2 {
5 
6 template <typename Context>
7 class ChannelShuffleOp final : public ConvPoolOpBase<Context> {
8  public:
9  USE_OPERATOR_FUNCTIONS(Context);
10  ChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws)
11  : ConvPoolOpBase<Context>(operator_def, ws) {
12  OPERATOR_NEEDS_FEATURE(
13  this->order_ == StorageOrder::NCHW,
14  "ChannelShuffleOp only supports NCHW order");
15  }
16 
17  bool RunOnDeviceWithOrderNCHW() override {
18  const auto& X = Input(0);
19  auto* Y = Output(0);
20  Y->ResizeLike(X);
21  const auto C = X.dim32(1);
22  CAFFE_ENFORCE(C % this->group_ == 0, "");
23  const auto K = C / this->group_;
24  const auto S = X.dim32(2) * X.dim32(3);
25  const auto G = this->group_;
26  for (auto n = 0; n < X.dim32(0); ++n) {
27  for (auto g = 0; g < G; ++g) {
28  // Scatter the group g block (of size KxS) to output channels
29  // g + 0 * G, g + 1 * G, g + 2 * G, g + G * (K - 1) etc.
30  math::CopyMatrix<Context>(
31  X.itemsize(),
32  K,
33  S,
34  X.template data<float>() + g * K * S + n * C * S,
35  S,
36  Y->template mutable_data<float>() + g * S + n * C * S,
37  G * S,
38  &context_,
39  X.meta().copy());
40  }
41  }
42  return true;
43  }
44 };
45 
46 template <typename Context>
47 class ChannelShuffleGradientOp final : public ConvPoolOpBase<Context> {
48  public:
49  USE_OPERATOR_FUNCTIONS(Context);
50  ChannelShuffleGradientOp(const OperatorDef& operator_def, Workspace* ws)
51  : ConvPoolOpBase<Context>(operator_def, ws) {
52  OPERATOR_NEEDS_FEATURE(
53  this->order_ == StorageOrder::NCHW,
54  "ChannelShuffleOp only supports NCHW order");
55  }
56 
57  bool RunOnDeviceWithOrderNCHW() override {
58  const auto& dY = Input(0);
59  auto* dX = Output(0);
60  dX->ResizeLike(dY);
61  const auto C = dY.dim32(1);
62  CAFFE_ENFORCE(C % this->group_ == 0, "");
63  const auto K = C / this->group_;
64  const auto S = dY.dim32(2) * dY.dim32(3);
65  const auto G = this->group_;
66  for (auto n = 0; n < dY.dim32(0); ++n) {
67  for (auto g = 0; g < G; ++g) {
68  // Gather the group g block (of size KxS) from output channels
69  // g + 0 * G, g + 1 * G, g + 2 * G, g + G * (K - 1) etc.
70  math::CopyMatrix<Context>(
71  dY.itemsize(),
72  K,
73  S,
74  dY.template data<float>() + g * S + n * C * S,
75  G * S,
76  dX->template mutable_data<float>() + g * K * S + n * C * S,
77  S,
78  &context_,
79  dY.meta().copy());
80  }
81  }
82  return true;
83  }
84 };
85 }
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...