Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor_protos_db_input.h
1 #ifndef CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
2 #define CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
3 
4 #include <iostream>
5 #include <mutex>
6 
7 #include "caffe2/core/db.h"
8 #include "caffe2/operators/prefetch_op.h"
9 
10 namespace caffe2 {
11 
12 template <class Context>
13 class TensorProtosDBInput final : public PrefetchOperator<Context> {
14  public:
15  using OperatorBase::OutputSize;
17  explicit TensorProtosDBInput(const OperatorDef& operator_def, Workspace* ws);
20  }
21 
22  bool Prefetch() override;
23  bool CopyPrefetched() override;
24 
25  private:
26  // Prefetch will always just happen on the CPU side.
27  vector<Blob> prefetched_blobs_;
28  int batch_size_;
29  bool shape_inferred_ = false;
30  string key_;
31  string value_;
32 };
33 
34 template <class Context>
36  const OperatorDef& operator_def,
37  Workspace* ws)
38  : PrefetchOperator<Context>(operator_def, ws),
39  prefetched_blobs_(operator_def.output_size()),
40  batch_size_(
41  OperatorBase::template GetSingleArgument<int>("batch_size", 0)) {}
42 
43 template <class Context>
45  const db::DBReader& reader = OperatorBase::Input<db::DBReader>(0);
46  TensorDeserializer<CPUContext> deserializer;
47  if (batch_size_ == 0) {
48  // We do not need to construct a batch. As a result, we will simply
49  // deserialize everything into the target prefetched blob.
50  reader.Read(&key_, &value_);
51  TensorProtos protos;
52  CAFFE_ENFORCE(protos.ParseFromString(value_));
53  CAFFE_ENFORCE(protos.protos_size() == OutputSize());
54  for (int i = 0; i < protos.protos_size(); ++i) {
55  if (protos.protos(i).has_device_detail()) {
56  protos.mutable_protos(i)->clear_device_detail();
57  }
58  deserializer.Deserialize(
59  protos.protos(i),
60  prefetched_blobs_[i].template GetMutable<TensorCPU>());
61  }
62  } else {
63  vector<TensorCPU> temp_tensors(OutputSize());
64  for (int item_id = 0; item_id < batch_size_; ++item_id) {
65  reader.Read(&key_, &value_);
66  TensorProtos protos;
67  CAFFE_ENFORCE(protos.ParseFromString(value_));
68  CAFFE_ENFORCE(protos.protos_size() == OutputSize());
69  if (!shape_inferred_) {
70  // First, set the shape of all the blobs.
71  for (int i = 0; i < protos.protos_size(); ++i) {
72  vector<int> dims(
73  protos.protos(i).dims().begin(), protos.protos(i).dims().end());
74  dims.insert(dims.begin(), batch_size_);
75  prefetched_blobs_[i].template GetMutable<TensorCPU>()->Resize(dims);
76  }
77  }
78  for (int i = 0; i < protos.protos_size(); ++i) {
79  TensorCPU* dst = prefetched_blobs_[i].template GetMutable<TensorCPU>();
80  TensorCPU& src = temp_tensors[i];
81  if (protos.protos(i).has_device_detail()) {
82  protos.mutable_protos(i)->clear_device_detail();
83  }
84  deserializer.Deserialize(protos.protos(i), &src);
85  DCHECK_EQ(src.size() * batch_size_, dst->size());
86  this->context_.template CopyItems<CPUContext, CPUContext>(
87  src.meta(),
88  src.size(),
89  src.raw_data(),
90  static_cast<char*>(dst->raw_mutable_data(src.meta())) +
91  src.nbytes() * item_id);
92  }
93  }
94  }
95  return true;
96 }
97 
98 template <class Context>
100  for (int i = 0; i < OutputSize(); ++i) {
101  OperatorBase::Output<Tensor<Context>>(i)->CopyFrom(
102  prefetched_blobs_[i].template Get<TensorCPU>(), &this->context_);
103  }
104  return true;
105 }
106 
107 } // namespace caffe2
108 
109 #endif // CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:222
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:144
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:593
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void * raw_mutable_data(const TypeMeta &meta)
Returns a mutable raw pointer of the underlying storage.
Definition: tensor.h:510