diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 58acd3d448a540adb700424249012ffa3122c7f6..1defadb94572c1b2843a375b77777dadcf49de91 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1565,7 +1565,7 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { if (new_inp[i]->shape()[1] % 4 != 0) { can_exec_cd4 = false; } - //! cd4 elemwise with scaler is supported + //! cd4 elemwise with scaler is unsupported } else if (!new_inp[i]->shape().is_scalar()) { can_exec_cd4 = false; } @@ -1627,6 +1627,7 @@ std::unique_ptr ConvertFormatPass::make_nhwcd4_converter() { replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw; replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw; replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw; + replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; replace_func[opr::WarpPerspectiveForward::typeinfo()] = replace_warp_perspective_opr; diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 5646ea3d741348464a75ab40bd4ef2d88a9f57c6..808cc17490e05ee94d2c06860932f1bcd41f5f0c 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1265,6 +1265,55 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); } +TEST(TestGoptInference, ConvertFormatNHWCD4TypeCvt) { + NaiveMegDNNHandleScope naive_megdnn_handle; + + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + auto host_x = gen({8, 8, 8, 8}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + + opr::Convolution::Param param; + + param.pad_h = param.pad_w = 0; + auto w1 = mkcvar("w1", {8, 8, 3, 3}), + conv1 = opr::Convolution::make(x, w1, param), + tcvt1 = opr::TypeCvt::make(conv1, dtype::Float16()); + auto w2 = mkcvar("w2", {8, 8, 3, 3}), + conv2 = opr::Convolution::make(x, w2, param), + tcvt2 = opr::TypeCvt::make(conv2, dtype::Float16()); + auto y = opr::Elemwise::make({tcvt1, tcvt2}, opr::Elemwise::Param::Mode::ADD); + + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nhwcd4(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4, + find_opr(y_opt).param().format); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.ConvertFormatNHWCD4TypeCvt.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); + + *host_x = *gen({8, 8, 16, 16}, cn); + func->execute(); + MGB_ASSERT_TENSOR_EQ(host_y, host_y_opt); +} + TEST(TestGoptInference, ConvertFormatNHWCD4LOCAL) { // hwcd4 is only supported in naive handle NaiveMegDNNHandleScope naive_megdnn_handle;