Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
ba05e1a7
M
Models
项目概览
曾经的那一瞬间
/
Models
9 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
ba05e1a7
编写于
8月 03, 2023
作者:
C
Chaochao Yan
提交者:
A. Unique TensorFlower
8月 03, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
No public description
PiperOrigin-RevId: 553700524
上级
3e15aa4a
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
37 addition
and
10 deletion
+37
-10
official/projects/yt8m/modeling/backbones/dbof.py
official/projects/yt8m/modeling/backbones/dbof.py
+3
-2
official/projects/yt8m/modeling/backbones/dbof_test.py
official/projects/yt8m/modeling/backbones/dbof_test.py
+1
-1
official/projects/yt8m/modeling/yt8m_model.py
official/projects/yt8m/modeling/yt8m_model.py
+6
-2
official/projects/yt8m/modeling/yt8m_model_test.py
official/projects/yt8m/modeling/yt8m_model_test.py
+3
-1
official/projects/yt8m/tasks/yt8m_task.py
official/projects/yt8m/tasks/yt8m_task.py
+24
-4
未找到文件。
official/projects/yt8m/modeling/backbones/dbof.py
浏览文件 @
ba05e1a7
...
...
@@ -15,7 +15,7 @@
"""Dbof model definitions."""
import
functools
from
typing
import
Optional
from
typing
import
Any
,
Optional
import
tensorflow
as
tf
...
...
@@ -124,7 +124,7 @@ class Dbof(layers.Layer):
)
def
call
(
self
,
inputs
:
tf
.
Tensor
self
,
inputs
:
tf
.
Tensor
,
num_frames
:
Any
=
None
,
)
->
tf
.
Tensor
:
# L2 normalize input features
activation
=
tf
.
nn
.
l2_normalize
(
inputs
,
-
1
)
...
...
@@ -147,6 +147,7 @@ class Dbof(layers.Layer):
activation
=
yt8m_model_utils
.
frame_pooling
(
activation
,
method
=
self
.
_params
.
pooling_method
,
num_frames
=
num_frames
,
)
activation
=
self
.
_hidden_dense
(
activation
)
...
...
official/projects/yt8m/modeling/backbones/dbof_test.py
浏览文件 @
ba05e1a7
...
...
@@ -50,7 +50,7 @@ class DbofTest(parameterized.TestCase, tf.test.TestCase):
)
inputs
=
tf
.
ones
([
2
,
24
,
32
],
dtype
=
tf
.
float32
)
outputs
=
backbone
(
inputs
)
outputs
=
backbone
(
inputs
,
num_frames
=
tf
.
constant
([
24
,
16
])
)
self
.
assertAllEqual
(
outputs
.
shape
.
as_list
(),
[
2
,
20
])
...
...
official/projects/yt8m/modeling/yt8m_model.py
浏览文件 @
ba05e1a7
...
...
@@ -131,10 +131,14 @@ class VideoClassificationModel(tf.keras.Model):
return
cls
(
**
config
)
def
call
(
self
,
inputs_tensor
:
tf
.
Tensor
,
training
:
Any
=
None
self
,
inputs
:
tf
.
Tensor
,
num_frames
:
Any
=
None
,
training
:
Any
=
None
,
)
->
dict
[
str
,
tf
.
Tensor
]:
features
=
self
.
backbone
(
inputs_tensor
,
inputs
,
num_frames
=
num_frames
,
training
=
training
,
)
outputs
=
self
.
head
(
features
,
training
=
training
)
...
...
official/projects/yt8m/modeling/yt8m_model_test.py
浏览文件 @
ba05e1a7
...
...
@@ -55,11 +55,13 @@ class YT8MNetworkTest(parameterized.TestCase, tf.test.TestCase):
# batch = 2 -> arbitrary value for test.
if
num_sample_frames
:
inputs
=
np
.
random
.
rand
(
2
,
num_sample_frames
,
feature_dims
)
num_frames
=
tf
.
constant
([
num_sample_frames
,
num_sample_frames
])
else
:
# Add padding frames.
inputs
=
np
.
random
.
rand
(
2
,
num_frames
+
4
,
feature_dims
)
num_frames
=
tf
.
constant
([
num_frames
,
num_frames
+
1
])
predictions
=
model
(
inputs
)[
'predictions'
]
predictions
=
model
(
inputs
,
num_frames
=
num_frames
)[
'predictions'
]
self
.
assertAllEqual
([
2
,
num_classes
],
predictions
.
numpy
().
shape
)
def
test_serialize_deserialize
(
self
):
...
...
official/projects/yt8m/tasks/yt8m_task.py
浏览文件 @
ba05e1a7
...
...
@@ -52,7 +52,10 @@ class YT8MTask(base_task.Task):
)
# Warmup calls to build model variables.
_
=
model
(
tf
.
keras
.
Input
(
common_input_shape
,
dtype
=
tf
.
float32
))
_
=
model
(
inputs
=
tf
.
keras
.
Input
(
common_input_shape
,
dtype
=
tf
.
float32
),
num_frames
=
tf
.
keras
.
Input
([],
dtype
=
tf
.
float32
),
)
non_trainable_batch_norm_variables
=
[]
non_trainable_extra_variables
=
[]
...
...
@@ -242,10 +245,16 @@ class YT8MTask(base_task.Task):
def
_preprocess_model_inputs
(
self
,
inputs
:
dict
[
str
,
tf
.
Tensor
],
require_num_frames
:
bool
=
True
,
training
:
bool
=
True
,
):
"""Preprocesses input tensors before model on device."""
extra_inputs
=
{
'num_frames'
:
(
tf
.
reshape
(
inputs
[
'num_frames'
],
[
-
1
])
if
require_num_frames
else
None
),
'training'
:
training
,
}
return
inputs
[
'video_matrix'
],
extra_inputs
...
...
@@ -286,8 +295,12 @@ class YT8MTask(base_task.Task):
Returns:
a dictionary of logs.
"""
# Will require `num_frames` if `num_sample_frames` is None since
# video_matrix is padded to max_frames in this case.
require_num_frames
=
self
.
task_config
.
train_data
.
num_sample_frames
is
None
inputs_tensor
,
extra_inputs
=
self
.
_preprocess_model_inputs
(
inputs
,
require_num_frames
=
require_num_frames
,
training
=
True
,
)
labels
,
label_weights
=
self
.
_preprocess_labels
(
inputs
,
training
=
True
)
...
...
@@ -361,7 +374,14 @@ class YT8MTask(base_task.Task):
Returns:
a dictionary of logs.
"""
outputs
=
self
.
inference_step
(
model
,
inputs
)[
'predictions'
]
# Will require `num_frames` if `num_sample_frames` is None since
# video_matrix is padded to max_frames in this case.
require_num_frames
=
(
self
.
task_config
.
validation_data
.
num_sample_frames
is
None
)
outputs
=
self
.
inference_step
(
model
,
inputs
,
require_num_frames
=
require_num_frames
)[
'predictions'
]
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
labels
,
label_weights
=
self
.
_preprocess_labels
(
inputs
,
training
=
False
)
outputs
,
labels
,
label_weights
=
self
.
_postprocess_outputs
(
...
...
@@ -389,10 +409,10 @@ class YT8MTask(base_task.Task):
return
logs
def
inference_step
(
self
,
model
,
inputs
):
def
inference_step
(
self
,
model
,
inputs
,
require_num_frames
=
True
):
"""Performs the forward step."""
model_inputs
,
extra_inputs
=
self
.
_preprocess_model_inputs
(
inputs
,
training
=
False
inputs
,
require_num_frames
=
require_num_frames
,
training
=
False
)
return
model
(
model_inputs
,
**
extra_inputs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录