未验证 提交 656e7a2b 编写于 作者: A Amit Patankar 提交者: GitHub

Merge pull request #21425 from saeta/fix_tpu

Refactor dependencies so keras_support can be imported directly.
......@@ -107,7 +107,6 @@ py_library(
"//tensorflow/contrib/tfprof",
"//tensorflow/contrib/timeseries",
"//tensorflow/contrib/tpu",
"//tensorflow/contrib/tpu:tpu_py",
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
......
......@@ -2,6 +2,7 @@ tensorflow/core
tensorflow/core/kernels/boosted_trees
tensorflow/core/profiler
tensorflow/python
tensorflow/compiler/xla
tensorflow/contrib/boosted_trees/proto
tensorflow/contrib/cloud/kernels
tensorflow/contrib/decision_trees/proto
......
......@@ -272,8 +272,7 @@ py_library(
deps = [
":one_device_strategy",
":values",
"//tensorflow/contrib/tpu",
"//tensorflow/contrib/tpu:tpu_py",
"//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
......
......@@ -46,7 +46,8 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":tpu_lib",
":tpu_py",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/compiler/xla/python_api:xla_shape",
"//tensorflow/contrib/training:training_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
......@@ -133,7 +134,7 @@ py_library(
tf_custom_op_py_library(
name = "tpu_py",
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
srcs = glob(["python/ops/*.py"]),
dso = [":python/ops/_tpu_ops.so"],
kernels = [
":all_ops",
......@@ -152,9 +153,13 @@ tf_custom_op_py_library(
py_library(
name = "tpu",
srcs = ["python/tpu/__init__.py"],
srcs = [
"__init__.py",
"python/tpu/__init__.py",
],
srcs_version = "PY2AND3",
deps = [
":keras_support", # split out to avoid cycle with tpu_strategy
":tpu_estimator",
":tpu_lib",
],
......@@ -166,11 +171,13 @@ py_library(
"python/tpu/keras_support.py",
],
srcs_version = "PY2AND3",
visibility = [
"//tensorflow:__subpackages__",
],
deps = [
":tpu_lib",
":tpu_py",
"//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/contrib/distribute",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
"//tensorflow/core:protos_all_py",
......
......@@ -47,6 +47,9 @@
@@TPUConfig
@@bfloat16_scope
@@TPUDistributionStrategy
@@keras_to_tpu_model
"""
from __future__ import absolute_import
......@@ -58,6 +61,8 @@ from tensorflow.contrib.tpu.python import profiler
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
from tensorflow.contrib.tpu.python.tpu.bfloat16 import *
from tensorflow.contrib.tpu.python.tpu.device_assignment import *
from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model
from tensorflow.contrib.tpu.python.tpu.keras_support import TPUDistributionStrategy
from tensorflow.contrib.tpu.python.tpu.topology import *
from tensorflow.contrib.tpu.python.tpu.tpu import *
from tensorflow.contrib.tpu.python.tpu.tpu_config import *
......
......@@ -55,7 +55,6 @@ import time
import numpy as np
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver
from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
......@@ -82,7 +81,11 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
TPUDistributionStrategy = tpu_strategy.TPUStrategy # pylint: disable=invalid-name
# Work-around dependency cycle between DistributionStrategy and TPU lib.
def TPUDistributionStrategy(*args, **kw): # pylint: disable=invalid-name
from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
return tpu_strategy.TPUStrategy(*args, **kw)
class TPUEmbedding(embeddings.Embedding):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册