Caffe2 - C++ API
A deep learning, cross platform ML framework
clip_tensor_op.h
1 #ifndef CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_
2 #define CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_
3 
4 #include <vector>
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/core/tensor.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <typename Context>
13 class ClipTensorByScalingOp final : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16 
17  ClipTensorByScalingOp(const OperatorDef& operator_def, Workspace* ws)
18  : Operator<Context>(operator_def, ws) {
19  threshold_ = OperatorBase::GetSingleArgument<float>("threshold", 0.0);
20  CAFFE_ENFORCE_GT(threshold_, 0, "Threshold must be greater than 0");
21  }
22 
23  bool RunOnDevice() override {
24  const auto& input_tensor = Input(0);
25  CAFFE_ENFORCE_GT(input_tensor.size(), 0);
26  const auto& val = Input(1);
27  CAFFE_ENFORCE_EQ(val.size(), 1);
28 
29  const auto* input_tensor_data = input_tensor.template data<float>();
30  const auto* val_data = val.template data<float>();
31 
32  auto* clipped = Output(0);
33  clipped->ResizeLike(input_tensor);
34  float* clipped_tensor_data = clipped->template mutable_data<float>();
35 
36  if (InputSize() > 2) {
37  const auto& additional_threshold = Input(2);
38  CAFFE_ENFORCE_EQ(additional_threshold.size(), 1);
39 
40  threshold_ *= *(additional_threshold.template data<float>());
41  }
42 
43  if (*val_data > threshold_) {
44  float ratio = threshold_ / *val_data;
45 
46  math::Scale<float, Context>(
47  clipped->size(),
48  ratio,
49  input_tensor_data,
50  clipped_tensor_data,
51  &context_);
52  } else {
53  if (input_tensor_data != clipped_tensor_data) {
54  clipped->CopyFrom(input_tensor, &context_);
55  }
56  }
57 
58  return true;
59  }
60 
61  private:
62  float threshold_;
63 };
64 
65 } // namespace caffe2
66 
67 #endif // CAFFE2_OPERATORS_CLIP_TENSOR_OP_H_
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:593
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...