1 #include "caffe2/operators/spatial_batch_norm_op.h" 6 bool SpatialBNGradientOp<CPUContext>::RunOnDevice() {
7 const auto& X = Input(INPUT);
8 const auto& dY = Input(OUTPUT_GRAD);
9 const auto& scale = Input(SCALE);
11 CAFFE_ENFORCE(X.ndim() >= 3 && X.ndim() <= 5);
12 const int N = X.dim32(0);
14 (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(X.ndim() - 1));
15 const int H = (order_ == StorageOrder::NCHW ? X.dim32(2) : X.dim32(1));
16 const int W = X.ndim() > 3
17 ? (order_ == StorageOrder::NCHW ? X.dim32(3) : X.dim32(2))
19 const int D = X.ndim() > 4
20 ? (order_ == StorageOrder::NCHW ? X.dim32(4) : X.dim32(3))
23 const int sample_size = H * W * D;
25 CAFFE_ENFORCE_EQ(scale.ndim(), 1);
26 CAFFE_ENFORCE_EQ(scale.dim32(0), C);
28 ConstEigenVectorArrayMap<float> scale_arr(scale.data<
float>(), C);
29 ConstEigenVectorArrayMap<float> mean_arr(Input(SAVED_MEAN).data<float>(), C);
30 ConstEigenVectorArrayMap<float> inv_var_arr(
31 Input(SAVED_INV_VAR).data<float>(), C);
33 auto* dX = Output(INPUT_GRAD);
36 auto* dScale = Output(SCALE_GRAD);
37 auto* dBias = Output(BIAS_GRAD);
39 if (num_batches_ == 1) {
40 dScale->ResizeLike(scale);
41 dBias->ResizeLike(scale);
49 EigenVectorArrayMap<float> dBias_arr(dBias->mutable_data<
float>(), C);
50 EigenVectorArrayMap<float> dScale_arr(dScale->mutable_data<
float>(), C);
52 if (num_batches_ == 1) {
57 const auto scaleInvVarNHW = scale_arr * inv_var_arr / (N * sample_size);
60 case StorageOrder::NCHW: {
61 ConstEigenArrayMap<float> X_arr(X.data<
float>(), sample_size, N * C);
62 ConstEigenArrayMap<float> dY_arr(dY.data<
float>(), sample_size, N * C);
63 EigenArrayMap<float> dX_arr(
64 dX->mutable_data<
float>(), sample_size, N * C);
67 if (num_batches_ == 1) {
68 for (
int nc = 0; nc < N * C; ++nc) {
70 dBias_arr(c) += dY_arr.col(nc).sum();
72 ((X_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) * dY_arr.col(nc))
76 for (
int c = 0; c < C; ++c) {
77 dBias_arr(c) /= num_batches_;
78 dScale_arr(c) /= num_batches_;
81 for (
int nc = 0; nc < N * C; ++nc) {
83 dX_arr.col(nc) += scaleInvVarNHW(c) *
84 (dY_arr.col(nc) * N * sample_size - dBias_arr(c) -
85 (X_arr.col(nc) - mean_arr[c]) * dScale_arr(c) * inv_var_arr(c));
89 case StorageOrder::NHWC: {
90 ConstEigenArrayMap<float> X_arr(X.data<
float>(), C, N * sample_size);
91 ConstEigenArrayMap<float> dY_arr(dY.data<
float>(), C, N * sample_size);
92 EigenArrayMap<float> dX_arr(
93 dX->mutable_data<
float>(), C, N * sample_size);
96 const auto dYRowSum = dY_arr.rowwise().sum();
97 const auto XMinusMean = X_arr.colwise() - mean_arr;
98 const auto dYMulXMinusMeanRowSum = (dY_arr * XMinusMean).rowwise().sum();
99 const auto invVarSqr = inv_var_arr * inv_var_arr;
100 for (
int nhw = 0; nhw < N * sample_size; ++nhw) {
101 dBias_arr += dY_arr.col(nhw);
103 (X_arr.col(nhw) - mean_arr) * inv_var_arr * dY_arr.col(nhw);
104 dX_arr.col(nhw) += scaleInvVarNHW *
105 (dY_arr.col(nhw) * N * sample_size - dYRowSum -
106 XMinusMean.col(nhw) * invVarSqr * dYMulXMinusMeanRowSum);
111 CAFFE_THROW(
"Unknown storage order: ", order_);
116 REGISTER_CPU_OPERATOR(SpatialBNGradient, SpatialBNGradientOp<CPUContext>);
120 OPERATOR_SCHEMA(SpatialBNGradient)
123 .AllowInplace({{5, 1}, {6, 2}});
128 using GradientMakerBase::GradientMakerBase;
129 vector<OperatorDef> GetGradientDefs()
override {
132 ArgumentHelper::GetSingleArgument(def_, OpSchema::Arg_IsTest, 0);
133 int num_batches = ArgumentHelper::GetSingleArgument(def_,
"num_batches", 1);
134 vector<string> grad_outputs{GI(0), GI(1), GI(2)};
135 vector<string> grad_inputs;
141 CAFFE_ENFORCE_EQ(def_.input_size(), 5);
142 CAFFE_ENFORCE_EQ(def_.output_size(), 1);
143 grad_inputs = vector<string>{I(0), I(1), GO(0), I(3), I(4)};
144 }
else if (num_batches > 1) {
145 CAFFE_ENFORCE_EQ(def_.input_size(), 7);
146 CAFFE_ENFORCE_EQ(def_.output_size(), 5);
147 grad_inputs = vector<string>{I(0), I(1), GO(0), O(3), O(4), GI(1), GI(2)};
149 CAFFE_ENFORCE_EQ(def_.input_size(), 5);
150 CAFFE_ENFORCE_EQ(def_.output_size(), 5);
151 grad_inputs = vector<string>{I(0), I(1), GO(0), O(3), O(4)};
154 "SpatialBNGradient",
"", grad_inputs, grad_outputs);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...