提交 6d686ff2 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(gopt/inference): allow Float32 output dtype in EnableNCHW64Pass

GitOrigin-RevId: 1891efb76f66a6abbd0a56820281b4fe91e70304
上级 7d3df995
......@@ -4330,14 +4330,20 @@ EnableNCHW64Pass::make_nchw64_converter() {
bool check_dtype =
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8;
if (opr->input().size() >= 3)
check_dtype &=
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32;
if (opr->input().size() >= 4)
check_dtype &=
new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8;
mgb_assert(opr->output().size() > 0);
bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 3) {
auto dtype_expect = dst_float ? DTypeEnum::Float32
: DTypeEnum::QuantizedS32;
check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect;
}
if (opr->input().size() >= 4) {
check_dtype &= new_inp[3]->dtype().enumv() ==
opr->output(0)->dtype().enumv();
}
if (!check_dtype)
return nullptr;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
bool check_channels = out_channels % 4 == 0 && in_channels % 4 == 0;
......@@ -4370,12 +4376,18 @@ EnableNCHW64Pass::make_nchw64_converter() {
}
}
};
for (size_t i = 0; i < inps.size(); ++i) {
inps[i] = process(i);
// do not format bias and z when dst_float is true
bool skip = dst_float && i >= 2;
if (!skip) inps[i] = process(i);
}
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
auto ret = make_new_conv(inps, &conv_bias, Format::NCHW4);
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4));
auto ret = make_new_conv(
inps, &conv_bias,
dst_float ? Format::NCHW4_NCHW : Format::NCHW4);
if (!dst_float)
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4));
return ret;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册