提交 ba05e1a7 编写于 作者: C Chaochao Yan 提交者: A. Unique TensorFlower

No public description

PiperOrigin-RevId: 553700524
上级 3e15aa4a
......@@ -15,7 +15,7 @@
"""Dbof model definitions."""
import functools
from typing import Optional
from typing import Any, Optional
import tensorflow as tf
......@@ -124,7 +124,7 @@ class Dbof(layers.Layer):
)
def call(
self, inputs: tf.Tensor
self, inputs: tf.Tensor, num_frames: Any = None,
) -> tf.Tensor:
# L2 normalize input features
activation = tf.nn.l2_normalize(inputs, -1)
......@@ -147,6 +147,7 @@ class Dbof(layers.Layer):
activation = yt8m_model_utils.frame_pooling(
activation,
method=self._params.pooling_method,
num_frames=num_frames,
)
activation = self._hidden_dense(activation)
......
......@@ -50,7 +50,7 @@ class DbofTest(parameterized.TestCase, tf.test.TestCase):
)
inputs = tf.ones([2, 24, 32], dtype=tf.float32)
outputs = backbone(inputs)
outputs = backbone(inputs, num_frames=tf.constant([24, 16]))
self.assertAllEqual(outputs.shape.as_list(), [2, 20])
......
......@@ -131,10 +131,14 @@ class VideoClassificationModel(tf.keras.Model):
return cls(**config)
def call(
self, inputs_tensor: tf.Tensor, training: Any = None
self,
inputs: tf.Tensor,
num_frames: Any = None,
training: Any = None,
) -> dict[str, tf.Tensor]:
features = self.backbone(
inputs_tensor,
inputs,
num_frames=num_frames,
training=training,
)
outputs = self.head(features, training=training)
......
......@@ -55,11 +55,13 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
# batch = 2 -> arbitrary value for test.
if num_sample_frames:
inputs = np.random.rand(2, num_sample_frames, feature_dims)
num_frames = tf.constant([num_sample_frames, num_sample_frames])
else:
# Add padding frames.
inputs = np.random.rand(2, num_frames + 4, feature_dims)
num_frames = tf.constant([num_frames, num_frames + 1])
predictions = model(inputs)['predictions']
predictions = model(inputs, num_frames=num_frames)['predictions']
self.assertAllEqual([2, num_classes], predictions.numpy().shape)
def test_serialize_deserialize(self):
......
......@@ -52,7 +52,10 @@ class YT8MTask(base_task.Task):
)
# Warmup calls to build model variables.
_ = model(tf.keras.Input(common_input_shape, dtype=tf.float32))
_ = model(
inputs=tf.keras.Input(common_input_shape, dtype=tf.float32),
num_frames=tf.keras.Input([], dtype=tf.float32),
)
non_trainable_batch_norm_variables = []
non_trainable_extra_variables = []
......@@ -242,10 +245,16 @@ class YT8MTask(base_task.Task):
def _preprocess_model_inputs(
self,
inputs: dict[str, tf.Tensor],
require_num_frames: bool = True,
training: bool = True,
):
"""Preprocesses input tensors before model on device."""
extra_inputs = {
'num_frames': (
tf.reshape(inputs['num_frames'], [-1])
if require_num_frames
else None
),
'training': training,
}
return inputs['video_matrix'], extra_inputs
......@@ -286,8 +295,12 @@ class YT8MTask(base_task.Task):
Returns:
a dictionary of logs.
"""
# Will require `num_frames` if `num_sample_frames` is None since
# video_matrix is padded to max_frames in this case.
require_num_frames = self.task_config.train_data.num_sample_frames is None
inputs_tensor, extra_inputs = self._preprocess_model_inputs(
inputs,
require_num_frames=require_num_frames,
training=True,
)
labels, label_weights = self._preprocess_labels(inputs, training=True)
......@@ -361,7 +374,14 @@ class YT8MTask(base_task.Task):
Returns:
a dictionary of logs.
"""
outputs = self.inference_step(model, inputs)['predictions']
# Will require `num_frames` if `num_sample_frames` is None since
# video_matrix is padded to max_frames in this case.
require_num_frames = (
self.task_config.validation_data.num_sample_frames is None
)
outputs = self.inference_step(
model, inputs, require_num_frames=require_num_frames
)['predictions']
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
labels, label_weights = self._preprocess_labels(inputs, training=False)
outputs, labels, label_weights = self._postprocess_outputs(
......@@ -389,10 +409,10 @@ class YT8MTask(base_task.Task):
return logs
def inference_step(self, model, inputs):
def inference_step(self, model, inputs, require_num_frames=True):
"""Performs the forward step."""
model_inputs, extra_inputs = self._preprocess_model_inputs(
inputs, training=False
inputs, require_num_frames=require_num_frames, training=False
)
return model(model_inputs, **extra_inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册