Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Models
提交
c34507f2
M
Models
项目概览
曾经的那一瞬间
/
Models
9 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
c34507f2
编写于
8月 24, 2023
作者:
T
Tyler Scott
提交者:
A. Unique TensorFlower
8月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
No public description
PiperOrigin-RevId: 559849502
上级
564ad533
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
46 addition
and
15 deletion
+46
-15
official/projects/pix2seq/modeling/pix2seq_model.py
official/projects/pix2seq/modeling/pix2seq_model.py
+10
-13
official/projects/pix2seq/modeling/pix2seq_model_test.py
official/projects/pix2seq/modeling/pix2seq_model_test.py
+36
-2
未找到文件。
official/projects/pix2seq/modeling/pix2seq_model.py
浏览文件 @
c34507f2
...
...
@@ -335,6 +335,7 @@ class Pix2Seq(tf.keras.Model):
inputs
:
tf
.
Tensor
,
targets
:
Optional
[
tf
.
Tensor
]
=
None
,
training
:
bool
=
None
,
use_teacher_forcing_for_eval
:
bool
=
False
)
->
List
[
Any
]:
features
=
self
.
_backbone
(
inputs
)[
self
.
_backbone_endpoint_name
]
mask
=
tf
.
ones_like
(
features
)
...
...
@@ -350,22 +351,18 @@ class Pix2Seq(tf.keras.Model):
pos_emb
=
tf
.
cast
(
pos_emb
,
features
.
dtype
)
tokens
=
None
inputs
=
{
"inputs"
:
features
,
"tokens"
:
targets
,
"pos_emb"
:
pos_emb
,
}
if
training
:
logits
=
self
.
_transformer
(
{
"inputs"
:
features
,
"tokens"
:
targets
,
"pos_emb"
:
pos_emb
,
},
training
,
)
logits
=
self
.
_transformer
(
inputs
,
training
=
True
)
elif
use_teacher_forcing_for_eval
:
logits
=
self
.
_transformer
(
inputs
,
training
=
False
)
else
:
tokens
,
logits
=
self
.
_transformer
.
infer
(
{
"inputs"
:
features
,
"tokens"
:
targets
,
"pos_emb"
:
pos_emb
,
},
inputs
,
top_k
=
self
.
_top_k
,
top_p
=
self
.
_top_p
,
)
...
...
official/projects/pix2seq/modeling/pix2seq_model_test.py
浏览文件 @
c34507f2
...
...
@@ -30,7 +30,11 @@ class Pix2SeqTest(tf.test.TestCase):
backbone
=
resnet
.
ResNet
(
50
,
bn_trainable
=
False
)
backbone_endpoint_name
=
'5'
model
=
pix2seq_model
.
Pix2Seq
(
backbone
,
backbone_endpoint_name
,
max_seq_len
,
vocab_size
,
hidden_size
,
backbone
,
backbone_endpoint_name
,
max_seq_len
,
vocab_size
,
hidden_size
,
num_heads
=
num_heads
,
)
_
,
outs
=
model
(
...
...
@@ -41,6 +45,32 @@ class Pix2SeqTest(tf.test.TestCase):
self
.
assertLen
(
outs
,
2
)
# intermediate decoded outputs.
def
test_forward_infer_teacher_forcing
(
self
):
hidden_size
=
256
num_heads
=
8
max_seq_len
=
50
vocab_size
=
164
image_size
=
224
batch_size
=
2
backbone
=
resnet
.
ResNet
(
50
,
bn_trainable
=
False
)
backbone_endpoint_name
=
'5'
model
=
pix2seq_model
.
Pix2Seq
(
backbone
,
backbone_endpoint_name
,
max_seq_len
,
vocab_size
,
hidden_size
,
num_heads
=
num_heads
,
)
_
,
outs
=
model
(
tf
.
ones
((
batch_size
,
image_size
,
image_size
,
3
)),
tf
.
ones
((
batch_size
,
max_seq_len
),
tf
.
int64
),
training
=
False
,
use_teacher_forcing_for_eval
=
True
,
)
self
.
assertLen
(
outs
,
2
)
# intermediate decoded outputs.
def
test_forward_infer
(
self
):
hidden_size
=
256
num_heads
=
8
...
...
@@ -51,7 +81,11 @@ class Pix2SeqTest(tf.test.TestCase):
backbone
=
resnet
.
ResNet
(
50
,
bn_trainable
=
False
)
backbone_endpoint_name
=
'5'
model
=
pix2seq_model
.
Pix2Seq
(
backbone
,
backbone_endpoint_name
,
max_seq_len
,
vocab_size
,
hidden_size
,
backbone
,
backbone_endpoint_name
,
max_seq_len
,
vocab_size
,
hidden_size
,
num_heads
=
num_heads
,
)
tokens
,
_
=
model
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录