提交 05483cd4 编写于 作者: S Scott Zhu 提交者: TensorFlower Gardener

Fork the tf function related util to keras.

PiperOrigin-RevId: 339983364
Change-Id: I877f2394f13b899ace0bb2893e6cb5f073b03458
上级 0509af7e
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import copy
import functools
import warnings
from tensorflow.python.eager import context
......@@ -28,11 +29,12 @@ from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.mixed_precision import policy
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
# Avoid breaking users who directly import this symbol from this file.
......@@ -541,7 +543,7 @@ class Layer(base_layer.Layer):
try:
call_has_scope_arg = self._call_has_scope_arg
except AttributeError:
self._call_fn_args = function_utils.fn_args(self.call)
self._call_fn_args = fn_args(self.call)
self._call_has_scope_arg = 'scope' in self._call_fn_args
call_has_scope_arg = self._call_has_scope_arg
if call_has_scope_arg:
......@@ -595,3 +597,35 @@ def _add_elements_to_collection(elements, collection_list):
for element in elements:
if id(element) not in collection_set:
collection.append(element)
def fn_args(fn):
"""Get argument names for function-like object.
Args:
fn: Function, or function-like object (e.g., result of `functools.partial`).
Returns:
`tuple` of string argument names.
Raises:
ValueError: if partial function has positionally bound arguments
"""
if isinstance(fn, functools.partial):
args = fn_args(fn.func)
args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
else:
if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
fn = fn.__call__
args = tf_inspect.getfullargspec(fn).args
if is_bound_method(fn) and args:
# If it's a bound method, it may or may not have a self/cls first
# argument; for example, self could be captured in *args.
# If it does have a positional argument, it is self/cls.
args.pop(0)
return tuple(args)
def is_bound_method(fn):
_, fn = tf_decorator.unwrap(fn)
return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册