提交 52cb4b39 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mgb/gopt): fix convert format nchw->nchw4 pass

GitOrigin-RevId: 1813753b144fa70d53f4f97f1a2a509963440d04
上级 90d5895e
...@@ -1598,18 +1598,103 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1598,18 +1598,103 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
} }
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); return serialization::copy_opr_shallow(*opr, temp_inp, opr->config());
}; };
auto replace_pooling_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
using Param = opr::PoolingForward::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
mgb_assert(pooling.param().format == Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHW4.");
if (new_inp[0]->shape().ndim == 5) {
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
auto new_param = pooling.param();
new_param.format = Format::NCHW4;
auto new_pooling =
opr::PoolingForward::make(new_inp[0], new_param, opr->config());
mgb_assert(new_pooling.shape().ndim == 5,
"out var of Pooling opr after transform must be 5 (got: "
"%zu).",
new_pooling.shape().ndim);
return new_pooling.node()->owner_opr();
}
auto new_opr =
serialization::copy_opr_shallow(*opr, new_inp, opr->config());
return new_opr;
};
auto replace_resize_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
using Param = opr::ResizeForward::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
auto& resize = opr->cast_final_safe<opr::ResizeForward>();
mgb_assert(resize.param().format == Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHW4.");
if (new_inp[0]->shape().ndim == 5) {
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
auto new_param = resize.param();
new_param.format = Format::NCHW4;
auto new_resize = opr::ResizeForward::make(
new_inp[0], new_inp[1], new_param, opr->config());
mgb_assert(new_resize.shape().ndim == 5,
"out var of Resize opr after transform must be 5 (got: "
"%zu).",
new_resize.shape().ndim);
return new_resize.node()->owner_opr();
}
auto new_opr =
serialization::copy_opr_shallow(*opr, new_inp, opr->config());
return new_opr;
};
auto replace_warp_perspective_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
using Param = opr::WarpPerspective::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>();
mgb_assert(warp.param().format == Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHW4.");
if (new_inp[0]->shape().ndim == 5) {
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
auto new_param = warp.param();
new_param.format = Format::NCHW4;
SymbolVar new_warp;
if (new_inp.size() == 3) {
new_warp = opr::WarpPerspectiveForward::make(
new_inp[0], new_inp[1], nullptr, new_inp[2], new_param,
opr->config());
} else {
mgb_assert(new_inp.size() == 4);
new_warp = opr::WarpPerspectiveForward::make(
new_inp[0], new_inp[1], new_inp[2], new_inp[3],
new_param, opr->config());
}
mgb_assert(new_warp.shape().ndim == 5,
"out var of WarpPerspective opr after transform must be "
"5 (got: "
"%zu).",
new_warp.shape().ndim);
return new_warp.node()->owner_opr();
}
auto new_opr =
serialization::copy_opr_shallow(*opr, new_inp, opr->config());
return new_opr;
};
auto&& replace_func = ret->m_opr_replace_func; auto&& replace_func = ret->m_opr_replace_func;
//! supportted nchw4 //! supportted nchw4
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
replace_func[opr::BatchConvBias::typeinfo()] = replace_func[opr::BatchConvBias::typeinfo()] =
replace_batch_conv_bias_opr; replace_batch_conv_bias_opr;
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
replace_warp_perspective_opr;
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr;
replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr;
//! not supported nchw4 //! not supported nchw4
replace_func[opr::PoolingForward::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_func[opr::ConvolutionBackwardData::typeinfo()] =
relayout_inp_to_nchw; relayout_inp_to_nchw;
...@@ -1619,9 +1704,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1619,9 +1704,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
relayout_inp_to_nchw;
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
return ret; return ret;
} }
......
...@@ -2430,14 +2430,16 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { ...@@ -2430,14 +2430,16 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv1 = opr::ConvBiasForward::make( auto conv1 = opr::ConvBiasForward::make(
x, w1, b1, param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); x, w1, b1, param_conv_bias, {},
OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
// group // group
// icpg != 1 && ocpg != 1 // icpg != 1 && ocpg != 1
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)), auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv2 = opr::ConvBiasForward::make(conv1, w2, b2, auto conv2 = opr::ConvBiasForward::make(
param_conv_bias, {}, OperatorNodeConfig{dtype::QuantizedS8{2.5f}}); conv1, w2, b2, param_conv_bias, {},
OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
auto y = opr::TypeCvt::make(conv2, dtype::Float32()); auto y = opr::TypeCvt::make(conv2, dtype::Float32());
...@@ -2455,8 +2457,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { ...@@ -2455,8 +2457,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
graph->compile({{y_opt, {}}}) graph->compile({{y_opt, {}}})
->to_json() ->to_json()
->writeto_fpath( ->writeto_fpath(output_file(
output_file("TestGoptInference.ConvertFormatNCHW4GPU.json")); "TestGoptInference.ConvertFormatNCHW4GPU.json"));
HostTensorND host_y, host_y_opt; HostTensorND host_y, host_y_opt;
auto func = graph->compile({make_callback_copy(y, host_y), auto func = graph->compile({make_callback_copy(y, host_y),
...@@ -2467,6 +2469,90 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) { ...@@ -2467,6 +2469,90 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
#endif #endif
TEST(TestGoptInference, ConvertFormatNCHW4NonConvOpr) {
auto cn = CompNode::load("xpu0");
HostTensorGenerator<dtype::Int8> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name),
dtype);
};
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};
auto mkcvarf32 = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto x = mkvar("x", {2, 4, 16, 16}, dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param_conv_bias;
param_conv_bias.format = opr::ConvBias::Param::Format::NCHW;
param_conv_bias.stride_h = param_conv_bias.stride_w = 1;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
// dense
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
auto w1 = mkcvar("w1", {8, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv1 = opr::ConvBiasForward::make(
x, w1, b1, param_conv_bias, {},
OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
// test Resize
auto shape_of = opr::GetVarShape::make(x);
auto subtensor = opr::Subtensor::make(
shape_of, {opr::Subtensor::AxisIndexer::make_interval(
0, x.make_scalar(2), None, x.make_scalar(1))});
opr::Resize::Param param_resize;
param_resize.format = opr::Resize::Param::Format::NCHW;
auto resize = opr::ResizeForward::make(conv1, subtensor * 2, param_resize);
// test WarpPerspective
auto mat = mkcvarf32("mat", {2, 3, 3}),
warp = opr::WarpPerspectiveForward::make(
resize, mat, nullptr, cg::var_from_tensor_shape(x, {32, 32}));
opr::Pooling::Param pool_param;
pool_param.format = opr::Pooling::Param::Format::NCHW;
// test Pooling
auto pool = opr::Pooling::make(warp, pool_param);
// group
// icpg != 1 && ocpg != 1
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv2 = opr::ConvBiasForward::make(
pool, w2, b2, param_conv_bias, {},
OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
auto add = opr::ElemwiseMultiType::make(
{conv1, conv2}, {opr::ElemwiseMultiType::Param::Mode::QADD},
OperatorNodeConfig{dtype::QuantizedS8{1.2f}});
auto y = opr::TypeCvt::make(add, dtype::Float32());
SymbolVar y_opt;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
auto nr_dimshuffle = find_opr_num<mgb::opr::Dimshuffle>(y_opt);
ASSERT_EQ(2u, nr_dimshuffle);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4,
find_opr<opr::ConvBias>(y_opt).param().format);
ASSERT_EQ(opr::ResizeForward::Param::Format::NCHW4,
find_opr<opr::ResizeForward>(y_opt).param().format);
ASSERT_EQ(opr::WarpPerspectiveForward::Param::Format::NCHW4,
find_opr<opr::WarpPerspectiveForward>(y_opt).param().format);
ASSERT_EQ(opr::PoolingForward::Param::Format::NCHW4,
find_opr<opr::PoolingForward>(y_opt).param().format);
}
TEST(TestGoptInference, ConvertFormatNCHW4) { TEST(TestGoptInference, ConvertFormatNCHW4) {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0"); auto cn = CompNode::load("cpu0");
...@@ -2481,7 +2567,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4) { ...@@ -2481,7 +2567,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4) {
}; };
auto x = mkvar("x", {2, 4, 16, 16}); auto x = mkvar("x", {2, 4, 16, 16});
// ConvBias // ConvBias test dense
opr::ConvBias::Param param_conv_bias; opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册