Caffe2 - C++ API
A deep learning, cross platform ML framework
iter_op.cc
1 #include "caffe2/sgd/iter_op.h"
2 
3 namespace caffe2 {
4 
6  const Blob& blob,
7  const string& name,
8  BlobSerializerBase::SerializationAcceptor acceptor) {
9  CAFFE_ENFORCE(blob.IsType<std::unique_ptr<std::mutex>>());
10  BlobProto blob_proto;
11  blob_proto.set_name(name);
12  blob_proto.set_type("std::unique_ptr<std::mutex>");
13  blob_proto.set_content("");
14  acceptor(name, blob_proto.SerializeAsString());
15 }
16 
17 void MutexDeserializer::Deserialize(const BlobProto& /* unused */, Blob* blob) {
18  *blob->GetMutable<std::unique_ptr<std::mutex>>() =
19  caffe2::make_unique<std::mutex>();
20 }
21 
22 REGISTER_CPU_OPERATOR(Iter, IterOp<CPUContext>);
23 REGISTER_CPU_OPERATOR(AtomicIter, AtomicIterOp<CPUContext>);
24 
25 REGISTER_BLOB_SERIALIZER(
26  (TypeMeta::Id<std::unique_ptr<std::mutex>>()),
28 REGISTER_BLOB_DESERIALIZER(std::unique_ptr<std::mutex>, MutexDeserializer);
29 
30 OPERATOR_SCHEMA(Iter)
31  .NumInputs(0, 1)
32  .NumOutputs(1)
33  .EnforceInplace({{0, 0}})
34  .SetDoc(R"DOC(
35 Stores a singe integer, that gets incremented on each call to Run().
36 Useful for tracking the iteration count during SGD, for example.
37 )DOC");
38 
39 OPERATOR_SCHEMA(AtomicIter)
40  .NumInputs(2)
41  .NumOutputs(1)
42  .EnforceInplace({{1, 0}})
43  .SetDoc(R"DOC(
44 Similar to Iter, but takes a mutex as the first input to make sure that
45 updates are carried out atomically. This can be used in e.g. Hogwild sgd
46 algorithms.
47 )DOC")
48  .Input(0, "mutex", "The mutex used to do atomic increment.")
49  .Input(1, "iter", "The iter counter as an int64_t TensorCPU.");
50 
51 NO_GRADIENT(Iter);
52 NO_GRADIENT(AtomicIter);
53 } // namespace caffe2
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
static CAFFE2_API CaffeTypeId Id()
Returns the unique id for the given type T.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
T * GetMutable(bool *is_new_object=nullptr)
Gets a mutable pointer to the stored object.
Definition: blob.h:101
bool IsType() const
Checks if the content stored in the blob is of type T.
Definition: blob.h:58
void Serialize(const Blob &blob, const string &name, BlobSerializerBase::SerializationAcceptor acceptor) override
Serializes a std::unique_ptr<std::mutex>.
Definition: iter_op.cc:5