未验证 提交 7c950aae 编写于 作者: H Hongsheng Zeng 提交者: GitHub

fix compatibility of batch_norm layer (#310)

* fix compatibility of batch_norm layer

* fix yapf
上级 f5ee9a1f
......@@ -46,6 +46,7 @@ from paddle.fluid.framework import Variable
from paddle.fluid.layers import *
from paddle.fluid.param_attr import ParamAttr
from parl.core.fluid.layers.attr_holder import AttrHolder
from parl.utils import get_fluid_version
def update_attr_name(name, default_name, attr, is_bias):
......@@ -497,61 +498,119 @@ def layer_norm(**kwargs):
raise NotImplementedError()
def batch_norm(act=None,
momentum=0.9,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
in_place=False,
name=None,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=False,
fuse_with_relu=False,
use_global_stats=False):
"""
Return a function that creates a paddle.fluid.layers.batch_norm.
"""
default_name = "batch_norm"
param_attr = update_attr_name(name, default_name, param_attr, False)
bias_attr = update_attr_name(name, default_name, bias_attr, True)
moving_mean_attr = update_attr_name(name, default_name + "_moving_mean",
None, False)
moving_variance_attr = update_attr_name(
name, default_name + "_moving_variance", None, False)
class BatchNorm_(LayerFunc):
def __init__(self):
super(BatchNorm_, self).__init__(
AttrHolder(
param_attr=param_attr,
bias_attr=bias_attr,
moving_mean_attr=moving_mean_attr,
moving_variance_attr=moving_variance_attr))
def __call__(self, input, is_test=False):
return layers.batch_norm(
input=input,
act=act,
is_test=is_test,
momentum=momentum,
epsilon=epsilon,
param_attr=self.attr_holder.param_attr,
bias_attr=self.attr_holder.bias_attr,
data_layout=data_layout,
in_place=in_place,
name=name,
moving_mean_name=self.attr_holder.moving_mean_attr.name,
moving_variance_name=self.attr_holder.moving_variance_attr.
name,
do_model_average_for_mean_and_var=
do_model_average_for_mean_and_var,
fuse_with_relu=fuse_with_relu,
use_global_stats=use_global_stats)
return BatchNorm_()
fluid_version = get_fluid_version()
if fluid_version >= 162 or fluid_version == 0:
def batch_norm(act=None,
momentum=0.9,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
in_place=False,
name=None,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=False,
use_global_stats=False):
"""
Return a function that creates a paddle.fluid.layers.batch_norm.
"""
default_name = "batch_norm"
param_attr = update_attr_name(name, default_name, param_attr, False)
bias_attr = update_attr_name(name, default_name, bias_attr, True)
moving_mean_attr = update_attr_name(
name, default_name + "_moving_mean", None, False)
moving_variance_attr = update_attr_name(
name, default_name + "_moving_variance", None, False)
class BatchNorm_(LayerFunc):
def __init__(self):
super(BatchNorm_, self).__init__(
AttrHolder(
param_attr=param_attr,
bias_attr=bias_attr,
moving_mean_attr=moving_mean_attr,
moving_variance_attr=moving_variance_attr))
def __call__(self, input, is_test=False):
return layers.batch_norm(
input=input,
act=act,
is_test=is_test,
momentum=momentum,
epsilon=epsilon,
param_attr=self.attr_holder.param_attr,
bias_attr=self.attr_holder.bias_attr,
data_layout=data_layout,
in_place=in_place,
name=name,
moving_mean_name=self.attr_holder.moving_mean_attr.name,
moving_variance_name=self.attr_holder.moving_variance_attr.
name,
do_model_average_for_mean_and_var=
do_model_average_for_mean_and_var,
use_global_stats=use_global_stats)
return BatchNorm_()
else:
def batch_norm(act=None,
momentum=0.9,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
in_place=False,
name=None,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=False,
fuse_with_relu=False,
use_global_stats=False):
"""
Return a function that creates a paddle.fluid.layers.batch_norm.
"""
default_name = "batch_norm"
param_attr = update_attr_name(name, default_name, param_attr, False)
bias_attr = update_attr_name(name, default_name, bias_attr, True)
moving_mean_attr = update_attr_name(
name, default_name + "_moving_mean", None, False)
moving_variance_attr = update_attr_name(
name, default_name + "_moving_variance", None, False)
class BatchNorm_(LayerFunc):
def __init__(self):
super(BatchNorm_, self).__init__(
AttrHolder(
param_attr=param_attr,
bias_attr=bias_attr,
moving_mean_attr=moving_mean_attr,
moving_variance_attr=moving_variance_attr))
def __call__(self, input, is_test=False):
return layers.batch_norm(
input=input,
act=act,
is_test=is_test,
momentum=momentum,
epsilon=epsilon,
param_attr=self.attr_holder.param_attr,
bias_attr=self.attr_holder.bias_attr,
data_layout=data_layout,
in_place=in_place,
name=name,
moving_mean_name=self.attr_holder.moving_mean_attr.name,
moving_variance_name=self.attr_holder.moving_variance_attr.
name,
do_model_average_for_mean_and_var=
do_model_average_for_mean_and_var,
fuse_with_relu=fuse_with_relu,
use_global_stats=use_global_stats)
return BatchNorm_()
def create_parameter(shape,
......
......@@ -20,7 +20,7 @@ import numpy as np
__all__ = [
'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3',
'MAX_INT32', '_HAS_FLUID', '_HAS_TORCH', '_IS_WINDOWS', '_IS_MAC',
'kill_process'
'kill_process', 'get_fluid_version'
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册