diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e5e3fa7bf8f76d60d16f82c149b4c7ae5bb3c693..52c605d5bb49825503b031f0788008f640398d93 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1388,6 +1388,8 @@ def cross_entropy(input, "should be '-100', but received %s, which is not allowed." % ignore_index) + softmax_switch = use_softmax + input_dims = len(list(input.shape)) label_dims = len(list(label.shape)) if input_dims - 1 != label_dims and input_dims != label_dims: @@ -1400,7 +1402,7 @@ def cross_entropy(input, _, out = core.ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index', ignore_index, 'numeric_stable_mode', True, 'axis', axis, - 'use_softmax', use_softmax) + 'softmax_switch', softmax_switch) if weight is not None: @@ -1482,7 +1484,7 @@ def cross_entropy(input, 'ignore_index': ignore_index, 'numeric_stable_mode': True, 'axis': axis, - 'use_softmax': use_softmax + 'softmax_switch': softmax_switch } helper = LayerHelper('softmax_with_cross_entropy', **locals()) softmax = helper.create_variable_for_type_inference(dtype=input.dtype)