diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 287f0670a81e1675b036f583b7de117e0713b0d4..8497556c0cd77dddf7ac99e8bfea96689496565e 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -415,6 +415,12 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { const int axis = CanonicalAxis(context.Attr("axis"), rank); int axis_dim = logits->dims()[axis]; + const int n = SizeToAxis(axis, logits->dims()); + const int d = SizeFromAxis(axis, logits->dims()); + + auto* softmax_data = softmax->mutable_data(context.GetPlace()); + auto* loss_data = loss->mutable_data(context.GetPlace()); + if (axis_dim == 1) { math::SetConstant set_constant; set_constant(context.cuda_device_context(), softmax, static_cast(1)); @@ -422,12 +428,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { return; } - const int n = SizeToAxis(axis, logits->dims()); - const int d = SizeFromAxis(axis, logits->dims()); - - auto* softmax_data = softmax->mutable_data(context.GetPlace()); - auto* loss_data = loss->mutable_data(context.GetPlace()); - auto soft_label = context.Attr("soft_label"); auto ignore_index = context.Attr("ignore_index"); diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index 76186704e37f9431e412815ac47b40dce6e3bd74..f33fdda8de450f176abc1714a353c3e4f4d81148 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -280,6 +280,23 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp): self.shape = [3, 5, 7, 11] +class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne( + TestSoftmaxWithCrossEntropyOp): + """ + Test softmax with cross entropy operator with discreate one-hot labels. + Given axis != -1 + """ + + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -1 + self.shape = [3, 5, 7, 1] + + class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1( TestSoftmaxWithCrossEntropyOpNoCudnnFp16): def initParams(self):