Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
a computer's friend
Chinese-Text-Classification-Pytorch
提交
6f5d4856
C
Chinese-Text-Classification-Pytorch
项目概览
a computer's friend
/
Chinese-Text-Classification-Pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
Chinese-Text-Classification-Pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
6f5d4856
编写于
7月 15, 2019
作者:
滴水无痕0801
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
a3354345
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
96 addition
and
3 deletion
+96
-3
README.md
README.md
+4
-0
models/DPCNN.py
models/DPCNN.py
+87
-0
models/TextRNN.py
models/TextRNN.py
+1
-1
models/TextRNN_Att.py
models/TextRNN_Att.py
+2
-0
run.py
run.py
+2
-2
未找到文件。
README.md
浏览文件 @
6f5d4856
...
@@ -43,6 +43,7 @@ TextRNN|91.12%|BiLSTM
...
@@ -43,6 +43,7 @@ TextRNN|91.12%|BiLSTM
TextRNN_Att|90.90%|BiLSTM+Attention
TextRNN_Att|90.90%|BiLSTM+Attention
TextRCNN|91.54%|BiLSTM+池化
TextRCNN|91.54%|BiLSTM+池化
FastText|92.23%|bow+bigram+trigram, 效果出奇的好
FastText|92.23%|bow+bigram+trigram, 效果出奇的好
DPCNN|91.25%|深层金字塔CNN
## 使用说明
## 使用说明
```
```
...
@@ -61,6 +62,9 @@ python run.py --model TextRCNN
...
@@ -61,6 +62,9 @@ python run.py --model TextRCNN
# FastText, embedding层是随机初始化的
# FastText, embedding层是随机初始化的
python run.py --model FastText --embedding random
python run.py --model FastText --embedding random
# DPCNN
python run.py --model DPCNN
```
```
### 参数
### 参数
...
...
models/DPCNN.py
0 → 100644
浏览文件 @
6f5d4856
# coding: UTF-8
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
class
Config
(
object
):
"""配置参数"""
def
__init__
(
self
,
dataset
,
embedding
):
self
.
train_path
=
dataset
+
'/data/train.txt'
# 训练集
self
.
dev_path
=
dataset
+
'/data/dev.txt'
# 验证集
self
.
test_path
=
dataset
+
'/data/test.txt'
# 测试集
self
.
class_list
=
[
x
.
strip
()
for
x
in
open
(
dataset
+
'/data/class.txt'
).
readlines
()]
# 类别名单
self
.
vocab_path
=
dataset
+
'/data/vocab.pkl'
# 词表
self
.
save_path
=
dataset
+
'/saved_dict/TextCNN.ckpt'
# 模型训练结果
self
.
embedding_pretrained
=
torch
.
tensor
(
np
.
load
(
dataset
+
'/data/'
+
embedding
)[
"embeddings"
].
astype
(
'float32'
))
\
if
embedding
!=
'random'
else
None
# 预训练词向量
self
.
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
# 设备
self
.
dropout
=
0.5
# 随机失活
self
.
require_improvement
=
1000
# 若超过1000batch效果还没提升,则提前结束训练
self
.
num_classes
=
len
(
self
.
class_list
)
# 类别数
self
.
n_vocab
=
0
# 词表大小,在运行时赋值
self
.
num_epochs
=
20
# epoch数
self
.
batch_size
=
128
# mini-batch大小
self
.
pad_size
=
32
# 每句话处理成的长度(短填长切)
self
.
learning_rate
=
1e-3
# 学习率
self
.
embed
=
self
.
embedding_pretrained
.
size
(
1
)
\
if
self
.
embedding_pretrained
is
not
None
else
300
# 字向量维度
self
.
num_filters
=
250
# 卷积核数量(channels数)
'''Deep Pyramid Convolutional Neural Networks for Text Categorization'''
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
Model
,
self
).
__init__
()
if
config
.
embedding_pretrained
is
not
None
:
self
.
embedding
=
nn
.
Embedding
.
from_pretrained
(
config
.
embedding_pretrained
,
freeze
=
False
)
else
:
self
.
embedding
=
nn
.
Embedding
(
config
.
n_vocab
,
config
.
embed
,
padding_idx
=
config
.
n_vocab
-
1
)
self
.
conv_region
=
nn
.
Conv2d
(
1
,
config
.
num_filters
,
(
3
,
config
.
embed
),
stride
=
1
)
self
.
conv
=
nn
.
Conv2d
(
config
.
num_filters
,
config
.
num_filters
,
(
3
,
1
),
stride
=
1
)
self
.
max_pool
=
nn
.
MaxPool2d
(
kernel_size
=
(
3
,
1
),
stride
=
2
)
self
.
padding1
=
nn
.
ZeroPad2d
((
0
,
0
,
1
,
1
))
# top bottom
self
.
padding2
=
nn
.
ZeroPad2d
((
0
,
0
,
0
,
1
))
# bottom
self
.
relu
=
nn
.
ReLU
()
self
.
fc
=
nn
.
Linear
(
config
.
num_filters
,
config
.
num_classes
)
def
forward
(
self
,
x
):
x
=
x
[
0
]
x
=
self
.
embedding
(
x
)
x
=
x
.
unsqueeze
(
1
)
# [batch_size, 1, seq_len, 1]
x
=
self
.
conv_region
(
x
)
# [batch_size, 1, seq_len-3+1, 1]
x
=
self
.
padding1
(
x
)
# [batch_size, 1, seq_len, 1]
x
=
self
.
relu
(
x
)
x
=
self
.
conv
(
x
)
# [batch_size, 1, seq_len-3+1, 1]
x
=
self
.
padding1
(
x
)
# [batch_size, 1, seq_len, 1]
x
=
self
.
relu
(
x
)
x
=
self
.
conv
(
x
)
# [batch_size, 1, seq_len-3+1, 1]
while
x
.
size
()[
2
]
>
2
:
x
=
self
.
_block
(
x
)
x
=
x
.
squeeze
()
# [batch_size, num_filters]
x
=
self
.
fc
(
x
)
return
x
def
_block
(
self
,
x
):
x
=
self
.
padding2
(
x
)
px
=
self
.
max_pool
(
x
)
x
=
self
.
padding1
(
px
)
x
=
F
.
relu
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
padding1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
conv
(
x
)
# Short Cut
x
=
x
+
px
return
x
models/TextRNN.py
浏览文件 @
6f5d4856
...
@@ -56,7 +56,7 @@ class Model(nn.Module):
...
@@ -56,7 +56,7 @@ class Model(nn.Module):
out
=
self
.
fc
(
out
[:,
-
1
,
:])
# 句子最后时刻的 hidden state
out
=
self
.
fc
(
out
[:,
-
1
,
:])
# 句子最后时刻的 hidden state
return
out
return
out
'''变长RNN'''
'''变长RNN
,效果差不多,甚至还低了点...
'''
# def forward(self, x):
# def forward(self, x):
# x, seq_len = x
# x, seq_len = x
# out = self.embedding(x)
# out = self.embedding(x)
...
...
models/TextRNN_Att.py
浏览文件 @
6f5d4856
...
@@ -49,6 +49,7 @@ class Model(nn.Module):
...
@@ -49,6 +49,7 @@ class Model(nn.Module):
self
.
lstm
=
nn
.
LSTM
(
config
.
embed
,
config
.
hidden_size
,
config
.
num_layers
,
self
.
lstm
=
nn
.
LSTM
(
config
.
embed
,
config
.
hidden_size
,
config
.
num_layers
,
bidirectional
=
True
,
batch_first
=
True
,
dropout
=
config
.
dropout
)
bidirectional
=
True
,
batch_first
=
True
,
dropout
=
config
.
dropout
)
self
.
tanh1
=
nn
.
Tanh
()
self
.
tanh1
=
nn
.
Tanh
()
# self.u = nn.Parameter(torch.Tensor(config.hidden_size * 2, config.hidden_size * 2))
self
.
w
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
*
2
))
self
.
w
=
nn
.
Parameter
(
torch
.
Tensor
(
config
.
hidden_size
*
2
))
self
.
tanh2
=
nn
.
Tanh
()
self
.
tanh2
=
nn
.
Tanh
()
self
.
fc1
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size2
)
self
.
fc1
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size2
)
...
@@ -60,6 +61,7 @@ class Model(nn.Module):
...
@@ -60,6 +61,7 @@ class Model(nn.Module):
H
,
_
=
self
.
lstm
(
emb
)
# [batch_size, seq_len, hidden_size * num_direction]=[128, 32, 256]
H
,
_
=
self
.
lstm
(
emb
)
# [batch_size, seq_len, hidden_size * num_direction]=[128, 32, 256]
M
=
self
.
tanh1
(
H
)
# [128, 32, 256]
M
=
self
.
tanh1
(
H
)
# [128, 32, 256]
# M = torch.tanh(torch.matmul(H, self.u))
alpha
=
F
.
softmax
(
torch
.
matmul
(
M
,
self
.
w
),
dim
=
1
).
unsqueeze
(
-
1
)
# [128, 32, 1]
alpha
=
F
.
softmax
(
torch
.
matmul
(
M
,
self
.
w
),
dim
=
1
).
unsqueeze
(
-
1
)
# [128, 32, 1]
out
=
H
*
alpha
# [128, 32, 256]
out
=
H
*
alpha
# [128, 32, 256]
out
=
torch
.
sum
(
out
,
1
)
# [128, 256]
out
=
torch
.
sum
(
out
,
1
)
# [128, 256]
...
...
run.py
浏览文件 @
6f5d4856
...
@@ -7,7 +7,7 @@ from importlib import import_module
...
@@ -7,7 +7,7 @@ from importlib import import_module
import
argparse
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
'Chinese Text Classification'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Chinese Text Classification'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att
, DPCNN
'
)
parser
.
add_argument
(
'--embedding'
,
default
=
'pre_trained'
,
type
=
str
,
help
=
'random or pre_trained'
)
parser
.
add_argument
(
'--embedding'
,
default
=
'pre_trained'
,
type
=
str
,
help
=
'random or pre_trained'
)
parser
.
add_argument
(
'--word'
,
default
=
False
,
type
=
bool
,
help
=
'True for word, False for char'
)
parser
.
add_argument
(
'--word'
,
default
=
False
,
type
=
bool
,
help
=
'True for word, False for char'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -20,7 +20,7 @@ if __name__ == '__main__':
...
@@ -20,7 +20,7 @@ if __name__ == '__main__':
embedding
=
'embedding_SougouNews.npz'
embedding
=
'embedding_SougouNews.npz'
if
args
.
embedding
==
'random'
:
if
args
.
embedding
==
'random'
:
embedding
=
'random'
embedding
=
'random'
model_name
=
args
.
model
# 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att
model_name
=
args
.
model
# 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att
, DPCNN
if
model_name
==
'FastText'
:
if
model_name
==
'FastText'
:
from
utils_fasttext
import
build_dataset
,
build_iterator
,
get_time_dif
from
utils_fasttext
import
build_dataset
,
build_iterator
,
get_time_dif
embedding
=
'random'
embedding
=
'random'
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录