1 #include "caffe2/core/context.h" 2 #include "caffe2/core/operator.h" 3 #include "caffe2/core/tensor.h" 4 #include "caffe2/core/types.h" 5 #include "caffe2/operators/text_file_reader_utils.h" 6 #include "caffe2/utils/string_utils.h" 12 const std::vector<char>& delims,
14 const std::string& filename,
16 const std::vector<int>& types)
17 : fileReader(filename),
18 tokenizer(
Tokenizer(delims, escape), &fileReader, numPasses),
20 for (
const auto dt : fieldTypes) {
22 DataTypeToTypeMeta(static_cast<TensorProto_DataType>(dt)));
23 fieldByteSizes.push_back(fieldMetas.back().itemsize());
29 std::vector<int> fieldTypes;
30 std::vector<TypeMeta> fieldMetas;
31 std::vector<size_t> fieldByteSizes;
36 std::mutex globalMutex_;
43 filename_(GetSingleArgument<string>(
"filename",
"")),
44 numPasses_(GetSingleArgument<int>(
"num_passes", 1)),
45 fieldTypes_(GetRepeatedArgument<int>(
"field_types")) {
46 CAFFE_ENFORCE(fieldTypes_.size() > 0,
"field_types arg must be non-empty");
49 bool RunOnDevice()
override {
50 *OperatorBase::Output<std::unique_ptr<TextFileReaderInstance>>(0) =
52 {
'\n',
'\t'},
'\0', filename_, numPasses_, fieldTypes_));
57 std::string filename_;
59 std::vector<int> fieldTypes_;
63 TensorProto_DataType dst_type,
64 const char* src_start,
68 case TensorProto_DataType_STRING: {
69 static_cast<std::string*
>(dst)->assign(src_start, src_end);
71 case TensorProto_DataType_FLOAT: {
73 std::string str_copy(src_start, src_end);
74 const char* src_copy = str_copy.c_str();
76 float val = strtof(src_copy, &src_copy_end);
77 if (src_copy == src_copy_end) {
78 throw std::runtime_error(
"Invalid float: " + str_copy);
80 *
static_cast<float*
>(dst) = val;
83 throw std::runtime_error(
"Unsupported type.");
91 batchSize_(GetSingleArgument<int>(
"batch_size", 1)) {}
93 bool RunOnDevice()
override {
94 const int numFields = OutputSize();
95 CAFFE_ENFORCE(numFields > 0,
"Expected at least one output.");
98 OperatorBase::Input<std::unique_ptr<TextFileReaderInstance>>(0).
get();
101 instance->fieldTypes.size() == numFields,
102 "Invalid number of outputs. Expected " +
103 to_string(instance->fieldTypes.size()) +
" got " +
104 to_string(numFields));
109 std::vector<char*> datas(numFields);
110 for (
int i = 0; i < numFields; ++i) {
111 Output(i)->Resize(batchSize_);
112 datas[i] = (
char*)Output(i)->raw_mutable_data(instance->fieldMetas[i]);
118 std::lock_guard<std::mutex> guard(instance->globalMutex_);
120 bool finished =
false;
122 while (!finished && (rowsRead < batchSize_)) {
124 for (field = 0; field < numFields; ++field) {
125 finished = !instance->tokenizer.next(token);
128 field == 0,
"Invalid number of fields at end of file.");
132 (field == 0 && token.startDelimId == 0) ||
133 (field > 0 && token.startDelimId == 1),
134 "Invalid number of columns at row ",
135 instance->rowsRead + rowsRead + 1);
136 const auto& meta = instance->fieldMetas[field];
137 char*& data = datas[field];
139 (TensorProto_DataType)instance->fieldTypes[field],
143 data += instance->fieldByteSizes[field];
149 instance->rowsRead += rowsRead;
152 for (
int i = 0; i < numFields; ++i) {
153 Output(i)->Shrink(rowsRead);
162 CAFFE_KNOWN_TYPE(std::unique_ptr<TextFileReaderInstance>);
167 OPERATOR_SCHEMA(CreateTextFileReader)
170 .SetDoc(
"Create a text file reader. Fields are delimited by <TAB>.")
171 .Arg(
"filename",
"Path to the file.")
172 .Arg(
"num_passes",
"Number of passes over the file.")
175 "List with type of each field. Type enum is found at core.DataType.")
176 .Output(0,
"handler",
"Pointer to the created TextFileReaderInstance.");
178 OPERATOR_SCHEMA(TextFileReaderRead)
180 .NumOutputs(1, INT_MAX)
182 "Read a batch of rows from the given text file reader instance. " 183 "Expects the number of fields to be equal to the number of outputs. " 184 "Each output is a 1D tensor containing the values for the given field " 185 "for each row. When end of file is reached, returns empty tensors.")
186 .Input(0,
"handler",
"Pointer to an existing TextFileReaderInstance.")
187 .Arg(
"batch_size",
"Maximum number of rows to read.");
189 NO_GRADIENT(CreateTextFileReader);
190 NO_GRADIENT(TextFileReaderRead);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...