diff --git a/train_demo.py b/train_demo.py index 8c4657ad0062cf30180e8b8e3b5b8af1ae8d538f..1ce581fa7a1e70d3194f94dc0904b270f5794399 100644 --- a/train_demo.py +++ b/train_demo.py @@ -8,7 +8,7 @@ import torch.nn as nn import torch.optim as optim from Config_load import * from data_process import * - +import random np.random.seed(random_seed) @@ -21,13 +21,8 @@ data=json2list.getdata() # transform list to token list2token=generate_vocab_normalway(data,map_dir="words_info.json") sentences,token_list,idx2word,word2idx,vocab_size=list2token.transform() - -batch = creat_batch(batch_size,max_pred,maxlen,vocab_size,word2idx,token_list,sentences) -input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch) -input_ids, segment_ids, masked_tokens, masked_pos, isNext = \ - torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), \ - torch.LongTensor(masked_pos), torch.LongTensor(isNext) -loader = Data.DataLoader(Text_file(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True) +batch = creat_batch(batch_size,max_pred,maxlen,word2idx,idx2word,token_list,0.15) +loader = Data.DataLoader(Text_file(batch), batch_size, True) model=Bert(n_layers=n_layers, vocab_size=vocab_size, @@ -39,14 +34,14 @@ model=Bert(n_layers=n_layers, dv=d_v, n_head=n_heads, n_class=2, - ) + drop=drop) if use_gpu: with torch.cuda.device(device) as device: model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adadelta(model.parameters(), lr=lr) - model.Train(epoches=epoches, + model.Train_for_mask_guess(epoches=epoches, train_data_loader=loader, optimizer=optimizer, criterion=criterion, @@ -59,7 +54,7 @@ if use_gpu: else: criterion = nn.CrossEntropyLoss() optimizer = optim.Adadelta(model.parameters(), lr=lr) - model.Train(epoches=epoches, + model.Train_for_mask_guess(epoches=epoches, train_data_loader=loader, optimizer=optimizer, criterion=criterion, @@ -68,4 +63,4 @@ else: load_dir="checkpoint/checkpoint_199.pth", use_gpu=use_gpu, device=device - ) \ No newline at end of file + )