Caffe2 - Python API
A deep learning, cross platform ML framework
rewrite_graph.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 import copy
7 from caffe2.proto import caffe2_pb2
8 from caffe2.python import core
9 
10 
11 def rewrite_init_net_simple(net):
12  for op in net.op:
13  op.device_option.device_type = caffe2_pb2.MKLDNN
14 
15 def last_producer(ops, blob):
16  for (i, op) in reversed(list(enumerate(ops))):
17  if blob in op.output:
18  return i
19  raise ValueError("Failed to find last producer of blob, %s", blob)
20 
21 
22 def rewrite_run_net_simple(net):
23  # Simple rewrite for now - assume entire graph can be executed
24  # with MKL, so just insert copy ops for external_input[0] and
25  # external_output[0]
26  def mkl_tmp(name):
27  return "{}__MKL__".format(name)
28 
29  input_blob = net.external_input[0]
30  if input_blob != net.op[0].input[0]:
31  raise Exception(
32  "Input blob: {} is not consumed by first op: {}".format(
33  input_blob, net.op[0]))
34  # Modify input/outputs to point to copied MKL blobs.
35  copy_input_op = core.CreateOperator(
36  "CopyCPUToMKL", input_blob, mkl_tmp(input_blob))
37  net.op[0].input[0] = mkl_tmp(input_blob)
38 
39  copy_output_ops = [
40  core.CreateOperator("CopyMKLToCPU", mkl_tmp(output_blob), output_blob)
41  for output_blob in net.external_output]
42 
43  for output_blob in net.external_output:
44  last_producer_idx = last_producer(net.op, output_blob)
45  renamed_outputs = [blob if blob != output_blob else mkl_tmp(blob)
46  for blob in net.op[last_producer_idx].output]
47  net.op[last_producer_idx].output[:] = renamed_outputs
48  # Rename any subsequent consumers of an output blob.
49  for op in net.op[last_producer_idx + 1:]:
50  renamed_input = [blob if blob != output_blob else mkl_tmp(blob)
51  for blob in op.input]
52  op.input[:] = renamed_input
53 
54  ops = [copy_input_op] + net.op[:] + copy_output_ops
55  del net.op[:]
56  net.op.extend(ops)
57  for op in net.op:
58  op.device_option.MergeFrom(
59  core.DeviceOption(device_type=caffe2_pb2.MKLDNN))
60  op.engine = ""
61 
62 
63 def rewrite_model_helper_simple(model):
64  model = copy.deepcopy(model)
65  # All parameter initialization should run on MKL
66  rewrite_init_net_simple(model.param_init_net.Proto())
67  rewrite_run_net_simple(model.net.Proto())
68  return model