1 #include "caffe2/operators/order_switch_ops.h" 6 bool NHWC2NCHWOp<float, CPUContext>::RunOnDevice() {
9 CAFFE_ENFORCE(X.ndim() == 4);
10 const int N = X.dim32(0), H = X.dim32(1), W = X.dim32(2), C = X.dim32(3);
11 Y->Resize(N, C, H, W);
12 const float* Xdata = X.data<
float>();
13 float* Ydata = Y->mutable_data<
float>();
14 for (
int n = 0; n < N; ++n) {
15 for (
int h = 0; h < H; ++h) {
16 for (
int w = 0; w < W; ++w) {
17 for (
int c = 0; c < C; ++c) {
18 Ydata[((n * C + c) * H + h) * W + w] = *(Xdata++);
27 bool NCHW2NHWCOp<float, CPUContext>::RunOnDevice() {
30 CAFFE_ENFORCE(X.ndim() == 4);
31 const int N = X.dim32(0), C = X.dim32(1), H = X.dim32(2), W = X.dim32(3);
32 Y->Resize(N, H, W, C);
33 const float* Xdata = X.data<
float>();
34 float* Ydata = Y->mutable_data<
float>();
35 for (
int n = 0; n < N; ++n) {
36 for (
int c = 0; c < C; ++c) {
37 for (
int h = 0; h < H; ++h) {
38 for (
int w = 0; w < W; ++w) {
39 Ydata[((n * H + h) * W + w) * C + c] = *(Xdata++);
48 REGISTER_CPU_OPERATOR(NHWC2NCHW, NHWC2NCHWOp<float, CPUContext>);
49 REGISTER_CPU_OPERATOR(NCHW2NHWC, NCHW2NHWCOp<float, CPUContext>);
51 OPERATOR_SCHEMA(NHWC2NCHW)
54 .TensorInferenceFunction([](
const OperatorDef& ,
55 const vector<TensorShape>& in) {
57 in[0].dims_size(), 4,
"Input for NHWC2NCHW must be 4 dimensional");
58 vector<TensorShape> out(1);
59 out[0].add_dims(in[0].dims(0));
60 out[0].add_dims(in[0].dims(3));
61 out[0].add_dims(in[0].dims(1));
62 out[0].add_dims(in[0].dims(2));
66 The operator switches the order of data in a tensor from NHWC- sample index N, 67 height H, width H and channels C, to the NCHW order. 69 .Input(0, "data",
"The input data (Tensor<float>) in the NHWC order.")
73 "The output tensor (Tensor<float>) in the NCHW order.");
75 OPERATOR_SCHEMA(NCHW2NHWC).NumInputs(1).NumOutputs(1)
77 The operator switches the order of data in a tensor from NCHW- sample index N, 78 channels C, height H and width W, to the NHWC order. 80 .Input(0, "data",
"The input data (Tensor<float>) in the NCHW order.")
81 .Output(0,
"output",
"The output tensor (Tensor<float>) in the NHWC order.");
85 using GradientMakerBase::GradientMakerBase;
86 vector<OperatorDef> GetGradientDefs()
override {
89 vector<string>{GO(0)},
90 vector<string>{GI(0)});
96 using GradientMakerBase::GradientMakerBase;
97 vector<OperatorDef> GetGradientDefs()
override {
100 vector<string>{GO(0)},
101 vector<string>{GI(0)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...