From 6679bc9474d8ba01dbe19a28a553aef308371c57 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Sat, 31 Oct 2020 20:47:04 -0700 Subject: [PATCH] [tf.data] Minor cleanup PiperOrigin-RevId: 340071266 Change-Id: Ic21209a25a1f8efa1122c9cee4a8ab3b8043c308 --- tensorflow/core/kernels/data/model_dataset_op.cc | 12 +++++------- .../python/data/experimental/ops/distribute.py | 8 ++++++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index f790b4bf07f..fdf8aebc3ab 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -140,7 +140,7 @@ class ModelDatasetOp::Dataset : public DatasetBase { IteratorContext::Params params(ctx); { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx)); + TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx)); params.model = model_; int64 now_nanos = EnvTime::NowNanos(); RecordInput(now_nanos); @@ -175,18 +175,16 @@ class ModelDatasetOp::Dataset : public DatasetBase { } private: - Status EnsureOptimizeThreadStarted(IteratorContext* ctx) + Status EnsureModelThreadStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!model_thread_) { - std::shared_ptr new_ctx = - std::make_shared(*ctx); - model_thread_ = ctx->StartThread( - "tf_data_model", [this, new_ctx]() { ModelThread(new_ctx); }); + model_thread_ = + ctx->StartThread("tf_data_model", [this]() { ModelThread(); }); } return Status::OK(); } - void ModelThread(const std::shared_ptr& ctx) { + void ModelThread() { int64 last_optimization_ms = 0; int64 optimization_period_ms = 10; int64 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py index 568c01646de..a65a9d79340 100644 --- a/tensorflow/python/data/experimental/ops/distribute.py +++ b/tensorflow/python/data/experimental/ops/distribute.py @@ -330,6 +330,14 @@ def replicate(dataset, devices): return datasets with ops.colocate_with(dataset._variant_tensor): + # We apply options before replicating the dataset because options are + # currently not automatically preserved through dataset serialization and + # thus an explicit application of options here is needed to avoid losing + # `dataset` options. + # + # TODO(b/147325552): Propagating options to C++ upon their setting would + # allow us to preserve the options across both variant and GraphDef based + # serialization, avoiding the need to explicitly apply options here. dataset = dataset._apply_options() policy = dataset.options().experimental_external_state_policy if policy is None: -- GitLab