1 #include "counter_ops.h" 3 #include "caffe2/core/blob_serialization.h" 15 class CounterSerializer :
public BlobSerializerBase {
17 CounterSerializer() {}
18 ~CounterSerializer() {}
23 SerializationAcceptor acceptor)
override {
24 CAFFE_ENFORCE(blob.IsType<std::unique_ptr<Counter<int64_t>>>());
27 blob_proto.set_name(name);
28 blob_proto.set_type(
"std::unique_ptr<Counter<int64_t>>");
29 TensorProto& proto = *blob_proto.mutable_tensor();
31 proto.set_data_type(TensorProto_DataType_INT64);
34 blob.template Get<std::unique_ptr<Counter<int64_t>>>()->retrieve());
35 acceptor(name, blob_proto.SerializeAsString());
43 class CounterDeserializer :
public BlobDeserializerBase {
45 void Deserialize(
const BlobProto& proto, Blob* blob)
override {
46 auto tensorProto = proto.tensor();
47 CAFFE_ENFORCE_EQ(tensorProto.dims_size(), 1,
"Unexpected size of dims");
48 CAFFE_ENFORCE_EQ(tensorProto.dims(0), 1,
"Unexpected value of dims");
50 tensorProto.data_type(),
51 TensorProto_DataType_INT64,
52 "Only int64_t counters supported");
54 tensorProto.int64_data_size(), 1,
"Unexpected size of data");
55 *blob->GetMutable<std::unique_ptr<Counter<int64_t>>>() =
56 caffe2::make_unique<Counter<int64_t>>(tensorProto.int64_data(0));
64 REGISTER_CPU_OPERATOR(CreateCounter, CreateCounterOp<int64_t, CPUContext>);
65 REGISTER_CPU_OPERATOR(ResetCounter, ResetCounterOp<int64_t, CPUContext>);
66 REGISTER_CPU_OPERATOR(CountDown, CountDownOp<int64_t, CPUContext>);
67 REGISTER_CPU_OPERATOR(
69 CheckCounterDoneOp<int64_t, CPUContext>);
70 REGISTER_CPU_OPERATOR(CountUp, CountUpOp<int64_t, CPUContext>);
71 REGISTER_CPU_OPERATOR(RetrieveCount, RetrieveCountOp<int64_t, CPUContext>);
73 OPERATOR_SCHEMA(CreateCounter)
77 Creates a count-down counter with initial value specified by the 'init_count' 80 .Output(0, "counter",
"A blob pointing to an instance of a new counter.")
81 .Arg(
"init_count",
"Initial count for the counter, must be >= 0.");
83 OPERATOR_SCHEMA(ResetCounter)
87 Resets a count-down counter with initial value specified by the 'init_count' 90 .Input(0, "counter",
"A blob pointing to an instance of a new counter.")
91 .Output(0,
"previous_value",
"(optional) Previous value of the counter.")
92 .Arg(
"init_count",
"Resets counter to this value, must be >= 0.");
94 OPERATOR_SCHEMA(CountDown)
98 If the internal count value > 0, decreases count value by 1 and outputs false, 99 otherwise outputs true. 101 .Input(0, "counter",
"A blob pointing to an instance of a counter.")
102 .Output(0,
"done",
"false unless the internal count is zero.");
104 OPERATOR_SCHEMA(CheckCounterDone)
108 If the internal count value <= 0, outputs true, otherwise outputs false, 110 .Input(0, "counter",
"A blob pointing to an instance of a counter.")
111 .Output(0,
"done",
"true if the internal count is zero or negative.");
113 OPERATOR_SCHEMA(CountUp)
117 Increases count value by 1 and outputs the previous value atomically 119 .Input(0, "counter",
"A blob pointing to an instance of a counter.")
120 .Output(0,
"previous_count",
"count value BEFORE this operation");
122 OPERATOR_SCHEMA(RetrieveCount)
125 .ScalarType(TensorProto::INT64)
127 Retrieve the current value from the counter. 129 .Input(0, "counter",
"A blob pointing to an instance of a counter.")
130 .Output(0,
"count",
"current count value.");
132 SHOULD_NOT_DO_GRADIENT(CreateCounter);
133 SHOULD_NOT_DO_GRADIENT(ResetCounter);
134 SHOULD_NOT_DO_GRADIENT(CountDown);
135 SHOULD_NOT_DO_GRADIENT(CountUp);
136 SHOULD_NOT_DO_GRADIENT(RetrieveCount);
138 CAFFE_KNOWN_TYPE(std::unique_ptr<Counter<int64_t>>);
139 REGISTER_BLOB_SERIALIZER(
142 REGISTER_BLOB_DESERIALIZER(
143 std::unique_ptr<Counter<int64_t>>,
144 CounterDeserializer);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...