Caffe2 - Python API
A deep learning, cross platform ML framework
schema.py
1 ## @package schema
2 # Module caffe2.python.schema
3 """
4 Defines a minimal set of data types that allow to represent datasets with
5 arbitrary nested structure, including objects of variable length, such as
6 maps and lists.
7 
8 This defines a columnar storage format for such datasets on top of caffe2
9 tensors. In terms of capacity of representation, it can represent most of
10 the data types supported by Parquet, ORC, DWRF file formats.
11 
12 See comments in operator_test/dataset_ops_test.py for an example and
13 walkthrough on how to use schema to store and iterate through a structured
14 in-memory dataset.
15 """
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
20 
21 import logging
22 import numpy as np
23 from caffe2.python import core
24 from caffe2.python import workspace
25 from caffe2.python.core import BlobReference
26 from collections import OrderedDict, namedtuple
27 from past.builtins import basestring
28 from future.utils import viewitems, viewkeys, viewvalues
29 from itertools import islice
30 from six import StringIO
31 
32 logger = logging.getLogger(__name__)
33 logger.setLevel(logging.INFO)
34 
35 FIELD_SEPARATOR = ':'
36 
37 
38 def _join_field_name(prefix, suffix):
39  if prefix and suffix:
40  return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
41  elif prefix:
42  return prefix
43  elif suffix:
44  return suffix
45  else:
46  return ''
47 
48 
49 def _normalize_field(field_or_type_or_blob, keep_blobs=True):
50  """Clones/normalizes a field before adding it to a container."""
51  if isinstance(field_or_type_or_blob, Field):
52  return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
53  elif type(field_or_type_or_blob) in (type, np.dtype):
54  return Scalar(dtype=field_or_type_or_blob)
55  else:
56  return Scalar(blob=field_or_type_or_blob)
57 
58 
59 FeatureSpec = namedtuple(
60  'FeatureSpec',
61  [
62  'feature_type',
63  'feature_names',
64  'feature_ids',
65  'feature_is_request_only',
66  'desired_hash_size',
67  ]
68 )
69 
70 FeatureSpec.__new__.__defaults__ = (None, None, None, None, None)
71 
72 
73 class Metadata(
74  namedtuple(
75  'Metadata', ['categorical_limit', 'expected_value', 'feature_specs']
76  )
77 ):
78  """Represents additional information associated with a scalar in schema.
79 
80  `categorical_limit` - for fields of integral type that are guaranteed to be
81  non-negative it specifies the maximum possible value plus one. It's often
82  used as a size of an embedding table.
83 
84  `expected_value` - anticipated average value of elements in the field.
85  Usually makes sense for length fields of lists.
86 
87  `feature_specs` - information about the features that contained in this
88  field. For example if field have more than 1 feature it can have list of
89  feature names contained in this field."""
90  __slots__ = ()
91 
92 
93 Metadata.__new__.__defaults__ = (None, None, None)
94 
95 
96 class Field(object):
97  """Represents an abstract field type in a dataset.
98  """
99 
100  def __init__(self, children):
101  """Derived classes must call this after their initialization."""
102  self._parent = (None, 0)
103  offset = 0
104  self._field_offsets = []
105  for child in children:
106  self._field_offsets.append(offset)
107  offset += len(child.field_names())
108  self._field_offsets.append(offset)
109 
110  def clone_schema(self):
111  return self.clone(keep_blobs=False)
112 
113  def field_names(self):
114  """Return the children field names for this field."""
115  raise NotImplementedError('Field is an abstract class.')
116 
117  def field_types(self):
118  """Return the numpy.dtype for each of the children fields."""
119  raise NotImplementedError('Field is an abstract class.')
120 
121  def field_metadata(self):
122  """Return the Metadata for each of the children fields."""
123  raise NotImplementedError('Field is an abstract class.')
124 
125  def field_blobs(self):
126  """Return the list of blobs with contents for this Field.
127  Values can either be all numpy.ndarray or BlobReference.
128  If any of the fields doens't have a blob, throws.
129  """
130  raise NotImplementedError('Field is an abstract class.')
131 
132  def all_scalars(self):
133  """Return the list of all Scalar instances in the Field.
134  The order is the same as for field_names() or field_blobs()"""
135  raise NotImplementedError('Field is an abstract class.')
136 
137  def has_blobs(self):
138  """Return True if every scalar of this field has blobs."""
139  raise NotImplementedError('Field is an abstract class.')
140 
141  def clone(self, keep_blobs=True):
142  """Clone this Field along with its children."""
143  raise NotImplementedError('Field is an abstract class.')
144 
145  def _set_parent(self, parent, relative_id):
146  self._parent = (parent, relative_id)
147 
148  def slice(self):
149  """
150  Returns a slice representing the range of field ids that belong to
151  this field. This slice can be used to index a list of fields.
152 
153  E.g.:
154 
155  >>> s = Struct(
156  >>> ('a', Scalar()),
157  >>> ('b', Struct(
158  >>> ('b1', Scalar()),
159  >>> ('b2', Scalar()),
160  >>> )),
161  >>> ('c', Scalar()),
162  >>> )
163  >>> field_data = ['da', 'db1', 'db2', 'dc']
164  >>> field_data[s.b.split()]
165  ['db1', 'db2']
166  """
167  base_id = self._child_base_id()
168  return slice(base_id, base_id + len(self.field_names()))
169 
170  def _child_base_id(self, child_index=None):
171  """Get the base id of the given child"""
172  p, i = self._parent
173  pos = 0 if child_index is None else self._field_offsets[child_index]
174  if p:
175  pos += p._child_base_id(i)
176  return pos
177 
178  def __eq__(self, other):
179  """Equivalance of two schemas"""
180  return (
181  (self.field_names() == other.field_names()) and
182  (self.field_types() == other.field_types()) and
183  (self.field_metadata() == other.field_metadata())
184  )
185 
186  def _pprint_impl(self, indent, str_buffer):
187  raise NotImplementedError('Field is an abstrct class.')
188 
189  def __repr__(self):
190  str_buffer = StringIO()
191  self._pprint_impl(0, str_buffer)
192  contents = str_buffer.getvalue()
193  str_buffer.close()
194  return contents
195 
196 
197 class List(Field):
198  """Represents a variable-length list.
199 
200  Values of a list can also be complex fields such as Lists and Structs.
201  In addition to the fields exposed by its `values` field, a List exposes an
202  additional `lengths` field, which will contain the size of each list under
203  the parent domain.
204  """
205 
206  def __init__(self, values, lengths_blob=None):
207  if isinstance(lengths_blob, Field):
208  assert isinstance(lengths_blob, Scalar)
209  self.lengths = _normalize_field(lengths_blob)
210  else:
211  self.lengths = Scalar(np.int32, lengths_blob)
212  self._items = _normalize_field(values)
213  self.lengths._set_parent(self, 0)
214  self._items._set_parent(self, 1)
215  Field.__init__(self, [self.lengths, self._items])
216 
217  def field_names(self):
218  value_fields = self._items.field_names()
219  return (
220  ['lengths'] + [_join_field_name('values', v) for v in value_fields]
221  )
222 
223  def field_types(self):
224  return self.lengths.field_types() + self._items.field_types()
225 
226  def field_metadata(self):
227  return self.lengths.field_metadata() + self._items.field_metadata()
228 
229  def field_blobs(self):
230  return self.lengths.field_blobs() + self._items.field_blobs()
231 
232  def all_scalars(self):
233  return self.lengths.all_scalars() + self._items.all_scalars()
234 
235  def has_blobs(self):
236  return self.lengths.has_blobs() and self._items.has_blobs()
237 
238  def clone(self, keep_blobs=True):
239  return List(
240  _normalize_field(self._items, keep_blobs=keep_blobs),
241  _normalize_field(self.lengths, keep_blobs=keep_blobs)
242  )
243 
244  def _pprint_impl(self, indent, str_buffer):
245  str_buffer.write(' ' * indent + "List(\n")
246  str_buffer.write(' ' * (indent + 1) + "lengths=\n")
247  self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
248  str_buffer.write(' ' * (indent + 1) + "_items=\n")
249  self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
250  str_buffer.write(' ' * indent + ")\n")
251 
252  def __getattr__(self, item):
253  """If the value of this list is a struct,
254  allow to introspect directly into its fields."""
255  if item.startswith('__'):
256  raise AttributeError(item)
257  if isinstance(self._items, Struct):
258  return getattr(self._items, item)
259  elif item == 'value' or item == 'items':
260  return self._items
261  else:
262  raise AttributeError('Field not found in list: %s.' % item)
263 
264  def __getitem__(self, item):
265  names = item.split(FIELD_SEPARATOR, 1)
266 
267  if len(names) == 1:
268  if item == 'lengths':
269  return self.lengths
270  elif item == 'values':
271  return self._items
272  else:
273  if names[0] == 'values':
274  return self._items[names[1]]
275  raise KeyError('Field not found in list: %s.' % item)
276 
277 
278 class Struct(Field):
279  """Represents a named list of fields sharing the same domain.
280  """
281 
282  def __init__(self, *fields):
283  """ fields is a list of tuples in format of (name, field). The name is
284  a string of nested name, e.g., `a`, `a:b`, `a:b:c`. For example
285 
286  Struct(
287  ('a', Scalar()),
288  ('b:c', Scalar()),
289  ('b:d:e', Scalar()),
290  ('b', Struct(
291  ('f', Scalar()),
292  )),
293  )
294 
295  is equal to
296 
297  Struct(
298  ('a', Scalar()),
299  ('b', Struct(
300  ('c', Scalar()),
301  ('d', Struct(('e', Scalar()))),
302  ('f', Scalar()),
303  )),
304  )
305  """
306  for field in fields:
307  assert len(field) == 2
308  assert field[0], 'Field names cannot be empty'
309  assert field[0] != 'lengths', (
310  'Struct cannot contain a field named `lengths`.'
311  )
312  fields = [(name, _normalize_field(field)) for name, field in fields]
313  self.fields = OrderedDict()
314  for name, field in fields:
315  if FIELD_SEPARATOR in name:
316  name, field = self._struct_from_nested_name(name, field)
317  if name not in self.fields:
318  self.fields[name] = field
319  continue
320  if (
321  not isinstance(field, Struct) or
322  not isinstance(self.fields[name], Struct)
323  ):
324  raise ValueError('Duplicate field name: %s' % name)
325  self.fields[name] = self.fields[name] + field
326  for id, (_, field) in enumerate(viewitems(self.fields)):
327  field._set_parent(self, id)
328  Field.__init__(self, viewvalues(self.fields))
329  self._frozen = True
330 
331  def _struct_from_nested_name(self, nested_name, field):
332  def create_internal(nested_name, field):
333  names = nested_name.split(FIELD_SEPARATOR, 1)
334  if len(names) == 1:
335  added_field = field
336  else:
337  added_field = create_internal(names[1], field)
338  return Struct((names[0], added_field))
339 
340  names = nested_name.split(FIELD_SEPARATOR, 1)
341  assert len(names) >= 2
342  return names[0], create_internal(names[1], field)
343 
344  def get_children(self):
345  return list(viewitems(self.fields))
346 
347  def field_names(self):
348  names = []
349  for name, field in viewitems(self.fields):
350  names += [_join_field_name(name, f) for f in field.field_names()]
351  return names
352 
353  def field_types(self):
354  types = []
355  for _, field in viewitems(self.fields):
356  types += field.field_types()
357  return types
358 
359  def field_metadata(self):
360  metadata = []
361  for _, field in viewitems(self.fields):
362  metadata += field.field_metadata()
363  return metadata
364 
365  def field_blobs(self):
366  blobs = []
367  for _, field in viewitems(self.fields):
368  blobs += field.field_blobs()
369  return blobs
370 
371  def all_scalars(self):
372  scalars = []
373  for _, field in viewitems(self.fields):
374  scalars += field.all_scalars()
375  return scalars
376 
377  def has_blobs(self):
378  return all(field.has_blobs() for field in viewvalues(self.fields))
379 
380  def clone(self, keep_blobs=True):
381  normalized_fields = [
382  (k, _normalize_field(v, keep_blobs=keep_blobs))
383  for k, v in viewitems(self.fields)
384  ]
385  return Struct(*normalized_fields)
386 
387  def _get_field_by_nested_name(self, nested_name):
388  names = nested_name.split(FIELD_SEPARATOR, 1)
389  field = self.fields.get(names[0], None)
390 
391  if field is None:
392  return None
393 
394  if len(names) == 1:
395  return field
396 
397  try:
398  return field[names[1]]
399  except (KeyError, TypeError):
400  return None
401 
402  def _pprint_impl(self, indent, str_buffer):
403  str_buffer.write(' ' * indent + "Struct( \n")
404  for name, field in viewitems(self.fields):
405  str_buffer.write(' ' * (indent + 1) + "{}=".format(name) + "\n")
406  field._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
407  str_buffer.write(' ' * indent + ") \n")
408 
409  def __contains__(self, item):
410  field = self._get_field_by_nested_name(item)
411  return field is not None
412 
413  def __len__(self):
414  return len(self.fields)
415 
416  def __getitem__(self, item):
417  """
418  item can be a tuple or list of ints or strings, or a single
419  int or string. String item is a nested field name, e.g., "a", "a:b",
420  "a:b:c". Int item is the index of a field at the first level of the
421  Struct.
422  """
423  if isinstance(item, list) or isinstance(item, tuple):
424  keys = list(viewkeys(self.fields))
425  return Struct(
426  * [
427  (
428  keys[k]
429  if isinstance(k, int) else k, self[k]
430  ) for k in item
431  ]
432  )
433  elif isinstance(item, int):
434  return next(islice(viewvalues(self.fields), item, None))
435  else:
436  field = self._get_field_by_nested_name(item)
437  if field is None:
438  raise KeyError('field "%s" not found' % (item))
439  return field
440 
441  def get(self, item, default_value):
442  """
443  similar to python's dictionary get method, return field of item if found
444  (i.e. self.item is valid) or otherwise return default_value
445 
446  it's a syntax suger of python's builtin getattr method
447  """
448  return getattr(self, item, default_value)
449 
450  def __getattr__(self, item):
451  if item.startswith('__'):
452  raise AttributeError(item)
453  try:
454  return self.__dict__['fields'][item]
455  except KeyError:
456  raise AttributeError(item)
457 
458  def __setattr__(self, key, value):
459  # Disable setting attributes after initialization to prevent false
460  # impression of being able to overwrite a field.
461  # Allowing setting internal states mainly so that _parent can be set
462  # post initialization.
463  if getattr(self, '_frozen', None) and not key.startswith('_'):
464  raise TypeError('Struct.__setattr__() is disabled after __init__()')
465  super(Struct, self).__setattr__(key, value)
466 
467  def __add__(self, other):
468  """
469  Allows to merge fields of two schema.Struct using '+' operator.
470  If two Struct have common field names, the merge is conducted
471  recursively. Here are examples:
472 
473  Example 1
474  s1 = Struct(('a', Scalar()))
475  s2 = Struct(('b', Scalar()))
476  s1 + s2 == Struct(
477  ('a', Scalar()),
478  ('b', Scalar()),
479  )
480 
481  Example 2
482  s1 = Struct(
483  ('a', Scalar()),
484  ('b', Struct(('c', Scalar()))),
485  )
486  s2 = Struct(('b', Struct(('d', Scalar()))))
487  s1 + s2 == Struct(
488  ('a', Scalar()),
489  ('b', Struct(
490  ('c', Scalar()),
491  ('d', Scalar()),
492  )),
493  )
494  """
495  if not isinstance(other, Struct):
496  return NotImplemented
497 
498  children = OrderedDict(self.get_children())
499  for name, right_field in other.get_children():
500  if name not in children:
501  children[name] = right_field
502  continue
503  left_field = children[name]
504  children[name] = left_field + right_field
505 
506  return Struct(*(viewitems(children)))
507 
508  def __sub__(self, other):
509  """
510  Allows to remove common fields of two schema.Struct from self by
511  using '-' operator. If two Struct have common field names, the
512  removal is conducted recursively. If a child struct has no fields
513  inside, it will be removed from its parent. Here are examples:
514 
515  Example 1
516  s1 = Struct(
517  ('a', Scalar()),
518  ('b', Scalar()),
519  )
520  s2 = Struct(('a', Scalar()))
521  s1 - s2 == Struct(('b', Scalar()))
522 
523  Example 2
524  s1 = Struct(
525  ('b', Struct(
526  ('c', Scalar()),
527  ('d', Scalar()),
528  ))
529  )
530  s2 = Struct(
531  ('b', Struct(('c', Scalar()))),
532  )
533  s1 - s2 == Struct(
534  ('b', Struct(
535  ('d', Scalar()),
536  )),
537  )
538 
539  Example 3
540  s1 = Struct(
541  ('a', Scalar()),
542  ('b', Struct(
543  ('d', Scalar()),
544  ))
545  )
546  s2 = Struct(
547  ('b', Struct(
548  ('c', Scalar())
549  ('d', Scalar())
550  )),
551  )
552  s1 - s2 == Struct(
553  ('a', Scalar()),
554  )
555  """
556  if not isinstance(other, Struct):
557  return NotImplemented
558 
559  children = OrderedDict(self.get_children())
560  for name, right_field in other.get_children():
561  if name in children:
562  left_field = children[name]
563  if type(left_field) == type(right_field):
564  if isinstance(left_field, Struct):
565  child = left_field - right_field
566  if child.get_children():
567  children[name] = child
568  continue
569  children.pop(name)
570  else:
571  raise TypeError(
572  "Type of left_field, " + str(type(left_field)) +
573  ", is not the same as that of right_field, " +
574  str(type(right_field)) +
575  ", yet they have the same field name, " + name)
576  return Struct(*(children.items()))
577 
578 
579 class Scalar(Field):
580  """Represents a typed scalar or tensor of fixed shape.
581 
582  A Scalar is a leaf in a schema tree, translating to exactly one tensor in
583  the dataset's underlying storage.
584 
585  Usually, the tensor storing the actual values of this field is a 1D tensor,
586  representing a series of values in its domain. It is possible however to
587  have higher rank values stored as a Scalar, as long as all entries have
588  the same shape.
589 
590  E.g.:
591 
592  Scalar(np.float64)
593 
594  Scalar field of type float64. Caffe2 will expect readers and
595  datasets to expose it as a 1D tensor of doubles (vector), where
596  the size of the vector is determined by this fields' domain.
597 
598  Scalar((np.int32, 5))
599 
600  Tensor field of type int32. Caffe2 will expect readers and
601  datasets to implement it as a 2D tensor (matrix) of shape (L, 5),
602  where L is determined by this fields' domain.
603 
604  Scalar((str, (10, 20)))
605 
606  Tensor field of type str. Caffe2 will expect readers and
607  datasets to implement it as a 3D tensor of shape (L, 10, 20),
608  where L is determined by this fields' domain.
609 
610  If the field type is unknown at construction time, call Scalar(), that will
611  default to np.void as its dtype.
612 
613  It is an error to pass a structured dtype to Scalar, since it would contain
614  more than one field. Instead, use from_dtype, which will construct
615  a nested `Struct` field reflecting the given dtype's structure.
616 
617  A Scalar can also contain a blob, which represents the value of this
618  Scalar. A blob can be either a numpy.ndarray, in which case it contain the
619  actual contents of the Scalar, or a BlobReference, which represents a
620  blob living in a caffe2 Workspace. If blob of different types are passed,
621  a conversion to numpy.ndarray is attempted.
622  """
623 
624  def __init__(self, dtype=None, blob=None, metadata=None):
625  self._metadata = None
626  self.set(dtype, blob, metadata, unsafe=True)
627  Field.__init__(self, [])
628 
629  def field_names(self):
630  return ['']
631 
632  def field_type(self):
633  return self.dtype
634 
635  def field_types(self):
636  return [self.dtype]
637 
638  def field_metadata(self):
639  return [self._metadata]
640 
641  def has_blobs(self):
642  return self._blob is not None
643 
644  def field_blobs(self):
645  assert self._blob is not None, 'Value is not set for this field.'
646  return [self._blob]
647 
648  def all_scalars(self):
649  return [self]
650 
651  def clone(self, keep_blobs=True):
652  return Scalar(
653  dtype=self._original_dtype,
654  blob=self._blob if keep_blobs else None,
655  metadata=self._metadata
656  )
657 
658  def get(self):
659  """Gets the current blob of this Scalar field."""
660  assert self._blob is not None, 'Value is not set for this field.'
661  return self._blob
662 
663  def __call__(self):
664  """Shortcut for self.get()"""
665  return self.get()
666 
667  @property
668  def metadata(self):
669  return self._metadata
670 
671  def set_metadata(self, value):
672  assert isinstance(value, Metadata), \
673  'metadata must be Metadata, got {}'.format(type(value))
674  self._metadata = value
675  self._validate_metadata()
676 
677  def _validate_metadata(self):
678  if self._metadata is None:
679  return
680  if (self._metadata.categorical_limit is not None and
681  self.dtype is not None):
682  assert np.issubdtype(self.dtype, np.integer), \
683  "`categorical_limit` can be specified only in integral " + \
684  "fields but got {}".format(self.dtype)
685 
686  def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False):
687  """Sets only the blob field still validating the existing dtype"""
688  if self.dtype.base != np.void and throw_on_type_mismatch:
689  assert isinstance(blob, np.ndarray), "Got {!r}".format(blob)
690  assert blob.dtype.base == self.dtype.base, (
691  "Expected {}, got {}".format(self.dtype.base, blob.dtype.base))
692  self.set(dtype=self._original_dtype, blob=blob, unsafe=unsafe)
693 
694  def set(self, dtype=None, blob=None, metadata=None, unsafe=False):
695  """Set the type and/or blob of this scalar. See __init__ for details.
696 
697  Args:
698  dtype: can be any numpy type. If not provided and `blob` is
699  provided, it will be inferred. If no argument is provided,
700  this Scalar will be of type np.void.
701  blob: if provided, can be either a BlobReference or a
702  numpy.ndarray. If a value of different type is passed,
703  a conversion to numpy.ndarray is attempted. Strings aren't
704  accepted, since they can be ambiguous. If you want to pass
705  a string, to either BlobReference(blob) or np.array(blob).
706  metadata: optional instance of Metadata, if provided overrides
707  the metadata information of the scalar
708  """
709  if not unsafe:
710  logger.warning(
711  "Scalar should be considered immutable. Only call Scalar.set() "
712  "on newly created Scalar with unsafe=True. This will become an "
713  "error soon."
714  )
715  if blob is not None and isinstance(blob, basestring):
716  raise ValueError(
717  'Passing str blob to Scalar.set() is ambiguous. '
718  'Do either set(blob=np.array(blob)) or '
719  'set(blob=BlobReference(blob))'
720  )
721 
722  self._original_dtype = dtype
723  if dtype is not None:
724  dtype = np.dtype(dtype)
725  # If blob is not None and it is not a BlobReference, we assume that
726  # it is actual tensor data, so we will try to cast it to a numpy array.
727  if blob is not None and not isinstance(blob, BlobReference):
728  preserve_shape = isinstance(blob, np.ndarray)
729  if dtype is not None and dtype != np.void:
730  blob = np.array(blob, dtype=dtype.base)
731  # if array is empty we may need to reshape a little
732  if blob.size == 0 and not preserve_shape:
733  blob = blob.reshape((0, ) + dtype.shape)
734  else:
735  assert isinstance(blob, np.ndarray), (
736  'Invalid blob type: %s' % str(type(blob)))
737 
738  # reshape scalars into 1D arrays
739  # TODO(azzolini): figure out better way of representing this
740  if len(blob.shape) == 0 and not preserve_shape:
741  blob = blob.reshape((1, ))
742 
743  # infer inner shape from the blob given
744  # TODO(dzhulgakov): tweak this to make it work with PackedStruct
745  if (len(blob.shape) > 1 and dtype is not None and
746  dtype.base != np.void):
747  dtype = np.dtype((dtype.base, blob.shape[1:]))
748  # if we were still unable to infer the dtype
749  if dtype is None:
750  dtype = np.dtype(np.void)
751  assert not dtype.fields, (
752  'Cannot create Scalar with a structured dtype. ' +
753  'Use from_dtype instead.'
754  )
755  self.dtype = dtype
756  self._blob = blob
757  if metadata is not None:
758  self.set_metadata(metadata)
759  self._validate_metadata()
760 
761  def set_type(self, dtype):
762  self._original_dtype = dtype
763  if dtype is not None:
764  self.dtype = np.dtype(dtype)
765  else:
766  self.dtype = np.dtype(np.void)
767  self._validate_metadata()
768 
769  def _pprint_impl(self, indent, str_buffer):
770  str_buffer.write(' ' * (indent) +
771  'Scalar({!r}, {!r}, {!r})'.format(
772  self.dtype, self._blob, self._metadata) + "\n")
773 
774  def id(self):
775  """
776  Return the zero-indexed position of this scalar field in its schema.
777  Used in order to index into the field_blob list returned by readers or
778  accepted by writers.
779  """
780  return self._child_base_id()
781 
782 
783 def Map(
784  keys,
785  values,
786  keys_name='keys',
787  values_name='values',
788  lengths_blob=None
789 ):
790  """A map is a List of Struct containing keys and values fields.
791  Optionally, you can provide custom name for the key and value fields.
792  """
793  return List(
794  Struct((keys_name, keys), (values_name, values)),
795  lengths_blob=lengths_blob
796  )
797 
798 
799 def NamedTuple(name_prefix, *fields):
800  return Struct(* [('%s_%d' % (name_prefix, i), field)
801  for i, field in enumerate(fields)])
802 
803 
804 def Tuple(*fields):
805  """
806  Creates a Struct with default, sequential, field names of given types.
807  """
808  return NamedTuple('field', *fields)
809 
810 
811 def RawTuple(num_fields, name_prefix='field'):
812  """
813  Creates a tuple of `num_field` untyped scalars.
814  """
815  assert isinstance(num_fields, int)
816  assert num_fields >= 0
817  return NamedTuple(name_prefix, *([np.void] * num_fields))
818 
819 
820 def from_dtype(dtype, _outer_shape=()):
821  """Constructs a Caffe2 schema from the given numpy's dtype.
822 
823  Numpy supports scalar, array-like and structured datatypes, as long as
824  all the shapes are fixed. This function breaks down the given dtype into
825  a Caffe2 schema containing `Struct` and `Scalar` types.
826 
827  Fields containing byte offsets are not currently supported.
828  """
829  if not isinstance(dtype, np.dtype):
830  # wrap into a ndtype
831  shape = _outer_shape
832  dtype = np.dtype((dtype, _outer_shape))
833  else:
834  # concatenate shapes if necessary
835  shape = _outer_shape + dtype.shape
836  if shape != dtype.shape:
837  dtype = np.dtype((dtype.base, shape))
838 
839  if not dtype.fields:
840  return Scalar(dtype)
841 
842  struct_fields = []
843  for name, (fdtype, offset) in dtype.fields:
844  assert offset == 0, ('Fields with byte offsets are not supported.')
845  struct_fields += (name, from_dtype(fdtype, _outer_shape=shape))
846  return Struct(*struct_fields)
847 
848 
849 class _SchemaNode(object):
850  """This is a private class used to represent a Schema Node"""
851 
852  def __init__(self, name, type_str=''):
853  self.name = name
854  self.children = []
855  self.type_str = type_str
856  self.field = None
857 
858  def add_child(self, name, type_str=''):
859  for child in self.children:
860  if child.name == name and child.type_str == type_str:
861  return child
862  child = _SchemaNode(name, type_str)
863  self.children.append(child)
864  return child
865 
866  def get_field(self):
867 
868  list_names = ['lengths', 'values']
869  map_names = ['lengths', 'keys', 'values']
870 
871  if len(self.children) == 0 or self.field is not None:
872  if self.field is None:
873  return Struct()
874  else:
875  return self.field
876 
877  child_names = []
878  for child in self.children:
879  child_names.append(child.name)
880 
881  if (set(child_names) == set(list_names)):
882  for child in self.children:
883  if child.name == 'values':
884  values_field = child.get_field()
885  else:
886  lengths_field = child.get_field()
887  self.field = List(
888  values_field,
889  lengths_blob=lengths_field
890  )
891  self.type_str = "List"
892  return self.field
893  elif (set(child_names) == set(map_names)):
894  for child in self.children:
895  if child.name == 'keys':
896  key_field = child.get_field()
897  elif child.name == 'values':
898  values_field = child.get_field()
899  else:
900  lengths_field = child.get_field()
901  self.field = Map(
902  key_field,
903  values_field,
904  lengths_blob=lengths_field
905  )
906  self.type_str = "Map"
907  return self.field
908 
909  else:
910  struct_fields = []
911  for child in self.children:
912  struct_fields.append((child.name, child.get_field()))
913 
914  self.field = Struct(*struct_fields)
915  self.type_str = "Struct"
916  return self.field
917 
918  def print_recursively(self):
919  for child in self.children:
920  child.print_recursively()
921  logger.info("Printing node: Name and type")
922  logger.info(self.name)
923  logger.info(self.type_str)
924 
925 
926 def from_column_list(
927  col_names, col_types=None,
928  col_blobs=None, col_metadata=None
929 ):
930  """
931  Given a list of names, types, and optionally values, construct a Schema.
932  """
933  if col_types is None:
934  col_types = [None] * len(col_names)
935  if col_metadata is None:
936  col_metadata = [None] * len(col_names)
937  if col_blobs is None:
938  col_blobs = [None] * len(col_names)
939  assert len(col_names) == len(col_types), (
940  'col_names and col_types must have the same length.'
941  )
942  assert len(col_names) == len(col_metadata), (
943  'col_names and col_metadata must have the same length.'
944  )
945  assert len(col_names) == len(col_blobs), (
946  'col_names and col_blobs must have the same length.'
947  )
948  root = _SchemaNode('root', 'Struct')
949  for col_name, col_type, col_blob, col_metadata in zip(
950  col_names, col_types, col_blobs, col_metadata
951  ):
952  columns = col_name.split(FIELD_SEPARATOR)
953  current = root
954  for i in range(len(columns)):
955  name = columns[i]
956  type_str = ''
957  field = None
958  if i == len(columns) - 1:
959  type_str = col_type
960  field = Scalar(
961  dtype=col_type,
962  blob=col_blob,
963  metadata=col_metadata
964  )
965  next = current.add_child(name, type_str)
966  if field is not None:
967  next.field = field
968  current = next
969 
970  return root.get_field()
971 
972 
973 def from_blob_list(schema, values, throw_on_type_mismatch=False):
974  """
975  Create a schema that clones the given schema, but containing the given
976  list of values.
977  """
978  assert isinstance(schema, Field), 'Argument `schema` must be a Field.'
979  if isinstance(values, BlobReference):
980  values = [values]
981  record = schema.clone_schema()
982  scalars = record.all_scalars()
983  assert len(scalars) == len(values), (
984  'Values must have %d elements, got %d.' % (len(scalars), len(values))
985  )
986  for scalar, value in zip(scalars, values):
987  scalar.set_value(value, throw_on_type_mismatch, unsafe=True)
988  return record
989 
990 
991 def as_record(value):
992  if isinstance(value, Field):
993  return value
994  elif isinstance(value, list) or isinstance(value, tuple):
995  is_field_list = all(
996  f is tuple and len(f) == 2 and isinstance(f[0], basestring)
997  for f in value
998  )
999  if is_field_list:
1000  return Struct(* [(k, as_record(v)) for k, v in value])
1001  else:
1002  return Tuple(* [as_record(f) for f in value])
1003  elif isinstance(value, dict):
1004  return Struct(* [(k, as_record(v)) for k, v in viewitems(value)])
1005  else:
1006  return _normalize_field(value)
1007 
1008 
1009 def FetchRecord(blob_record, ws=None, throw_on_type_mismatch=False):
1010  """
1011  Given a record containing BlobReferences, return a new record with same
1012  schema, containing numpy arrays, fetched from the current active workspace.
1013  """
1014 
1015  def fetch(v):
1016  if ws is None:
1017  return workspace.FetchBlob(str(v))
1018  else:
1019  return ws.blobs[str(v)].fetch()
1020 
1021  assert isinstance(blob_record, Field)
1022  field_blobs = blob_record.field_blobs()
1023  assert all(isinstance(v, BlobReference) for v in field_blobs)
1024  field_arrays = [fetch(value) for value in field_blobs]
1025  return from_blob_list(blob_record, field_arrays, throw_on_type_mismatch)
1026 
1027 
1028 def FeedRecord(blob_record, arrays, ws=None):
1029  """
1030  Given a Record containing blob_references and arrays, which is either
1031  a list of numpy arrays or a Record containing numpy arrays, feeds the
1032  record to the current workspace.
1033  """
1034 
1035  def feed(b, v):
1036  if ws is None:
1037  workspace.FeedBlob(str(b), v)
1038  else:
1039  ws.create_blob(str(b))
1040  ws.blobs[str(b)].feed(v)
1041 
1042  assert isinstance(blob_record, Field)
1043  field_blobs = blob_record.field_blobs()
1044  assert all(isinstance(v, BlobReference) for v in field_blobs)
1045  if isinstance(arrays, Field):
1046  # TODO: check schema
1047  arrays = arrays.field_blobs()
1048  assert len(arrays) == len(field_blobs), (
1049  'Values must contain exactly %d ndarrays.' % len(field_blobs)
1050  )
1051  for blob, array in zip(field_blobs, arrays):
1052  feed(blob, array)
1053 
1054 
1055 def NewRecord(net, schema):
1056  """
1057  Given a record of np.arrays, create a BlobReference for each one of them,
1058  returning a record containing BlobReferences. The name of each returned blob
1059  is NextScopedBlob(field_name), which guarantees unique name in the current
1060  net. Use NameScope explicitly to avoid name conflictions between different
1061  nets.
1062  """
1063  if isinstance(schema, Scalar):
1064  result = schema.clone()
1065  result.set_value(
1066  blob=net.NextScopedBlob('unnamed_scalar'),
1067  unsafe=True,
1068  )
1069  return result
1070 
1071  assert isinstance(schema, Field), 'Record must be a schema.Field instance.'
1072  blob_refs = [
1073  net.NextScopedBlob(prefix=name)
1074  for name in schema.field_names()
1075  ]
1076  return from_blob_list(schema, blob_refs)
1077 
1078 
1079 def ConstRecord(net, array_record):
1080  """
1081  Given a record of arrays, returns a record of blobs,
1082  initialized with net.Const.
1083  """
1084  blob_record = NewRecord(net, array_record)
1085  for blob, array in zip(
1086  blob_record.field_blobs(), array_record.field_blobs()
1087  ):
1088  net.Const(array, blob)
1089  return blob_record
1090 
1091 
1092 def InitEmptyRecord(net, schema_or_record, enforce_types=False):
1093  if not schema_or_record.has_blobs():
1094  record = NewRecord(net, schema_or_record)
1095  else:
1096  record = schema_or_record
1097 
1098  for blob_type, blob in zip(record.field_types(), record.field_blobs()):
1099  try:
1100  data_type = data_type_for_dtype(blob_type)
1101  shape = [0] + list(blob_type.shape)
1102  net.ConstantFill([], blob, shape=shape, dtype=data_type)
1103  except TypeError:
1104  logger.warning("Blob {} has type error".format(blob))
1105  # If data_type_for_dtype doesn't know how to resolve given numpy
1106  # type to core.DataType, that function can throw type error (for
1107  # example that would happen for cases of unknown types such as
1108  # np.void). This is not a problem for cases when the record if going
1109  # to be overwritten by some operator later, though it might be an
1110  # issue for type/shape inference.
1111  if enforce_types:
1112  raise
1113  # If we don't enforce types for all items we'll create a blob with
1114  # the default ConstantFill (FLOAT, no shape)
1115  net.ConstantFill([], blob, shape=[0])
1116 
1117  return record
1118 
1119 
1120 _DATA_TYPE_FOR_DTYPE = [
1121  (np.str, core.DataType.STRING),
1122  (np.float16, core.DataType.FLOAT16),
1123  (np.float32, core.DataType.FLOAT),
1124  (np.float64, core.DataType.DOUBLE),
1125  (np.bool, core.DataType.BOOL),
1126  (np.int8, core.DataType.INT8),
1127  (np.int16, core.DataType.INT16),
1128  (np.int32, core.DataType.INT32),
1129  (np.int64, core.DataType.INT64),
1130  (np.uint8, core.DataType.UINT8),
1131  (np.uint16, core.DataType.UINT16),
1132 ]
1133 
1134 
1135 def is_schema_subset(schema, original_schema):
1136  # TODO add more checks
1137  return set(schema.field_names()).issubset(
1138  set(original_schema.field_names()))
1139 
1140 
1141 def equal_schemas(schema,
1142  original_schema,
1143  check_field_names=True,
1144  check_field_types=True,
1145  check_field_metas=False):
1146  assert isinstance(schema, Field)
1147  assert isinstance(original_schema, Field)
1148 
1149  if check_field_names and (
1150  schema.field_names() != original_schema.field_names()):
1151  return False
1152  if check_field_types and (
1153  schema.field_types() != original_schema.field_types()):
1154  return False
1155  if check_field_metas and (
1156  schema.field_metadata() != original_schema.field_metadata()):
1157  return False
1158 
1159  return True
1160 
1161 
1162 def schema_check(schema, previous=None):
1163  record = as_record(schema)
1164  if previous is not None:
1165  assert equal_schemas(schema, previous)
1166  return record
1167 
1168 
1169 def data_type_for_dtype(dtype):
1170  for np_type, dt in _DATA_TYPE_FOR_DTYPE:
1171  if dtype.base == np_type:
1172  return dt
1173  raise TypeError('Unknown dtype: ' + str(dtype.base))
1174 
1175 
1176 def attach_metadata_to_scalars(field, metadata):
1177  for f in field.all_scalars():
1178  f.set_metadata(metadata)
def __init__(self, fields)
Definition: schema.py:282
def set(self, dtype=None, blob=None, metadata=None, unsafe=False)
Definition: schema.py:694
def __getattr__(self, item)
Definition: schema.py:252
def set_metadata(self, value)
Definition: schema.py:671
def field_metadata(self)
Definition: schema.py:121
def __add__(self, other)
Definition: schema.py:467
def get(self, item, default_value)
Definition: schema.py:441
def _pprint_impl(self, indent, str_buffer)
Definition: schema.py:186
def __getitem__(self, item)
Definition: schema.py:416
def clone(self, keep_blobs=True)
Definition: schema.py:141
def _child_base_id(self, child_index=None)
Definition: schema.py:170
def _struct_from_nested_name(self, nested_name, field)
Definition: schema.py:331
def __sub__(self, other)
Definition: schema.py:508
def _validate_metadata(self)
Definition: schema.py:677
def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False)
Definition: schema.py:686
def __eq__(self, other)
Definition: schema.py:178
def _get_field_by_nested_name(self, nested_name)
Definition: schema.py:387
def __init__(self, children)
Definition: schema.py:100