Caffe2 - Python API
A deep learning, cross platform ML framework
cached_reader.py
1 ## @package cached_reader
2 # Module caffe2.python.cached_reader
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import os
9 
10 from caffe2.python import core
11 from caffe2.python.dataio import Reader
12 from caffe2.python.dataset import Dataset
13 from caffe2.python.pipeline import pipe
14 from caffe2.python.task import Cluster, TaskGroup
15 
16 
18  """
19  Reader with persistent in-file cache.
20 
21  Example usage:
22  cached_reader = CachedReader(reader)
23  build_cache_step = cached_reader.build_cache('/tmp/cache.db')
24  with LocalSession() as session:
25  session.run(build_cache_step)
26 
27  Every time new reader is created, it's expected that build_cache will be
28  called before setup_ex and usage of the reader. build_cache will check
29  existence of provided file path and in case it's missing will initialize it
30  by reading data from original reader. All consequent attempts to read will
31  ignore original reader (i.e. no additional data will be read from it).
32  """
33 
34  def __init__(self, reader, db_type='leveldb', name='cached_reader'):
35  super(CachedReader, self).__init__(reader.schema())
36  self.original_reader = reader
37  self.cache_path = None
38  self.ds_reader = None
39  self.ds = Dataset(self._schema, name)
40  self.db_type = db_type
41  self.name = name
42  self.field_names = self._schema.field_names()
43 
44  def setup_ex(self, init_net, finish_net):
45  assert self.cache_path, 'build_cache must be called first'
46  self._init_dataset(init_net)
47  self._load_from_file(init_net)
48  self.ds_reader = self.ds.reader(init_net, batch_size=100)
49 
50  def read(self, read_net):
51  assert self.ds_reader, 'setup must be called first'
52  return self.ds_reader.read(read_net)
53 
54  def has_cache(self):
55  return self.cache_path and os.path.exists(self.cache_path)
56 
57  def build_cache(self, cache_path, overwrite=False):
58  if not self.has_cache() or overwrite:
59  self.cache_path = cache_path
60  if self.has_cache() and not overwrite:
61  # cache already exists, no need to rebuild it
62  return core.execution_step('build_step', [])
63 
64  init_net = core.Net('init')
65  self._init_dataset(init_net)
66  with Cluster(), core.NameScope(self.name), TaskGroup() as copy_tg:
67  pipe(self.original_reader, self.ds.writer(), num_threads=16)
68  copy_step = copy_tg.to_task().get_step()
69  save_net = core.Net('save')
70  self._save_to_file(save_net)
71 
72  return core.execution_step('build_cache', [init_net, copy_step, save_net])
73 
74  def _init_dataset(self, init_net):
75  with core.NameScope(self.name):
76  self.ds.init_empty(init_net)
77 
78  def _save_to_file(self, net):
79  net.Save(
80  self.ds.content().field_blobs(),
81  [],
82  db=self.cache_path,
83  db_type=self.db_type,
84  blob_name_overrides=self.field_names,
85  absolute_path=True,
86  )
87 
88  def _load_from_file(self, net):
89  net.Load(
90  [],
91  self.ds.content().field_blobs(),
92  db=self.cache_path,
93  db_type=self.db_type,
94  absolute_path=True,
95  source_blob_names=self.field_names,
96  )