提交 a437ec8e 编写于 作者: M Megvii Engine Team

fix(src/gopt): add replace func of typecvt opr for nhwcd4 pass

GitOrigin-RevId: 801eb1dab3ccbdcf71e8a153e0d3c7c9a7dbe6db
上级 b1baee60
......@@ -1565,7 +1565,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
if (new_inp[i]->shape()[1] % 4 != 0) {
can_exec_cd4 = false;
}
//! cd4 elemwise with scaler is supported
//! cd4 elemwise with scaler is unsupported
} else if (!new_inp[i]->shape().is_scalar()) {
can_exec_cd4 = false;
}
......@@ -1627,6 +1627,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
replace_warp_perspective_opr;
......
......@@ -1265,6 +1265,55 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNHWCD4TypeCvt) {
NaiveMegDNNHandleScope naive_megdnn_handle;
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto host_x = gen({8, 8, 8, 8}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
opr::Convolution::Param param;
param.pad_h = param.pad_w = 0;
auto w1 = mkcvar("w1", {8, 8, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param),
tcvt1 = opr::TypeCvt::make(conv1, dtype::Float16());
auto w2 = mkcvar("w2", {8, 8, 3, 3}),
conv2 = opr::Convolution::make(x, w2, param),
tcvt2 = opr::TypeCvt::make(conv2, dtype::Float16());
auto y = opr::Elemwise::make({tcvt1, tcvt2}, opr::Elemwise::Param::Mode::ADD);
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4,
find_opr<opr::Convolution>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.ConvertFormatNHWCD4TypeCvt.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);
*host_x = *gen({8, 8, 16, 16}, cn);
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt);
}
TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册