From cc47c83caecccd2b660991bfaa09552017cbc0bf Mon Sep 17 00:00:00 2001 From: Jason Date: Wed, 1 Dec 2021 13:21:23 +0800 Subject: [PATCH] fix fc_fuse pass (#37694) * fix fc_fuse * modify cmake notest,test=windows_ci * retrigger all the ci --- paddle/fluid/framework/ir/fc_fuse_pass.cc | 26 ++++++++++++++++++ .../fluid/framework/ir/fc_fuse_pass_tester.cc | 8 +++--- .../unittests/ir/inference/CMakeLists.txt | 5 +++- .../ir/inference/test_fc_fuse_pass.py | 27 ++++++++++++++++--- 4 files changed, 57 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index bb78cdab677..e246a10961c 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -130,6 +130,32 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); + + // Only support 2D-Tensor as weight for FC + std::vector w_shape = w->Var()->GetShape(); + size_t w_rank = w_shape.size(); + if (w_rank != 2) return; + + // axis of elementwise_add should be -1 or x_num_col_dims + auto x_num_col_dims = + BOOST_GET_CONST(int, mul->Op()->GetAttr("x_num_col_dims")); + auto axis = BOOST_GET_CONST(int, elementwise_add->Op()->GetAttr("axis")); + if (axis != -1 && axis != x_num_col_dims) return; + + // Shape of bias should be [1, out_size] or [out_size] + std::vector b_shape = bias->Var()->GetShape(); + if (b_shape.size() == 1) { + if (b_shape[0] != w_shape[1]) { + return; + } + } else if (b_shape.size() == 2) { + if (b_shape[0] != 1 || b_shape[1] != w_shape[1]) { + return; + } + } else { + return; + } + Node* relu = nullptr; Node* relu_out = nullptr; if (with_relu) { diff --git a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc index 50469110368..39b544e7160 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc @@ -55,14 +55,14 @@ TEST(FCFusePass, basic) { auto* bias_0 = layers.data("conv2d_bias_0", {}, true); auto* conv2d_out = layers.conv2d(a, filters_0, bias_0, false); auto* relu_out_0 = layers.relu(conv2d_out); - auto* weights_0 = layers.data("weights_0", {}, true); + auto* weights_0 = layers.data("weights_0", {5, 4}, true); auto* mul_out_0 = layers.mul(relu_out_0, weights_0); - auto* bias_1 = layers.data("bias_1", {}, true); + auto* bias_1 = layers.data("bias_1", {4}, true); auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1, nullptr, 1); auto* relu_out_1 = layers.relu(add_out_0); - auto* weights_1 = layers.data("weights_1", {}, true); + auto* weights_1 = layers.data("weights_1", {8, 9}, true); auto* mul_out_1 = layers.mul(relu_out_1, weights_1); - auto* bias_2 = layers.data("bias_2", {}, true); + auto* bias_2 = layers.data("bias_2", {1, 9}, true); auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2, nullptr, 1); VLOG(4) << add_out_1; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 0b127d2a11f..4126e604cc1 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -71,8 +71,11 @@ set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30) + +if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT 120) -set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 120) +set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240) +endif() if (WITH_MKLDNN) set_tests_properties(test_mkldnn_prelu_op PROPERTIES TIMEOUT 300) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py index 1db3a007131..dccc29e75f0 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_fc_fuse_pass.py @@ -46,6 +46,17 @@ class TestFcFusePass(PassAutoScanTest): config = self.create_inference_config(use_gpu=True) yield config, ["fc"], (1e-5, 1e-5) + # trt static_shape + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=8, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + yield config, ['fc'], (1e-5, 1e-5) + def add_ignore_pass_case(self): # Here we put some skip rules to avoid known bugs def teller1(program_config, predictor_config): @@ -53,14 +64,22 @@ class TestFcFusePass(PassAutoScanTest): x_shape = list(program_config.inputs["mul_x"].shape) y_shape = list(program_config.weights["mul_y"].shape) bias_shape = program_config.weights["bias"].shape - if (bias_shape != [y_shape[-1], ] and - bias_shape != [1, y_shape[-1]]): + bias_shape = list(program_config.weights["bias"].shape) + + if predictor_config.tensorrt_engine_enabled(): + # TensorRT cann't handle all the situation of elementwise_add + # disable it until this problem fixed + predictor_config.exp_disable_tensorrt_ops(["elementwise_add"]) + + if bias_shape != [y_shape[-1]] and bias_shape != [1, y_shape[-1]]: return True return False def teller2(program_config, predictor_config): # TODO fuse has bug while axis != -1 - if program_config.ops[1].attrs["axis"] != -1: + axis = program_config.ops[1].attrs["axis"] + if axis != -1 and axis != program_config.ops[0].attrs[ + "x_num_col_dims"]: return True return False @@ -164,7 +183,7 @@ class TestFcFusePass(PassAutoScanTest): def test(self): self.run_and_statis( - quant=False, max_examples=300, passes=["fc_fuse_pass"]) + quant=False, max_examples=500, passes=["fc_fuse_pass"]) if __name__ == "__main__": -- GitLab