提交 fbbb98d2 编写于 作者: Eric.Lee2021's avatar Eric.Lee2021 🚴🏻

correct random seed bug

上级 356e8010
......@@ -21,7 +21,7 @@ import cv2
import time
import json
from datetime import datetime
import random
def tester(ops,epoch,model,criterion,
train_split,train_split_label,val_split,val_split_label,
use_cuda):
......@@ -190,7 +190,7 @@ def trainer(ops,f_log):
else:
flag_change_lr_cnt += 1
if flag_change_lr_cnt > 5:
if flag_change_lr_cnt > 10:
init_lr = init_lr*ops.lr_decay
set_learning_rate(optimizer, init_lr)
flag_change_lr_cnt = 0
......@@ -226,7 +226,8 @@ def trainer(ops,f_log):
step += 1
# 一个 epoch 保存连词最新的 模型
if i%(int(dataset.__len__()/ops.batch_size/2-1)) == 0 and i > 0:
# if i%(int(dataset.__len__()/ops.batch_size/2-1)) == 0 and i > 0:
if i%(1000) == 0 and i > 0:
torch.save(model_.state_dict(), ops.model_exp + 'latest.pth')
# 每间隔 5 个 epoch 进行模型保存
if (epoch%5) == 0 and (epoch > 9):
......@@ -248,6 +249,8 @@ def trainer(ops,f_log):
json.dump(epochs_loss_dict,f_loss,ensure_ascii=False,indent = 1,cls = JSON_Encoder)
f_loss.close()
set_seed(random.randint(0,65535))
except Exception as e:
print('Exception : ',e) # 打印异常
print('Exception file : ', e.__traceback__.tb_frame.f_globals['__file__'])# 发生异常所在的文件
......@@ -260,16 +263,16 @@ if __name__ == "__main__":
help = 'seed') # 设置随机种子
parser.add_argument('--model_exp', type=str, default = './model_exp',
help = 'model_exp') # 模型输出文件夹
parser.add_argument('--model', type=str, default = 'resnet_34',
parser.add_argument('--model', type=str, default = 'resnet_50',
help = 'model : resnet_18,resnet_34,resnet_50,resnet_101,resnet_152') # 模型类型
'''
注意以下3个参数与具体分类任务数据集,息息相关
'''
#---------------------------------------------------------------------------------
parser.add_argument('--train_path', type=str, default = './handpose_x_gesture_v1/',
parser.add_argument('--train_path', type=str, default = './animals10/',
help = 'train_path') # 训练集路径
parser.add_argument('--num_classes', type=int , default = 14,
parser.add_argument('--num_classes', type=int , default = 10,
help = 'num_classes') # 分类类别个数,gesture 配置为 14 , Stanford Dogs 配置为 120
parser.add_argument('--have_label_file', type=bool, default = False,
help = 'have_label_file') # 是否有配套的标注文件解析才能生成分类样本,gesture 配置为 False , Stanford Dogs 配置为 True
......@@ -293,15 +296,15 @@ if __name__ == "__main__":
help = 'learningRate_decay') # 学习率权重衰减率
parser.add_argument('--weight_decay', type=float, default = 1e-6,
help = 'weight_decay') # 优化器正则损失权重
parser.add_argument('--batch_size', type=int, default = 32,
parser.add_argument('--batch_size', type=int, default = 48,
help = 'batch_size') # 训练每批次图像数量
parser.add_argument('--dropout', type=float, default = 0.5,
help = 'dropout') # dropout
parser.add_argument('--epochs', type=int, default = 1000,
help = 'epochs') # 训练周期
parser.add_argument('--num_workers', type=int, default = 1,
parser.add_argument('--num_workers', type=int, default = 6,
help = 'num_workers') # 训练数据生成器线程数
parser.add_argument('--img_size', type=tuple , default = (192,192),
parser.add_argument('--img_size', type=tuple , default = (256,256),
help = 'img_size') # 输入模型图片尺寸
parser.add_argument('--flag_agu', type=bool , default = True,
help = 'data_augmentation') # 训练数据生成器是否进行数据扩增
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册