From 833a0157325125f89918127f07b233179050f061 Mon Sep 17 00:00:00 2001 From: guosheng Date: Mon, 20 Apr 2020 21:39:01 +0800 Subject: [PATCH] Use create_global_var instead of fill_constant in __init__ to make it compatible between dygraph and static-graph. --- seq2seq/predict.py | 2 +- seq2seq/reader.py | 1 + seq2seq/seq2seq_attn.py | 4 +++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/seq2seq/predict.py b/seq2seq/predict.py index c51eed2..c9120bf 100644 --- a/seq2seq/predict.py +++ b/seq2seq/predict.py @@ -113,7 +113,7 @@ def do_predict(args): for data in data_loader(): finished_seq = model.test(inputs=flatten(data))[0] finished_seq = finished_seq[:, :, np.newaxis] if len( - finished_seq.shape == 2) else finished_seq + finished_seq.shape) == 2 else finished_seq finished_seq = np.transpose(finished_seq, [0, 2, 1]) for ins in finished_seq: for beam_idx, beam in enumerate(ins): diff --git a/seq2seq/reader.py b/seq2seq/reader.py index a6fa73f..26f5d6a 100644 --- a/seq2seq/reader.py +++ b/seq2seq/reader.py @@ -168,6 +168,7 @@ class SampleInfo(object): def __init__(self, i, lens): self.i = i self.lens = lens + self.max_len = lens[0] def get_ranges(self, min_length=None, max_length=None, truncate=False): ranges = [] diff --git a/seq2seq/seq2seq_attn.py b/seq2seq/seq2seq_attn.py index 507c72a..136b474 100644 --- a/seq2seq/seq2seq_attn.py +++ b/seq2seq/seq2seq_attn.py @@ -247,6 +247,8 @@ class GreedyEmbeddingHelper(fluid.layers.GreedyEmbeddingHelper): self.start_token_value = start_tokens super(GreedyEmbeddingHelper, self).__init__(embedding_fn, start_tokens, end_token) + self.end_token = fluid.layers.create_global_var( + shape=[1], dtype="int64", value=end_token, persistable=True) def initialize(self, batch_ref=None): if getattr(self, "need_convert_start_tokens", False): @@ -319,7 +321,7 @@ class AttentionGreedyInferModel(AttentionModel): encoder_padding_mask = (src_mask - 1.0) * 1e9 encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1]) - # dynamic decoding with beam search + # dynamic decoding with greedy search rs, _ = self.greedy_search_decoder( inits=decoder_initial_states, encoder_output=encoder_output, -- GitLab