From 024dc819367685804e73388758632fda81e5224d Mon Sep 17 00:00:00 2001 From: gengdongjie Date: Wed, 12 Aug 2020 09:42:33 +0800 Subject: [PATCH] support checkpoint path configuration for resnet --- chapter05/resnet/resnet_cifar.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chapter05/resnet/resnet_cifar.py b/chapter05/resnet/resnet_cifar.py index 5f2a6fa..fbb0f9c 100644 --- a/chapter05/resnet/resnet_cifar.py +++ b/chapter05/resnet/resnet_cifar.py @@ -128,7 +128,8 @@ if __name__ == '__main__': dataset = create_dataset() batch_num = dataset.get_dataset_size() config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10) - ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck) + checkpoint_path = args_opt.checkpoint_path if args_opt.checkpoint_path is not None else "./" + ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory=checkpoint_path, config=config_ck) loss_cb = LossMonitor() model.train(epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb]) -- GitLab