Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_box_cox_op.h
1 #ifndef CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_
2 #define CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class BatchBoxCoxOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  BatchBoxCoxOp(const OperatorDef& operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws),
17  min_block_size_(
18  OperatorBase::GetSingleArgument<int>("min_block_size", 256)) {}
19 
20  bool RunOnDevice() override {
21  return DispatchHelper<TensorTypes<float, double>>::call(this, Input(DATA));
22  }
23 
24  template <typename T>
25  bool DoRunWithType();
26 
27  protected:
28  template <typename T>
29  void BoxCoxNaive(
30  TIndex N,
31  TIndex D,
32  const T* data_ptr,
33  const T* lambda1_ptr,
34  const T* lambda2_ptr,
35  T k_eps,
36  T* output_ptr);
37 
38 #ifdef CAFFE2_USE_MKL
39  template <typename T>
40  void BoxCoxNonzeroLambda(
41  TIndex D,
42  const T* data_ptr,
43  const T* lambda1,
44  const T* lambda2,
45  T k_eps,
46  T* output_ptr);
47 
48  template <typename T>
49  void BoxCoxZeroLambda(
50  TIndex D,
51  const T* data_ptr,
52  const T* lambda2,
53  T k_eps,
54  T* output_ptr);
55 
56  template <typename T>
57  void BoxCoxMixedLambda(
58  const T* data_ptr,
59  const vector<int>& nonzeros,
60  const vector<int>& zeros,
61  const T* lambda1,
62  const T* lambda2,
63  const T* lambda2_z,
64  T k_eps,
65  T* buffer,
66  T* output_ptr);
67 
68  vector<int> nonzeros_, zeros_;
69 
70  // Buffers used by the MKL version are cached across calls.
71  struct CachedBuffers {
72  virtual ~CachedBuffers() {}
73  int type_;
74  };
75  template <typename T>
76  struct TypedCachedBuffers : public CachedBuffers {
77  vector<T> lambda1_, lambda2_, lambda2_z_;
78  vector<T> accumulator_;
79  };
80  template <typename T>
81  TypedCachedBuffers<T>& GetBuffers();
82  unique_ptr<CachedBuffers> buffers_;
83 
84 #endif // CAFFE2_USE_MKL
85 
86  int min_block_size_;
87 
88  INPUT_TAGS(DATA, LAMBDA1, LAMBDA2);
89 };
90 
91 } // namespace caffe2
92 
93 #endif // CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...