未验证 提交 cc47c83c 编写于 作者: J Jason 提交者: GitHub

fix fc_fuse pass (#37694)

* fix fc_fuse

* modify cmake notest,test=windows_ci

* retrigger all the ci
上级 06c3cce9
......@@ -130,6 +130,32 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
// Only support 2D-Tensor as weight for FC
std::vector<int64_t> w_shape = w->Var()->GetShape();
size_t w_rank = w_shape.size();
if (w_rank != 2) return;
// axis of elementwise_add should be -1 or x_num_col_dims
auto x_num_col_dims =
BOOST_GET_CONST(int, mul->Op()->GetAttr("x_num_col_dims"));
auto axis = BOOST_GET_CONST(int, elementwise_add->Op()->GetAttr("axis"));
if (axis != -1 && axis != x_num_col_dims) return;
// Shape of bias should be [1, out_size] or [out_size]
std::vector<int64_t> b_shape = bias->Var()->GetShape();
if (b_shape.size() == 1) {
if (b_shape[0] != w_shape[1]) {
return;
}
} else if (b_shape.size() == 2) {
if (b_shape[0] != 1 || b_shape[1] != w_shape[1]) {
return;
}
} else {
return;
}
Node* relu = nullptr;
Node* relu_out = nullptr;
if (with_relu) {
......
......@@ -55,14 +55,14 @@ TEST(FCFusePass, basic) {
auto* bias_0 = layers.data("conv2d_bias_0", {}, true);
auto* conv2d_out = layers.conv2d(a, filters_0, bias_0, false);
auto* relu_out_0 = layers.relu(conv2d_out);
auto* weights_0 = layers.data("weights_0", {}, true);
auto* weights_0 = layers.data("weights_0", {5, 4}, true);
auto* mul_out_0 = layers.mul(relu_out_0, weights_0);
auto* bias_1 = layers.data("bias_1", {}, true);
auto* bias_1 = layers.data("bias_1", {4}, true);
auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1, nullptr, 1);
auto* relu_out_1 = layers.relu(add_out_0);
auto* weights_1 = layers.data("weights_1", {}, true);
auto* weights_1 = layers.data("weights_1", {8, 9}, true);
auto* mul_out_1 = layers.mul(relu_out_1, weights_1);
auto* bias_2 = layers.data("bias_2", {}, true);
auto* bias_2 = layers.data("bias_2", {1, 9}, true);
auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2, nullptr, 1);
VLOG(4) << add_out_1;
......
......@@ -71,8 +71,11 @@ set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100)
set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30)
if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU)
set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240)
endif()
if (WITH_MKLDNN)
set_tests_properties(test_mkldnn_prelu_op PROPERTIES TIMEOUT 300)
......
......@@ -46,6 +46,17 @@ class TestFcFusePass(PassAutoScanTest):
config = self.create_inference_config(use_gpu=True)
yield config, ["fc"], (1e-5, 1e-5)
# trt static_shape
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=8,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
yield config, ['fc'], (1e-5, 1e-5)
def add_ignore_pass_case(self):
# Here we put some skip rules to avoid known bugs
def teller1(program_config, predictor_config):
......@@ -53,14 +64,22 @@ class TestFcFusePass(PassAutoScanTest):
x_shape = list(program_config.inputs["mul_x"].shape)
y_shape = list(program_config.weights["mul_y"].shape)
bias_shape = program_config.weights["bias"].shape
if (bias_shape != [y_shape[-1], ] and
bias_shape != [1, y_shape[-1]]):
bias_shape = list(program_config.weights["bias"].shape)
if predictor_config.tensorrt_engine_enabled():
# TensorRT cann't handle all the situation of elementwise_add
# disable it until this problem fixed
predictor_config.exp_disable_tensorrt_ops(["elementwise_add"])
if bias_shape != [y_shape[-1]] and bias_shape != [1, y_shape[-1]]:
return True
return False
def teller2(program_config, predictor_config):
# TODO fuse has bug while axis != -1
if program_config.ops[1].attrs["axis"] != -1:
axis = program_config.ops[1].attrs["axis"]
if axis != -1 and axis != program_config.ops[0].attrs[
"x_num_col_dims"]:
return True
return False
......@@ -164,7 +183,7 @@ class TestFcFusePass(PassAutoScanTest):
def test(self):
self.run_and_statis(
quant=False, max_examples=300, passes=["fc_fuse_pass"])
quant=False, max_examples=500, passes=["fc_fuse_pass"])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册