Caffe2 - Python API
A deep learning, cross platform ML framework
device_checker.py
1 ## @package device_checker
2 # Module caffe2.python.device_checker
3 import numpy as np
4 import copy
5 from caffe2.python import workspace
6 from future.utils import viewitems
7 
8 
9 class DeviceChecker(object):
10  """A device checker in Python to check consistency across multiple devices.
11 
12  This is not the most efficient way to check devices, as the Python interface
13  will involve a lot of copies back and forth operations. Use at your own risk.
14  """
15 
16  def __init__(self, threshold, device_options):
17  self._threshold = threshold
18  self._device_options = device_options
19 
20  def CheckSimple(self, op, inputs, outputs_to_check,
21  input_device_options=None):
22  """Checks the operator with different device implementations.
23 
24  Inputs:
25  op: the operator to be checked.
26  inputs: the input data in numpy arrays.
27  outputs_to_check: the outputs to check between devices.
28  input_device_options: a mapping from input name to a device to use
29  (instead of self._device_options)
30  Outputs:
31  boolean: True if it passes, False if it does not pass.
32  """
33  op = copy.deepcopy(op)
34  input_device_options = input_device_options or {}
35  # Entering the checker workspace
36  old_ws_name = workspace.CurrentWorkspace()
37  results = []
38  workspace.SwitchWorkspace("_device_check_", True)
39  for i, device_option in enumerate(self._device_options):
40  for i, arr in enumerate(inputs):
41  workspace.FeedBlob(
42  op.input[i], np.array(arr),
43  input_device_options.get(op.input[i], device_option))
44  op.device_option.CopyFrom(device_option)
45  workspace.RunOperatorOnce(op)
46  results.append(
47  [workspace.FetchBlob(op.output[idx])
48  for idx in outputs_to_check])
49  # Everything is done, reset the workspace.
50  workspace.ResetWorkspace()
51  # After running on all devices, check correctness
52  success = True
53  for i in range(1, len(self._device_options)):
54  for j in range(len(outputs_to_check)):
55  x = results[i][j]
56  y = results[0][j]
57  if not np.allclose(x, y,
58  atol=self._threshold, rtol=self._threshold):
59  print('Failure in checking device option {}'
60  ' and output {}. The outputs are:'
61  .format(i, op.output[outputs_to_check[j]]))
62  print(x.flatten())
63  print(y.flatten())
64  print(np.max(np.abs(x - y)))
65  success = False
66  # else:
67  # print ('Passed device pair (0, %d), %s %s' %
68  # (i, outputs_to_check[j], y.shape))
69  workspace.SwitchWorkspace(old_ws_name)
70  return success
71 
72  def CheckNet(self, net, inputs=None, blobs_to_check=None, ignore=None):
73  """Checks a network by inspecting all of its intermediate results, and
74  see if things match.
75  """
76  if inputs is None:
77  inputs = {}
78  if ignore is None:
79  ignore = set()
80  old_ws_name = workspace.CurrentWorkspace()
81  results = []
82  if blobs_to_check is None:
83  blobs_to_check = sum([list(op.output) for op in net.op], [])
84  blobs_to_check = [b for b in blobs_to_check if b not in ignore]
85  workspace.SwitchWorkspace("_device_check_", True)
86  for device_option in self._device_options:
87  for name, arr in viewitems(inputs):
88  # print 'feeding', name
89  workspace.FeedBlob(name, arr, device_option)
90  for op in net.op:
91  op.device_option.CopyFrom(device_option)
92  workspace.RunNetOnce(net)
93  results.append(
94  [workspace.FetchBlob(name) for name in blobs_to_check]
95  )
96  # After running on all devices, check correctness
97  success = True
98  for i in range(1, len(results)):
99  for j in range(len(blobs_to_check)):
100  x = results[i][j]
101  y = results[0][j]
102  if not np.allclose(x, y,
103  atol=self._threshold, rtol=self._threshold):
104  print('Failure in checking device option {}'
105  ' and output {}. The outputs are:'
106  .format(i, blobs_to_check[j]))
107  print(x.flatten())
108  print(y.flatten())
109  print(np.max(np.abs(x - y)))
110  success = False
111  # else:
112  # print ('Passed device pair (%d, %d), %s %s: %s' %
113  # (i, j, blobs_to_check[j], y.shape,
114  # str(y.flatten())))
115  workspace.SwitchWorkspace(old_ws_name)
116  return success
def CheckSimple(self, op, inputs, outputs_to_check, input_device_options=None)
def CheckNet(self, net, inputs=None, blobs_to_check=None, ignore=None)