提交 33ca3759 编写于 作者: Stevezhangz's avatar Stevezhangz

Update bert_for_sentence_classify.py

上级 8c4f5755
......@@ -11,6 +11,7 @@ data=json2list.getdata()
list2token=generate_vocab_normalway(data,map_dir="words_info.json")
sentences,token_list,idx2word,word2idx,vocab_size=list2token.transform()
batch = creat_batch_for_wordpre(100,word2idx,token_list,maxlen=maxlen)
loader = DataLoader(word_pre_load(batch), batch_size, True)
model=Bert_classify(n_layers=n_layers,
vocab_size=vocab_size,
emb_size=d_model,
......@@ -26,4 +27,4 @@ model=Bert_classify(n_layers=n_layers,
if use_gpu:
with torch.cuda.device(device) as device:
model.to(device)
model.display(batch=batch, load_dir="checkpoint/checkpoint_199.pth")
model.display(batch=loader, load_dir="checkpoint/checkpoint_199.pth")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册