提交 bacd6196 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!20 modify lenet

Merge pull request !20 from wukesong/modify_lenet_pa
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""LeNet.""" """LeNet."""
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.initializer import Normal
class LeNet5(nn.Cell): class LeNet5(nn.Cell):
...@@ -22,22 +23,21 @@ class LeNet5(nn.Cell): ...@@ -22,22 +23,21 @@ class LeNet5(nn.Cell):
Args: Args:
num_class (int): Num classes. Default: 10. num_class (int): Num classes. Default: 10.
channel (int): Num classes. Default: 1. num_channel (int): Num channels. Default: 1.
Returns: Returns:
Tensor, output tensor Tensor, output tensor
Examples: Examples:
>>> LeNet(num_class=10, channel=1) >>> LeNet(num_class=10, num_channel=1)
""" """
def __init__(self, num_class=10, channel=1): def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.num_class = num_class self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120) self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, self.num_class) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
......
...@@ -91,7 +91,7 @@ if __name__ == "__main__": ...@@ -91,7 +91,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = LeNet5(cfg.num_classes) network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
repeat_size = 1 repeat_size = 1
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册