Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
009c90a2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
009c90a2
编写于
6月 08, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/gopt): modify padding policy for 4bit conv bias oprs
GitOrigin-RevId: 188a2c3728c017c77eba433211b308c9680b1dad
上级
4eda3388
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
302 addition
and
25 deletion
+302
-25
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+207
-25
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+95
-0
未找到文件。
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
009c90a2
...
...
@@ -122,6 +122,14 @@ public:
NCHW_TO_NCHW64
,
//! <from nchw layout to nchw64 layout
NCHW_TO_NCHW32
,
//! <from nchw layout to nchw64 layout
NCHW4_TO_NCHW64
,
//! <from nchw4 layout to nchw64 layout
NCHW_TO_NHWC
,
//! <NHWC related layout transformation
NCHW4_TO_NHWC
,
NCHW32_TO_NHWC
,
NCHW64_TO_NHWC
,
NHWC_TO_NCHW
,
NHWC_TO_NCHW4
,
NHWC_TO_NCHW32
,
NHWC_TO_NCHW64
,
};
RelayoutPlaceholder
(
VarNode
*
src_var
,
LayoutType
layout_type
);
...
...
@@ -428,7 +436,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst
[
2
]
=
inp_shape
[
2
];
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
32
;
}
else
{
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NCHW64
)
{
mgb_assert
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NCHW64
);
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
1
]
%
16
==
0
);
...
...
@@ -438,6 +447,75 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst
[
2
]
=
inp_shape
[
2
];
dst
[
3
]
=
inp_shape
[
3
];
dst
[
4
]
=
64
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW_TO_NHWC
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
);
dst
.
ndim
=
4
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
2
];
dst
[
2
]
=
inp_shape
[
3
];
dst
[
3
]
=
inp_shape
[
1
];
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW4_TO_NHWC
)
{
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
4
]
==
4
);
dst
.
ndim
=
4
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
2
];
dst
[
2
]
=
inp_shape
[
3
];
dst
[
3
]
=
inp_shape
[
1
]
*
4
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW32_TO_NHWC
)
{
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
4
]
==
32
);
dst
.
ndim
=
4
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
2
];
dst
[
2
]
=
inp_shape
[
3
];
dst
[
3
]
=
inp_shape
[
1
]
*
32
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NCHW64_TO_NHWC
)
{
mgb_assert
(
inp_shape
.
ndim
==
5
&&
inp_shape
[
4
]
==
64
);
dst
.
ndim
=
4
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
2
];
dst
[
2
]
=
inp_shape
[
3
];
dst
[
3
]
=
inp_shape
[
1
]
*
64
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NHWC_TO_NCHW
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
);
dst
.
ndim
=
4
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
3
];
dst
[
2
]
=
inp_shape
[
1
];
dst
[
3
]
=
inp_shape
[
2
];
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NHWC_TO_NCHW4
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
3
]
%
4
==
0
);
dst
.
ndim
=
5
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
3
]
/
4
;
dst
[
2
]
=
inp_shape
[
1
];
dst
[
3
]
=
inp_shape
[
2
];
dst
[
4
]
=
4
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NHWC_TO_NCHW32
)
{
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
3
]
%
32
==
0
);
dst
.
ndim
=
4
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
3
]
/
32
;
dst
[
2
]
=
inp_shape
[
1
];
dst
[
3
]
=
inp_shape
[
2
];
dst
[
4
]
=
32
;
}
else
if
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NHWC_TO_NCHW64
)
{
mgb_assert
(
layout_type
()
==
RelayoutPlaceholder
::
LayoutType
::
NHWC_TO_NCHW64
);
mgb_assert
(
inp_shape
.
ndim
==
4
&&
inp_shape
[
3
]
%
64
==
0
);
dst
.
ndim
=
4
;
dst
[
0
]
=
inp_shape
[
0
];
dst
[
1
]
=
inp_shape
[
3
]
/
64
;
dst
[
2
]
=
inp_shape
[
1
];
dst
[
3
]
=
inp_shape
[
2
];
dst
[
4
]
=
64
;
}
return
true
;
};
...
...
@@ -934,6 +1012,93 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto
y2
=
opr
::
Reshape
::
make
(
y1
,
tshp1
);
return
y2
.
node
();
};
reformat
[
LayoutType
::
NCHW_TO_NHWC
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
megdnn
::
param
::
RelayoutFormat
param
;
param
.
mode
=
megdnn
::
param
::
RelayoutFormat
::
Mode
::
NCHW_NHWC
;
auto
reformat
=
opr
::
RelayoutFormat
::
make
(
inp
,
param
);
return
reformat
.
node
();
};
reformat
[
LayoutType
::
NCHW4_TO_NHWC
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
2
),
sub
(
3
),
sub
(
1
)
*
4
},
0
);
auto
y0
=
opr
::
Dimshuffle
::
make
(
x
,
{
0
,
2
,
3
,
1
,
4
});
auto
y1
=
opr
::
Reshape
::
make
(
y0
,
tshp0
);
return
y1
.
node
();
};
reformat
[
LayoutType
::
NCHW32_TO_NHWC
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
2
),
sub
(
3
),
sub
(
1
)
*
32
},
0
);
auto
y0
=
opr
::
Dimshuffle
::
make
(
x
,
{
0
,
2
,
3
,
1
,
4
});
auto
y1
=
opr
::
Reshape
::
make
(
y0
,
tshp0
);
return
y1
.
node
();
};
reformat
[
LayoutType
::
NCHW64_TO_NHWC
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
({
sub
(
0
),
sub
(
2
),
sub
(
3
),
sub
(
1
)
*
64
},
0
);
auto
y0
=
opr
::
Dimshuffle
::
make
(
x
,
{
0
,
2
,
3
,
1
,
4
});
auto
y1
=
opr
::
Reshape
::
make
(
y0
,
tshp0
);
return
y1
.
node
();
};
reformat
[
LayoutType
::
NHWC_TO_NCHW
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
y
=
opr
::
Dimshuffle
::
make
(
x
,
{
0
,
3
,
1
,
2
});
return
y
.
node
();
};
reformat
[
LayoutType
::
NHWC_TO_NCHW4
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
),
sub
(
2
),
sub
(
3
)
/
4
,
cv
(
4
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
3
,
1
,
2
,
4
});
return
y1
.
node
();
};
reformat
[
LayoutType
::
NHWC_TO_NCHW32
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
),
sub
(
2
),
sub
(
3
)
/
32
,
cv
(
32
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
3
,
1
,
2
,
4
});
return
y1
.
node
();
};
reformat
[
LayoutType
::
NHWC_TO_NCHW64
]
=
[](
VarNode
*
inp
)
->
VarNode
*
{
auto
x
=
SymbolVar
(
inp
);
auto
xshp
=
opr
::
GetVarShape
::
make
(
x
);
auto
cv
=
[
&
x
](
int
v
)
{
return
x
.
make_scalar
(
v
);
};
auto
sub
=
[
&
xshp
,
&
cv
](
int
idx
)
{
return
opr
::
IndexAt
::
make
(
xshp
,
{{
0
,
cv
(
idx
)}});
};
auto
tshp0
=
opr
::
Concat
::
make
(
{
sub
(
0
),
sub
(
1
),
sub
(
2
),
sub
(
3
)
/
64
,
cv
(
64
)},
0
);
auto
y0
=
opr
::
Reshape
::
make
(
x
,
tshp0
);
auto
y1
=
opr
::
Dimshuffle
::
make
(
y0
,
{
0
,
3
,
1
,
2
,
4
});
return
y1
.
node
();
};
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
on_opr
=
[
&
reformat
,
&
rewriter
](
OperatorNodeBase
*
opr
)
{
...
...
@@ -4095,20 +4260,37 @@ void PaddingChannelPass::apply(OptState& opt) const {
size_t
new_in_channels
=
new_inp
[
0
]
->
shape
()[
1
];
// pad input channels
if
(
padding_oprs
.
count
(
opr
->
input
(
0
)
->
owner_opr
()))
{
if
(
new_in_channels
%
64
==
0
)
{
size_t
pad_channels
=
new_in_channels
-
in_channels
;
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
if
(
new_in_channels
<=
32
)
{
if
(
new_in_channels
%
8
==
0
)
{
size_t
pad_channels
=
new_in_channels
-
in_channels
;
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
}
else
{
size_t
pad_channels_0
=
8
-
(
new_in_channels
%
8
);
size_t
pad_channels_1
=
8
-
(
in_channels
%
8
);
inps
[
0
]
=
pad_in_channels
(
new_inp
[
0
],
pad_channels_0
);
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels_1
);
}
}
else
{
size_t
pad_channels_0
=
64
-
(
new_in_channels
%
64
);
size_t
pad_channels_1
=
64
-
(
in_channels
%
64
);
inps
[
0
]
=
pad_in_channels
(
new_inp
[
0
],
pad_channels_0
);
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels_1
);
if
(
new_in_channels
%
64
==
0
)
{
size_t
pad_channels
=
new_in_channels
-
in_channels
;
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
}
else
{
size_t
pad_channels_0
=
64
-
(
new_in_channels
%
64
);
size_t
pad_channels_1
=
64
-
(
in_channels
%
64
);
inps
[
0
]
=
pad_in_channels
(
new_inp
[
0
],
pad_channels_0
);
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels_1
);
}
}
}
else
{
size_t
pad_channels
=
0
;
mgb_assert
(
new_in_channels
==
in_channels
);
if
(
in_channels
%
64
)
pad_channels
=
64
-
(
in_channels
%
64
);
if
(
in_channels
<=
32
)
{
if
(
in_channels
%
8
)
pad_channels
=
8
-
(
in_channels
%
8
);
}
else
{
if
(
in_channels
%
64
)
pad_channels
=
64
-
(
in_channels
%
64
);
}
if
(
pad_channels
>
0
)
{
inps
[
0
]
=
pad_in_channels
(
new_inp
[
0
],
pad_channels
);
inps
[
1
]
=
pad_in_channels
(
new_inp
[
1
],
pad_channels
);
...
...
@@ -4117,8 +4299,13 @@ void PaddingChannelPass::apply(OptState& opt) const {
out_channels
=
inps
[
1
]
->
shape
()[
0
];
in_channels
=
inps
[
1
]
->
shape
()[
1
];
size_t
pad_channels
=
0
;
if
(
out_channels
%
64
)
pad_channels
=
64
-
(
out_channels
%
64
);
if
(
out_channels
<=
32
)
{
if
(
out_channels
%
8
)
pad_channels
=
8
-
(
out_channels
%
8
);
}
else
{
if
(
out_channels
%
64
)
pad_channels
=
64
-
(
out_channels
%
64
);
}
if
(
pad_channels
>
0
)
{
inps
[
1
]
=
pad_out_channels
(
inps
[
1
],
pad_channels
);
inps
[
2
]
=
pad_in_channels
(
inps
[
2
],
pad_channels
);
...
...
@@ -4402,20 +4589,16 @@ EnableNCHW64Pass::make_nchw64_converter() {
return
new_conv
.
node
();
}
};
auto
try_transform_to_nchw
=
[
&
format_map
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
->
VarNode
*
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
bool
check_dtype
=
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
&&
new_inp
[
1
]
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
;
auto
try_transform_to_nchw
=
[
&
format_map
](
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
new_inp
)
->
VarNode
*
{
mgb_assert
(
opr
->
input
().
size
()
==
new_inp
.
size
());
bool
check_dtype
=
new_inp
[
0
]
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
&&
new_inp
[
1
]
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
;
if
(
opr
->
input
().
size
()
>=
3
)
check_dtype
&=
new_inp
[
2
]
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
;
check_dtype
&=
new_inp
[
2
]
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
;
if
(
opr
->
input
().
size
()
>=
4
)
check_dtype
&=
new_inp
[
3
]
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
;
check_dtype
&=
new_inp
[
3
]
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
;
if
(
!
check_dtype
)
return
nullptr
;
auto
inps
=
new_inp
;
...
...
@@ -4451,7 +4634,6 @@ EnableNCHW64Pass::make_nchw64_converter() {
return
ret
->
output
()[
0
];
};
auto
try_transform_to_nchw4
=
[
make_new_conv
,
&
format_map
](
OperatorNodeBase
*
opr
,
...
...
src/gopt/test/inference.cpp
浏览文件 @
009c90a2
...
...
@@ -4735,6 +4735,101 @@ TEST(TestGoptInference, PaddingChannelsWithWarpPerspective) {
MGB_ASSERT_TENSOR_EQ
(
t1
,
t2
);
}
TEST
(
TestGoptInference
,
PaddingChannelsB4
)
{
REQUIRE_GPU
(
1
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
cn
.
activate
();
REQUIRE_CUDA_COMPUTE_CAPABILITY
(
7
,
5
);
HostTensorGenerator
<
dtype
::
Int8
>
gen
;
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
mkvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
,
const
DType
&
dtype
)
{
return
opr
::
TypeCvt
::
make
(
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
gen
(
shp
,
cn
)).
rename
(
name
),
dtype
);
};
auto
mkcvar
=
[
&
](
const
char
*
name
,
const
TensorShape
&
shp
,
const
DType
&
dtype
)
{
return
opr
::
TypeCvt
::
make
(
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
*
gen
(
shp
,
cn
))
.
rename
(
name
),
dtype
);
};
auto
x
=
mkvar
(
"x"
,
{
16
,
3
,
14
,
14
},
dtype
::
QuantizedS8
(
2.5
f
)),
w
=
mkcvar
(
"w"
,
{
16
,
3
,
3
,
3
},
dtype
::
QuantizedS8
(
2.5
f
)),
b
=
mkcvar
(
"b"
,
{
1
,
16
,
1
,
1
},
dtype
::
QuantizedS32
(
6.25
f
));
opr
::
ConvBias
::
Param
param
;
param
.
format
=
opr
::
ConvBias
::
Param
::
Format
::
NCHW
;
param
.
nonlineMode
=
opr
::
ConvBias
::
Param
::
NonlineMode
::
RELU
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
pad_h
=
param
.
pad_w
=
1
;
auto
y
=
opr
::
ConvBias
::
make
(
x
,
w
,
b
,
param
,
{},
OperatorNodeConfig
{
dtype
::
QuantizedS8
(
2.5
f
)});
y
=
opr
::
TypeCvt
::
make
(
y
,
dtype
::
Quantized4Asymm
{
20.
f
,
8
});
opr
::
Pooling
::
Param
pool
;
pool
.
format
=
opr
::
Pooling
::
Param
::
Format
::
NCHW
;
y
=
opr
::
Pooling
::
make
(
y
,
pool
);
auto
w1
=
mkcvar
(
"w1"
,
{
48
,
16
,
3
,
3
},
dtype
::
QuantizedS4
(
1.234
f
)),
b1
=
mkcvar
(
"b1"
,
{
1
,
48
,
1
,
1
},
dtype
::
QuantizedS32
(
20.
f
*
1.234
f
));
auto
y1
=
opr
::
ConvBias
::
make
(
y
,
w1
,
b1
,
param
,
{},
OperatorNodeConfig
{
dtype
::
Quantized4Asymm
(
20.
f
,
8
)});
auto
w2
=
mkcvar
(
"w2"
,
{
48
,
48
,
3
,
3
},
dtype
::
QuantizedS4
(
1.234
f
)),
b2
=
mkcvar
(
"b2"
,
{
1
,
48
,
1
,
1
},
dtype
::
QuantizedS32
(
20.
f
*
1.234
f
));
auto
y2
=
opr
::
ConvBias
::
make
(
y1
,
w2
,
b2
,
param
,
{},
OperatorNodeConfig
{
dtype
::
Quantized4Asymm
(
20.
f
,
8
)});
auto
w3
=
mkcvar
(
"w2"
,
{
16
,
48
,
3
,
3
},
dtype
::
QuantizedS4
(
1.234
f
)),
b3
=
mkcvar
(
"b2"
,
{
1
,
16
,
1
,
1
},
dtype
::
QuantizedS32
(
20.
f
*
1.234
f
));
auto
y3
=
opr
::
ConvBias
::
make
(
y2
,
w3
,
b3
,
param
,
{},
OperatorNodeConfig
{
dtype
::
Quantized4Asymm
(
20.
f
,
8
)});
using
ElemMultiMode
=
opr
::
ElemwiseMultiType
::
Param
::
Mode
;
auto
y4
=
opr
::
ElemwiseMultiType
::
make
(
{
y
,
y3
},
{
ElemMultiMode
::
QFUSE_ADD_RELU
},
OperatorNodeConfig
{
dtype
::
Quantized4Asymm
{
20.
f
,
7
}});
y4
=
opr
::
TypeCvt
::
make
(
y4
,
dtype
::
Float32
());
SymbolVar
y4_pad
;
unpack_vector
(
gopt
::
GraphOptimizer
{}
.
add_pass
<
gopt
::
PaddingChannelPass
>
()
.
add_pass
<
gopt
::
ParamFusePass
>
()
.
apply
({{
y4
}})
.
endpoint_vars
(),
y4_pad
);
ASSERT_EQ
(
y4_pad
.
node
()
->
shape
()[
1
],
y4
.
node
()
->
shape
()[
1
]);
SmallVector
<
cg
::
OperatorNodeBase
*>
oprs
;
auto
cb1
=
[
&
oprs
](
cg
::
OperatorNodeBase
*
opr
)
{
if
(
opr
->
same_type
<
opr
::
ConvBias
>
())
{
oprs
.
push_back
(
opr
);
}
};
cg
::
DepOprIter
{
cb1
}.
add
(
y4_pad
.
node
()
->
owner_opr
());
ASSERT_EQ
(
oprs
.
size
(),
4
);
ASSERT_EQ
(
oprs
[
0
]
->
output
(
0
)
->
shape
()[
1
],
16
);
ASSERT_EQ
(
oprs
[
1
]
->
output
(
0
)
->
shape
()[
1
],
64
);
ASSERT_EQ
(
oprs
[
2
]
->
output
(
0
)
->
shape
()[
1
],
64
);
ASSERT_EQ
(
oprs
[
3
]
->
output
(
0
)
->
shape
()[
1
],
16
);
size_t
nr_concat
=
find_opr_num
<
opr
::
Concat
>
(
y4_pad
);
ASSERT_EQ
(
nr_concat
,
1
);
cg
::
OperatorNodeBase
*
concat
=
nullptr
;
auto
cb2
=
[
&
concat
](
cg
::
OperatorNodeBase
*
opr
)
{
if
(
opr
->
same_type
<
opr
::
Concat
>
())
{
concat
=
opr
;
}
};
cg
::
DepOprIter
{
cb2
}.
add
(
y4_pad
.
node
()
->
owner_opr
());
ASSERT_EQ
(
oprs
[
0
]
->
input
(
0
)
->
owner_opr
(),
concat
);
HostTensorND
t1
,
t2
;
auto
func1
=
graph
->
compile
({
make_callback_copy
(
y4
,
t1
)});
func1
->
execute
();
auto
func2
=
graph
->
compile
({
make_callback_copy
(
y4_pad
,
t2
)});
func2
->
execute
();
MGB_ASSERT_TENSOR_EQ
(
t1
,
t2
);
}
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录