提交 009c90a2 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mgb/gopt): modify padding policy for 4bit conv bias oprs

GitOrigin-RevId: 188a2c3728c017c77eba433211b308c9680b1dad
上级 4eda3388
......@@ -122,6 +122,14 @@ public:
NCHW_TO_NCHW64, //! <from nchw layout to nchw64 layout
NCHW_TO_NCHW32, //! <from nchw layout to nchw64 layout
NCHW4_TO_NCHW64, //! <from nchw4 layout to nchw64 layout
NCHW_TO_NHWC, //! <NHWC related layout transformation
NCHW4_TO_NHWC,
NCHW32_TO_NHWC,
NCHW64_TO_NHWC,
NHWC_TO_NCHW,
NHWC_TO_NCHW4,
NHWC_TO_NCHW32,
NHWC_TO_NCHW64,
};
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);
......@@ -428,7 +436,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 32;
} else {
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64) {
mgb_assert(layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64);
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 16 == 0);
......@@ -438,6 +447,75 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 64;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC) {
mgb_assert(inp_shape.ndim == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[1];
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW4_TO_NHWC) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[1] * 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW32_TO_NHWC) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[1] * 32;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW64_TO_NHWC) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 64);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[1] * 64;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW) {
mgb_assert(inp_shape.ndim == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[3];
dst[2] = inp_shape[1];
dst[3] = inp_shape[2];
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW4) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 4 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[3] / 4;
dst[2] = inp_shape[1];
dst[3] = inp_shape[2];
dst[4] = 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 32 == 0);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[3] / 32;
dst[2] = inp_shape[1];
dst[3] = inp_shape[2];
dst[4] = 32;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64) {
mgb_assert(layout_type() ==
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64);
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 64 == 0);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[3] / 64;
dst[2] = inp_shape[1];
dst[3] = inp_shape[2];
dst[4] = 64;
}
return true;
};
......@@ -934,6 +1012,93 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW_TO_NHWC] = [](VarNode* inp) -> VarNode* {
megdnn::param::RelayoutFormat param;
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWC;
auto reformat = opr::RelayoutFormat::make(inp, param);
return reformat.node();
};
reformat[LayoutType::NCHW4_TO_NHWC] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 =
opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::NCHW32_TO_NHWC] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 32}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::NCHW64_TO_NHWC] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 64}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::NHWC_TO_NCHW] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto y = opr::Dimshuffle::make(x, {0, 3, 1, 2});
return y.node();
};
reformat[LayoutType::NHWC_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3) / 4, cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4});
return y1.node();
};
reformat[LayoutType::NHWC_TO_NCHW32] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3) / 32, cv(32)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4});
return y1.node();
};
reformat[LayoutType::NHWC_TO_NCHW64] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4});
return y1.node();
};
auto rewriter = opt.graph().make_rewriter();
auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) {
......@@ -4095,20 +4260,37 @@ void PaddingChannelPass::apply(OptState& opt) const {
size_t new_in_channels = new_inp[0]->shape()[1];
// pad input channels
if (padding_oprs.count(opr->input(0)->owner_opr())) {
if (new_in_channels % 64 == 0) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
if (new_in_channels <= 32) {
if (new_in_channels % 8 == 0) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels_0 = 8 - (new_in_channels % 8);
size_t pad_channels_1 = 8 - (in_channels % 8);
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
}
} else {
size_t pad_channels_0 = 64 - (new_in_channels % 64);
size_t pad_channels_1 = 64 - (in_channels % 64);
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
if (new_in_channels % 64 == 0) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels_0 = 64 - (new_in_channels % 64);
size_t pad_channels_1 = 64 - (in_channels % 64);
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
}
}
} else {
size_t pad_channels = 0;
mgb_assert(new_in_channels == in_channels);
if (in_channels % 64)
pad_channels = 64 - (in_channels % 64);
if (in_channels <= 32) {
if (in_channels % 8)
pad_channels = 8 - (in_channels % 8);
} else {
if (in_channels % 64)
pad_channels = 64 - (in_channels % 64);
}
if (pad_channels > 0) {
inps[0] = pad_in_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
......@@ -4117,8 +4299,13 @@ void PaddingChannelPass::apply(OptState& opt) const {
out_channels = inps[1]->shape()[0];
in_channels = inps[1]->shape()[1];
size_t pad_channels = 0;
if (out_channels % 64)
pad_channels = 64 - (out_channels % 64);
if (out_channels <= 32) {
if (out_channels % 8)
pad_channels = 8 - (out_channels % 8);
} else {
if (out_channels % 64)
pad_channels = 64 - (out_channels % 64);
}
if (pad_channels > 0) {
inps[1] = pad_out_channels(inps[1], pad_channels);
inps[2] = pad_in_channels(inps[2], pad_channels);
......@@ -4402,20 +4589,16 @@ EnableNCHW64Pass::make_nchw64_converter() {
return new_conv.node();
}
};
auto try_transform_to_nchw =
[&format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
mgb_assert(opr->input().size()==new_inp.size());
bool check_dtype =
new_inp[0]->dtype().enumv() == DTypeEnum::Float32 &&
new_inp[1]->dtype().enumv() == DTypeEnum::Float32;
auto try_transform_to_nchw =
[&format_map](OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
mgb_assert(opr->input().size() == new_inp.size());
bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::Float32 &&
new_inp[1]->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 3)
check_dtype &=
new_inp[2]->dtype().enumv() == DTypeEnum::Float32;
check_dtype &= new_inp[2]->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 4)
check_dtype &=
new_inp[3]->dtype().enumv() == DTypeEnum::Float32;
check_dtype &= new_inp[3]->dtype().enumv() == DTypeEnum::Float32;
if (!check_dtype)
return nullptr;
auto inps = new_inp;
......@@ -4451,7 +4634,6 @@ EnableNCHW64Pass::make_nchw64_converter() {
return ret->output()[0];
};
auto try_transform_to_nchw4 =
[make_new_conv, &format_map](
OperatorNodeBase* opr,
......
......@@ -4735,6 +4735,101 @@ TEST(TestGoptInference, PaddingChannelsWithWarpPerspective) {
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
TEST(TestGoptInference, PaddingChannelsB4) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
cn.activate();
REQUIRE_CUDA_COMPUTE_CAPABILITY(7, 5);
HostTensorGenerator<dtype::Int8> gen;
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name),
dtype);
};
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};
auto x = mkvar("x", {16, 3, 14, 14}, dtype::QuantizedS8(2.5f)),
w = mkcvar("w", {16, 3, 3, 3}, dtype::QuantizedS8(2.5f)),
b = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(6.25f));
opr::ConvBias::Param param;
param.format = opr::ConvBias::Param::Format::NCHW;
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = 1;
auto y = opr::ConvBias::make(x, w, b, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::TypeCvt::make(y, dtype::Quantized4Asymm{20.f, 8});
opr::Pooling::Param pool;
pool.format = opr::Pooling::Param::Format::NCHW;
y = opr::Pooling::make(y, pool);
auto w1 = mkcvar("w1", {48, 16, 3, 3}, dtype::QuantizedS4(1.234f)),
b1 = mkcvar("b1", {1, 48, 1, 1}, dtype::QuantizedS32(20.f*1.234f));
auto y1 = opr::ConvBias::make(y, w1, b1, param, {},
OperatorNodeConfig{dtype::Quantized4Asymm(20.f, 8)});
auto w2 = mkcvar("w2", {48, 48, 3, 3}, dtype::QuantizedS4(1.234f)),
b2 = mkcvar("b2", {1, 48, 1, 1}, dtype::QuantizedS32(20.f*1.234f));
auto y2 = opr::ConvBias::make(
y1, w2, b2, param, {},
OperatorNodeConfig{dtype::Quantized4Asymm(20.f, 8)});
auto w3 = mkcvar("w2", {16, 48, 3, 3}, dtype::QuantizedS4(1.234f)),
b3 = mkcvar("b2", {1, 16, 1, 1}, dtype::QuantizedS32(20.f*1.234f));
auto y3 = opr::ConvBias::make(
y2, w3, b3, param, {},
OperatorNodeConfig{dtype::Quantized4Asymm(20.f, 8)});
using ElemMultiMode = opr::ElemwiseMultiType::Param::Mode;
auto y4 = opr::ElemwiseMultiType::make(
{y, y3}, {ElemMultiMode::QFUSE_ADD_RELU},
OperatorNodeConfig{dtype::Quantized4Asymm{20.f, 7}});
y4 = opr::TypeCvt::make(y4, dtype::Float32());
SymbolVar y4_pad;
unpack_vector(gopt::GraphOptimizer{}
.add_pass<gopt::PaddingChannelPass>()
.add_pass<gopt::ParamFusePass>()
.apply({{y4}})
.endpoint_vars(),
y4_pad);
ASSERT_EQ(y4_pad.node()->shape()[1], y4.node()->shape()[1]);
SmallVector<cg::OperatorNodeBase*> oprs;
auto cb1 = [&oprs](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::ConvBias>()) {
oprs.push_back(opr);
}
};
cg::DepOprIter{cb1}.add(y4_pad.node()->owner_opr());
ASSERT_EQ(oprs.size(), 4);
ASSERT_EQ(oprs[0]->output(0)->shape()[1], 16);
ASSERT_EQ(oprs[1]->output(0)->shape()[1], 64);
ASSERT_EQ(oprs[2]->output(0)->shape()[1], 64);
ASSERT_EQ(oprs[3]->output(0)->shape()[1], 16);
size_t nr_concat = find_opr_num<opr::Concat>(y4_pad);
ASSERT_EQ(nr_concat, 1);
cg::OperatorNodeBase* concat = nullptr;
auto cb2 = [&concat](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::Concat>()) {
concat = opr;
}
};
cg::DepOprIter{cb2}.add(y4_pad.node()->owner_opr());
ASSERT_EQ(oprs[0]->input(0)->owner_opr(), concat);
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y4, t1)});
func1->execute();
auto func2 = graph->compile({make_callback_copy(y4_pad, t2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册