1 #ifndef CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_ 2 #define CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 9 template <
typename T,
class Context>
12 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 binary_ = OperatorBase::GetSingleArgument<bool>(
"binary",
false);
19 bounds_from_arg_ = OperatorBase::GetRepeatedArgument<T>(
"bounds");
20 slopes_from_arg_ = OperatorBase::GetRepeatedArgument<T>(
"slopes");
21 intercepts_from_arg_ = OperatorBase::GetRepeatedArgument<T>(
"intercepts");
22 transform_param_from_arg_ = CheckTransParamFromArg();
25 bool RunOnDevice()
override {
26 return binary_ ? TransformBinary() : TransformGeneral();
34 void InferNumFunctionsPerGroup(
35 const TIndex num_bounds,
36 const TIndex num_slopes,
37 const TIndex num_intercepts,
38 TIndex* num_func_per_group,
40 CAFFE_ENFORCE_EQ(num_slopes, num_intercepts);
45 *num_group = num_bounds - num_slopes;
46 CAFFE_ENFORCE_GT(*num_group, 0);
48 CAFFE_ENFORCE_EQ(*num_group, 1);
50 *num_func_per_group = num_slopes / *num_group;
51 CAFFE_ENFORCE_GT(*num_func_per_group, 0);
52 CAFFE_ENFORCE_EQ(num_slopes % *num_group, 0);
55 bool CheckBoundsSorted(
57 const TIndex num_bounds_per_group,
58 const TIndex num_group) {
59 const T* start = bounds;
60 for (TIndex i = 0; i < num_group; i++) {
61 if (!std::is_sorted(start, start + num_bounds_per_group)) {
64 start += num_bounds_per_group;
71 bool CheckTransParamFromArg() {
73 good_param += bounds_from_arg_.size() > 0;
74 good_param += slopes_from_arg_.size() > 0;
75 good_param += intercepts_from_arg_.size() > 0;
77 good_param == 0 || good_param == 3,
78 "bounds, slopes, intercepts must be all set or all not set");
79 if (good_param == 3) {
80 TIndex num_func_per_group;
82 InferNumFunctionsPerGroup(
83 bounds_from_arg_.size(),
84 slopes_from_arg_.size(),
85 intercepts_from_arg_.size(),
90 bounds_from_arg_.data(), num_func_per_group + 1, num_group),
91 "bounds must be sorted for each group");
94 return good_param == 3;
97 void setUpTensors(TIndex& num_func_per_group, TIndex& num_group, TIndex M);
99 void GetTransParamData(
102 const T** intercepts,
103 TIndex* num_func_per_group,
107 TIndex num_intercepts;
109 if (transform_param_from_arg_) {
110 CAFFE_ENFORCE_EQ(InputSize(), 1);
111 *bounds = bounds_from_arg_.data();
112 *slopes = slopes_from_arg_.data();
113 *intercepts = intercepts_from_arg_.data();
114 num_bounds = bounds_from_arg_.size();
115 num_slopes = slopes_from_arg_.size();
116 num_intercepts = intercepts_from_arg_.size();
118 CAFFE_ENFORCE_EQ(InputSize(), 4);
119 auto& bounds_input = Input(BOUNDS);
120 auto& slopes_input = Input(SLOPES);
121 auto& intercepts_input = Input(INTERCEPTS);
122 *bounds = bounds_input.template data<T>();
123 *slopes = slopes_input.template data<T>();
124 *intercepts = intercepts_input.template data<T>();
125 num_bounds = bounds_input.size();
126 num_slopes = slopes_input.size();
127 num_intercepts = intercepts_input.size();
129 InferNumFunctionsPerGroup(
130 num_bounds, num_slopes, num_intercepts, num_func_per_group, num_group);
133 bool TransformGeneral() {
136 CAFFE_ENFORCE_EQ(X.ndim(), 2);
137 TIndex N = X.dim32(0);
138 TIndex M = X.dim32(1);
140 const auto* Xdata = X.template data<T>();
141 T* Ydata = Y->template mutable_data<T>();
146 TIndex num_func_per_group;
149 &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
150 CAFFE_ENFORCE_EQ(num_group, M);
152 for (TIndex j = 0; j < M; ++j) {
153 const T* bounds_group = bounds + j * (num_func_per_group + 1);
154 const T* slopes_group = slopes + j * num_func_per_group;
155 const T* intercepts_group = intercepts + j * num_func_per_group;
156 for (TIndex i = 0; i < N; ++i) {
157 Ydata[i * M + j] = PiecewiseLinearTransform(
168 bool TransformBinary() {
169 auto& X = Input(PREDICTIONS);
171 CAFFE_ENFORCE(X.ndim() == 1 || X.ndim() == 2);
172 TIndex N = X.dim32(0);
173 TIndex M = X.ndim() == 2 ? X.dim32(1) : 1;
176 "If binary is set to true, the input must be Nx2 or Nx1 tensor");
178 const auto* Xdata = X.template data<T>();
179 T* Ydata = Y->template mutable_data<T>();
184 TIndex num_func_per_group;
187 &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
188 CAFFE_ENFORCE_EQ(num_group, 1);
191 for (TIndex i = 0; i < N; ++i) {
192 Ydata[i] = PiecewiseLinearTransform(
193 Xdata[i], bounds, slopes, intercepts, num_func_per_group);
196 for (TIndex i = 0; i < N; ++i) {
197 Ydata[i * M + 1] = PiecewiseLinearTransform(
198 Xdata[i * M + 1], bounds, slopes, intercepts, num_func_per_group);
199 Ydata[i * M] = 1.0f - Ydata[i * M + 1];
206 T PiecewiseLinearTransform(
211 const TIndex num_func_per_group) {
215 if (x <= bounds[0]) {
216 y = slopes[0] * bounds[0] + intercepts[0];
217 }
else if (x >= bounds[num_func_per_group]) {
218 y = slopes[num_func_per_group - 1] * bounds[num_func_per_group] +
219 intercepts[num_func_per_group - 1];
222 std::lower_bound(bounds, bounds + num_func_per_group + 1, x);
223 int bounds_idx = low_bound - bounds - 1;
225 y = slopes[bounds_idx] * x + intercepts[bounds_idx];
232 vector<T> bounds_from_arg_;
233 vector<T> slopes_from_arg_;
234 vector<T> intercepts_from_arg_;
239 bool gpu_copied_ =
false;
243 bool transform_param_from_arg_;
245 INPUT_TAGS(PREDICTIONS, BOUNDS, SLOPES, INTERCEPTS);
250 #endif // CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_ Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...