提交 c34507f2 编写于 作者: T Tyler Scott 提交者: A. Unique TensorFlower

No public description

PiperOrigin-RevId: 559849502
上级 564ad533
......@@ -335,6 +335,7 @@ class Pix2Seq(tf.keras.Model):
inputs: tf.Tensor,
targets: Optional[tf.Tensor] = None,
training: bool = None,
use_teacher_forcing_for_eval: bool = False
) -> List[Any]:
features = self._backbone(inputs)[self._backbone_endpoint_name]
mask = tf.ones_like(features)
......@@ -350,22 +351,18 @@ class Pix2Seq(tf.keras.Model):
pos_emb = tf.cast(pos_emb, features.dtype)
tokens = None
inputs = {
"inputs": features,
"tokens": targets,
"pos_emb": pos_emb,
}
if training:
logits = self._transformer(
{
"inputs": features,
"tokens": targets,
"pos_emb": pos_emb,
},
training,
)
logits = self._transformer(inputs, training=True)
elif use_teacher_forcing_for_eval:
logits = self._transformer(inputs, training=False)
else:
tokens, logits = self._transformer.infer(
{
"inputs": features,
"tokens": targets,
"pos_emb": pos_emb,
},
inputs,
top_k=self._top_k,
top_p=self._top_p,
)
......
......@@ -30,7 +30,11 @@ class Pix2SeqTest(tf.test.TestCase):
backbone = resnet.ResNet(50, bn_trainable=False)
backbone_endpoint_name = '5'
model = pix2seq_model.Pix2Seq(
backbone, backbone_endpoint_name, max_seq_len, vocab_size, hidden_size,
backbone,
backbone_endpoint_name,
max_seq_len,
vocab_size,
hidden_size,
num_heads=num_heads,
)
_, outs = model(
......@@ -41,6 +45,32 @@ class Pix2SeqTest(tf.test.TestCase):
self.assertLen(outs, 2) # intermediate decoded outputs.
def test_forward_infer_teacher_forcing(self):
hidden_size = 256
num_heads = 8
max_seq_len = 50
vocab_size = 164
image_size = 224
batch_size = 2
backbone = resnet.ResNet(50, bn_trainable=False)
backbone_endpoint_name = '5'
model = pix2seq_model.Pix2Seq(
backbone,
backbone_endpoint_name,
max_seq_len,
vocab_size,
hidden_size,
num_heads=num_heads,
)
_, outs = model(
tf.ones((batch_size, image_size, image_size, 3)),
tf.ones((batch_size, max_seq_len), tf.int64),
training=False,
use_teacher_forcing_for_eval=True,
)
self.assertLen(outs, 2) # intermediate decoded outputs.
def test_forward_infer(self):
hidden_size = 256
num_heads = 8
......@@ -51,7 +81,11 @@ class Pix2SeqTest(tf.test.TestCase):
backbone = resnet.ResNet(50, bn_trainable=False)
backbone_endpoint_name = '5'
model = pix2seq_model.Pix2Seq(
backbone, backbone_endpoint_name, max_seq_len, vocab_size, hidden_size,
backbone,
backbone_endpoint_name,
max_seq_len,
vocab_size,
hidden_size,
num_heads=num_heads,
)
tokens, _ = model(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册