diff --git a/lite/core/mir/fusion/conv_conv_fuse_pass.cc b/lite/core/mir/fusion/conv_conv_fuse_pass.cc index b2c5d8d15ab95fbcc43adc01c4189ae83b1316ed..e7f816ae4c99b3d27e9473c0937936a2f25a232b 100644 --- a/lite/core/mir/fusion/conv_conv_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_conv_fuse_pass.cc @@ -27,7 +27,7 @@ namespace mir { void ConvConvFusePass::Apply(const std::unique_ptr& graph) { // initialze fuser params std::vector conv_has_bias_cases{true, false}; - std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; + std::vector conv_type_cases{"conv2d"}; bool has_int8 = false; bool has_weight_quant = false; for (auto& place : graph->valid_places()) { diff --git a/lite/core/mir/fusion/conv_conv_fuser.cc b/lite/core/mir/fusion/conv_conv_fuser.cc index f2e24d06fa089ea4f575116d26f333060757e789..2393ff533007460f6f3d15dce11ef73ca09e802b 100644 --- a/lite/core/mir/fusion/conv_conv_fuser.cc +++ b/lite/core/mir/fusion/conv_conv_fuser.cc @@ -132,8 +132,8 @@ void ConvConvFuser::BuildPattern() { VLOG(5) << "The kernel size of the second conv must be 1x1"; continue; } - if (groups1 != 1) { - VLOG(5) << "The groups of weight1_dim must be 1"; + if (groups0 != 1 || groups1 != 1) { + VLOG(5) << "The all groups of weight_dim must be 1"; continue; } if (ch_out_0 != ch_in_1) {