提交 ecbf73f7 编写于 作者: G Gabriel de Marmiesse 提交者: François Chollet

[RELNOTES] [P] Write to TensorBoard every x samples. (#11152)

* Working on improving tensor flow callbacks

* Adding batch level TensorBoard logging (implementing the `on_batch_end` method to the TensorBoard class

* Interim commit -- added notes.

* Corrected stylistic issues -- brought to compliance w/ PEP8

* Added the missing argument in the test suite.

* Added the possibility to choose how frequently tensorboard should log
the metrics and losses.

* Fixed the issue of the validation data not being displayed.

* Fixed the issue about the callback not remembering when was the last
time it wrote to the logs.

* Removed the error check.

* Used update_freq instead of write_step.

* Forgot to change the constructor call.
上级 ae6474df
......@@ -719,6 +719,12 @@ class TensorBoard(Callback):
input) or list of Numpy arrays (if the model has multiple inputs).
Learn [more about embeddings]
(https://www.tensorflow.org/programmers_guide/embedding).
update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, writes
the losses and metrics to TensorBoard after each batch. The same
applies for `'epoch'`. If using an integer, let's say `10000`,
the callback will write the metrics and losses to TensorBoard every
10000 samples. Note that writing too frequently to TensorBoard
can slow down your training.
"""
def __init__(self, log_dir='./logs',
......@@ -730,7 +736,8 @@ class TensorBoard(Callback):
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None,
embeddings_data=None):
embeddings_data=None,
update_freq='epoch'):
super(TensorBoard, self).__init__()
global tf, projector
try:
......@@ -769,6 +776,13 @@ class TensorBoard(Callback):
self.embeddings_metadata = embeddings_metadata or {}
self.batch_size = batch_size
self.embeddings_data = embeddings_data
if update_freq == 'batch':
# It is the same as writing as frequently as possible.
self.update_freq = 1
else:
self.update_freq = update_freq
self.samples_seen = 0
self.samples_seen_at_last_write = 0
def set_model(self, model):
self.model = model
......@@ -968,6 +982,13 @@ class TensorBoard(Callback):
i += self.batch_size
if self.update_freq == 'epoch':
index = epoch
else:
index = self.samples_seen
self._write_logs(logs, index)
def _write_logs(self, logs, index):
for name, value in logs.items():
if name in ['batch', 'size']:
continue
......@@ -978,12 +999,20 @@ class TensorBoard(Callback):
else:
summary_value.simple_value = value
summary_value.tag = name
self.writer.add_summary(summary, epoch)
self.writer.add_summary(summary, index)
self.writer.flush()
def on_train_end(self, _):
self.writer.close()
def on_batch_end(self, batch, logs=None):
if self.update_freq != 'epoch':
self.samples_seen += logs['size']
samples_seen_since = self.samples_seen - self.samples_seen_at_last_write
if samples_seen_since >= self.update_freq:
self._write_logs(logs, self.samples_seen)
self.samples_seen_at_last_write = self.samples_seen
class ReduceLROnPlateau(Callback):
"""Reduce learning rate when a metric has stopped improving.
......
......@@ -550,7 +550,8 @@ def test_CSVLogger(tmpdir):
assert not tmpdir.listdir()
def test_TensorBoard(tmpdir):
@pytest.mark.parametrize('update_freq', ['batch', 'epoch', 9])
def test_TensorBoard(tmpdir, update_freq):
np.random.seed(np.random.randint(1, 1e7))
filepath = str(tmpdir / 'logs')
......@@ -588,7 +589,8 @@ def test_TensorBoard(tmpdir):
embeddings_freq=embeddings_freq,
embeddings_layer_names=['dense_1'],
embeddings_data=X_test,
batch_size=5)]
batch_size=5,
update_freq=update_freq)]
# fit without validation data
model.fit(X_train, y_train, batch_size=batch_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册