1 #ifndef CAFFE2_OPERATORS_SQUARE_ROOT_DIVIDE_OP_H_ 2 #define CAFFE2_OPERATORS_SQUARE_ROOT_DIVIDE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 bool RunOnDevice()
override {
24 template <
typename TData>
25 bool DoRunWithType() {
30 template <
typename TData,
typename TScale>
31 bool DoRunWithType2() {
32 auto& data = Input(DATA);
33 auto& scale = Input(SCALE);
36 size_t batchSize = data.dim(0);
37 size_t exampleSize = data.size_from_dim(1);
38 CAFFE_ENFORCE(batchSize == scale.dim(0), batchSize,
" != ", scale.dim(0));
39 auto* scalePtr = scale.template data<TScale>();
40 auto* dataPtr = data.template data<TData>();
41 auto* yPtr = Y->template mutable_data<TData>();
42 for (
auto i = 0; i < batchSize; ++i) {
43 auto scale = scalePtr[i];
44 CAFFE_ENFORCE(scale >= 0, scale,
" < 0");
45 auto multiplier = scale == 0 ? 1.0 : 1 / std::sqrt(scale);
46 math::Scale<TData, Context>(
49 dataPtr + i * exampleSize,
50 yPtr + i * exampleSize,
56 INPUT_TAGS(DATA, SCALE);
61 #endif // CAFFE2_OPERATORS_SQUARE_ROOT_DIVIDE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...