diff --git a/chapter03/lenet/main.py b/chapter03/lenet/main.py index 7f3dff9104eb0396b085a917779c8796301c2f9c..b94200e3a93acb2a0d1fa826678ab670bcce1cf6 100644 --- a/chapter03/lenet/main.py +++ b/chapter03/lenet/main.py @@ -94,15 +94,15 @@ if __name__ == "__main__": net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") repeat_size = 1 net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) if args.mode == 'train': # train ds_train = create_dataset(os.path.join(args.data_path, args.mode), batch_size=cfg.batch_size, repeat_size=repeat_size) print("============== Starting Training ==============") + config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck, directory=args.ckpt_path) model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=args.dataset_sink_mode) elif args.mode == 'test': # test