Caffe2 - C++ API
A deep learning, cross platform ML framework
dataset_ops.h
1 #ifndef CAFFE2_OPERATORS_DATASET_OPS_H_
2 #define CAFFE2_OPERATORS_DATASET_OPS_H_
3 
4 #include <memory>
5 #include <mutex>
6 #include <string>
7 #include <vector>
8 #include "caffe2/core/blob.h"
9 #include "caffe2/core/blob_serialization.h"
10 #include "caffe2/core/tensor.h"
11 
12 namespace caffe2 {
13 namespace dataset_ops {
14 
15 // used for lengths tensors in the dataset
16 using TLength = int32_t;
17 // used for all internal dataset operations (offsets, sizes to read, etc.)
18 using TOffset = int64_t;
19 
24 class TreeIterator {
25  public:
26  struct FieldDesc {
27  int id;
28  int lengthFieldId = -1;
29  std::string name;
30  };
31 
32  explicit TreeIterator(const std::vector<std::string>& fields);
33 
34  void advance(
35  const std::vector<const TLength*>& lengths,
36  std::vector<TOffset>& offsets,
37  std::vector<TOffset>& sizes,
38  std::vector<TOffset>& limits,
39  TOffset num);
40 
41  // Corresponds to the number of fields that have "length" as its last name
42  int numLengthFields() const {
43  return lengthFieldIds_.size();
44  }
45 
46  // Corresponds to the number of length fields + 1 (for the top-level domain)
47  int numOffsetFields() const {
48  return numLengthFields() + 1;
49  }
50 
51  // Get lengthField description for the given field
52  const FieldDesc* lengthFieldFor(const FieldDesc& desc) {
53  return (desc.lengthFieldId == -1)
54  ? nullptr
55  : &fields_.at(lengthFieldIds_.at(desc.lengthFieldId));
56  }
57 
58  // Get lengthField description for the given lengthFieldId, where
59  // 0 <= lengthFieldId < numLengthFields()
60  const FieldDesc& lengthField(int lengthFieldId) {
61  return fields_.at(lengthFieldIds_.at(lengthFieldId));
62  }
63 
64  // Returns the index into the 'offset' vector for the given field.
65  int offsetFieldIdFor(const FieldDesc& fieldDesc) {
66  return fieldDesc.lengthFieldId + 1;
67  }
68 
69  // Returns the field description for all fields.
70  const std::vector<FieldDesc>& fields() {
71  return fields_;
72  }
73 
74  const std::vector<int>& lengthFieldIds() const {
75  return lengthFieldIds_;
76  }
77 
78  private:
79  // Description of each field
80  std::vector<FieldDesc> fields_;
81  // Index into fields_ above for the fields that are lengths.
82  std::vector<int> lengthFieldIds_;
83 };
84 
85 class TreeCursor {
86  public:
87  explicit TreeCursor(const TreeIterator& iterator) : it(iterator) {}
88  std::vector<TOffset> offsets;
89  std::mutex mutex_;
90  TreeIterator it;
91 };
92 
97 class TreeWalker {
98  public:
99  TreeWalker(const vector<const Blob*>& inputs, TreeCursor& cursor);
100 
101  // Returns the number of records in a dataset
102  inline TOffset size() const {
103  return limits_.at(0);
104  }
105 
106  void advance();
107 
108  private:
109  inline const TensorCPU& input(int32_t idx) const {
110  return inputs_[idx]->Get<TensorCPU>();
111  }
112 
113  // TODO: Change to fieldDesc
114  inline const TreeIterator::FieldDesc& field(int idx) const {
115  return cursor_.it.fields().at(idx);
116  }
117 
118  inline int lengthIdx(int fieldId) const {
119  return field(fieldId).lengthFieldId + 1;
120  }
121 
122  inline TOffset offset(int fieldId) const {
123  return prevOffsets_[lengthIdx(fieldId)];
124  }
125 
126  std::vector<TIndex> fieldDim(int fieldId) const;
127 
128  void* fieldPtr(int fieldId) const;
129 
130  public:
131  // Simple Proxy class to expose nicer API for field access
132  class Field {
133  public:
134  Field(TreeWalker& walker, int fieldId)
135  : walker_(walker), fieldId_(fieldId) {}
136 
137  inline std::vector<TIndex> dim() const {
138  return walker_.fieldDim(fieldId_);
139  }
140 
141  inline TIndex size() const {
142  TIndex size = 1;
143  for (const auto d : dim()) {
144  size *= d;
145  }
146  return size;
147  }
148 
149  inline const TypeMeta& meta() const {
150  return walker_.input(fieldId_).meta();
151  }
152 
153  inline void* ptr() const {
154  return walker_.fieldPtr(fieldId_);
155  }
156 
157  int fieldId() const {
158  return fieldId_;
159  }
160 
161  inline TOffset offset() const {
162  return walker_.offset(fieldId_);
163  }
164 
165  private:
166  const TreeWalker& walker_;
167  const int fieldId_;
168  };
169 
170  // Notice that a reference is returned. If advance() is called the fields will
171  // be updated to represent the new state.
172  inline const std::vector<Field>& fields() const {
173  return fields_;
174  }
175 
176  private:
177  void gatherLengthData();
178 
179  void gatherSizeLimits();
180 
181  const vector<const Blob*>& inputs_;
182  TreeCursor& cursor_;
183  std::vector<Field> fields_;
184 
185  std::vector<const TLength*> lengths_;
186  std::vector<TOffset> limits_;
187  std::vector<TOffset> sizes_;
188  std::vector<TOffset> offsets_;
189  std::vector<TOffset> prevOffsets_;
190 };
191 
192 using SharedTensorVectorPtr = std::shared_ptr<std::vector<TensorCPU>>;
193 
194 template <class Context>
195 using TensorVectorPtr = std::unique_ptr<std::vector<Tensor<Context>>>;
196 
198  public:
199  void Serialize(
200  const Blob& blob,
201  const string& name,
202  BlobSerializerBase::SerializationAcceptor acceptor) override;
203 };
204 
206  public:
207  void Deserialize(const BlobProto& proto, Blob* blob) override;
208 };
209 
210 } // namespace dataset_ops
211 } // namespace caffe2
212 
213 #endif // CAFFE2_OPERATORS_DATASET_OPS_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Simple wrapper class allowing an easy traversal of the tensors representing the hirerarchical structu...
Definition: dataset_ops.h:97
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:88
BlobSerializerBase is an abstract class that serializes a blob to a string.
Provides functionality to iterate across a list of tensors where some of those tensors represent leng...
Definition: dataset_ops.h:24