Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_network_blob_fetcher_op.h
1 #ifndef CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_
2 #define CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/core/tensor.h"
8 #include "caffe2/operators/rnn/recurrent_network_op.h"
9 #include "google/protobuf/text_format.h"
10 
11 #include <string>
12 
13 namespace caffe2 {
14 
15 template <class Context>
16 class RecurrentNetworkBlobFetcherOp final : public Operator<Context> {
17  public:
18  USE_OPERATOR_CONTEXT_FUNCTIONS;
19 
20  RecurrentNetworkBlobFetcherOp(const OperatorDef& operator_def, Workspace* ws)
21  : Operator<Context>(operator_def, ws) {
22  prefix_ = OperatorBase::GetSingleArgument<std::string>("prefix", "rnn");
23  ws_ = ws;
24  }
25 
26  bool RunOnDevice() override {
27  const detail::ScratchWorkspaces& scratch =
28  OperatorBase::Input<detail::ScratchWorkspaces>(0);
29  const std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
30  scratch.stepWorkspaces;
31 
32  std::vector<std::string> blob_names_vector = {};
33 
34  for (TIndex i = 0; i < stepWorkspaces.size(); i++) {
35  Workspace* currentStepWorkspace = stepWorkspaces[i].get();
36  std::vector<std::string> blob_names = currentStepWorkspace->LocalBlobs();
37 
38  for (auto& blob_name : blob_names) {
39  const Blob* currentBlob = currentStepWorkspace->GetBlob(blob_name);
40  const auto& currentTensor = currentBlob->Get<Tensor<Context>>();
41 
42  std::string newBlobName =
43  prefix_ + std::string("_") + blob_name + caffe2::to_string(i);
44  blob_names_vector.push_back(newBlobName);
45 
46  ws_->CreateBlob(newBlobName)
47  ->template GetMutable<TensorCPU>()
48  ->ResizeLike(currentTensor);
49 
50  auto* newTensor =
51  ws_->GetBlob(newBlobName)->template GetMutable<Tensor<Context>>();
52  newTensor->template CopyFrom<Context>(currentTensor);
53  }
54  }
55 
56  auto* output = Output(0);
57  output->Resize(blob_names_vector.size());
58  std::copy(
59  blob_names_vector.begin(),
60  blob_names_vector.end(),
61  output->template mutable_data<std::string>());
62 
63  return true;
64  }
65 
66  private:
67  std::string prefix_;
68  Workspace* ws_;
69 };
70 } // namespace caffe2
71 
72 #endif // CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:104
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
vector< string > LocalBlobs() const
Return list of blobs owned by this Workspace, not including blobs shared from parent workspace...
Definition: workspace.cc:75
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:164
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
const T & Get() const
Gets the const reference of the stored object.
Definition: blob.h:75