3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
10 This module provides a python-land multithreaded data input mechanism 13 Basic usage is as follows: 14 coordinator = data_workers.init_data_input_workers( 19 input_source_name="train", 25 First argument is the Caffe2 net (or model helper), and second argument 26 is list of input blobs that are to be fed. 28 Argument 'input_source_name' is used to distinguish different sources of data, 29 such as train or test data. This is to ensure the data does not get mixed up, 30 although two nets would share blobs. 32 To do the actual data loading, one defines a "fetcher function" 33 that has call signature 34 my_fetch_fun(worker_id, batch_size) 36 Optionally, one can define a "init function" that is called once before 37 threads start, and has call signature: 38 my_init_fun(data_coordinator, global_coordinator) 40 If dont_rebatch is set to True, the data input is not batched into equal sized 41 chunks but data directly provided by fetchers is used. 43 'batch_columns' can be used to specify which dimension is the batch dimension, 44 for each of the inputs. Default is 0 for all iputs. 46 'timeout' is the timeout in seconds after which if no data is available, the 47 net will fail (default 600s = 10 mins). 49 This function returns a list of numpy arrays corresponding to the different 50 input blobs. In the example above, it would return two arrays, one for the 51 data blob and another for the labels. These arrays can have arbitrary number 52 of elements (i.e they do not need to match the batch size). The batch size 53 is provided for the function as a hint only. 55 For example, fetcher function could download images from a remote service or 56 load random images from a directory on a file system. 58 For a dummy example, see the data_workers_test unit test. 60 Note that for data_parallel_models, init_data_input_workers will be called 61 for each GPU. Note that the 'coordinator' returned by the function is same 70 from itertools
import chain
77 from caffe2.proto
import caffe2_pb2
79 WorkerCoordinator, GlobalWorkerCoordinator, Worker, run_worker
81 log = logging.getLogger(
"data_workers")
82 log.setLevel(logging.INFO)
86 def get_worker_ids(num_workers):
87 return list(range(0, num_workers))
90 def init_data_input_workers(
96 input_source_name=
"train",
97 max_buffered_batches=800,
99 external_loggers=
None,
104 global global_coordinator
105 device_option = scope.CurrentDeviceScope()
106 if (device_option
is None):
107 device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
109 metrics = Metrics(external_loggers)
110 batch_feeder = BatchFeeder(
115 scope.CurrentNameScope(),
117 global_coordinator.get_queue(input_source_name, max_buffered_batches),
125 coordinator = WorkerCoordinator(
126 input_source_name, init_fun, batch_feeder)
130 global_coordinator.get_new_worker_id()
131 for i
in range(num_worker_threads)
136 name=
"data_workers fetcher id {}".format(worker_id),
138 DataWorker(coordinator, worker_id, fetch_fun, metrics,
139 batch_size, batch_feeder)],
140 )
for worker_id
in worker_ids
143 workers.append(threading.Thread(
145 name=
"Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
146 args=[coordinator, batch_feeder]))
147 coordinator._workers = workers
148 global_coordinator.add(coordinator)
150 return global_coordinator
154 def __init__(self, net, input_blob_names, batch_size,
155 device_option, namescope, input_source_name, queue,
156 metrics, dont_rebatch, batch_columns, timeout=600):
176 if batch_columns
is None:
177 batch_columns = [0
for _
in input_blob_names]
187 workspace.RunOperatorOnce(
188 core.CreateOperator(
"CloseBlobsQueue", [q], [])
194 utils.ResetBlobs(self._scratch_blob.values())
195 utils.ResetBlobs(self._scratch_status.values())
197 def _get(self, data_input_coordinator):
198 start_time = time.time()
199 last_warning = time.time()
200 while data_input_coordinator.is_active():
202 return self._internal_queue.get(block=
True, timeout=0.5)
204 if time.time() - last_warning > 10.0:
205 log.warning(
"** Data input is slow: (still) no data in {} secs.".format(
206 time.time() - start_time))
207 last_warning = time.time()
211 def _validate_chunk(self, chunk):
213 log.warning(
"Fetcher function returned None")
217 "Expecting data blob for each input" 219 assert isinstance(d, np.ndarray), \
220 "Fetcher function must return a numpy array" 226 "Each returned input must have equal number of samples" 230 log.warning(
"Worker provided zero length input")
235 def put(self, chunk, data_input_coordinator):
239 while data_input_coordinator.is_active():
241 qsize = self._internal_queue.qsize()
242 if qsize < 2
and (time.time() - self.
_last_warning) > LOG_INT_SECS:
243 log.warning(
"Warning, data loading lagging behind: " +
247 self._internal_queue.put(chunk, block=
True, timeout=0.5)
251 log.debug(
"Queue full: stalling fetchers...")
254 def _enqueue_batch_direct(self, data_input_coordinator):
255 data = self.
_get(data_input_coordinator)
258 if data_input_coordinator.is_active():
262 def _enqueue_batch(self, data_input_coordinator):
264 This pulls data from the python-side queue and collects them 265 into batch-sized pieces, unless dont_rebatch is set to true. 276 cur_batch[0].shape[0] == 0
or 277 cur_batch[0].shape[first_batch_col] < self.
_batch_size 278 )
and data_input_coordinator.is_active():
279 chunk = self.
_get(data_input_coordinator)
283 for j, chunk_elem
in enumerate(chunk):
284 if cur_batch[j].shape[0] == 0:
285 cur_batch[j] = chunk_elem.copy()
287 cur_batch[j] = np.append(
291 start_time = time.time()
294 if cur_batch[0].shape[0] > 0
and cur_batch[0].shape[
299 for j, b
in enumerate(cur_batch):
304 trimmed_batch.append(c)
305 cur_batch = trimmed_batch
307 self._internal_queue.put(leftover, block=
False)
311 assert cur_batch[0].shape[first_batch_col] == self.
_batch_size 313 if data_input_coordinator.is_active():
319 self._metrics.put_metric(
'enqueue_time', time.time() - start_time)
321 def _init_scratch(self):
325 scratch_name = self.
_namescope + blob_name + \
329 scratch_name +
"_status" 336 self._scratch_blob.values(), self._scratch_status.values()
340 np.array([]).astype(np.float32),
344 def _enqueue(self, blob_name, queue, data_arr):
346 Enqueue the correctly sized batch arrays to Caffe2's queue. 354 op = core.CreateOperator(
360 workspace.RunOperatorOnce(op)
362 def _create_caffe2_queues(self, net):
364 Creates queues on caffe2 side 366 def create_queue(queue_name, num_blobs, capacity):
367 workspace.RunOperatorOnce(
373 return core.ScopedBlobReference(queue_name)
380 self._queues.append(q)
382 def _create_caffe2_ops(self, net):
384 Creates dequeue-ops on caffe2 side 388 net.DequeueBlobs(q, blob_name, timeout_secs=float(self.
_timeout))
390 def _log_inputs_per_interval(self, inputs, force=False):
392 current_seconds = time.time()
394 if delta_seconds >= LOG_INT_SECS
or force:
395 inputs_per_sec = int(self.
_inputs / delta_seconds)
396 qsize = self._internal_queue.qsize()
397 log.info(
"{}/{}: {} inputs/sec".format(
402 log.info(
"-- queue: {} batches".format(qsize))
404 self._metrics.put_metric(
405 'inputs_per_sec', inputs_per_sec,
False)
406 self._metrics.put_metric(
'queue_size', qsize,
False)
407 self._metrics.put_metric(
408 'time_elapsed', delta_seconds,
False)
409 self._metrics.log_metrics()
410 self._metrics.reset_metrics()
417 GlobalWorkerCoordinator.__init__(self)
420 def get_queue(self, queue_name, max_buffered_batches):
421 assert isinstance(max_buffered_batches, int)
422 if queue_name
not in self.
_queues:
423 self.
_queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
424 return self.
_queues[queue_name]
426 def reset_data_input(self, namescope, name, net, batch_size):
427 log.info(
"Reset data input {}, batch size {}: ".format(name, batch_size))
428 for c
in self._coordinators:
429 if c._worker_name == name
and c._state._namescope == namescope:
430 c._state._batch_size = batch_size
431 c._state._create_caffe2_ops(net)
444 Worker.__init__(self, coordinator, worker_id, worker_fun=worker_fun,
450 input_data = self._worker_fun(self._worker_id, self.
_batch_size)
452 self._batch_feeder.put(input_data, self._coordinator)
455 self._metrics.put_metric(
456 'fetcher_time', time.time() - self._start_time)
462 def enqueuer(coordinator, batch_feeder):
463 while coordinator.is_active():
464 batch_feeder._enqueue_batch(coordinator)
def _create_caffe2_ops(self, net)
def _enqueue_batch_direct(self, data_input_coordinator)
def _enqueue(self, blob_name, queue, data_arr)
def _log_inputs_per_interval(self, inputs, force=False)
def _create_caffe2_queues(self, net)
def _get(self, data_input_coordinator)
def _validate_chunk(self, chunk)