3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
10 This module provides a python-land multithreaded mechanism for executing work. 12 Basic usage is as follows: 13 coordinator = parallel_workers.init_workers( 20 First argument is the function to run in a loop on potentially multiple threads. 21 It has the call signature 24 Argument 'worker_name' is used to distinguish different workers, 25 such as workers processing train data or workers processing test data. 27 Optionally, one can define an "init function" that is called once before 28 threads start, and has call signature: 29 my_init_fun(worker_coordinator, global_coordinator) 31 Note that for data_parallel_models, init_workers will be called 32 for each GPU. Note that the 'coordinator' returned by the function is same 44 from abc
import ABCMeta, abstractmethod
46 log = logging.getLogger(
"parallel_workers")
47 log.setLevel(logging.INFO)
56 external_loggers=
None,
59 global global_coordinator
61 metrics = Metrics(external_loggers)
64 coordinator = WorkerCoordinator(
65 worker_name, init_fun, shutdown_fun=shutdown_fun)
69 global_coordinator.get_new_worker_id()
70 for i
in range(num_worker_threads)
75 name=
"parallel_workers worker id {}".format(worker_id),
77 Worker(coordinator, worker_id, worker_fun, metrics)],
78 )
for worker_id
in worker_ids
81 coordinator._workers = workers
82 global_coordinator.add(coordinator)
84 return global_coordinator
88 def __init__(self, external_loggers):
89 self.
_metrics = collections.defaultdict(
lambda: 0)
92 def reset_metrics(self):
93 self.
_metrics = collections.defaultdict(
lambda: 0)
95 def log_metrics(self):
101 except Exception
as e:
102 print(
"Failed to call ExternalLogger: {}".format(e))
104 def put_metric(self, key, value, count=True):
107 count_key =
'{}_count'.format(key)
112 six.add_metaclass(ABCMeta)
127 class WorkerCoordinator(object):
128 def __init__(self, worker_name, init_fun, state=None, shutdown_fun=None):
130 self._started =
False 132 self._worker_name = worker_name
133 self._init_fun = init_fun
135 self._shutdown_fun = shutdown_fun
140 def init(self, global_coordinator):
141 if self._init_fun
and not self._started:
142 data_coordinator = self
143 self._init_fun(data_coordinator, global_coordinator)
153 for w
in self._workers:
157 def _stop(self, reason=None):
159 if reason
is not None:
160 log.error(
"Data input failed due to an error: {}".format(reason))
161 if self._shutdown_fun
and self._started:
166 self._started =
False 168 def _wait_finish(self, cleanup=None):
169 print(
"Wait for workers to die: {}".format(self._worker_name))
170 for w
in self._workers:
171 if w != threading.current_thread():
174 for w
in self._workers:
176 print(
"Worker {} failed to close while waiting".format(w))
180 if success
and self._state:
181 self._state.cleanup()
183 print(
"All workers terminated: {}".format(success))
194 def add(self, coordinator):
195 self._coordinators.append(coordinator)
197 def get_new_worker_id(self):
199 self._worker_ids.append(worker_id)
203 def get_worker_ids(self):
219 success = c._wait_finish()
220 all_success = all_success
and success
226 Stop a specific coordinator 229 if c._worker_name == worker_name:
234 if c._worker_name != worker_name
237 def register_shutdown_handler(self):
241 atexit.register(cleanup)
263 def handle_exception(self, e):
264 traceback.print_exc()
265 logging.exception(
"Exception in worker", e)
266 self._coordinator._stop(
"Exception in worker {}: {}".format(
271 self._metrics.put_metric(
273 self._metrics.log_metrics()
279 def run_worker(coordinator, worker):
280 while coordinator.is_active():
284 except Exception
as e:
285 worker.handle_exception(e)
def register_shutdown_handler(self)
def stop_coordinator(self, worker_name)