2 #include "caffe2/operators/conv_pool_op_base.h" 6 template <
typename Context>
9 USE_OPERATOR_FUNCTIONS(Context);
12 OPERATOR_NEEDS_FEATURE(
13 this->order_ == StorageOrder::NCHW,
14 "ChannelShuffleOp only supports NCHW order");
17 bool RunOnDeviceWithOrderNCHW()
override {
18 const auto& X = Input(0);
21 const auto C = X.dim32(1);
22 CAFFE_ENFORCE(C % this->group_ == 0,
"");
23 const auto K = C / this->group_;
24 const auto S = X.dim32(2) * X.dim32(3);
25 const auto G = this->group_;
26 for (
auto n = 0; n < X.dim32(0); ++n) {
27 for (
auto g = 0; g < G; ++g) {
30 math::CopyMatrix<Context>(
34 X.template data<float>() + g * K * S + n * C * S,
36 Y->template mutable_data<float>() + g * S + n * C * S,
46 template <
typename Context>
49 USE_OPERATOR_FUNCTIONS(Context);
52 OPERATOR_NEEDS_FEATURE(
53 this->order_ == StorageOrder::NCHW,
54 "ChannelShuffleOp only supports NCHW order");
57 bool RunOnDeviceWithOrderNCHW()
override {
58 const auto& dY = Input(0);
61 const auto C = dY.dim32(1);
62 CAFFE_ENFORCE(C % this->group_ == 0,
"");
63 const auto K = C / this->group_;
64 const auto S = dY.dim32(2) * dY.dim32(3);
65 const auto G = this->group_;
66 for (
auto n = 0; n < dY.dim32(0); ++n) {
67 for (
auto g = 0; g < G; ++g) {
70 math::CopyMatrix<Context>(
74 dY.template data<float>() + g * S + n * C * S,
76 dX->template mutable_data<float>() + g * K * S + n * C * S,
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...