提交 07bc3e8d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!21 Modify SoftmaxCrossEntropyWithLogits

Merge pull request !21 from wanyiming/mod_SoftmaxCrossEntropyWithlogits
...@@ -83,7 +83,7 @@ if __name__ == "__main__": ...@@ -83,7 +83,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = AlexNet(cfg.num_classes) network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
repeat_size = 1 repeat_size = 1
# when batch_size=32, steps is 1562 # when batch_size=32, steps is 1562
lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, 1562)) lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, 1562))
......
...@@ -120,7 +120,7 @@ if __name__ == '__main__': ...@@ -120,7 +120,7 @@ if __name__ == '__main__':
epoch_size = args_opt.epoch_size epoch_size = args_opt.epoch_size
net = resnet50(args_opt.num_classes) net = resnet50(args_opt.num_classes)
ls = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction="mean") ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'}) model = Model(net, loss_fn=ls, optimizer=opt, metrics={'acc'})
......
...@@ -64,7 +64,7 @@ if __name__ == '__main__': ...@@ -64,7 +64,7 @@ if __name__ == '__main__':
weight=Tensor(embedding_table), weight=Tensor(embedding_table),
batch_size=cfg.batch_size) batch_size=cfg.batch_size)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
loss_cb = LossMonitor() loss_cb = LossMonitor()
......
...@@ -70,7 +70,7 @@ if __name__ == '__main__': ...@@ -70,7 +70,7 @@ if __name__ == '__main__':
if args.pre_trained: if args.pre_trained:
load_param_into_net(network, load_checkpoint(args.pre_trained)) load_param_into_net(network, load_checkpoint(args.pre_trained))
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
loss_cb = LossMonitor() loss_cb = LossMonitor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册