提交 14737e19 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[cherry-pick] [Refine Paddle-TRT INT8]: Support PaddleSlim's Resnet50 (#22485)

test=develop
上级 d2d4a02c
......@@ -34,10 +34,13 @@ using framework::ir::Node;
void analysis::TensorRtSubgraphPass::ApplyImpl(
framework::ir::Graph *graph) const {
framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph);
auto teller = [](const framework::ir::Node *node) {
auto enable_int8 = Get<bool>("enable_int8");
auto use_calib_mode = Get<bool>("use_calib_mode");
bool no_calib_int8 = enable_int8 && !(use_calib_mode);
auto teller = [&](const framework::ir::Node *node) {
if (!node->IsOp() || !node->Op()) return false;
return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op());
return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op(),
no_calib_int8);
};
framework::ir::SubGraphFuser fuser(
......
......@@ -98,6 +98,14 @@ class Pool2dOpConverter : public OpConverter {
nvinfer1::ILayer *layer = nullptr;
if (op_desc.HasAttr("enable_int8")) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("X_scale"));
float input_scale = boost::get<float>(op_desc.GetAttr("X_scale"));
engine_->SetTensorDynamicRange(input1, input_scale);
#endif
}
if (global_pooling == true) {
nv_ksize.d[0] = input_shape.d[input_dims - 2];
nv_ksize.d[1] = input_shape.d[input_dims - 1];
......@@ -159,14 +167,6 @@ class Pool2dOpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "pool2d", {output_name}, test_mode);
if (op_desc.HasAttr("enable_int8")) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("X_scale"));
float input_scale = boost::get<float>(op_desc.GetAttr("X_scale"));
engine_->SetTensorDynamicRange(input1, input_scale);
#endif
}
}
};
......
......@@ -36,12 +36,8 @@ class SoftMaxOpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode);
if (op_desc.HasAttr("out_scale")) {
#if IS_TRT_VERSION_GE(5000)
float out_scale = boost::get<float>(op_desc.GetAttr("out_scale"));
engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale);
#endif
}
// The trt will not run int for softmax.
engine_->SetTensorDynamicRange(input1, 1.0);
}
};
......
......@@ -26,9 +26,13 @@ struct SimpleOpTypeSetTeller : public Teller {
#endif
}
bool operator()(const std::string& op_type,
const framework::OpDesc& desc) override {
return teller_set.count(op_type);
bool operator()(const std::string& op_type, const framework::OpDesc& desc,
bool use_no_calib_int8) override {
if (use_no_calib_int8) {
return int8_teller_set.count(op_type);
} else {
return teller_set.count(op_type);
}
}
private:
......@@ -59,13 +63,22 @@ struct SimpleOpTypeSetTeller : public Teller {
"layer_norm",
"multihead_matmul",
}};
// use this set for no calib int8.
std::unordered_set<std::string> int8_teller_set{
{"mul", "conv2d", "pool2d", "relu", "depthwise_conv2d", "softmax",
"batch_norm", "elementwise_add", "leaky_relu", "fc"}};
};
bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) {
bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
bool use_no_calib_int8) {
// do not support the op which is labeled the `skip_quant`
if (desc.HasAttr("op_namescope") &&
boost::get<std::string>(desc.GetAttr("op_namescope")) == "/skip_quant_2/")
if ((desc.HasAttr("namescope") &&
boost::get<std::string>(desc.GetAttr("op_namescope")) ==
"/skip_quant_2/") ||
desc.HasAttr("skip_quant"))
return false;
for (auto& teller : tellers_) {
if (op_type == "pool2d" || op_type == "conv2d" ||
op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") {
......@@ -73,7 +86,7 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) {
boost::get<std::vector<int>>(desc.GetAttr("paddings"));
if (paddings.size() > 2) return false;
}
if ((*teller)(op_type, desc)) return true;
if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
}
return false;
}
......
......@@ -31,7 +31,8 @@ namespace tensorrt {
*/
struct Teller {
virtual bool operator()(const std::string& op_type,
const framework::OpDesc& desc) = 0;
const framework::OpDesc& desc,
bool use_no_calib_int8) = 0;
virtual ~Teller() = default;
};
......@@ -57,7 +58,8 @@ class OpTeller {
return *x;
}
bool Tell(const std::string& op_type, const framework::OpDesc& desc);
bool Tell(const std::string& op_type, const framework::OpDesc& desc,
bool use_no_calib_int8 = false);
private:
OpTeller();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册