提交 176d8b39 编写于 作者: A Afroz Mohiuddin 提交者: GitHub

Revert "Fix decoding in prepend mode (#1726)"

This reverts commit c825d126.
上级 d381f2bc
......@@ -458,6 +458,7 @@ def universal_transformer_base():
@registry.register_hparams
def universal_transformer_base_tpu():
hparams = universal_transformer_base()
hparams = update_hparams_for_universal_transformer(hparams)
transformer.update_hparams_for_tpu(hparams)
hparams.add_step_timing_signal = False
return hparams
......@@ -466,6 +467,7 @@ def universal_transformer_base_tpu():
@registry.register_hparams
def universal_transformer_big():
hparams = universal_transformer_base()
hparams = update_hparams_for_universal_transformer(hparams)
hparams.hidden_size = 2048
hparams.filter_size = 8192
return hparams
......
......@@ -863,15 +863,9 @@ class Transformer(t2t_model.T2TModel):
vocab_size = tf.shape(ret)[1]
def forced_logits():
# Workaround for: tf.one_hot(
# tf.repeat(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
# -1e9)
# Can be replaced by the above in future versions (from tf 1.15).
return tf.one_hot(
tf.reshape(tf.tile(
tf.reshape(partial_targets[:, i], [-1, 1]),
[1, beam_size]), [-1]),
vocab_size, 0.0, -1e9)
tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
-1e9)
ret = tf.cond(
tf.less(i, partial_targets_length), forced_logits, lambda: ret)
......@@ -1174,6 +1168,9 @@ def fast_decode(encoder_output,
"scores": decoding log probs from the beam search,
None if using greedy decoding (beam_size=1)
}
Raises:
NotImplementedError: If beam size > 1 with partial targets.
"""
if encoder_output is not None:
batch_size = common_layers.shape_list(encoder_output)[0]
......
......@@ -927,13 +927,6 @@ def _interactive_input_tensor_to_features_dict(feature_map, hparams):
features["decode_length"] = (
IMAGE_DECODE_LENGTH if input_is_image else inputs[1])
features["inputs"] = x
# Save inputs to "partial_targets" when prepending inputs to targets. Also
# keep "inputs" as some models crash if they don't exist.
if getattr(hparams, "prepend_mode", "none") != "none":
shape = tf.shape(x)
partial_targets = tf.reshape(x, [shape[0], shape[1]])
partial_targets = tf.pad(partial_targets, [[0, 0], [0, 1]])
features["partial_targets"] = partial_targets
return features
......@@ -964,13 +957,6 @@ def _decode_input_tensor_to_features_dict(feature_map, hparams):
features["decode_length"] = (
IMAGE_DECODE_LENGTH if input_is_image else tf.shape(x)[1] + 50)
features["inputs"] = x
# Save inputs to "partial_targets" when prepending inputs to targets. Also
# keep "inputs" as some models crash if they don't exist.
if getattr(hparams, "prepend_mode", "none") != "none":
shape = tf.shape(x)
partial_targets = tf.reshape(x, [shape[0], shape[1]])
partial_targets = tf.pad(partial_targets, [[0, 0], [0, 1]])
features["partial_targets"] = partial_targets
return features
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册