1 #ifndef CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_ 2 #define CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/core/tensor.h" 8 #include "caffe2/operators/rnn/recurrent_network_op.h" 9 #include "google/protobuf/text_format.h" 15 template <
class Context>
18 USE_OPERATOR_CONTEXT_FUNCTIONS;
22 prefix_ = OperatorBase::GetSingleArgument<std::string>(
"prefix",
"rnn");
26 bool RunOnDevice()
override {
28 OperatorBase::Input<detail::ScratchWorkspaces>(0);
29 const std::vector<std::shared_ptr<Workspace>>& stepWorkspaces =
30 scratch.stepWorkspaces;
32 std::vector<std::string> blob_names_vector = {};
34 for (TIndex i = 0; i < stepWorkspaces.size(); i++) {
35 Workspace* currentStepWorkspace = stepWorkspaces[i].get();
36 std::vector<std::string> blob_names = currentStepWorkspace->
LocalBlobs();
38 for (
auto& blob_name : blob_names) {
39 const Blob* currentBlob = currentStepWorkspace->
GetBlob(blob_name);
42 std::string newBlobName =
43 prefix_ + std::string(
"_") + blob_name + caffe2::to_string(i);
44 blob_names_vector.push_back(newBlobName);
47 ->template GetMutable<TensorCPU>()
48 ->ResizeLike(currentTensor);
51 ws_->
GetBlob(newBlobName)->template GetMutable<Tensor<Context>>();
52 newTensor->template CopyFrom<Context>(currentTensor);
56 auto* output = Output(0);
57 output->Resize(blob_names_vector.size());
59 blob_names_vector.begin(),
60 blob_names_vector.end(),
61 output->template mutable_data<std::string>());
72 #endif // CAFFE2_OPERATORS_RECURRENT_BLOB_FETCHER_OP_H_ Blob is a general container that hosts a typed pointer.
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
vector< string > LocalBlobs() const
Return list of blobs owned by this Workspace, not including blobs shared from parent workspace...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
const T & Get() const
Gets the const reference of the stored object.