Caffe2 - C++ API
A deep learning, cross platform ML framework
iter_op.h
1 #ifndef CAFFE2_SGD_ITER_OP_H_
2 #define CAFFE2_SGD_ITER_OP_H_
3 
4 #include <limits>
5 #include <mutex>
6 
7 #include "caffe2/core/blob_serialization.h"
8 #include "caffe2/core/context.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/stats.h"
11 
12 namespace caffe2 {
13 
14 inline void IncrementIter(TensorCPU* output) {
15  CAFFE_ENFORCE_EQ(
16  output->size(),
17  1,
18  "The output of IterOp exists, but not of the right size.");
19  int64_t* iter = output->template mutable_data<int64_t>();
20  CAFFE_ENFORCE(*iter >= 0, "Previous iteration number is negative.");
21  CAFFE_ENFORCE(
22  *iter < std::numeric_limits<int64_t>::max(), "Overflow will happen!");
23  (*iter)++;
24 }
25 
26 // IterOp runs an iteration counter. I cannot think of a case where we would
27 // need to access the iter variable on device, so this will always produce a
28 // tensor on the CPU side. If the blob already exists and is a tensor<int64_t>
29 // object, we will simply increment it (this emulates the case when we want to
30 // resume training). Otherwise we will have the iter starting with 0.
31 template <class Context>
32 class IterOp final : public Operator<Context> {
33  public:
34  USE_OPERATOR_CONTEXT_FUNCTIONS;
35 
36  IterOp(const OperatorDef& operator_def, Workspace* ws)
37  : Operator<Context>(operator_def, ws) {}
38 
39  bool RunOnDevice() override {
40  if (InputSize() == 0) {
41  if (!OperatorBase::OutputIsType<TensorCPU>(0)) {
42  // This is the first run; set the iter to start with 0.
43  LOG(ERROR) << "You are using an old definition of IterOp that will "
44  "be deprecated soon. More specifically, IterOp now "
45  "requires an explicit in-place input and output.";
46 
47  auto* output = OperatorBase::Output<TensorCPU>(0);
48  VLOG(1) << "Initializing iter counter.";
49  output->Resize(1);
50  output->template mutable_data<int64_t>()[0] = 0;
51  }
52  }
53  IncrementIter(OperatorBase::Output<TensorCPU>(0));
54  return true;
55  }
56 };
57 
58 template <class Context>
59 class AtomicIterOp final : public Operator<Context> {
60  public:
61  USE_OPERATOR_CONTEXT_FUNCTIONS;
62 
63  AtomicIterOp(const OperatorDef& operator_def, Workspace* ws)
64  : Operator<Context>(operator_def, ws),
65  stats_(std::string("atomic_iter/stats/") + operator_def.input(1)) {}
66 
67  bool RunOnDevice() override {
68  auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(0);
69  std::lock_guard<std::mutex> lg(*mutex);
70  IncrementIter(OperatorBase::Output<TensorCPU>(0));
71  CAFFE_EVENT(stats_, num_iter);
72  return true;
73  }
74 
75  private:
76  struct AtomicIterOpStats {
77  CAFFE_STAT_CTOR(AtomicIterOpStats);
78  CAFFE_EXPORTED_STAT(num_iter);
79  } stats_;
80 };
81 
83  public:
89  void Serialize(
90  const Blob& blob,
91  const string& name,
92  BlobSerializerBase::SerializationAcceptor acceptor) override;
93 };
94 
96  public:
97  void Deserialize(const BlobProto& proto, Blob* blob) override;
98 };
99 
100 } // namespace caffe2
101 
102 #endif // CAFFE2_SGD_ITER_OP_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
BlobSerializerBase is an abstract class that serializes a blob to a string.