From fbbb98d27d973ce4ea3397d0b0ea467fa3be80c4 Mon Sep 17 00:00:00 2001 From: "Eric.Lee2021" <305141918@qq.com> Date: Fri, 23 Apr 2021 11:17:23 +0800 Subject: [PATCH] correct random seed bug --- train.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index 82998e1..2214e52 100644 --- a/train.py +++ b/train.py @@ -21,7 +21,7 @@ import cv2 import time import json from datetime import datetime - +import random def tester(ops,epoch,model,criterion, train_split,train_split_label,val_split,val_split_label, use_cuda): @@ -190,7 +190,7 @@ def trainer(ops,f_log): else: flag_change_lr_cnt += 1 - if flag_change_lr_cnt > 5: + if flag_change_lr_cnt > 10: init_lr = init_lr*ops.lr_decay set_learning_rate(optimizer, init_lr) flag_change_lr_cnt = 0 @@ -226,7 +226,8 @@ def trainer(ops,f_log): step += 1 # 一个 epoch 保存连词最新的 模型 - if i%(int(dataset.__len__()/ops.batch_size/2-1)) == 0 and i > 0: + # if i%(int(dataset.__len__()/ops.batch_size/2-1)) == 0 and i > 0: + if i%(1000) == 0 and i > 0: torch.save(model_.state_dict(), ops.model_exp + 'latest.pth') # 每间隔 5 个 epoch 进行模型保存 if (epoch%5) == 0 and (epoch > 9): @@ -248,6 +249,8 @@ def trainer(ops,f_log): json.dump(epochs_loss_dict,f_loss,ensure_ascii=False,indent = 1,cls = JSON_Encoder) f_loss.close() + set_seed(random.randint(0,65535)) + except Exception as e: print('Exception : ',e) # 打印异常 print('Exception file : ', e.__traceback__.tb_frame.f_globals['__file__'])# 发生异常所在的文件 @@ -260,16 +263,16 @@ if __name__ == "__main__": help = 'seed') # 设置随机种子 parser.add_argument('--model_exp', type=str, default = './model_exp', help = 'model_exp') # 模型输出文件夹 - parser.add_argument('--model', type=str, default = 'resnet_34', + parser.add_argument('--model', type=str, default = 'resnet_50', help = 'model : resnet_18,resnet_34,resnet_50,resnet_101,resnet_152') # 模型类型 ''' 注意以下3个参数与具体分类任务数据集,息息相关 ''' #--------------------------------------------------------------------------------- - parser.add_argument('--train_path', type=str, default = './handpose_x_gesture_v1/', + parser.add_argument('--train_path', type=str, default = './animals10/', help = 'train_path') # 训练集路径 - parser.add_argument('--num_classes', type=int , default = 14, + parser.add_argument('--num_classes', type=int , default = 10, help = 'num_classes') # 分类类别个数,gesture 配置为 14 , Stanford Dogs 配置为 120 parser.add_argument('--have_label_file', type=bool, default = False, help = 'have_label_file') # 是否有配套的标注文件解析才能生成分类样本,gesture 配置为 False , Stanford Dogs 配置为 True @@ -293,15 +296,15 @@ if __name__ == "__main__": help = 'learningRate_decay') # 学习率权重衰减率 parser.add_argument('--weight_decay', type=float, default = 1e-6, help = 'weight_decay') # 优化器正则损失权重 - parser.add_argument('--batch_size', type=int, default = 32, + parser.add_argument('--batch_size', type=int, default = 48, help = 'batch_size') # 训练每批次图像数量 parser.add_argument('--dropout', type=float, default = 0.5, help = 'dropout') # dropout parser.add_argument('--epochs', type=int, default = 1000, help = 'epochs') # 训练周期 - parser.add_argument('--num_workers', type=int, default = 1, + parser.add_argument('--num_workers', type=int, default = 6, help = 'num_workers') # 训练数据生成器线程数 - parser.add_argument('--img_size', type=tuple , default = (192,192), + parser.add_argument('--img_size', type=tuple , default = (256,256), help = 'img_size') # 输入模型图片尺寸 parser.add_argument('--flag_agu', type=bool , default = True, help = 'data_augmentation') # 训练数据生成器是否进行数据扩增 -- GitLab