未验证 提交 2646d230 编写于 作者: G Goldie Gadde 提交者: GitHub

Merge pull request #32669 from tensorflow/ggadde-cp-19

[r2.0-CherryPick]:[tf.data] Avoid double conversion to a tensor during input normalizat…
......@@ -31,6 +31,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
......@@ -83,24 +84,31 @@ def normalize_element(element):
components = nest.flatten(element)
normalized_components = []
with ops.name_scope("normalize_element"):
# Imported here to avoid circular dependency
# Imported here to avoid circular dependency.
from tensorflow.python.data.ops import dataset_ops # pylint: disable=g-import-not-at-top
for i, t in enumerate(components):
spec = type_spec_from_value(t)
if isinstance(spec, sparse_tensor.SparseTensorSpec):
normalized_components.append(sparse_tensor.SparseTensor.from_value(t))
elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
normalized_components.append(
ragged_tensor.convert_to_tensor_or_ragged_tensor(
t, name="component_%d" % i))
elif isinstance(
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
normalized_components.append(t)
elif isinstance(t, composite_tensor.CompositeTensor):
normalized_components.append(t)
else:
try:
spec = type_spec_from_value(t, use_fallback=False)
except TypeError:
# TypeError indicates it was not possible to compute a `TypeSpec` for
# the value. As a fallback try converting the value to a tensor.
normalized_components.append(
ops.convert_to_tensor(t, name="component_%d" % i))
else:
if isinstance(spec, sparse_tensor.SparseTensorSpec):
normalized_components.append(sparse_tensor.SparseTensor.from_value(t))
elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
normalized_components.append(
ragged_tensor.convert_to_tensor_or_ragged_tensor(
t, name="component_%d" % i))
elif isinstance(
spec, (tensor_array_ops.TensorArraySpec, dataset_ops.DatasetSpec)):
normalized_components.append(t)
elif isinstance(t, composite_tensor.CompositeTensor):
normalized_components.append(t)
else:
normalized_components.append(
ops.convert_to_tensor(t, name="component_%d" % i))
return nest.pack_sequence_as(element, normalized_components)
......@@ -392,11 +400,13 @@ def are_compatible(spec1, spec2):
return True
def type_spec_from_value(element):
def type_spec_from_value(element, use_fallback=True):
"""Creates a type specification for the given value.
Args:
element: The element to create the type specification for.
use_fallback: Whether to fall back to converting the element to a tensor
in order to compute its `TypeSpec`.
Returns:
A nested structure of `TypeSpec`s that represents the type specification
......@@ -432,14 +442,16 @@ def type_spec_from_value(element):
# `element` is not a namedtuple
return tuple([type_spec_from_value(v) for v in element])
# Fallback: try converting value to a tensor.
try:
tensor = ops.convert_to_tensor(element)
spec = type_spec_from_value(tensor)
if spec is not None:
return spec
except (ValueError, TypeError):
pass
if use_fallback:
# As a fallback try converting the element to a tensor.
try:
tensor = ops.convert_to_tensor(element)
spec = type_spec_from_value(tensor)
if spec is not None:
return spec
except (ValueError, TypeError) as e:
logging.vlog(
3, "Failed to convert %r to tensor: %s" % (type(element).__name__, e))
raise TypeError("Could not build a TypeSpec for %r with type %s" %
(element, type(element).__name__))
......@@ -26,6 +26,7 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
......@@ -483,8 +484,9 @@ def type_spec_from_value(value):
spec = _type_spec_from_value(tensor)
if spec is not None:
return spec
except (ValueError, TypeError):
pass
except (ValueError, TypeError) as e:
logging.vlog(
3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e))
raise TypeError("Could not build a TypeSpec for %r with type %s" %
(value, type(value).__name__))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册