4 Defines a minimal set of data types that allow to represent datasets with 5 arbitrary nested structure, including objects of variable length, such as 8 This defines a columnar storage format for such datasets on top of caffe2 9 tensors. In terms of capacity of representation, it can represent most of 10 the data types supported by Parquet, ORC, DWRF file formats. 12 See comments in operator_test/dataset_ops_test.py for an example and 13 walkthrough on how to use schema to store and iterate through a structured 16 from __future__
import absolute_import
17 from __future__
import division
18 from __future__
import print_function
19 from __future__
import unicode_literals
26 from collections
import OrderedDict, namedtuple
27 from past.builtins
import basestring
28 from future.utils
import viewitems, viewkeys, viewvalues
29 from itertools
import islice
30 from six
import StringIO
32 logger = logging.getLogger(__name__)
33 logger.setLevel(logging.INFO)
38 def _join_field_name(prefix, suffix):
40 return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
49 def _normalize_field(field_or_type_or_blob, keep_blobs=True):
50 """Clones/normalizes a field before adding it to a container.""" 51 if isinstance(field_or_type_or_blob, Field):
52 return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
53 elif type(field_or_type_or_blob)
in (type, np.dtype):
54 return Scalar(dtype=field_or_type_or_blob)
56 return Scalar(blob=field_or_type_or_blob)
59 FeatureSpec = namedtuple(
65 'feature_is_request_only',
70 FeatureSpec.__new__.__defaults__ = (
None,
None,
None,
None,
None)
75 'Metadata', [
'categorical_limit',
'expected_value',
'feature_specs']
78 """Represents additional information associated with a scalar in schema. 80 `categorical_limit` - for fields of integral type that are guaranteed to be 81 non-negative it specifies the maximum possible value plus one. It's often 82 used as a size of an embedding table. 84 `expected_value` - anticipated average value of elements in the field. 85 Usually makes sense for length fields of lists. 87 `feature_specs` - information about the features that contained in this 88 field. For example if field have more than 1 feature it can have list of 89 feature names contained in this field.""" 93 Metadata.__new__.__defaults__ = (
None,
None,
None)
97 """Represents an abstract field type in a dataset. 101 """Derived classes must call this after their initialization.""" 105 for child
in children:
106 self._field_offsets.append(offset)
107 offset += len(child.field_names())
108 self._field_offsets.append(offset)
110 def clone_schema(self):
111 return self.
clone(keep_blobs=
False)
114 """Return the children field names for this field.""" 115 raise NotImplementedError(
'Field is an abstract class.')
118 """Return the numpy.dtype for each of the children fields.""" 119 raise NotImplementedError(
'Field is an abstract class.')
122 """Return the Metadata for each of the children fields.""" 123 raise NotImplementedError(
'Field is an abstract class.')
126 """Return the list of blobs with contents for this Field. 127 Values can either be all numpy.ndarray or BlobReference. 128 If any of the fields doens't have a blob, throws. 130 raise NotImplementedError(
'Field is an abstract class.')
133 """Return the list of all Scalar instances in the Field. 134 The order is the same as for field_names() or field_blobs()""" 135 raise NotImplementedError(
'Field is an abstract class.')
138 """Return True if every scalar of this field has blobs.""" 139 raise NotImplementedError(
'Field is an abstract class.')
142 """Clone this Field along with its children.""" 143 raise NotImplementedError(
'Field is an abstract class.')
145 def _set_parent(self, parent, relative_id):
146 self.
_parent = (parent, relative_id)
150 Returns a slice representing the range of field ids that belong to 151 this field. This slice can be used to index a list of fields. 158 >>> ('b1', Scalar()), 159 >>> ('b2', Scalar()), 163 >>> field_data = ['da', 'db1', 'db2', 'dc'] 164 >>> field_data[s.b.split()] 170 def _child_base_id(self, child_index=None):
171 """Get the base id of the given child""" 173 pos = 0
if child_index
is None else self.
_field_offsets[child_index]
175 pos += p._child_base_id(i)
179 """Equivalance of two schemas""" 186 def _pprint_impl(self, indent, str_buffer):
187 raise NotImplementedError(
'Field is an abstrct class.')
190 str_buffer = StringIO()
192 contents = str_buffer.getvalue()
198 """Represents a variable-length list. 200 Values of a list can also be complex fields such as Lists and Structs. 201 In addition to the fields exposed by its `values` field, a List exposes an 202 additional `lengths` field, which will contain the size of each list under 206 def __init__(self, values, lengths_blob=None):
207 if isinstance(lengths_blob, Field):
208 assert isinstance(lengths_blob, Scalar)
209 self.
lengths = _normalize_field(lengths_blob)
212 self.
_items = _normalize_field(values)
213 self.lengths._set_parent(self, 0)
214 self._items._set_parent(self, 1)
217 def field_names(self):
218 value_fields = self._items.field_names()
220 [
'lengths'] + [_join_field_name(
'values', v)
for v
in value_fields]
223 def field_types(self):
224 return self.lengths.field_types() + self._items.field_types()
226 def field_metadata(self):
227 return self.lengths.field_metadata() + self._items.field_metadata()
229 def field_blobs(self):
230 return self.lengths.field_blobs() + self._items.field_blobs()
232 def all_scalars(self):
233 return self.lengths.all_scalars() + self._items.all_scalars()
236 return self.lengths.has_blobs()
and self._items.has_blobs()
238 def clone(self, keep_blobs=True):
240 _normalize_field(self.
_items, keep_blobs=keep_blobs),
241 _normalize_field(self.
lengths, keep_blobs=keep_blobs)
244 def _pprint_impl(self, indent, str_buffer):
245 str_buffer.write(
' ' * indent +
"List(\n")
246 str_buffer.write(
' ' * (indent + 1) +
"lengths=\n")
247 self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
248 str_buffer.write(
' ' * (indent + 1) +
"_items=\n")
249 self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
250 str_buffer.write(
' ' * indent +
")\n")
253 """If the value of this list is a struct, 254 allow to introspect directly into its fields.""" 255 if item.startswith(
'__'):
256 raise AttributeError(item)
257 if isinstance(self.
_items, Struct):
258 return getattr(self.
_items, item)
259 elif item ==
'value' or item ==
'items':
262 raise AttributeError(
'Field not found in list: %s.' % item)
264 def __getitem__(self, item):
265 names = item.split(FIELD_SEPARATOR, 1)
268 if item ==
'lengths':
270 elif item ==
'values':
273 if names[0] ==
'values':
274 return self.
_items[names[1]]
275 raise KeyError(
'Field not found in list: %s.' % item)
279 """Represents a named list of fields sharing the same domain. 283 """ fields is a list of tuples in format of (name, field). The name is 284 a string of nested name, e.g., `a`, `a:b`, `a:b:c`. For example 301 ('d', Struct(('e', Scalar()))), 307 assert len(field) == 2
308 assert field[0],
'Field names cannot be empty' 309 assert field[0] !=
'lengths', (
310 'Struct cannot contain a field named `lengths`.' 312 fields = [(name, _normalize_field(field))
for name, field
in fields]
313 self.
fields = OrderedDict()
314 for name, field
in fields:
315 if FIELD_SEPARATOR
in name:
317 if name
not in self.
fields:
321 not isinstance(field, Struct)
or 322 not isinstance(self.
fields[name], Struct)
324 raise ValueError(
'Duplicate field name: %s' % name)
326 for id, (_, field)
in enumerate(viewitems(self.
fields)):
327 field._set_parent(self, id)
328 Field.__init__(self, viewvalues(self.
fields))
331 def _struct_from_nested_name(self, nested_name, field):
332 def create_internal(nested_name, field):
333 names = nested_name.split(FIELD_SEPARATOR, 1)
337 added_field = create_internal(names[1], field)
338 return Struct((names[0], added_field))
340 names = nested_name.split(FIELD_SEPARATOR, 1)
341 assert len(names) >= 2
342 return names[0], create_internal(names[1], field)
344 def get_children(self):
345 return list(viewitems(self.
fields))
347 def field_names(self):
349 for name, field
in viewitems(self.
fields):
350 names += [_join_field_name(name, f)
for f
in field.field_names()]
353 def field_types(self):
355 for _, field
in viewitems(self.
fields):
356 types += field.field_types()
359 def field_metadata(self):
361 for _, field
in viewitems(self.
fields):
362 metadata += field.field_metadata()
365 def field_blobs(self):
367 for _, field
in viewitems(self.
fields):
368 blobs += field.field_blobs()
371 def all_scalars(self):
373 for _, field
in viewitems(self.
fields):
374 scalars += field.all_scalars()
378 return all(field.has_blobs()
for field
in viewvalues(self.
fields))
380 def clone(self, keep_blobs=True):
381 normalized_fields = [
382 (k, _normalize_field(v, keep_blobs=keep_blobs))
383 for k, v
in viewitems(self.
fields)
385 return Struct(*normalized_fields)
387 def _get_field_by_nested_name(self, nested_name):
388 names = nested_name.split(FIELD_SEPARATOR, 1)
389 field = self.fields.get(names[0],
None)
398 return field[names[1]]
399 except (KeyError, TypeError):
402 def _pprint_impl(self, indent, str_buffer):
403 str_buffer.write(
' ' * indent +
"Struct( \n")
404 for name, field
in viewitems(self.
fields):
405 str_buffer.write(
' ' * (indent + 1) +
"{}=".format(name) +
"\n")
406 field._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
407 str_buffer.write(
' ' * indent +
") \n")
409 def __contains__(self, item):
411 return field
is not None 418 item can be a tuple or list of ints or strings, or a single 419 int or string. String item is a nested field name, e.g., "a", "a:b", 420 "a:b:c". Int item is the index of a field at the first level of the 423 if isinstance(item, list)
or isinstance(item, tuple):
424 keys = list(viewkeys(self.
fields))
429 if isinstance(k, int)
else k, self[k]
433 elif isinstance(item, int):
434 return next(islice(viewvalues(self.
fields), item,
None))
438 raise KeyError(
'field "%s" not found' % (item))
441 def get(self, item, default_value):
443 similar to python's dictionary get method, return field of item if found 444 (i.e. self.item is valid) or otherwise return default_value 446 it's a syntax suger of python's builtin getattr method 448 return getattr(self, item, default_value)
450 def __getattr__(self, item):
451 if item.startswith(
'__'):
452 raise AttributeError(item)
454 return self.__dict__[
'fields'][item]
456 raise AttributeError(item)
458 def __setattr__(self, key, value):
463 if getattr(self,
'_frozen',
None)
and not key.startswith(
'_'):
464 raise TypeError(
'Struct.__setattr__() is disabled after __init__()')
465 super(Struct, self).__setattr__(key, value)
469 Allows to merge fields of two schema.Struct using '+' operator. 470 If two Struct have common field names, the merge is conducted 471 recursively. Here are examples: 474 s1 = Struct(('a', Scalar())) 475 s2 = Struct(('b', Scalar())) 484 ('b', Struct(('c', Scalar()))), 486 s2 = Struct(('b', Struct(('d', Scalar())))) 495 if not isinstance(other, Struct):
496 return NotImplemented
499 for name, right_field
in other.get_children():
500 if name
not in children:
501 children[name] = right_field
503 left_field = children[name]
504 children[name] = left_field + right_field
506 return Struct(*(viewitems(children)))
510 Allows to remove common fields of two schema.Struct from self by 511 using '-' operator. If two Struct have common field names, the 512 removal is conducted recursively. If a child struct has no fields 513 inside, it will be removed from its parent. Here are examples: 520 s2 = Struct(('a', Scalar())) 521 s1 - s2 == Struct(('b', Scalar())) 531 ('b', Struct(('c', Scalar()))), 556 if not isinstance(other, Struct):
557 return NotImplemented
560 for name, right_field
in other.get_children():
562 left_field = children[name]
563 if type(left_field) == type(right_field):
564 if isinstance(left_field, Struct):
565 child = left_field - right_field
566 if child.get_children():
567 children[name] = child
572 "Type of left_field, " + str(type(left_field)) +
573 ", is not the same as that of right_field, " +
574 str(type(right_field)) +
575 ", yet they have the same field name, " + name)
576 return Struct(*(children.items()))
580 """Represents a typed scalar or tensor of fixed shape. 582 A Scalar is a leaf in a schema tree, translating to exactly one tensor in 583 the dataset's underlying storage. 585 Usually, the tensor storing the actual values of this field is a 1D tensor, 586 representing a series of values in its domain. It is possible however to 587 have higher rank values stored as a Scalar, as long as all entries have 594 Scalar field of type float64. Caffe2 will expect readers and 595 datasets to expose it as a 1D tensor of doubles (vector), where 596 the size of the vector is determined by this fields' domain. 598 Scalar((np.int32, 5)) 600 Tensor field of type int32. Caffe2 will expect readers and 601 datasets to implement it as a 2D tensor (matrix) of shape (L, 5), 602 where L is determined by this fields' domain. 604 Scalar((str, (10, 20))) 606 Tensor field of type str. Caffe2 will expect readers and 607 datasets to implement it as a 3D tensor of shape (L, 10, 20), 608 where L is determined by this fields' domain. 610 If the field type is unknown at construction time, call Scalar(), that will 611 default to np.void as its dtype. 613 It is an error to pass a structured dtype to Scalar, since it would contain 614 more than one field. Instead, use from_dtype, which will construct 615 a nested `Struct` field reflecting the given dtype's structure. 617 A Scalar can also contain a blob, which represents the value of this 618 Scalar. A blob can be either a numpy.ndarray, in which case it contain the 619 actual contents of the Scalar, or a BlobReference, which represents a 620 blob living in a caffe2 Workspace. If blob of different types are passed, 621 a conversion to numpy.ndarray is attempted. 624 def __init__(self, dtype=None, blob=None, metadata=None):
626 self.
set(dtype, blob, metadata, unsafe=
True)
627 Field.__init__(self, [])
629 def field_names(self):
632 def field_type(self):
635 def field_types(self):
638 def field_metadata(self):
642 return self.
_blob is not None 644 def field_blobs(self):
645 assert self.
_blob is not None,
'Value is not set for this field.' 648 def all_scalars(self):
651 def clone(self, keep_blobs=True):
654 blob=self.
_blob if keep_blobs
else None,
659 """Gets the current blob of this Scalar field.""" 660 assert self.
_blob is not None,
'Value is not set for this field.' 664 """Shortcut for self.get()""" 671 def set_metadata(self, value):
672 assert isinstance(value, Metadata), \
673 'metadata must be Metadata, got {}'.format(type(value))
677 def _validate_metadata(self):
680 if (self._metadata.categorical_limit
is not None and 681 self.
dtype is not None):
682 assert np.issubdtype(self.
dtype, np.integer), \
683 "`categorical_limit` can be specified only in integral " + \
684 "fields but got {}".format(self.
dtype)
686 def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False):
687 """Sets only the blob field still validating the existing dtype""" 688 if self.dtype.base != np.void
and throw_on_type_mismatch:
689 assert isinstance(blob, np.ndarray),
"Got {!r}".format(blob)
690 assert blob.dtype.base == self.dtype.base, (
691 "Expected {}, got {}".format(self.dtype.base, blob.dtype.base))
694 def set(self, dtype=None, blob=None, metadata=None, unsafe=False):
695 """Set the type and/or blob of this scalar. See __init__ for details. 698 dtype: can be any numpy type. If not provided and `blob` is 699 provided, it will be inferred. If no argument is provided, 700 this Scalar will be of type np.void. 701 blob: if provided, can be either a BlobReference or a 702 numpy.ndarray. If a value of different type is passed, 703 a conversion to numpy.ndarray is attempted. Strings aren't 704 accepted, since they can be ambiguous. If you want to pass 705 a string, to either BlobReference(blob) or np.array(blob). 706 metadata: optional instance of Metadata, if provided overrides 707 the metadata information of the scalar 711 "Scalar should be considered immutable. Only call Scalar.set() " 712 "on newly created Scalar with unsafe=True. This will become an " 715 if blob
is not None and isinstance(blob, basestring):
717 'Passing str blob to Scalar.set() is ambiguous. ' 718 'Do either set(blob=np.array(blob)) or ' 719 'set(blob=BlobReference(blob))' 723 if dtype
is not None:
724 dtype = np.dtype(dtype)
727 if blob
is not None and not isinstance(blob, BlobReference):
728 preserve_shape = isinstance(blob, np.ndarray)
729 if dtype
is not None and dtype != np.void:
730 blob = np.array(blob, dtype=dtype.base)
732 if blob.size == 0
and not preserve_shape:
733 blob = blob.reshape((0, ) + dtype.shape)
735 assert isinstance(blob, np.ndarray), (
736 'Invalid blob type: %s' % str(type(blob)))
740 if len(blob.shape) == 0
and not preserve_shape:
741 blob = blob.reshape((1, ))
745 if (len(blob.shape) > 1
and dtype
is not None and 746 dtype.base != np.void):
747 dtype = np.dtype((dtype.base, blob.shape[1:]))
750 dtype = np.dtype(np.void)
751 assert not dtype.fields, (
752 'Cannot create Scalar with a structured dtype. ' +
753 'Use from_dtype instead.' 757 if metadata
is not None:
761 def set_type(self, dtype):
763 if dtype
is not None:
764 self.
dtype = np.dtype(dtype)
766 self.
dtype = np.dtype(np.void)
769 def _pprint_impl(self, indent, str_buffer):
770 str_buffer.write(
' ' * (indent) +
771 'Scalar({!r}, {!r}, {!r})'.format(
776 Return the zero-indexed position of this scalar field in its schema. 777 Used in order to index into the field_blob list returned by readers or 787 values_name=
'values',
790 """A map is a List of Struct containing keys and values fields. 791 Optionally, you can provide custom name for the key and value fields. 794 Struct((keys_name, keys), (values_name, values)),
795 lengths_blob=lengths_blob
799 def NamedTuple(name_prefix, *fields):
800 return Struct(* [(
'%s_%d' % (name_prefix, i), field)
801 for i, field
in enumerate(fields)])
806 Creates a Struct with default, sequential, field names of given types. 808 return NamedTuple(
'field', *fields)
811 def RawTuple(num_fields, name_prefix='field'):
813 Creates a tuple of `num_field` untyped scalars. 815 assert isinstance(num_fields, int)
816 assert num_fields >= 0
817 return NamedTuple(name_prefix, *([np.void] * num_fields))
820 def from_dtype(dtype, _outer_shape=()):
821 """Constructs a Caffe2 schema from the given numpy's dtype. 823 Numpy supports scalar, array-like and structured datatypes, as long as 824 all the shapes are fixed. This function breaks down the given dtype into 825 a Caffe2 schema containing `Struct` and `Scalar` types. 827 Fields containing byte offsets are not currently supported. 829 if not isinstance(dtype, np.dtype):
832 dtype = np.dtype((dtype, _outer_shape))
835 shape = _outer_shape + dtype.shape
836 if shape != dtype.shape:
837 dtype = np.dtype((dtype.base, shape))
843 for name, (fdtype, offset)
in dtype.fields:
844 assert offset == 0, (
'Fields with byte offsets are not supported.')
845 struct_fields += (name, from_dtype(fdtype, _outer_shape=shape))
846 return Struct(*struct_fields)
850 """This is a private class used to represent a Schema Node""" 852 def __init__(self, name, type_str=''):
858 def add_child(self, name, type_str=''):
860 if child.name == name
and child.type_str == type_str:
863 self.children.append(child)
868 list_names = [
'lengths',
'values']
869 map_names = [
'lengths',
'keys',
'values']
872 if self.
field is None:
879 child_names.append(child.name)
881 if (set(child_names) == set(list_names)):
883 if child.name ==
'values':
884 values_field = child.get_field()
886 lengths_field = child.get_field()
889 lengths_blob=lengths_field
893 elif (set(child_names) == set(map_names)):
895 if child.name ==
'keys':
896 key_field = child.get_field()
897 elif child.name ==
'values':
898 values_field = child.get_field()
900 lengths_field = child.get_field()
904 lengths_blob=lengths_field
912 struct_fields.append((child.name, child.get_field()))
918 def print_recursively(self):
920 child.print_recursively()
921 logger.info(
"Printing node: Name and type")
922 logger.info(self.
name)
926 def from_column_list(
927 col_names, col_types=
None,
928 col_blobs=
None, col_metadata=
None 931 Given a list of names, types, and optionally values, construct a Schema. 933 if col_types
is None:
934 col_types = [
None] * len(col_names)
935 if col_metadata
is None:
936 col_metadata = [
None] * len(col_names)
937 if col_blobs
is None:
938 col_blobs = [
None] * len(col_names)
939 assert len(col_names) == len(col_types), (
940 'col_names and col_types must have the same length.' 942 assert len(col_names) == len(col_metadata), (
943 'col_names and col_metadata must have the same length.' 945 assert len(col_names) == len(col_blobs), (
946 'col_names and col_blobs must have the same length.' 949 for col_name, col_type, col_blob, col_metadata
in zip(
950 col_names, col_types, col_blobs, col_metadata
952 columns = col_name.split(FIELD_SEPARATOR)
954 for i
in range(len(columns)):
958 if i == len(columns) - 1:
963 metadata=col_metadata
965 next = current.add_child(name, type_str)
966 if field
is not None:
970 return root.get_field()
973 def from_blob_list(schema, values, throw_on_type_mismatch=False):
975 Create a schema that clones the given schema, but containing the given 978 assert isinstance(schema, Field),
'Argument `schema` must be a Field.' 979 if isinstance(values, BlobReference):
981 record = schema.clone_schema()
982 scalars = record.all_scalars()
983 assert len(scalars) == len(values), (
984 'Values must have %d elements, got %d.' % (len(scalars), len(values))
986 for scalar, value
in zip(scalars, values):
987 scalar.set_value(value, throw_on_type_mismatch, unsafe=
True)
991 def as_record(value):
992 if isinstance(value, Field):
994 elif isinstance(value, list)
or isinstance(value, tuple):
996 f
is tuple
and len(f) == 2
and isinstance(f[0], basestring)
1000 return Struct(* [(k, as_record(v))
for k, v
in value])
1002 return Tuple(* [as_record(f)
for f
in value])
1003 elif isinstance(value, dict):
1004 return Struct(* [(k, as_record(v))
for k, v
in viewitems(value)])
1006 return _normalize_field(value)
1009 def FetchRecord(blob_record, ws=None, throw_on_type_mismatch=False):
1011 Given a record containing BlobReferences, return a new record with same 1012 schema, containing numpy arrays, fetched from the current active workspace. 1017 return workspace.FetchBlob(str(v))
1019 return ws.blobs[str(v)].fetch()
1021 assert isinstance(blob_record, Field)
1022 field_blobs = blob_record.field_blobs()
1023 assert all(isinstance(v, BlobReference)
for v
in field_blobs)
1024 field_arrays = [fetch(value)
for value
in field_blobs]
1025 return from_blob_list(blob_record, field_arrays, throw_on_type_mismatch)
1028 def FeedRecord(blob_record, arrays, ws=None):
1030 Given a Record containing blob_references and arrays, which is either 1031 a list of numpy arrays or a Record containing numpy arrays, feeds the 1032 record to the current workspace. 1037 workspace.FeedBlob(str(b), v)
1039 ws.create_blob(str(b))
1040 ws.blobs[str(b)].feed(v)
1042 assert isinstance(blob_record, Field)
1043 field_blobs = blob_record.field_blobs()
1044 assert all(isinstance(v, BlobReference)
for v
in field_blobs)
1045 if isinstance(arrays, Field):
1047 arrays = arrays.field_blobs()
1048 assert len(arrays) == len(field_blobs), (
1049 'Values must contain exactly %d ndarrays.' % len(field_blobs)
1051 for blob, array
in zip(field_blobs, arrays):
1055 def NewRecord(net, schema):
1057 Given a record of np.arrays, create a BlobReference for each one of them, 1058 returning a record containing BlobReferences. The name of each returned blob 1059 is NextScopedBlob(field_name), which guarantees unique name in the current 1060 net. Use NameScope explicitly to avoid name conflictions between different 1063 if isinstance(schema, Scalar):
1064 result = schema.clone()
1066 blob=net.NextScopedBlob(
'unnamed_scalar'),
1071 assert isinstance(schema, Field),
'Record must be a schema.Field instance.' 1073 net.NextScopedBlob(prefix=name)
1074 for name
in schema.field_names()
1076 return from_blob_list(schema, blob_refs)
1079 def ConstRecord(net, array_record):
1081 Given a record of arrays, returns a record of blobs, 1082 initialized with net.Const. 1084 blob_record = NewRecord(net, array_record)
1085 for blob, array
in zip(
1086 blob_record.field_blobs(), array_record.field_blobs()
1088 net.Const(array, blob)
1092 def InitEmptyRecord(net, schema_or_record, enforce_types=False):
1093 if not schema_or_record.has_blobs():
1094 record = NewRecord(net, schema_or_record)
1096 record = schema_or_record
1098 for blob_type, blob
in zip(record.field_types(), record.field_blobs()):
1100 data_type = data_type_for_dtype(blob_type)
1101 shape = [0] + list(blob_type.shape)
1102 net.ConstantFill([], blob, shape=shape, dtype=data_type)
1104 logger.warning(
"Blob {} has type error".format(blob))
1115 net.ConstantFill([], blob, shape=[0])
1120 _DATA_TYPE_FOR_DTYPE = [
1121 (np.str, core.DataType.STRING),
1122 (np.float16, core.DataType.FLOAT16),
1123 (np.float32, core.DataType.FLOAT),
1124 (np.float64, core.DataType.DOUBLE),
1125 (np.bool, core.DataType.BOOL),
1126 (np.int8, core.DataType.INT8),
1127 (np.int16, core.DataType.INT16),
1128 (np.int32, core.DataType.INT32),
1129 (np.int64, core.DataType.INT64),
1130 (np.uint8, core.DataType.UINT8),
1131 (np.uint16, core.DataType.UINT16),
1135 def is_schema_subset(schema, original_schema):
1137 return set(schema.field_names()).issubset(
1138 set(original_schema.field_names()))
1141 def equal_schemas(schema,
1143 check_field_names=
True,
1144 check_field_types=
True,
1145 check_field_metas=
False):
1146 assert isinstance(schema, Field)
1147 assert isinstance(original_schema, Field)
1149 if check_field_names
and (
1150 schema.field_names() != original_schema.field_names()):
1152 if check_field_types
and (
1153 schema.field_types() != original_schema.field_types()):
1155 if check_field_metas
and (
1156 schema.field_metadata() != original_schema.field_metadata()):
1162 def schema_check(schema, previous=None):
1163 record = as_record(schema)
1164 if previous
is not None:
1165 assert equal_schemas(schema, previous)
1169 def data_type_for_dtype(dtype):
1170 for np_type, dt
in _DATA_TYPE_FOR_DTYPE:
1171 if dtype.base == np_type:
1173 raise TypeError(
'Unknown dtype: ' + str(dtype.base))
1176 def attach_metadata_to_scalars(field, metadata):
1177 for f
in field.all_scalars():
1178 f.set_metadata(metadata)
def __init__(self, fields)
def set(self, dtype=None, blob=None, metadata=None, unsafe=False)
def __getattr__(self, item)
def set_metadata(self, value)
def get(self, item, default_value)
def _pprint_impl(self, indent, str_buffer)
def __getitem__(self, item)
def clone(self, keep_blobs=True)
def _child_base_id(self, child_index=None)
def _struct_from_nested_name(self, nested_name, field)
def _validate_metadata(self)
def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False)
def _get_field_by_nested_name(self, nested_name)
def __init__(self, children)