Caffe2 - C++ API
A deep learning, cross platform ML framework
pool_gradient_op.cc
1 #include "caffe2/operators/pool_op.h"
2 
3 namespace caffe2 {
4 
5 using std::max;
6 using std::min;
7 
8 namespace {
9 // These two classe are just used as template arguments passed to the
10 // PoolGradientOp
11 // template to instantiate the different algorithms.
12 template <typename T>
13 class AveragePool {
14  public:
15  static void process_grad(
16  const T& /*x_data*/,
17  const T& /*y_data*/,
18  const T& dy_data,
19  const T& scale,
20  T& dx_data) {
21  dx_data += (scale * dy_data);
22  }
23 
24  static void process_grad(
25  const int y_col,
26  const int x_col,
27  const float scale,
28  ConstEigenArrayMap<float>& /*x_data*/,
29  ConstEigenArrayMap<float>& /*y_data*/,
30  ConstEigenArrayMap<float>& dy_data,
31  EigenArrayMap<float>& dx_data) {
32  dx_data.col(x_col) += scale * dy_data.col(y_col);
33  }
34 };
35 
36 template <typename T>
37 class MaxPool {
38  public:
39  static void process_grad(
40  const T& x_data,
41  const T& y_data,
42  const T& dy_data,
43  const T& /*scale*/,
44  T& dx_data) {
45  if (x_data == y_data) {
46  dx_data += dy_data;
47  }
48  }
49 
50  static void process_grad(
51  const int y_col,
52  const int x_col,
53  const float /*scale*/,
54  ConstEigenArrayMap<float>& x_data,
55  ConstEigenArrayMap<float>& y_data,
56  ConstEigenArrayMap<float>& dy_data,
57  EigenArrayMap<float>& dx_data) {
58  dx_data.col(x_col) +=
59  dy_data.col(y_col) * (x_data.col(x_col)
60  .cwiseEqual(y_data.col(y_col))
61  .template cast<float>());
62  }
63 };
64 }
65 
66 template <typename T, class Context, typename PoolType>
67 bool PoolGradientOp<T, Context, PoolType>::RunOnDeviceWithOrderNCHW() {
68  auto& X = Input(0);
69  auto& Y = Input(1);
70  auto& dY = Input(2);
71  auto* dX = Output(0);
72  // TODO(Yangqing): Add shape checks.
73  dX->ResizeLike(X);
74  math::Set<float, CPUContext>(
75  X.size(), 0, dX->template mutable_data<float>(), &context_);
76  const float* Xdata = X.template data<float>();
77  const float* Ydata = Y.template data<float>();
78  const float* dYdata = dY.template data<float>();
79  float* dXdata = dX->template mutable_data<float>();
80  int channels = X.dim32(1);
81  CAFFE_ENFORCE_EQ(channels, dY.dim32(1));
82  int height = X.dim32(2);
83  int width = kernel_.size() > 1 ? X.dim32(3) : 1;
84  int depth = kernel_.size() > 2 ? X.dim32(4) : 1;
85  vector<int> dims(X.dims().begin() + 2, X.dims().end());
86  ConvPoolOpBase<CPUContext>::ComputePads(dims);
87  int pooled_height = dY.dim32(2);
88  int pooled_width = kernel_.size() > 1 ? dY.dim32(3) : 1;
89  int pooled_depth = kernel_.size() > 2 ? dY.dim32(4) : 1;
90  // The main loop
91  switch (kernel_.size()) {
92  case 1:
93  for (int n = 0; n < X.dim32(0); ++n) {
94  for (int c = 0; c < channels; ++c) {
95  for (int ph = 0; ph < pooled_height; ++ph) {
96  int hstart = ph * stride_h() - pad_t();
97  int hend = min(hstart + kernel_h(), height);
98  hstart = max(hstart, 0);
99  float scale = 1. / (hend - hstart);
100  for (int h = hstart; h < hend; ++h) {
101  PoolType::process_grad(
102  Xdata[h], Ydata[ph], dYdata[ph], scale, dXdata[h]);
103  }
104  }
105  // offset
106  Xdata += height;
107  dXdata += height;
108  Ydata += pooled_height;
109  dYdata += pooled_height;
110  }
111  }
112  break;
113  case 2:
114  for (int n = 0; n < X.dim32(0); ++n) {
115  for (int c = 0; c < channels; ++c) {
116  for (int ph = 0; ph < pooled_height; ++ph) {
117  int hstart = ph * stride_h() - pad_t();
118  int hend = min(hstart + kernel_h(), height);
119  hstart = max(hstart, 0);
120  for (int pw = 0; pw < pooled_width; ++pw) {
121  int wstart = pw * stride_w() - pad_l();
122  int wend = min(wstart + kernel_w(), width);
123  wstart = max(wstart, 0);
124  float scale = 1. / (hend - hstart) / (wend - wstart);
125  const int pooled_index = ph * pooled_width + pw;
126  for (int h = hstart; h < hend; ++h) {
127  for (int w = wstart; w < wend; ++w) {
128  const int index = h * width + w;
129  PoolType::process_grad(
130  Xdata[index],
131  Ydata[pooled_index],
132  dYdata[pooled_index],
133  scale,
134  dXdata[index]);
135  }
136  }
137  }
138  }
139  // offset
140  Xdata += height * width;
141  dXdata += height * width;
142  Ydata += pooled_height * pooled_width;
143  dYdata += pooled_height * pooled_width;
144  }
145  }
146  break;
147  case 3:
148  for (int n = 0; n < X.dim32(0); ++n) {
149  for (int c = 0; c < channels; ++c) {
150  for (int ph = 0; ph < pooled_height; ++ph) {
151  int hstart = ph * stride_h() - pad_t();
152  int hend = min(hstart + kernel_h(), height);
153  hstart = max(hstart, 0);
154  for (int pw = 0; pw < pooled_width; ++pw) {
155  int wstart = pw * stride_w() - pad_l();
156  int wend = min(wstart + kernel_w(), width);
157  wstart = max(wstart, 0);
158  for (int pd = 0; pd < pooled_depth; ++pd) {
159  int dstart = pd * stride_[2] - pads_[2];
160  int dend = min(dstart + kernel_[2], depth);
161  dstart = max(dstart, 0);
162  float scale =
163  1. / (hend - hstart) / (wend - wstart) / (dend - dstart);
164  const int pooled_index =
165  ph * pooled_width * pooled_depth + pw * pooled_depth + pd;
166  for (int h = hstart; h < hend; ++h) {
167  for (int w = wstart; w < wend; ++w) {
168  for (int d = dstart; d < dend; ++d) {
169  const int index = h * width * depth + w * depth + d;
170  PoolType::process_grad(
171  Xdata[index],
172  Ydata[pooled_index],
173  dYdata[pooled_index],
174  scale,
175  dXdata[index]);
176  }
177  }
178  }
179  }
180  }
181  }
182  // offset
183  Xdata += height * width * depth;
184  dXdata += height * width * depth;
185  Ydata += pooled_height * pooled_width * pooled_depth;
186  dYdata += pooled_height * pooled_width * pooled_depth;
187  }
188  }
189  break;
190  default:
191  CAFFE_THROW("Unsupported pooling size");
192  return false;
193  }
194  return true;
195 }
196 
197 template <typename T, class Context, typename PoolType>
198 bool PoolGradientOp<T, Context, PoolType>::RunOnDeviceWithOrderNHWC() {
199  auto& X = Input(0);
200  auto& Y = Input(1);
201  auto& dY = Input(2);
202  DCHECK_EQ(dY.ndim(), kernel_.size() + 2);
203  auto* dX = Output(0);
204  dX->ResizeLike(X);
205 
206  int channels = X.dim32(X.ndim() - 1);
207  CAFFE_ENFORCE_EQ(channels, dY.dim32(dY.ndim() - 1));
208  ConstEigenArrayMap<T> Ymat(
209  Y.template data<float>(), channels, Y.size() / channels);
210  ConstEigenArrayMap<float> dYmat(
211  dY.template data<float>(), channels, Y.size() / channels);
212  ConstEigenArrayMap<float> Xmat(
213  X.template data<float>(), channels, X.size() / channels);
214  EigenArrayMap<float> dXmat(
215  dX->template mutable_data<float>(), channels, X.size() / channels);
216  dXmat.setZero();
217  int height = X.dim32(1);
218  int width = kernel_.size() > 1 ? X.dim32(2) : 1;
219  int depth = kernel_.size() > 2 ? X.dim32(3) : 1;
220  vector<int> dims(X.dims().begin() + 1, X.dims().end() - 1);
221  ConvPoolOpBase<CPUContext>::ComputePads(dims);
222  int pooled_height = dY.dim32(1);
223  int pooled_width = kernel_.size() > 1 ? dY.dim32(2) : 1;
224  int pooled_depth = kernel_.size() > 2 ? dY.dim32(3) : 1;
225 
226  // The main loop
227  // Do not do openmp here: the following for loops are looping over the pooled
228  // output, so if one parallelizes the outer loops, race conditions could
229  // happen in the inner loops.
230  switch (kernel_.size()) {
231  case 1:
232  for (int n = 0; n < X.dim32(0); ++n) {
233  for (int ph = 0; ph < pooled_height; ++ph) {
234  int hstart = ph * stride_h() - pad_t();
235  int hend = min(hstart + kernel_h(), height);
236  hstart = max(hstart, 0);
237  const int pool_index = n * pooled_height + ph;
238  const float scale = 1. / (hend - hstart);
239  for (int h = hstart; h < hend; ++h) {
240  const int input_index = n * height + h;
241  PoolType::process_grad(
242  pool_index, input_index, scale, Xmat, Ymat, dYmat, dXmat);
243  }
244  }
245  }
246  break;
247  case 2:
248  for (int n = 0; n < X.dim32(0); ++n) {
249  for (int ph = 0; ph < pooled_height; ++ph) {
250  int hstart = ph * stride_h() - pad_t();
251  int hend = min(hstart + kernel_h(), height);
252  hstart = max(hstart, 0);
253  for (int pw = 0; pw < pooled_width; ++pw) {
254  int wstart = pw * stride_w() - pad_l();
255  int wend = min(wstart + kernel_w(), width);
256  wstart = max(wstart, 0);
257  const int pool_index = (n * pooled_height + ph) * pooled_width + pw;
258  const float scale = 1. / (hend - hstart) / (wend - wstart);
259  for (int h = hstart; h < hend; ++h) {
260  for (int w = wstart; w < wend; ++w) {
261  const int input_index = (n * height + h) * width + w;
262  PoolType::process_grad(
263  pool_index, input_index, scale, Xmat, Ymat, dYmat, dXmat);
264  }
265  }
266  }
267  }
268  }
269  break;
270  case 3:
271  for (int n = 0; n < X.dim32(0); ++n) {
272  for (int ph = 0; ph < pooled_height; ++ph) {
273  int hstart = ph * stride_h() - pad_t();
274  int hend = min(hstart + kernel_h(), height);
275  hstart = max(hstart, 0);
276  for (int pw = 0; pw < pooled_width; ++pw) {
277  int wstart = pw * stride_w() - pad_l();
278  int wend = min(wstart + kernel_w(), width);
279  wstart = max(wstart, 0);
280  for (int pd = 0; pd < pooled_depth; ++pd) {
281  int dstart = pd * stride_[2] - pads_[2];
282  int dend = min(dstart + kernel_[2], depth);
283  dstart = max(dstart, 0);
284  const int pool_index =
285  ((n * pooled_height + ph) * pooled_width + pw) *
286  pooled_depth +
287  pd;
288  const float scale =
289  1. / (hend - hstart) / (wend - wstart) / (dend - dstart);
290  for (int h = hstart; h < hend; ++h) {
291  for (int w = wstart; w < wend; ++w) {
292  for (int d = dstart; d < dend; ++d) {
293  const int input_index =
294  ((n * height + h) * width + w) * depth + d;
295  PoolType::process_grad(
296  pool_index,
297  input_index,
298  scale,
299  Xmat,
300  Ymat,
301  dYmat,
302  dXmat);
303  }
304  }
305  }
306  }
307  }
308  }
309  }
310  break;
311  default:
312  CAFFE_THROW("Unsupported pooling size");
313  return false;
314  }
315  return true;
316 }
317 
318 REGISTER_CPU_OPERATOR(
319  AveragePoolGradient,
320  PoolGradientOp<float, CPUContext, AveragePool<float>>);
321 OPERATOR_SCHEMA(AveragePoolGradient).NumInputs(3).NumOutputs(1);
322 
323 REGISTER_CPU_OPERATOR(
324  AveragePool1DGradient,
325  PoolGradientOp<float, CPUContext, AveragePool<float>>);
326 OPERATOR_SCHEMA(AveragePool1DGradient).NumInputs(3).NumOutputs(1);
327 
328 REGISTER_CPU_OPERATOR(
329  AveragePool2DGradient,
330  PoolGradientOp<float, CPUContext, AveragePool<float>>);
331 OPERATOR_SCHEMA(AveragePool2DGradient).NumInputs(3).NumOutputs(1);
332 
333 REGISTER_CPU_OPERATOR(
334  AveragePool3DGradient,
335  PoolGradientOp<float, CPUContext, AveragePool<float>>);
336 OPERATOR_SCHEMA(AveragePool3DGradient).NumInputs(3).NumOutputs(1);
337 
338 REGISTER_CPU_OPERATOR(
339  MaxPoolGradient,
340  PoolGradientOp<float, CPUContext, MaxPool<float>>);
341 OPERATOR_SCHEMA(MaxPoolGradient).NumInputs(3).NumOutputs(1);
342 
343 REGISTER_CPU_OPERATOR(
344  MaxPool1DGradient,
345  PoolGradientOp<float, CPUContext, MaxPool<float>>);
346 OPERATOR_SCHEMA(MaxPool1DGradient).NumInputs(3).NumOutputs(1);
347 
348 REGISTER_CPU_OPERATOR(
349  MaxPool2DGradient,
350  PoolGradientOp<float, CPUContext, MaxPool<float>>);
351 OPERATOR_SCHEMA(MaxPool2DGradient).NumInputs(3).NumOutputs(1);
352 
353 REGISTER_CPU_OPERATOR(
354  MaxPool3DGradient,
355  PoolGradientOp<float, CPUContext, MaxPool<float>>);
356 OPERATOR_SCHEMA(MaxPool3DGradient).NumInputs(3).NumOutputs(1);
357 
358 class GetPoolGradient : public GradientMakerBase {
359  using GradientMakerBase::GradientMakerBase;
360  vector<OperatorDef> GetGradientDefs() override {
361  return SingleGradientDef(
362  def_.type() + "Gradient",
363  "",
364  vector<string>{I(0), O(0), GO(0)},
365  vector<string>{GI(0)});
366  }
367 };
368 REGISTER_GRADIENT(AveragePool, GetPoolGradient);
369 REGISTER_GRADIENT(AveragePool1D, GetPoolGradient);
370 REGISTER_GRADIENT(AveragePool2D, GetPoolGradient);
371 REGISTER_GRADIENT(AveragePool3D, GetPoolGradient);
372 REGISTER_GRADIENT(MaxPool, GetPoolGradient);
373 REGISTER_GRADIENT(MaxPool1D, GetPoolGradient);
374 REGISTER_GRADIENT(MaxPool2D, GetPoolGradient);
375 REGISTER_GRADIENT(MaxPool3D, GetPoolGradient);
376 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...