未验证 提交 4f43d51f 编写于 作者: D Double_V 提交者: GitHub

add rois_num params for roi_align_xpu op, test=kunlun (#28094)

上级 c4e18dc0
...@@ -39,14 +39,40 @@ class XPUROIAlignOpKernel : public framework::OpKernel<T> { ...@@ -39,14 +39,40 @@ class XPUROIAlignOpKernel : public framework::OpKernel<T> {
int width = in_dims[3]; int width = in_dims[3];
int rois_num = rois->dims()[0]; int rois_num = rois->dims()[0];
const T* input_data = in->data<T>(); const T* input_data = in->data<T>();
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1; framework::Tensor _roi_batch_list;
PADDLE_ENFORCE_EQ( _roi_batch_list.Resize({rois_num});
rois_batch_size, batch_size, int* rois_lod = _roi_batch_list.mutable_data<int>(ctx.GetPlace());
platform::errors::InvalidArgument( int rois_batch_size = 1;
"The rois_batch_size and imgs batch_size of roi_align_xpu OP must " if (ctx.HasInput("RoisNum")) {
"be the same. But received rois_batch_size %d , batch_size %d", auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size, batch_size)); rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of rois and the batch size of images "
" must be the same. But received the batch size of rois is %d, "
"and the batch size of images is %d",
rois_batch_size, batch_size));
auto* rois_num_data = rois_num_t->data<int>();
rois_lod[0] = 0;
for (int n = 0; n < rois_batch_size; ++n) {
rois_lod[n + 1] = rois_lod[n] + rois_num_data[n];
}
} else {
auto _rois_lod = rois->lod().back();
rois_batch_size = _rois_lod.size() - 1;
for (int n = 0; n < _rois_lod.size(); ++n) {
rois_lod[n] = _rois_lod[n];
}
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The rois_batch_size and imgs batch_size of roi_align_xpu OP "
"must "
"be the same. But received rois_batch_size %d , batch_size %d",
rois_batch_size, batch_size));
}
int rois_num_with_lod = rois_lod[rois_batch_size]; int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rois_num, rois_num_with_lod, rois_num, rois_num_with_lod,
......
...@@ -179,5 +179,29 @@ class TestROIAlignOp(OpTest): ...@@ -179,5 +179,29 @@ class TestROIAlignOp(OpTest):
self.check_output_with_place(place) self.check_output_with_place(place)
class TestROIAlignInLodOp(TestROIAlignOp):
def set_data(self):
self.init_test_case()
self.make_rois()
self.calc_roi_align()
seq_len = self.rois_lod[0]
self.inputs = {
'X': self.x,
'ROIs': (self.rois[:, 1:5], self.rois_lod),
'RoisNum': np.asarray(seq_len).astype('int32')
}
self.attrs = {
'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width,
'sampling_ratio': self.sampling_ratio
}
self.outputs = {'Out': self.out_data}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册