未验证 提交 badaaee6 编写于 作者: P Pei Yang 提交者: GitHub

show shape diff in wrong trt input shape errmsg, test=develop (#21451) (#21470)

上级 ccb508dc
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
...@@ -212,11 +213,25 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -212,11 +213,25 @@ class TensorRTEngineOp : public framework::OperatorBase {
i_shape.end()); i_shape.end());
std::vector<int64_t> runtime_input_shape(t_shape.begin() + 1, std::vector<int64_t> runtime_input_shape(t_shape.begin() + 1,
t_shape.end()); t_shape.end());
PADDLE_ENFORCE_EQ(model_input_shape == runtime_input_shape, true, auto comma_fold = [](std::string a, int b) {
"Input shapes are inconsistent with the model. TRT 5 " return std::move(a) + ", " + std::to_string(b);
"or lower version " };
"does not support dynamic input shapes. Please check " std::string model_input_shape_str = std::accumulate(
"your input shapes."); std::next(model_input_shape.begin()), model_input_shape.end(),
std::to_string(model_input_shape[0]), comma_fold);
std::string runtime_input_shape_str = std::accumulate(
std::next(runtime_input_shape.begin()), runtime_input_shape.end(),
std::to_string(runtime_input_shape[0]), comma_fold);
PADDLE_ENFORCE_EQ(
model_input_shape == runtime_input_shape, true,
platform::errors::InvalidArgument(
"Input shapes are inconsistent with the model. Expect [%s] in "
"model description, but got [%s] in runtime. TRT 5 "
"or lower version "
"does not support dynamic input shapes. Please check and "
"modify "
"your input shapes.",
model_input_shape_str, runtime_input_shape_str));
} }
runtime_batch = t_shape[0]; runtime_batch = t_shape[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册