提交 8d248a6a 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(dnn/cuda): fix testcase for fallback nchw qs8 conv

GitOrigin-RevId: 646440db59f0157a3fdbd8061167f9ac04dbd422
上级 894a2407
......@@ -353,7 +353,8 @@ bool megdnn::check_bias_share_in_channel(const TensorLayout& bias,
format == param::ConvBias::Format::NCHW4_NCHW) {
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 &&
bias[3] == 1);
} else if (format == param::ConvBias::Format::NHWC) {
} else if (format == param::ConvBias::Format::NHWC ||
format == param::ConvBias::Format::NCHW4_NHWC) {
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 &&
bias[2] == 1);
} else if (format == param::ConvBias::Format::NCHW4 ||
......
......@@ -84,8 +84,12 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_subopr_list(
inner_dst_layout, inner_bias_layout, inner_z_layout);
Param inner_conv_param = o->param();
inner_conv_param.format = Param::Format::NCHW4;
if (layouts[4].dtype.enumv() == DTypeEnum::Float32) {
inner_conv_param.format = Param::Format::NCHW4_NCHW;
} else {
inner_conv_param.format = Param::Format::NCHW4;
}
std::string param_str;
Algorithm::serialize_write_pod(inner_conv_param, param_str);
......@@ -192,9 +196,9 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec(
inner_conv_param.format =
dst_float ? Param::Format::NCHW4_NCHW : Param::Format::NCHW4;
auto inner_opr = args.handle->create_operator<ConvBiasForward>();
inner_opr->param() = inner_conv_param;
set_execution_policy<ConvBiasForward, ConvBiasForward*>(args.opr,
inner_opr.get());
inner_opr->param() = inner_conv_param;
relayout_nchw_nchw4->exec(*args.src_tensor, inner_src, {});
relayout_weight->exec(*args.filter_tensor, inner_weight, {});
......
......@@ -701,9 +701,11 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) {
TEST_F(CUDA, FALLBACK_CONV_QS8) {
require_compute_capability_eq(7, 5);
Checker<ConvBiasForward> checker(handle_cuda());
auto check = [&checker](const std::string&& algo) {
auto check = [&checker](const std::string&& algo,
const std::string&& sub_algo) {
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo.c_str()));
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
{algo.c_str(), {sub_algo.c_str()}}));
UniformIntRNG rng{-3, 3};
UniformIntRNG bias_rng{-50, 50};
checker.set_rng(0, &rng)
......@@ -733,15 +735,17 @@ TEST_F(CUDA, FALLBACK_CONV_QS8) {
{},
{}});
};
check("FALLBACK_CONV_NCHW_QS8");
check("FALLBACK_CONV_NCHW_QS8", "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM");
}
TEST_F(CUDA, FALLBACK_CONV_QS8_F32) {
require_compute_capability_eq(7, 5);
Checker<ConvBiasForward> checker(handle_cuda());
auto check = [&checker](const std::string&& algo) {
auto check = [&checker](const std::string&& algo,
const std::string&& sub_algo) {
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(algo.c_str()));
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
{algo.c_str(), {sub_algo.c_str()}}));
UniformIntRNG rng{-3, 3};
UniformFloatRNG bias_rng{-50.f, 50.f};
checker.set_rng(0, &rng)
......@@ -771,7 +775,7 @@ TEST_F(CUDA, FALLBACK_CONV_QS8_F32) {
{},
{}});
};
check("FALLBACK_CONV_NCHW_QS8");
check("FALLBACK_CONV_NCHW_QS8", "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM");
}
TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_WEIGHT_PREPROCESS) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册