提交 5cf89a66 编写于 作者: 滴水无痕0801's avatar 滴水无痕0801

update

上级 42526613
......@@ -8,6 +8,7 @@ import time
from utils import get_time_dif
# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):
for name, w in model.named_parameters():
if exclude not in name:
......@@ -20,17 +21,16 @@ def init_network(model, method='xavier', exclude='embedding', seed=123):
nn.init.normal_(w)
elif 'bias' in name:
nn.init.constant_(w, 0)
else:
else:
pass
def train(config, model, train_iter, dev_iter, test_iter):
start_time = time.time()
model.train()
# criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
# 学习率指数衰减,每次epoch:学习率 ×= gamma
# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
total_batch = 0 # 记录进行到多少batch
dev_best_loss = float('inf')
......@@ -43,10 +43,7 @@ def train(config, model, train_iter, dev_iter, test_iter):
for i, (trains, labels) in enumerate(train_iter):
outputs = model(trains)
model.zero_grad()
# print(outputs.size())
# print(labels.size())
loss = F.cross_entropy(outputs, labels)
# loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if total_batch % 100 == 0:
......
......@@ -133,7 +133,6 @@ class DatasetIterater(object):
return self.n_batches + 1
else:
return self.n_batches
def build_iterator(dataset, config):
......@@ -146,3 +145,23 @@ def get_time_dif(start_time):
end_time = time.time()
time_dif = end_time - start_time
return timedelta(seconds=int(round(time_dif)))
if __name__ == "__main__":
'''提取预训练词向量'''
vocab_dir = "./THUCNews/data/vocab.pkl"
pretrain_dir = "./THUCNews/data/sgns.sogou.char"
emb_dim = 300
filename_trimmed_dir = "./THUCNews/data/vocab.embedding.sougou"
word_to_id = pkl.load(open(vocab_dir, 'rb'))
embeddings = np.random.rand(len(word_to_id), emb_dim)
f = open(pretrain_dir, "r", encoding='UTF-8')
for i, line in enumerate(f.readlines()):
# if i == 0: # 若第一行是标题,则跳过
# continue
lin = line.strip().split(" ")
if lin[0] in word_to_id:
idx = word_to_id[lin[0]]
emb = [float(x) for x in lin[1:301]]
embeddings[idx] = np.asarray(emb, dtype='float32')
f.close()
np.savez_compressed(filename_trimmed_dir, embeddings=embeddings)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册