Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_38786831
tcn
提交
ec2ee103
T
tcn
项目概览
weixin_38786831
/
tcn
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tcn
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
ec2ee103
编写于
3月 08, 2019
作者:
S
Shaojie Bai
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update for pytorch 1.0 with nograd
上级
a13a6b82
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
76 addition
and
70 deletion
+76
-70
TCN/adding_problem/add_test.py
TCN/adding_problem/add_test.py
+5
-4
TCN/copy_memory/copymem_test.py
TCN/copy_memory/copymem_test.py
+9
-8
TCN/lambada_language/lambada_test.py
TCN/lambada_language/lambada_test.py
+13
-12
TCN/mnist_pixel/pmnist_test.py
TCN/mnist_pixel/pmnist_test.py
+17
-16
TCN/poly_music/music_test.py
TCN/poly_music/music_test.py
+14
-13
TCN/word_cnn/word_cnn_test.py
TCN/word_cnn/word_cnn_test.py
+18
-17
未找到文件。
TCN/adding_problem/add_test.py
浏览文件 @
ec2ee103
...
...
@@ -100,6 +100,7 @@ def train(epoch):
def
evaluate
():
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
X_test
)
test_loss
=
F
.
mse_loss
(
output
,
Y_test
)
print
(
'
\n
Test set: Average loss: {:.6f}
\n
'
.
format
(
test_loss
.
item
()))
...
...
TCN/copy_memory/copymem_test.py
浏览文件 @
ec2ee103
...
...
@@ -85,6 +85,7 @@ optimizer = getattr(optim, args.optim)(model.parameters(), lr=lr)
def
evaluate
():
model
.
eval
()
with
torch
.
no_grad
():
out
=
model
(
test_x
.
unsqueeze
(
1
).
contiguous
())
loss
=
criterion
(
out
.
view
(
-
1
,
n_classes
),
test_y
.
view
(
-
1
))
pred
=
out
.
view
(
-
1
,
n_classes
).
data
.
max
(
1
,
keepdim
=
True
)[
1
]
...
...
TCN/lambada_language/lambada_test.py
浏览文件 @
ec2ee103
...
...
@@ -88,6 +88,7 @@ def evaluate(data_source):
total_loss
=
0
processed_data_size
=
0
correct
=
0
with
torch
.
no_grad
():
for
i
in
range
(
len
(
data_source
)):
data
,
targets
=
torch
.
LongTensor
(
data_source
[
i
]).
view
(
1
,
-
1
),
torch
.
LongTensor
([
data_source
[
i
][
-
1
]]).
view
(
1
,
-
1
)
data
,
targets
=
Variable
(
data
),
Variable
(
targets
)
...
...
TCN/mnist_pixel/pmnist_test.py
浏览文件 @
ec2ee103
...
...
@@ -97,6 +97,7 @@ def test():
model
.
eval
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
if
args
.
cuda
:
data
,
target
=
data
.
cuda
(),
target
.
cuda
()
...
...
TCN/poly_music/music_test.py
浏览文件 @
ec2ee103
...
...
@@ -68,6 +68,7 @@ def evaluate(X_data, name='Eval'):
eval_idx_list
=
np
.
arange
(
len
(
X_data
),
dtype
=
"int32"
)
total_loss
=
0.0
count
=
0
with
torch
.
no_grad
():
for
idx
in
eval_idx_list
:
data_line
=
X_data
[
idx
]
x
,
y
=
Variable
(
data_line
[:
-
1
]),
Variable
(
data_line
[
1
:])
...
...
TCN/word_cnn/word_cnn_test.py
浏览文件 @
ec2ee103
...
...
@@ -92,6 +92,7 @@ def evaluate(data_source):
model
.
eval
()
total_loss
=
0
processed_data_size
=
0
with
torch
.
no_grad
():
for
i
in
range
(
0
,
data_source
.
size
(
1
)
-
1
,
args
.
validseqlen
):
if
i
+
args
.
seq_len
-
args
.
validseqlen
>=
data_source
.
size
(
1
)
-
1
:
continue
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录