1 #ifndef CAFFE2_OPERATORS_IM2COL_OP_H_ 2 #define CAFFE2_OPERATORS_IM2COL_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
typename T,
class Context>
14 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 pad_(OperatorBase::GetSingleArgument<int>(
"pad", 0)),
18 kernel_h_(OperatorBase::GetSingleArgument<int>(
20 OperatorBase::GetSingleArgument<int>(
"kernel", 0))),
21 kernel_w_(OperatorBase::GetSingleArgument<int>(
23 OperatorBase::GetSingleArgument<int>(
"kernel", 0))),
24 dilation_h_(OperatorBase::GetSingleArgument<int>(
26 OperatorBase::GetSingleArgument<int>(
"dilation", 1))),
27 dilation_w_(OperatorBase::GetSingleArgument<int>(
29 OperatorBase::GetSingleArgument<int>(
"dilation", 1))),
30 stride_h_(OperatorBase::GetSingleArgument<int>(
32 OperatorBase::GetSingleArgument<int>(
"stride", 1))),
33 stride_w_(OperatorBase::GetSingleArgument<int>(
35 OperatorBase::GetSingleArgument<int>(
"stride", 1))),
36 order_(StringToStorageOrder(
37 OperatorBase::GetSingleArgument<string>(
"order",
"NCHW"))) {
38 CAFFE_ENFORCE(kernel_h_ > 0);
39 CAFFE_ENFORCE(kernel_w_ > 0);
40 CAFFE_ENFORCE(dilation_h_ > 0);
41 CAFFE_ENFORCE(dilation_w_ > 0);
42 CAFFE_ENFORCE(stride_h_ > 0);
43 CAFFE_ENFORCE(stride_w_ > 0);
44 CAFFE_ENFORCE(pad_ >= 0);
47 bool RunOnDevice()
override {
50 CAFFE_ENFORCE(4 == X.ndim());
52 int N = 0, C = 0, H = 0, W = 0;
54 case StorageOrder::NCHW:
60 case StorageOrder::NHWC:
67 CAFFE_THROW(
"Unknown storage order: ", order_);
70 const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
71 const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
72 CAFFE_ENFORCE(H >= dkernel_h);
73 CAFFE_ENFORCE(W >= dkernel_w);
74 const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
75 const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
78 case StorageOrder::NCHW: {
80 std::vector<TIndex>{N, C * kernel_h_ * kernel_w_, out_h, out_w});
82 const size_t dx = X.size() / N;
83 const size_t dy = Y->size() / N;
84 for (
int n = 0; n < N; ++n) {
85 const auto* xdata = X.template data<T>() + (n * dx);
86 auto* ydata = Y->template mutable_data<T>() + (n * dy);
87 math::Im2col<T, Context, StorageOrder::NCHW>(
106 case StorageOrder::NHWC: {
108 std::vector<TIndex>{N, out_h, out_w, kernel_h_ * kernel_w_ * C});
110 const size_t dx = X.size() / N;
111 const size_t dy = Y->size() / N;
112 for (
int n = 0; n < N; ++n) {
113 const auto* xdata = X.template data<T>() + (n * dx);
114 auto* ydata = Y->template mutable_data<T>() + (n * dy);
115 math::Im2col<T, Context, StorageOrder::NHWC>(
135 CAFFE_THROW(
"Unknown storage order: ", order_);
152 template <
typename T,
class Context>
155 USE_OPERATOR_CONTEXT_FUNCTIONS;
158 pad_(OperatorBase::GetSingleArgument<int>(
"pad", 0)),
159 kernel_h_(OperatorBase::GetSingleArgument<int>(
161 OperatorBase::GetSingleArgument<int>(
"kernel", 0))),
162 kernel_w_(OperatorBase::GetSingleArgument<int>(
164 OperatorBase::GetSingleArgument<int>(
"kernel", 0))),
165 dilation_h_(OperatorBase::GetSingleArgument<int>(
167 OperatorBase::GetSingleArgument<int>(
"dilation", 1))),
168 dilation_w_(OperatorBase::GetSingleArgument<int>(
170 OperatorBase::GetSingleArgument<int>(
"dilation", 1))),
171 stride_h_(OperatorBase::GetSingleArgument<int>(
173 OperatorBase::GetSingleArgument<int>(
"stride", 1))),
174 stride_w_(OperatorBase::GetSingleArgument<int>(
176 OperatorBase::GetSingleArgument<int>(
"stride", 1))),
177 order_(StringToStorageOrder(
178 OperatorBase::GetSingleArgument<string>(
"order",
"NCHW"))) {
179 CAFFE_ENFORCE(kernel_h_ > 0);
180 CAFFE_ENFORCE(kernel_w_ > 0);
181 CAFFE_ENFORCE(dilation_h_ > 0);
182 CAFFE_ENFORCE(dilation_w_ > 0);
183 CAFFE_ENFORCE(stride_h_ > 0);
184 CAFFE_ENFORCE(stride_w_ > 0);
185 CAFFE_ENFORCE(pad_ >= 0);
188 bool RunOnDevice()
override {
193 CAFFE_ENFORCE(4 == Y->ndim());
195 int N = 0, C = 0, H = 0, W = 0;
197 case StorageOrder::NCHW:
203 case StorageOrder::NHWC:
210 CAFFE_THROW(
"Unknown storage order: ", order_);
213 const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
214 const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
215 CAFFE_ENFORCE(H >= dkernel_h);
216 CAFFE_ENFORCE(W >= dkernel_w);
217 const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
218 const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
219 CAFFE_ENFORCE(X.size() == N * kernel_h_ * kernel_w_ * C * out_h * out_w);
221 const size_t dx = X.size() / N;
222 const size_t dy = Y->size() / N;
226 case StorageOrder::NCHW: {
227 for (
int n = 0; n < N; ++n) {
228 const auto* xdata = X.template data<T>() + (n * dx);
229 auto* ydata = Y->template mutable_data<T>() + (n * dy);
230 math::Col2im<T, Context, StorageOrder::NCHW>(
249 case StorageOrder::NHWC: {
250 for (
int n = 0; n < N; ++n) {
251 const auto* xdata = X.template data<T>() + (n * dx);
252 auto* ydata = Y->template mutable_data<T>() + (n * dy);
253 math::Col2im<T, Context, StorageOrder::NHWC>(
273 CAFFE_THROW(
"Unknown storage order: ", order_);
292 #endif // CAFFE2_OPERATORS_IM2COL_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...