1 #ifndef CAFFE2_OPERATORS_NORMALIZE_OP_H_ 2 #define CAFFE2_OPERATORS_NORMALIZE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
typename T,
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 bool RunOnDevice()
override {
18 const auto& x = Input(0);
20 const auto* xData = x.template data<T>();
22 auto* yData = y->template mutable_data<T>();
24 const auto canonical_axis = x.canonical_axis_index(
25 OperatorBase::GetSingleArgument<int>(
"axis", -1));
26 const int m = x.dim32(canonical_axis);
27 const int n = x.size() / m;
28 const int sf = x.size_from_dim(canonical_axis + 1);
29 DoNormalize(xData, yData, m, n, sf);
35 DoNormalize(
const T* xData, T* yData,
const int m,
const int n,
const int sf);
38 template <
typename T,
class Context>
41 USE_OPERATOR_CONTEXT_FUNCTIONS;
45 bool RunOnDevice()
override {
46 const auto& x = Input(0);
47 const auto& gOut = Input(GRAD_OUT);
48 auto* gIn = Output(GRAD_IN);
49 gIn->ResizeLike(gOut);
51 const auto* xData = x.template data<T>();
52 const auto* gOutData = gOut.template data<T>();
53 auto* gInData = gIn->template mutable_data<T>();
55 const auto canonical_axis = x.canonical_axis_index(
56 OperatorBase::GetSingleArgument<int>(
"axis", -1));
57 const int m = x.dim32(canonical_axis);
58 const int n = x.size() / m;
59 const int sf = x.size_from_dim(canonical_axis + 1);
60 DoNormalize(xData, gOutData, gInData, m, n, sf);
73 INPUT_TAGS(INPUT, GRAD_OUT);
79 #endif // CAFFE2_OPERATORS_NORMALIZE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...