未验证 提交 748435b8 编写于 作者: T Todd Wang 提交者: GitHub

Fixed the issue that each invocation of model.fit/evaluate/predict modifies the (#23280)

graph.

PiperOrigin-RevId: 218793646
上级 f90c2141
......@@ -97,14 +97,25 @@ from tensorflow.python.platform import tf_logging as logging
# TODO(b/114775106): temporary shim to optionally initialize the TPU
# This increases the odds our session is initialized, but shouldn't be needed.
_TEST_REWRITE_OP = None
def _maybe_initialize_tpu(session):
"""Initialize the TPU if it has not already been initialized."""
global _TEST_REWRITE_OP
try:
# Try to use cached version to avoid another ground of graph optimization.
test_rewrite_op = _TEST_REWRITE_OP
if (test_rewrite_op is None or
test_rewrite_op[0].graph != ops.get_default_graph()):
def test_op():
return constant_op.constant(1) + constant_op.constant(1)
def test_op():
return constant_op.constant(1) + constant_op.constant(1)
test_rewrite_op = tpu.rewrite(test_op)
_TEST_REWRITE_OP = test_rewrite_op
session.run(tpu.rewrite(test_op))
session.run(test_rewrite_op)
except errors.FailedPreconditionError as _:
session.run(tpu.initialize_system())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册