提交 cf198a43 编写于 作者: M Megvii Engine Team

fix the bug of memory increasing during evaluation

GitOrigin-RevId: 5d7ab3be4077849e1c186ef0700a63c471700a83
上级 d89667ad
......@@ -174,17 +174,12 @@ MegEngine 提供了很方便的动静态图转换的方法,几乎无需代码
trace.enabled = True # 开启trace,使用静态图模式
le_net.eval() # 将网络设为测试模式
data = mge.tensor()
label = mge.tensor(dtype="int32")
correct = 0
total = 0
for idx, (batch_data, batch_label) in enumerate(dataloader_test):
data.set_value(batch_data)
label.set_value(batch_label)
logits = eval_func(data, net=le_net) # 测试函数
logits = eval_func(batch_data, net=le_net) # 测试函数
predicted = F.argmax(logits, axis=1)
correct += (predicted==label).sum().numpy().item()
total += label.shape[0]
predicted = logits.numpy().argmax(axis=1)
correct += (predicted==batch_label).sum()
total += batch_label.shape[0]
print("correct: {}, total: {}, accuracy: {}".format(correct, total, float(correct)/total))
......@@ -261,16 +261,14 @@ MegEngine 在GPU和CPU同时存在时默认使用GPU进行训练。用户可以
le_net.eval() # 设置为测试模式
data = mge.tensor()
label = mge.tensor(dtype="int32")
correct = 0
total = 0
for idx, (batch_data, batch_label) in enumerate(dataloader_test):
data.set_value(batch_data)
label.set_value(batch_label)
logits = le_net(data)
predicted = F.argmax(logits, axis=1)
correct += (predicted==label).sum().numpy().item()
total += label.shape[0]
predicted = logits.numpy().argmax(axis=1)
correct += (predicted==batch_label).sum()
total += batch_label.shape[0]
print("correct: {}, total: {}, accuracy: {}".format(correct, total, float(correct)/total))
测试输出如下,可以看到经过训练的 ``LeNet`` 在 MNIST 测试数据集上的准确率已经达到98.84%:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册