提交 833a0157 编写于 作者: G guosheng

Use create_global_var instead of fill_constant in __init__ to make it...

Use create_global_var instead of fill_constant in __init__ to make it compatible between dygraph and static-graph.
上级 bc039c59
......@@ -113,7 +113,7 @@ def do_predict(args):
for data in data_loader():
finished_seq = model.test(inputs=flatten(data))[0]
finished_seq = finished_seq[:, :, np.newaxis] if len(
finished_seq.shape == 2) else finished_seq
finished_seq.shape) == 2 else finished_seq
finished_seq = np.transpose(finished_seq, [0, 2, 1])
for ins in finished_seq:
for beam_idx, beam in enumerate(ins):
......
......@@ -168,6 +168,7 @@ class SampleInfo(object):
def __init__(self, i, lens):
self.i = i
self.lens = lens
self.max_len = lens[0]
def get_ranges(self, min_length=None, max_length=None, truncate=False):
ranges = []
......
......@@ -247,6 +247,8 @@ class GreedyEmbeddingHelper(fluid.layers.GreedyEmbeddingHelper):
self.start_token_value = start_tokens
super(GreedyEmbeddingHelper, self).__init__(embedding_fn, start_tokens,
end_token)
self.end_token = fluid.layers.create_global_var(
shape=[1], dtype="int64", value=end_token, persistable=True)
def initialize(self, batch_ref=None):
if getattr(self, "need_convert_start_tokens", False):
......@@ -319,7 +321,7 @@ class AttentionGreedyInferModel(AttentionModel):
encoder_padding_mask = (src_mask - 1.0) * 1e9
encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1])
# dynamic decoding with beam search
# dynamic decoding with greedy search
rs, _ = self.greedy_search_decoder(
inits=decoder_initial_states,
encoder_output=encoder_output,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册