3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
18 Simple benchmark that creates a data-parallel resnet-50 model 24 log = logging.getLogger(
"net_construct_bench")
25 log.setLevel(logging.DEBUG)
30 Add the momentum-SGD update. 32 params = train_model.GetParams()
33 assert(len(params) > 0)
34 ONE = train_model.param_init_net.ConstantFill(
35 [],
"ONE", shape=[1], value=1.0,
37 NEGONE = train_model.param_init_net.ConstantFill(
38 [],
'NEGONE', shape=[1], value=-1.0,
42 param_grad = train_model.param_to_grad[param]
43 param_momentum = train_model.param_init_net.ConstantFill(
44 [param], param +
'_momentum', value=0.0
48 train_model.net.MomentumSGD(
49 [param_grad, param_momentum, LR],
50 [param_grad, param_momentum],
56 train_model.WeightedSum(
57 [param, ONE, param_grad, NEGONE],
63 gpus = list(range(args.num_gpus))
64 log.info(
"Running on gpus: {}".format(gpus))
67 train_model = cnn.CNNModelHelper(
71 cudnn_exhaustive_search=
False 75 def create_resnet50_model_ops(model, loss_scale):
76 [softmax, loss] = resnet.create_resnet50(
83 model.Accuracy([softmax,
"label"],
"accuracy")
87 def add_parameter_update_ops(model):
88 model.AddWeightDecay(1e-4)
89 ITER = model.Iter(
"ITER")
91 LR = model.net.LearningRate(
101 def add_image_input(model):
104 start_time = time.time()
107 data_parallel_model.Parallelize_GPU(
109 input_builder_fun=add_image_input,
110 forward_pass_builder_fun=create_resnet50_model_ops,
111 param_update_builder_fun=add_parameter_update_ops,
115 ct = time.time() - start_time
116 train_model.net._CheckLookupTables()
118 log.info(
"Model create for {} gpus took: {} secs".format(len(gpus), ct))
123 parser = argparse.ArgumentParser(
124 description=
"Caffe2: Benchmark for net construction" 126 parser.add_argument(
"--num_gpus", type=int, default=1,
127 help=
"Number of GPUs.")
128 args = parser.parse_args()
133 if __name__ ==
'__main__':
134 workspace.GlobalInit([
'caffe2',
'--caffe2_log_level=2'])
137 cProfile.run(
'main()', sort=
"cumulative")
def AddMomentumParameterUpdate(train_model, LR)