From c5d0f74a514df9c9cf456f3e5552c93c9e54728e Mon Sep 17 00:00:00 2001 From: ShusenTang Date: Tue, 12 Nov 2019 23:34:47 +0800 Subject: [PATCH] fix bug about batch_count (#60) --- code/d2lzh_pytorch/utils.py | 3 +-- docs/chapter05_CNN/5.5_lenet.md | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/code/d2lzh_pytorch/utils.py b/code/d2lzh_pytorch/utils.py index 03564c3..cd54bad 100644 --- a/code/d2lzh_pytorch/utils.py +++ b/code/d2lzh_pytorch/utils.py @@ -230,9 +230,8 @@ def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epo net = net.to(device) print("training on ", device) loss = torch.nn.CrossEntropyLoss() - batch_count = 0 for epoch in range(num_epochs): - train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time() + train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time() for X, y in train_iter: X = X.to(device) y = y.to(device) diff --git a/docs/chapter05_CNN/5.5_lenet.md b/docs/chapter05_CNN/5.5_lenet.md index b33c90d..d527fd9 100644 --- a/docs/chapter05_CNN/5.5_lenet.md +++ b/docs/chapter05_CNN/5.5_lenet.md @@ -131,9 +131,8 @@ def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epo net = net.to(device) print("training on ", device) loss = torch.nn.CrossEntropyLoss() - batch_count = 0 for epoch in range(num_epochs): - train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time() + train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time() for X, y in train_iter: X = X.to(device) y = y.to(device) -- GitLab