提交 64edb2fb 编写于 作者: R Rick Chao 提交者: TensorFlower Gardener

Multi-worker tutorial: Add the workflow of MWMS+CTL example that is going to...

Multi-worker tutorial: Add the workflow of MWMS+CTL example that is going to be added in the tutorial in multi_worker_tutorial_test.

Fix the flakiness of the test and re-enable in TAP.

PiperOrigin-RevId: 339966024
Change-Id: Icb866f8a7054fa88f2e474c02960982a57c542b3
上级 3d1f1b06
......@@ -789,29 +789,6 @@ py_library(
],
)
py_test(
name = "multi_worker_tutorial_test",
srcs = ["multi_worker_tutorial_test.py"],
python_version = "PY3",
shard_count = 5,
tags = [
"noasan", # TODO(b/156029134)
"nomsan", # TODO(b/156029134)
"notap", # TODO(b/165865820): restore when not flaky
"notsan", # TODO(b/156029134)
],
deps = [
"//tensorflow/python:platform",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:collective_all_reduce_strategy",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:multi_process_runner",
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/keras",
"//tensorflow/python/keras/optimizer_v2",
],
)
distribute_py_test(
name = "saved_model_save_load_test",
size = "medium",
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for multi-worker training tutorial."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import os
import re
import zipfile
from absl import logging
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.datasets import mnist
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import test
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.tracking import util as tracking_util
from tensorflow.python.util import nest
class MultiWorkerTutorialTest(parameterized.TestCase, test.TestCase):
"""Test multi-worker training flow demo'ed in go/multi-worker-with-keras."""
@contextlib.contextmanager
def skip_fetch_failure_exception(self):
try:
yield
except zipfile.BadZipfile as e:
self.skipTest('Data loading error: Bad magic number for file header.')
except Exception as e: # pylint: disable=broad-except
if 'URL fetch failure' in str(e):
self.skipTest('URL fetch error not considered failure of the test.')
else:
raise
@ds_combinations.generate(
combinations.combine(
mode=['eager'],
shard_policy=[None] + list(distribute_options.AutoShardPolicy)))
def testMultiWorkerTutorial(self, mode, shard_policy):
"""Test multi-worker training flow demo'ed in go/multi-worker-with-keras.
This test should be kept in sync with the code samples in
go/multi-worker-with-keras.
Args:
mode: Runtime mode.
shard_policy: None or any of tf.data.experimental.AutoShardPolicy for
testing.
"""
if shard_policy is distribute_options.AutoShardPolicy.FILE:
self.skipTest('TensorSliceDataset is not shardable with FILE policy.')
def mnist_dataset(batch_size):
with self.skip_fetch_failure_exception():
(x_train, y_train), _ = mnist.load_data()
# The `x` arrays are in uint8 and have values in the range [0, 255].
# We need to convert them to float32 with values in the range [0, 1]
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = dataset_ops.DatasetV2.from_tensor_slices(
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset
def build_and_compile_cnn_model():
model = keras.Sequential([
keras.layers.Input(shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(32, 3, activation='relu'),
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10)
])
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=gradient_descent.SGD(learning_rate=0.001),
metrics=['accuracy'])
return model
per_worker_batch_size = 64
single_worker_dataset = mnist_dataset(per_worker_batch_size)
single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)
num_workers = 4
def fn(model_path, checkpoint_dir):
global_batch_size = per_worker_batch_size * num_workers
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
with strategy.scope():
multi_worker_model = build_and_compile_cnn_model()
callbacks = [
keras.callbacks.ModelCheckpoint(
filepath=os.path.join(self.get_temp_dir(), 'checkpoint'))
]
multi_worker_dataset = mnist_dataset(global_batch_size)
if shard_policy:
options = dataset_ops.Options()
options.experimental_distribute.auto_shard_policy = shard_policy
multi_worker_dataset = multi_worker_dataset.with_options(options)
multi_worker_model.fit(
multi_worker_dataset,
epochs=2,
steps_per_epoch=20,
callbacks=callbacks)
def _is_chief(task_type, task_id):
return task_type is None or task_type == 'chief' or (
task_type == 'worker' and task_id == 0)
def _get_temp_dir(dirpath, task_id):
base_dirpath = 'workertemp_' + str(task_id)
temp_dir = os.path.join(dirpath, base_dirpath)
file_io.recursive_create_dir_v2(temp_dir)
return temp_dir
def write_filepath(filepath, task_type, task_id):
dirpath = os.path.dirname(filepath)
base = os.path.basename(filepath)
if not _is_chief(task_type, task_id):
dirpath = _get_temp_dir(dirpath, task_id)
return os.path.join(dirpath, base)
task_type, task_id = (strategy.cluster_resolver.task_type,
strategy.cluster_resolver.task_id)
write_model_path = write_filepath(model_path, task_type, task_id)
multi_worker_model.save(write_model_path)
if not _is_chief(task_type, task_id):
file_io.delete_recursively_v2(os.path.dirname(write_model_path))
# Make sure chief finishes saving before non-chief's assertions.
multi_process_runner.get_barrier().wait()
if not file_io.file_exists_v2(model_path):
raise RuntimeError()
if file_io.file_exists_v2(write_model_path) != _is_chief(
task_type, task_id):
raise RuntimeError()
loaded_model = keras.saving.save.load_model(model_path)
loaded_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)
checkpoint = tracking_util.Checkpoint(model=multi_worker_model)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = checkpoint_management.CheckpointManager(
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
checkpoint_manager.save()
if not _is_chief(task_type, task_id):
file_io.delete_recursively_v2(write_checkpoint_dir)
# Make sure chief finishes saving before non-chief's assertions.
multi_process_runner.get_barrier().wait()
if not file_io.file_exists_v2(checkpoint_dir):
raise RuntimeError()
if file_io.file_exists_v2(write_checkpoint_dir) != _is_chief(
task_type, task_id):
raise RuntimeError()
latest_checkpoint = checkpoint_management.latest_checkpoint(
checkpoint_dir)
checkpoint.restore(latest_checkpoint)
multi_worker_model.fit(multi_worker_dataset, epochs=2, steps_per_epoch=20)
logging.info('testMultiWorkerTutorial successfully ends')
model_path = os.path.join(self.get_temp_dir(), 'model.tf')
checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt')
try:
mpr_result = multi_process_runner.run(
fn,
multi_worker_test_base.create_cluster_spec(num_workers=num_workers),
args=(model_path, checkpoint_dir),
return_output=True)
except errors_impl.UnavailableError as e:
self.skipTest('Skipping error: {}: {}'.format(type(e), str(e)))
self.assertTrue(
any([
'testMultiWorkerTutorial successfully ends' in msg
for msg in mpr_result.stdout
]))
def extract_accuracy(worker_id, input_string):
match = re.match(
r'\[worker\-{}\].*accuracy: (\d+\.\d+).*'.format(worker_id),
input_string)
return None if match is None else float(match.group(1))
for worker_id in range(num_workers):
accu_result = nest.map_structure(
lambda x: extract_accuracy(worker_id, x), # pylint: disable=cell-var-from-loop
mpr_result.stdout)
self.assertTrue(
any(accu_result), 'Every worker is supposed to have accuracy result.')
if __name__ == '__main__':
multi_process_runner.test_main()
......@@ -103,3 +103,18 @@ tpu_py_test(
"//tensorflow/python:extra_py_tests_deps",
],
)
tf_py_test(
name = "multi_worker_tutorial_test",
srcs = ["multi_worker_tutorial_test.py"],
python_version = "PY3",
shard_count = 3,
tags = [
"noasan", # TODO(b/156029134)
"nomsan", # TODO(b/156029134)
"notsan", # TODO(b/156029134)
],
deps = [
"//tensorflow:tensorflow_py",
],
)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for multi-worker training tutorial."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import os
import re
import unittest
import uuid
import zipfile
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
PER_WORKER_BATCH_SIZE = 64
NUM_WORKERS = 2
NUM_EPOCHS = 2
NUM_STEPS_PER_EPOCH = 50
def _is_chief(task_type, task_id):
return task_type is None or task_type == 'chief' or (task_type == 'worker' and
task_id == 0)
def _get_temp_dir(dirpath, task_id):
base_dirpath = 'workertemp_' + str(task_id)
temp_dir = os.path.join(dirpath, base_dirpath)
tf.io.gfile.makedirs(temp_dir)
return temp_dir
def write_filepath(filepath, task_type, task_id):
dirpath = os.path.dirname(filepath)
base = os.path.basename(filepath)
if not _is_chief(task_type, task_id):
dirpath = _get_temp_dir(dirpath, task_id)
return os.path.join(dirpath, base)
class MultiWorkerTutorialTest(parameterized.TestCase, tf.test.TestCase):
"""Test of multi-worker training flow in tutorials on tensorflow.org.
Please see below test method docs for what actual tutorial is being covered.
"""
# TODO(rchao): Add a test to demonstrate gather with MWMS.
@contextlib.contextmanager
def skip_fetch_failure_exception(self):
try:
yield
except zipfile.BadZipfile as e:
# There can be a race when multiple processes are downloading the data.
# Skip the test if that results in loading errors.
self.skipTest('Data loading error: Bad magic number for file header.')
except Exception as e: # pylint: disable=broad-except
if 'URL fetch failure' in str(e):
self.skipTest('URL fetch error not considered failure of the test.')
else:
raise
def mnist_dataset(self):
path_to_use = 'mnist_{}.npz'.format(str(uuid.uuid4()))
with self.skip_fetch_failure_exception():
(x_train,
y_train), _ = tf.keras.datasets.mnist.load_data(path=path_to_use)
# The `x` arrays are in uint8 and have values in the range [0, 255].
# We need to convert them to float32 with values in the range [0, 1]
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(60000)
return train_dataset
def dataset_fn(self, global_batch_size, input_context):
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = self.mnist_dataset()
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
dataset = dataset.batch(batch_size)
return dataset
def build_cnn_model(self):
return tf.keras.Sequential([
tf.keras.layers.Input(shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
def build_and_compile_cnn_model(self):
model = self.build_cnn_model()
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=['accuracy'])
return model
@tf.__internal__.test.combinations.generate(
tf.__internal__.test.combinations.combine(
mode=['eager'], tf_api_version=2))
def testSingleWorkerModelFit(self):
single_worker_dataset = self.mnist_dataset().batch(
PER_WORKER_BATCH_SIZE)
single_worker_model = self.build_and_compile_cnn_model()
single_worker_model.fit(single_worker_dataset, epochs=NUM_EPOCHS)
@tf.__internal__.test.combinations.generate(
tf.__internal__.test.combinations.combine(
mode=['eager'], tf_api_version=2))
def testMwmsWithModelFit(self, mode):
"""Test multi-worker training flow demo'ed in go/multi-worker-with-keras.
This test should be kept in sync with the code samples in
go/multi-worker-with-keras.
Args:
mode: Runtime mode.
"""
def fn(model_path, checkpoint_dir):
global_batch_size = PER_WORKER_BATCH_SIZE * NUM_WORKERS
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
multi_worker_model = self.build_and_compile_cnn_model()
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(self.get_temp_dir(), 'checkpoint'))
]
multi_worker_dataset = strategy.distribute_datasets_from_function(
lambda input_context: self.dataset_fn(global_batch_size, input_context
))
multi_worker_model.fit(
multi_worker_dataset,
epochs=NUM_EPOCHS,
steps_per_epoch=50,
callbacks=callbacks)
task_type, task_id = (strategy.cluster_resolver.task_type,
strategy.cluster_resolver.task_id)
write_model_path = write_filepath(model_path, task_type, task_id)
multi_worker_model.save(write_model_path)
if not _is_chief(task_type, task_id):
tf.io.gfile.rmtree(os.path.dirname(write_model_path))
# Make sure chief finishes saving before non-chief's assertions.
tf.__internal__.distribute.multi_process_runner.get_barrier().wait()
if not tf.io.gfile.exists(model_path):
raise RuntimeError()
if tf.io.gfile.exists(write_model_path) != _is_chief(task_type, task_id):
raise RuntimeError()
with strategy.scope():
loaded_model = tf.keras.models.load_model(model_path)
loaded_model.fit(multi_worker_dataset, epochs=1, steps_per_epoch=1)
checkpoint = tf.train.Checkpoint(model=multi_worker_model)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
checkpoint_manager.save()
if not _is_chief(task_type, task_id):
tf.io.gfile.rmtree(write_checkpoint_dir)
# Make sure chief finishes saving before non-chief's assertions.
tf.__internal__.distribute.multi_process_runner.get_barrier().wait()
if not tf.io.gfile.exists(checkpoint_dir):
raise RuntimeError()
if tf.io.gfile.exists(write_checkpoint_dir) != _is_chief(
task_type, task_id):
raise RuntimeError()
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint)
multi_worker_model.fit(multi_worker_dataset, epochs=1, steps_per_epoch=1)
logging.info('testMwmsWithModelFit successfully ends')
model_path = os.path.join(self.get_temp_dir(), 'model.tf')
checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt')
try:
mpr_result = tf.__internal__.distribute.multi_process_runner.run(
fn,
tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
num_workers=NUM_WORKERS),
args=(model_path, checkpoint_dir),
return_output=True)
except tf.errors.UnavailableError:
self.skipTest('Skipping rare disconnection among the workers.')
self.assertTrue(
any([
'testMwmsWithModelFit successfully ends' in msg
for msg in mpr_result.stdout
]))
def extract_accuracy(worker_id, input_string):
match = re.match(
r'\[worker\-{}\].*accuracy: (\d+\.\d+).*'.format(worker_id),
input_string)
return None if match is None else float(match.group(1))
for worker_id in range(NUM_WORKERS):
accu_result = tf.nest.map_structure(
lambda x: extract_accuracy(worker_id, x), # pylint: disable=cell-var-from-loop
mpr_result.stdout)
self.assertTrue(
any(accu_result), 'Every worker is supposed to have accuracy result.')
@tf.__internal__.test.combinations.generate(
tf.__internal__.test.combinations.combine(
mode=['eager'], tf_api_version=2))
def testMwmsWithCtl(self, mode):
"""Test multi-worker CTL training flow demo'ed in a to-be-added tutorial."""
def proc_func(checkpoint_dir):
global_batch_size = PER_WORKER_BATCH_SIZE * NUM_WORKERS
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
try:
with strategy.scope():
multi_worker_model = self.build_cnn_model()
multi_worker_dataset = strategy.distribute_datasets_from_function(
lambda input_context: self.dataset_fn(global_batch_size, # pylint: disable=g-long-lambda
input_context))
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
@tf.function
def train_step(iterator):
"""Training step function."""
def step_fn(inputs):
"""Per-Replica step function."""
x, y = inputs
with tf.GradientTape() as tape:
predictions = multi_worker_model(x, training=True)
per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)(y, predictions)
loss = tf.nn.compute_average_loss(
per_batch_loss, global_batch_size=global_batch_size)
grads = tape.gradient(loss, multi_worker_model.trainable_variables)
optimizer.apply_gradients(
zip(grads, multi_worker_model.trainable_variables))
train_accuracy.update_state(y, predictions)
return loss
per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
return strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
epoch = tf.Variable(
initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(
initial_value=tf.constant(0, dtype=tf.dtypes.int64),
name='step_in_epoch')
task_type, task_id = (strategy.cluster_resolver.task_type,
strategy.cluster_resolver.task_id)
checkpoint = tf.train.Checkpoint(
model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type,
task_id)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
while epoch.numpy() < NUM_EPOCHS:
iterator = iter(multi_worker_dataset)
total_loss = 0.0
num_batches = 0
while step_in_epoch.numpy() < NUM_STEPS_PER_EPOCH:
total_loss += train_step(iterator)
num_batches += 1
step_in_epoch.assign_add(1)
train_loss = total_loss / num_batches
logging.info('Epoch: %d, accuracy: %f, train_loss: %f.',
epoch.numpy(), train_accuracy.result(), train_loss)
train_accuracy.reset_states()
checkpoint_manager.save()
if not _is_chief(task_type, task_id):
tf.io.gfile.rmtree(write_checkpoint_dir)
epoch.assign_add(1)
step_in_epoch.assign(0)
except tf.errors.UnavailableError as e:
logging.info('UnavailableError occurred: %r', e)
raise unittest.SkipTest('Skipping test due to UnavailableError')
logging.info('testMwmsWithCtl successfully ends')
checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt')
mpr_result = tf.__internal__.distribute.multi_process_runner.run(
proc_func,
tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
num_workers=NUM_WORKERS),
return_output=True,
args=(checkpoint_dir,))
self.assertTrue(
any([
'testMwmsWithCtl successfully ends' in msg
for msg in mpr_result.stdout
]))
if __name__ == '__main__':
tf.__internal__.distribute.multi_process_runner.test_main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册