diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index 207882e88ff46ba251cbbb0fab2984598aa408dc..fd77ad278b5a1d413b496290c0a81b157b4a2336 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -789,29 +789,6 @@ py_library( ], ) -py_test( - name = "multi_worker_tutorial_test", - srcs = ["multi_worker_tutorial_test.py"], - python_version = "PY3", - shard_count = 5, - tags = [ - "noasan", # TODO(b/156029134) - "nomsan", # TODO(b/156029134) - "notap", # TODO(b/165865820): restore when not flaky - "notsan", # TODO(b/156029134) - ], - deps = [ - "//tensorflow/python:platform", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/distribute:collective_all_reduce_strategy", - "//tensorflow/python/distribute:combinations", - "//tensorflow/python/distribute:multi_process_runner", - "//tensorflow/python/distribute:multi_worker_test_base", - "//tensorflow/python/keras", - "//tensorflow/python/keras/optimizer_v2", - ], -) - distribute_py_test( name = "saved_model_save_load_test", size = "medium", diff --git a/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py b/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py deleted file mode 100644 index 14e502a9fdbe4eb43b39b64ece7e334c47c16bf3..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/distribute/multi_worker_tutorial_test.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Test for multi-worker training tutorial.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import contextlib -import os -import re -import zipfile -from absl import logging -from absl.testing import parameterized -import numpy as np -from tensorflow.python import keras -from tensorflow.python.data.experimental.ops import distribute_options -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import collective_all_reduce_strategy -from tensorflow.python.distribute import combinations as ds_combinations -from tensorflow.python.distribute import multi_process_runner -from tensorflow.python.distribute import multi_worker_test_base -from tensorflow.python.framework import errors_impl -from tensorflow.python.framework import test_combinations as combinations -from tensorflow.python.keras.datasets import mnist -from tensorflow.python.keras.optimizer_v2 import gradient_descent -from tensorflow.python.lib.io import file_io -from tensorflow.python.platform import test -from tensorflow.python.training import checkpoint_management -from tensorflow.python.training.tracking import util as tracking_util -from tensorflow.python.util import nest - - -class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase): - """Test multi-worker training flow demo'ed in go/multi-worker-with-keras.""" - - @contextlib.contextmanager - def skip_fetch_failure_exception(self): - try: - yield - except zipfile.BadZipfile as e: - self.skipTest('Data loading error: Bad magic number for file header.') - except Exception as e: # pylint: disable=broad-except - if 'URL fetch failure' in str(e): - self.skipTest('URL fetch error not considered failure of the test.') - else: - raise - - @ds_combinations.generate( - combinations.combine( - mode=['eager'], - shard_policy=[None] + list(distribute_options.AutoShardPolicy))) - def testMultiWorkerTutorial(self, mode, shard_policy): - """Test multi-worker training flow demo'ed in go/multi-worker-with-keras. - - This test should be kept in sync with the code samples in - go/multi-worker-with-keras. - - Args: - mode: Runtime mode. - shard_policy: None or any of tf.data.experimental.AutoShardPolicy for - testing. - """ - if shard_policy is distribute_options.AutoShardPolicy.FILE: - self.skipTest('TensorSliceDataset is not shardable with FILE policy.') - - def mnist_dataset(batch_size): - with self.skip_fetch_failure_exception(): - (x_train, y_train), _ = mnist.load_data() - # The `x` arrays are in uint8 and have values in the range [0, 255]. - # We need to convert them to float32 with values in the range [0, 1] - x_train = x_train / np.float32(255) - y_train = y_train.astype(np.int64) - train_dataset = dataset_ops.DatasetV2.from_tensor_slices( - (x_train, y_train)).shuffle(60000).repeat().batch(batch_size) - return train_dataset - - def build_and_compile_cnn_model(): - model = keras.Sequential([ - keras.layers.Input(shape=(28, 28)), - keras.layers.Reshape(target_shape=(28, 28, 1)), - keras.layers.Conv2D(32, 3, activation='relu'), - keras.layers.Flatten(), - keras.layers.Dense(128, activation='relu'), - keras.layers.Dense(10) - ]) - model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=gradient_descent.SGD(learning_rate=0.001), - metrics=['accuracy']) - return model - - per_worker_batch_size = 64 - - single_worker_dataset = mnist_dataset(per_worker_batch_size) - single_worker_model = build_and_compile_cnn_model() - single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70) - - num_workers = 4 - - def fn(model_path, checkpoint_dir): - global_batch_size = per_worker_batch_size * num_workers - strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy() - with strategy.scope(): - multi_worker_model = build_and_compile_cnn_model() - - callbacks = [ - keras.callbacks.ModelCheckpoint( - filepath=os.path.join(self.get_temp_dir(), 'checkpoint')) - ] - - multi_worker_dataset = mnist_dataset(global_batch_size) - if shard_policy: - options = dataset_ops.Options() - options.experimental_distribute.auto_shard_policy = shard_policy - multi_worker_dataset = multi_worker_dataset.with_options(options) - - multi_worker_model.fit( - multi_worker_dataset, - epochs=2, - steps_per_epoch=20, - callbacks=callbacks) - - def _is_chief(task_type, task_id): - return task_type is None or task_type == 'chief' or ( - task_type == 'worker' and task_id == 0) - - def _get_temp_dir(dirpath, task_id): - base_dirpath = 'workertemp_' + str(task_id) - temp_dir = os.path.join(dirpath, base_dirpath) - file_io.recursive_create_dir_v2(temp_dir) - return temp_dir - - def write_filepath(filepath, task_type, task_id): - dirpath = os.path.dirname(filepath) - base = os.path.basename(filepath) - if not _is_chief(task_type, task_id): - dirpath = _get_temp_dir(dirpath, task_id) - return os.path.join(dirpath, base) - - task_type, task_id = (strategy.cluster_resolver.task_type, - strategy.cluster_resolver.task_id) - write_model_path = write_filepath(model_path, task_type, task_id) - - multi_worker_model.save(write_model_path) - if not _is_chief(task_type, task_id): - file_io.delete_recursively_v2(os.path.dirname(write_model_path)) - - # Make sure chief finishes saving before non-chief's assertions. - multi_process_runner.get_barrier().wait() - - if not file_io.file_exists_v2(model_path): - raise RuntimeError() - if file_io.file_exists_v2(write_model_path) != _is_chief( - task_type, task_id): - raise RuntimeError() - - loaded_model = keras.saving.save.load_model(model_path) - loaded_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20) - - checkpoint = tracking_util.Checkpoint(model=multi_worker_model) - write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id) - checkpoint_manager = checkpoint_management.CheckpointManager( - checkpoint, directory=write_checkpoint_dir, max_to_keep=1) - - checkpoint_manager.save() - if not _is_chief(task_type, task_id): - file_io.delete_recursively_v2(write_checkpoint_dir) - - # Make sure chief finishes saving before non-chief's assertions. - multi_process_runner.get_barrier().wait() - - if not file_io.file_exists_v2(checkpoint_dir): - raise RuntimeError() - if file_io.file_exists_v2(write_checkpoint_dir) != _is_chief( - task_type, task_id): - raise RuntimeError() - - latest_checkpoint = checkpoint_management.latest_checkpoint( - checkpoint_dir) - checkpoint.restore(latest_checkpoint) - multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20) - - logging.info('testMultiWorkerTutorial successfully ends') - - model_path = os.path.join(self.get_temp_dir(), 'model.tf') - checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt') - try: - mpr_result = multi_process_runner.run( - fn, - multi_worker_test_base.create_cluster_spec(num_workers=num_workers), - args=(model_path, checkpoint_dir), - return_output=True) - except errors_impl.UnavailableError as e: - self.skipTest('Skipping error: {}: {}'.format(type(e), str(e))) - - self.assertTrue( - any([ - 'testMultiWorkerTutorial successfully ends' in msg - for msg in mpr_result.stdout - ])) - - def extract_accuracy(worker_id, input_string): - match = re.match( - r'\[worker\-{}\].*accuracy: (\d+\.\d+).*'.format(worker_id), - input_string) - return None if match is None else float(match.group(1)) - - for worker_id in range(num_workers): - accu_result = nest.map_structure( - lambda x: extract_accuracy(worker_id, x), # pylint: disable=cell-var-from-loop - mpr_result.stdout) - self.assertTrue( - any(accu_result), 'Every worker is supposed to have accuracy result.') - - -if __name__ == '__main__': - multi_process_runner.test_main() diff --git a/tensorflow/python/keras/integration_test/BUILD b/tensorflow/python/keras/integration_test/BUILD index 3b4db66ab55ccc3c10bc800b93b8e9537a2a44e5..164b731d56bf3224539d3dce7605bc5df0c0ba2c 100644 --- a/tensorflow/python/keras/integration_test/BUILD +++ b/tensorflow/python/keras/integration_test/BUILD @@ -103,3 +103,18 @@ tpu_py_test( "//tensorflow/python:extra_py_tests_deps", ], ) + +tf_py_test( + name = "multi_worker_tutorial_test", + srcs = ["multi_worker_tutorial_test.py"], + python_version = "PY3", + shard_count = 3, + tags = [ + "noasan", # TODO(b/156029134) + "nomsan", # TODO(b/156029134) + "notsan", # TODO(b/156029134) + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/python/keras/integration_test/multi_worker_tutorial_test.py b/tensorflow/python/keras/integration_test/multi_worker_tutorial_test.py new file mode 100644 index 0000000000000000000000000000000000000000..20ba4d79af207380fe39687f04436938d5386997 --- /dev/null +++ b/tensorflow/python/keras/integration_test/multi_worker_tutorial_test.py @@ -0,0 +1,346 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for multi-worker training tutorial.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import contextlib +import os +import re +import unittest +import uuid +import zipfile +from absl import logging +from absl.testing import parameterized +import numpy as np +import tensorflow as tf + +PER_WORKER_BATCH_SIZE = 64 +NUM_WORKERS = 2 +NUM_EPOCHS = 2 +NUM_STEPS_PER_EPOCH = 50 + + +def _is_chief(task_type, task_id): + return task_type is None or task_type == 'chief' or (task_type == 'worker' and + task_id == 0) + + +def _get_temp_dir(dirpath, task_id): + base_dirpath = 'workertemp_' + str(task_id) + temp_dir = os.path.join(dirpath, base_dirpath) + tf.io.gfile.makedirs(temp_dir) + return temp_dir + + +def write_filepath(filepath, task_type, task_id): + dirpath = os.path.dirname(filepath) + base = os.path.basename(filepath) + if not _is_chief(task_type, task_id): + dirpath = _get_temp_dir(dirpath, task_id) + return os.path.join(dirpath, base) + + +class MultiWorkerTutorialTest(parameterized.TestCase, tf.test.TestCase): + """Test of multi-worker training flow in tutorials on tensorflow.org. + + Please see below test method docs for what actual tutorial is being covered. + """ + + # TODO(rchao): Add a test to demonstrate gather with MWMS. + + @contextlib.contextmanager + def skip_fetch_failure_exception(self): + try: + yield + except zipfile.BadZipfile as e: + # There can be a race when multiple processes are downloading the data. + # Skip the test if that results in loading errors. + self.skipTest('Data loading error: Bad magic number for file header.') + except Exception as e: # pylint: disable=broad-except + if 'URL fetch failure' in str(e): + self.skipTest('URL fetch error not considered failure of the test.') + else: + raise + + def mnist_dataset(self): + path_to_use = 'mnist_{}.npz'.format(str(uuid.uuid4())) + with self.skip_fetch_failure_exception(): + (x_train, + y_train), _ = tf.keras.datasets.mnist.load_data(path=path_to_use) + # The `x` arrays are in uint8 and have values in the range [0, 255]. + # We need to convert them to float32 with values in the range [0, 1] + x_train = x_train / np.float32(255) + y_train = y_train.astype(np.int64) + train_dataset = tf.data.Dataset.from_tensor_slices( + (x_train, y_train)).shuffle(60000) + return train_dataset + + def dataset_fn(self, global_batch_size, input_context): + batch_size = input_context.get_per_replica_batch_size(global_batch_size) + dataset = self.mnist_dataset() + dataset = dataset.shard(input_context.num_input_pipelines, + input_context.input_pipeline_id) + dataset = dataset.batch(batch_size) + return dataset + + def build_cnn_model(self): + return tf.keras.Sequential([ + tf.keras.layers.Input(shape=(28, 28)), + tf.keras.layers.Reshape(target_shape=(28, 28, 1)), + tf.keras.layers.Conv2D(32, 3, activation='relu'), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(10) + ]) + + def build_and_compile_cnn_model(self): + model = self.build_cnn_model() + model.compile( + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), + metrics=['accuracy']) + return model + + @tf.__internal__.test.combinations.generate( + tf.__internal__.test.combinations.combine( + mode=['eager'], tf_api_version=2)) + def testSingleWorkerModelFit(self): + single_worker_dataset = self.mnist_dataset().batch( + PER_WORKER_BATCH_SIZE) + single_worker_model = self.build_and_compile_cnn_model() + single_worker_model.fit(single_worker_dataset, epochs=NUM_EPOCHS) + + @tf.__internal__.test.combinations.generate( + tf.__internal__.test.combinations.combine( + mode=['eager'], tf_api_version=2)) + def testMwmsWithModelFit(self, mode): + """Test multi-worker training flow demo'ed in go/multi-worker-with-keras. + + This test should be kept in sync with the code samples in + go/multi-worker-with-keras. + + Args: + mode: Runtime mode. + """ + def fn(model_path, checkpoint_dir): + global_batch_size = PER_WORKER_BATCH_SIZE * NUM_WORKERS + strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() + with strategy.scope(): + multi_worker_model = self.build_and_compile_cnn_model() + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(self.get_temp_dir(), 'checkpoint')) + ] + + multi_worker_dataset = strategy.distribute_datasets_from_function( + lambda input_context: self.dataset_fn(global_batch_size, input_context + )) + + multi_worker_model.fit( + multi_worker_dataset, + epochs=NUM_EPOCHS, + steps_per_epoch=50, + callbacks=callbacks) + + task_type, task_id = (strategy.cluster_resolver.task_type, + strategy.cluster_resolver.task_id) + write_model_path = write_filepath(model_path, task_type, task_id) + + multi_worker_model.save(write_model_path) + if not _is_chief(task_type, task_id): + tf.io.gfile.rmtree(os.path.dirname(write_model_path)) + + # Make sure chief finishes saving before non-chief's assertions. + tf.__internal__.distribute.multi_process_runner.get_barrier().wait() + + if not tf.io.gfile.exists(model_path): + raise RuntimeError() + if tf.io.gfile.exists(write_model_path) != _is_chief(task_type, task_id): + raise RuntimeError() + + with strategy.scope(): + loaded_model = tf.keras.models.load_model(model_path) + loaded_model.fit(multi_worker_dataset, epochs=1, steps_per_epoch=1) + + checkpoint = tf.train.Checkpoint(model=multi_worker_model) + write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id) + checkpoint_manager = tf.train.CheckpointManager( + checkpoint, directory=write_checkpoint_dir, max_to_keep=1) + + checkpoint_manager.save() + if not _is_chief(task_type, task_id): + tf.io.gfile.rmtree(write_checkpoint_dir) + + # Make sure chief finishes saving before non-chief's assertions. + tf.__internal__.distribute.multi_process_runner.get_barrier().wait() + + if not tf.io.gfile.exists(checkpoint_dir): + raise RuntimeError() + if tf.io.gfile.exists(write_checkpoint_dir) != _is_chief( + task_type, task_id): + raise RuntimeError() + + latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) + checkpoint.restore(latest_checkpoint) + multi_worker_model.fit(multi_worker_dataset, epochs=1, steps_per_epoch=1) + + logging.info('testMwmsWithModelFit successfully ends') + + model_path = os.path.join(self.get_temp_dir(), 'model.tf') + checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt') + try: + mpr_result = tf.__internal__.distribute.multi_process_runner.run( + fn, + tf.__internal__.distribute.multi_process_runner.create_cluster_spec( + num_workers=NUM_WORKERS), + args=(model_path, checkpoint_dir), + return_output=True) + except tf.errors.UnavailableError: + self.skipTest('Skipping rare disconnection among the workers.') + + self.assertTrue( + any([ + 'testMwmsWithModelFit successfully ends' in msg + for msg in mpr_result.stdout + ])) + + def extract_accuracy(worker_id, input_string): + match = re.match( + r'\[worker\-{}\].*accuracy: (\d+\.\d+).*'.format(worker_id), + input_string) + return None if match is None else float(match.group(1)) + + for worker_id in range(NUM_WORKERS): + accu_result = tf.nest.map_structure( + lambda x: extract_accuracy(worker_id, x), # pylint: disable=cell-var-from-loop + mpr_result.stdout) + self.assertTrue( + any(accu_result), 'Every worker is supposed to have accuracy result.') + + @tf.__internal__.test.combinations.generate( + tf.__internal__.test.combinations.combine( + mode=['eager'], tf_api_version=2)) + def testMwmsWithCtl(self, mode): + """Test multi-worker CTL training flow demo'ed in a to-be-added tutorial.""" + + def proc_func(checkpoint_dir): + global_batch_size = PER_WORKER_BATCH_SIZE * NUM_WORKERS + strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() + try: + + with strategy.scope(): + multi_worker_model = self.build_cnn_model() + + multi_worker_dataset = strategy.distribute_datasets_from_function( + lambda input_context: self.dataset_fn(global_batch_size, # pylint: disable=g-long-lambda + input_context)) + optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001) + train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( + name='train_accuracy') + + @tf.function + def train_step(iterator): + """Training step function.""" + + def step_fn(inputs): + """Per-Replica step function.""" + x, y = inputs + with tf.GradientTape() as tape: + predictions = multi_worker_model(x, training=True) + per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True, + reduction=tf.keras.losses.Reduction.NONE)(y, predictions) + loss = tf.nn.compute_average_loss( + per_batch_loss, global_batch_size=global_batch_size) + + grads = tape.gradient(loss, multi_worker_model.trainable_variables) + optimizer.apply_gradients( + zip(grads, multi_worker_model.trainable_variables)) + train_accuracy.update_state(y, predictions) + + return loss + + per_replica_losses = strategy.run(step_fn, args=(next(iterator),)) + return strategy.reduce( + tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) + + epoch = tf.Variable( + initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch') + step_in_epoch = tf.Variable( + initial_value=tf.constant(0, dtype=tf.dtypes.int64), + name='step_in_epoch') + + task_type, task_id = (strategy.cluster_resolver.task_type, + strategy.cluster_resolver.task_id) + checkpoint = tf.train.Checkpoint( + model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch) + write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, + task_id) + checkpoint_manager = tf.train.CheckpointManager( + checkpoint, directory=write_checkpoint_dir, max_to_keep=1) + + latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) + if latest_checkpoint: + checkpoint.restore(latest_checkpoint) + + while epoch.numpy() < NUM_EPOCHS: + iterator = iter(multi_worker_dataset) + total_loss = 0.0 + num_batches = 0 + + while step_in_epoch.numpy() < NUM_STEPS_PER_EPOCH: + total_loss += train_step(iterator) + num_batches += 1 + step_in_epoch.assign_add(1) + + train_loss = total_loss / num_batches + logging.info('Epoch: %d, accuracy: %f, train_loss: %f.', + epoch.numpy(), train_accuracy.result(), train_loss) + + train_accuracy.reset_states() + + checkpoint_manager.save() + if not _is_chief(task_type, task_id): + tf.io.gfile.rmtree(write_checkpoint_dir) + + epoch.assign_add(1) + step_in_epoch.assign(0) + + except tf.errors.UnavailableError as e: + logging.info('UnavailableError occurred: %r', e) + raise unittest.SkipTest('Skipping test due to UnavailableError') + + logging.info('testMwmsWithCtl successfully ends') + + checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt') + + mpr_result = tf.__internal__.distribute.multi_process_runner.run( + proc_func, + tf.__internal__.distribute.multi_process_runner.create_cluster_spec( + num_workers=NUM_WORKERS), + return_output=True, + args=(checkpoint_dir,)) + + self.assertTrue( + any([ + 'testMwmsWithCtl successfully ends' in msg + for msg in mpr_result.stdout + ])) + + +if __name__ == '__main__': + tf.__internal__.distribute.multi_process_runner.test_main()