diff --git a/train/keras-train/train.py b/train/keras-train/train.py index 31ec5ae172601fb334f62875105c65c9a0c8ce6f..35ca1b402beb94215dcbdf379250dda76e2303e6 100644 --- a/train/keras-train/train.py +++ b/train/keras-train/train.py @@ -31,13 +31,16 @@ def one_hot(text,length=10,characters=characters): return label n_len = 10 -def gen(loader): +def gen(loader,flag='train'): while True: i =0 n = len(loader) for X,Y in loader: X = X.numpy() X = X.reshape((-1,imgH,imgW,1)) + if flag=='test': + Y = Y.numpy() + Y = np.array(Y) Length = int(imgW/4)-1 batchs = X.shape[0] @@ -63,6 +66,7 @@ train_loader = torch.utils.data.DataLoader( test_dataset = dataset.lmdbDataset( root=valroot, transform=dataset.resizeNormalize((imgW, imgH)),target_transform=one_hot) + test_loader = torch.utils.data.DataLoader( test_dataset, shuffle=True, batch_size=batchSize, num_workers=int(workers)) @@ -78,9 +82,9 @@ if __name__=='__main__': checkpointer = ModelCheckpoint(filepath="save_model/model{epoch:02d}-{val_loss:.4f}.hdf5",monitor='val_loss', verbose=0,save_weights_only=False, save_best_only=True) rlu = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=1, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0) - model.fit_generator(gen(train_loader), - steps_per_epoch=10240, + model.fit_generator(gen(train_loader,flag='train'), + steps_per_epoch=102400, epochs=200, - validation_data=gen(test_loader), + validation_data=gen(test_loader,flag='test'), callbacks=[checkpointer,rlu], validation_steps=1024)