未验证 提交 3a255881 编写于 作者: C chajchaj 提交者: GitHub

fix use_softmax=False does not work, test=develop (#32035)

上级 1f8834ad
...@@ -1388,6 +1388,8 @@ def cross_entropy(input, ...@@ -1388,6 +1388,8 @@ def cross_entropy(input,
"should be '-100', but received %s, which is not allowed." % "should be '-100', but received %s, which is not allowed." %
ignore_index) ignore_index)
softmax_switch = use_softmax
input_dims = len(list(input.shape)) input_dims = len(list(input.shape))
label_dims = len(list(label.shape)) label_dims = len(list(label.shape))
if input_dims - 1 != label_dims and input_dims != label_dims: if input_dims - 1 != label_dims and input_dims != label_dims:
...@@ -1400,7 +1402,7 @@ def cross_entropy(input, ...@@ -1400,7 +1402,7 @@ def cross_entropy(input,
_, out = core.ops.softmax_with_cross_entropy( _, out = core.ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index', input, label, 'soft_label', soft_label, 'ignore_index',
ignore_index, 'numeric_stable_mode', True, 'axis', axis, ignore_index, 'numeric_stable_mode', True, 'axis', axis,
'use_softmax', use_softmax) 'softmax_switch', softmax_switch)
if weight is not None: if weight is not None:
...@@ -1482,7 +1484,7 @@ def cross_entropy(input, ...@@ -1482,7 +1484,7 @@ def cross_entropy(input,
'ignore_index': ignore_index, 'ignore_index': ignore_index,
'numeric_stable_mode': True, 'numeric_stable_mode': True,
'axis': axis, 'axis': axis,
'use_softmax': use_softmax 'softmax_switch': softmax_switch
} }
helper = LayerHelper('softmax_with_cross_entropy', **locals()) helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=input.dtype) softmax = helper.create_variable_for_type_inference(dtype=input.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册