Caffe2 - C++ API
A deep learning, cross platform ML framework
text_file_reader.cc
1 #include "caffe2/core/context.h"
2 #include "caffe2/core/operator.h"
3 #include "caffe2/core/tensor.h"
4 #include "caffe2/core/types.h"
5 #include "caffe2/operators/text_file_reader_utils.h"
6 #include "caffe2/utils/string_utils.h"
7 
8 namespace caffe2 {
9 
12  const std::vector<char>& delims,
13  char escape,
14  const std::string& filename,
15  int numPasses,
16  const std::vector<int>& types)
17  : fileReader(filename),
18  tokenizer(Tokenizer(delims, escape), &fileReader, numPasses),
19  fieldTypes(types) {
20  for (const auto dt : fieldTypes) {
21  fieldMetas.push_back(
22  DataTypeToTypeMeta(static_cast<TensorProto_DataType>(dt)));
23  fieldByteSizes.push_back(fieldMetas.back().itemsize());
24  }
25  }
26 
27  FileReader fileReader;
28  BufferedTokenizer tokenizer;
29  std::vector<int> fieldTypes;
30  std::vector<TypeMeta> fieldMetas;
31  std::vector<size_t> fieldByteSizes;
32  size_t rowsRead{0};
33 
34  // hack to guarantee thread-safeness of the read op
35  // TODO(azzolini): support multi-threaded reading.
36  std::mutex globalMutex_;
37 };
38 
39 class CreateTextFileReaderOp : public Operator<CPUContext> {
40  public:
41  CreateTextFileReaderOp(const OperatorDef& operator_def, Workspace* ws)
42  : Operator<CPUContext>(operator_def, ws),
43  filename_(GetSingleArgument<string>("filename", "")),
44  numPasses_(GetSingleArgument<int>("num_passes", 1)),
45  fieldTypes_(GetRepeatedArgument<int>("field_types")) {
46  CAFFE_ENFORCE(fieldTypes_.size() > 0, "field_types arg must be non-empty");
47  }
48 
49  bool RunOnDevice() override {
50  *OperatorBase::Output<std::unique_ptr<TextFileReaderInstance>>(0) =
51  std::unique_ptr<TextFileReaderInstance>(new TextFileReaderInstance(
52  {'\n', '\t'}, '\0', filename_, numPasses_, fieldTypes_));
53  return true;
54  }
55 
56  private:
57  std::string filename_;
58  int numPasses_;
59  std::vector<int> fieldTypes_;
60 };
61 
62 inline void convert(
63  TensorProto_DataType dst_type,
64  const char* src_start,
65  const char* src_end,
66  void* dst) {
67  switch (dst_type) {
68  case TensorProto_DataType_STRING: {
69  static_cast<std::string*>(dst)->assign(src_start, src_end);
70  } break;
71  case TensorProto_DataType_FLOAT: {
72  // TODO(azzolini): avoid copy, use faster convertion
73  std::string str_copy(src_start, src_end);
74  const char* src_copy = str_copy.c_str();
75  char* src_copy_end;
76  float val = strtof(src_copy, &src_copy_end);
77  if (src_copy == src_copy_end) {
78  throw std::runtime_error("Invalid float: " + str_copy);
79  }
80  *static_cast<float*>(dst) = val;
81  } break;
82  default:
83  throw std::runtime_error("Unsupported type.");
84  }
85 }
86 
87 class TextFileReaderReadOp : public Operator<CPUContext> {
88  public:
89  TextFileReaderReadOp(const OperatorDef& operator_def, Workspace* ws)
90  : Operator<CPUContext>(operator_def, ws),
91  batchSize_(GetSingleArgument<int>("batch_size", 1)) {}
92 
93  bool RunOnDevice() override {
94  const int numFields = OutputSize();
95  CAFFE_ENFORCE(numFields > 0, "Expected at least one output.");
96 
97  auto instance =
98  OperatorBase::Input<std::unique_ptr<TextFileReaderInstance>>(0).get();
99 
100  CAFFE_ENFORCE(
101  instance->fieldTypes.size() == numFields,
102  "Invalid number of outputs. Expected " +
103  to_string(instance->fieldTypes.size()) + " got " +
104  to_string(numFields));
105 
106  // char* datas[numFields];
107  // MSVC does not allow using const int, so we will need to dynamically allocate
108  // it.
109  std::vector<char*> datas(numFields);
110  for (int i = 0; i < numFields; ++i) {
111  Output(i)->Resize(batchSize_);
112  datas[i] = (char*)Output(i)->raw_mutable_data(instance->fieldMetas[i]);
113  }
114 
115  int rowsRead = 0;
116  {
117  // TODO(azzolini): support multi-threaded reading
118  std::lock_guard<std::mutex> guard(instance->globalMutex_);
119 
120  bool finished = false;
121  Token token;
122  while (!finished && (rowsRead < batchSize_)) {
123  int field;
124  for (field = 0; field < numFields; ++field) {
125  finished = !instance->tokenizer.next(token);
126  if (finished) {
127  CAFFE_ENFORCE(
128  field == 0, "Invalid number of fields at end of file.");
129  break;
130  }
131  CAFFE_ENFORCE(
132  (field == 0 && token.startDelimId == 0) ||
133  (field > 0 && token.startDelimId == 1),
134  "Invalid number of columns at row ",
135  instance->rowsRead + rowsRead + 1);
136  const auto& meta = instance->fieldMetas[field];
137  char*& data = datas[field];
138  convert(
139  (TensorProto_DataType)instance->fieldTypes[field],
140  token.start,
141  token.end,
142  data);
143  data += instance->fieldByteSizes[field];
144  }
145  if (!finished) {
146  ++rowsRead;
147  }
148  }
149  instance->rowsRead += rowsRead;
150  }
151 
152  for (int i = 0; i < numFields; ++i) {
153  Output(i)->Shrink(rowsRead);
154  }
155  return true;
156  }
157 
158  private:
159  TIndex batchSize_;
160 };
161 
162 CAFFE_KNOWN_TYPE(std::unique_ptr<TextFileReaderInstance>);
163 
164 REGISTER_CPU_OPERATOR(CreateTextFileReader, CreateTextFileReaderOp);
165 REGISTER_CPU_OPERATOR(TextFileReaderRead, TextFileReaderReadOp);
166 
167 OPERATOR_SCHEMA(CreateTextFileReader)
168  .NumInputs(0)
169  .NumOutputs(1)
170  .SetDoc("Create a text file reader. Fields are delimited by <TAB>.")
171  .Arg("filename", "Path to the file.")
172  .Arg("num_passes", "Number of passes over the file.")
173  .Arg(
174  "field_types",
175  "List with type of each field. Type enum is found at core.DataType.")
176  .Output(0, "handler", "Pointer to the created TextFileReaderInstance.");
177 
178 OPERATOR_SCHEMA(TextFileReaderRead)
179  .NumInputs(1)
180  .NumOutputs(1, INT_MAX)
181  .SetDoc(
182  "Read a batch of rows from the given text file reader instance. "
183  "Expects the number of fields to be equal to the number of outputs. "
184  "Each output is a 1D tensor containing the values for the given field "
185  "for each row. When end of file is reached, returns empty tensors.")
186  .Input(0, "handler", "Pointer to an existing TextFileReaderInstance.")
187  .Arg("batch_size", "Maximum number of rows to read.");
188 
189 NO_GRADIENT(CreateTextFileReader);
190 NO_GRADIENT(TextFileReaderRead);
191 
192 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...