Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
a computer's friend
Chinese-Text-Classification-Pytorch
提交
5cf89a66
C
Chinese-Text-Classification-Pytorch
项目概览
a computer's friend
/
Chinese-Text-Classification-Pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
Chinese-Text-Classification-Pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
5cf89a66
编写于
7月 11, 2019
作者:
滴水无痕0801
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
42526613
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
23 addition
and
7 deletion
+23
-7
train_eval.py
train_eval.py
+3
-6
utils_fasttext.py
utils_fasttext.py
+20
-1
未找到文件。
train_eval.py
浏览文件 @
5cf89a66
...
...
@@ -8,6 +8,7 @@ import time
from
utils
import
get_time_dif
# 权重初始化,默认xavier
def
init_network
(
model
,
method
=
'xavier'
,
exclude
=
'embedding'
,
seed
=
123
):
for
name
,
w
in
model
.
named_parameters
():
if
exclude
not
in
name
:
...
...
@@ -20,17 +21,16 @@ def init_network(model, method='xavier', exclude='embedding', seed=123):
nn
.
init
.
normal_
(
w
)
elif
'bias'
in
name
:
nn
.
init
.
constant_
(
w
,
0
)
else
:
else
:
pass
def
train
(
config
,
model
,
train_iter
,
dev_iter
,
test_iter
):
start_time
=
time
.
time
()
model
.
train
()
# criterion = nn.CrossEntropyLoss()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
config
.
learning_rate
)
# 学习率指数衰减,每次epoch:学习率
×= gamma
# 学习率指数衰减,每次epoch:学习率
= gamma * 学习率
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
total_batch
=
0
# 记录进行到多少batch
dev_best_loss
=
float
(
'inf'
)
...
...
@@ -43,10 +43,7 @@ def train(config, model, train_iter, dev_iter, test_iter):
for
i
,
(
trains
,
labels
)
in
enumerate
(
train_iter
):
outputs
=
model
(
trains
)
model
.
zero_grad
()
# print(outputs.size())
# print(labels.size())
loss
=
F
.
cross_entropy
(
outputs
,
labels
)
# loss = criterion(outputs, labels)
loss
.
backward
()
optimizer
.
step
()
if
total_batch
%
100
==
0
:
...
...
utils_fasttext.py
浏览文件 @
5cf89a66
...
...
@@ -133,7 +133,6 @@ class DatasetIterater(object):
return
self
.
n_batches
+
1
else
:
return
self
.
n_batches
def
build_iterator
(
dataset
,
config
):
...
...
@@ -146,3 +145,23 @@ def get_time_dif(start_time):
end_time
=
time
.
time
()
time_dif
=
end_time
-
start_time
return
timedelta
(
seconds
=
int
(
round
(
time_dif
)))
if
__name__
==
"__main__"
:
'''提取预训练词向量'''
vocab_dir
=
"./THUCNews/data/vocab.pkl"
pretrain_dir
=
"./THUCNews/data/sgns.sogou.char"
emb_dim
=
300
filename_trimmed_dir
=
"./THUCNews/data/vocab.embedding.sougou"
word_to_id
=
pkl
.
load
(
open
(
vocab_dir
,
'rb'
))
embeddings
=
np
.
random
.
rand
(
len
(
word_to_id
),
emb_dim
)
f
=
open
(
pretrain_dir
,
"r"
,
encoding
=
'UTF-8'
)
for
i
,
line
in
enumerate
(
f
.
readlines
()):
# if i == 0: # 若第一行是标题,则跳过
# continue
lin
=
line
.
strip
().
split
(
" "
)
if
lin
[
0
]
in
word_to_id
:
idx
=
word_to_id
[
lin
[
0
]]
emb
=
[
float
(
x
)
for
x
in
lin
[
1
:
301
]]
embeddings
[
idx
]
=
np
.
asarray
(
emb
,
dtype
=
'float32'
)
f
.
close
()
np
.
savez_compressed
(
filename_trimmed_dir
,
embeddings
=
embeddings
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录