1 #ifndef CAFFE2_SGD_ITER_OP_H_ 2 #define CAFFE2_SGD_ITER_OP_H_ 7 #include "caffe2/core/blob_serialization.h" 8 #include "caffe2/core/context.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/stats.h" 14 inline void IncrementIter(TensorCPU* output) {
18 "The output of IterOp exists, but not of the right size.");
19 int64_t* iter = output->template mutable_data<int64_t>();
20 CAFFE_ENFORCE(*iter >= 0,
"Previous iteration number is negative.");
22 *iter < std::numeric_limits<int64_t>::max(),
"Overflow will happen!");
31 template <
class Context>
34 USE_OPERATOR_CONTEXT_FUNCTIONS;
39 bool RunOnDevice()
override {
40 if (InputSize() == 0) {
41 if (!OperatorBase::OutputIsType<TensorCPU>(0)) {
43 LOG(ERROR) <<
"You are using an old definition of IterOp that will " 44 "be deprecated soon. More specifically, IterOp now " 45 "requires an explicit in-place input and output.";
47 auto* output = OperatorBase::Output<TensorCPU>(0);
48 VLOG(1) <<
"Initializing iter counter.";
50 output->template mutable_data<int64_t>()[0] = 0;
53 IncrementIter(OperatorBase::Output<TensorCPU>(0));
58 template <
class Context>
61 USE_OPERATOR_CONTEXT_FUNCTIONS;
65 stats_(std::string(
"atomic_iter/stats/") + operator_def.input(1)) {}
67 bool RunOnDevice()
override {
68 auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(0);
69 std::lock_guard<std::mutex> lg(*mutex);
70 IncrementIter(OperatorBase::Output<TensorCPU>(0));
71 CAFFE_EVENT(stats_, num_iter);
76 struct AtomicIterOpStats {
77 CAFFE_STAT_CTOR(AtomicIterOpStats);
78 CAFFE_EXPORTED_STAT(num_iter);
92 BlobSerializerBase::SerializationAcceptor acceptor)
override;
97 void Deserialize(
const BlobProto& proto,
Blob* blob)
override;
102 #endif // CAFFE2_SGD_ITER_OP_H_ Blob is a general container that hosts a typed pointer.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
BlobSerializerBase is an abstract class that serializes a blob to a string.