3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
7 from future.utils
import viewitems, viewkeys, viewvalues
14 log = logging.getLogger(
"data_parallel_model_utils")
15 log.setLevel(logging.INFO)
18 def GetActivationBlobs(model):
21 first_gpu_prefix =
"{}_{}/".format(model._device_prefix, model._devices[0])
24 for op
in model.net.Proto().op:
28 params = set(model.GetParams(
''))
30 for op
in model.net.Proto().op:
32 if b.startswith(first_gpu_prefix)
and not b.endswith(
"_grad"):
33 if b
in all_inputs
and b
not in params
and b +
"_grad" in all_inputs:
34 activations.append(stripBlobName(b))
38 def _ShiftActivationDevices(model, activations, from_device, to_device):
39 prefix =
"{}_{}/".format(model._device_prefix, from_device)
40 activations = set([prefix + a
for a
in activations])
41 all_activations = set([prefix + a
for a
in GetActivationBlobs(model)])
42 ops = list(op
for op
in model.net.Proto().op
if 43 op.device_option.cuda_gpu_id == from_device)
44 device_mapping = {a: to_device
for a
in activations}
45 device_mapping.update({b: from_device
for b
in all_activations
if 46 b
not in activations})
53 for b
in list(op.input) + list(op.output):
54 if b
in device_mapping:
55 if b
in all_activations
or op_device
is None:
56 op_device = device_mapping[b]
58 op_device = op.device_option.cuda_gpu_id
59 for b
in list(op.input) + list(op.output):
60 if b
not in device_mapping
and b.startswith(prefix):
61 device_mapping[b] = op_device
62 op.device_option.cuda_gpu_id = op_device
65 for op
in model.param_init_net.Proto().op:
66 if op.output[0]
in device_mapping:
67 op.device_option.cuda_gpu_id = device_mapping[op.output[0]]
70 def ShiftActivationDevices(model, activations, shifts):
72 Function to enable simple model-parallellism for data_parallel_model 73 models. 'shifts' is a dictionary from_gpu -> to_gpu, and activations is 74 a list of activation blobs (wout gpu_x/ prefix -- use GetActivationBlobs()). 76 Operators handling these activations are shifted to the gpu declared in 77 'shifts'. Also related operators such as gradient operators will be moved. 78 Appropriate copy-ops are inserted. 80 This allows shifting memory usage from one gpu to another, enabling bigger 83 assert set(viewvalues(shifts)).intersection(set(viewkeys(shifts))) == set()
84 for from_device, to_device
in viewitems(shifts):
86 "Shifting {} activations from {} --> {}".
87 format(len(activations), from_device, to_device)
89 _ShiftActivationDevices(model, activations, from_device, to_device)
91 param_init_net, blob_to_device = core.InjectCrossDeviceCopies(model.param_init_net)
92 net, _blob_to_device = core.InjectCrossDeviceCopies(model.net, blob_to_device)
93 model.param_init_net = param_init_net