提交 8070f40a 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mgb/gopt): fix gopt nchwxx convert elemwise and reshape

GitOrigin-RevId: 982dee36e111bf4cc25321cf5ee8ec20d14bfce2
上级 b38e8225
......@@ -2049,16 +2049,17 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){
return new_opr;
}
};
auto replace_concat_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
//! When input change and all input can convert to nchwxx, this opr will run
//! in nchwxx mode, else it will run in nchw mode, for example concat and
//! elemwise opr
auto replace_multi_inp_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
bool has_inp_changed = false;
bool can_exec_ncwxx = true;
for (size_t i = 0; i < opr->input().size(); i++) {
if (new_inp[i]->shape().ndim == 5) {
has_inp_changed = true;
break;
} else if (new_inp[i]->shape().ndim == 4) {
if (new_inp[i]->shape()[1] % pack_c_size != 0) {
can_exec_ncwxx = false;
......@@ -2095,36 +2096,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){
}
};
auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
bool has_inp_changed = false;
for (size_t i = 0; i < opr->input().size(); i++) {
if (new_inp[i]->shape().ndim == 5) {
has_inp_changed = true;
break;
}
}
if (has_inp_changed) {
auto temp_inp = new_inp;
for (size_t i = 0; i < opr->input().size(); i++) {
if (new_inp[i]->shape().ndim == 4) {
auto new_var = RelayoutPlaceholder::make(
new_inp[i], src_to_nchwxx_mode);
temp_inp[i] = new_var.node();
} else {
mgb_assert((new_inp[i]->shape().ndim == 5) ||
new_inp[i]->shape().is_scalar());
}
}
return serialization::copy_opr_shallow(*opr, temp_inp,
opr->config());
} else {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
};
auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
......@@ -2146,11 +2117,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
replace_func[opr::Concat::typeinfo()] = replace_concat_opr;
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr;
replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr;
replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr;
replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr;
replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr;
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_multi_inp_opr;
replace_func[opr::PowC::typeinfo()] = replace_multi_inp_opr;
//! not support yet
replace_func[opr::ConvolutionBackwardData::typeinfo()] =
relayout_inp_to_nchw;
......@@ -2164,6 +2135,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
relayout_inp_to_nchw;
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw;
}
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
......
......@@ -2948,6 +2948,90 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
}
TEST(TestGoptInference, ConvertFormatNCHW44MultiInput) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto host_x1 = gen({1, 8, 16, 16}, cn);
auto host_x2 = gen({1, 1, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 8, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv);
auto b = mkvar("b", {1, 1, 16, 16}),
y = opr::Elemwise::make({conv1 + b}, opr::Elemwise::Param::Mode::RELU);
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw44();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::Convolution>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.ConvertFormatNCHW44MultiInput.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
}
TEST(TestGoptInference, ConvertFormatNCHW44Reshape) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto host_x1 = gen({1, 8, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 8, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv);
auto y = opr::Reshape::make(conv1, {8, 16 * 16});
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw44();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::Convolution>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.ConvertFormatNCHW44Reshape.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
}
TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册