提交 b62f7376 编写于 作者: W wenlihaoyu

新增ocr训练代码

上级 2c8023df
......@@ -70,8 +70,9 @@ if __name__=='__main__':
model,basemodel = get_model(height=imgH, nclass=nclass)
import os
if os.path.exists('pretrain-models/keras.hdf5'):
model.load_weights('pretrain-models/keras.hdf5')
basemodel.load_weights('pretrain-models/keras.hdf5')
##注意此处保存的是model的权重
checkpointer = ModelCheckpoint(filepath="save_model/model{epoch:02d}-{val_loss:.4f}.hdf5",monitor='val_loss', verbose=0,save_weights_only=False, save_best_only=True)
rlu = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=1, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0)
......
......@@ -4,7 +4,7 @@ model,basemodel = get_model(height=imgH, nclass=nclass)
import os
modelPath = '../pretrain-models/keras.hdf5'
if os.path.exists(modelPath):
model.load_weights(modelPath)
basemodel.load_weights(modelPath)
batchSize = 128
train_loader = torch.utils.data.DataLoader(
......@@ -43,7 +43,7 @@ for i in range(3):
loss = crrentLoss
path = 'save_model/model{}.h5'.format(loss)
print "save model:".format(path)
model.save(path)
basemodel.save(path)
j+=1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册