1 #ifndef CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_ 2 #define CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_ 7 #include "caffe2/core/db.h" 8 #include "caffe2/operators/prefetch_op.h" 12 template <
class Context>
15 using OperatorBase::OutputSize;
22 bool Prefetch()
override;
23 bool CopyPrefetched()
override;
27 vector<Blob> prefetched_blobs_;
29 bool shape_inferred_ =
false;
34 template <
class Context>
36 const OperatorDef& operator_def,
39 prefetched_blobs_(operator_def.output_size()),
41 OperatorBase::template GetSingleArgument<int>(
"batch_size", 0)) {}
43 template <
class Context>
45 const db::DBReader& reader = OperatorBase::Input<db::DBReader>(0);
47 if (batch_size_ == 0) {
50 reader.
Read(&key_, &value_);
52 CAFFE_ENFORCE(protos.ParseFromString(value_));
53 CAFFE_ENFORCE(protos.protos_size() == OutputSize());
54 for (
int i = 0; i < protos.protos_size(); ++i) {
55 if (protos.protos(i).has_device_detail()) {
56 protos.mutable_protos(i)->clear_device_detail();
58 deserializer.Deserialize(
60 prefetched_blobs_[i].template GetMutable<TensorCPU>());
63 vector<TensorCPU> temp_tensors(OutputSize());
64 for (
int item_id = 0; item_id < batch_size_; ++item_id) {
65 reader.
Read(&key_, &value_);
67 CAFFE_ENFORCE(protos.ParseFromString(value_));
68 CAFFE_ENFORCE(protos.protos_size() == OutputSize());
69 if (!shape_inferred_) {
71 for (
int i = 0; i < protos.protos_size(); ++i) {
73 protos.protos(i).dims().begin(), protos.protos(i).dims().end());
74 dims.insert(dims.begin(), batch_size_);
75 prefetched_blobs_[i].template GetMutable<TensorCPU>()->Resize(dims);
78 for (
int i = 0; i < protos.protos_size(); ++i) {
79 TensorCPU* dst = prefetched_blobs_[i].template GetMutable<TensorCPU>();
81 if (protos.protos(i).has_device_detail()) {
82 protos.mutable_protos(i)->clear_device_detail();
84 deserializer.Deserialize(protos.protos(i), &src);
85 DCHECK_EQ(src.size() * batch_size_, dst->
size());
86 this->context_.template CopyItems<CPUContext, CPUContext>(
91 src.nbytes() * item_id);
98 template <
class Context>
100 for (
int i = 0; i < OutputSize(); ++i) {
101 OperatorBase::Output<Tensor<Context>>(i)->CopyFrom(
102 prefetched_blobs_[i].
template Get<TensorCPU>(), &this->context_);
109 #endif // CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_ void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
A reader wrapper for DB that also allows us to serialize it.
TIndex size() const
Returns the size (i.e.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void * raw_mutable_data(const TypeMeta &meta)
Returns a mutable raw pointer of the underlying storage.