提交 a3560fa1 编写于 作者: M Megvii Engine Team

feat(gopt): add tranform to chwn4 to optimize_for_inference

GitOrigin-RevId: 4d1a9c6c8410904ea4da17a1bed2ad06ce369869
上级 1fb7d34f
......@@ -542,7 +542,8 @@ def optimize_for_inference(
use_nchw32=False,
fuse_conv_bias_with_z=False,
use_nchw88=False,
use_nchw44=False
use_nchw44=False,
use_chwn4=False
):
"""optimize computing graph for inference
......@@ -566,6 +567,8 @@ def optimize_for_inference(
times.
:param use_nchw32: whether to use NCHW32 tensor format. Mainly used for
nvidia tensorcore.
:param use_chwn4: whether to use CHWN4 tensor format. Mainly used for
nvidia tensorcore.
:return: list of transformed vars corresponding to given output vars
......@@ -589,6 +592,7 @@ def optimize_for_inference(
"use_nchw32": "nchw2nchw32",
"use_nchw88": "nchw2nchw88",
"use_nchw44": "nchw2nchw44",
"use_chwn4": "nchw42chwn4",
}.items():
if settings[k]:
assert (
......
......@@ -84,6 +84,7 @@ struct _OptimizeForInferenceOptions {
SET(nchw2nchw88, NCHW2NCHW88);
SET(nchw2nchw44, NCHW2NCHW44);
SET(nchw2nchw32, NCHW2NCHW32);
SET(nchw42chwn4, NCHW42CHWN4);
#undef SET
};
......
......@@ -254,8 +254,9 @@ def optimize_for_inference(args, outputs):
'enable_hwcd4': 'use_nhwcd4',
'enable_nchw88': 'use_nchw88',
'enable_nchw44': 'use_nchw44',
'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity',
'enable_nchw32': 'use_nchw32',
'enable_chwn4': 'use_chwn4',
'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity',
'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z',
}
kwargs = {}
......@@ -398,6 +399,12 @@ def main():
help='transform the model format from NCHW4 to NCHW32 '
'for inference on nvidia TensoCore'
)
parser.add_argument(
'--enable-chwn4',
action='store_true',
help='transform the model format to CHWN4 '
'for inference, mainly used for nvidia tensorcore'
)
parser.add_argument(
'--enable-fuse-conv-bias-with-z',
action='store_true',
......
......@@ -724,6 +724,13 @@ void GraphOptimizer::apply_optimize_options(
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
}
if (options->transform_nchw42chwn4()) {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
add_pass(EnableCHWN4Pass::make_chwn4_converter());
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
}
if (options->fuse_conv_bias_nonlinearity) {
add_pass<FuseConvBiasNonlinPass>();
......
......@@ -395,6 +395,8 @@ namespace gopt {
NCHW2NCHW44, ///< compute using NCHW44 tensor format
NCHW2NCHW32, ///< compute using NCHW32 tensor format, used for
///< tensorcore
NCHW42CHWN4, ///< compute using CHWN4 tensor format, transformed
///< from NCHW4, mainly used for cuda
};
LayoutTransform layout_transform = LayoutTransform::DEFAULT;
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
......@@ -424,6 +426,7 @@ namespace gopt {
SET(nchw2nchw88, NCHW2NCHW88);
SET(nchw2nchw44, NCHW2NCHW44);
SET(nchw2nchw32, NCHW2NCHW32);
SET(nchw42chwn4, NCHW42CHWN4);
#undef SET
};
......
......@@ -2011,14 +2011,11 @@ TEST(TestGoptInference, EnableCHWN4) {
y4 = opr::TypeCvt::make(y4, dtype::Float32());
SymbolVar y_opt;
SymbolVar y_cudnn;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass<gopt::FuseConvBiasNonlinPass>()
.add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter())
.add_pass<gopt::FuseConvBiasZPass>()
.apply({{y4}})
.endpoint_vars(),
y_opt);
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw42chwn4();
unpack_vector(gopt::optimize_for_inference({y4}, options), y_opt);
}
unpack_vector(gopt::GraphOptimizer{}
.add_pass<gopt::FuseConvBiasNonlinPass>()
.add_pass<gopt::FuseConvBiasZPass>()
......@@ -2100,13 +2097,11 @@ TEST(TestGoptInference, EnableCHWN4WarpPespective) {
auto y2 = opr::WarpPerspective::make(y1, mat_var, TensorShape{16, 16}, warp_param);
SymbolVar y_opt;
SymbolVar y_cudnn;
unpack_vector(gopt::GraphOptimizer{}
.add_pass<gopt::FuseConvBiasNonlinPass>()
.add_pass<gopt::FuseConvBiasZPass>()
.add_pass(gopt::EnableCHWN4Pass::make_chwn4_converter())
.apply({{y2}})
.endpoint_vars(),
y_opt);
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw42chwn4();
unpack_vector(gopt::optimize_for_inference({y2}, options), y_opt);
}
unpack_vector(gopt::GraphOptimizer{}
.add_pass<gopt::FuseConvBiasNonlinPass>()
.add_pass<gopt::FuseConvBiasZPass>()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册