1 #ifndef CAFFE2_OPERATORS_COUNTER_OPS_H 2 #define CAFFE2_OPERATORS_COUNTER_OPS_H 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/logging.h" 8 #include "caffe2/core/operator.h" 14 explicit Counter(T count) : count_(count) {}
30 T checkIfDone()
const {
31 return (count_.load() <= 0);
34 T reset(T init_count) {
35 return count_.exchange(init_count);
39 std::atomic<T> count_;
44 template <
typename T,
class Context>
47 USE_OPERATOR_CONTEXT_FUNCTIONS;
50 init_count_(OperatorBase::GetSingleArgument<T>(
"init_count", 0)) {
51 CAFFE_ENFORCE_LE(0, init_count_,
"negative init_count is not permitted.");
54 bool RunOnDevice()
override {
55 *OperatorBase::Output<std::unique_ptr<Counter<T>>>(0) =
64 template <
typename T,
class Context>
67 USE_OPERATOR_CONTEXT_FUNCTIONS;
70 init_count_(OperatorBase::GetSingleArgument<T>(
"init_count", 0)) {
71 CAFFE_ENFORCE_LE(0, init_count_,
"negative init_count is not permitted.");
74 bool RunOnDevice()
override {
75 auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
76 auto previous = counterPtr->reset(init_count_);
77 if (OutputSize() == 1) {
78 auto* output = OperatorBase::Output<TensorCPU>(0);
80 *output->template mutable_data<T>() = previous;
90 template <
typename T,
class Context>
93 USE_OPERATOR_CONTEXT_FUNCTIONS;
97 bool RunOnDevice()
override {
98 auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
99 auto* output = OperatorBase::Output<TensorCPU>(0);
100 output->Resize(std::vector<int>{});
101 *output->template mutable_data<bool>() = counterPtr->countDown();
107 template <
typename T,
class Context>
110 USE_OPERATOR_CONTEXT_FUNCTIONS;
114 bool RunOnDevice()
override {
115 auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
116 auto* output = OperatorBase::Output<TensorCPU>(0);
117 output->Resize(std::vector<int>{});
118 *output->template mutable_data<bool>() = counterPtr->checkIfDone();
124 template <
typename T,
class Context>
127 USE_OPERATOR_CONTEXT_FUNCTIONS;
131 bool RunOnDevice()
override {
132 auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
133 auto* output = OperatorBase::Output<TensorCPU>(0);
134 output->Resize(std::vector<int>{});
135 *output->template mutable_data<T>() = counterPtr->countUp();
141 template <
typename T,
class Context>
144 USE_OPERATOR_CONTEXT_FUNCTIONS;
148 bool RunOnDevice()
override {
149 auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
150 auto* output = OperatorBase::Output<TensorCPU>(0);
151 output->Resize(std::vector<int>{});
152 *output->template mutable_data<T>() = counterPtr->retrieve();
158 #endif // CAFFE2_OPERATORS_COUNTER_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...