未验证 提交 eb27d8b7 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Add build_strategy in @to_static to support open pass (#34347)

* Add build_strategy in @to_static to support open pass

* fix os.environ

* add timeout

* disable test_build_strategy on openblas
上级 cf12ea51
......@@ -131,12 +131,16 @@ class PartialProgramLayer:
Layer: A Layer object that run all ops internally in static mode.
"""
def __init__(self, main_program, inputs, outputs, parameters=None):
def __init__(self, main_program, inputs, outputs, parameters=None,
**kwargs):
super(PartialProgramLayer, self).__init__()
self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else []
self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
assert isinstance(self._build_strategy, BuildStrategy)
self._origin_main_program = self._verify_program(main_program)
self._tmp_scope_vec = self._create_scope_vec()
# A fake_var to handle empty input or output
......@@ -170,7 +174,11 @@ class PartialProgramLayer:
@LazyInitialized
def _train_program_id(self):
return _hash_with_id(self._train_program, self)
program_id = _hash_with_id(self._train_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)
return program_id
def _verify_program(self, main_program):
"""
......@@ -451,6 +459,6 @@ def partial_program_from(concrete_program):
if inputs and isinstance(inputs[0], layers.Layer):
inputs = inputs[1:]
return PartialProgramLayer(concrete_program.main_program, inputs,
concrete_program.outputs,
concrete_program.parameters)
return PartialProgramLayer(
concrete_program.main_program, inputs, concrete_program.outputs,
concrete_program.parameters, **concrete_program.kwargs)
......@@ -145,14 +145,13 @@ class CacheKey(object):
"""
Cached key for ProgramCache.
"""
__slots__ = [
'function_spec', 'input_args_with_spec', 'input_kwargs_with_spec',
'class_instance'
'class_instance', 'kwargs'
]
def __init__(self, function_spec, input_args_with_spec,
input_kwargs_with_spec, class_instance):
input_kwargs_with_spec, class_instance, **kwargs):
"""
Initializes a cache key.
......@@ -161,11 +160,14 @@ class CacheKey(object):
input_args_with_spec(list[InputSpec]): actual input args with some arguments replaced by InputSpec.
input_kwargs_with_spec(list[{string:InputSpec}]): actual input kwargs with some arguments replaced by InputSpec.
class_instance(object): a instance of class `Layer`.
**kwargs(dict): manage other arguments used for better scalability
"""
self.function_spec = function_spec
self.input_args_with_spec = input_args_with_spec
self.input_kwargs_with_spec = input_kwargs_with_spec
self.class_instance = class_instance
# NOTE: `kwargs` is usually not considered as basic member for `__hash__`
self.kwargs = kwargs
@classmethod
def from_func_and_args(cls, function_spec, args, kwargs, class_instance):
......@@ -235,13 +237,14 @@ class StaticFunction(object):
"""
def __init__(self, function, input_spec=None):
def __init__(self, function, input_spec=None, **kwargs):
"""
Initializes a `StaticFunction`.
Args:
function(callable): A function or method that will be converted into static program.
input_spec(list[InputSpec]): list of InputSpec to specify the `shape/dtype/name` information for each input argument, default None.
**kwargs(dict): other arguments like `build_strategy` et.al.
"""
# save the instance `self` while decorating a method of class.
if inspect.ismethod(function):
......@@ -257,6 +260,7 @@ class StaticFunction(object):
self._descriptor_cache = weakref.WeakKeyDictionary()
# Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
self._program_trans = ProgramTranslator()
self._kwargs = kwargs
def __get__(self, instance, owner):
"""
......@@ -395,7 +399,8 @@ class StaticFunction(object):
# 2. generate cache key
cache_key = CacheKey(self._function_spec, input_args_with_spec,
input_kwargs_with_spec, self._class_instance)
input_kwargs_with_spec, self._class_instance,
**self._kwargs)
# 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[cache_key]
......@@ -586,7 +591,7 @@ class ConcreteProgram(object):
__slots__ = [
'inputs', 'outputs', 'main_program', "startup_program", "parameters",
"function"
"function", 'kwargs'
]
def __init__(self,
......@@ -595,18 +600,20 @@ class ConcreteProgram(object):
parameters,
function,
main_program,
startup_program=None):
startup_program=None,
**kwargs):
self.inputs = inputs
self.outputs = outputs
self.main_program = main_program
self.startup_program = startup_program
self.parameters = parameters
self.function = function
self.kwargs = kwargs
@staticmethod
@switch_to_static_graph
def from_func_spec(func_spec, input_spec, input_kwargs_spec,
class_instance):
def from_func_spec(func_spec, input_spec, input_kwargs_spec, class_instance,
**kwargs):
"""
Builds the main_program with specialized inputs and returns outputs
of program as fetch_list.
......@@ -635,8 +642,8 @@ class ConcreteProgram(object):
# 1. Adds `fluid.data` layers for input if needed
inputs = func_spec.to_static_inputs_with_spec(input_spec,
main_program)
kwargs = func_spec.to_static_inputs_with_spec(input_kwargs_spec,
main_program)
_kwargs = func_spec.to_static_inputs_with_spec(
input_kwargs_spec, main_program)
if class_instance:
inputs = tuple([class_instance] + list(inputs))
......@@ -649,8 +656,8 @@ class ConcreteProgram(object):
class_instance, False)), param_guard(
get_buffers(class_instance, False)):
try:
if kwargs:
outputs = static_func(*inputs, **kwargs)
if _kwargs:
outputs = static_func(*inputs, **_kwargs)
else:
outputs = static_func(*inputs)
except BaseException as e:
......@@ -675,7 +682,8 @@ class ConcreteProgram(object):
parameters=all_parameters_and_buffers,
function=dygraph_function,
main_program=main_program,
startup_program=startup_program)
startup_program=startup_program,
**kwargs)
def _extract_indeed_params_buffers(class_instance):
......@@ -702,7 +710,8 @@ class ProgramCache(object):
func_spec=cache_key.function_spec,
input_spec=cache_key.input_args_with_spec,
input_kwargs_spec=cache_key.input_kwargs_with_spec,
class_instance=cache_key.class_instance)
class_instance=cache_key.class_instance,
**cache_key.kwargs)
return concrete_program, partial_program_from(concrete_program)
def __getitem__(self, item):
......
......@@ -158,7 +158,7 @@ def copy_decorator_attrs(original_func, decorated_obj):
return decorated_obj
def declarative(function=None, input_spec=None):
def declarative(function=None, input_spec=None, build_strategy=None):
"""
Converts imperative dygraph APIs into declarative function APIs. Decorator
@declarative handles the Program and Executor of static mode and returns
......@@ -171,6 +171,12 @@ def declarative(function=None, input_spec=None):
function (callable): callable imperative function.
input_spec(list[InputSpec]|tuple[InputSpec]): list/tuple of InputSpec to specific the shape/dtype/name
information of each input Tensor.
build_strategy(BuildStrategy|None): This argument is used to compile the
converted program with the specified options, such as operators' fusion
in the computational graph and memory optimization during the execution
of the computational graph. For more information about build_strategy,
please refer to :code:`paddle.static.BuildStrategy`. The default is None.
Returns:
Tensor(s): containing the numerical result.
......@@ -206,10 +212,18 @@ def declarative(function=None, input_spec=None):
static_layer = copy_decorator_attrs(
original_func=python_func,
decorated_obj=StaticFunction(
function=python_func, input_spec=input_spec))
function=python_func,
input_spec=input_spec,
build_strategy=build_strategy))
return static_layer
build_strategy = build_strategy or BuildStrategy()
if not isinstance(build_strategy, BuildStrategy):
raise TypeError(
"Required type(build_strategy) shall be `paddle.static.BuildStrategy`, but received {}".
format(type(build_strategy).__name__))
# for usage: `declarative(foo, ...)`
if function is not None:
if isinstance(function, Layer):
......
......@@ -25,6 +25,7 @@ set_tests_properties(test_reinforcement_learning PROPERTIES TIMEOUT 120)
set_tests_properties(test_transformer PROPERTIES TIMEOUT 200)
set_tests_properties(test_bmn PROPERTIES TIMEOUT 120)
#set_tests_properties(test_mnist PROPERTIES TIMEOUT 120)
set_tests_properties(test_build_strategy PROPERTIES TIMEOUT 120)
if(NOT WIN32)
set_tests_properties(test_resnet_v2 PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import unittest
import numpy as np
from paddle.jit import ProgramTranslator
from test_resnet import ResNet, train, predict_dygraph_jit
from test_resnet import predict_dygraph, predict_static, predict_analysis_inference
program_translator = ProgramTranslator()
class TestResnetWithPass(unittest.TestCase):
def setUp(self):
self.build_strategy = paddle.static.BuildStrategy()
self.build_strategy.fuse_elewise_add_act_ops = True
self.build_strategy.fuse_bn_act_ops = True
self.build_strategy.fuse_bn_add_act_ops = True
self.build_strategy.enable_addto = True
# NOTE: for enable_addto
paddle.fluid.set_flags({"FLAGS_max_inplace_grad_add": 8})
def train(self, to_static):
program_translator.enable(to_static)
return train(to_static, self.build_strategy)
def verify_predict(self):
image = np.random.random([1, 3, 224, 224]).astype('float32')
dy_pre = predict_dygraph(image)
st_pre = predict_static(image)
dy_jit_pre = predict_dygraph_jit(image)
predictor_pre = predict_analysis_inference(image)
self.assertTrue(
np.allclose(dy_pre, st_pre),
msg="dy_pre:\n {}\n, st_pre: \n{}.".format(dy_pre, st_pre))
self.assertTrue(
np.allclose(dy_jit_pre, st_pre),
msg="dy_jit_pre:\n {}\n, st_pre: \n{}.".format(dy_jit_pre, st_pre))
self.assertTrue(
np.allclose(predictor_pre, st_pre),
msg="predictor_pre:\n {}\n, st_pre: \n{}.".format(predictor_pre,
st_pre))
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
self.assertTrue(
np.allclose(static_loss, dygraph_loss),
msg="static_loss: {} \n dygraph_loss: {}".format(static_loss,
dygraph_loss))
self.verify_predict()
def test_in_static_mode_mkldnn(self):
paddle.fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
if paddle.fluid.core.is_compiled_with_mkldnn():
train(True, self.build_strategy)
finally:
paddle.fluid.set_flags({'FLAGS_use_mkldnn': False})
class TestError(unittest.TestCase):
def test_type_error(self):
def foo(x):
out = x + 1
return out
with self.assertRaises(TypeError):
static_foo = paddle.jit.to_static(foo, build_strategy="x")
if __name__ == '__main__':
unittest.main()
......@@ -190,7 +190,6 @@ class ResNet(fluid.dygraph.Layer):
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
@declarative
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
......@@ -213,7 +212,7 @@ def reader_decorator(reader):
return __reader__
def train(to_static):
def train(to_static, build_strategy=None):
"""
Tests model decorated by `dygraph_to_static_output` in static mode. For users, the model is defined in dygraph mode and trained in static mode.
"""
......@@ -231,6 +230,8 @@ def train(to_static):
data_loader.set_sample_list_generator(train_reader)
resnet = ResNet()
if to_static:
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy)
optimizer = optimizer_setting(parameter_list=resnet.parameters())
for epoch in range(epoch_num):
......
......@@ -96,6 +96,7 @@ disable_wincpu_test="^jit_kernel_test$|\
^test_bmn$|\
^test_mobile_net$|\
^test_resnet_v2$|\
^test_build_strategy$|\
^test_se_resnet$|\
^disable_wincpu_test$"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册