未验证 提交 fa7ace7c 编写于 作者: G Guo Sheng 提交者: GitHub

Cherry pick from #21862 (#22194)

* Fix default label dim of label_smooth_op. test=develop (#21862)

* Fix unit tests of label_smooth_op's data size.
上级 c7248cda
......@@ -37,7 +37,7 @@ class LabelSmoothOp : public framework::OperatorWithKernel {
auto noise_dims = ctx->GetInputDim("PriorDist");
auto noise_numel = paddle::framework::product(noise_dims);
PADDLE_ENFORCE(
in_dims[1] == noise_numel,
in_dims[in_dims.size() - 1] == noise_numel,
"The number of elements in Input(PriorDist) must be equal to the "
"dimension of each label.");
}
......
......@@ -34,7 +34,7 @@ __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon,
const T* dist_data, T* dst) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) {
int dist_idx = idx - (idx / dist_numel) * dist_numel;
int dist_idx = idx % dist_numel;
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] +
static_cast<T>(epsilon) * dist_data[dist_idx];
}
......@@ -56,7 +56,7 @@ class LabelSmoothGPUKernel : public framework::OpKernel<T> {
auto* out_t = ctx.Output<framework::LoDTensor>("Out");
auto* in_t = ctx.Input<framework::LoDTensor>("X");
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
auto label_dim = in_t->dims()[1];
auto label_dim = in_t->dims()[in_t->dims().size() - 1];
auto epsilon = ctx.Attr<float>("epsilon");
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
auto size_prob = in_t->numel();
......
......@@ -27,7 +27,7 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
auto* out_t = ctx.Output<framework::LoDTensor>("Out");
auto* in_t = ctx.Input<framework::LoDTensor>("X");
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
auto label_dim = in_t->dims()[1];
auto label_dim = in_t->dims()[in_t->dims().size() - 1];
out_t->mutable_data<T>(ctx.GetPlace());
auto epsilon = ctx.Attr<float>("epsilon");
......@@ -39,7 +39,7 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
out.device(dev) =
static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon) *
dist.broadcast(Eigen::DSizes<int, 1>(in_t->numel()));
dist.broadcast(Eigen::DSizes<int, 1>(in_t->numel() / label_dim));
} else {
out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon / label_dim);
......
......@@ -23,7 +23,7 @@ class TestLabelSmoothOp(OpTest):
def config(self):
self.op_type = "label_smooth"
self.epsilon = 0.1
batch_size, self.label_dim = 5, 10
batch_size, self.label_dim = 10, 12
self.label = np.zeros((batch_size, self.label_dim)).astype("float64")
nonzero_index = np.random.randint(self.label_dim, size=(batch_size))
self.label[np.arange(batch_size), nonzero_index] = 1
......@@ -53,5 +53,23 @@ class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp):
self.outputs = {'Out': smoothed_label}
class TestLabelSmoothOp3D(TestLabelSmoothOp):
def setUp(self):
super(TestLabelSmoothOp3D, self).setUp()
self.inputs['X'] = self.inputs['X'].reshape(
[2, -1, self.inputs['X'].shape[-1]])
self.outputs['Out'] = self.outputs['Out'].reshape(self.inputs['X']
.shape)
class TestLabelSmoothOpWithPriorDist3D(TestLabelSmoothOpWithPriorDist):
def setUp(self):
super(TestLabelSmoothOpWithPriorDist3D, self).setUp()
self.inputs['X'] = self.inputs['X'].reshape(
[2, -1, self.inputs['X'].shape[-1]])
self.outputs['Out'] = self.outputs['Out'].reshape(self.inputs['X']
.shape)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册