From 14737e19a919663bf58c0af3be9a36a064c87407 Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Mon, 10 Feb 2020 17:40:46 +0800 Subject: [PATCH] [cherry-pick] [Refine Paddle-TRT INT8]: Support PaddleSlim's Resnet50 (#22485) test=develop --- .../ir_passes/tensorrt_subgraph_pass.cc | 9 ++++--- .../inference/tensorrt/convert/pool2d_op.cc | 16 +++++------ .../inference/tensorrt/convert/softmax_op.cc | 8 ++---- paddle/fluid/inference/tensorrt/op_teller.cc | 27 ++++++++++++++----- paddle/fluid/inference/tensorrt/op_teller.h | 6 +++-- 5 files changed, 40 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 38106141b69..397411ccf87 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -34,10 +34,13 @@ using framework::ir::Node; void analysis::TensorRtSubgraphPass::ApplyImpl( framework::ir::Graph *graph) const { framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph); - - auto teller = [](const framework::ir::Node *node) { + auto enable_int8 = Get("enable_int8"); + auto use_calib_mode = Get("use_calib_mode"); + bool no_calib_int8 = enable_int8 && !(use_calib_mode); + auto teller = [&](const framework::ir::Node *node) { if (!node->IsOp() || !node->Op()) return false; - return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op()); + return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op(), + no_calib_int8); }; framework::ir::SubGraphFuser fuser( diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 846e2154b11..2a4aada92e5 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -98,6 +98,14 @@ class Pool2dOpConverter : public OpConverter { nvinfer1::ILayer *layer = nullptr; + if (op_desc.HasAttr("enable_int8")) { +#if IS_TRT_VERSION_GE(5000) + CHECK(op_desc.HasAttr("X_scale")); + float input_scale = boost::get(op_desc.GetAttr("X_scale")); + engine_->SetTensorDynamicRange(input1, input_scale); +#endif + } + if (global_pooling == true) { nv_ksize.d[0] = input_shape.d[input_dims - 2]; nv_ksize.d[1] = input_shape.d[input_dims - 1]; @@ -159,14 +167,6 @@ class Pool2dOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "pool2d", {output_name}, test_mode); - - if (op_desc.HasAttr("enable_int8")) { -#if IS_TRT_VERSION_GE(5000) - CHECK(op_desc.HasAttr("X_scale")); - float input_scale = boost::get(op_desc.GetAttr("X_scale")); - engine_->SetTensorDynamicRange(input1, input_scale); -#endif - } } }; diff --git a/paddle/fluid/inference/tensorrt/convert/softmax_op.cc b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc index b0ae1694127..9f4a048961f 100644 --- a/paddle/fluid/inference/tensorrt/convert/softmax_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc @@ -36,12 +36,8 @@ class SoftMaxOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode); - if (op_desc.HasAttr("out_scale")) { -#if IS_TRT_VERSION_GE(5000) - float out_scale = boost::get(op_desc.GetAttr("out_scale")); - engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale); -#endif - } + // The trt will not run int for softmax. + engine_->SetTensorDynamicRange(input1, 1.0); } }; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 39cdf5ba1af..a0720605b97 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -26,9 +26,13 @@ struct SimpleOpTypeSetTeller : public Teller { #endif } - bool operator()(const std::string& op_type, - const framework::OpDesc& desc) override { - return teller_set.count(op_type); + bool operator()(const std::string& op_type, const framework::OpDesc& desc, + bool use_no_calib_int8) override { + if (use_no_calib_int8) { + return int8_teller_set.count(op_type); + } else { + return teller_set.count(op_type); + } } private: @@ -59,13 +63,22 @@ struct SimpleOpTypeSetTeller : public Teller { "layer_norm", "multihead_matmul", }}; + + // use this set for no calib int8. + std::unordered_set int8_teller_set{ + {"mul", "conv2d", "pool2d", "relu", "depthwise_conv2d", "softmax", + "batch_norm", "elementwise_add", "leaky_relu", "fc"}}; }; -bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) { +bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc, + bool use_no_calib_int8) { // do not support the op which is labeled the `skip_quant` - if (desc.HasAttr("op_namescope") && - boost::get(desc.GetAttr("op_namescope")) == "/skip_quant_2/") + if ((desc.HasAttr("namescope") && + boost::get(desc.GetAttr("op_namescope")) == + "/skip_quant_2/") || + desc.HasAttr("skip_quant")) return false; + for (auto& teller : tellers_) { if (op_type == "pool2d" || op_type == "conv2d" || op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") { @@ -73,7 +86,7 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) { boost::get>(desc.GetAttr("paddings")); if (paddings.size() > 2) return false; } - if ((*teller)(op_type, desc)) return true; + if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } return false; } diff --git a/paddle/fluid/inference/tensorrt/op_teller.h b/paddle/fluid/inference/tensorrt/op_teller.h index 7ff1d4746a1..76784c7445e 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.h +++ b/paddle/fluid/inference/tensorrt/op_teller.h @@ -31,7 +31,8 @@ namespace tensorrt { */ struct Teller { virtual bool operator()(const std::string& op_type, - const framework::OpDesc& desc) = 0; + const framework::OpDesc& desc, + bool use_no_calib_int8) = 0; virtual ~Teller() = default; }; @@ -57,7 +58,8 @@ class OpTeller { return *x; } - bool Tell(const std::string& op_type, const framework::OpDesc& desc); + bool Tell(const std::string& op_type, const framework::OpDesc& desc, + bool use_no_calib_int8 = false); private: OpTeller(); -- GitLab