未验证 提交 08c5b1d1 编写于 作者: S shangliang Xu 提交者: GitHub

fix bug for num_iters in fit/evaluate (#34059)

上级 edb9aff5
......@@ -1707,7 +1707,8 @@ class Model(object):
steps = self._len_data_loader(train_loader)
self.num_iters = num_iters
if num_iters is not None and isinstance(num_iters, int):
if num_iters is not None and isinstance(num_iters, int) and isinstance(
steps, int):
assert num_iters > 0, "num_iters must be greater than 0!"
epochs = (num_iters // steps) + 1
steps = min(num_iters, steps)
......@@ -1830,7 +1831,8 @@ class Model(object):
eval_steps = self._len_data_loader(eval_loader)
self.num_iters = num_iters
if num_iters is not None and isinstance(num_iters, int):
if num_iters is not None and isinstance(num_iters, int) and isinstance(
eval_steps, int):
assert num_iters > 0, "num_iters must be greater than 0!"
eval_steps = min(num_iters, eval_steps)
self.num_iters = eval_steps
......@@ -2092,7 +2094,9 @@ class Model(object):
callbacks.on_batch_end(mode, step, logs)
if hasattr(self, 'num_iters') and self.num_iters is not None:
self.num_iters -= 1
if self.num_iters == 0:
if self.num_iters <= 0:
self.stop_training = True
del self.num_iters
break
self._reset_metrics()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册