Caffe2 - C++ API
A deep learning, cross platform ML framework
counter_ops.cc
1 #include "counter_ops.h"
2 
3 #include "caffe2/core/blob_serialization.h"
4 
5 namespace caffe2 {
6 namespace {
15 class CounterSerializer : public BlobSerializerBase {
16  public:
17  CounterSerializer() {}
18  ~CounterSerializer() {}
19 
20  void Serialize(
21  const Blob& blob,
22  const string& name,
23  SerializationAcceptor acceptor) override {
24  CAFFE_ENFORCE(blob.IsType<std::unique_ptr<Counter<int64_t>>>());
25 
26  BlobProto blob_proto;
27  blob_proto.set_name(name);
28  blob_proto.set_type("std::unique_ptr<Counter<int64_t>>");
29  TensorProto& proto = *blob_proto.mutable_tensor();
30  proto.set_name(name);
31  proto.set_data_type(TensorProto_DataType_INT64);
32  proto.add_dims(1);
33  proto.add_int64_data(
34  blob.template Get<std::unique_ptr<Counter<int64_t>>>()->retrieve());
35  acceptor(name, blob_proto.SerializeAsString());
36  }
37 };
38 
43 class CounterDeserializer : public BlobDeserializerBase {
44  public:
45  void Deserialize(const BlobProto& proto, Blob* blob) override {
46  auto tensorProto = proto.tensor();
47  CAFFE_ENFORCE_EQ(tensorProto.dims_size(), 1, "Unexpected size of dims");
48  CAFFE_ENFORCE_EQ(tensorProto.dims(0), 1, "Unexpected value of dims");
49  CAFFE_ENFORCE_EQ(
50  tensorProto.data_type(),
51  TensorProto_DataType_INT64,
52  "Only int64_t counters supported");
53  CAFFE_ENFORCE_EQ(
54  tensorProto.int64_data_size(), 1, "Unexpected size of data");
55  *blob->GetMutable<std::unique_ptr<Counter<int64_t>>>() =
56  caffe2::make_unique<Counter<int64_t>>(tensorProto.int64_data(0));
57  }
58 };
59 }
60 
61 // TODO(jiayq): deprecate these ops & consolidate them with
62 // IterOp/AtomicIterOp
63 
64 REGISTER_CPU_OPERATOR(CreateCounter, CreateCounterOp<int64_t, CPUContext>);
65 REGISTER_CPU_OPERATOR(ResetCounter, ResetCounterOp<int64_t, CPUContext>);
66 REGISTER_CPU_OPERATOR(CountDown, CountDownOp<int64_t, CPUContext>);
67 REGISTER_CPU_OPERATOR(
68  CheckCounterDone,
69  CheckCounterDoneOp<int64_t, CPUContext>);
70 REGISTER_CPU_OPERATOR(CountUp, CountUpOp<int64_t, CPUContext>);
71 REGISTER_CPU_OPERATOR(RetrieveCount, RetrieveCountOp<int64_t, CPUContext>);
72 
73 OPERATOR_SCHEMA(CreateCounter)
74  .NumInputs(0)
75  .NumOutputs(1)
76  .SetDoc(R"DOC(
77 Creates a count-down counter with initial value specified by the 'init_count'
78 argument.
79 )DOC")
80  .Output(0, "counter", "A blob pointing to an instance of a new counter.")
81  .Arg("init_count", "Initial count for the counter, must be >= 0.");
82 
83 OPERATOR_SCHEMA(ResetCounter)
84  .NumInputs(1)
85  .NumOutputs(0, 1)
86  .SetDoc(R"DOC(
87 Resets a count-down counter with initial value specified by the 'init_count'
88 argument.
89 )DOC")
90  .Input(0, "counter", "A blob pointing to an instance of a new counter.")
91  .Output(0, "previous_value", "(optional) Previous value of the counter.")
92  .Arg("init_count", "Resets counter to this value, must be >= 0.");
93 
94 OPERATOR_SCHEMA(CountDown)
95  .NumInputs(1)
96  .NumOutputs(1)
97  .SetDoc(R"DOC(
98 If the internal count value > 0, decreases count value by 1 and outputs false,
99 otherwise outputs true.
100 )DOC")
101  .Input(0, "counter", "A blob pointing to an instance of a counter.")
102  .Output(0, "done", "false unless the internal count is zero.");
103 
104 OPERATOR_SCHEMA(CheckCounterDone)
105  .NumInputs(1)
106  .NumOutputs(1)
107  .SetDoc(R"DOC(
108 If the internal count value <= 0, outputs true, otherwise outputs false,
109 )DOC")
110  .Input(0, "counter", "A blob pointing to an instance of a counter.")
111  .Output(0, "done", "true if the internal count is zero or negative.");
112 
113 OPERATOR_SCHEMA(CountUp)
114  .NumInputs(1)
115  .NumOutputs(1)
116  .SetDoc(R"DOC(
117 Increases count value by 1 and outputs the previous value atomically
118 )DOC")
119  .Input(0, "counter", "A blob pointing to an instance of a counter.")
120  .Output(0, "previous_count", "count value BEFORE this operation");
121 
122 OPERATOR_SCHEMA(RetrieveCount)
123  .NumInputs(1)
124  .NumOutputs(1)
125  .ScalarType(TensorProto::INT64)
126  .SetDoc(R"DOC(
127 Retrieve the current value from the counter.
128 )DOC")
129  .Input(0, "counter", "A blob pointing to an instance of a counter.")
130  .Output(0, "count", "current count value.");
131 
132 SHOULD_NOT_DO_GRADIENT(CreateCounter);
133 SHOULD_NOT_DO_GRADIENT(ResetCounter);
134 SHOULD_NOT_DO_GRADIENT(CountDown);
135 SHOULD_NOT_DO_GRADIENT(CountUp);
136 SHOULD_NOT_DO_GRADIENT(RetrieveCount);
137 
138 CAFFE_KNOWN_TYPE(std::unique_ptr<Counter<int64_t>>);
139 REGISTER_BLOB_SERIALIZER(
140  (TypeMeta::Id<std::unique_ptr<Counter<int64_t>>>()),
141  CounterSerializer);
142 REGISTER_BLOB_DESERIALIZER(
143  std::unique_ptr<Counter<int64_t>>,
144  CounterDeserializer);
145 
146 } // namespace caffe2
static CAFFE2_API CaffeTypeId Id()
Returns the unique id for the given type T.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...