提交 554ce352 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mgb/gopt): add nchw44 optpass

GitOrigin-RevId: dc38724558b0c6635ea9a3137e1c0d0acc665a0f
上级 7d1e1f9a
......@@ -541,7 +541,8 @@ def optimize_for_inference(
fuse_conv_bias_nonlinearity=False,
use_tensor_core=False,
fuse_conv_bias_with_z=False,
use_nchw88=False
use_nchw88=False,
use_nchw44=False
):
"""optimize computing graph for inference
......@@ -559,7 +560,9 @@ def optimize_for_inference(
OpenCL devices
:param fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr. This is supported only in NHWCD4 format.
:param use_nchw88: whether to use NCHW4 tensor format. This maybe faster some
:param use_nchw88: whether to use NCHW88 tensor format. This maybe faster some
times.
:param use_nchw44: whether to use NCHW44 tensor format. This maybe faster some
times.
......@@ -577,6 +580,7 @@ def optimize_for_inference(
"use_tensor_core",
"fuse_conv_bias_with_z",
"use_nchw88",
"use_nchw44",
]:
if settings[i]:
getattr(opt, "enable_{}".format(i))()
......
......@@ -79,6 +79,7 @@ struct _OptimizeForInferenceOptions {
SET(use_tensor_core);
SET(fuse_conv_bias_with_z);
SET(use_nchw88);
SET(use_nchw44);
#undef SET
};
......
......@@ -253,6 +253,7 @@ def optimize_for_inference(args, outputs):
'enable_ioc16': 'f16_io_comp',
'enable_hwcd4': 'use_nhwcd4',
'enable_nchw88': 'use_nchw88',
'enable_nchw44': 'use_nchw44',
'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity',
'enable_tensorcore': 'use_tensor_core',
'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z',
......@@ -385,6 +386,12 @@ def main():
help='transform the model format from NCHW to NCHW88 '
'for inference'
)
parser.add_argument(
'--enable-nchw44',
action='store_true',
help='transform the model format from NCHW to NCHW44 '
'for inference'
)
parser.add_argument(
'--enable-tensorcore',
action='store_true',
......
......@@ -700,6 +700,9 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
if (inference_opt->use_nchw88) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
}
if (inference_opt->use_nchw44) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
}
if (inference_opt->use_tensor_core) {
mgb_assert(inference_opt->fuse_conv_bias_nonlinearity,
"enable tensor core should fuse conv bias activation "
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/gopt/inference.h"
......@@ -63,7 +64,10 @@ public:
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
NCHW_TO_NCHW44, //!< from nchw layout to nchw44 layout
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
NCHW44_TO_NCHW, //!< from nchw44 layout to nchw layout
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
//!< layout
WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to
......@@ -73,6 +77,16 @@ public:
//!< the weight layout of input is nchw output is nchw88, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8}
WEIGHT_HYBIRD_NCHW_NCHW88,
WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44
//!< layout
WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to
//!< nchw44 layout
WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout
//!< to nchw44 layout
//!< the weight layout of input is nchw output is nchw44, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4}
WEIGHT_HYBIRD_NCHW_NCHW44,
};
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);
......@@ -203,10 +217,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[3] = inp_shape[3];
dst[4] = inp_shape[4];
dst[5] = 8;
} else {
mgb_assert(
layout_type() ==
RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88);
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0] / 8;
......@@ -214,6 +226,68 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[2] = inp_shape[3];
dst[3] = inp_shape[1];
dst[4] = 8;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW44) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_DENSE) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 &&
inp_shape[1] % 4 == 0);
dst.ndim = 6;
dst[0] = inp_shape[0] / 4;
dst[1] = inp_shape[1] / 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
dst[5] = 4;
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_GROUP) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 4 == 0 &&
inp_shape[2] % 4 == 0);
dst.ndim = 7;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 4;
dst[2] = inp_shape[2] / 4;
dst[3] = inp_shape[3];
dst[4] = inp_shape[4];
dst[5] = 4;
dst[6] = 4;
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_CHAN) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 &&
inp_shape[2] == 1 && inp_shape[0] % 4 == 0);
dst.ndim = 6;
dst[0] = inp_shape[0] / 4;
dst[1] = inp_shape[1];
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = inp_shape[4];
dst[5] = 4;
} else {
mgb_assert(
layout_type() ==
RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44);
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0] / 4;
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[1];
dst[4] = 4;
}
return true;
};
......@@ -418,6 +492,104 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW_TO_NCHW44] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW44_TO_NCHW] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DENSE] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_GROUP] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4,
cv(4), sub(3), sub(4)},
0),
tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3),
sub(4), cv(4), cv(4)},
0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_CHAN] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0) / 4, cv(4), sub(1), sub(2), sub(3), sub(4)}, 0),
tshp1 = opr::Concat::make(
{sub(0) / 4, sub(1), sub(2), sub(3), sub(4), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0) / 4, cv(4), sub(1), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0) / 4, sub(2), sub(3), sub(1), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
auto rewriter = opt.graph().make_rewriter();
auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) {
......@@ -1071,16 +1243,24 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
if (!orig_var->shape().eq_shape(new_var->shape())) {
return RelayoutPlaceholder::make(
new_var, RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW)
.node();
if (m_pack_c_size == 8) {
return RelayoutPlaceholder::make(
new_var,
RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW)
.node();
} else if (m_pack_c_size == 4) {
return RelayoutPlaceholder::make(
new_var,
RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW)
.node();
}
}
return new_var;
}
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
size_t pack_c_size) {
auto ret = std::make_unique<EnableNchwxxPass>();
auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size);
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
//! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode
......@@ -1102,8 +1282,18 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
megdnn::param::Pooling::Format pooling_format =
megdnn::param::Pooling::Format::NCHW88;
std::string convter_pass_name = "conv_format_nchw88";
mgb_assert(pack_c_size == static_cast<size_t>(8),
"The ConvertFormatPass to nchwxx only support NCHW88 now !");
if (pack_c_size == 4) {
weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE;
weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP;
weight_to_nchwxx_mode_chan = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN;
hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW44;
src_to_nchw_mode = RelayoutMode::NCHW44_TO_NCHW;
conv_bias_format = megdnn::param::ConvBias::Format::NCHW44;
conv_format = megdnn::param::ConvolutionV0::Format::NCHW44;
pooling_format = megdnn::param::Pooling::Format::NCHW44;
convter_pass_name = "conv_format_nchw44";
}
auto test_trans_nchwxx =
[pack_c_size, weight_to_nchwxx_mode_dense,
weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan,
......@@ -1297,7 +1487,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
auto new_param = conv_bias_opr.param();
new_param.format = conv_bias_format;
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, new_param,
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
......@@ -1330,6 +1520,51 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
}
};
auto replace_concat_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
bool has_inp_changed = false;
bool can_exec_ncwxx = true;
for (size_t i = 0; i < opr->input().size(); i++) {
if (new_inp[i]->shape().ndim == 5) {
has_inp_changed = true;
break;
} else if (new_inp[i]->shape().ndim == 4) {
if (new_inp[i]->shape()[1] % pack_c_size != 0) {
can_exec_ncwxx = false;
}
}
}
if (has_inp_changed) {
auto temp_inp = new_inp;
if (can_exec_ncwxx) {
for (size_t i = 0; i < opr->input().size(); i++) {
if (new_inp[i]->shape().ndim == 4) {
auto new_var = RelayoutPlaceholder::make(
new_inp[i], src_to_nchwxx_mode);
temp_inp[i] = new_var.node();
} else {
mgb_assert((new_inp[i]->shape().ndim == 5) ||
new_inp[i]->shape().is_scalar());
}
}
} else {
for (size_t i = 0; i < opr->input().size(); i++) {
if (new_inp[i]->shape().ndim == 5) {
auto new_var = RelayoutPlaceholder::make(
new_inp[i], src_to_nchw_mode);
temp_inp[i] = new_var.node();
}
}
}
return serialization::copy_opr_shallow(*opr, temp_inp,
opr->config());
} else {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
};
auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
......@@ -1382,6 +1617,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
replace_func[opr::Concat::typeinfo()] = replace_concat_opr;
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr;
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr;
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr;
......@@ -1390,13 +1626,10 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter(
replace_func[opr::ConvolutionBackwardData::typeinfo()] =
relayout_inp_to_nchw;
replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
......
......@@ -234,16 +234,18 @@ namespace gopt {
*/
class EnableNchwxxPass final : public TensorReformatPass {
std::string m_name = "tensor_format_nchwxx";
size_t m_pack_c_size;
VarNode* on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const override;
//! the flag for conv to transform to nchwxx
enum class TransType {
TRANS_PURE_NCHWXX, //!< weight and src all trans to nchw88
TRANS_HYBIRD_NCHWXX, //!< input is nchw, output is nchw88
TRANS_PURE_NCHWXX, //!< weight and src all trans to nchwxx
TRANS_HYBIRD_NCHWXX, //!< input is nchw, output is nchwxx
TRANS_NONE, //!< no need trans
};
public:
EnableNchwxxPass(size_t pack_c_size) : m_pack_c_size(pack_c_size) {}
const char* name() const override {
return mgb_cstr_log(m_name.c_str());
}
......@@ -265,6 +267,8 @@ namespace gopt {
bool use_nhwcd4 = false;
//! whether to compute using NCHW88 tensor format
bool use_nchw88 = false;
//! whether to compute using NCHW44 tensor format
bool use_nchw44 = false;
//! whether to enable tensor core
bool use_tensor_core = false;
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
......@@ -283,6 +287,7 @@ namespace gopt {
SET(use_tensor_core);
SET(fuse_conv_bias_with_z);
SET(use_nchw88);
SET(use_nchw44);
#undef SET
};
......
......@@ -2325,5 +2325,86 @@ TEST(TestGoptInference, ConvertFormatNCHW88) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
}
TEST(TestGoptInference, ConvertFormatNCHW44) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name);
};
auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//!Hybrid nchw88 mode
opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv);
//!channel wise
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}),
conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias);
//! group
auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}),
conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias);
auto shape_of = opr::GetVarShape::make(conv3);
auto subtensor = opr::Subtensor::make(
shape_of, {opr::Subtensor::AxisIndexer::make_interval(
0, x.make_scalar(2), None, x.make_scalar(1))});
opr::Resize::Param param_resize;
param_resize.format = opr::Resize::Param::Format::NCHW;
auto resize = opr::ResizeForward::make(conv3, subtensor * 2, param_resize);
auto mat = mkcvar("mat", {2, 3, 3}),
warp = opr::WarpPerspectiveForward::make(
resize, mat, nullptr, cg::var_from_tensor_shape(x, {4, 4}));
auto b = mkvar("b", {1, 8, 1, 1}),
elem = opr::Elemwise::make({warp + b},
opr::Elemwise::Param::Mode::RELU);
//! Dense
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}),
conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias);
auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}),
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias);
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}),
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias);
SymbolVar y_opt;
unpack_vector(
gopt::optimize_for_inference(
{y},
gopt::OptimizeForInferenceOptions{}.enable_use_nchw44()),
y_opt);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(
output_file("TestGoptInference.ConvertFormatNCHW44.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
*host_x = *gen({2, 3, 32, 32}, cn);
func->execute();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -99,7 +99,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
group = filter_shape[0];
}
if (param.format == Param::Format::NCHW88) {
//! if channel wise weight layout is {group/8, 1, 1, FH, FW, 8}
//! if channel wise weight layout is {group/8, FH, FW, 1, 1, 8}
if (filter_shape[1] == 1 && filter_shape[2] == 1) {
group *= 8;
}
......@@ -107,6 +107,15 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
src_shape[1] / group * 2;
return hybird_nchwx ? computation : computation * 8;
}
if (param.format == Param::Format::NCHW44) {
//! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4}
if (filter_shape[1] == 1 && filter_shape[2] == 1) {
group *= 4;
}
size_t computation = dst_shape.total_nr_elems() * fh * fw *
src_shape[1] / group * 2;
return hybird_nchwx ? computation : computation * 4;
}
if (param.format == Param::Format::NCHW32) {
return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * 32 /
group * 2;
......@@ -135,6 +144,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
};
if (param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW32) {
return eval_conv_computation_nchwx();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册