提交 51868533 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix opt pass elementwise operation shape issue at tranform to NCHW4

GitOrigin-RevId: c0c4e3f82ecd1149855b969027de5fe49b6efe95
上级 0a191f8d
......@@ -1670,10 +1670,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
};
auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
bool has_inp_changed = false;
for (size_t i = 0; i < opr->input().size(); i++) {
......@@ -2827,18 +2823,18 @@ MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr,
public:
AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);
static SymbolVar make(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);
TensorFormat inp_format() const {
return m_inp_format;
}
TensorFormat out_format() const {
return m_out_format;
}
private:
void init_output_static_infer_desc() override;
void scn_do_execute() override;
......@@ -3262,7 +3258,7 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp);
auto y2 = opr::TypeCvt::make(y1, dtype::Float32());
auto y2 = opr::TypeCvt::make(y1, dtype::Float32());
return y2.node();
};
......
......@@ -2910,7 +2910,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {
conv1, w2, b2, param_conv_bias, {},
OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
auto y = opr::TypeCvt::make(conv2, dtype::Float32());
auto conv2_fp32 = opr::TypeCvt::make(conv2, dtype::Float32());
auto y = conv2_fp32 + opr::TypeCvt::make(b2, dtype::Float32());
SymbolVar y_opt;
{
......@@ -4076,7 +4077,7 @@ TEST(TestGoptInference, FoldingConvDimshuffleNCHW32NCHW4) {
auto y = opr::ConvBias::make(x, w, b, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
param.stride_h = param.stride_w = 1;
param.stride_h = param.stride_w = 1;
y = opr::ConvBias::make(y, w1, b1, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::TypeCvt::make(y, dtype::Float32());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册