未验证 提交 9de6d88a 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13722 from tensor-tang/fix/release/1.0.0

cherry-pick 'bugfix: fusion lstm and gru batch,seq  mode switch'
......@@ -290,12 +290,13 @@ class FusionGRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) {
xx->Resize({total_T, D3});
SeqCompute(ctx);
return;
}
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
......
......@@ -424,11 +424,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext;
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) {
xx->Resize({x_dims[0], D4});
SeqCompute(ctx);
return;
}
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册