diff --git a/chapter03/lenet/lenet.py b/chapter03/lenet/lenet.py index 068719da4ee151a94c8df218bbcefe5c0a3b3aeb..1b11096d10380db771fe344d9d257a45b2fd7979 100644 --- a/chapter03/lenet/lenet.py +++ b/chapter03/lenet/lenet.py @@ -14,6 +14,7 @@ # ============================================================================ """LeNet.""" import mindspore.nn as nn +from mindspore.common.initializer import Normal class LeNet5(nn.Cell): @@ -22,22 +23,21 @@ class LeNet5(nn.Cell): Args: num_class (int): Num classes. Default: 10. - channel (int): Num classes. Default: 1. + num_channel (int): Num channels. Default: 1. Returns: Tensor, output tensor Examples: - >>> LeNet(num_class=10, channel=1) + >>> LeNet(num_class=10, num_channel=1) """ - def __init__(self, num_class=10, channel=1): + def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() - self.num_class = num_class - self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid') + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') - self.fc1 = nn.Dense(16 * 5 * 5, 120) - self.fc2 = nn.Dense(120, 84) - self.fc3 = nn.Dense(84, self.num_class) + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() diff --git a/chapter03/lenet/main.py b/chapter03/lenet/main.py index 1dc82379a886bdf968e9641bc0e106ee5d1dbac6..efc8cf5a739bb7c982b66209d5912267e961782a 100644 --- a/chapter03/lenet/main.py +++ b/chapter03/lenet/main.py @@ -91,7 +91,7 @@ if __name__ == "__main__": context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) network = LeNet5(cfg.num_classes) - net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") repeat_size = 1 net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})