1 #ifndef CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_ 2 #define CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_ 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/core/tensor.h" 8 #include "caffe2/utils/math.h" 12 template <
typename Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 threshold_ = OperatorBase::GetSingleArgument<float>(
"threshold", 0.0);
20 CAFFE_ENFORCE_GT(threshold_, 0,
"Threshold must be greater than 0");
23 bool RunOnDevice()
override {
24 const auto& input_tensor = Input(0);
25 CAFFE_ENFORCE_GT(input_tensor.size(), 0);
26 const auto& val = Input(1);
27 CAFFE_ENFORCE_EQ(val.size(), 1);
29 const auto* input_tensor_data = input_tensor.template data<float>();
30 const auto* val_data = val.template data<float>();
32 auto* clipped = Output(0);
33 clipped->ResizeLike(input_tensor);
34 float* clipped_tensor_data = clipped->template mutable_data<float>();
36 if (InputSize() > 2) {
37 const auto& additional_threshold = Input(2);
38 CAFFE_ENFORCE_EQ(additional_threshold.
size(), 1);
40 threshold_ *= *(additional_threshold.template data<float>());
43 if (*val_data > threshold_) {
44 float ratio = threshold_ / *val_data;
46 math::Scale<float, Context>(
53 if (input_tensor_data != clipped_tensor_data) {
54 clipped->CopyFrom(input_tensor, &context_);
67 #endif // CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_
TIndex size() const
Returns the size (i.e.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...