Caffe2 - C++ API
A deep learning, cross platform ML framework
store_ops.cc
1 #include "store_ops.h"
2 
3 namespace caffe2 {
4 
5 constexpr auto kBlobName = "blob_name";
6 constexpr auto kAddValue = "add_value";
7 
8 StoreSetOp::StoreSetOp(const OperatorDef& operator_def, Workspace* ws)
9  : Operator<CPUContext>(operator_def, ws),
10  blobName_(
11  GetSingleArgument<std::string>(kBlobName, operator_def.input(DATA))) {
12 }
13 
14 bool StoreSetOp::RunOnDevice() {
15  // Serialize and pass to store
16  auto* handler =
17  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
18  handler->set(blobName_, InputBlob(DATA).Serialize(blobName_));
19  return true;
20 }
21 
22 REGISTER_CPU_OPERATOR(StoreSet, StoreSetOp);
23 OPERATOR_SCHEMA(StoreSet)
24  .NumInputs(2)
25  .NumOutputs(0)
26  .SetDoc(R"DOC(
27 Set a blob in a store. The key is the input blob's name and the value
28 is the data in that blob. The key can be overridden by specifying the
29 'blob_name' argument.
30 )DOC")
31  .Arg("blob_name", "alternative key for the blob (optional)")
32  .Input(0, "handler", "unique_ptr<StoreHandler>")
33  .Input(1, "data", "data blob");
34 
35 StoreGetOp::StoreGetOp(const OperatorDef& operator_def, Workspace* ws)
36  : Operator<CPUContext>(operator_def, ws),
37  blobName_(GetSingleArgument<std::string>(
38  kBlobName,
39  operator_def.output(DATA))) {}
40 
41 bool StoreGetOp::RunOnDevice() {
42  // Get from store and deserialize
43  auto* handler =
44  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
45  OperatorBase::Outputs()[DATA]->Deserialize(handler->get(blobName_));
46  return true;
47 }
48 
49 REGISTER_CPU_OPERATOR(StoreGet, StoreGetOp);
50 OPERATOR_SCHEMA(StoreGet)
51  .NumInputs(1)
52  .NumOutputs(1)
53  .SetDoc(R"DOC(
54 Get a blob from a store. The key is the output blob's name. The key
55 can be overridden by specifying the 'blob_name' argument.
56 )DOC")
57  .Arg("blob_name", "alternative key for the blob (optional)")
58  .Input(0, "handler", "unique_ptr<StoreHandler>")
59  .Output(0, "data", "data blob");
60 
61 StoreAddOp::StoreAddOp(const OperatorDef& operator_def, Workspace* ws)
62  : Operator<CPUContext>(operator_def, ws),
63  blobName_(GetSingleArgument<std::string>(kBlobName, "")),
64  addValue_(GetSingleArgument<int64_t>(kAddValue, 1)) {
65  CAFFE_ENFORCE(HasArgument(kBlobName));
66 }
67 
68 bool StoreAddOp::RunOnDevice() {
69  auto* handler =
70  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
71  Output(VALUE)->Resize(1);
72  Output(VALUE)->mutable_data<int64_t>()[0] =
73  handler->add(blobName_, addValue_);
74  return true;
75 }
76 
77 REGISTER_CPU_OPERATOR(StoreAdd, StoreAddOp);
78 OPERATOR_SCHEMA(StoreAdd)
79  .NumInputs(1)
80  .NumOutputs(1)
81  .SetDoc(R"DOC(
82 Add a value to a remote counter. If the key is not set, the store
83 initializes it to 0 and then performs the add operation. The operation
84 returns the resulting counter value.
85 )DOC")
86  .Arg("blob_name", "key of the counter (required)")
87  .Arg("add_value", "value that is added (optional, default: 1)")
88  .Input(0, "handler", "unique_ptr<StoreHandler>")
89  .Output(0, "value", "the current value of the counter");
90 
91 StoreWaitOp::StoreWaitOp(const OperatorDef& operator_def, Workspace* ws)
92  : Operator<CPUContext>(operator_def, ws),
93  blobNames_(GetRepeatedArgument<std::string>(kBlobName)) {}
94 
95 bool StoreWaitOp::RunOnDevice() {
96  auto* handler =
97  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
98  if (InputSize() == 2 && Input(1).IsType<std::string>()) {
99  CAFFE_ENFORCE(
100  blobNames_.empty(), "cannot specify both argument and input blob");
101  std::vector<std::string> blobNames;
102  auto* namesPtr = Input(1).data<std::string>();
103  for (int i = 0; i < Input(1).size(); ++i) {
104  blobNames.push_back(namesPtr[i]);
105  }
106  handler->wait(blobNames);
107  } else {
108  handler->wait(blobNames_);
109  }
110  return true;
111 }
112 
113 REGISTER_CPU_OPERATOR(StoreWait, StoreWaitOp);
114 OPERATOR_SCHEMA(StoreWait)
115  .NumInputs(1, 2)
116  .NumOutputs(0)
117  .SetDoc(R"DOC(
118 Wait for the specified blob names to be set. The blob names can be passed
119 either as an input blob with blob names or as an argument.
120 )DOC")
121  .Arg("blob_names", "names of the blobs to wait for (optional)")
122  .Input(0, "handler", "unique_ptr<StoreHandler>")
123  .Input(1, "names", "names of the blobs to wait for (optional)");
124 }
Definition: types.h:72
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...