3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
9 from future.utils
import viewitems
20 logger = logging.getLogger(__name__)
21 logger.setLevel(logging.INFO)
22 logger.addHandler(logging.StreamHandler(sys.stderr))
25 def _weighted_sum(model, values, weight, output_name):
26 values_weights = zip(values, [weight] * len(values))
27 values_weights_flattened = [x
for v_w
in values_weights
for x
in v_w]
28 return model.net.WeightedSum(
29 values_weights_flattened,
36 def scope(self, scope_name, blob_name):
38 scope_name +
'/' + blob_name
39 if scope_name
is not None 53 attention_type = model_params[
'attention']
54 assert attention_type
in [
'none',
'regular']
55 use_attention = (attention_type !=
'none')
57 with core.NameScope(scope):
58 encoder_embeddings = seq2seq_util.build_embeddings(
61 embedding_size=model_params[
'encoder_embedding_size'],
62 name=
'encoder_embeddings',
63 freeze_embeddings=
False,
68 weighted_encoder_outputs,
69 final_encoder_hidden_states,
70 final_encoder_cell_states,
71 encoder_units_per_layer,
72 ) = seq2seq_util.build_embedding_encoder(
74 encoder_params=model_params[
'encoder_type'],
75 num_decoder_layers=len(model_params[
'decoder_layer_configs']),
79 embeddings=encoder_embeddings,
80 embedding_size=model_params[
'encoder_embedding_size'],
81 use_attention=use_attention,
86 with core.NameScope(scope):
89 encoder_outputs = model.net.Tile(
91 'encoder_outputs_tiled',
96 if weighted_encoder_outputs
is not None:
97 weighted_encoder_outputs = model.net.Tile(
98 weighted_encoder_outputs,
99 'weighted_encoder_outputs_tiled',
104 decoder_embeddings = seq2seq_util.build_embeddings(
107 embedding_size=model_params[
'decoder_embedding_size'],
108 name=
'decoder_embeddings',
109 freeze_embeddings=
False,
111 embedded_tokens_t_prev = step_model.net.Gather(
112 [decoder_embeddings, previous_tokens],
113 'embedded_tokens_t_prev',
117 decoder_units_per_layer = []
118 for i, layer_config
in enumerate(model_params[
'decoder_layer_configs']):
119 num_units = layer_config[
'num_units']
120 decoder_units_per_layer.append(num_units)
122 input_size = model_params[
'decoder_embedding_size']
125 model_params[
'decoder_layer_configs'][i - 1][
'num_units']
130 input_size=input_size,
131 hidden_size=num_units,
133 memory_optimization=
False,
135 decoder_cells.append(cell)
137 with core.NameScope(scope):
138 if final_encoder_hidden_states
is not None:
139 for i
in range(len(final_encoder_hidden_states)):
140 if final_encoder_hidden_states[i]
is not None:
141 final_encoder_hidden_states[i] = model.net.Tile(
142 final_encoder_hidden_states[i],
143 'final_encoder_hidden_tiled_{}'.format(i),
147 if final_encoder_cell_states
is not None:
148 for i
in range(len(final_encoder_cell_states)):
149 if final_encoder_cell_states[i]
is not None:
150 final_encoder_cell_states[i] = model.net.Tile(
151 final_encoder_cell_states[i],
152 'final_encoder_cell_tiled_{}'.format(i),
157 seq2seq_util.build_initial_rnn_decoder_states(
159 encoder_units_per_layer=encoder_units_per_layer,
160 decoder_units_per_layer=decoder_units_per_layer,
161 final_encoder_hidden_states=final_encoder_hidden_states,
162 final_encoder_cell_states=final_encoder_cell_states,
163 use_attention=use_attention,
167 encoder_outputs=encoder_outputs,
168 encoder_output_dim=encoder_units_per_layer[-1],
169 encoder_lengths=
None,
171 attention_type=attention_type,
172 embedding_size=model_params[
'decoder_embedding_size'],
173 decoder_num_units=decoder_units_per_layer[-1],
174 decoder_cells=decoder_cells,
175 weighted_encoder_outputs=weighted_encoder_outputs,
178 states_prev = step_model.net.AddExternalInputs(*[
179 '{}/{}_prev'.format(scope, s)
180 for s
in attention_decoder.get_state_names()
182 decoder_outputs, states = attention_decoder.apply(
184 input_t=embedded_tokens_t_prev,
185 seq_lengths=fake_seq_lengths,
191 BeamSearchForwardOnly.StateConfig(
192 initial_value=initial_state,
193 state_prev_link=BeamSearchForwardOnly.LinkConfig(
198 state_link=BeamSearchForwardOnly.LinkConfig(
204 for initial_state, state_prev, state
in zip(
211 with core.NameScope(scope):
212 decoder_outputs_flattened, _ = step_model.net.Reshape(
215 'decoder_outputs_flattened',
216 'decoder_outputs_and_contexts_combination_old_shape',
218 shape=[-1, attention_decoder.get_output_dim()],
220 output_logits = seq2seq_util.output_projection(
222 decoder_outputs=decoder_outputs_flattened,
223 decoder_output_size=attention_decoder.get_output_dim(),
225 decoder_softmax_size=model_params[
'decoder_softmax_size'],
228 output_probs = step_model.net.Softmax(
232 output_log_probs = step_model.net.Log(
237 attention_weights = attention_decoder.get_attention_weights()
239 attention_weights = step_model.net.ConstantFill(
241 'zero_attention_weights_tmp_1',
244 attention_weights = step_model.net.Transpose(
246 'zero_attention_weights_tmp_2',
248 attention_weights = step_model.net.Tile(
250 'zero_attention_weights_tmp',
261 def build_word_rewards(self, vocab_size, word_reward, unk_reward):
262 word_rewards = np.full([vocab_size], word_reward, dtype=np.float32)
263 word_rewards[seq2seq_util.PAD_ID] = 0
264 word_rewards[seq2seq_util.GO_ID] = 0
265 word_rewards[seq2seq_util.EOS_ID] = 0
266 word_rewards[seq2seq_util.UNK_ID] = word_reward + unk_reward
273 self.
models = translate_params[
'ensemble_models']
274 decoding_params = translate_params[
'decoding_params']
275 self.
beam_size = decoding_params[
'beam_size']
277 assert len(self.
models) > 0
278 source_vocab = self.
models[0][
'source_vocab']
279 target_vocab = self.
models[0][
'target_vocab']
281 assert model[
'source_vocab'] == source_vocab
282 assert model[
'target_vocab'] == target_vocab
288 'model{}'.format(i)
for i
in range(len(self.
models))
293 self.
encoder_inputs = self.model.net.AddExternalInput(
'encoder_inputs')
301 fake_seq_lengths = self.model.param_init_net.ConstantFill(
306 dtype=core.DataType.INT32,
312 go_token_id=seq2seq_util.GO_ID,
313 eos_token_id=seq2seq_util.EOS_ID,
315 step_model = beam_decoder.get_step_model()
318 output_log_probs = []
319 attention_weights = []
320 for model, scope_name
in zip(
325 state_configs_per_decoder,
326 output_log_probs_per_decoder,
327 attention_weights_per_decoder,
330 step_model=step_model,
331 model_params=model[
'model_params'],
333 previous_tokens=beam_decoder.get_previous_tokens(),
334 timestep=beam_decoder.get_timestep(),
335 fake_seq_lengths=fake_seq_lengths,
337 state_configs.extend(state_configs_per_decoder)
338 output_log_probs.append(output_log_probs_per_decoder)
339 if attention_weights_per_decoder
is not None:
340 attention_weights.append(attention_weights_per_decoder)
342 assert len(attention_weights) > 0
343 num_decoders_with_attention_blob = (
344 self.model.param_init_net.ConstantFill(
346 'num_decoders_with_attention_blob',
347 value=1 / float(len(attention_weights)),
352 attention_weights_average = _weighted_sum(
354 values=attention_weights,
355 weight=num_decoders_with_attention_blob,
356 output_name=
'attention_weights_average',
359 num_decoders_blob = self.model.param_init_net.ConstantFill(
362 value=1 / float(len(output_log_probs)),
366 output_log_probs_average = _weighted_sum(
368 values=output_log_probs,
369 weight=num_decoders_blob,
370 output_name=
'output_log_probs_average',
372 word_rewards = self.model.param_init_net.ConstantFill(
377 dtype=core.DataType.FLOAT,
380 self.output_token_beam_list,
381 self.output_prev_index_beam_list,
382 self.output_score_beam_list,
383 self.output_attention_weights_beam_list,
384 ) = beam_decoder.apply(
387 log_probs=output_log_probs_average,
388 attentions=attention_weights_average,
389 state_configs=state_configs,
390 data_dependencies=[],
391 word_rewards=word_rewards,
394 workspace.RunNetOnce(self.model.param_init_net)
399 word_reward=translate_params[
'decoding_params'][
'word_reward'],
400 unk_reward=translate_params[
'decoding_params'][
'unk_reward'],
413 logger.info(
'Params created: ')
414 for param
in self.model.params:
417 def load_models(self):
419 for model, scope_name
in zip(
423 params_for_current_model = [
425 for param
in self.model.GetAllParams()
426 if str(param).startswith(scope_name)
428 assert workspace.RunOperatorOnce(core.CreateOperator(
431 db=model[
'model_file'],
433 ),
'Failed to create db {}'.format(model[
'model_file'])
434 assert workspace.RunOperatorOnce(core.CreateOperator(
437 params_for_current_model,
439 add_prefix=scope_name +
'/',
440 strip_prefix=
'gpu_0/',
442 logger.info(
'Model {} is loaded from a checkpoint {}'.format(
447 def decode(self, numberized_input, max_output_seq_len):
451 [token_id]
for token_id
in reversed(numberized_input)
452 ]).astype(dtype=np.int32),
456 np.array([len(numberized_input)]).astype(dtype=np.int32),
460 np.array([max_output_seq_len]).astype(dtype=np.int64),
463 workspace.RunNet(self.model.net)
465 num_steps = max_output_seq_len
466 score_beam_list = workspace.FetchBlob(self.output_score_beam_list)
468 workspace.FetchBlob(self.output_token_beam_list)
470 prev_index_beam_list = (
471 workspace.FetchBlob(self.output_prev_index_beam_list)
474 attention_weights_beam_list = (
475 workspace.FetchBlob(self.output_attention_weights_beam_list)
477 best_indices = (num_steps, 0)
478 for i
in range(num_steps + 1):
482 token_beam_list[i][hyp_index][0] ==
483 seq2seq_util.EOS_ID
or 487 score_beam_list[i][hyp_index][0] >
488 score_beam_list[best_indices[0]][best_indices[1]][0]
491 best_indices = (i, hyp_index)
493 i, hyp_index = best_indices
495 attention_weights_per_token = []
496 best_score = -score_beam_list[i][hyp_index][0]
498 output.append(token_beam_list[i][hyp_index][0])
499 attention_weights_per_token.append(
500 attention_weights_beam_list[i][hyp_index]
502 hyp_index = prev_index_beam_list[i][hyp_index][0]
505 attention_weights_per_token = reversed(attention_weights_per_token)
507 attention_weights_per_token = [
508 list(reversed(attention_weights))[:len(numberized_input)]
509 for attention_weights
in attention_weights_per_token
511 output = list(reversed(output))
512 return output, attention_weights_per_token, best_score
515 def run_seq2seq_beam_decoder(args, model_params, decoding_params):
516 source_vocab = seq2seq_util.gen_vocab(
520 logger.info(
'Source vocab size {}'.format(len(source_vocab)))
521 target_vocab = seq2seq_util.gen_vocab(
525 inversed_target_vocab = {v: k
for (k, v)
in viewitems(target_vocab)}
526 logger.info(
'Target vocab size {}'.format(len(target_vocab)))
529 translate_params=dict(
530 ensemble_models=[dict(
531 source_vocab=source_vocab,
532 target_vocab=target_vocab,
533 model_params=model_params,
534 model_file=args.checkpoint,
536 decoding_params=decoding_params,
539 decoder.load_models()
541 for line
in sys.stdin:
542 numerized_source_sentence = seq2seq_util.get_numberized_sentence(
546 translation, alignment, _ = decoder.decode(
547 numerized_source_sentence,
548 2 * len(numerized_source_sentence) + 5,
550 print(
' '.join([inversed_target_vocab[tid]
for tid
in translation]))
554 parser = argparse.ArgumentParser(
555 description=
'Caffe2: Seq2Seq Translation',
557 parser.add_argument(
'--source-corpus', type=str, default=
None,
558 help=
'Path to source corpus in a text file format. Each ' 559 'line in the file should contain a single sentence',
561 parser.add_argument(
'--target-corpus', type=str, default=
None,
562 help=
'Path to target corpus in a text file format',
564 parser.add_argument(
'--unk-threshold', type=int, default=50,
565 help=
'Threshold frequency under which token becomes ' 566 'labeled unknown token')
568 parser.add_argument(
'--use-bidirectional-encoder', action=
'store_true',
569 help=
'Set flag to use bidirectional recurrent network ' 571 parser.add_argument(
'--use-attention', action=
'store_true',
572 help=
'Set flag to use seq2seq with attention model')
573 parser.add_argument(
'--encoder-cell-num-units', type=int, default=512,
574 help=
'Number of cell units per encoder layer')
575 parser.add_argument(
'--encoder-num-layers', type=int, default=2,
576 help=
'Number encoder layers')
577 parser.add_argument(
'--decoder-cell-num-units', type=int, default=512,
578 help=
'Number of cell units in the decoder layer')
579 parser.add_argument(
'--decoder-num-layers', type=int, default=2,
580 help=
'Number decoder layers')
581 parser.add_argument(
'--encoder-embedding-size', type=int, default=256,
582 help=
'Size of embedding in the encoder layer')
583 parser.add_argument(
'--decoder-embedding-size', type=int, default=512,
584 help=
'Size of embedding in the decoder layer')
585 parser.add_argument(
'--decoder-softmax-size', type=int, default=
None,
586 help=
'Size of softmax layer in the decoder')
588 parser.add_argument(
'--beam-size', type=int, default=6,
589 help=
'Size of beam for the decoder')
590 parser.add_argument(
'--word-reward', type=float, default=0.0,
591 help=
'Reward per each word generated.')
592 parser.add_argument(
'--unk-reward', type=float, default=0.0,
593 help=
'Reward per each UNK token generated. ' 594 'Typically should be negative.')
596 parser.add_argument(
'--checkpoint', type=str, default=
None,
597 help=
'Path to checkpoint', required=
True)
599 args = parser.parse_args()
601 encoder_layer_configs = [
603 num_units=args.encoder_cell_num_units,
605 ] * args.encoder_num_layers
607 if args.use_bidirectional_encoder:
608 assert args.encoder_cell_num_units % 2 == 0
609 encoder_layer_configs[0][
'num_units'] /= 2
611 decoder_layer_configs = [
613 num_units=args.decoder_cell_num_units,
615 ] * args.decoder_num_layers
617 run_seq2seq_beam_decoder(
620 attention=(
'regular' if args.use_attention
else 'none'),
621 decoder_layer_configs=decoder_layer_configs,
623 encoder_layer_configs=encoder_layer_configs,
624 use_bidirectional_encoder=args.use_bidirectional_encoder,
626 encoder_embedding_size=args.encoder_embedding_size,
627 decoder_embedding_size=args.decoder_embedding_size,
628 decoder_softmax_size=args.decoder_softmax_size,
630 decoding_params=dict(
631 beam_size=args.beam_size,
632 word_reward=args.word_reward,
633 unk_reward=args.unk_reward,
638 if __name__ ==
'__main__':
Module caffe2.python.scope.
def _build_decoder(self, model, step_model, model_params, scope, previous_tokens, timestep, fake_seq_lengths)
def build_word_rewards(self, vocab_size, word_reward, unk_reward)