From ecbf73f72b59f8f5c8746de63270aa1fb3ad7524 Mon Sep 17 00:00:00 2001 From: Gabriel de Marmiesse Date: Mon, 1 Oct 2018 21:00:01 +0200 Subject: [PATCH] [RELNOTES] [P] Write to TensorBoard every x samples. (#11152) * Working on improving tensor flow callbacks * Adding batch level TensorBoard logging (implementing the `on_batch_end` method to the TensorBoard class * Interim commit -- added notes. * Corrected stylistic issues -- brought to compliance w/ PEP8 * Added the missing argument in the test suite. * Added the possibility to choose how frequently tensorboard should log the metrics and losses. * Fixed the issue of the validation data not being displayed. * Fixed the issue about the callback not remembering when was the last time it wrote to the logs. * Removed the error check. * Used update_freq instead of write_step. * Forgot to change the constructor call. --- keras/callbacks.py | 33 +++++++++++++++++++++++++++++++-- tests/keras/test_callbacks.py | 6 ++++-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/keras/callbacks.py b/keras/callbacks.py index c695fbea0..858b0dba9 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -719,6 +719,12 @@ class TensorBoard(Callback): input) or list of Numpy arrays (if the model has multiple inputs). Learn [more about embeddings] (https://www.tensorflow.org/programmers_guide/embedding). + update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, writes + the losses and metrics to TensorBoard after each batch. The same + applies for `'epoch'`. If using an integer, let's say `10000`, + the callback will write the metrics and losses to TensorBoard every + 10000 samples. Note that writing too frequently to TensorBoard + can slow down your training. """ def __init__(self, log_dir='./logs', @@ -730,7 +736,8 @@ class TensorBoard(Callback): embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None, - embeddings_data=None): + embeddings_data=None, + update_freq='epoch'): super(TensorBoard, self).__init__() global tf, projector try: @@ -769,6 +776,13 @@ class TensorBoard(Callback): self.embeddings_metadata = embeddings_metadata or {} self.batch_size = batch_size self.embeddings_data = embeddings_data + if update_freq == 'batch': + # It is the same as writing as frequently as possible. + self.update_freq = 1 + else: + self.update_freq = update_freq + self.samples_seen = 0 + self.samples_seen_at_last_write = 0 def set_model(self, model): self.model = model @@ -968,6 +982,13 @@ class TensorBoard(Callback): i += self.batch_size + if self.update_freq == 'epoch': + index = epoch + else: + index = self.samples_seen + self._write_logs(logs, index) + + def _write_logs(self, logs, index): for name, value in logs.items(): if name in ['batch', 'size']: continue @@ -978,12 +999,20 @@ class TensorBoard(Callback): else: summary_value.simple_value = value summary_value.tag = name - self.writer.add_summary(summary, epoch) + self.writer.add_summary(summary, index) self.writer.flush() def on_train_end(self, _): self.writer.close() + def on_batch_end(self, batch, logs=None): + if self.update_freq != 'epoch': + self.samples_seen += logs['size'] + samples_seen_since = self.samples_seen - self.samples_seen_at_last_write + if samples_seen_since >= self.update_freq: + self._write_logs(logs, self.samples_seen) + self.samples_seen_at_last_write = self.samples_seen + class ReduceLROnPlateau(Callback): """Reduce learning rate when a metric has stopped improving. diff --git a/tests/keras/test_callbacks.py b/tests/keras/test_callbacks.py index 04c487ac1..c8e71df9c 100644 --- a/tests/keras/test_callbacks.py +++ b/tests/keras/test_callbacks.py @@ -550,7 +550,8 @@ def test_CSVLogger(tmpdir): assert not tmpdir.listdir() -def test_TensorBoard(tmpdir): +@pytest.mark.parametrize('update_freq', ['batch', 'epoch', 9]) +def test_TensorBoard(tmpdir, update_freq): np.random.seed(np.random.randint(1, 1e7)) filepath = str(tmpdir / 'logs') @@ -588,7 +589,8 @@ def test_TensorBoard(tmpdir): embeddings_freq=embeddings_freq, embeddings_layer_names=['dense_1'], embeddings_data=X_test, - batch_size=5)] + batch_size=5, + update_freq=update_freq)] # fit without validation data model.fit(X_train, y_train, batch_size=batch_size, -- GitLab