1 #include "caffe2/operators/layer_norm_op.h" 7 using EigenMatrixMapRowMajor = Eigen::Map<
8 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
11 using ConstEigenMatrixMapRowMajor = Eigen::Map<
12 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
17 bool LayerNormOp<CPUContext>::DoRunWithType<
float>() {
18 const auto& input = Input(0);
19 auto* output = Output(0);
20 auto* mean = Output(1);
21 auto* stdev = Output(2);
23 CAFFE_ENFORCE_GE(input.dims().size(), 2,
"LayerNorm requires input dim >= 2");
25 const auto canonical_axis = input.canonical_axis_index(axis_);
26 const int left = input.size_to_dim(canonical_axis);
27 const int right = input.size_from_dim(canonical_axis);
29 output->ResizeLike(input);
30 std::vector<TIndex> stats_dims(
31 input.dims().begin(), input.dims().begin() + canonical_axis);
32 stats_dims.push_back(1);
33 mean->Resize(stats_dims);
34 stdev->Resize(stats_dims);
36 auto input_map = ConstEigenMatrixMapRowMajor<float>(
37 input.template data<float>(), left, right);
38 auto mean_map = EigenMatrixMapRowMajor<float>(
39 mean->template mutable_data<float>(), left, 1);
40 auto stdev_map = EigenMatrixMapRowMajor<float>(
41 stdev->template mutable_data<float>(), left, 1);
42 auto output_map = EigenMatrixMapRowMajor<float>(
43 output->template mutable_data<float>(), left, right);
45 auto sqr = [](
float f) {
return f * f; };
46 auto add_ep = [
this](
float f) {
return f + epsilon_; };
47 auto fsqrt = [](
float f) {
return std::sqrt(f); };
49 mean_map = input_map.rowwise().mean();
51 (input_map.unaryExpr(sqr).rowwise().mean() - mean_map.unaryExpr(sqr))
54 output_map = (input_map - mean_map.replicate(1, right))
55 .cwiseQuotient(stdev_map.replicate(1, right));
60 REGISTER_CPU_OPERATOR(LayerNorm, LayerNormOp<CPUContext>);
64 bool LayerNormGradientOp<CPUContext>::DoRunWithType<
float>() {
65 const auto& dout = Input(0);
66 const auto& norm_outputs = Input(1);
67 const auto& means = Input(2);
68 const auto& stdev = Input(3);
69 const auto& norm_inputs = Input(4);
70 auto* ginput = Output(0);
72 const auto canonical_axis = norm_inputs.canonical_axis_index(axis_);
73 const int left = norm_inputs.size_to_dim(canonical_axis);
74 const int right = norm_inputs.size_from_dim(canonical_axis);
76 ginput->ResizeLike(norm_inputs);
78 auto dout_map = ConstEigenMatrixMapRowMajor<float>(
79 dout.template data<float>(), left, right);
81 ConstEigenMatrixMapRowMajor<float>(means.template data<float>(), left, 1);
83 ConstEigenMatrixMapRowMajor<float>(stdev.template data<float>(), left, 1);
84 auto norm_inputs_map = ConstEigenMatrixMapRowMajor<float>(
85 norm_inputs.template data<float>(), left, right);
86 auto ginput_map = EigenMatrixMapRowMajor<float>(
87 ginput->template mutable_data<float>(), left, right);
90 auto sqr = [](
float f) {
return f * f; };
91 auto recip = [](
float f) {
return 1.0f / f; };
92 auto neg_recip = [](
float f) {
return -1.0f / f; };
97 auto dstdev_end_0 = stdev_map.unaryExpr(sqr).unaryExpr(neg_recip);
99 auto dstdev_end_1 = (norm_inputs_map - means_map.replicate(1, right))
100 .cwiseProduct(dout_map)
103 auto dstdev_end = dstdev_end_0.cwiseProduct(dstdev_end_1);
105 auto dmean_end = stdev_map.unaryExpr(neg_recip)
107 .cwiseProduct(dout_map)
112 stdev_map.unaryExpr(recip).replicate(1, right).cwiseProduct(dout_map);
116 auto dmean_stdev = stdev_map.unaryExpr(neg_recip)
117 .cwiseProduct(means_map)
119 .cwiseProduct(dstdev_end);
121 auto dx_stdev = (1.0f / right) *
122 norm_inputs_map.cwiseQuotient(stdev_map.replicate(1, right))
123 .cwiseProduct(dstdev_end.replicate(1, right));
126 auto dmean = dmean_end + dmean_stdev;
127 auto dx_mean = (1.0f / right) * dmean.replicate(1, right);
129 ginput_map = dx_end + dx_stdev + dx_mean;
134 OPERATOR_SCHEMA(LayerNormGradient).NumInputs(5).NumOutputs(1);
136 REGISTER_CPU_OPERATOR(LayerNormGradient, LayerNormGradientOp<CPUContext>);
140 class GetLayerNormGradient :
public GradientMakerBase {
141 using GradientMakerBase::GradientMakerBase;
142 vector<OperatorDef> GetGradientDefs()
override {
143 return SingleGradientDef(
146 vector<string>{GO(0), O(0), O(1), O(2), I(0)},
147 vector<string>{GI(0)});
153 REGISTER_GRADIENT(LayerNorm, GetLayerNormGradient);
155 OPERATOR_SCHEMA(LayerNorm)
158 .TensorInferenceFunction([](
const OperatorDef& def,
159 const vector<TensorShape>& in) {
160 vector<TensorShape> out(3);
161 auto input_dims_long = GetDimsVector(in[0]);
162 std::vector<int> input_dims(
163 input_dims_long.begin(), input_dims_long.end());
164 out[0] = CreateTensorShape(input_dims, TensorProto::FLOAT);
166 ArgumentHelper helper(def);
168 auto axis = helper.GetSingleArgument<int32_t>(
"axis", 1);
169 const auto canonical_axis =
170 canonical_axis_index_(axis, in[0].dims().size());
171 std::vector<int> stat_dims(
172 input_dims.begin(), input_dims.begin() + canonical_axis);
173 stat_dims.push_back(1);
174 out[1] = CreateTensorShape(stat_dims, TensorProto::FLOAT);
175 out[2] = CreateTensorShape(stat_dims, TensorProto::FLOAT);
179 Computes layer normalization as described in https://arxiv.org/pdf/1607.06450.pdf. 180 Given an input vector x \in [a_0, a_1, ...,a_{k-1}, a_k, ..., a_{n-1}], 181 this op treats dimensions a_k through a_{n-1} as feature vectors. For each 182 feature vector, the op contains the mean and standard deviation. Then, 183 it returns the normalized values (with respect to the feature vector). 185 Note that this op does not contain the scale an bias terms described in the 186 paper. Simply follow this op with an FC op to add those. Concretely, this op 189 h = \frac{1}{\sigma}(a - \mu) 190 where \mu = \frac{1}{H}\sum_{i=1}^{H} a_i 191 and \sigma = \sqrt{\frac{1}{H}\sum_{i=1}^{H}(a_i - \mu)^2} 192 where H is the number of hidden units (i.e. product of dimensions from 'axis' 197 "(int) default to 1; Describes axis of the inputs. Defaults to one " 198 "because the 0th axis most likely describes the batch size")
201 "(float) default to 0.001. Small value to be added to the stdev when" 202 " dividing out by that value. This prevents division by zero.")
206 "Input tensor which layer normalization will be applied to")
207 .Output(0,
"output",
"Normalized values")
208 .Output(1,
"mean",
"Mean values for each feature vector")
209 .Output(2,
"stddev",
"Standard deviations for each feature vector");
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...