Caffe2 - C++ API
A deep learning, cross platform ML framework
top_k.h
1 #ifndef CAFFE2_OPERATORS_TOP_K_H_
2 #define CAFFE2_OPERATORS_TOP_K_H_
3 
4 #include "caffe2/core/logging.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class TopKOp : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14 
15  TopKOp(const OperatorDef& operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws),
17  OP_SINGLE_ARG(int, "k", k_, -1),
18  OP_SINGLE_ARG(int, "axis", axis_, -1) {
19  CAFFE_ENFORCE(k_ >= 1, "k argument must be >= 1");
20  }
21 
22  ~TopKOp() {}
23 
24  bool RunOnDevice() override;
25 
26  private:
27  const int k_;
28  int axis_;
29 };
30 
31 template <typename T, class Context>
32 class TopKGradientOp : public Operator<Context> {
33  public:
34  USE_OPERATOR_CONTEXT_FUNCTIONS;
35 
36  TopKGradientOp(const OperatorDef& operator_def, Workspace* ws)
37  : Operator<Context>(operator_def, ws),
38  OP_SINGLE_ARG(int, "axis", axis_, -1) {}
39 
40  ~TopKGradientOp() {}
41 
42  bool RunOnDevice() override;
43 
44  private:
45  int axis_;
46 };
47 
48 } // namespace caffe2
49 
50 #endif // CAFFE2_OPERATORS_TOP_K_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...