3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
14 logger = logging.getLogger(__name__)
15 logger.setLevel(logging.INFO)
18 @context.define_context()
21 A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the 22 `exit_group` which will be run by a JobRunner. 24 The `init_group` will be run only once at startup. Its role is to 25 initialize globally persistent blobs such as model weights, accumulators 28 The `epoch_group` will be run in a loop after init_group. The loop will 29 exit when any of the stop signals added with `add_stop_signal` is True 30 at the end of an epoch. 32 The download_group will be run only once, after all the executions of 33 epoch_group finish. Its role is to collect the distribute scattered 34 parameters back after training. 36 The `exit_group` will be run only once at the very end of the job, the 37 role of this group is to save the results of training in the end of the job. 39 Jobs are context-driven, so that Tasks can be added to the active Job 40 without having to explicitly pass the job object around. 44 def build_reader(partitions): 45 with Job.current().init_group: 46 reader = HiveReader(init_reader, ..., partitions) 47 Task(step=init_reader) 48 with Job.current().epoch_group: 49 limited_reader = ReaderWithLimit(reader, num_iter=10000) 50 data_queue = pipe(limited_reader, num_threads=8) 51 Job.current().add_stop_signal(limited_reader.data_finished()) 54 def build_hogwild_trainer(reader, model): 55 with Job.current().init_group: 56 Task(step=model.param_init_net) 57 with Job.current().epoch_group: 58 pipe(reader, processor=model, num_threads=8) 59 with Job.current().exit_group: 60 Task(step=model.save_model_net) 63 reader = build_reader(partitions) 64 model = build_model(params) 65 build_hogwild_trainer(reader, model) 68 init_group=
None, epoch_group=
None,
69 download_group=
None, exit_group=
None,
70 stop_signals=
None, nodes_to_checkpoint=
None):
72 workspace_type=WorkspaceType.GLOBAL)
79 def nodes_to_checkpoint(self):
83 return self.init_group.used_nodes()
85 def compile(self, session_class):
87 init_group=session_class.compile(self.
init_group),
88 epoch_group=session_class.compile(self.
epoch_group),
90 exit_group=session_class.compile(self.
exit_group),
95 self.epoch_group.__enter__()
98 def __exit__(self, *args):
99 self.epoch_group.__exit__()
101 def add_stop_signal(self, output):
104 output = t.outputs()[0]
105 assert isinstance(output, TaskOutput)
106 self.stop_signals.append(output)
109 def get_ckpt_filename(node_name, epoch):
110 """Returns the checkpoint filename. 113 node_name: A string. The name of the node. 114 epoch: An integer. The checkpoint epoch. 117 ckpt_filename: A string. The filename of the checkpoint. 119 return node_name +
'.' + str(epoch)
122 def db_name(epoch, node_name, db_prefix, path_prefix=None):
123 """Returns the full db name where checkpoint files are saved. 126 epoch: An integer. The checkpoint epoch. 127 node_name: A string. The name of the node. 128 db_prefix: A string. The prefix used to construct full db name. 129 path_prefix: A string. Optional param used to construct db name or path 130 where checkpoint files are are stored. 132 db_name: A string. The absolute path of full_db_name where checkpoint 136 db_name = path_prefix + get_ckpt_filename(node_name, epoch)
138 ckpt_filename = get_ckpt_filename(node_name, epoch)
139 db_name = os.path.join(db_prefix, ckpt_filename)
145 Controls saving and loading of workspaces on every epoch boundary of a job. 146 If a CheckpointManager instance is passed to JobRunner, then JobRunner will 147 call `init`, `read` and `save` at different moments in between epoch runs. 150 db_prefix: The prefix used to construct full db name. Since `absolute_path` 151 is set to True, this will be used as db_name in SaveOp. 152 node_name: Name of the node where this checkpoint_manager is used. 153 db_type: Type of database to use for storing checkpoint. 154 metadata_handler: An optional object capable of reading/writing 155 checkpoint info in storage of choice. 157 def __init__(self, db_prefix, node_name, db_type, metadata_handler=None):
164 self.
_blob_names = self._net.AddExternalInput(
'blob_names')
170 Initialize the checkpoint manager. Determines all blobs that need to be saved 171 or loads from a checkpoint. 174 nodes: An array of nodes where this checkpoint manager is running. Should 175 only contain a single node. 176 retrieve_from_epoch: Set to a number to load blobs from this epoch. 177 path_prefix: Used to construct db name or path where checkpoint files are 179 path_type: Indicate the type of path where checkpoint files are stored. 184 retrieve_from_epoch=
None,
189 Build a Task that will be run once after the job's `init_group` is run. 190 This task will determine which blobs need to be checkpointed. 191 If retrieve_from_epoch is not None, then the checkpoint metadata is 192 retrieved from a previously saved checkpoint. 194 assert nodes
is None or len(nodes) == 1, (
195 'CheckpointManager only supports single node.')
198 if retrieve_from_epoch
is None:
202 include_shared=
False)
204 full_db_name = db_name(retrieve_from_epoch,
206 db_type = path_type
or self.
_db_type 207 logger.info(
"Initializing checkpoints from = %s" 219 return self._names_output.fetch().tolist()
221 def load(self, epoch, path_prefix=None, path_type=None):
223 Build a Task that will be run by JobRunner when the job is to be 224 resumed from a given epoch. This task will run a Load op that will 225 load and deserialize all relevant blobs from a persistent storage. 228 db_type = path_type
or self.
_db_type 229 logger.info(
"Loading checkpoints from = %s" % full_db_name)
241 Builds a Task that loads only the necessary blobs from a checkpoint of 242 the given epoch. The necessary blobs are given in the blob_names 246 blob_names: A list of strings. Each string is the name of a 248 epoch: The checkpoint epoch to load from. 251 A Task which loads the specified blobs from the checkpoint of the 262 allow_incomplete=
True)
265 def check_db_exists(self, epoch):
266 logger.info(
'Check existence of %s' %
269 existence = ops.Const(
False)
276 task.add_output(existence)
281 Build a Task that is run once after `init_group` and after each 282 epoch is run. This will execute a Save ops to serialize and persist 283 blobs present in the global workspace. 290 db_type=self.
_db_type, absolute_path=
True)
295 Write metadata for checkpoint 298 epoch: An integer. The epoch-id for which checkpoint metadata is 302 self._metadata_handler.write(epoch=epoch)
306 Identify the epoch-id from which Job must resume 309 user_epoch: An integer. Optional parameter for user to explicitly 310 identify the epoch-id to load checkpoint from 312 epoch: the epoch-id to load checkpoints from 313 or None if no checkpoints were written 315 last_epoch = user_epoch
317 last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
320 def set_params(self, nodes, path_prefix=None, path_type=None):
321 """Set parameters associated with CP manager 324 nodes: An array of nodes where this checkpoint manager is running. 325 path_prefix: Used to construct db name or path where checkpoint files are 327 path_type: Indicate the type of path where checkpoint files are stored. 334 self._metadata_handler.set_params(
342 """Returns True if Checkpoint data is accessible 345 epoch: An integer. The epoch of the checkpoint. If None, 346 it implies we need to check if checkpoint directory is accessible 349 is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible 352 return self._metadata_handler.cp_accessible(epoch)
359 Coordinates checkpointing and checkpointing across multiple nodes. 360 Each of `init`, `load` and `save` will build TaskGroups which will 361 trigger checkpointing on each of the nodes involved in a distributed job. 364 db_prefix: The prefix used to construct full db name. Since `absolute_path` 365 is set to True, this will be used as db_name in SaveOp. 366 db_type: Type of database to use for storing checkpoint. 367 metadata_handler: An optional object capable of reading/writing 368 checkpoint info in storage of choice. 370 def __init__(self, db_prefix, db_type, metadata_handler=None):
378 def _task_group(self, func, *args, **kw):
379 assert self.
_node_managers is not None,
'init must be called first.' 380 with
TaskGroup(WorkspaceType.GLOBAL)
as task_group:
383 func(manager, *args, **kw)
388 nodes: An array of nodes where this checkpoint manager is running. 389 retrieve_from_epoch: Set to a number to load blobs from this epoch. 390 path_prefix: Used to construct db name or path where checkpoint files are 392 path_type: Indicate the type of path where checkpoint files are stored. 395 self, nodes, retrieve_from_epoch=
None, path_prefix=
None, path_type=
None 407 self._node_managers.append((node, manager))
409 CheckpointManager.init,
411 retrieve_from_epoch=retrieve_from_epoch,
412 path_prefix=path_prefix,
415 def load(self, epoch, path_prefix=None, path_type=None):
417 CheckpointManager.load,
419 path_prefix=path_prefix,
423 """Loads the necessary blobs from the checkpoints to the current node. 426 blob_names: A list of strings. Each string is the name of a 428 epoch: An integer. The checkpoint epoch to load from. 429 session: A Session object to execute the Load ops. 441 self._node_managers.append((node, manager))
442 assert self.
_node_managers is not None,
'must initialize node managers' 444 existence_task = manager.check_db_exists(epoch)
445 session.run(existence_task)
446 existence = existence_task.outputs()[0].fetch()
448 logger.info(
'DB %s does not exist!' %
449 db_name(epoch, manager._node_name, manager._db_prefix))
451 load_task = manager.load_blobs_from_checkpoint(blob_names, epoch)
452 session.run(load_task)
453 logger.info(
'Successfully loaded from checkpoints.')
457 """Returns the DB name of the given node and the given epoch. 459 The DB name is effectively the checkpoint path of the given node and 463 node_name: A string. The node name of interest. 464 epoch: An integer. The epoch of the checkpoint. 467 checkpoint_db_name: A string. The checkpoint path of the given 468 node and the given epoch. 471 if str(node) == node_name:
472 return db_name(epoch, manager._node_name, manager._db_prefix)
476 Build a Task that will execute a Save ops to serialize and persist 477 blobs present in the global workspace. 479 return self.
_task_group(CheckpointManager.save, epoch)
483 Write metadata for checkpoint 486 epoch: An integer. The epoch-id for which checkpoint metadata is 490 self._metadata_handler.write(epoch=epoch)
494 Identify the epoch-id from which Job must resume 497 user_epoch: An integer. Optional parameter for user to explicitly 498 identify the epoch-id to load checkpoint from 500 epoch: the epoch-id to load checkpoints from 501 or None if no checkpoints were written 503 last_epoch = user_epoch
505 last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
508 def set_params(self, nodes, path_prefix=None, path_type=None):
509 """Set parameters associated with CP manager 512 nodes: An array of nodes where this checkpoint manager is running. 513 path_prefix: Used to construct db name or path where checkpoint files are 515 path_type: Indicate the type of path where checkpoint files are stored. 523 self._metadata_handler.set_params(
531 """Returns True if Checkpoint data is accessible 534 epoch: An integer. The epoch of the checkpoint. If None, 535 it implies we need to check if checkpoint directory is accessible 538 is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible 541 return self._metadata_handler.cp_accessible(epoch)
547 """A simple class to upload checkpoints.""" 548 def build(self, epoch, checkpoint_manager):
549 """Builds the task group to upload checkpoints. 552 epoch: An integer. The checkpoint epoch to be uploaded. 553 checkpoint_manager: Can be a CheckpointManager for single machine 554 or a MultiNodeCheckpointManager for multi-machine. The manager 555 that initializes/saves/loads checkpoints. 558 NotImplementedError: This base class only has the interface, 559 the implementation will be in the subclasses. 561 raise NotImplementedError()
566 Implement the runtime logic for jobs with checkpointing at the level of 567 epoch. Can be used to run either single-host or distributed jobs. Job 568 runner is a callable to be called once from the master, passing a session 569 as an argument. This call will block until the Job execution is complete. 571 If a checkpoint_manager is passed, checkpoints will be taken after 572 initialization and after each epoch execution. If, in addition, 573 `resume_from_epoch` is an epoch number, the corresponding checkpoint will 574 be loaded and job execution will continue from the given epoch. In 575 this case, the job's init_group will not be run. 577 Refer to checkpoint_test.py for an example. 579 def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None,
580 upload_task_group_builder=
None):
581 """Initializes the JobRunner. 584 job: A Job object. The job to be executed. 585 checkpoint_manager: Can be a CheckpointManager for single machine 586 or a MultiNodeCheckpointManager for multi-machine. The manager 587 that initializes/saves/loads checkpoints. 588 resume_from_epoch: An integer. The epoch to resume from. 589 upload_task_group_builder: A subclass of the 590 UploadTaskGroupBuilder. Creates a task group to upload 599 """Runs the training flow. 602 session: A Session object. Valid choises are: LocalSession, 603 LocalHostScheduler, and DistributedSession. It is used to 604 execute one TaskGroup a time. 608 self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
617 session.run(self.job.init_group)
620 logger.info(
'Preparing checkpoints ...')
621 session.run(self.checkpoint_manager.init(
622 self.job.nodes_to_checkpoint(),
629 logger.info(
'Loading checkpoints for epoch {} ...'.format(
633 logger.info(
'Checkpoint loaded')
635 logger.info(
"Finished initializing")
640 logger.info(
'Starting epoch %d' % epoch)
641 session.run(self.job.epoch_group)
642 logger.info(
'Finished epoch %d' % epoch)
643 stop_signals = [o.fetch()
for o
in self.job.stop_signals]
648 if any(stop_signals):
649 logger.info(
'Stopping')
652 logger.info(
'Finished training')
655 upload_task_group = self.upload_task_group_builder.build(
657 session.run(upload_task_group)
658 logger.info(
'Finished uploading the checkpoints')
661 session.run(self.job.download_group)
662 logger.info(
'Finished downloading the parameters')
665 session.run(self.job.exit_group)
666 logger.info(
'Finished running the exit group')
670 """Loads the necessary blobs from the checkpoints. 672 Checkpoints store the snapshots of the workspace in each node. 673 Sometimes we only need to load a subset of the blobs from the 674 checkpoints. One common scenario is to load only the model blobs from 675 the checkpoints for evaluation purpose. Given the names of the 676 necessary blobs, this function goes over all the checkpoints of all the 677 nodes, but only loads the blobs specified in the blob_names to the 681 blob_names: A list of strings. Each string is the name of a 683 epoch: An integer. The checkpoint epoch to load from. 684 session: A Session object to execute the load ops. 687 ValueError: When the checkpoint manager is invalid. 690 raise ValueError(
'Checkpoint manager is None')
691 logger.info(
'Loading checkpoint for epoch {} ...'.format(epoch))
692 return self.checkpoint_manager.load_blobs_locally(
693 self.job.nodes_to_checkpoint(), blob_names, epoch, session)
696 """Triggers operation to save checkpoints 698 This method will trigger the Save ops to serialize and persist the 699 blobs present in the global workspaace. 702 epoch: An integer. The checkpoint epoch-id that we are saving. 703 session: A Session object to execute the save ops. 706 ValueError: When the checkpoint manager is invalid. 709 raise ValueError(
'Checkpoint manager is None')
711 is_accessible = self.checkpoint_manager.cp_accessible(epoch=
None)
713 logger.info(
'Saving checkpoints for epoch {}'.format(epoch))
714 session.run(self.checkpoint_manager.save(epoch))
715 self.checkpoint_manager.write_checkpoint_metadata(epoch)
716 logger.info(
'Checkpoints saved')
718 logger.warning(
"Checkpoint files cannot be accessed!")
719 except Exception
as ex:
720 logger.warning(
"Unable to write checkpoint for epoch {}. Error={}".
724 def epoch_limiter(job, num_epochs):
726 Creates a task that will output True when a given 727 number of epochs has finished. 730 init_net =
core.Net(
'epoch_counter_init')
731 counter = init_net.CreateCounter([], init_count=num_epochs - 1)
734 with job.epoch_group:
735 epoch_net =
core.Net(
'epoch_countdown')
736 finished = epoch_net.CountDown(counter)
737 output =
Task(step=epoch_net, outputs=finished).outputs()[0]
738 job.add_stop_signal(output)
def set_params(self, nodes, path_prefix=None, path_type=None)
def nodes_to_checkpoint(self)
def init(self, nodes=None, retrieve_from_epoch=None, path_prefix=None, path_type=None)
def load_blobs_locally(self, nodes, blob_names, epoch, session)
def load_blobs_from_checkpoints(self, blob_names, epoch, session)
def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None, upload_task_group_builder=None)
upload_task_group_builder
def write_checkpoint_metadata(self, epoch)
def _task_group(self, func, args, kw)
def build(self, epoch, checkpoint_manager)
def get_ckpt_db_name(self, node_name, epoch)
def save_checkpoints(self, epoch, session)
def get_resume_from_epoch_id(self, user_epoch=None)
def cp_accessible(self, epoch=None)
def __call__(self, session)
def load(self, epoch, path_prefix=None, path_type=None)
def load_blobs_from_checkpoint(self, blob_names, epoch)
def cp_accessible(self, epoch=None)
def set_params(self, nodes, path_prefix=None, path_type=None)
def get_resume_from_epoch_id(self, user_epoch=None)
def write_checkpoint_metadata(self, epoch)