diff --git a/official/nlp/modeling/layers/masked_lm.py b/official/nlp/modeling/layers/masked_lm.py index dbd71a141f0222f5e8583fb52489c04027c3eb35..a33000b04961d773a5cdc185b258142a07b402c3 100644 --- a/official/nlp/modeling/layers/masked_lm.py +++ b/official/nlp/modeling/layers/masked_lm.py @@ -81,10 +81,16 @@ class MaskedLM(tf.keras.layers.Layer): lm_data = self.layer_norm(lm_data) lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True) logits = tf.nn.bias_add(lm_data, self.bias) - masked_positions_length = masked_positions.shape.as_list()[1] or tf.shape( - masked_positions)[1] - logits = tf.reshape(logits, - [-1, masked_positions_length, self._vocab_size]) + masked_positions_length = ( + masked_positions.shape.as_list()[1] or tf.shape(masked_positions)[1] + ) + batch_size = ( + masked_positions.shape.as_list()[0] or tf.shape(masked_positions)[0] + ) + logits = tf.reshape( + logits, + [batch_size, masked_positions_length, self._vocab_size], + ) if self._output_type == 'logits': return logits return tf.nn.log_softmax(logits) diff --git a/official/nlp/modeling/layers/masked_lm_test.py b/official/nlp/modeling/layers/masked_lm_test.py index 02c7e64ed22b8fc7ca80c3f7e6aa496ad4e74795..d4e144e396ba660fbfb8f85c5cb92307c1d48b6f 100644 --- a/official/nlp/modeling/layers/masked_lm_test.py +++ b/official/nlp/modeling/layers/masked_lm_test.py @@ -13,7 +13,7 @@ # limitations under the License. """Tests for masked language model network.""" - +from absl.testing import parameterized import numpy as np import tensorflow as tf @@ -21,7 +21,7 @@ from official.nlp.modeling.layers import masked_lm from official.nlp.modeling.networks import bert_encoder -class MaskedLMTest(tf.test.TestCase): +class MaskedLMTest(tf.test.TestCase, parameterized.TestCase): def create_layer(self, vocab_size, @@ -110,11 +110,20 @@ class MaskedLMTest(tf.test.TestCase): self.assertEqual(expected_output_shape, outputs.shape) self.assertAllClose(ref_outputs, outputs) - def test_layer_invocation(self): + @parameterized.named_parameters( + dict( + testcase_name='default', + num_predictions=21, + ), + dict( + testcase_name='zero_predictions', + num_predictions=0, + ), + ) + def test_layer_invocation(self, num_predictions): vocab_size = 100 sequence_length = 32 hidden_size = 64 - num_predictions = 21 test_layer = self.create_layer( vocab_size=vocab_size, hidden_size=hidden_size) @@ -131,7 +140,9 @@ class MaskedLMTest(tf.test.TestCase): (batch_size, sequence_length, hidden_size)) masked_position_data = np.random.randint( 2, size=(batch_size, num_predictions)) - _ = model.predict([lm_input_data, masked_position_data]) + res = model.predict([lm_input_data, masked_position_data]) + expected_shape = (batch_size, num_predictions, vocab_size) + self.assertEqual(expected_shape, res.shape) def test_unknown_output_type_fails(self): with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):