提交 e281f154 编写于 作者: Stevezhangz's avatar Stevezhangz

Update data_process.py

上级 aa6be9dd
......@@ -9,7 +9,7 @@ import os
import json
import thulac
import numpy as np
import torch
class general_transform_text2list:
"""
notification: All series of data process method here only support the list type sentences, so whether json or txt file
......@@ -193,13 +193,79 @@ def generate_vocab_from_poem_chuci(poem_dir,map_dir):
return sentences,id_sentence,idx2word,word2idx,vocab_size
def creat_batch(batch_size,max_pred,maxlen,vocab_size,word2idx,token_list,sentences):
def creat_batch(batch_size,
max_pred,
maxlen,
word2idx,
idx2word,
token_list,
pre_percent):
"""
here this mechine just have to predict several masked words. and also have to predict whether they are sequential
:param batch_size:
:param max_pred:
:param maxlen:
:param vocab_size:
:param word2idx:
:param token_list:
:param sentences:
:return: batch[In_id, seg_id, could_mask, could_mask_tok, isconnect]
"""
batch=[]
connect=unconnect=0
while connect<batch_size/2 or unconnect<batch_size/2:
s1=choice(token_list)
s1_index=token_list.index(s1)
s2 = choice(token_list)
s2_index=token_list.index(s2)
In_id=[word2idx['[CLS]']] + s1 + [word2idx['[SEP]']] + s2 + [word2idx['[SEP]']]
seg_id=[0] * (1 + len(s1) + 1) + [1] * (len(s2) + 1)
could_mask=[]
for seq,val in enumerate(In_id):
if idx2word[val]!='[CLS]' and idx2word[val]!='[SEP]':
could_mask.append(seq)
mask_num=min(max_pred,max(int(len(could_mask)*pre_percent),1))
mask_Inid=np.random.choice(could_mask,int(mask_num))
mask_pos=[]
for mIid in mask_Inid:
In_id[mIid]=word2idx['[MASK]']
mask_pos.append(mIid)
pad_need=maxlen-len(In_id)
In_id.extend([0]*pad_need)
seg_id.extend([0]*pad_need)
mask_Inid=[i for i in mask_Inid]
if mask_num<max_pred:
mask_Inid.extend([0]*(max_pred-int(mask_num)))
mask_pos.extend([0]*(max_pred-int(mask_num)))
if s1_index+1==s2_index and connect<batch_size/2:
connect+=1
batch.append([In_id,seg_id,mask_pos,mask_Inid,True])
if s1_index+1!=s2_index and unconnect<batch_size/2:
unconnect+=1
batch.append([In_id, seg_id,mask_Inid,mask_pos, False])
return batch
def creat_batch_demo(batch_size,max_pred,maxlen,vocab_size,word2idx,token_list,sentences):
"""
this demo could be found, thanks: https://codechina.csdn.net/mirrors/wmathor/nlp-tutorial/-/tree/master/5-2.BERT
:param batch_size:
:param max_pred:
:param maxlen:
:param vocab_size:
:param word2idx:
:param token_list:
:param sentences:
:return:batch
"""
batch = []
positive = negative = 0
while positive != batch_size / 2 or negative != batch_size / 2:
tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(
len(sentences))
# random choice two sentences
tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))
......@@ -234,7 +300,13 @@ def creat_batch(batch_size,max_pred,maxlen,vocab_size,word2idx,token_list,senten
class Text_file(Data.Dataset):
def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
def __init__(self, batch):
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
input_ids, segment_ids, masked_tokens, masked_pos, isNext = torch.LongTensor(input_ids),\
torch.LongTensor( segment_ids),\
torch.LongTensor(masked_tokens),\
torch.LongTensor(masked_pos),\
torch.LongTensor( isNext)
self.input_ids = input_ids
self.segment_ids = segment_ids
self.masked_tokens = masked_tokens
......@@ -246,4 +318,4 @@ class Text_file(Data.Dataset):
def __getitem__(self, idx):
return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[
idx]
\ No newline at end of file
idx]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册