提交 bae839cf 编写于 作者: A A. Unique TensorFlower

Add SWAP pooling to YT8M open-source code base.

PiperOrigin-RevId: 552603688
上级 e11f5294
......@@ -19,6 +19,45 @@ from typing import Any, Dict, Optional, Union
import tensorflow as tf
def weighted_average_pooling(features, weights, axis):
"""Weighted average pooling.
Args:
features: a tensor of at least rank 1.
weights: a weight tensor whose shape is broadcast compatible with features.
It doesn't have to be normalized.
axis: the dimensions to reduce.
Returns:
The reduced tensor.
"""
return tf.math.divide_no_nan(
tf.reduce_sum(weights * features, axis), # numerator.
tf.reduce_sum(weights, axis), # denominator.
)
def frame_swap(frames: tf.Tensor) -> tf.Tensor:
"""Self-weighted average pooling over all frames of a video.
It does the following operation independently for each feature:
x_pooled = (sum_i x_i * |x_i|) / (sum_i |x_i|).
Basically the weight for the feature in each frame is determined by the
magnitude of the feature itself.
Paper: https://research.google/pubs/pub48351/
Args:
frames: A tensor with shape [batch_size, max_frames, feature_size].
Returns:
A tensor with shape [batch_size, feature_size].
"""
weights = tf.abs(frames)
# We set axis to 1 to reduce the dimension corresponding to max_frames.
return weighted_average_pooling(frames, weights, axis=1)
def frame_pooling(frames, method):
"""Pools over the frames of a video.
......@@ -39,6 +78,11 @@ def frame_pooling(frames, method):
reduced = tf.reduce_mean(frames, 1)
elif method == "max":
reduced = tf.reduce_max(frames, 1)
elif method == "swap":
# Note we assume the frames are in the shape of
# [batch_size, num_frames, feature_size]. Otherwise this function might
# fail.
reduced = frame_swap(frames)
elif method == "none":
feature_size = frames.shape_as_list()[2]
reduced = tf.reshape(frames, [-1, feature_size])
......
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for YT8M modeling utilities."""
import tensorflow as tf
from official.projects.yt8m.modeling import yt8m_model_utils
class Yt8MModelUtilsTest(tf.test.TestCase):
def test_swap_pooling(self):
frame = tf.constant([
[[0.0, 0.0, 0.0], [0.0, 1.0, -1.0]],
[[0.0, 0.0, 0.0], [0.0, 2.0, -2.0]],
])
swap_frame = yt8m_model_utils.frame_pooling(frame, "swap")
self.assertAllClose([[0.0, 1.0, -1.0], [0.0, 2.0, -2.0]], swap_frame)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册