Caffe2 - C++ API
A deep learning, cross platform ML framework
counter_ops.h
1 #ifndef CAFFE2_OPERATORS_COUNTER_OPS_H
2 #define CAFFE2_OPERATORS_COUNTER_OPS_H
3 
4 #include <atomic>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/operator.h"
9 
10 namespace caffe2 {
11 template <typename T>
12 class Counter {
13  public:
14  explicit Counter(T count) : count_(count) {}
15  bool countDown() {
16  if (count_-- > 0) {
17  return false;
18  }
19  return true;
20  }
21 
22  T countUp() {
23  return count_++;
24  }
25 
26  T retrieve() const {
27  return count_.load();
28  }
29 
30  T checkIfDone() const {
31  return (count_.load() <= 0);
32  }
33 
34  T reset(T init_count) {
35  return count_.exchange(init_count);
36  }
37 
38  private:
39  std::atomic<T> count_;
40 };
41 
42 // TODO(jiayq): deprecate these ops & consolidate them with IterOp/AtomicIterOp
43 
44 template <typename T, class Context>
45 class CreateCounterOp final : public Operator<Context> {
46  public:
47  USE_OPERATOR_CONTEXT_FUNCTIONS;
48  CreateCounterOp(const OperatorDef& operator_def, Workspace* ws)
49  : Operator<Context>(operator_def, ws),
50  init_count_(OperatorBase::GetSingleArgument<T>("init_count", 0)) {
51  CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
52  }
53 
54  bool RunOnDevice() override {
55  *OperatorBase::Output<std::unique_ptr<Counter<T>>>(0) =
56  std::unique_ptr<Counter<T>>(new Counter<T>(init_count_));
57  return true;
58  }
59 
60  private:
61  T init_count_ = 0;
62 };
63 
64 template <typename T, class Context>
65 class ResetCounterOp final : public Operator<Context> {
66  public:
67  USE_OPERATOR_CONTEXT_FUNCTIONS;
68  ResetCounterOp(const OperatorDef& operator_def, Workspace* ws)
69  : Operator<Context>(operator_def, ws),
70  init_count_(OperatorBase::GetSingleArgument<T>("init_count", 0)) {
71  CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
72  }
73 
74  bool RunOnDevice() override {
75  auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
76  auto previous = counterPtr->reset(init_count_);
77  if (OutputSize() == 1) {
78  auto* output = OperatorBase::Output<TensorCPU>(0);
79  output->Resize();
80  *output->template mutable_data<T>() = previous;
81  }
82  return true;
83  }
84 
85  private:
86  T init_count_;
87 };
88 
89 // Will always use TensorCPU regardless the Context
90 template <typename T, class Context>
91 class CountDownOp final : public Operator<Context> {
92  public:
93  USE_OPERATOR_CONTEXT_FUNCTIONS;
94  CountDownOp(const OperatorDef& operator_def, Workspace* ws)
95  : Operator<Context>(operator_def, ws) {}
96 
97  bool RunOnDevice() override {
98  auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
99  auto* output = OperatorBase::Output<TensorCPU>(0);
100  output->Resize(std::vector<int>{});
101  *output->template mutable_data<bool>() = counterPtr->countDown();
102  return true;
103  }
104 };
105 
106 // Will always use TensorCPU regardless the Context
107 template <typename T, class Context>
108 class CheckCounterDoneOp final : public Operator<Context> {
109  public:
110  USE_OPERATOR_CONTEXT_FUNCTIONS;
111  CheckCounterDoneOp(const OperatorDef& operator_def, Workspace* ws)
112  : Operator<Context>(operator_def, ws) {}
113 
114  bool RunOnDevice() override {
115  auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
116  auto* output = OperatorBase::Output<TensorCPU>(0);
117  output->Resize(std::vector<int>{});
118  *output->template mutable_data<bool>() = counterPtr->checkIfDone();
119  return true;
120  }
121 };
122 
123 // Will always use TensorCPU regardless the Context
124 template <typename T, class Context>
125 class CountUpOp final : public Operator<Context> {
126  public:
127  USE_OPERATOR_CONTEXT_FUNCTIONS;
128  CountUpOp(const OperatorDef& operator_def, Workspace* ws)
129  : Operator<Context>(operator_def, ws) {}
130 
131  bool RunOnDevice() override {
132  auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
133  auto* output = OperatorBase::Output<TensorCPU>(0);
134  output->Resize(std::vector<int>{});
135  *output->template mutable_data<T>() = counterPtr->countUp();
136  return true;
137  }
138 };
139 
140 // Will always use TensorCPU regardless the Context
141 template <typename T, class Context>
142 class RetrieveCountOp final : public Operator<Context> {
143  public:
144  USE_OPERATOR_CONTEXT_FUNCTIONS;
145  RetrieveCountOp(const OperatorDef& operator_def, Workspace* ws)
146  : Operator<Context>(operator_def, ws) {}
147 
148  bool RunOnDevice() override {
149  auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
150  auto* output = OperatorBase::Output<TensorCPU>(0);
151  output->Resize(std::vector<int>{});
152  *output->template mutable_data<T>() = counterPtr->retrieve();
153  return true;
154  }
155 };
156 
157 } // namespace caffe2
158 #endif // CAFFE2_OPERATORS_COUNTER_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...