From c7a1481a64e3de7a859bb8c28d2f24f4bd953e51 Mon Sep 17 00:00:00 2001 From: Stevezhangz Date: Mon, 26 Apr 2021 02:05:23 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=B0=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Bert_finetune.py | 236 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 Bert_finetune.py diff --git a/Bert_finetune.py b/Bert_finetune.py new file mode 100644 index 0000000..cc073e4 --- /dev/null +++ b/Bert_finetune.py @@ -0,0 +1,236 @@ +from bert import * +from Config_load import * + + + +class Bert_word_pre(nn.Module): + def __init__(self, + n_layers, + vocab_size, + emb_size, + max_len, + seg_size, + dff, + dk, + dv, + n_head, + n_class, + drop): + super(Bert_word_pre, self).__init__() + self.vocab_size=vocab_size + self.emb_size=emb_size + self.emb_layer=Embedding(vocab_size,emb_size,max_len,seg_size) + self.encoder_layer=nn.Sequential(*[basic_block(emb_size,dff,dk,dv,n_head) for i in range(n_layers)]) + self.fc1=nn.Sequential( + nn.Linear(emb_size, vocab_size), + nn.Dropout(drop), + nn.Tanh(), + nn.Linear(vocab_size, n_class) + ) + fc2=nn.Linear(emb_size, vocab_size) + fc2.weight=self.emb_layer.emb_x.weight + self.fc2=nn.Sequential( + nn.Linear(emb_size, emb_size), + Grelu(), + fc2 + ) + def get_mask(self,In): + batch_size,length,mask=In.size()[0],In.size()[1],In + mask=mask.eq(0).unsqueeze(1) + return mask.data.expand(batch_size,length,length) + + def forward(self,x): + mask=self.get_mask(x) + output=self.emb_layer(x=x,seg=None) + for layer in self.encoder_layer: + output=layer(output,mask) + cls=self.fc2(output[:,1:]) + return cls + + def display(self,batch,load_dir,map_dir): + import json + if load_dir != None: + if os.path.exists(load_dir): + checkpoint = torch.load(load_dir) + try: + self.load_state_dict(checkpoint['model']) + except: + print("fail to load the state_dict") + map_file=json.load(open(map_dir,"r"))["idx2word"] + for x in batch: + pre=self(x) + pre=pre.data.max(2)[1][0].data.numpy() + transform=[] + for i in pre: + try: + word_pre=map_file[int(i)] + except: + word_pre="mistake" + transform.append(word_pre) + print("prediction_words:",transform) + print("prediction_token:", pre) + + + def Train(self, epoches, criterion, optimizer, train_data_loader, use_gpu, device, + eval_data_loader=None, save_dir="./checkpoint", load_dir=None, save_freq=5, + ): + import tqdm + if load_dir != None: + if os.path.exists(load_dir): + checkpoint = torch.load(load_dir) + try: + self.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + except: + print("fail to load the state_dict") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + for epc in range(epoches): + tq = tqdm.tqdm(train_data_loader) + for seq, (input_ids, classi) in enumerate(tq): + if use_gpu: + input_ids, classi = input_ids.to(device), classi.to(device) + logits_clsf = self(x=input_ids) + loss_cls = criterion(logits_clsf, classi) + optimizer.zero_grad() + loss_cls.backward() + optimizer.step() + tq.set_description(f"train Epoch {epc + 1}, Batch{seq}") + tq.set_postfix(train_loss=loss_cls) + + if eval_data_loader != None: + tq = tqdm.tqdm(eval_data_loader) + with torch.no_grad(): + for epc in range(epoches): + tq = tqdm.tqdm(train_data_loader) + for seq, (input_ids, classi) in enumerate(tq): + if use_gpu: + input_ids, classi = input_ids.to(device), classi.to(device) + logits_clsf = self(x=input_ids) + loss_cls = criterion(logits_clsf, classi) + optimizer.zero_grad() + loss_cls.backward() + optimizer.step() + tq.set_description(f"Eval Epoch {epc + 1}, Batch{seq}") + tq.set_postfix(train_loss=loss_cls) + + if (epc + 1) % save_freq == 0: + checkpoint = {'epoch': epc, + 'best_loss': criterion, + 'model': self.state_dict(), + 'optimizer': optimizer.state_dict() + } + torch.save(checkpoint, save_dir + f"/checkpoint_{epc}.pth") + +class Bert_classify(nn.Module): + def __init__(self, + n_layers, + vocab_size, + emb_size, + max_len, + seg_size, + dff, + dk, + dv, + n_head, + n_class, + drop): + super(Bert_classify, self).__init__() + self.vocab_size=vocab_size + self.emb_size=emb_size + self.emb_layer=Embedding(vocab_size,emb_size,max_len,seg_size) + self.encoder_layer=nn.Sequential(*[basic_block(emb_size,dff,dk,dv,n_head) for i in range(n_layers)]) + self.fc1=nn.Sequential( + nn.Linear(emb_size, vocab_size), + nn.Dropout(drop), + nn.Tanh(), + nn.Linear(vocab_size, n_class) + ) + fc2=nn.Linear(emb_size, vocab_size) + fc2.weight=self.emb_layer.emb_x.weight + self.fc2=nn.Sequential( + nn.Linear(emb_size, emb_size), + Grelu(), + fc2 + ) + def get_mask(self,In): + batch_size,length,mask=In.size()[0],In.size()[1],In + mask=mask.eq(0).unsqueeze(1) + return mask.data.expand(batch_size,length,length) + + def forward(self,x): + mask=self.get_mask(x) + output=self.emb_layer(x=x,seg=None) + for layer in self.encoder_layer: + output=layer(output,mask) + cls=self.fc1(output[:,0]) + return cls + + def display(self,batch,load_dir): + if load_dir != None: + if os.path.exists(load_dir): + checkpoint = torch.load(load_dir) + try: + self.load_state_dict(checkpoint['model']) + except: + print("fail to load the state_dict") + for i in batch: + logits_clsf = self(x=i) + print(logits_clsf) + + + def Train(self,epoches,criterion,optimizer,train_data_loader,use_gpu,device, + eval_data_loader=None,save_dir="./checkpoint",load_dir=None,save_freq=5, + ): + import tqdm + if load_dir!=None: + if os.path.exists(load_dir): + checkpoint=torch.load(load_dir) + try: + self.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + except: + print("fail to load the state_dict") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + for epc in range(epoches): + tq=tqdm.tqdm(train_data_loader) + for seq,(input_ids,classi) in enumerate(tq): + if use_gpu: + input_ids, classi=input_ids.to(device), classi.to(device) + logits_clsf = self(x=input_ids) + loss_word = criterion(logits_clsf.view(-1, self.vocab_size), logits_clsf.view(-1)) + loss_word = (loss_word.float()).mean() + optimizer.zero_grad() + loss_word.backward() + optimizer.step() + tq.set_description(f"train Epoch {epc+1}, Batch{seq}") + tq.set_postfix(train_loss=loss_word) + + if eval_data_loader!=None: + tq=tqdm.tqdm(eval_data_loader) + with torch.no_grad(): + for epc in range(epoches): + tq = tqdm.tqdm(train_data_loader) + for seq, (input_ids, classi) in enumerate(tq): + if use_gpu: + input_ids, classi = input_ids.to(device), classi.to(device) + logits_clsf = self(x=input_ids) + loss_word = criterion(logits_clsf.view(-1, self.vocab_size), logits_clsf.view(-1)) + loss_word = (loss_word.float()).mean() + optimizer.zero_grad() + loss_word.backward() + optimizer.step() + tq.set_description(f"Eval Epoch {epc + 1}, Batch{seq}") + tq.set_postfix(train_loss=loss_word) + + if (epc+1)%save_freq==0: + checkpoint = {'epoch': epc, + 'best_loss': criterion, + 'model': self.state_dict(), + 'optimizer': optimizer.state_dict() + } + torch.save(checkpoint, save_dir+ f"/checkpoint_{epc}.pth") + + + -- GitLab