Caffe2 - Python API
A deep learning, cross platform ML framework
checkpoint.py
1 ## @package checkpoint
2 # Module caffe2.python.checkpoint
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import os
9 import logging
10 from caffe2.python import core, context
11 from caffe2.python.net_builder import ops
12 from caffe2.python.task import Node, Task, TaskGroup, TaskOutput, WorkspaceType
13 
14 logger = logging.getLogger(__name__)
15 logger.setLevel(logging.INFO)
16 
17 
18 @context.define_context()
19 class Job(object):
20  """
21  A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the
22  `exit_group` which will be run by a JobRunner.
23 
24  The `init_group` will be run only once at startup. Its role is to
25  initialize globally persistent blobs such as model weights, accumulators
26  and data file lists.
27 
28  The `epoch_group` will be run in a loop after init_group. The loop will
29  exit when any of the stop signals added with `add_stop_signal` is True
30  at the end of an epoch.
31 
32  The download_group will be run only once, after all the executions of
33  epoch_group finish. Its role is to collect the distribute scattered
34  parameters back after training.
35 
36  The `exit_group` will be run only once at the very end of the job, the
37  role of this group is to save the results of training in the end of the job.
38 
39  Jobs are context-driven, so that Tasks can be added to the active Job
40  without having to explicitly pass the job object around.
41 
42  Example of usage:
43 
44  def build_reader(partitions):
45  with Job.current().init_group:
46  reader = HiveReader(init_reader, ..., partitions)
47  Task(step=init_reader)
48  with Job.current().epoch_group:
49  limited_reader = ReaderWithLimit(reader, num_iter=10000)
50  data_queue = pipe(limited_reader, num_threads=8)
51  Job.current().add_stop_signal(limited_reader.data_finished())
52  return data_queue
53 
54  def build_hogwild_trainer(reader, model):
55  with Job.current().init_group:
56  Task(step=model.param_init_net)
57  with Job.current().epoch_group:
58  pipe(reader, processor=model, num_threads=8)
59  with Job.current().exit_group:
60  Task(step=model.save_model_net)
61 
62  with Job() as job:
63  reader = build_reader(partitions)
64  model = build_model(params)
65  build_hogwild_trainer(reader, model)
66  """
67  def __init__(self,
68  init_group=None, epoch_group=None,
69  download_group=None, exit_group=None,
70  stop_signals=None, nodes_to_checkpoint=None):
71  self.init_group = init_group or TaskGroup(
72  workspace_type=WorkspaceType.GLOBAL)
73  self.epoch_group = epoch_group or TaskGroup()
74  self.download_group = download_group or TaskGroup()
75  self.exit_group = exit_group or TaskGroup()
76  self.stop_signals = stop_signals or []
77  self._nodes_to_checkpoint = nodes_to_checkpoint
78 
79  def nodes_to_checkpoint(self):
80  if self._nodes_to_checkpoint:
81  return self._nodes_to_checkpoint
82  else:
83  return self.init_group.used_nodes()
84 
85  def compile(self, session_class):
86  return Job(
87  init_group=session_class.compile(self.init_group),
88  epoch_group=session_class.compile(self.epoch_group),
89  download_group=session_class.compile(self.download_group),
90  exit_group=session_class.compile(self.exit_group),
91  stop_signals=self.stop_signals,
92  nodes_to_checkpoint=self.nodes_to_checkpoint())
93 
94  def __enter__(self):
95  self.epoch_group.__enter__()
96  return self
97 
98  def __exit__(self, *args):
99  self.epoch_group.__exit__()
100 
101  def add_stop_signal(self, output):
102  if isinstance(output, core.BlobReference):
103  t = Task(outputs=[output], group=self.epoch_group)
104  output = t.outputs()[0]
105  assert isinstance(output, TaskOutput)
106  self.stop_signals.append(output)
107 
108 
109 def get_ckpt_filename(node_name, epoch):
110  """Returns the checkpoint filename.
111 
112  Args:
113  node_name: A string. The name of the node.
114  epoch: An integer. The checkpoint epoch.
115 
116  Returns:
117  ckpt_filename: A string. The filename of the checkpoint.
118  """
119  return node_name + '.' + str(epoch)
120 
121 
122 def db_name(epoch, node_name, db_prefix, path_prefix=None):
123  """Returns the full db name where checkpoint files are saved.
124 
125  Args:
126  epoch: An integer. The checkpoint epoch.
127  node_name: A string. The name of the node.
128  db_prefix: A string. The prefix used to construct full db name.
129  path_prefix: A string. Optional param used to construct db name or path
130  where checkpoint files are are stored.
131  Returns:
132  db_name: A string. The absolute path of full_db_name where checkpoint
133  files are saved
134  """
135  if path_prefix:
136  db_name = path_prefix + get_ckpt_filename(node_name, epoch)
137  else:
138  ckpt_filename = get_ckpt_filename(node_name, epoch)
139  db_name = os.path.join(db_prefix, ckpt_filename)
140  return db_name
141 
142 
143 class CheckpointManager(object):
144  """
145  Controls saving and loading of workspaces on every epoch boundary of a job.
146  If a CheckpointManager instance is passed to JobRunner, then JobRunner will
147  call `init`, `read` and `save` at different moments in between epoch runs.
148 
149  Args:
150  db_prefix: The prefix used to construct full db name. Since `absolute_path`
151  is set to True, this will be used as db_name in SaveOp.
152  node_name: Name of the node where this checkpoint_manager is used.
153  db_type: Type of database to use for storing checkpoint.
154  metadata_handler: An optional object capable of reading/writing
155  checkpoint info in storage of choice.
156  """
157  def __init__(self, db_prefix, node_name, db_type, metadata_handler=None):
158  self._db_prefix = db_prefix
159  self._node_name = node_name
160  self._db_type = db_type
161  self._metadata_handler = metadata_handler
162  # make sure these blobs are the first in the checkpoint file.
163  self._net = core.Net('!!checkpoint_mngr')
164  self._blob_names = self._net.AddExternalInput('blob_names')
165  self._names_output = None
166  self._path_prefix = None
167  self._path_type = None
168 
169  """
170  Initialize the checkpoint manager. Determines all blobs that need to be saved
171  or loads from a checkpoint.
172 
173  Args:
174  nodes: An array of nodes where this checkpoint manager is running. Should
175  only contain a single node.
176  retrieve_from_epoch: Set to a number to load blobs from this epoch.
177  path_prefix: Used to construct db name or path where checkpoint files are
178  stored.
179  path_type: Indicate the type of path where checkpoint files are stored.
180  """
181  def init(
182  self,
183  nodes=None,
184  retrieve_from_epoch=None,
185  path_prefix=None,
186  path_type=None
187  ):
188  """
189  Build a Task that will be run once after the job's `init_group` is run.
190  This task will determine which blobs need to be checkpointed.
191  If retrieve_from_epoch is not None, then the checkpoint metadata is
192  retrieved from a previously saved checkpoint.
193  """
194  assert nodes is None or len(nodes) == 1, (
195  'CheckpointManager only supports single node.')
196 
197  with Task(outputs=[self._blob_names]) as task:
198  if retrieve_from_epoch is None:
199  ops.GetAllBlobNames(
200  [],
201  self._blob_names,
202  include_shared=False)
203  else:
204  full_db_name = db_name(retrieve_from_epoch,
205  self._node_name, self._db_prefix, path_prefix)
206  db_type = path_type or self._db_type
207  logger.info("Initializing checkpoints from = %s"
208  % full_db_name)
209  ops.Load(
210  [], self._blob_names,
211  db=full_db_name,
212  db_type=db_type,
213  absolute_path=True)
214  self._names_output = task.outputs()[0]
215  return task
216 
217  def blob_list(self):
218  assert self._names_output
219  return self._names_output.fetch().tolist()
220 
221  def load(self, epoch, path_prefix=None, path_type=None):
222  """
223  Build a Task that will be run by JobRunner when the job is to be
224  resumed from a given epoch. This task will run a Load op that will
225  load and deserialize all relevant blobs from a persistent storage.
226  """
227  full_db_name = db_name(epoch, self._node_name, self._db_prefix, path_prefix)
228  db_type = path_type or self._db_type
229  logger.info("Loading checkpoints from = %s" % full_db_name)
230  with Task() as task:
231  ops.Load(
232  [],
233  self.blob_list(),
234  db=full_db_name,
235  db_type=db_type,
236  absolute_path=True)
237  return task
238 
239  def load_blobs_from_checkpoint(self, blob_names, epoch):
240  """
241  Builds a Task that loads only the necessary blobs from a checkpoint of
242  the given epoch. The necessary blobs are given in the blob_names
243  argument.
244 
245  Args:
246  blob_names: A list of strings. Each string is the name of a
247  blob.
248  epoch: The checkpoint epoch to load from.
249 
250  Returns:
251  A Task which loads the specified blobs from the checkpoint of the
252  given epoch.
253  """
254  logger.info('Load from %s' % db_name(epoch, self._node_name, self._db_prefix))
255  with Task() as task:
256  ops.Load(
257  [],
258  blob_names,
259  db=db_name(epoch, self._node_name, self._db_prefix),
260  db_type=self._db_type,
261  absolute_path=True,
262  allow_incomplete=True)
263  return task
264 
265  def check_db_exists(self, epoch):
266  logger.info('Check existence of %s' %
267  db_name(epoch, self._node_name, self._db_prefix))
268  with Task() as task:
269  existence = ops.Const(False)
270  ops.DBExists(
271  [],
272  [existence],
273  db_name=db_name(epoch, self._node_name, self._db_prefix),
274  db_type=self._db_type,
275  absolute_path=True)
276  task.add_output(existence)
277  return task
278 
279  def save(self, epoch):
280  """
281  Build a Task that is run once after `init_group` and after each
282  epoch is run. This will execute a Save ops to serialize and persist
283  blobs present in the global workspace.
284  """
285  logger.info('Saving to %s' % db_name(epoch, self._node_name, self._db_prefix))
286  with Task() as task:
287  ops.Save(
288  self.blob_list(), [],
289  db=db_name(epoch, self._node_name, self._db_prefix),
290  db_type=self._db_type, absolute_path=True)
291  return task
292 
293  def write_checkpoint_metadata(self, epoch):
294  """
295  Write metadata for checkpoint
296 
297  Args:
298  epoch: An integer. The epoch-id for which checkpoint metadata is
299  written
300  """
301  if self._metadata_handler is not None:
302  self._metadata_handler.write(epoch=epoch)
303 
304  def get_resume_from_epoch_id(self, user_epoch=None):
305  """
306  Identify the epoch-id from which Job must resume
307 
308  Args:
309  user_epoch: An integer. Optional parameter for user to explicitly
310  identify the epoch-id to load checkpoint from
311  Retruns:
312  epoch: the epoch-id to load checkpoints from
313  or None if no checkpoints were written
314  """
315  last_epoch = user_epoch
316  if self._metadata_handler is not None:
317  last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
318  return last_epoch
319 
320  def set_params(self, nodes, path_prefix=None, path_type=None):
321  """Set parameters associated with CP manager
322 
323  Args:
324  nodes: An array of nodes where this checkpoint manager is running.
325  path_prefix: Used to construct db name or path where checkpoint files are
326  stored.
327  path_type: Indicate the type of path where checkpoint files are stored.
328  """
329  if path_prefix:
330  self._path_prefix = path_prefix
331  if path_type:
332  self._path_type = path_type
333  if self._metadata_handler:
334  self._metadata_handler.set_params(
335  db_prefix=self._db_prefix,
336  db_type=self._db_type,
337  node_names=[str(self._node_name)],
338  path_prefix=self._path_prefix,
339  path_type=self._path_type)
340 
341  def cp_accessible(self, epoch=None):
342  """Returns True if Checkpoint data is accessible
343 
344  Args:
345  epoch: An integer. The epoch of the checkpoint. If None,
346  it implies we need to check if checkpoint directory is accessible
347 
348  Returns:
349  is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
350  """
351  if self._metadata_handler is not None:
352  return self._metadata_handler.cp_accessible(epoch)
353  else:
354  return True
355 
356 
358  """
359  Coordinates checkpointing and checkpointing across multiple nodes.
360  Each of `init`, `load` and `save` will build TaskGroups which will
361  trigger checkpointing on each of the nodes involved in a distributed job.
362 
363  Args:
364  db_prefix: The prefix used to construct full db name. Since `absolute_path`
365  is set to True, this will be used as db_name in SaveOp.
366  db_type: Type of database to use for storing checkpoint.
367  metadata_handler: An optional object capable of reading/writing
368  checkpoint info in storage of choice.
369  """
370  def __init__(self, db_prefix, db_type, metadata_handler=None):
371  self._node_managers = None
372  self._db_prefix = db_prefix
373  self._db_type = db_type
374  self._metadata_handler = metadata_handler
375  self._path_prefix = None
376  self._path_type = None
377 
378  def _task_group(self, func, *args, **kw):
379  assert self._node_managers is not None, 'init must be called first.'
380  with TaskGroup(WorkspaceType.GLOBAL) as task_group:
381  for node, manager in self._node_managers:
382  with Node(node):
383  func(manager, *args, **kw)
384  return task_group
385 
386  """
387  Args:
388  nodes: An array of nodes where this checkpoint manager is running.
389  retrieve_from_epoch: Set to a number to load blobs from this epoch.
390  path_prefix: Used to construct db name or path where checkpoint files are
391  stored.
392  path_type: Indicate the type of path where checkpoint files are stored.
393  """
394  def init(
395  self, nodes, retrieve_from_epoch=None, path_prefix=None, path_type=None
396  ):
397  if self._node_managers is not None:
398  assert [node for node, _ in self._node_managers] == nodes
399  return TaskGroup(WorkspaceType.GLOBAL)
400  self._node_managers = []
401  for node in nodes:
402  with Node(node):
403  manager = CheckpointManager(
404  db_prefix=self._db_prefix,
405  node_name=str(node),
406  db_type=self._db_type)
407  self._node_managers.append((node, manager))
408  return self._task_group(
409  CheckpointManager.init,
410  nodes=[node],
411  retrieve_from_epoch=retrieve_from_epoch,
412  path_prefix=path_prefix,
413  path_type=path_type)
414 
415  def load(self, epoch, path_prefix=None, path_type=None):
416  return self._task_group(
417  CheckpointManager.load,
418  epoch,
419  path_prefix=path_prefix,
420  path_type=path_type)
421 
422  def load_blobs_locally(self, nodes, blob_names, epoch, session):
423  """Loads the necessary blobs from the checkpoints to the current node.
424 
425  Args:
426  blob_names: A list of strings. Each string is the name of a
427  blob.
428  epoch: An integer. The checkpoint epoch to load from.
429  session: A Session object to execute the Load ops.
430  """
431  if self._node_managers is not None:
432  assert [node for node, _ in self._node_managers] == nodes
433  else:
434  self._node_managers = []
435  for node in nodes:
436  with Node(node):
437  manager = CheckpointManager(
438  db_prefix=self._db_prefix,
439  node_name=str(node),
440  db_type=self._db_type)
441  self._node_managers.append((node, manager))
442  assert self._node_managers is not None, 'must initialize node managers'
443  for _, manager in self._node_managers:
444  existence_task = manager.check_db_exists(epoch)
445  session.run(existence_task)
446  existence = existence_task.outputs()[0].fetch()
447  if not existence:
448  logger.info('DB %s does not exist!' %
449  db_name(epoch, manager._node_name, manager._db_prefix))
450  return False
451  load_task = manager.load_blobs_from_checkpoint(blob_names, epoch)
452  session.run(load_task)
453  logger.info('Successfully loaded from checkpoints.')
454  return True
455 
456  def get_ckpt_db_name(self, node_name, epoch):
457  """Returns the DB name of the given node and the given epoch.
458 
459  The DB name is effectively the checkpoint path of the given node and
460  the given epoch.
461 
462  Args:
463  node_name: A string. The node name of interest.
464  epoch: An integer. The epoch of the checkpoint.
465 
466  Returns:
467  checkpoint_db_name: A string. The checkpoint path of the given
468  node and the given epoch.
469  """
470  for node, manager in self._node_managers:
471  if str(node) == node_name:
472  return db_name(epoch, manager._node_name, manager._db_prefix)
473 
474  def save(self, epoch):
475  """
476  Build a Task that will execute a Save ops to serialize and persist
477  blobs present in the global workspace.
478  """
479  return self._task_group(CheckpointManager.save, epoch)
480 
481  def write_checkpoint_metadata(self, epoch):
482  """
483  Write metadata for checkpoint
484 
485  Args:
486  epoch: An integer. The epoch-id for which checkpoint metadata is
487  written
488  """
489  if self._metadata_handler is not None:
490  self._metadata_handler.write(epoch=epoch)
491 
492  def get_resume_from_epoch_id(self, user_epoch=None):
493  """
494  Identify the epoch-id from which Job must resume
495 
496  Args:
497  user_epoch: An integer. Optional parameter for user to explicitly
498  identify the epoch-id to load checkpoint from
499  Retruns:
500  epoch: the epoch-id to load checkpoints from
501  or None if no checkpoints were written
502  """
503  last_epoch = user_epoch
504  if self._metadata_handler is not None:
505  last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
506  return last_epoch
507 
508  def set_params(self, nodes, path_prefix=None, path_type=None):
509  """Set parameters associated with CP manager
510 
511  Args:
512  nodes: An array of nodes where this checkpoint manager is running.
513  path_prefix: Used to construct db name or path where checkpoint files are
514  stored.
515  path_type: Indicate the type of path where checkpoint files are stored.
516  """
517  self._node_names = [str(node) for node in nodes]
518  if path_prefix:
519  self._path_prefix = path_prefix
520  if path_type:
521  self._path_type = path_type
522  if self._metadata_handler:
523  self._metadata_handler.set_params(
524  db_prefix=self._db_prefix,
525  db_type=self._db_type,
526  node_names=self._node_names,
527  path_prefix=self._path_prefix,
528  path_type=self._path_type)
529 
530  def cp_accessible(self, epoch=None):
531  """Returns True if Checkpoint data is accessible
532 
533  Args:
534  epoch: An integer. The epoch of the checkpoint. If None,
535  it implies we need to check if checkpoint directory is accessible
536 
537  Returns:
538  is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
539  """
540  if self._metadata_handler is not None:
541  return self._metadata_handler.cp_accessible(epoch)
542  else:
543  return True
544 
545 
547  """A simple class to upload checkpoints."""
548  def build(self, epoch, checkpoint_manager):
549  """Builds the task group to upload checkpoints.
550 
551  Args:
552  epoch: An integer. The checkpoint epoch to be uploaded.
553  checkpoint_manager: Can be a CheckpointManager for single machine
554  or a MultiNodeCheckpointManager for multi-machine. The manager
555  that initializes/saves/loads checkpoints.
556 
557  Raises:
558  NotImplementedError: This base class only has the interface,
559  the implementation will be in the subclasses.
560  """
561  raise NotImplementedError()
562 
563 
564 class JobRunner(object):
565  """
566  Implement the runtime logic for jobs with checkpointing at the level of
567  epoch. Can be used to run either single-host or distributed jobs. Job
568  runner is a callable to be called once from the master, passing a session
569  as an argument. This call will block until the Job execution is complete.
570 
571  If a checkpoint_manager is passed, checkpoints will be taken after
572  initialization and after each epoch execution. If, in addition,
573  `resume_from_epoch` is an epoch number, the corresponding checkpoint will
574  be loaded and job execution will continue from the given epoch. In
575  this case, the job's init_group will not be run.
576 
577  Refer to checkpoint_test.py for an example.
578  """
579  def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None,
580  upload_task_group_builder=None):
581  """Initializes the JobRunner.
582 
583  Args:
584  job: A Job object. The job to be executed.
585  checkpoint_manager: Can be a CheckpointManager for single machine
586  or a MultiNodeCheckpointManager for multi-machine. The manager
587  that initializes/saves/loads checkpoints.
588  resume_from_epoch: An integer. The epoch to resume from.
589  upload_task_group_builder: A subclass of the
590  UploadTaskGroupBuilder. Creates a task group to upload
591  checkpoints.
592  """
593  self.resume_from_epoch = resume_from_epoch
594  self.checkpoint_manager = checkpoint_manager
595  self.job = job
596  self.upload_task_group_builder = upload_task_group_builder
597 
598  def __call__(self, session):
599  """Runs the training flow.
600 
601  Args:
602  session: A Session object. Valid choises are: LocalSession,
603  LocalHostScheduler, and DistributedSession. It is used to
604  execute one TaskGroup a time.
605  """
606  # identify the epoch we must resume from
607  if self.checkpoint_manager:
608  self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
610  get_resume_from_epoch_id(self.resume_from_epoch)
611  if self.resume_from_epoch is not None:
612  logger.info('Resuming from epoch {}'.format(self.resume_from_epoch))
613 
614  # Initialize all the nodes.
615  from_scratch = self.resume_from_epoch is None
616  if from_scratch:
617  session.run(self.job.init_group)
618 
619  if self.checkpoint_manager:
620  logger.info('Preparing checkpoints ...')
621  session.run(self.checkpoint_manager.init(
622  self.job.nodes_to_checkpoint(),
623  retrieve_from_epoch=self.resume_from_epoch))
624  # Save the first checkpoint before training starts, or resume from
625  # a previously saved checkpoint.
626  if from_scratch:
627  self.save_checkpoints(0, session)
628  else:
629  logger.info('Loading checkpoints for epoch {} ...'.format(
630  self.resume_from_epoch))
631  session.run(
632  self.checkpoint_manager.load(self.resume_from_epoch))
633  logger.info('Checkpoint loaded')
634 
635  logger.info("Finished initializing")
636 
637  # Start training.
638  epoch = 1 if from_scratch else self.resume_from_epoch + 1
639  while True:
640  logger.info('Starting epoch %d' % epoch)
641  session.run(self.job.epoch_group)
642  logger.info('Finished epoch %d' % epoch)
643  stop_signals = [o.fetch() for o in self.job.stop_signals]
644 
645  if self.checkpoint_manager:
646  self.save_checkpoints(epoch, session)
647 
648  if any(stop_signals):
649  logger.info('Stopping')
650  break
651  epoch += 1
652  logger.info('Finished training')
653  # Upload the checkpoints.
654  if (self.upload_task_group_builder):
655  upload_task_group = self.upload_task_group_builder.build(
656  epoch, self.checkpoint_manager)
657  session.run(upload_task_group)
658  logger.info('Finished uploading the checkpoints')
659 
660  # Download the parameters to save
661  session.run(self.job.download_group)
662  logger.info('Finished downloading the parameters')
663 
664  # Finally run the exit step to save nets
665  session.run(self.job.exit_group)
666  logger.info('Finished running the exit group')
667  return epoch
668 
669  def load_blobs_from_checkpoints(self, blob_names, epoch, session):
670  """Loads the necessary blobs from the checkpoints.
671 
672  Checkpoints store the snapshots of the workspace in each node.
673  Sometimes we only need to load a subset of the blobs from the
674  checkpoints. One common scenario is to load only the model blobs from
675  the checkpoints for evaluation purpose. Given the names of the
676  necessary blobs, this function goes over all the checkpoints of all the
677  nodes, but only loads the blobs specified in the blob_names to the
678  current workspace.
679 
680  Args:
681  blob_names: A list of strings. Each string is the name of a
682  blob.
683  epoch: An integer. The checkpoint epoch to load from.
684  session: A Session object to execute the load ops.
685 
686  Raises:
687  ValueError: When the checkpoint manager is invalid.
688  """
689  if not self.checkpoint_manager:
690  raise ValueError('Checkpoint manager is None')
691  logger.info('Loading checkpoint for epoch {} ...'.format(epoch))
692  return self.checkpoint_manager.load_blobs_locally(
693  self.job.nodes_to_checkpoint(), blob_names, epoch, session)
694 
695  def save_checkpoints(self, epoch, session):
696  """Triggers operation to save checkpoints
697 
698  This method will trigger the Save ops to serialize and persist the
699  blobs present in the global workspaace.
700 
701  Args:
702  epoch: An integer. The checkpoint epoch-id that we are saving.
703  session: A Session object to execute the save ops.
704 
705  Raises:
706  ValueError: When the checkpoint manager is invalid.
707  """
708  if not self.checkpoint_manager:
709  raise ValueError('Checkpoint manager is None')
710  try:
711  is_accessible = self.checkpoint_manager.cp_accessible(epoch=None)
712  if is_accessible:
713  logger.info('Saving checkpoints for epoch {}'.format(epoch))
714  session.run(self.checkpoint_manager.save(epoch))
715  self.checkpoint_manager.write_checkpoint_metadata(epoch)
716  logger.info('Checkpoints saved')
717  else:
718  logger.warning("Checkpoint files cannot be accessed!")
719  except Exception as ex:
720  logger.warning("Unable to write checkpoint for epoch {}. Error={}".
721  format(epoch, ex))
722 
723 
724 def epoch_limiter(job, num_epochs):
725  """
726  Creates a task that will output True when a given
727  number of epochs has finished.
728  """
729  with job.init_group:
730  init_net = core.Net('epoch_counter_init')
731  counter = init_net.CreateCounter([], init_count=num_epochs - 1)
732  Task(step=init_net)
733 
734  with job.epoch_group:
735  epoch_net = core.Net('epoch_countdown')
736  finished = epoch_net.CountDown(counter)
737  output = Task(step=epoch_net, outputs=finished).outputs()[0]
738  job.add_stop_signal(output)
def set_params(self, nodes, path_prefix=None, path_type=None)
Definition: checkpoint.py:508
def init(self, nodes=None, retrieve_from_epoch=None, path_prefix=None, path_type=None)
Definition: checkpoint.py:187
def load_blobs_locally(self, nodes, blob_names, epoch, session)
Definition: checkpoint.py:422
def load_blobs_from_checkpoints(self, blob_names, epoch, session)
Definition: checkpoint.py:669
def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None, upload_task_group_builder=None)
Definition: checkpoint.py:580
def build(self, epoch, checkpoint_manager)
Definition: checkpoint.py:548
def save_checkpoints(self, epoch, session)
Definition: checkpoint.py:695
def get_resume_from_epoch_id(self, user_epoch=None)
Definition: checkpoint.py:304
def load(self, epoch, path_prefix=None, path_type=None)
Definition: checkpoint.py:221
def load_blobs_from_checkpoint(self, blob_names, epoch)
Definition: checkpoint.py:239
def set_params(self, nodes, path_prefix=None, path_type=None)
Definition: checkpoint.py:320
def get_resume_from_epoch_id(self, user_epoch=None)
Definition: checkpoint.py:492