diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 52be7493cf229becc24d6b83ef140e0708f479e0..60f844b27bef1f62767bc3e3613cc62bd1a75d61 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -93,10 +93,10 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''): if in_dygraph_mode(): return - from .dygraph.dygraph_to_static.program_translator import in_declarative_mode # NOTE: `in_declarative_mode` is used to determined whether this op is called under # @declarative in transformation from dygrah to static layer. We add VarBase in # expected_type to skip checking because varBase may be created and used in unusual way. + from .dygraph.base import in_declarative_mode # Need a better design to be fix this. if in_declarative_mode(): if not isinstance(expected_type, tuple): diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 460831f8745b31fef28502c3a67b12a04d765ced..f54a1629196a0c382fa27f53c8f77afadce3a17d 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -33,6 +33,17 @@ __all__ = [ 'enabled', 'to_variable' ] +# Flag that indicates whether running code under `@declarative` +_in_declarative_mode_ = False + + +def in_declarative_mode(): + """ + Return a bool value that indicates whether running code under `@declarative` + + """ + return _in_declarative_mode_ + def _switch_to_static_graph_(func): def __impl__(*args, **kwargs): @@ -45,6 +56,16 @@ def _switch_to_static_graph_(func): switch_to_static_graph = wrap_decorator(_switch_to_static_graph_) +@signature_safe_contextmanager +def _switch_declarative_mode_guard_(is_declarative=True): + + global _in_declarative_mode_ + original_val = _in_declarative_mode_ + _in_declarative_mode_ = is_declarative + yield + _in_declarative_mode_ = original_val + + @signature_safe_contextmanager def program_desc_tracing_guard(enable): tracer = framework._dygraph_tracer() @@ -63,7 +84,6 @@ _functional_dygraph_context_manager = None @signature_safe_contextmanager def param_guard(parameters): - from paddle.fluid.dygraph.dygraph_to_static.program_translator import in_declarative_mode # Note: parameters is a reference of self._parameters or self._buffers if in_declarative_mode() and not framework.in_dygraph_mode() and parameters: origin_parameters = parameters.copy() diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index d5d0e8ab88b869a4fd000c63ac29b9dc0b45c8e1..19479a190c3b9e83e267fa1a7acbfc007f34ec58 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -573,28 +573,6 @@ class StaticFunction(object): return self._function_spec -# Flag that indicates whether running code under `@declarative` -_in_declarative_mode_ = False - - -def in_declarative_mode(): - """ - Return a bool value that indicates whether running code under `@declarative` - - """ - return _in_declarative_mode_ - - -@signature_safe_contextmanager -def _switch_declarative_mode_guard_(is_declarative=True): - - global _in_declarative_mode_ - original_val = _in_declarative_mode_ - _in_declarative_mode_ = is_declarative - yield - _in_declarative_mode_ = original_val - - def _verify_init_in_dynamic_mode(class_instance): """ Verifies the instance is initialized in dynamic mode. @@ -658,6 +636,7 @@ class ConcreteProgram(object): startup_program.random_seed = framework.default_startup_program( ).random_seed + from paddle.fluid.dygraph.base import _switch_declarative_mode_guard_ with framework.program_guard(main_program, startup_program): with _switch_declarative_mode_guard_(is_declarative=True): # 1. Adds `fluid.data` layers for input if needed diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 11812398ba455053bbc6fbcfbc1dfc99df31448d..0373c1e63da81574d53248bd3da93854a755ae1a 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -31,7 +31,7 @@ from .. import unique_name from paddle.fluid import core from .layer_object_helper import LayerObjectHelper from .layer_hooks import record_program_ops_pre_hook, set_op_customized_attrs_post_hook, LayerOpsRecoder -from .base import program_desc_tracing_guard, param_guard +from .base import program_desc_tracing_guard, param_guard, in_declarative_mode from paddle.fluid import framework from ..param_attr import ParamAttr from paddle.fluid.executor import Executor, global_scope @@ -917,7 +917,6 @@ class Layer(object): # In case of ControlFlow, true_fn and false_fn will contain # parameters that may not trigger logic of `Operator` to create # them. we add this to make sure all parameters is available. - from paddle.fluid.dygraph.dygraph_to_static.program_translator import in_declarative_mode if in_declarative_mode() and not framework.in_dygraph_mode(): with param_guard(self._parameters), param_guard(self._buffers):