Caffe2 - Python API
A deep learning, cross platform ML framework
parallel_workers.py
1 # @package parallel_workers
2 # Module caffe2.python.parallel_workers
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 
9 '''
10 This module provides a python-land multithreaded mechanism for executing work.
11 
12 Basic usage is as follows:
13  coordinator = parallel_workers.init_workers(
14  my_worker_fun,
15  worker_name="train"
16  )
17  ...
18  coordinator.start()
19 
20 First argument is the function to run in a loop on potentially multiple threads.
21 It has the call signature
22  worker_fun(worker_id)
23 
24 Argument 'worker_name' is used to distinguish different workers,
25 such as workers processing train data or workers processing test data.
26 
27 Optionally, one can define an "init function" that is called once before
28 threads start, and has call signature:
29  my_init_fun(worker_coordinator, global_coordinator)
30 
31 Note that for data_parallel_models, init_workers will be called
32 for each GPU. Note that the 'coordinator' returned by the function is same
33 each time.
34 '''
35 
36 import logging
37 import threading
38 import atexit
39 import time
40 import collections
41 import six
42 import traceback
43 
44 from abc import ABCMeta, abstractmethod
45 
46 log = logging.getLogger("parallel_workers")
47 log.setLevel(logging.INFO)
48 LOG_INT_SECS = 60
49 
50 
51 def init_workers(
52  worker_fun,
53  num_worker_threads=2,
54  worker_name="train",
55  init_fun=None,
56  external_loggers=None,
57  shutdown_fun=None,
58 ):
59  global global_coordinator
60 
61  metrics = Metrics(external_loggers)
62 
63  # Create coordinator object
64  coordinator = WorkerCoordinator(
65  worker_name, init_fun, shutdown_fun=shutdown_fun)
66 
67  # Launch fetch worker threads
68  worker_ids = [
69  global_coordinator.get_new_worker_id()
70  for i in range(num_worker_threads)
71  ]
72  workers = [
73  threading.Thread(
74  target=run_worker,
75  name="parallel_workers worker id {}".format(worker_id),
76  args=[coordinator,
77  Worker(coordinator, worker_id, worker_fun, metrics)],
78  ) for worker_id in worker_ids
79  ]
80 
81  coordinator._workers = workers
82  global_coordinator.add(coordinator)
83 
84  return global_coordinator
85 
86 
87 class Metrics(object):
88  def __init__(self, external_loggers):
89  self._metrics = collections.defaultdict(lambda: 0)
90  self._external_loggers = external_loggers
91 
92  def reset_metrics(self):
93  self._metrics = collections.defaultdict(lambda: 0)
94 
95  def log_metrics(self):
96  if not self._external_loggers:
97  return
98  for logger in self._external_loggers:
99  try:
100  logger.log(self._metrics)
101  except Exception as e:
102  print("Failed to call ExternalLogger: {}".format(e))
103 
104  def put_metric(self, key, value, count=True):
105  self._metrics[key] += value
106  if count:
107  count_key = '{}_count'.format(key)
108  self._metrics[count_key] += 1
109 
110 
111 class State():
112  six.add_metaclass(ABCMeta)
113 
114  @abstractmethod
115  def start(self):
116  pass
117 
118  @abstractmethod
119  def stop(self):
120  pass
121 
122  @abstractmethod
123  def cleanup(self):
124  pass
125 
126 
127 class WorkerCoordinator(object):
128  def __init__(self, worker_name, init_fun, state=None, shutdown_fun=None):
129  self._active = True
130  self._started = False
131  self._workers = []
132  self._worker_name = worker_name
133  self._init_fun = init_fun
134  self._state = state
135  self._shutdown_fun = shutdown_fun
136 
137  def is_active(self):
138  return self._active
139 
140  def init(self, global_coordinator):
141  if self._init_fun and not self._started:
142  data_coordinator = self
143  self._init_fun(data_coordinator, global_coordinator)
144 
145  def _start(self):
146  if self._started:
147  return
148  self._active = True
149  self._started = True
150  if self._state:
151  self._state.start()
152 
153  for w in self._workers:
154  w.daemon = True
155  w.start()
156 
157  def _stop(self, reason=None):
158  self._active = False
159  if reason is not None:
160  log.error("Data input failed due to an error: {}".format(reason))
161  if self._shutdown_fun and self._started:
162  self._shutdown_fun()
163  if self._state:
164  self._state.stop()
165 
166  self._started = False
167 
168  def _wait_finish(self, cleanup=None):
169  print("Wait for workers to die: {}".format(self._worker_name))
170  for w in self._workers:
171  if w != threading.current_thread():
172  w.join(5.0) # don't wait forever, thread may be blocked in i/o
173  success = True
174  for w in self._workers:
175  if w.isAlive():
176  print("Worker {} failed to close while waiting".format(w))
177  success = False
178 
179  # Release memory for the scratch blobs
180  if success and self._state:
181  self._state.cleanup()
182 
183  print("All workers terminated: {}".format(success))
184  return success
185 
186 
188  def __init__(self):
189  self._coordinators = []
190  self._fetcher_id_seq = 0
191  self._worker_ids = []
193 
194  def add(self, coordinator):
195  self._coordinators.append(coordinator)
196 
197  def get_new_worker_id(self):
198  worker_id = self._fetcher_id_seq
199  self._worker_ids.append(worker_id)
200  self._fetcher_id_seq += 1
201  return worker_id
202 
203  def get_worker_ids(self):
204  return self._worker_ids
205 
206  def start(self):
207  # run init and start in separate for loop to
208  # ensure init happens serially before threads are spawn.
209  for c in self._coordinators:
210  c.init(self)
211  for c in self._coordinators:
212  c._start()
213 
214  def stop(self):
215  all_success = True
216  for c in self._coordinators:
217  c._stop()
218  for c in self._coordinators:
219  success = c._wait_finish()
220  all_success = all_success and success
221  self._coordinators = []
222  return all_success
223 
224  def stop_coordinator(self, worker_name):
225  '''
226  Stop a specific coordinator
227  '''
228  for c in self._coordinators:
229  if c._worker_name == worker_name:
230  c._stop()
231  c._wait_finish()
232  self._coordinators = [
233  c for c in self._coordinators
234  if c._worker_name != worker_name
235  ]
236 
237  def register_shutdown_handler(self):
238  def cleanup():
239  self.stop()
240 
241  atexit.register(cleanup)
242 
243 
244 class Worker(object):
245  def __init__(
246  self,
247  coordinator,
248  worker_id,
249  worker_fun=None,
250  metrics=None
251  ):
252  self._coordinator = coordinator
253  self._worker_id = worker_id
254  self._worker_fun = worker_fun
255  self._metrics = metrics
256 
257  def start(self):
258  self._start_time = time.time()
259 
260  def run(self):
261  self._worker_fun(self._worker_id)
262 
263  def handle_exception(self, e):
264  traceback.print_exc()
265  logging.exception("Exception in worker", e)
266  self._coordinator._stop("Exception in worker {}: {}".format(
267  self._worker_id, e
268  ))
269 
270  def finish(self):
271  self._metrics.put_metric(
272  'worker_time', time.time() - self._start_time)
273  self._metrics.log_metrics()
274 
275 
276 global_coordinator = GlobalWorkerCoordinator()
277 
278 
279 def run_worker(coordinator, worker):
280  while coordinator.is_active():
281  worker.start()
282  try:
283  worker.run()
284  except Exception as e:
285  worker.handle_exception(e)
286  finally:
287  worker.finish()