Caffe2 - C++ API
A deep learning, cross platform ML framework
piecewise_linear_transform_op.h
1 #ifndef CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
2 #define CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <typename T, class Context>
10 class PiecewiseLinearTransformOp final : public Operator<Context> {
11  public:
12  USE_OPERATOR_CONTEXT_FUNCTIONS;
13 
14  PiecewiseLinearTransformOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws) {
16  binary_ = OperatorBase::GetSingleArgument<bool>("binary", false);
17 
18  // Retrieve transform params (i.e., the linear functions).
19  bounds_from_arg_ = OperatorBase::GetRepeatedArgument<T>("bounds");
20  slopes_from_arg_ = OperatorBase::GetRepeatedArgument<T>("slopes");
21  intercepts_from_arg_ = OperatorBase::GetRepeatedArgument<T>("intercepts");
22  transform_param_from_arg_ = CheckTransParamFromArg();
23  }
24 
25  bool RunOnDevice() override {
26  return binary_ ? TransformBinary() : TransformGeneral();
27  }
28 
29  private:
30  // num_func_per_group is the number of pieces of linear functions of
31  // each group.
32  // num_group: The number of groups of linear functions. Each group is for
33  // transforming one column of predictions.
34  void InferNumFunctionsPerGroup(
35  const TIndex num_bounds,
36  const TIndex num_slopes,
37  const TIndex num_intercepts,
38  TIndex* num_func_per_group,
39  TIndex* num_group) {
40  CAFFE_ENFORCE_EQ(num_slopes, num_intercepts);
41 
42  // This is based on the facts:
43  // 1. in each group, the num of bounds minus the num of slopes is 1;
44  // 2. each group has the same number of pieces.
45  *num_group = num_bounds - num_slopes;
46  CAFFE_ENFORCE_GT(*num_group, 0);
47  if (binary_) {
48  CAFFE_ENFORCE_EQ(*num_group, 1);
49  }
50  *num_func_per_group = num_slopes / *num_group;
51  CAFFE_ENFORCE_GT(*num_func_per_group, 0);
52  CAFFE_ENFORCE_EQ(num_slopes % *num_group, 0);
53  }
54 
55  bool CheckBoundsSorted(
56  const T* bounds,
57  const TIndex num_bounds_per_group,
58  const TIndex num_group) {
59  const T* start = bounds;
60  for (TIndex i = 0; i < num_group; i++) {
61  if (!std::is_sorted(start, start + num_bounds_per_group)) {
62  return false;
63  }
64  start += num_bounds_per_group;
65  }
66  return true;
67  }
68 
69  // Returns true if the transform params from arg are valid.
70  // Otherwise, we will assume the transform params will pass from Input blobs.
71  bool CheckTransParamFromArg() {
72  int good_param = 0;
73  good_param += bounds_from_arg_.size() > 0;
74  good_param += slopes_from_arg_.size() > 0;
75  good_param += intercepts_from_arg_.size() > 0;
76  CAFFE_ENFORCE(
77  good_param == 0 || good_param == 3,
78  "bounds, slopes, intercepts must be all set or all not set");
79  if (good_param == 3) {
80  TIndex num_func_per_group;
81  TIndex num_group;
82  InferNumFunctionsPerGroup(
83  bounds_from_arg_.size(),
84  slopes_from_arg_.size(),
85  intercepts_from_arg_.size(),
86  &num_func_per_group,
87  &num_group);
88  CAFFE_ENFORCE(
89  CheckBoundsSorted(
90  bounds_from_arg_.data(), num_func_per_group + 1, num_group),
91  "bounds must be sorted for each group");
92  }
93 
94  return good_param == 3;
95  }
96 
97  void setUpTensors(TIndex& num_func_per_group, TIndex& num_group, TIndex M);
98 
99  void GetTransParamData(
100  const T** bounds,
101  const T** slopes,
102  const T** intercepts,
103  TIndex* num_func_per_group,
104  TIndex* num_group) {
105  TIndex num_bounds;
106  TIndex num_slopes;
107  TIndex num_intercepts;
108 
109  if (transform_param_from_arg_) {
110  CAFFE_ENFORCE_EQ(InputSize(), 1);
111  *bounds = bounds_from_arg_.data();
112  *slopes = slopes_from_arg_.data();
113  *intercepts = intercepts_from_arg_.data();
114  num_bounds = bounds_from_arg_.size();
115  num_slopes = slopes_from_arg_.size();
116  num_intercepts = intercepts_from_arg_.size();
117  } else {
118  CAFFE_ENFORCE_EQ(InputSize(), 4);
119  auto& bounds_input = Input(BOUNDS);
120  auto& slopes_input = Input(SLOPES);
121  auto& intercepts_input = Input(INTERCEPTS);
122  *bounds = bounds_input.template data<T>();
123  *slopes = slopes_input.template data<T>();
124  *intercepts = intercepts_input.template data<T>();
125  num_bounds = bounds_input.size();
126  num_slopes = slopes_input.size();
127  num_intercepts = intercepts_input.size();
128  }
129  InferNumFunctionsPerGroup(
130  num_bounds, num_slopes, num_intercepts, num_func_per_group, num_group);
131  }
132 
133  bool TransformGeneral() {
134  auto& X = Input(0);
135  auto* Y = Output(0);
136  CAFFE_ENFORCE_EQ(X.ndim(), 2);
137  TIndex N = X.dim32(0);
138  TIndex M = X.dim32(1);
139  Y->ResizeLike(X);
140  const auto* Xdata = X.template data<T>();
141  T* Ydata = Y->template mutable_data<T>();
142 
143  const T* bounds;
144  const T* slopes;
145  const T* intercepts;
146  TIndex num_func_per_group;
147  TIndex num_group;
148  GetTransParamData(
149  &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
150  CAFFE_ENFORCE_EQ(num_group, M);
151 
152  for (TIndex j = 0; j < M; ++j) {
153  const T* bounds_group = bounds + j * (num_func_per_group + 1);
154  const T* slopes_group = slopes + j * num_func_per_group;
155  const T* intercepts_group = intercepts + j * num_func_per_group;
156  for (TIndex i = 0; i < N; ++i) {
157  Ydata[i * M + j] = PiecewiseLinearTransform(
158  Xdata[i * M + j],
159  bounds_group,
160  slopes_group,
161  intercepts_group,
162  num_func_per_group);
163  }
164  }
165  return true;
166  }
167 
168  bool TransformBinary() {
169  auto& X = Input(PREDICTIONS);
170  auto* Y = Output(0);
171  CAFFE_ENFORCE(X.ndim() == 1 || X.ndim() == 2);
172  TIndex N = X.dim32(0);
173  TIndex M = X.ndim() == 2 ? X.dim32(1) : 1;
174  CAFFE_ENFORCE(
175  M == 1 || M == 2,
176  "If binary is set to true, the input must be Nx2 or Nx1 tensor");
177  Y->ResizeLike(X);
178  const auto* Xdata = X.template data<T>();
179  T* Ydata = Y->template mutable_data<T>();
180 
181  const T* bounds;
182  const T* slopes;
183  const T* intercepts;
184  TIndex num_func_per_group;
185  TIndex num_group;
186  GetTransParamData(
187  &bounds, &slopes, &intercepts, &num_func_per_group, &num_group);
188  CAFFE_ENFORCE_EQ(num_group, 1);
189 
190  if (M == 1) {
191  for (TIndex i = 0; i < N; ++i) {
192  Ydata[i] = PiecewiseLinearTransform(
193  Xdata[i], bounds, slopes, intercepts, num_func_per_group);
194  }
195  } else {
196  for (TIndex i = 0; i < N; ++i) {
197  Ydata[i * M + 1] = PiecewiseLinearTransform(
198  Xdata[i * M + 1], bounds, slopes, intercepts, num_func_per_group);
199  Ydata[i * M] = 1.0f - Ydata[i * M + 1];
200  }
201  }
202 
203  return true;
204  }
205 
206  T PiecewiseLinearTransform(
207  const T x,
208  const T* bounds,
209  const T* slopes,
210  const T* intercepts,
211  const TIndex num_func_per_group) {
212  T y = 0;
213  // deal with samples out of bounds
214  // make it the same as the upper/lower bound value
215  if (x <= bounds[0]) {
216  y = slopes[0] * bounds[0] + intercepts[0];
217  } else if (x >= bounds[num_func_per_group]) {
218  y = slopes[num_func_per_group - 1] * bounds[num_func_per_group] +
219  intercepts[num_func_per_group - 1];
220  } else {
221  auto low_bound =
222  std::lower_bound(bounds, bounds + num_func_per_group + 1, x);
223  int bounds_idx = low_bound - bounds - 1;
224  // compute the piecewise linear transformation as Y
225  y = slopes[bounds_idx] * x + intercepts[bounds_idx];
226  }
227  return y;
228  }
229 
230  private:
231  bool binary_;
232  vector<T> bounds_from_arg_;
233  vector<T> slopes_from_arg_;
234  vector<T> intercepts_from_arg_;
235 
236  Tensor<Context> bounds_device_;
237  Tensor<Context> intercepts_device_;
238  Tensor<Context> slopes_device_;
239  bool gpu_copied_ = false;
240 
241  // If true, the piecewise linear functions are passed through args,
242  // otherwise, they are passed through Input blobs.
243  bool transform_param_from_arg_;
244 
245  INPUT_TAGS(PREDICTIONS, BOUNDS, SLOPES, INTERCEPTS);
246 };
247 
248 } // namespace caffe2
249 
250 #endif // CAFFE2_OPERATORS_PIECEWISE_LINEAR_TRANSFORM_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...