Caffe2 - Python API
A deep learning, cross platform ML framework
resnet50_trainer.py
1 # Module caffe2.python.examples.resnet50_trainer
2 from __future__ import absolute_import
3 from __future__ import division
4 from __future__ import print_function
5 from __future__ import unicode_literals
6 
7 import argparse
8 import logging
9 import numpy as np
10 import time
11 import os
12 
13 from caffe2.python import core, workspace, experiment_util, data_parallel_model
14 from caffe2.python import data_parallel_model_utils, dyndep, optimizer
15 from caffe2.python import timeout_guard, model_helper, brew
16 from caffe2.proto import caffe2_pb2
17 
18 import caffe2.python.models.resnet as resnet
19 from caffe2.python.modeling.initializers import Initializer, PseudoFP16Initializer
22 from caffe2.python.predictor_constants import predictor_constants as predictor_constants
23 
24 '''
25 Parallelized multi-GPU distributed trainer for Resnet 50. Can be used to train
26 on imagenet data, for example.
27 
28 To run the trainer in single-machine multi-gpu mode by setting num_shards = 1.
29 
30 To run the trainer in multi-machine multi-gpu mode with M machines,
31 run the same program on all machines, specifying num_shards = M, and
32 shard_id = a unique integer in the set [0, M-1].
33 
34 For rendezvous (the trainer processes have to know about each other),
35 you can either use a directory path that is visible to all processes
36 (e.g. NFS directory), or use a Redis instance. Use the former by
37 passing the `file_store_path` argument. Use the latter by passing the
38 `redis_host` and `redis_port` arguments.
39 '''
40 
41 logging.basicConfig()
42 log = logging.getLogger("resnet50_trainer")
43 log.setLevel(logging.DEBUG)
44 
45 dyndep.InitOpsLibrary('@/caffe2/caffe2/distributed:file_store_handler_ops')
46 dyndep.InitOpsLibrary('@/caffe2/caffe2/distributed:redis_store_handler_ops')
47 
48 
49 def AddImageInput(model, reader, batch_size, img_size, dtype, is_test):
50  '''
51  The image input operator loads image and label data from the reader and
52  applies transformations to the images (random cropping, mirroring, ...).
53  '''
54  data, label = brew.image_input(
55  model,
56  reader, ["data", "label"],
57  batch_size=batch_size,
58  output_type=dtype,
59  use_gpu_transform=True if model._device_type == 1 else False,
60  use_caffe_datum=True,
61  mean=128.,
62  std=128.,
63  scale=256,
64  crop=img_size,
65  mirror=1,
66  is_test=is_test,
67  )
68 
69  data = model.StopGradient(data, data)
70 
71 
72 def AddNullInput(model, reader, batch_size, img_size, dtype):
73  '''
74  The null input function uses a gaussian fill operator to emulate real image
75  input. A label blob is hardcoded to a single value. This is useful if you
76  want to test compute throughput or don't have a dataset available.
77  '''
78  suffix = "_fp16" if dtype == "float16" else ""
79  model.param_init_net.GaussianFill(
80  [],
81  ["data" + suffix],
82  shape=[batch_size, 3, img_size, img_size],
83  )
84  if dtype == "float16":
85  model.param_init_net.FloatToHalf("data" + suffix, "data")
86 
87  model.param_init_net.ConstantFill(
88  [],
89  ["label"],
90  shape=[batch_size],
91  value=1,
92  dtype=core.DataType.INT32,
93  )
94 
95 
96 def SaveModel(args, train_model, epoch):
97  prefix = "[]_{}".format(train_model._device_prefix, train_model._devices[0])
98  predictor_export_meta = pred_exp.PredictorExportMeta(
99  predict_net=train_model.net.Proto(),
100  parameters=data_parallel_model.GetCheckpointParams(train_model),
101  inputs=[prefix + "/data"],
102  outputs=[prefix + "/softmax"],
103  shapes={
104  prefix + "/softmax": (1, args.num_labels),
105  prefix + "/data": (args.num_channels, args.image_size, args.image_size)
106  }
107  )
108 
109  # save the train_model for the current epoch
110  model_path = "%s/%s_%d.mdl" % (
111  args.file_store_path,
112  args.save_model_name,
113  epoch,
114  )
115 
116  # set db_type to be "minidb" instead of "log_file_db", which breaks
117  # the serialization in save_to_db. Need to switch back to log_file_db
118  # after migration
119  pred_exp.save_to_db(
120  db_type="minidb",
121  db_destination=model_path,
122  predictor_export_meta=predictor_export_meta,
123  )
124 
125 
126 def LoadModel(path, model):
127  '''
128  Load pretrained model from file
129  '''
130  log.info("Loading path: {}".format(path))
131  meta_net_def = pred_exp.load_from_db(path, 'minidb')
132  init_net = core.Net(pred_utils.GetNet(
133  meta_net_def, predictor_constants.GLOBAL_INIT_NET_TYPE))
134  predict_init_net = core.Net(pred_utils.GetNet(
135  meta_net_def, predictor_constants.PREDICT_INIT_NET_TYPE))
136 
137  predict_init_net.RunAllOnGPU()
138  init_net.RunAllOnGPU()
139 
140  assert workspace.RunNetOnce(predict_init_net)
141  assert workspace.RunNetOnce(init_net)
142 
143  # Hack: fix iteration counter which is in CUDA context after load model
144  itercnt = workspace.FetchBlob("optimizer_iteration")
145  workspace.FeedBlob(
146  "optimizer_iteration",
147  itercnt,
148  device_option=core.DeviceOption(caffe2_pb2.CPU, 0)
149  )
150 
151 
152 def RunEpoch(
153  args,
154  epoch,
155  train_model,
156  test_model,
157  total_batch_size,
158  num_shards,
159  expname,
160  explog,
161 ):
162  '''
163  Run one epoch of the trainer.
164  TODO: add checkpointing here.
165  '''
166  # TODO: add loading from checkpoint
167  log.info("Starting epoch {}/{}".format(epoch, args.num_epochs))
168  epoch_iters = int(args.epoch_size / total_batch_size / num_shards)
169  for i in range(epoch_iters):
170  # This timeout is required (temporarily) since CUDA-NCCL
171  # operators might deadlock when synchronizing between GPUs.
172  timeout = 600.0 if i == 0 else 60.0
173  with timeout_guard.CompleteInTimeOrDie(timeout):
174  t1 = time.time()
175  workspace.RunNet(train_model.net.Proto().name)
176  t2 = time.time()
177  dt = t2 - t1
178 
179  fmt = "Finished iteration {}/{} of epoch {} ({:.2f} images/sec)"
180  log.info(fmt.format(i + 1, epoch_iters, epoch, total_batch_size / dt))
181  prefix = "{}_{}".format(
182  train_model._device_prefix,
183  train_model._devices[0])
184  accuracy = workspace.FetchBlob(prefix + '/accuracy')
185  loss = workspace.FetchBlob(prefix + '/loss')
186  train_fmt = "Training loss: {}, accuracy: {}"
187  log.info(train_fmt.format(loss, accuracy))
188 
189  num_images = epoch * epoch_iters * total_batch_size
190  prefix = "{}_{}".format(train_model._device_prefix, train_model._devices[0])
191  accuracy = workspace.FetchBlob(prefix + '/accuracy')
192  loss = workspace.FetchBlob(prefix + '/loss')
193  learning_rate = workspace.FetchBlob(
194  data_parallel_model.GetLearningRateBlobNames(train_model)[0]
195  )
196  test_accuracy = 0
197  if (test_model is not None):
198  # Run 100 iters of testing
199  ntests = 0
200  for _ in range(0, 100):
201  workspace.RunNet(test_model.net.Proto().name)
202  for g in test_model._devices:
203  test_accuracy += np.asscalar(workspace.FetchBlob(
204  "{}_{}".format(test_model._device_prefix, g) + '/accuracy'
205  ))
206  ntests += 1
207  test_accuracy /= ntests
208  else:
209  test_accuracy = (-1)
210 
211  explog.log(
212  input_count=num_images,
213  batch_count=(i + epoch * epoch_iters),
214  additional_values={
215  'accuracy': accuracy,
216  'loss': loss,
217  'learning_rate': learning_rate,
218  'epoch': epoch,
219  'test_accuracy': test_accuracy,
220  }
221  )
222  assert loss < 40, "Exploded gradients :("
223 
224  # TODO: add checkpointing
225  return epoch + 1
226 
227 
228 def Train(args):
229  # Either use specified device list or generate one
230  if args.gpus is not None:
231  gpus = [int(x) for x in args.gpus.split(',')]
232  num_gpus = len(gpus)
233  else:
234  gpus = list(range(args.num_gpus))
235  num_gpus = args.num_gpus
236 
237  log.info("Running on GPUs: {}".format(gpus))
238 
239  # Verify valid batch size
240  total_batch_size = args.batch_size
241  batch_per_device = total_batch_size // num_gpus
242  assert \
243  total_batch_size % num_gpus == 0, \
244  "Number of GPUs must divide batch size"
245 
246  # Round down epoch size to closest multiple of batch size across machines
247  global_batch_size = total_batch_size * args.num_shards
248  epoch_iters = int(args.epoch_size / global_batch_size)
249 
250  assert \
251  epoch_iters > 0, \
252  "Epoch size must be larger than batch size times shard count"
253 
254  args.epoch_size = epoch_iters * global_batch_size
255  log.info("Using epoch size: {}".format(args.epoch_size))
256 
257  # Create ModelHelper object
258  train_arg_scope = {
259  'order': 'NCHW',
260  'use_cudnn': True,
261  'cudnn_exhaustive_search': True,
262  'ws_nbytes_limit': (args.cudnn_workspace_limit_mb * 1024 * 1024),
263  }
264  train_model = model_helper.ModelHelper(
265  name="resnet50", arg_scope=train_arg_scope
266  )
267 
268  num_shards = args.num_shards
269  shard_id = args.shard_id
270 
271  # Expect interfaces to be comma separated.
272  # Use of multiple network interfaces is not yet complete,
273  # so simply use the first one in the list.
274  interfaces = args.distributed_interfaces.split(",")
275 
276  # Rendezvous using MPI when run with mpirun
277  if os.getenv("OMPI_COMM_WORLD_SIZE") is not None:
278  num_shards = int(os.getenv("OMPI_COMM_WORLD_SIZE", 1))
279  shard_id = int(os.getenv("OMPI_COMM_WORLD_RANK", 0))
280  if num_shards > 1:
281  rendezvous = dict(
282  kv_handler=None,
283  num_shards=num_shards,
284  shard_id=shard_id,
285  engine="GLOO",
286  transport=args.distributed_transport,
287  interface=interfaces[0],
288  mpi_rendezvous=True,
289  exit_nets=None)
290 
291  elif num_shards > 1:
292  # Create rendezvous for distributed computation
293  store_handler = "store_handler"
294  if args.redis_host is not None:
295  # Use Redis for rendezvous if Redis host is specified
296  workspace.RunOperatorOnce(
297  core.CreateOperator(
298  "RedisStoreHandlerCreate", [], [store_handler],
299  host=args.redis_host,
300  port=args.redis_port,
301  prefix=args.run_id,
302  )
303  )
304  else:
305  # Use filesystem for rendezvous otherwise
306  workspace.RunOperatorOnce(
307  core.CreateOperator(
308  "FileStoreHandlerCreate", [], [store_handler],
309  path=args.file_store_path,
310  prefix=args.run_id,
311  )
312  )
313 
314  rendezvous = dict(
315  kv_handler=store_handler,
316  shard_id=shard_id,
317  num_shards=num_shards,
318  engine="GLOO",
319  transport=args.distributed_transport,
320  interface=interfaces[0],
321  exit_nets=None)
322 
323  else:
324  rendezvous = None
325 
326  # Model building functions
327  def create_resnet50_model_ops(model, loss_scale):
328  initializer = (PseudoFP16Initializer if args.dtype == 'float16'
329  else Initializer)
330 
331  with brew.arg_scope([brew.conv, brew.fc],
332  WeightInitializer=initializer,
333  BiasInitializer=initializer,
334  enable_tensor_core=args.enable_tensor_core,
335  float16_compute=args.float16_compute):
336  pred = resnet.create_resnet50(
337  model,
338  "data",
339  num_input_channels=args.num_channels,
340  num_labels=args.num_labels,
341  no_bias=True,
342  no_loss=True,
343  )
344 
345  if args.dtype == 'float16':
346  pred = model.net.HalfToFloat(pred, pred + '_fp32')
347 
348  softmax, loss = model.SoftmaxWithLoss([pred, 'label'],
349  ['softmax', 'loss'])
350  loss = model.Scale(loss, scale=loss_scale)
351  brew.accuracy(model, [softmax, "label"], "accuracy")
352  return [loss]
353 
354  def add_optimizer(model):
355  stepsz = int(30 * args.epoch_size / total_batch_size / num_shards)
356 
357  if args.float16_compute:
358  # TODO: merge with multi-prceision optimizer
359  opt = optimizer.build_fp16_sgd(
360  model,
361  args.base_learning_rate,
362  momentum=0.9,
363  nesterov=1,
364  weight_decay=args.weight_decay, # weight decay included
365  policy="step",
366  stepsize=stepsz,
367  gamma=0.1
368  )
369  else:
370  optimizer.add_weight_decay(model, args.weight_decay)
371  opt = optimizer.build_multi_precision_sgd(
372  model,
373  args.base_learning_rate,
374  momentum=0.9,
375  nesterov=1,
376  policy="step",
377  stepsize=stepsz,
378  gamma=0.1
379  )
380  return opt
381 
382  # Define add_image_input function.
383  # Depends on the "train_data" argument.
384  # Note that the reader will be shared with between all GPUS.
385  if args.train_data == "null":
386  def add_image_input(model):
387  AddNullInput(
388  model,
389  None,
390  batch_size=batch_per_device,
391  img_size=args.image_size,
392  dtype=args.dtype,
393  )
394  else:
395  reader = train_model.CreateDB(
396  "reader",
397  db=args.train_data,
398  db_type=args.db_type,
399  num_shards=num_shards,
400  shard_id=shard_id,
401  )
402 
403  def add_image_input(model):
404  AddImageInput(
405  model,
406  reader,
407  batch_size=batch_per_device,
408  img_size=args.image_size,
409  dtype=args.dtype,
410  is_test=False,
411  )
412 
413  def add_post_sync_ops(model):
414  """Add ops applied after initial parameter sync."""
415  for param_info in model.GetOptimizationParamInfo(model.GetParams()):
416  if param_info.blob_copy is not None:
417  model.param_init_net.HalfToFloat(
418  param_info.blob,
419  param_info.blob_copy[core.DataType.FLOAT]
420  )
421 
422  # Create parallelized model
423  data_parallel_model.Parallelize(
424  train_model,
425  input_builder_fun=add_image_input,
426  forward_pass_builder_fun=create_resnet50_model_ops,
427  optimizer_builder_fun=add_optimizer,
428  post_sync_builder_fun=add_post_sync_ops,
429  devices=gpus,
430  rendezvous=rendezvous,
431  optimize_gradient_memory=False,
432  cpu_device=args.use_cpu,
433  shared_model=args.use_cpu,
434  combine_spatial_bn=args.use_cpu,
435  )
436 
437  if args.model_parallel:
438  # Shift half of the activations to another GPU
439  assert workspace.NumCudaDevices() >= 2 * args.num_gpus
440  activations = data_parallel_model_utils.GetActivationBlobs(train_model)
441  data_parallel_model_utils.ShiftActivationDevices(
442  train_model,
443  activations=activations[len(activations) // 2:],
444  shifts={g: args.num_gpus + g for g in range(args.num_gpus)},
445  )
446 
447  data_parallel_model.OptimizeGradientMemory(train_model, {}, set(), False)
448 
449  workspace.RunNetOnce(train_model.param_init_net)
450  workspace.CreateNet(train_model.net)
451 
452  # Add test model, if specified
453  test_model = None
454  if (args.test_data is not None):
455  log.info("----- Create test net ----")
456  test_arg_scope = {
457  'order': "NCHW",
458  'use_cudnn': True,
459  'cudnn_exhaustive_search': True,
460  }
461  test_model = model_helper.ModelHelper(
462  name="resnet50_test", arg_scope=test_arg_scope, init_params=False
463  )
464 
465  test_reader = test_model.CreateDB(
466  "test_reader",
467  db=args.test_data,
468  db_type=args.db_type,
469  )
470 
471  def test_input_fn(model):
472  AddImageInput(
473  model,
474  test_reader,
475  batch_size=batch_per_device,
476  img_size=args.image_size,
477  dtype=args.dtype,
478  is_test=True,
479  )
480 
481  data_parallel_model.Parallelize(
482  test_model,
483  input_builder_fun=test_input_fn,
484  forward_pass_builder_fun=create_resnet50_model_ops,
485  post_sync_builder_fun=add_post_sync_ops,
486  param_update_builder_fun=None,
487  devices=gpus,
488  cpu_device=args.use_cpu,
489  )
490  workspace.RunNetOnce(test_model.param_init_net)
491  workspace.CreateNet(test_model.net)
492 
493  epoch = 0
494  # load the pre-trained model and reset epoch
495  if args.load_model_path is not None:
496  LoadModel(args.load_model_path, train_model)
497 
498  # Sync the model params
499  data_parallel_model.FinalizeAfterCheckpoint(train_model)
500 
501  # reset epoch. load_model_path should end with *_X.mdl,
502  # where X is the epoch number
503  last_str = args.load_model_path.split('_')[-1]
504  if last_str.endswith('.mdl'):
505  epoch = int(last_str[:-4])
506  log.info("Reset epoch to {}".format(epoch))
507  else:
508  log.warning("The format of load_model_path doesn't match!")
509 
510  expname = "resnet50_gpu%d_b%d_L%d_lr%.2f_v2" % (
511  args.num_gpus,
512  total_batch_size,
513  args.num_labels,
514  args.base_learning_rate,
515  )
516 
517  explog = experiment_util.ModelTrainerLog(expname, args)
518 
519  # Run the training one epoch a time
520  while epoch < args.num_epochs:
521  epoch = RunEpoch(
522  args,
523  epoch,
524  train_model,
525  test_model,
526  total_batch_size,
527  num_shards,
528  expname,
529  explog
530  )
531 
532  # Save the model for each epoch
533  SaveModel(args, train_model, epoch)
534 
535  model_path = "%s/%s_" % (
536  args.file_store_path,
537  args.save_model_name
538  )
539  # remove the saved model from the previous epoch if it exists
540  if os.path.isfile(model_path + str(epoch - 1) + ".mdl"):
541  os.remove(model_path + str(epoch - 1) + ".mdl")
542 
543 
544 def main():
545  # TODO: use argv
546  parser = argparse.ArgumentParser(
547  description="Caffe2: Resnet-50 training"
548  )
549  parser.add_argument("--train_data", type=str, default=None, required=True,
550  help="Path to training data (or 'null' to simulate)")
551  parser.add_argument("--test_data", type=str, default=None,
552  help="Path to test data")
553  parser.add_argument("--db_type", type=str, default="lmdb",
554  help="Database type (such as lmdb or leveldb)")
555  parser.add_argument("--gpus", type=str,
556  help="Comma separated list of GPU devices to use")
557  parser.add_argument("--num_gpus", type=int, default=1,
558  help="Number of GPU devices (instead of --gpus)")
559  parser.add_argument("--model_parallel", type=bool, default=False,
560  help="Split model over 2 x num_gpus")
561  parser.add_argument("--num_channels", type=int, default=3,
562  help="Number of color channels")
563  parser.add_argument("--image_size", type=int, default=227,
564  help="Input image size (to crop to)")
565  parser.add_argument("--num_labels", type=int, default=1000,
566  help="Number of labels")
567  parser.add_argument("--batch_size", type=int, default=32,
568  help="Batch size, total over all GPUs")
569  parser.add_argument("--epoch_size", type=int, default=1500000,
570  help="Number of images/epoch, total over all machines")
571  parser.add_argument("--num_epochs", type=int, default=1000,
572  help="Num epochs.")
573  parser.add_argument("--base_learning_rate", type=float, default=0.1,
574  help="Initial learning rate.")
575  parser.add_argument("--weight_decay", type=float, default=1e-4,
576  help="Weight decay (L2 regularization)")
577  parser.add_argument("--cudnn_workspace_limit_mb", type=int, default=64,
578  help="CuDNN workspace limit in MBs")
579  parser.add_argument("--num_shards", type=int, default=1,
580  help="Number of machines in distributed run")
581  parser.add_argument("--shard_id", type=int, default=0,
582  help="Shard id.")
583  parser.add_argument("--run_id", type=str,
584  help="Unique run identifier (e.g. uuid)")
585  parser.add_argument("--redis_host", type=str,
586  help="Host of Redis server (for rendezvous)")
587  parser.add_argument("--redis_port", type=int, default=6379,
588  help="Port of Redis server (for rendezvous)")
589  parser.add_argument("--file_store_path", type=str, default="/tmp",
590  help="Path to directory to use for rendezvous")
591  parser.add_argument("--save_model_name", type=str, default="resnet50_model",
592  help="Save the trained model to a given name")
593  parser.add_argument("--load_model_path", type=str, default=None,
594  help="Load previously saved model to continue training")
595  parser.add_argument("--use_cpu", type=bool, default=False,
596  help="Use CPU instead of GPU")
597  parser.add_argument('--dtype', default='float',
598  choices=['float', 'float16'],
599  help='Data type used for training')
600  parser.add_argument('--float16_compute', action='store_true',
601  help="Use float 16 compute, if available")
602  parser.add_argument('--enable_tensor_core', action='store_true',
603  help='Enable Tensor Core math for Conv and FC ops')
604  parser.add_argument("--distributed_transport", type=str, default="tcp",
605  help="Transport to use for distributed run [tcp|ibverbs]")
606  parser.add_argument("--distributed_interfaces", type=str, default="",
607  help="Network interfaces to use for distributed run")
608 
609  args = parser.parse_args()
610 
611  Train(args)
612 
613 if __name__ == '__main__':
614  workspace.GlobalInit(['caffe2', '--caffe2_log_level=2'])
615  main()