From c34507f2f2b24567289f4c08066dcdd0cfa3aa55 Mon Sep 17 00:00:00 2001 From: Tyler Scott Date: Thu, 24 Aug 2023 13:20:38 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 559849502 --- .../pix2seq/modeling/pix2seq_model.py | 23 +++++------ .../pix2seq/modeling/pix2seq_model_test.py | 38 ++++++++++++++++++- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/official/projects/pix2seq/modeling/pix2seq_model.py b/official/projects/pix2seq/modeling/pix2seq_model.py index 0111d3579..b457f3bc7 100644 --- a/official/projects/pix2seq/modeling/pix2seq_model.py +++ b/official/projects/pix2seq/modeling/pix2seq_model.py @@ -335,6 +335,7 @@ class Pix2Seq(tf.keras.Model): inputs: tf.Tensor, targets: Optional[tf.Tensor] = None, training: bool = None, + use_teacher_forcing_for_eval: bool = False ) -> List[Any]: features = self._backbone(inputs)[self._backbone_endpoint_name] mask = tf.ones_like(features) @@ -350,22 +351,18 @@ class Pix2Seq(tf.keras.Model): pos_emb = tf.cast(pos_emb, features.dtype) tokens = None + inputs = { + "inputs": features, + "tokens": targets, + "pos_emb": pos_emb, + } if training: - logits = self._transformer( - { - "inputs": features, - "tokens": targets, - "pos_emb": pos_emb, - }, - training, - ) + logits = self._transformer(inputs, training=True) + elif use_teacher_forcing_for_eval: + logits = self._transformer(inputs, training=False) else: tokens, logits = self._transformer.infer( - { - "inputs": features, - "tokens": targets, - "pos_emb": pos_emb, - }, + inputs, top_k=self._top_k, top_p=self._top_p, ) diff --git a/official/projects/pix2seq/modeling/pix2seq_model_test.py b/official/projects/pix2seq/modeling/pix2seq_model_test.py index ddf4050eb..784f43c32 100644 --- a/official/projects/pix2seq/modeling/pix2seq_model_test.py +++ b/official/projects/pix2seq/modeling/pix2seq_model_test.py @@ -30,7 +30,11 @@ class Pix2SeqTest(tf.test.TestCase): backbone = resnet.ResNet(50, bn_trainable=False) backbone_endpoint_name = '5' model = pix2seq_model.Pix2Seq( - backbone, backbone_endpoint_name, max_seq_len, vocab_size, hidden_size, + backbone, + backbone_endpoint_name, + max_seq_len, + vocab_size, + hidden_size, num_heads=num_heads, ) _, outs = model( @@ -41,6 +45,32 @@ class Pix2SeqTest(tf.test.TestCase): self.assertLen(outs, 2) # intermediate decoded outputs. + def test_forward_infer_teacher_forcing(self): + hidden_size = 256 + num_heads = 8 + max_seq_len = 50 + vocab_size = 164 + image_size = 224 + batch_size = 2 + backbone = resnet.ResNet(50, bn_trainable=False) + backbone_endpoint_name = '5' + model = pix2seq_model.Pix2Seq( + backbone, + backbone_endpoint_name, + max_seq_len, + vocab_size, + hidden_size, + num_heads=num_heads, + ) + _, outs = model( + tf.ones((batch_size, image_size, image_size, 3)), + tf.ones((batch_size, max_seq_len), tf.int64), + training=False, + use_teacher_forcing_for_eval=True, + ) + + self.assertLen(outs, 2) # intermediate decoded outputs. + def test_forward_infer(self): hidden_size = 256 num_heads = 8 @@ -51,7 +81,11 @@ class Pix2SeqTest(tf.test.TestCase): backbone = resnet.ResNet(50, bn_trainable=False) backbone_endpoint_name = '5' model = pix2seq_model.Pix2Seq( - backbone, backbone_endpoint_name, max_seq_len, vocab_size, hidden_size, + backbone, + backbone_endpoint_name, + max_seq_len, + vocab_size, + hidden_size, num_heads=num_heads, ) tokens, _ = model( -- GitLab