3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
11 Due to a limitation in ReccurentNetworkOp, this layer only supports batch_size=1 12 In order to support batch_size > 1, we will have to implement the CRFUnit 13 and its gradient in C++ and handle the different batches there. 18 def __init__(self, model, num_classes, transitions_blob=None):
22 if not transitions_blob:
23 transitions_blob = self.model.param_init_net.UniformFill(
25 [core.ScopedBlobReference(
'crf_transitions')],
33 def crf_loss(self, predictions, labels, seq_lengths=None):
37 transitions_snapshot = self.model.net.Copy(
38 self.
transitions, core.ScopedBlobReference(
'transitions_snapshot')
49 labels, transitions_snapshot, seq_lengths
51 path_total_score = self.model.net.Add(
52 [path_binary_score, path_unary_score],
53 core.ScopedBlobReference(
'path_total')
56 zero_index = self.model.param_init_net.ConstantFill(
57 [], shape=[1], value=0
59 initial_state = self.model.net.Gather(
60 [predictions, zero_index],
61 core.ScopedBlobReference(
'rnn_initial'),
64 input_data, _ = self.model.net.RemovePadding(
70 input_data = self.model.net.ExpandDims(
72 core.ScopedBlobReference(
'rnn_input_data'),
77 transitions_copy = self.model.net.Copy(
78 transitions_snapshot, core.ScopedBlobReference(
'transitions_copy')
81 input_data, initial_state, transitions_copy
83 loss = self.model.net.Sub(
84 [all_paths_scores, path_total_score],
85 core.ScopedBlobReference(
'crf_loss')
89 def _pad_predictions(self, predictions):
103 b_scores = self.model.param_init_net.GivenTensorFill(
106 e_scores = self.model.param_init_net.GivenTensorFill(
110 zero_index = self.model.net.ConstantFill(
111 [], shape=[1, ], value=0
113 length = self.model.net.Gather(
114 [self.model.net.Shape([predictions]), zero_index],
116 length = self.model.net.Cast(length, to=
'int32')
117 t_range = self.model.net.LengthsRangeFill(length)
118 padding = self.model.net.ConstantFill([t_range], value=low_score)
119 padding = self.model.net.ExpandDims(padding, dims=[1])
120 padded_predictions, _ = self.model.net.Concat(
121 [predictions, padding, padding],
125 padded_predictions_concat, _ = self.model.net.Concat(
126 [b_scores, padded_predictions, e_scores],
130 return padded_predictions_concat
132 def _pad_labels(self, labels):
135 bos_i_b = self.model.param_init_net.ConstantFill(
136 [], shape=[1], value=bos_i
138 eos_i_b = self.model.param_init_net.ConstantFill(
139 [], shape=[1], value=eos_i
141 labels = self.model.net.Cast([labels], to=
'int64')
142 padded_labels, _ = self.model.net.Concat(
143 [bos_i_b, labels, eos_i_b],
149 def _path_binary_scores(self, labels, transitions, seq_lengths=None):
150 column_ids, _ = self.model.net.RemovePadding(
156 row_ids, _ = self.model.net.RemovePadding(
165 num_columns_blob = self.model.net.ConstantFill(
169 flattened_ids = self.model.net.Mul([row_ids, num_columns_blob])
170 flattened_ids = self.model.net.Add([flattened_ids, column_ids])
171 flattened_transitions = self.model.net.FlattenToVec([transitions])
172 entries = self.model.net.Gather(
173 [flattened_transitions, flattened_ids],
176 return self.model.ReduceFrontSum(entries)
178 def _gather_entries_sum(self, in_data, indices, index_size):
179 indices = self.model.net.Cast([indices], to=
'int64')
180 index_size_blob = self.model.param_init_net.ConstantFill(
185 query_one_hot = self.model.net.OneHot(
186 [indices, index_size_blob]
188 flattend_query = self.model.net.FlattenToVec(query_one_hot)
189 flattend_data = self.model.net.FlattenToVec(in_data)
190 query_scores = self.model.net.DotProduct(
191 [flattend_query, flattend_data]
193 final_sum = self.model.net.ReduceFrontSum([query_scores])
205 input_blob, initial_state, transitions_copy
207 out_last, _ = self.model.net.Reshape(
212 zero_segment_id = self.model.param_init_net.ConstantFill(
216 dtype=core.DataType.INT32,
220 accum_score = self.model.net.SortedSegmentRangeLogSumExp(
221 [out_last, zero_segment_id]
223 accum_score, _ = self.model.net.Reshape(
232 Adds the crf_net recurrent operator to the model. 234 model: model_helper.ModelHelper object new operators would be added 237 input_blob: the input sequence in a format T x N x D 238 where T is sequence size, N - batch size and D - input dimention 239 ##Only supports batch-size 1## 241 seq_lengths: blob containing sequence lengths (unused) 250 return "{}/{}".format(str(scope), str(name))
253 param_model=self.
model)
254 input_t, cell_t_prev, _ = (
255 step_model.net.AddExternalInputs(
256 core.ScopedBlobReference(
'input_t'),
257 core.ScopedBlobReference(
'cell_t_prev'),
261 zero_segment_id = step_model.param_init_net.ConstantFill(
263 [s(
'zero_segment_id')],
266 dtype=core.DataType.INT32,
270 step_model.param_init_net.AddExternalOutput(zero_segment_id)
273 prev_transpose = brew.transpose(
276 [s(
'prev_transpose')],
279 prev_tiled = step_model.net.Tile(
285 input_t_tiled = step_model.net.Tile(
287 [s(
'input_t_tiled')],
291 input_with_prev = step_model.net.Add(
292 [prev_tiled, input_t_tiled],
293 [s(
'input_with_prev')]
295 all_with_transitions = step_model.net.Add(
296 [input_with_prev, transitions],
297 [s(
'prev_with_transitions')],
301 all_with_transitions_reshaped, _ = step_model.net.Reshape(
302 all_with_transitions,
303 [s(
'all_with_transitions_reshaped'), s(
'all_with_transitions_orig')],
306 cell_t = step_model.net.SortedSegmentRangeLogSumExp(
307 [all_with_transitions_reshaped, zero_segment_id],
310 step_model.net.AddExternalOutputs(cell_t)
311 """ recurrent network """ 312 cell_input_blob = initial_state
313 out_all, out_last = recurrent.recurrent_net(
315 cell_net=step_model.net,
316 inputs=[(input_t, input_blob)],
317 initial_cell_inputs=[
318 (cell_t_prev, cell_input_blob),
324 outputs_with_grads=(1,)
328 def update_predictions(self, classes):
330 def crf_update_predictions_op(inputs, outputs):
334 predictions = inputs[0].data
335 transitions = inputs[1].data
336 predictions = inputs[0].data
337 predictions_shape = inputs[0].shape
338 outputs[0].reshape(predictions_shape)
340 trellis = np.zeros(predictions_shape)
341 backpointers = np.zeros(predictions_shape, dtype=np.int32)
342 trellis[0] = predictions[0]
344 for t
in range(1, predictions_shape[0]):
345 v = np.expand_dims(trellis[t - 1], 1) + transitions
346 trellis[t] = predictions[t] + np.max(v, 0)
347 backpointers[t] = np.argmax(v, 0)
349 viterbi = [np.argmax(trellis[-1])]
350 for bp
in reversed(backpointers[1:]):
351 viterbi.append(bp[viterbi[-1]])
354 new_predictions = np.zeros(predictions_shape)
356 for i, w_predictions
in enumerate(predictions):
358 new_predictions[i] = predictions[i]
359 old_best = np.argmax(w_predictions)
360 old_bests.append(old_best)
363 w_predictions[viterbi[i]], w_predictions[old_best] = \
364 w_predictions[old_best], w_predictions[viterbi[i]]
365 new_predictions[i] = w_predictions
367 orig_predictions = new_predictions[1:-1, 0:-2]
368 outputs[0].reshape(orig_predictions.shape)
369 outputs[0].data[...] = orig_predictions
371 new_classes = self.model.net.Python(crf_update_predictions_op)(
373 core.ScopedBlobReference(
'post_crf_classes')
def build_crf_net(self, input_blob, initial_state, transitions)
def _pad_predictions(self, predictions)
def _path_binary_scores(self, labels, transitions, seq_lengths=None)
def _crf_forward(self, input_blob, initial_state, transitions_copy, seq_lengths=None)
def _gather_entries_sum(self, in_data, indices, index_size)
def _pad_labels(self, labels)