Caffe2 - Python API
A deep learning, cross platform ML framework
data_workers.py
1 ## @package data_workers
2 # Module caffe2.python.data_workers
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 
9 '''
10 This module provides a python-land multithreaded data input mechanism
11 for Caffe2 nets.
12 
13 Basic usage is as follows:
14  coordinator = data_workers.init_data_input_workers(
15  net,
16  ["data", "label"],
17  my_fetch_fun,
18  batch_size=32,
19  input_source_name="train",
20  dont_rebatch=False
21  )
22  ...
23  coordinator.start()
24 
25 First argument is the Caffe2 net (or model helper), and second argument
26 is list of input blobs that are to be fed.
27 
28 Argument 'input_source_name' is used to distinguish different sources of data,
29 such as train or test data. This is to ensure the data does not get mixed up,
30 although two nets would share blobs.
31 
32 To do the actual data loading, one defines a "fetcher function"
33 that has call signature
34  my_fetch_fun(worker_id, batch_size)
35 
36 Optionally, one can define a "init function" that is called once before
37 threads start, and has call signature:
38  my_init_fun(data_coordinator, global_coordinator)
39 
40 If dont_rebatch is set to True, the data input is not batched into equal sized
41 chunks but data directly provided by fetchers is used.
42 
43 'batch_columns' can be used to specify which dimension is the batch dimension,
44 for each of the inputs. Default is 0 for all iputs.
45 
46 'timeout' is the timeout in seconds after which if no data is available, the
47 net will fail (default 600s = 10 mins).
48 
49 This function returns a list of numpy arrays corresponding to the different
50 input blobs. In the example above, it would return two arrays, one for the
51 data blob and another for the labels. These arrays can have arbitrary number
52 of elements (i.e they do not need to match the batch size). The batch size
53 is provided for the function as a hint only.
54 
55 For example, fetcher function could download images from a remote service or
56 load random images from a directory on a file system.
57 
58 For a dummy example, see the data_workers_test unit test.
59 
60 Note that for data_parallel_models, init_data_input_workers will be called
61 for each GPU. Note that the 'coordinator' returned by the function is same
62 each time.
63 '''
64 
65 try:
66  import Queue
67 except ImportError:
68  # Py3
69  import queue as Queue
70 from itertools import chain
71 import logging
72 import threading
73 import numpy as np
74 import time
75 
76 from caffe2.python import workspace, core, scope, utils
77 from caffe2.proto import caffe2_pb2
78 from caffe2.python.parallel_workers import Metrics, State, \
79  WorkerCoordinator, GlobalWorkerCoordinator, Worker, run_worker
80 
81 log = logging.getLogger("data_workers")
82 log.setLevel(logging.INFO)
83 LOG_INT_SECS = 60
84 
85 
86 def get_worker_ids(num_workers):
87  return list(range(0, num_workers))
88 
89 
90 def init_data_input_workers(
91  net,
92  input_blob_names,
93  fetch_fun,
94  batch_size,
95  num_worker_threads=2,
96  input_source_name="train",
97  max_buffered_batches=800,
98  init_fun=None,
99  external_loggers=None,
100  dont_rebatch=False,
101  batch_columns=None,
102  timeout=600
103 ):
104  global global_coordinator
105  device_option = scope.CurrentDeviceScope()
106  if (device_option is None):
107  device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
108 
109  metrics = Metrics(external_loggers)
110  batch_feeder = BatchFeeder(
111  net,
112  input_blob_names,
113  batch_size,
114  device_option,
115  scope.CurrentNameScope(),
116  input_source_name,
117  global_coordinator.get_queue(input_source_name, max_buffered_batches),
118  metrics,
119  dont_rebatch,
120  batch_columns,
121  timeout=timeout
122  )
123 
124  # Create coordinator object
125  coordinator = WorkerCoordinator(
126  input_source_name, init_fun, batch_feeder)
127 
128  # Launch fetch worker threads
129  worker_ids = [
130  global_coordinator.get_new_worker_id()
131  for i in range(num_worker_threads)
132  ]
133  workers = [
134  threading.Thread(
135  target=run_worker,
136  name="data_workers fetcher id {}".format(worker_id),
137  args=[coordinator,
138  DataWorker(coordinator, worker_id, fetch_fun, metrics,
139  batch_size, batch_feeder)],
140  ) for worker_id in worker_ids
141  ]
142 
143  workers.append(threading.Thread(
144  target=enqueuer,
145  name="Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
146  args=[coordinator, batch_feeder]))
147  coordinator._workers = workers
148  global_coordinator.add(coordinator)
149 
150  return global_coordinator
151 
152 
154  def __init__(self, net, input_blob_names, batch_size,
155  device_option, namescope, input_source_name, queue,
156  metrics, dont_rebatch, batch_columns, timeout=600):
157  self._counter = 0
158  self._input_blob_names = input_blob_names
159  self._batch_size = batch_size
160  self._internal_queue = queue
161  self._queues = []
162  self._device_option = device_option
163  self._namescope = namescope
164  self._timeout = timeout
165  self._input_source_name = input_source_name
166  self._c2_queue_capacity = 4
167  self._create_caffe2_queues(net)
168  self._create_caffe2_ops(net)
169  self._inputs = 0
170  self._prev_seconds = 0
171  self._last_warning = time.time()
172  self._dont_rebatch = dont_rebatch
173  self._init_scratch()
174  self._metrics = metrics
175 
176  if batch_columns is None:
177  batch_columns = [0 for _ in input_blob_names]
178  self._batch_columns = batch_columns
179 
180  def start(self):
181  self._inputs = 0
182  self._prev_seconds = time.time()
183 
184  def stop(self):
185  try:
186  for q in self._queues:
187  workspace.RunOperatorOnce(
188  core.CreateOperator("CloseBlobsQueue", [q], [])
189  )
190  finally:
191  self._log_inputs_per_interval(0, force=True)
192 
193  def cleanup(self):
194  utils.ResetBlobs(self._scratch_blob.values())
195  utils.ResetBlobs(self._scratch_status.values())
196 
197  def _get(self, data_input_coordinator):
198  start_time = time.time()
199  last_warning = time.time()
200  while data_input_coordinator.is_active():
201  try:
202  return self._internal_queue.get(block=True, timeout=0.5)
203  except Queue.Empty:
204  if time.time() - last_warning > 10.0:
205  log.warning("** Data input is slow: (still) no data in {} secs.".format(
206  time.time() - start_time))
207  last_warning = time.time()
208  continue
209  return None
210 
211  def _validate_chunk(self, chunk):
212  if chunk is None:
213  log.warning("Fetcher function returned None")
214  return False
215 
216  assert len(chunk) == len(self._input_blob_names), \
217  "Expecting data blob for each input"
218  for d in chunk:
219  assert isinstance(d, np.ndarray), \
220  "Fetcher function must return a numpy array"
221  if not self._dont_rebatch:
222  j = 1
223  for d in chunk[1:]:
224  assert d.shape[self._batch_columns[j]] == \
225  chunk[0].shape[self._batch_columns[0]], \
226  "Each returned input must have equal number of samples"
227  j += 1
228 
229  if len(chunk) == 0:
230  log.warning("Worker provided zero length input")
231  return False
232 
233  return True
234 
235  def put(self, chunk, data_input_coordinator):
236  if not self._validate_chunk(chunk):
237  return
238 
239  while data_input_coordinator.is_active():
240  try:
241  qsize = self._internal_queue.qsize()
242  if qsize < 2 and (time.time() - self._last_warning) > LOG_INT_SECS:
243  log.warning("Warning, data loading lagging behind: " +
244  "name={}".format(qsize, self._input_source_name))
245  self._last_warning = time.time()
246  self._counter += 1
247  self._internal_queue.put(chunk, block=True, timeout=0.5)
248  self._log_inputs_per_interval(chunk[0].shape[0])
249  return
250  except Queue.Full:
251  log.debug("Queue full: stalling fetchers...")
252  continue
253 
254  def _enqueue_batch_direct(self, data_input_coordinator):
255  data = self._get(data_input_coordinator)
256  if data is None:
257  return
258  if data_input_coordinator.is_active():
259  for b, q, c in zip(self._input_blob_names, self._queues, data):
260  self._enqueue(b, q, c)
261 
262  def _enqueue_batch(self, data_input_coordinator):
263  '''
264  This pulls data from the python-side queue and collects them
265  into batch-sized pieces, unless dont_rebatch is set to true.
266  '''
267  if self._dont_rebatch:
268  self._enqueue_batch_direct(data_input_coordinator)
269  return
270 
271  cur_batch = [np.array([]) for d in self._input_blob_names]
272  first_batch_col = self._batch_columns[0]
273 
274  # Collect data until we have a full batch size
275  while (
276  cur_batch[0].shape[0] == 0 or
277  cur_batch[0].shape[first_batch_col] < self._batch_size
278  ) and data_input_coordinator.is_active():
279  chunk = self._get(data_input_coordinator)
280  if chunk is None:
281  continue
282 
283  for j, chunk_elem in enumerate(chunk):
284  if cur_batch[j].shape[0] == 0:
285  cur_batch[j] = chunk_elem.copy()
286  else:
287  cur_batch[j] = np.append(
288  cur_batch[j], chunk_elem, axis=self._batch_columns[j]
289  )
290 
291  start_time = time.time()
292  try:
293  # Return data over the batch size back to queue
294  if cur_batch[0].shape[0] > 0 and cur_batch[0].shape[
295  first_batch_col
296  ] > self._batch_size:
297  leftover = []
298  trimmed_batch = []
299  for j, b in enumerate(cur_batch):
300  [c, l] = np.split(
301  b, [self._batch_size], axis=self._batch_columns[j]
302  )
303  leftover.append(l)
304  trimmed_batch.append(c)
305  cur_batch = trimmed_batch
306  try:
307  self._internal_queue.put(leftover, block=False)
308  except Queue.Full:
309  pass
310 
311  assert cur_batch[0].shape[first_batch_col] == self._batch_size
312 
313  if data_input_coordinator.is_active():
314  for b, q, c in zip(
315  self._input_blob_names, self._queues, cur_batch
316  ):
317  self._enqueue(b, q, c)
318  finally:
319  self._metrics.put_metric('enqueue_time', time.time() - start_time)
320 
321  def _init_scratch(self):
322  self._scratch_blob = {}
323  self._scratch_status = {}
324  for blob_name in self._input_blob_names:
325  scratch_name = self._namescope + blob_name + \
326  "_scratch_" + self._input_source_name
327  self._scratch_blob[blob_name] = core.BlobReference(scratch_name)
328  self._scratch_status[blob_name] = core.BlobReference(
329  scratch_name + "_status"
330  )
331 
332  # Feed empty arrays to the scratch blobs here, so that there won't be
333  # race conditions when calling FeedBlob (which calls wworkspace
334  # CreateBlob()) from enqueue threads
335  for b in chain(
336  self._scratch_blob.values(), self._scratch_status.values()
337  ):
338  workspace.FeedBlob(
339  b,
340  np.array([]).astype(np.float32),
341  device_option=self._device_option,
342  )
343 
344  def _enqueue(self, blob_name, queue, data_arr):
345  '''
346  Enqueue the correctly sized batch arrays to Caffe2's queue.
347  '''
348  workspace.FeedBlob(
349  self._scratch_blob[blob_name],
350  data_arr,
351  device_option=self._device_option
352  )
353 
354  op = core.CreateOperator(
355  "SafeEnqueueBlobs",
356  [queue, self._scratch_blob[blob_name]],
357  [self._scratch_blob[blob_name], self._scratch_status[blob_name]],
358  device_option=self._device_option
359  )
360  workspace.RunOperatorOnce(op)
361 
362  def _create_caffe2_queues(self, net):
363  '''
364  Creates queues on caffe2 side
365  '''
366  def create_queue(queue_name, num_blobs, capacity):
367  workspace.RunOperatorOnce(
368  core.CreateOperator(
369  "CreateBlobsQueue",
370  [], [queue_name],
371  num_blobs=1,
372  capacity=capacity))
373  return core.ScopedBlobReference(queue_name)
374 
375  for blob_name in self._input_blob_names:
376  qname = blob_name + "_c2queue" + "_" + self._input_source_name
377  q = create_queue(
378  qname, num_blobs=1, capacity=self._c2_queue_capacity
379  )
380  self._queues.append(q)
381 
382  def _create_caffe2_ops(self, net):
383  '''
384  Creates dequeue-ops on caffe2 side
385  '''
386  for q, blob_name in zip(self._queues, self._input_blob_names):
387  # Add operator to the Caffe2 network to dequeue
388  net.DequeueBlobs(q, blob_name, timeout_secs=float(self._timeout))
389 
390  def _log_inputs_per_interval(self, inputs, force=False):
391  self._inputs += inputs
392  current_seconds = time.time()
393  delta_seconds = current_seconds - self._prev_seconds
394  if delta_seconds >= LOG_INT_SECS or force:
395  inputs_per_sec = int(self._inputs / delta_seconds)
396  qsize = self._internal_queue.qsize()
397  log.info("{}/{}: {} inputs/sec".format(
398  self._input_source_name,
399  self._namescope,
400  inputs_per_sec,
401  ))
402  log.info("-- queue: {} batches".format(qsize))
403  # log and reset perf metrics
404  self._metrics.put_metric(
405  'inputs_per_sec', inputs_per_sec, False)
406  self._metrics.put_metric('queue_size', qsize, False)
407  self._metrics.put_metric(
408  'time_elapsed', delta_seconds, False)
409  self._metrics.log_metrics()
410  self._metrics.reset_metrics()
411  self._inputs = 0
412  self._prev_seconds = current_seconds
413 
414 
415 class GlobalCoordinator(GlobalWorkerCoordinator):
416  def __init__(self):
417  GlobalWorkerCoordinator.__init__(self)
418  self._queues = {}
419 
420  def get_queue(self, queue_name, max_buffered_batches):
421  assert isinstance(max_buffered_batches, int)
422  if queue_name not in self._queues:
423  self._queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
424  return self._queues[queue_name]
425 
426  def reset_data_input(self, namescope, name, net, batch_size):
427  log.info("Reset data input {}, batch size {}: ".format(name, batch_size))
428  for c in self._coordinators:
429  if c._worker_name == name and c._state._namescope == namescope:
430  c._state._batch_size = batch_size
431  c._state._create_caffe2_ops(net)
432 
433 
434 class DataWorker(Worker):
435  def __init__(
436  self,
437  coordinator,
438  worker_id,
439  worker_fun,
440  metrics,
441  batch_size,
442  batch_feeder
443  ):
444  Worker.__init__(self, coordinator, worker_id, worker_fun=worker_fun,
445  metrics=metrics)
446  self._batch_size = batch_size
447  self._batch_feeder = batch_feeder
448 
449  def run(self):
450  input_data = self._worker_fun(self._worker_id, self._batch_size)
451 
452  self._batch_feeder.put(input_data, self._coordinator)
453 
454  def finish(self):
455  self._metrics.put_metric(
456  'fetcher_time', time.time() - self._start_time)
457 
458 
459 global_coordinator = GlobalCoordinator()
460 
461 
462 def enqueuer(coordinator, batch_feeder):
463  while coordinator.is_active():
464  batch_feeder._enqueue_batch(coordinator)
def _enqueue_batch_direct(self, data_input_coordinator)
def _enqueue(self, blob_name, queue, data_arr)
def _log_inputs_per_interval(self, inputs, force=False)
def _get(self, data_input_coordinator)