3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
13 def __init__(self, name, optype, p=None, op=None):
24 Insert input of this op 25 also maintain the output of previous op 26 p: a node or a list of node 28 if isinstance(p, list):
31 i.ops[self.
name] = self
32 elif isinstance(p, NetDefNode):
34 p.ops[self.
name] = self
36 def deleteInput(self, p):
37 if isinstance(p, NetDefNode):
44 Combine mask and weights 45 create wcsr, iw, jw, return their names 47 w = workspace.FetchBlob(weight_name)
48 w_csr = scipy.sparse.csr_matrix(w)
52 workspace.FeedBlob(weight_name +
"wcsr", wcsr)
53 workspace.FeedBlob(weight_name +
"iw", iw)
54 workspace.FeedBlob(weight_name +
"jw", jw)
55 return weight_name +
"wcsr", weight_name +
"iw", weight_name +
"jw" 60 Add trans before and after this FC_Prune->(Relu)->FC_Prune chain. 64 pre = cur.prev.itervalues().next()
67 current_blob = model.Transpose(cur.op.input[0], cur.op.input[0] +
"_trans")
69 trans_op = model.net.Proto().op[-1]
70 trans_node =
NetDefNode(trans_op.output[0],
"Transpose", pre, trans_op)
71 trans_node.visited =
True 78 if not (cur.optype ==
"FC_Prune" or cur.optype ==
"Relu"):
79 print(
"Reaching the end of the chain")
82 print(
"A FC/Relu giving more than 1 useful outputs")
83 if cur.optype ==
"FC_Prune":
86 bias_name = op.input[3]
88 current_blob = model.FC_Sparse(current_blob,
89 cur.op.output[0] +
"_Sparse",
90 wcsr, iw, jw, bias_name)
91 sps_op = model.net.Proto().op[-1]
92 sps_node =
NetDefNode(cur.op.output[0] +
"_Sparse",
95 sps_node.visited =
True 97 if cur.optype ==
"Relu":
99 current_blob = model.Relu(current_blob, current_blob)
100 rel_op = model.net.Proto().op[-1]
101 rel_node =
NetDefNode(str(current_blob),
"Relu",
103 rel_node.visited =
True 109 for _, temp
in cur.ops.iteritems():
110 if temp.optype ==
"Relu" or temp.optype ==
"FC_Prune":
115 cur = cur.ops.itervalues().next()
117 print(
"No FC/RElu children")
121 current_blob = model.Transpose(current_blob, pre.op.output[0])
122 trans_op = model.net.Proto().op[-1]
123 trans_node =
NetDefNode(str(current_blob),
"Transpose", pre_new, trans_op)
124 trans_node.visited =
True 125 cur.insertInput(trans_node)
127 print(trans_node.ops)
130 def Prune2Sparse(cur, id2node, name2id, ops, model):
133 if not cur.visited
and cur.optype ==
"FC_Prune":
137 for name, n
in cur.ops.iteritems():
138 Prune2Sparse(n, id2node, name2id, ops, model)
143 Use topological order(BFS) to print the op of a net in a list 148 for _, n
in cur.ops.iteritems():
152 bfs_queue = bfs_queue[1:]
153 op_list.append(node.op)
154 for _, n
in node.ops.iteritems():
160 def netbuilder(model):
161 print(
"Welcome to model checker")
162 proto = model.net.Proto()
165 net_root =
NetDefNode(
"net_root",
"root",
None)
167 for op_id, op
in enumerate(proto.op):
168 if op.type ==
"Print":
170 op_name =
'%s/%s (op#%d)' % (op.name, op.type, op_id) \
171 if op.name
else '%s (op#%d)' % (op.type, op_id)
174 net_id2node[op_id] = op_node
176 if_has_layer_input =
False 177 for input_name
in op.input:
178 if input_name
not in net_name2id:
182 op_node.insertInput(net_id2node[net_name2id[input_name]])
183 if_has_layer_input =
True 185 if not if_has_layer_input:
186 op_node.insertInput(net_root)
188 for output_name
in op.output:
189 net_name2id[output_name] = op_id
191 return net_root, net_name2id, net_id2node