未验证 提交 a0562813 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add gru qat int8 test (#50846)

* Add gru qat int8 test

* Change place of model downloading

* Update paddle/fluid/inference/tests/api/CMakeLists.txt
Co-authored-by: NSławomir Siwek <slawomir.siwek@intel.com>

* Correct flags names and add description

---------
Co-authored-by: NSławomir Siwek <slawomir.siwek@intel.com>
上级 4557e7e8
......@@ -126,7 +126,7 @@ function(inference_analysis_api_int8_test target install_dir filename)
--refer_result=${install_dir}/result.txt
--accuracy=0.8
--batch_size=5
--enable_int8=true)
--enable_int8_ptq=true)
endfunction()
function(inference_multiple_models_analysis_api_test target install_dir
......@@ -159,7 +159,7 @@ function(inference_analysis_api_int8_test_run TARGET_NAME test_binary model_dir
--infer_data=${data_path}
--warmup_batch_size=${WARMUP_BATCH_SIZE}
--batch_size=50
--enable_int8=true
--enable_int8_ptq=true
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=2)
endfunction()
......@@ -197,7 +197,7 @@ function(inference_analysis_api_object_dection_int8_test_run TARGET_NAME
--infer_data=${data_path}
--warmup_batch_size=10
--batch_size=300
--enable_int8=true
--enable_int8_ptq=true
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=1)
endfunction()
......@@ -221,7 +221,7 @@ function(
fp32_model_dir
int8_model_dir
data_path
enable_quant_int8)
enable_int8_qat)
inference_analysis_test_run(
${TARGET_NAME}
COMMAND
......@@ -231,8 +231,7 @@ function(
--int8_model=${int8_model_dir}
--infer_data=${data_path}
--batch_size=50
--enable_int8=true
--enable_quant_int8=${enable_quant_int8}
--enable_int8_qat=${enable_int8_qat}
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI}
--with_accuracy_layer=false
--iterations=2)
......@@ -271,8 +270,15 @@ function(inference_analysis_api_lexical_bfloat16_test_run TARGET_NAME
--iterations=2)
endfunction()
function(inference_analysis_api_lexical_int8_test_run TARGET_NAME test_binary
infer_model data_path fuse_multi_gru)
function(
inference_analysis_api_lexical_int8_test_run
TARGET_NAME
test_binary
infer_model
data_path
enable_int8_ptq
enable_int8_qat
fuse_multi_gru)
inference_analysis_test_run(
${TARGET_NAME}
COMMAND
......@@ -284,8 +290,9 @@ function(inference_analysis_api_lexical_int8_test_run TARGET_NAME test_binary
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI}
--with_accuracy_layer=true
--use_analysis=true
--enable_int8=true
--quantized_accuracy=0.01
--enable_int8_ptq=${enable_int8_ptq}
--enable_int8_qat=${enable_int8_qat}
--quantized_accuracy=0.015
--fuse_multi_gru=${fuse_multi_gru}
--iterations=4)
endfunction()
......@@ -685,7 +692,7 @@ if(WITH_MKLDNN)
--infer_data=${IMAGENET_DATA_PATH}
--warmup_batch_size=50
--batch_size=1
--enable_int8=true
--enable_int8_ptq=true
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI}
--iterations=100
--with_accuracy_layer=false)
......@@ -775,12 +782,37 @@ if(WITH_MKLDNN)
${GRU_DATA_PATH})
# run post-training quantization lexical analysis test
inference_analysis_api_lexical_int8_test_run(
test_analyzer_lexical_gru_int8 ${LEXICAL_TEST_APP} ${GRU_MODEL_PATH}
${GRU_DATA_PATH} false)
test_analyzer_lexical_gru_int8
${LEXICAL_TEST_APP}
${GRU_MODEL_PATH}
${GRU_DATA_PATH}
true # enable_int8_ptq
false # enable_int8_qat
false) # fuse_multi_gru
# run post-training quantization lexical analysis test with multi_gru fuse
inference_analysis_api_lexical_int8_test_run(
test_analyzer_lexical_gru_int8_multi_gru ${LEXICAL_TEST_APP}
${GRU_MODEL_PATH} ${GRU_DATA_PATH} true)
test_analyzer_lexical_gru_int8_multi_gru
${LEXICAL_TEST_APP}
${GRU_MODEL_PATH}
${GRU_DATA_PATH}
true # enable_int8_ptq
false # enable_int8_qat
true) # fuse_multi_gru
# run qat gru test
set(QAT_GRU_MODEL_ARCHIVE "GRU_quant_acc.tar.gz")
set(QAT_GRU_MODEL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/quant/GRU_quant2")
download_quant_data(${QAT_GRU_MODEL_DIR} ${QAT_GRU_MODEL_ARCHIVE}
cf207f8076dcfb8b74d8b6bdddf9090c)
inference_analysis_api_lexical_int8_test_run(
test_analyzer_lexical_gru_qat_int8
${LEXICAL_TEST_APP}
"${QAT_GRU_MODEL_DIR}/GRU_quant_acc"
${GRU_DATA_PATH}
false # enable_int8_ptq
true # enable_int8_qat
false) # fuse_multi_gru
### optimized FP32 vs. Quant INT8 tests
......
......@@ -49,7 +49,7 @@ TEST(Analyzer_int8_image_classification, quantization) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInputs(&input_slots_all);
if (FLAGS_enable_mkldnn && FLAGS_enable_int8) {
if (FLAGS_enable_mkldnn && FLAGS_enable_int8_ptq) {
// prepare warmup batch from input data read earlier
// warmup batch size can be different than batch size
std::shared_ptr<std::vector<PaddleTensor>> warmup_data =
......
......@@ -261,7 +261,7 @@ TEST(Analyzer_lexical_test, Analyzer_lexical_analysis) {
SetAnalysisConfig(&analysis_cfg, FLAGS_cpu_num_threads);
if (FLAGS_enable_bf16) {
analysis_cfg.EnableMkldnnBfloat16();
} else if (FLAGS_enable_int8) {
} else if (FLAGS_enable_int8_ptq) {
if (FLAGS_fuse_multi_gru) {
analysis_cfg.pass_builder()->AppendPass("multi_gru_fuse_pass");
}
......@@ -271,6 +271,8 @@ TEST(Analyzer_lexical_test, Analyzer_lexical_analysis) {
analysis_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data);
analysis_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(
FLAGS_batch_size);
} else if (FLAGS_enable_int8_qat) {
analysis_cfg.EnableMkldnnInt8();
} else {
// if fp32 => disable mkldnn fc passes
// when passes are enabled dnnl error occurs for iterations==0
......
......@@ -119,7 +119,7 @@ TEST(Analyzer_quant_image_classification, quantization) {
AnalysisConfig int8_cfg;
SetConfig(&int8_cfg, FLAGS_int8_model);
if (FLAGS_enable_quant_int8) int8_cfg.EnableMkldnnInt8();
if (FLAGS_enable_int8_qat) int8_cfg.EnableMkldnnInt8();
// read data from file and prepare batches with test data
std::vector<std::vector<PaddleTensor>> input_slots_all;
......
......@@ -53,8 +53,12 @@ DEFINE_bool(with_accuracy_layer,
"Calculate the accuracy while label is in the input");
DEFINE_bool(enable_fp32, true, "Enable FP32 type prediction");
DEFINE_bool(enable_bf16, false, "Enable BF16 type prediction");
DEFINE_bool(enable_int8, false, "Enable INT8 type prediction");
DEFINE_bool(enable_quant_int8, false, "Enable QUANT INT8 type prediction");
DEFINE_bool(enable_int8_ptq,
false,
"Enable INT8 post-training quantization prediction");
DEFINE_bool(enable_int8_qat,
false,
"Enable INT8 quant-aware training prediction");
DEFINE_int32(warmup_batch_size, 100, "batch size for quantization warmup");
// setting iterations to 0 means processing the whole dataset
DEFINE_int32(iterations, 0, "number of batches to process");
......@@ -701,7 +705,7 @@ void SummarizePerformance(const char *title_fp32,
const char *title,
float sample_latency) {
if (FLAGS_enable_fp32) SummarizePerformance(title_fp32, sample_latency_fp32);
if (FLAGS_enable_int8 || FLAGS_enable_bf16)
if (FLAGS_enable_int8_ptq || FLAGS_enable_int8_qat || FLAGS_enable_bf16)
SummarizePerformance(title, sample_latency);
}
......@@ -760,7 +764,8 @@ void CompareAccuracy(
const std::vector<std::vector<PaddleTensor>> &output_slots_quant,
const std::vector<std::vector<PaddleTensor>> &output_slots_ref,
int compared_idx) {
if ((FLAGS_enable_fp32 && (FLAGS_enable_int8 || FLAGS_enable_bf16)) &&
if ((FLAGS_enable_fp32 &&
(FLAGS_enable_int8_ptq || FLAGS_enable_int8_qat || FLAGS_enable_bf16)) &&
(output_slots_quant.size() == 0 || output_slots_ref.size()) == 0)
throw std::invalid_argument(
"CompareAccuracy: output_slots vector is empty.");
......@@ -768,7 +773,7 @@ void CompareAccuracy(
float avg_acc_quant = 0.0;
float avg_acc_ref = 0.0;
if (FLAGS_enable_int8 || FLAGS_enable_bf16)
if (FLAGS_enable_int8_ptq || FLAGS_enable_int8_qat || FLAGS_enable_bf16)
avg_acc_quant = CompareAccuracyOne(output_slots_quant, compared_idx);
if (FLAGS_enable_fp32)
......@@ -778,9 +783,11 @@ void CompareAccuracy(
if (FLAGS_enable_fp32) CHECK_GT(avg_acc_ref, 0.0);
if (FLAGS_enable_int8 || FLAGS_enable_bf16) CHECK_GT(avg_acc_quant, 0.0);
if (FLAGS_enable_int8_ptq || FLAGS_enable_int8_qat || FLAGS_enable_bf16)
CHECK_GT(avg_acc_quant, 0.0);
if (FLAGS_enable_fp32 && (FLAGS_enable_int8 || FLAGS_enable_bf16))
if (FLAGS_enable_fp32 &&
(FLAGS_enable_int8_ptq || FLAGS_enable_int8_qat || FLAGS_enable_bf16))
CHECK_LE(avg_acc_ref - avg_acc_quant, FLAGS_quantized_accuracy);
}
......@@ -864,7 +871,7 @@ void CompareQuantizedAndAnalysis(
std::vector<std::vector<PaddleTensor>> quantized_outputs;
float sample_latency_int8{-1};
if (FLAGS_enable_int8) {
if (FLAGS_enable_int8_ptq || FLAGS_enable_int8_qat) {
TestOneThreadPrediction(qcfg,
inputs,
&quantized_outputs,
......@@ -965,7 +972,7 @@ void CompareAnalysisAndAnalysis(
std::vector<std::vector<PaddleTensor>> int8_outputs;
float sample_latency_int8{-1};
if (FLAGS_enable_int8) {
if (FLAGS_enable_int8_ptq || FLAGS_enable_int8_qat) {
TestOneThreadPrediction(
cfg2, inputs, &int8_outputs, true, VarType::INT8, &sample_latency_int8);
}
......
......@@ -401,10 +401,7 @@ if(LINUX AND WITH_MKLDNN)
${NLP_LABLES_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE})
# Quant2 GRU
set(QUANT2_GRU_MODEL_ARCHIVE "GRU_quant_acc.tar.gz")
set(QUANT2_GRU_MODEL_DIR "${QUANT_INSTALL_DIR}/GRU_quant2")
download_quant_model(${QUANT2_GRU_MODEL_DIR} ${QUANT2_GRU_MODEL_ARCHIVE}
cf207f8076dcfb8b74d8b6bdddf9090c)
set(QUANT2_GRU_OPS_TO_QUANTIZE "multi_gru")
# Quant2 LSTM
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册