diff --git a/code/d2lzh_pytorch/utils.py b/code/d2lzh_pytorch/utils.py index 03564c373756150a4388939d45a1a42252f0bf8c..cd54bad3bdf2382e2318a037ae8ede611cd89b1f 100644 --- a/code/d2lzh_pytorch/utils.py +++ b/code/d2lzh_pytorch/utils.py @@ -230,9 +230,8 @@ def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epo net = net.to(device) print("training on ", device) loss = torch.nn.CrossEntropyLoss() - batch_count = 0 for epoch in range(num_epochs): - train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time() + train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time() for X, y in train_iter: X = X.to(device) y = y.to(device) diff --git a/docs/chapter05_CNN/5.5_lenet.md b/docs/chapter05_CNN/5.5_lenet.md index b33c90decfe0f9567d5d92d33059e24c65eb6d30..d527fd9a6893f7a683bd4f81cc8e75be10472a67 100644 --- a/docs/chapter05_CNN/5.5_lenet.md +++ b/docs/chapter05_CNN/5.5_lenet.md @@ -131,9 +131,8 @@ def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epo net = net.to(device) print("training on ", device) loss = torch.nn.CrossEntropyLoss() - batch_count = 0 for epoch in range(num_epochs): - train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time() + train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time() for X, y in train_iter: X = X.to(device) y = y.to(device)