diff --git a/README.md b/README.md index d8b80d61fa05feab86df877de5c59cb34a666efc..9bffb8b160434c3e0ae4352d38e194ba0e15aaae 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ TextRNN|91.12%|BiLSTM TextRNN_Att|90.90%|BiLSTM+Attention TextRCNN|91.54%|BiLSTM+池化 FastText|92.23%|bow+bigram+trigram, 效果出奇的好 +DPCNN|91.25%|深层金字塔CNN ## 使用说明 ``` @@ -61,6 +62,9 @@ python run.py --model TextRCNN # FastText, embedding层是随机初始化的 python run.py --model FastText --embedding random + +# DPCNN +python run.py --model DPCNN ``` ### 参数 diff --git a/models/DPCNN.py b/models/DPCNN.py new file mode 100644 index 0000000000000000000000000000000000000000..6d99bb542cd13c3db460d4880763411a49d6ee44 --- /dev/null +++ b/models/DPCNN.py @@ -0,0 +1,87 @@ +# coding: UTF-8 +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +class Config(object): + + """配置参数""" + def __init__(self, dataset, embedding): + self.train_path = dataset + '/data/train.txt' # 训练集 + self.dev_path = dataset + '/data/dev.txt' # 验证集 + self.test_path = dataset + '/data/test.txt' # 测试集 + self.class_list = [x.strip() for x in open( + dataset + '/data/class.txt').readlines()] # 类别名单 + self.vocab_path = dataset + '/data/vocab.pkl' # 词表 + self.save_path = dataset + '/saved_dict/TextCNN.ckpt' # 模型训练结果 + self.embedding_pretrained = torch.tensor( + np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\ + if embedding != 'random' else None # 预训练词向量 + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 + + self.dropout = 0.5 # 随机失活 + self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 + self.num_classes = len(self.class_list) # 类别数 + self.n_vocab = 0 # 词表大小,在运行时赋值 + self.num_epochs = 20 # epoch数 + self.batch_size = 128 # mini-batch大小 + self.pad_size = 32 # 每句话处理成的长度(短填长切) + self.learning_rate = 1e-3 # 学习率 + self.embed = self.embedding_pretrained.size(1)\ + if self.embedding_pretrained is not None else 300 # 字向量维度 + self.num_filters = 250 # 卷积核数量(channels数) + + +'''Deep Pyramid Convolutional Neural Networks for Text Categorization''' + + +class Model(nn.Module): + def __init__(self, config): + super(Model, self).__init__() + if config.embedding_pretrained is not None: + self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False) + else: + self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1) + self.conv_region = nn.Conv2d(1, config.num_filters, (3, config.embed), stride=1) + self.conv = nn.Conv2d(config.num_filters, config.num_filters, (3, 1), stride=1) + self.max_pool = nn.MaxPool2d(kernel_size=(3, 1), stride=2) + self.padding1 = nn.ZeroPad2d((0, 0, 1, 1)) # top bottom + self.padding2 = nn.ZeroPad2d((0, 0, 0, 1)) # bottom + self.relu = nn.ReLU() + self.fc = nn.Linear(config.num_filters, config.num_classes) + + def forward(self, x): + x = x[0] + x = self.embedding(x) + x = x.unsqueeze(1) # [batch_size, 1, seq_len, 1] + x = self.conv_region(x) # [batch_size, 1, seq_len-3+1, 1] + + x = self.padding1(x) # [batch_size, 1, seq_len, 1] + x = self.relu(x) + x = self.conv(x) # [batch_size, 1, seq_len-3+1, 1] + x = self.padding1(x) # [batch_size, 1, seq_len, 1] + x = self.relu(x) + x = self.conv(x) # [batch_size, 1, seq_len-3+1, 1] + while x.size()[2] > 2: + x = self._block(x) + x = x.squeeze() # [batch_size, num_filters] + x = self.fc(x) + return x + + def _block(self, x): + x = self.padding2(x) + px = self.max_pool(x) + + x = self.padding1(px) + x = F.relu(x) + x = self.conv(x) + + x = self.padding1(x) + x = F.relu(x) + x = self.conv(x) + + # Short Cut + x = x + px + return x diff --git a/models/TextRNN.py b/models/TextRNN.py index 0ab30cfc4bfa8c49ecd40b89dc95963da392edd7..a80220e6bb5d5e5f699ebce5762293a285dde15d 100644 --- a/models/TextRNN.py +++ b/models/TextRNN.py @@ -56,7 +56,7 @@ class Model(nn.Module): out = self.fc(out[:, -1, :]) # 句子最后时刻的 hidden state return out - '''变长RNN''' + '''变长RNN,效果差不多,甚至还低了点...''' # def forward(self, x): # x, seq_len = x # out = self.embedding(x) diff --git a/models/TextRNN_Att.py b/models/TextRNN_Att.py index 95316645ea6c40302ab71c2131975efeed0e7629..ef35beefc03701ae20533b1c3cb2f18bab8dd792 100644 --- a/models/TextRNN_Att.py +++ b/models/TextRNN_Att.py @@ -49,6 +49,7 @@ class Model(nn.Module): self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers, bidirectional=True, batch_first=True, dropout=config.dropout) self.tanh1 = nn.Tanh() + # self.u = nn.Parameter(torch.Tensor(config.hidden_size * 2, config.hidden_size * 2)) self.w = nn.Parameter(torch.Tensor(config.hidden_size * 2)) self.tanh2 = nn.Tanh() self.fc1 = nn.Linear(config.hidden_size * 2, config.hidden_size2) @@ -60,6 +61,7 @@ class Model(nn.Module): H, _ = self.lstm(emb) # [batch_size, seq_len, hidden_size * num_direction]=[128, 32, 256] M = self.tanh1(H) # [128, 32, 256] + # M = torch.tanh(torch.matmul(H, self.u)) alpha = F.softmax(torch.matmul(M, self.w), dim=1).unsqueeze(-1) # [128, 32, 1] out = H * alpha # [128, 32, 256] out = torch.sum(out, 1) # [128, 256] diff --git a/run.py b/run.py index c6531c1a0b441194e6f31f67f58641b2904d0ed9..d51a9db35ecd7cae4daf2cd6cdfc41e5b8e7342c 100644 --- a/run.py +++ b/run.py @@ -7,7 +7,7 @@ from importlib import import_module import argparse parser = argparse.ArgumentParser(description='Chinese Text Classification') -parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att') +parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN') parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained') parser.add_argument('--word', default=False, type=bool, help='True for word, False for char') args = parser.parse_args() @@ -20,7 +20,7 @@ if __name__ == '__main__': embedding = 'embedding_SougouNews.npz' if args.embedding == 'random': embedding = 'random' - model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att + model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN if model_name == 'FastText': from utils_fasttext import build_dataset, build_iterator, get_time_dif embedding = 'random'