Caffe2 - Python API
A deep learning, cross platform ML framework
SparseTransformer.py
1 ## @package SparseTransformer
2 # Module caffe2.experiments.python.SparseTransformer
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 from caffe2.python import workspace
8 import scipy.sparse
9 
10 
11 class NetDefNode():
12 
13  def __init__(self, name, optype, p=None, op=None):
14  self.name = name
15  self.optype = optype
16  self.ops = {}
17  self.prev = {}
18  self.insertInput(p)
19  self.visited = False
20  self.op = op
21 
22  def insertInput(self, p):
23  """
24  Insert input of this op
25  also maintain the output of previous op
26  p: a node or a list of node
27  """
28  if isinstance(p, list):
29  for i in p:
30  self.prev[i.name] = i
31  i.ops[self.name] = self
32  elif isinstance(p, NetDefNode):
33  self.prev[p.name] = p
34  p.ops[self.name] = self
35 
36  def deleteInput(self, p):
37  if isinstance(p, NetDefNode):
38  del self.prev[p.name]
39  del p.ops[self.name]
40 
41 
42 def maskNallocate(weight_name):
43  """
44  Combine mask and weights
45  create wcsr, iw, jw, return their names
46  """
47  w = workspace.FetchBlob(weight_name)
48  w_csr = scipy.sparse.csr_matrix(w)
49  wcsr = w_csr.data
50  iw = w_csr.indptr
51  jw = w_csr.indices
52  workspace.FeedBlob(weight_name + "wcsr", wcsr)
53  workspace.FeedBlob(weight_name + "iw", iw)
54  workspace.FeedBlob(weight_name + "jw", jw)
55  return weight_name + "wcsr", weight_name + "iw", weight_name + "jw"
56 
57 
58 def transFCRelu(cur, id2node, name2id, ops, model):
59  """
60  Add trans before and after this FC_Prune->(Relu)->FC_Prune chain.
61  """
62  # 1. add trans before the start of this chain
63  # assuming that cur is a FC_Prune, and it has only one input
64  pre = cur.prev.itervalues().next()
65  # Create a node /op and insert it.
66  # TODO(wyiming): check whether it is correct here
67  current_blob = model.Transpose(cur.op.input[0], cur.op.input[0] + "_trans")
68 # print model.net.Proto()
69  trans_op = model.net.Proto().op[-1]
70  trans_node = NetDefNode(trans_op.output[0], "Transpose", pre, trans_op)
71  trans_node.visited = True
72  pre_new = trans_node
73 
74  # 2. use while loop to visit the chain
75  while True:
76  # breakup with the parent
77  cur.deleteInput(pre)
78  if not (cur.optype == "FC_Prune" or cur.optype == "Relu"):
79  print("Reaching the end of the chain")
80  break
81  if len(cur.ops) > 1:
82  print("A FC/Relu giving more than 1 useful outputs")
83  if cur.optype == "FC_Prune":
84  op = cur.op
85  wcsr, iw, jw = maskNallocate(op.input[1])
86  bias_name = op.input[3]
87  # TODO(wyiming): create a new Op here
88  current_blob = model.FC_Sparse(current_blob,
89  cur.op.output[0] + "_Sparse",
90  wcsr, iw, jw, bias_name)
91  sps_op = model.net.Proto().op[-1]
92  sps_node = NetDefNode(cur.op.output[0] + "_Sparse",
93  "FC_Sparse",
94  pre_new, sps_op)
95  sps_node.visited = True
96  pre_new = sps_node
97  if cur.optype == "Relu":
98  op = cur.op
99  current_blob = model.Relu(current_blob, current_blob)
100  rel_op = model.net.Proto().op[-1]
101  rel_node = NetDefNode(str(current_blob), "Relu",
102  pre_new, rel_op)
103  rel_node.visited = True
104  pre_new = rel_node
105 
106  cur.visited = True
107  pre = cur
108  flag = False
109  for _, temp in cur.ops.iteritems():
110  if temp.optype == "Relu" or temp.optype == "FC_Prune":
111  flag = True
112  cur = temp
113  if not flag:
114  # assume that there is only 1 output that is not PrintOP
115  cur = cur.ops.itervalues().next()
116  cur.deleteInput(pre)
117  print("No FC/RElu children")
118  print(cur.op.type)
119  break
120  # 3. add trans after this chain like 1.
121  current_blob = model.Transpose(current_blob, pre.op.output[0])
122  trans_op = model.net.Proto().op[-1]
123  trans_node = NetDefNode(str(current_blob), "Transpose", pre_new, trans_op)
124  trans_node.visited = True
125  cur.insertInput(trans_node)
126  print(cur.prev)
127  print(trans_node.ops)
128 
129 
130 def Prune2Sparse(cur, id2node, name2id, ops, model):
131  # Assume that FC and Relu takes in only 1 input;
132  # If not raise warning
133  if not cur.visited and cur.optype == "FC_Prune":
134  transFCRelu(cur, id2node, name2id, ops, model)
135 
136  cur.visited = True
137  for name, n in cur.ops.iteritems():
138  Prune2Sparse(n, id2node, name2id, ops, model)
139 
140 
141 def net2list(net_root):
142  """
143  Use topological order(BFS) to print the op of a net in a list
144  """
145  bfs_queue = []
146  op_list = []
147  cur = net_root
148  for _, n in cur.ops.iteritems():
149  bfs_queue.append(n)
150  while bfs_queue:
151  node = bfs_queue[0]
152  bfs_queue = bfs_queue[1:]
153  op_list.append(node.op)
154  for _, n in node.ops.iteritems():
155  bfs_queue.append(n)
156 
157  return op_list
158 
159 
160 def netbuilder(model):
161  print("Welcome to model checker")
162  proto = model.net.Proto()
163  net_name2id = {}
164  net_id2node = {}
165  net_root = NetDefNode("net_root", "root", None)
166 
167  for op_id, op in enumerate(proto.op):
168  if op.type == "Print":
169  continue
170  op_name = '%s/%s (op#%d)' % (op.name, op.type, op_id) \
171  if op.name else '%s (op#%d)' % (op.type, op_id)
172  # print(op_name)
173  op_node = NetDefNode(op_name, op.type, op=op)
174  net_id2node[op_id] = op_node
175 
176  if_has_layer_input = False
177  for input_name in op.input:
178  if input_name not in net_name2id:
179  # assume that un_occured name are non_layers
180  # TODO: write a non-layer checker and log it
181  continue
182  op_node.insertInput(net_id2node[net_name2id[input_name]])
183  if_has_layer_input = True
184 
185  if not if_has_layer_input:
186  op_node.insertInput(net_root)
187 
188  for output_name in op.output:
189  net_name2id[output_name] = op_id
190 
191  return net_root, net_name2id, net_id2node
def transFCRelu(cur, id2node, name2id, ops, model)
def net2list(net_root)
def maskNallocate(weight_name)