Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
61f917fb
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
61f917fb
编写于
11月 21, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add impl for fusing warp perspective and dimshuffle
GitOrigin-RevId: 51e025973f58ba75ea96765bcb8507581c6c1c25
上级
15dd5e1a
变更
13
展开全部
隐藏空白更改
内联
并排
Showing
13 changed file
with
2023 addition
and
417 deletion
+2023
-417
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+6
-0
dnn/src/common/warp_perspective.cpp
dnn/src/common/warp_perspective.cpp
+150
-108
dnn/src/cuda/warp_perspective/common.h
dnn/src/cuda/warp_perspective/common.h
+17
-0
dnn/src/cuda/warp_perspective/forward.cpp
dnn/src/cuda/warp_perspective/forward.cpp
+141
-48
dnn/src/cuda/warp_perspective/forward.cu
dnn/src/cuda/warp_perspective/forward.cu
+823
-77
dnn/src/naive/warp_perspective/opr_impl.cpp
dnn/src/naive/warp_perspective/opr_impl.cpp
+220
-5
dnn/src/naive/warp_perspective/opr_impl.h
dnn/src/naive/warp_perspective/opr_impl.h
+32
-10
dnn/test/cuda/warp_perspective.cpp
dnn/test/cuda/warp_perspective.cpp
+139
-2
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+143
-148
src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
+241
-0
src/gopt/include/megbrain/gopt/inference.h
src/gopt/include/megbrain/gopt/inference.h
+10
-0
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+62
-1
src/opr/impl/imgproc.cpp
src/opr/impl/imgproc.cpp
+39
-18
未找到文件。
dnn/scripts/opr_param_defs.py
浏览文件 @
61f917fb
...
...
@@ -43,6 +43,12 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc
(
'NCHW4_NCHW32'
,
'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'
),
Doc
(
'NCHW32_NCHW4'
,
'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'
),
Doc
(
'NCHW4_NCHW'
,
'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'
),
Doc
(
'NHWC_NCHW'
,
'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'
),
Doc
(
'NHWC_NCHW4_IC_SMALL'
,
'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'
),
Doc
(
'NCHW_NCHW4_IC_SMALL'
,
'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'
),
Doc
(
'CHWN4'
,
'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'
))
)
...
...
dnn/src/common/warp_perspective.cpp
浏览文件 @
61f917fb
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"
...
...
@@ -14,20 +15,17 @@
namespace
megdnn
{
void
WarpPerspectiveBase
::
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
{
void
WarpPerspectiveBase
::
check_layout_fwd
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
)
{
megdnn_assert_contiguous
(
mat
);
megdnn_assert_contiguous
(
src
);
megdnn_assert_contiguous
(
dst
);
auto
errmsg
=
[
&
]()
{
return
megdnn_layout_msg
(
src
)
+
", "
+
megdnn_layout_msg
(
mat
)
+
", "
+
megdnn_layout_msg
(
mat_idx
)
+
", "
+
megdnn_layout_msg
(
dst
)
+
", "
+
param_msg
();
return
megdnn_layout_msg
(
src
)
+
", "
+
megdnn_layout_msg
(
mat
)
+
", "
+
megdnn_layout_msg
(
mat_idx
)
+
", "
+
megdnn_layout_msg
(
dst
)
+
", "
+
param_msg
();
};
MEGDNN_MARK_USED_VAR
(
errmsg
);
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWCD4
||
...
...
@@ -35,9 +33,17 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert
(
src
.
ndim
==
5
_z
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
dst
.
ndim
==
5
_z
,
"%s"
,
errmsg
().
c_str
());
}
else
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWC_NCHW4_IC_SMALL
||
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW_NCHW4_IC_SMALL
)
{
megdnn_assert
(
src
.
ndim
==
4
_z
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
dst
.
ndim
==
5
_z
,
"%s"
,
errmsg
().
c_str
());
}
else
{
megdnn_assert
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWC
||
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW
);
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW
||
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWC_NCHW
);
megdnn_assert
(
src
.
ndim
==
4
_z
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
dst
.
ndim
==
4
_z
,
"%s"
,
errmsg
().
c_str
());
}
...
...
@@ -45,7 +51,7 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert
(
dst
.
shape
[
0
]
==
mat
.
shape
[
0
],
"%s"
,
errmsg
().
c_str
());
if
(
mat_idx
.
ndim
)
{
megdnn_assert
(
mat_idx
.
dtype
==
dtype
::
Int32
()
&&
mat_idx
.
ndim
==
1
,
"%s"
,
errmsg
().
c_str
());
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
mat
.
shape
[
0
]
==
mat_idx
.
shape
[
0
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert_contiguous
(
mat_idx
);
}
else
{
...
...
@@ -54,35 +60,103 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert
(
mat
.
shape
[
1
]
==
3
_z
,
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
mat
.
shape
[
2
]
==
3
_z
,
"%s"
,
errmsg
().
c_str
());
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW
)
{
megdnn_assert
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
||
MEGDNN_FLOAT16_SELECT
(
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Float16
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
BFloat16
),
false
)
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Int8
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
||
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
),
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8"
MEGDNN_FLOAT16_SELECT
(
"/Float16/BFloat16"
,
""
)
"."
);
megdnn_assert
(
(
src
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
&&
(
src
.
dtype
==
mat
.
dtype
||
mat
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
))
||
((
src
.
dtype
.
category
()
==
DTypeCategory
::
INT
||
src
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
)
&&
mat
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
),
"The input to WarpPerspective is in NCHW format, in this "
"case, if the input dtype is floating point, the "
"transformation matrix should have same dtype as the "
"input, otherwise, it should be in Float32, %s given."
,
mat
.
dtype
.
name
());
if
(
src
.
format
==
dst
.
format
&&
dst
.
dtype
==
src
.
dtype
)
{
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW
)
{
megdnn_assert
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
||
MEGDNN_FLOAT16_SELECT
(
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Float16
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
BFloat16
),
false
)
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Int8
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
||
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
),
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8"
MEGDNN_FLOAT16_SELECT
(
"/Float16/BFloat16"
,
""
)
"."
);
megdnn_assert
(
(
src
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
&&
(
src
.
dtype
==
mat
.
dtype
||
mat
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
))
||
((
src
.
dtype
.
category
()
==
DTypeCategory
::
INT
||
src
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
)
&&
mat
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
),
"The input to WarpPerspective is in NCHW format, in this "
"case, if the input dtype is floating point, the "
"transformation matrix should have same dtype as the "
"input, otherwise, it should be in Float32, %s given."
,
mat
.
dtype
.
name
());
megdnn_assert
(
src
.
shape
[
1
]
==
dst
.
shape
[
1
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
dst
.
dtype
==
src
.
dtype
);
megdnn_assert
(
src
.
shape
[
1
]
==
dst
.
shape
[
1
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
param
().
imode
==
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
);
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
TRANSPARENT
);
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
ISOLATED
);
}
else
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWC
)
{
megdnn_assert
(
src
.
shape
[
3
]
==
dst
.
shape
[
3
],
"%s"
,
errmsg
().
c_str
());
}
else
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW4
)
{
megdnn_assert
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
,
"src expected QuantizedS8, but got %s"
,
src
.
dtype
.
name
());
megdnn_assert
(
mat
.
dtype
==
dtype
::
Float32
(),
"matrix dtype expected float, got %s"
,
mat
.
dtype
.
name
());
megdnn_assert
(
src
.
shape
[
4
]
==
4
&&
dst
.
shape
[
4
]
==
4
);
megdnn_assert
(
src
.
shape
[
1
]
==
dst
.
shape
[
1
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
param
().
imode
==
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
);
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
TRANSPARENT
);
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
ISOLATED
);
}
else
{
megdnn_assert
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWCD4
);
megdnn_assert
(
src
.
dtype
==
dtype
::
Float32
()
||
MEGDNN_FLOAT16_SELECT
(
(
src
.
dtype
==
dtype
::
Float16
()
||
src
.
dtype
==
dtype
::
BFloat16
()),
false
)
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
,
"WarpPerspective NHWCD4 input dtype should be "
"Float32"
MEGDNN_FLOAT16_SELECT
(
"/Float16/BFloat16"
,
""
)
",QunatizedS8, Quantized8Asymm."
);
megdnn_assert
(
(
src
.
dtype
==
mat
.
dtype
||
mat
.
dtype
==
dtype
::
Float32
()),
"The input to WarpPerspective is in NHWCD4 format, in this "
"case, if the input dtype is floating point, the "
"transformation matrix should have same dtype as the "
"input, %s given."
,
mat
.
dtype
.
name
());
//! number of channels is same
megdnn_assert
(
src
.
shape
[
2
]
==
dst
.
shape
[
2
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
param
().
imode
==
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
);
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
TRANSPARENT
);
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
ISOLATED
);
}
}
else
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWC_NCHW4_IC_SMALL
||
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW_NCHW4_IC_SMALL
)
{
megdnn_assert
((
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
),
"src expected Quantized8Asymm or Uint8, but got %s"
,
src
.
dtype
.
name
());
megdnn_assert
(
mat
.
dtype
==
dtype
::
Float32
(),
"matrix dtype expected float, got %s"
,
mat
.
dtype
.
name
());
megdnn_assert
(
dst
.
shape
[
4
]
==
4
);
megdnn_assert
(
param
().
imode
==
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
);
...
...
@@ -90,16 +164,14 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
param
::
WarpPerspective
::
BorderMode
::
TRANSPARENT
);
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
ISOLATED
);
}
else
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWC
)
{
megdnn_assert
(
src
.
shape
[
3
]
==
dst
.
shape
[
3
],
"%s"
,
errmsg
().
c_str
());
}
else
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW4
)
{
megdnn_assert
(
dst
.
dtype
==
src
.
dtype
);
megdnn_assert
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
,
"src expected QuantizedS8, but got %s"
,
src
.
dtype
.
name
());
}
else
if
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWC_NCHW
)
{
megdnn_assert
((
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
),
"src expected Quantized8Asymm or Uint8, but got %s"
,
src
.
dtype
.
name
());
megdnn_assert
(
mat
.
dtype
==
dtype
::
Float32
(),
"matrix dtype expected float, got %s"
,
mat
.
dtype
.
name
());
megdnn_assert
(
src
.
shape
[
4
]
==
4
&&
dst
.
shape
[
4
]
==
4
);
megdnn_assert
(
src
.
shape
[
1
]
==
dst
.
shape
[
1
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
src
.
shape
[
3
]
==
dst
.
shape
[
1
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
param
().
imode
==
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
);
...
...
@@ -108,40 +180,14 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
ISOLATED
);
}
else
{
megdnn_assert
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NHWCD4
);
megdnn_assert
(
src
.
dtype
==
dtype
::
Float32
()
||
MEGDNN_FLOAT16_SELECT
((
src
.
dtype
==
dtype
::
Float16
()
||
src
.
dtype
==
dtype
::
BFloat16
()),
false
)
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
,
"WarpPerspective NHWCD4 input dtype should be "
"Float32"
MEGDNN_FLOAT16_SELECT
(
"/Float16/BFloat16"
,
""
)
",QunatizedS8, Quantized8Asymm."
);
megdnn_assert
(
(
src
.
dtype
==
mat
.
dtype
||
mat
.
dtype
==
dtype
::
Float32
()),
"The input to WarpPerspective is in NHWCD4 format, in this "
"case, if the input dtype is floating point, the "
"transformation matrix should have same dtype as the "
"input, %s given."
,
mat
.
dtype
.
name
());
megdnn_assert
(
dst
.
dtype
==
src
.
dtype
);
//! number of channels is same
megdnn_assert
(
src
.
shape
[
2
]
==
dst
.
shape
[
2
],
"%s"
,
errmsg
().
c_str
());
megdnn_assert
(
param
().
imode
==
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
);
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
TRANSPARENT
);
megdnn_assert
(
param
().
bmode
!=
param
::
WarpPerspective
::
BorderMode
::
ISOLATED
);
megdnn_assert
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW
);
megdnn_assert
((
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
)
&&
dst
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
);
}
megdnn_assert
(
src
.
format
==
dst
.
format
);
}
std
::
string
WarpPerspectiveBase
::
param_msg
()
const
{
std
::
string
WarpPerspectiveBase
::
param_msg
()
const
{
std
::
string
res
;
res
.
append
(
megdnn_mangle
(
"imode="
));
switch
(
param
().
imode
)
{
...
...
@@ -191,31 +237,25 @@ std::string WarpPerspectiveBase::param_msg() const
return
res
;
}
int
WarpPerspectiveBase
::
get_real_coord
(
int
p
,
int
len
)
{
int
WarpPerspectiveBase
::
get_real_coord
(
int
p
,
int
len
)
{
auto
bmode
=
param
().
bmode
;
if
(
(
unsigned
)
p
<
(
unsigned
)
len
)
if
((
unsigned
)
p
<
(
unsigned
)
len
)
;
else
if
(
bmode
==
BorderMode
::
REPLICATE
)
else
if
(
bmode
==
BorderMode
::
REPLICATE
)
p
=
p
<
0
?
0
:
len
-
1
;
else
if
(
bmode
==
BorderMode
::
REFLECT
||
bmode
==
BorderMode
::
REFLECT_101
)
{
else
if
(
bmode
==
BorderMode
::
REFLECT
||
bmode
==
BorderMode
::
REFLECT_101
)
{
int
delta
=
(
bmode
==
BorderMode
::
REFLECT_101
);
if
(
len
==
1
)
if
(
len
==
1
)
return
0
;
do
{
if
(
p
<
0
)
do
{
if
(
p
<
0
)
p
=
-
p
-
1
+
delta
;
else
p
=
len
-
1
-
(
p
-
len
)
-
delta
;
}
while
(
(
unsigned
)
p
>=
(
unsigned
)
len
);
}
else
if
(
bmode
==
BorderMode
::
WRAP
)
{
if
(
p
<
0
)
p
-=
((
p
-
len
+
1
)
/
len
)
*
len
;
}
while
((
unsigned
)
p
>=
(
unsigned
)
len
);
}
else
if
(
bmode
==
BorderMode
::
WRAP
)
{
if
(
p
<
0
)
p
-=
((
p
-
len
+
1
)
/
len
)
*
len
;
/*
if( p >= len )
p %= len;
...
...
@@ -223,18 +263,16 @@ int WarpPerspectiveBase::get_real_coord(int p, int len)
while
(
p
>=
len
)
{
p
-=
len
;
}
}
else
if
(
bmode
==
BorderMode
::
CONSTANT
)
}
else
if
(
bmode
==
BorderMode
::
CONSTANT
)
p
=
-
1
;
return
p
;
}
void
WarpPerspectiveForward
::
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
)
{
void
WarpPerspectiveForward
::
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
mat
,
const
TensorLayout
&
mat_idx
,
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
)
{
check_exec_allow_nhwc_mat_idx
(
src
,
mat
,
mat_idx
,
dst
,
workspace_in_bytes
);
}
...
...
@@ -248,7 +286,10 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
if
(
param
().
format
!=
Param
::
Format
::
NHWC
&&
param
().
format
!=
Param
::
Format
::
NCHW
&&
param
().
format
!=
Param
::
Format
::
NCHW4
)
{
param
().
format
!=
Param
::
Format
::
NCHW4
&&
param
().
format
!=
Param
::
Format
::
NHWC_NCHW
&&
param
().
format
!=
Param
::
Format
::
NHWC_NCHW4_IC_SMALL
&&
param
().
format
!=
Param
::
Format
::
NCHW_NCHW4_IC_SMALL
)
{
megdnn_assert
(
!
mat_idx
.
ndim
,
"mat_idx not supported for current format"
);
}
...
...
@@ -263,7 +304,8 @@ void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat,
megdnn_assert
(
grad
.
dtype
==
dtype
::
Float32
()
MEGDNN_INC_FLOAT16
(
||
grad
.
dtype
==
dtype
::
BFloat16
()),
"Backward WarpPerspective only supports Float32/BFloat16."
);
auto
required_workspace_in_bytes
=
get_workspace_in_bytes
(
mat
,
mat_idx
,
diff
,
grad
);
auto
required_workspace_in_bytes
=
get_workspace_in_bytes
(
mat
,
mat_idx
,
diff
,
grad
);
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
}
...
...
@@ -283,6 +325,6 @@ void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src,
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
}
}
// namespace megdnn
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/cuda/warp_perspective/common.h
浏览文件 @
61f917fb
...
...
@@ -12,6 +12,7 @@
#pragma once
#include <cuda_runtime_api.h>
#include "src/common/cv/enums.h"
#include "src/cuda/utils.cuh"
#include "megcore_cdefs.h"
namespace
megdnn
{
...
...
@@ -34,6 +35,22 @@ void forward_proxy_nchw4(const ctype* src, const float* mat, const int* mat_idx,
megcore
::
AsyncErrorInfo
*
error_info
,
void
*
error_tracker
,
cudaStream_t
stream
);
template
<
typename
src_dtype
,
typename
src_ctype
,
typename
dst_ctype
>
void
forward_proxy_quint8_dimshuffle_typecvt_nchw4
(
bool
is_nhwc
,
const
src_ctype
*
src
,
const
float
*
mat
,
const
int
*
mat_idx
,
dst_ctype
*
dst
,
int
N_SRC
,
int
N_MAT
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
src_ctype
bval
,
DTypeParamImpl
<
src_dtype
>
param
,
BorderMode
bmode
,
megcore
::
AsyncErrorInfo
*
error_info
,
void
*
error_tracker
,
cudaStream_t
stream
);
template
<
typename
src_dtype
,
typename
src_ctype
,
typename
dst_ctype
>
void
forward_proxy_quint8_dimshuffle_typecvt_nchw
(
bool
is_nhwc
,
const
src_ctype
*
src
,
const
float
*
mat
,
const
int
*
mat_idx
,
dst_ctype
*
dst
,
int
N_SRC
,
int
N_MAT
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
src_ctype
bval
,
DTypeParamImpl
<
src_dtype
>
param
,
BorderMode
bmode
,
megcore
::
AsyncErrorInfo
*
error_info
,
void
*
error_tracker
,
cudaStream_t
stream
);
void
backward_data_proxy
(
const
float
*
mat
,
const
int
*
midx
,
const
float
*
diff
,
float
*
grad
,
float
*
workspace
,
int
N
,
int
N_SRC
,
int
C
,
int
IH
,
int
IW
,
int
OH
,
int
OW
,
float
bval
,
...
...
dnn/src/cuda/warp_perspective/forward.cpp
浏览文件 @
61f917fb
...
...
@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/warp_perspective/opr_impl.h"
#include "src/cuda/warp_perspective/warp_perspective_cv.cuh"
...
...
@@ -166,6 +167,30 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
IW
=
src
.
layout
.
shape
[
3
];
OH
=
dst
.
layout
.
shape
[
2
];
OW
=
dst
.
layout
.
shape
[
3
];
}
else
if
(
param
().
format
==
Param
::
Format
::
NHWC_NCHW
)
{
C
=
src
.
layout
.
shape
[
3
];
IH
=
src
.
layout
.
shape
[
1
];
IW
=
src
.
layout
.
shape
[
2
];
OH
=
dst
.
layout
.
shape
[
2
];
OW
=
dst
.
layout
.
shape
[
3
];
}
else
if
(
param
().
format
==
Param
::
Format
::
NHWC_NCHW4_IC_SMALL
)
{
C
=
src
.
layout
.
shape
[
3
];
IH
=
src
.
layout
.
shape
[
1
];
IW
=
src
.
layout
.
shape
[
2
];
OH
=
dst
.
layout
.
shape
[
2
];
OW
=
dst
.
layout
.
shape
[
3
];
megdnn_assert
(
(
C
==
1
)
||
(
C
==
3
),
"NHWC_NCHW4_IC_SMALL only support C == 1 or C == 3"
);
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW_NCHW4_IC_SMALL
)
{
C
=
src
.
layout
.
shape
[
1
];
IH
=
src
.
layout
.
shape
[
2
];
IW
=
src
.
layout
.
shape
[
3
];
OH
=
dst
.
layout
.
shape
[
2
];
OW
=
dst
.
layout
.
shape
[
3
];
megdnn_assert
(
(
C
==
1
)
||
(
C
==
3
),
"NCHW_NCHW4_IC_SMALL only support C == 1 or C == 3"
);
}
else
{
megdnn_assert
(
param
().
format
==
param
::
WarpPerspective
::
Format
::
NCHW
,
...
...
@@ -180,55 +205,123 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
"unsupported interpolation mode for NCHW format"
);
auto
bval
=
param
().
border_val
;
auto
bmode
=
warp_perspective
::
get_bmode
(
param
().
bmode
);
if
(
src
.
layout
.
dtype
==
dtype
::
Float32
{})
{
warp_perspective
::
forward_proxy
(
is_nhwc
,
src
.
ptr
<
dt_float32
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_float32
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
if
(
MEGDNN_FLOAT16_SELECT
(
src
.
layout
.
dtype
==
dtype
::
Float16
(),
false
))
{
if
(
src
.
layout
.
dtype
==
dst
.
layout
.
dtype
)
{
if
(
src
.
layout
.
dtype
==
dtype
::
Float32
{})
{
warp_perspective
::
forward_proxy
(
is_nhwc
,
src
.
ptr
<
dt_float32
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_float32
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
if
(
MEGDNN_FLOAT16_SELECT
(
src
.
layout
.
dtype
==
dtype
::
Float16
(),
false
))
{
#ifndef MEGDNN_DISABLE_FLOAT16
warp_perspective
::
forward_proxy
(
is_nhwc
,
src
.
ptr
<
dt_float16
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_float16
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
static_cast
<
dt_float16
>
(
bval
),
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
warp_perspective
::
forward_proxy
(
is_nhwc
,
src
.
ptr
<
dt_float16
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_float16
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
static_cast
<
dt_float16
>
(
bval
),
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
#endif
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Uint8
())
{
warp_perspective
::
forward_proxy
<
dt_uint8
>
(
is_nhwc
,
src
.
ptr
<
dt_uint8
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_uint8
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Int8
())
{
megdnn_assert
(
!
is_nhwc
,
"WarpPerspective on CUDA does not support NHWC + Int8"
);
warp_perspective
::
forward_proxy
<
dt_int8
>
(
false
,
src
.
ptr
<
dt_int8
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_int8
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
/* implicit float -> int8 conversion, should be
safe */
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NCHW4
,
"WarpPerspective on CUDA supports NCHW4 + "
"QuantizedS8 only"
);
warp_perspective
::
forward_proxy_nchw4
<
dt_int8
>
(
src
.
compatible_ptr
<
dt_int8
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
compatible_ptr
<
dt_int8
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Uint8
())
{
warp_perspective
::
forward_proxy
<
dt_uint8
>
(
is_nhwc
,
src
.
ptr
<
dt_uint8
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_uint8
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
if
(
src
.
layout
.
dtype
==
dtype
::
Int8
())
{
megdnn_assert
(
!
is_nhwc
,
"WarpPerspective on CUDA does not support "
"NHWC + Int8"
);
warp_perspective
::
forward_proxy
<
dt_int8
>
(
false
,
src
.
ptr
<
dt_int8
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
ptr
<
dt_int8
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
/* implicit float -> int8 conversion,
should be safe */
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NCHW4
,
"WarpPerspective on CUDA supports NCHW4 + "
"QuantizedS8 only"
);
warp_perspective
::
forward_proxy_nchw4
<
dt_int8
>
(
src
.
compatible_ptr
<
dt_int8
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
compatible_ptr
<
dt_int8
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
}
else
if
((
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
))
{
uint8_t
zero_point
=
0
;
float
scale
=
1.
f
;
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
zero_point
=
src
.
layout
.
dtype
.
param
<
dtype
::
Quantized8Asymm
>
()
.
zero_point
;
scale
=
src
.
layout
.
dtype
.
param
<
dtype
::
Quantized8Asymm
>
()
.
scale
;
}
else
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
&&
dst
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
zero_point
=
128
;
scale
=
1.
f
;
}
DTypeParamImpl
<
dt_quint8
>
src_dtype_param
(
scale
,
zero_point
);
if
((
dst
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
dst
.
layout
.
dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
==
scale
)
&&
((
param
().
format
==
Param
::
Format
::
NCHW_NCHW4_IC_SMALL
)
||
(
param
().
format
==
Param
::
Format
::
NHWC_NCHW4_IC_SMALL
)))
{
bool
is_nhwc_ic_small
=
(
param
().
format
==
Param
::
Format
::
NHWC_NCHW4_IC_SMALL
);
warp_perspective
::
forward_proxy_quint8_dimshuffle_typecvt_nchw4
<
dt_quint8
,
dt_uint8
,
dt_int8
>
(
is_nhwc_ic_small
,
src
.
compatible_ptr
<
dt_uint8
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
compatible_ptr
<
dt_int8
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
src_dtype_param
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
else
{
megdnn_assert
(
((
dst
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
)
&&
((
param
().
format
==
Param
::
Format
::
NCHW
)
||
(
param
().
format
==
Param
::
Format
::
NHWC_NCHW
))),
"invalid format for Quantized8Asymm input"
);
bool
is_nhwc
=
(
param
().
format
==
Param
::
Format
::
NHWC_NCHW
);
warp_perspective
::
forward_proxy_quint8_dimshuffle_typecvt_nchw
<
dt_quint8
,
dt_uint8
,
dt_float32
>
(
is_nhwc
,
src
.
compatible_ptr
<
dt_uint8
>
(),
mat
.
ptr
<
dt_float32
>
(),
mat_idx
.
raw_ptr
?
mat_idx
.
ptr
<
int
>
()
:
nullptr
,
dst
.
compatible_ptr
<
dt_float32
>
(),
src
.
layout
[
0
],
mat
.
layout
[
0
],
C
,
IH
,
IW
,
OH
,
OW
,
bval
,
src_dtype_param
,
bmode
,
async_error_info
(
handle
()),
m_error_tracker
,
stream
);
}
}
else
{
megdnn_throw
(
ssprintf
(
"unsupported dtype: %s"
,
src
.
layout
.
dtype
.
name
()));
...
...
dnn/src/cuda/warp_perspective/forward.cu
浏览文件 @
61f917fb
此差异已折叠。
点击以展开。
dnn/src/naive/warp_perspective/opr_impl.cpp
浏览文件 @
61f917fb
...
...
@@ -249,6 +249,162 @@ void WarpPerspectiveForwardImpl::kern_naive_nhwcd4(
MIDOUT_END
();
}
template
<
typename
ctype
,
typename
dst_ctype
,
typename
mtype
>
void
WarpPerspectiveForwardImpl
::
kern_naive_dimshuffle_typecvt
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
,
size_t
task_id
)
{
MEGDNN_MARK_USED_VAR
(
kern_param
);
MIDOUT_BEGIN
(
megdnn_naive_warpperspective
,
ctype
,
mtype
,
midout_iv
(
2
))
{
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM
(
kern_param
);
MEGDNN_MARK_USED_VAR
(
N_MAT
);
//! strides of C, H, W on src and dst
size_t
sstrd
[
3
],
dstrd
[
3
];
auto
set_sstrd
=
[
&
](
size_t
s0
,
size_t
s1
,
size_t
s2
)
{
sstrd
[
0
]
=
s0
;
sstrd
[
1
]
=
s1
;
sstrd
[
2
]
=
s2
;
};
auto
set_dstrd
=
[
&
](
size_t
s0
,
size_t
s1
,
size_t
s2
)
{
dstrd
[
0
]
=
s0
;
dstrd
[
1
]
=
s1
;
dstrd
[
2
]
=
s2
;
};
switch
(
kern_param
.
format
)
{
case
Format
::
NCHW
:
case
Format
::
NCHW_NCHW4_IC_SMALL
:
set_sstrd
(
IH
*
IW
,
IW
,
1
);
set_dstrd
(
OH
*
OW
,
OW
,
1
);
break
;
case
Format
::
NHWC_NCHW
:
case
Format
::
NHWC_NCHW4_IC_SMALL
:
set_sstrd
(
1
,
IW
*
C
,
C
);
set_dstrd
(
OH
*
OW
,
OW
,
1
);
break
;
default:
megdnn_throw
(
"bad format"
);
}
uint8_t
zero_point
=
0
;
float
scale
=
1.
f
;
bool
is_dst_float
=
kern_param
.
dst_dtype
.
enumv
()
==
DTypeEnum
::
Float32
;
if
(
kern_param
.
src_dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
Quantized8Asymm
>::
enumv
)
{
auto
dtype_param
=
kern_param
.
src_dtype
.
template
param
<
dtype
::
Quantized8Asymm
>();
zero_point
=
dtype_param
.
zero_point
;
scale
=
dtype_param
.
scale
;
}
else
if
(
kern_param
.
src_dtype
.
enumv
()
==
DTypeEnum
::
Uint8
)
{
zero_point
=
(
kern_param
.
dst_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
?
128
:
0
;
scale
=
1.
f
;
}
dst_ctype
*
dst_ptr
=
reinterpret_cast
<
dst_ctype
*>
(
dptr
);
bool
is_dst_nchw4
=
(
kern_param
.
format
==
Format
::
NCHW_NCHW4_IC_SMALL
)
||
(
kern_param
.
format
==
Format
::
NHWC_NCHW4_IC_SMALL
);
auto
visit_src
=
[
&
sptr
,
sstrd
](
size_t
c
,
int
h
,
int
w
)
->
float
{
return
sptr
[
sstrd
[
0
]
*
c
+
sstrd
[
1
]
*
h
+
sstrd
[
2
]
*
w
];
};
auto
visit_src_bd
=
[
&
sptr
,
sstrd
,
border_val
](
size_t
c
,
int
h
,
int
w
)
->
float
{
if
(
h
!=
-
1
&&
w
!=
-
1
)
{
return
sptr
[
sstrd
[
0
]
*
c
+
sstrd
[
1
]
*
h
+
sstrd
[
2
]
*
w
];
}
else
return
border_val
;
};
auto
visit_dst
=
[
&
dst_ptr
,
dstrd
,
is_dst_nchw4
](
size_t
c
,
int
h
,
int
w
)
->
dst_ctype
&
{
if
(
!
is_dst_nchw4
)
return
dst_ptr
[
dstrd
[
0
]
*
c
+
dstrd
[
1
]
*
h
+
dstrd
[
2
]
*
w
];
else
return
dst_ptr
[((
dstrd
[
0
]
*
(
c
>>
2
)
+
dstrd
[
1
]
*
h
+
dstrd
[
2
]
*
w
)
<<
2
)
+
(
c
&
0b11
)];
};
rounding
::
RoundingConverter
<
dst_ctype
>
output_converter
;
auto
orig_sptr
=
sptr
;
size_t
n
=
task_id
/
OH
;
size_t
oh
=
task_id
%
OH
;
mptr
=
mptr
+
n
*
3
*
3
;
dst_ptr
=
is_dst_nchw4
?
(
dst_ptr
+
n
*
OH
*
OW
*
4
)
:
(
dst_ptr
+
n
*
C
*
OH
*
OW
);
if
(
midx_ptr
)
{
size_t
idx
=
midx_ptr
[
n
];
megdnn_assert
(
idx
<
N_SRC
,
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu"
,
n
,
idx
,
N_SRC
);
sptr
=
orig_sptr
+
idx
*
(
C
*
IH
*
IW
);
}
else
if
(
n
)
{
sptr
+=
n
*
C
*
IH
*
IW
;
}
rep
(
ow
,
OW
)
{
float
numeratorw
=
mptr
[
0
]
*
ow
+
mptr
[
1
]
*
oh
+
mptr
[
2
];
float
numeratorh
=
mptr
[
3
]
*
ow
+
mptr
[
4
]
*
oh
+
mptr
[
5
];
float
denominator
=
mptr
[
6
]
*
ow
+
mptr
[
7
]
*
oh
+
mptr
[
8
];
float
alphaw
=
numeratorw
/
denominator
;
float
alphah
=
numeratorh
/
denominator
;
int
iw0
=
get_real_coord
(
std
::
floor
(
alphaw
)
+
0
,
IW
);
int
iw1
=
get_real_coord
(
std
::
floor
(
alphaw
)
+
1
,
IW
);
int
ih0
=
get_real_coord
(
std
::
floor
(
alphah
)
+
0
,
IH
);
int
ih1
=
get_real_coord
(
std
::
floor
(
alphah
)
+
1
,
IH
);
alphaw
-=
floor
(
alphaw
);
alphah
-=
floor
(
alphah
);
if
(
bmode
!=
BorderMode
::
CONSTANT
)
{
rep
(
c
,
C
)
{
auto
val
=
visit_src
(
c
,
ih0
,
iw0
)
*
(
1.0
f
-
alphaw
)
*
(
1.0
f
-
alphah
)
+
visit_src
(
c
,
ih0
,
iw1
)
*
alphaw
*
(
1.0
f
-
alphah
)
+
visit_src
(
c
,
ih1
,
iw0
)
*
(
1.0
f
-
alphaw
)
*
alphah
+
visit_src
(
c
,
ih1
,
iw1
)
*
alphaw
*
alphah
;
val
=
is_dst_float
?
(
val
-
zero_point
)
*
scale
:
val
-
zero_point
;
visit_dst
(
c
,
oh
,
ow
)
=
output_converter
(
val
);
}
}
else
{
rep
(
c
,
C
)
{
auto
val
=
visit_src_bd
(
c
,
ih0
,
iw0
)
*
(
1.0
f
-
alphaw
)
*
(
1.0
f
-
alphah
)
+
visit_src_bd
(
c
,
ih0
,
iw1
)
*
alphaw
*
(
1.0
f
-
alphah
)
+
visit_src_bd
(
c
,
ih1
,
iw0
)
*
(
1.0
f
-
alphaw
)
*
alphah
+
visit_src_bd
(
c
,
ih1
,
iw1
)
*
alphaw
*
alphah
;
val
=
std
::
isfinite
(
val
)
?
val
:
border_val
;
val
=
is_dst_float
?
(
val
-
zero_point
)
*
scale
:
val
-
zero_point
;
visit_dst
(
c
,
oh
,
ow
)
=
output_converter
(
val
);
}
}
if
(
is_dst_nchw4
)
{
for
(
auto
c
=
C
;
c
<
4
;
++
c
)
{
visit_dst
(
c
,
oh
,
ow
)
=
0
;
}
}
}
}
MIDOUT_END
();
}
#define INST(ctype, drc_ctype, mtype) \
template void WarpPerspectiveForwardImpl::kern_naive_dimshuffle_typecvt< \
ctype, drc_ctype, mtype>(const KernParam<ctype, mtype>&, size_t);
INST
(
uint8_t
,
int8_t
,
float
);
INST
(
uint8_t
,
float
,
float
);
#undef INST
void
WarpPerspectiveForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
mat
,
_megdnn_tensor_in
mat_idx
,
...
...
@@ -320,6 +476,65 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src,
src
.
layout
.
dtype
.
name
())
.
c_str
());
}
bool
is_fusion_dtype
=
src
.
layout
.
dtype
.
enumv
()
!=
dst
.
layout
.
dtype
.
enumv
();
bool
is_u8_or_qu8_in
=
src
.
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
Uint8
>::
enumv
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
Quantized8Asymm
>::
enumv
;
if
(
is_fusion_dtype
&&
is_u8_or_qu8_in
&&
((
param
().
format
==
Format
::
NCHW_NCHW4_IC_SMALL
)
||
(
param
().
format
==
Format
::
NHWC_NCHW4_IC_SMALL
)
||
(
param
().
format
==
Format
::
NHWC_NCHW
)
||
(
param
().
format
==
Format
::
NCHW
)))
{
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
Quantized8Asymm
>::
enumv
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
Uint8
>::
enumv
)
{
float
scale
=
1.
f
;
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
Quantized8Asymm
>::
enumv
)
{
scale
=
src
.
layout
.
dtype
.
param
<
dtype
::
Quantized8Asymm
>
().
scale
;
}
auto
kparam
=
KernParam
<
uint8_t
,
float
>::
from_tensors
(
param
().
format
,
param
().
bmode
,
param
().
border_val
,
src
,
mat
,
mat_idx
,
dst
,
workspace
);
if
(
dst
.
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
Float32
>::
enumv
)
{
auto
run
=
[
kparam
,
this
](
size_t
index
,
size_t
)
{
kern_naive_dimshuffle_typecvt
<
uint8_t
,
float
,
float
>
(
kparam
,
index
);
};
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR
(
run
,
kparam
.
oh
*
batch
);
return
;
}
else
if
((
dst
.
layout
.
dtype
.
enumv
()
==
DTypeTrait
<
dtype
::
QuantizedS8
>::
enumv
)
&&
(
dst
.
layout
.
dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
==
scale
))
{
auto
run
=
[
kparam
,
this
](
size_t
index
,
size_t
)
{
kern_naive_dimshuffle_typecvt
<
uint8_t
,
int8_t
,
float
>
(
kparam
,
index
);
};
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR
(
run
,
kparam
.
oh
*
batch
);
return
;
}
else
{
megdnn_throw
(
ssprintf
(
"Unsupported DType in "
"WarpPerspective Dimshuffle Typecvt: %s"
,
src
.
layout
.
dtype
.
name
())
.
c_str
());
}
}
megdnn_throw
(
ssprintf
(
"Unsupported input DType in "
"WarpPerspective: %s"
,
src
.
layout
.
dtype
.
name
())
.
c_str
());
}
if
(
warp
::
is_cv_available
(
src
.
layout
,
mat
.
layout
,
dst
.
layout
,
param
().
imode
,
param
().
format
))
{
MIDOUT_BEGIN
(
megdnn_naive_warpperspective
,
void
)
{
...
...
@@ -331,12 +546,12 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src,
megdnn_assert
(
warp
::
is_dnn_available
(
src
.
layout
,
mat
.
layout
,
dst
.
layout
,
param
().
imode
,
param
().
format
));
/*!
* We currently use floating point for all WarpPerspective
computation,
*
so even if the input ctype is one of the integer type, mtype should
* always be float32.
* We currently use floating point for all WarpPerspective
*
computation, so even if the input ctype is one of the integer
*
type, mtype should
always be float32.
*
* \warning It's different with \c WarpAffine, with mtype be float16
if
* input type is float16.
* \warning It's different with \c WarpAffine, with mtype be float16
* i
f i
nput type is float16.
*/
DISPATCH_ST
(
dtype
::
Float32
,
float
,
float
,
KERN
);
...
...
dnn/src/naive/warp_perspective/opr_impl.h
浏览文件 @
61f917fb
...
...
@@ -26,6 +26,7 @@ protected:
float
border_val
;
size_t
n_src
,
n_mat
,
c
,
ih
,
iw
,
oh
,
ow
;
ctype
*
sptr
,
*
dptr
;
DType
src_dtype
,
dst_dtype
;
mtype
*
mptr
;
int
*
midx_ptr
;
//!< can be null
Workspace
workspace
;
...
...
@@ -41,6 +42,8 @@ protected:
ret
.
bmode
=
bmode
;
ret
.
border_val
=
border_val
;
ret
.
n_src
=
src
.
layout
.
shape
[
0
];
ret
.
src_dtype
=
src
.
layout
.
dtype
;
ret
.
dst_dtype
=
dst
.
layout
.
dtype
;
if
(
mat_idx
.
raw_ptr
)
{
megdnn_assert
(
mat_idx
.
layout
.
ndim
==
1
);
ret
.
n_mat
=
mat_idx
.
layout
.
shape
[
0
];
...
...
@@ -50,7 +53,8 @@ protected:
ret
.
n_mat
=
ret
.
n_src
;
ret
.
midx_ptr
=
nullptr
;
}
if
(
format
==
Format
::
NCHW
)
{
if
(
format
==
Format
::
NCHW
||
format
==
Format
::
NCHW_NCHW4_IC_SMALL
)
{
ret
.
c
=
src
.
layout
.
shape
[
1
];
ret
.
ih
=
src
.
layout
.
shape
[
2
];
ret
.
iw
=
src
.
layout
.
shape
[
3
];
...
...
@@ -62,6 +66,13 @@ protected:
ret
.
iw
=
src
.
layout
.
shape
[
2
];
ret
.
oh
=
dst
.
layout
.
shape
[
1
];
ret
.
ow
=
dst
.
layout
.
shape
[
2
];
}
else
if
(
format
==
Format
::
NHWC_NCHW
||
format
==
Format
::
NHWC_NCHW4_IC_SMALL
)
{
ret
.
c
=
src
.
layout
.
shape
[
3
];
ret
.
ih
=
src
.
layout
.
shape
[
1
];
ret
.
iw
=
src
.
layout
.
shape
[
2
];
ret
.
oh
=
dst
.
layout
.
shape
[
2
];
ret
.
ow
=
dst
.
layout
.
shape
[
3
];
}
else
if
(
format
==
Format
::
NCHW4
)
{
ret
.
c
=
src
.
layout
.
shape
[
1
]
*
4
;
ret
.
ih
=
src
.
layout
.
shape
[
2
];
...
...
@@ -76,15 +87,16 @@ protected:
ret
.
oh
=
dst
.
layout
.
shape
[
1
];
ret
.
ow
=
dst
.
layout
.
shape
[
3
];
}
if
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
||
MEGDNN_FLOAT16_SELECT
(
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Float16
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
BFloat16
),
false
)
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Int8
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
if
((
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
||
MEGDNN_FLOAT16_SELECT
(
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Float16
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
BFloat16
),
false
)
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Int8
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
&&
(
src
.
layout
.
dtype
==
dst
.
layout
.
dtype
))
{
ret
.
sptr
=
src
.
compatible_ptr
<
ctype
>
();
ret
.
mptr
=
mat
.
ptr
<
mtype
>
();
ret
.
dptr
=
dst
.
compatible_ptr
<
ctype
>
();
...
...
@@ -92,6 +104,13 @@ protected:
ret
.
sptr
=
src
.
compatible_ptr
<
ctype
>
();
ret
.
mptr
=
mat
.
ptr
<
mtype
>
();
ret
.
dptr
=
dst
.
compatible_ptr
<
ctype
>
();
}
else
if
((
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint8
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
&&
src
.
layout
.
dtype
.
enumv
()
!=
dst
.
layout
.
dtype
.
enumv
())
{
ret
.
sptr
=
src
.
compatible_ptr
<
ctype
>
();
ret
.
mptr
=
mat
.
ptr
<
mtype
>
();
ret
.
dptr
=
reinterpret_cast
<
ctype
*>
(
dst
.
raw_ptr
);
}
else
{
ret
.
sptr
=
nullptr
;
ret
.
mptr
=
nullptr
;
...
...
@@ -122,6 +141,9 @@ private:
template
<
typename
ctype
,
typename
mtype
>
void
kern_naive_nhwcd4
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
,
size_t
task_id
);
template
<
typename
ctype
,
typename
dst_ctype
,
typename
mtype
>
void
kern_naive_dimshuffle_typecvt
(
const
KernParam
<
ctype
,
mtype
>&
kern_param
,
size_t
task_id
);
};
class
WarpPerspectiveBackwardDataImpl
:
public
WarpPerspectiveBackwardData
{
...
...
dnn/test/cuda/warp_perspective.cpp
浏览文件 @
61f917fb
...
...
@@ -23,8 +23,7 @@ using namespace megdnn;
using
namespace
test
;
class
NanMatRNG
:
public
RNG
{
void
gen
(
const
TensorND
&
tensor_
)
override
{
void
gen
(
const
TensorND
&
tensor_
)
override
{
auto
&
gen
=
RandomState
::
generator
();
std
::
uniform_real_distribution
<
dt_float32
>
pdist3
(
1.9
f
,
2.1
f
);
std
::
uniform_real_distribution
<
dt_float32
>
pdist
(
0.9
f
,
1.1
f
);
...
...
@@ -335,6 +334,144 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW4) {
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_NCHW_NCHW4_IC_SMALL
)
{
using
Param
=
WarpPerspective
::
Param
;
WarpPerspective
::
Param
param
;
Checker
<
WarpPerspectiveForward
>
checker
(
handle_cuda
());
WarpPerspectiveMatRNG
rng
;
param
.
format
=
Param
::
Format
::
NCHW_NCHW4_IC_SMALL
;
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
0
,
dtype
::
Quantized8Asymm
(
0.1
f
,
128
));
checker
.
set_dtype
(
2
,
dtype
::
QuantizedS8
(
0.1
f
));
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
checker
.
execs
({{
2
,
3
,
10
,
11
},
{
2
,
3
,
3
},
{
2
,
1
,
11
,
12
,
4
}});
checker
.
execs
({{
1
,
3
,
25
,
510
},
{
1
,
3
,
3
},
{
1
,
1
,
25
,
25
,
4
}});
checker
.
execs
({{
1
,
3
,
25
,
25
},
{
1
,
3
,
3
},
{
1
,
1
,
51
,
51
,
4
}});
checker
.
execs
({{
1
,
3
,
51
,
51
},
{
1
,
3
,
3
},
{
1
,
1
,
25
,
25
,
4
}});
}
{
Checker
<
WarpPerspective
,
WarpPerspectiveMatIdxProxy
>
checker
(
handle_cuda
());
constexpr
int
N_SRC
=
5
;
UniformIntRNG
mat_idx_rng
{
0
,
N_SRC
-
1
};
checker
.
set_dtype
(
0
,
dtype
::
Quantized8Asymm
(
0.1
f
,
128
));
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
2
,
dtype
::
Int32
());
checker
.
set_rng
(
2
,
&
mat_idx_rng
);
checker
.
set_dtype
(
3
,
dtype
::
QuantizedS8
(
0.1
f
));
param
.
bmode
=
WarpPerspective
::
Param
::
BorderMode
::
REFLECT
;
param
.
imode
=
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
checker
.
execs
({{
N_SRC
,
3
,
10
,
11
},
{
2
,
3
,
3
},
{
2
},
{
2
,
1
,
11
,
12
,
4
}});
checker
.
execs
(
{{
N_SRC
,
3
,
17
,
13
},
{
123
,
3
,
3
},
{
123
},
{
123
,
1
,
16
,
15
,
4
}});
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_NHWC_NCHW4_IC_SMALL
)
{
using
Param
=
WarpPerspective
::
Param
;
WarpPerspective
::
Param
param
;
Checker
<
WarpPerspectiveForward
>
checker
(
handle_cuda
());
WarpPerspectiveMatRNG
rng
;
param
.
format
=
Param
::
Format
::
NHWC_NCHW4_IC_SMALL
;
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
0
,
dtype
::
Uint8
());
checker
.
set_dtype
(
2
,
dtype
::
QuantizedS8
(
1.
f
));
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
checker
.
execs
({{
2
,
10
,
11
,
3
},
{
2
,
3
,
3
},
{
2
,
1
,
11
,
12
,
4
}});
checker
.
execs
({{
1
,
25
,
510
,
3
},
{
1
,
3
,
3
},
{
1
,
1
,
25
,
25
,
4
}});
checker
.
execs
({{
1
,
25
,
25
,
3
},
{
1
,
3
,
3
},
{
1
,
1
,
51
,
51
,
4
}});
checker
.
execs
({{
1
,
51
,
51
,
3
},
{
1
,
3
,
3
},
{
1
,
1
,
25
,
25
,
4
}});
}
{
Checker
<
WarpPerspective
,
WarpPerspectiveMatIdxProxy
>
checker
(
handle_cuda
());
constexpr
int
N_SRC
=
5
;
UniformIntRNG
mat_idx_rng
{
0
,
N_SRC
-
1
};
checker
.
set_dtype
(
0
,
dtype
::
Uint8
());
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
2
,
dtype
::
Int32
());
checker
.
set_rng
(
2
,
&
mat_idx_rng
);
checker
.
set_dtype
(
3
,
dtype
::
QuantizedS8
(
1.
f
));
param
.
bmode
=
WarpPerspective
::
Param
::
BorderMode
::
REFLECT
;
param
.
imode
=
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
checker
.
execs
({{
N_SRC
,
10
,
11
,
3
},
{
2
,
3
,
3
},
{
2
},
{
2
,
1
,
11
,
12
,
4
}});
checker
.
execs
(
{{
N_SRC
,
17
,
13
,
3
},
{
123
,
3
,
3
},
{
123
},
{
123
,
1
,
16
,
15
,
4
}});
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_NHWC_NCHW
)
{
using
Param
=
WarpPerspective
::
Param
;
WarpPerspective
::
Param
param
;
Checker
<
WarpPerspectiveForward
>
checker
(
handle_cuda
());
WarpPerspectiveMatRNG
rng
;
param
.
format
=
Param
::
Format
::
NHWC_NCHW
;
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
0
,
dtype
::
Uint8
());
checker
.
set_dtype
(
2
,
dtype
::
Float32
());
for
(
auto
bmode
:
{
WarpPerspective
::
BorderMode
::
WRAP
,
WarpPerspective
::
BorderMode
::
REFLECT
,
WarpPerspective
::
BorderMode
::
REPLICATE
,
WarpPerspective
::
BorderMode
::
CONSTANT
})
{
param
.
border_val
=
0.3
f
;
param
.
bmode
=
bmode
;
param
.
imode
=
Param
::
InterpolationMode
::
LINEAR
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
checker
.
execs
({{
2
,
10
,
11
,
3
},
{
2
,
3
,
3
},
{
2
,
3
,
11
,
12
}});
checker
.
execs
({{
1
,
25
,
510
,
3
},
{
1
,
3
,
3
},
{
1
,
3
,
25
,
25
}});
checker
.
execs
({{
1
,
25
,
25
,
3
},
{
1
,
3
,
3
},
{
1
,
3
,
51
,
51
}});
checker
.
execs
({{
1
,
51
,
51
,
3
},
{
1
,
3
,
3
},
{
1
,
3
,
25
,
25
}});
}
{
Checker
<
WarpPerspective
,
WarpPerspectiveMatIdxProxy
>
checker
(
handle_cuda
());
constexpr
int
N_SRC
=
5
;
UniformIntRNG
mat_idx_rng
{
0
,
N_SRC
-
1
};
checker
.
set_dtype
(
0
,
dtype
::
Uint8
());
checker
.
set_rng
(
1
,
&
rng
);
checker
.
set_dtype
(
2
,
dtype
::
Int32
());
checker
.
set_rng
(
2
,
&
mat_idx_rng
);
checker
.
set_dtype
(
3
,
dtype
::
Float32
());
param
.
bmode
=
WarpPerspective
::
Param
::
BorderMode
::
REFLECT
;
param
.
imode
=
param
::
WarpPerspective
::
InterpolationMode
::
LINEAR
;
checker
.
set_param
(
param
);
checker
.
set_epsilon
(
1
+
1e-3
);
checker
.
execs
({{
N_SRC
,
10
,
11
,
3
},
{
2
,
3
,
3
},
{
2
},
{
2
,
3
,
11
,
12
}});
checker
.
execs
(
{{
N_SRC
,
17
,
13
,
3
},
{
123
,
3
,
3
},
{
123
},
{
123
,
3
,
16
,
15
}});
}
}
TEST_F
(
CUDA
,
WARP_PERSPECTIVE_FORWARD_NCHW_INT8
)
{
warp_perspective
::
run_int8_test
(
handle_cuda
());
}
...
...
src/gopt/impl/framework.cpp
浏览文件 @
61f917fb
此差异已折叠。
点击以展开。
src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
浏览文件 @
61f917fb
...
...
@@ -19,6 +19,7 @@
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/opr/imgproc.h"
using
namespace
mgb
;
using
namespace
gopt
;
...
...
@@ -443,4 +444,244 @@ void FuseNCHW4Int8Preprocess::apply(OptState& state) const {
};
state
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
}
/* ==================== FuseWarpPerspectiveDimshufflePass ================= */
const
char
*
FuseWarpPerspectiveDimshufflePass
::
name
()
const
{
return
mgb_cstr_log
(
"Fuse warp perspective dimshuffle pass"
);
}
void
FuseWarpPerspectiveDimshufflePass
::
apply
(
OptState
&
opt
)
const
{
auto
rewriter
=
opt
.
graph
().
make_rewriter
();
auto
uniq_reader_check
=
UniqReaderCheck
{
opt
.
graph
()};
auto
make_new_warp
=
[
&
rewriter
](
opr
::
WarpPerspective
*
warp
,
opr
::
WarpPerspective
::
Param
new_param
,
megdnn
::
DType
dst_dtype
,
SymbolVar
&
new_warp
)
{
OperatorNodeConfig
new_config
(
dst_dtype
);
if
(
warp
->
input
().
size
()
==
3
)
{
auto
src
=
rewriter
.
get_var
(
warp
->
input
(
0
)),
mat
=
rewriter
.
get_var
(
warp
->
input
(
1
)),
out_shape
=
rewriter
.
get_var
(
warp
->
input
(
2
));
new_warp
=
opr
::
WarpPerspective
::
make
(
src
,
mat
,
out_shape
,
new_param
,
new_config
);
}
else
{
mgb_assert
(
warp
->
input
().
size
()
==
4
);
auto
src
=
rewriter
.
get_var
(
warp
->
input
(
0
)),
mat
=
rewriter
.
get_var
(
warp
->
input
(
1
)),
mat_idx
=
rewriter
.
get_var
(
warp
->
input
(
2
)),
out_shape
=
rewriter
.
get_var
(
warp
->
input
(
3
));
new_warp
=
opr
::
WarpPerspective
::
make
(
src
,
mat
,
mat_idx
,
out_shape
,
new_param
,
new_config
);
}
};
auto
is_warp_nchw
=
[
&
uniq_reader_check
](
OperatorNodeBase
*
bottom_opr
,
OperatorNodeBase
*&
top_opr
)
{
// check warp
auto
warp
=
try_cast_as_op
<
opr
::
WarpPerspective
>
(
bottom_opr
);
if
(
warp
==
nullptr
)
return
false
;
auto
inp_dtype
=
warp
->
input
(
0
)
->
dtype
();
bool
is_u8_or_qu8
=
inp_dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
inp_dtype
.
enumv
()
==
DTypeEnum
::
Uint8
;
bool
is_nchw
=
warp
->
param
().
format
==
megdnn
::
param
::
WarpPerspective
::
Format
::
NCHW
;
if
(
!
(
is_u8_or_qu8
&&
is_nchw
))
return
false
;
if
(
!
uniq_reader_check
(
warp
->
input
(
0
)))
return
false
;
top_opr
=
warp
;
return
true
;
};
auto
is_warp_nhwc2nchw
=
[
&
uniq_reader_check
](
OperatorNodeBase
*
bottom_opr
,
OperatorNodeBase
*&
top_opr
)
{
// check shuffle
auto
shuffle
=
try_cast_as_op
<
opr
::
Dimshuffle
>
(
bottom_opr
);
if
(
shuffle
==
nullptr
)
return
false
;
auto
&&
shuffle_param
=
shuffle
->
param
();
if
(
shuffle_param
.
pattern_len
!=
4
)
return
false
;
bool
is_nhwc2nchw
=
shuffle_param
.
pattern
[
0
]
==
0
&&
shuffle_param
.
pattern
[
1
]
==
3
&&
shuffle_param
.
pattern
[
2
]
==
1
&&
shuffle_param
.
pattern
[
3
]
==
2
;
if
(
!
is_nhwc2nchw
)
return
false
;
if
(
!
uniq_reader_check
(
shuffle
->
input
(
0
)))
return
false
;
// check warp
auto
warp
=
try_cast_as_op
<
opr
::
WarpPerspective
>
(
shuffle
->
input
(
0
)
->
owner_opr
());
if
(
warp
==
nullptr
)
return
false
;
auto
inp_dtype
=
warp
->
input
(
0
)
->
dtype
();
bool
is_u8_or_qu8
=
inp_dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
inp_dtype
.
enumv
()
==
DTypeEnum
::
Uint8
;
bool
is_nhwc
=
warp
->
param
().
format
==
megdnn
::
param
::
WarpPerspective
::
Format
::
NHWC
;
if
(
!
(
is_u8_or_qu8
&&
is_nhwc
))
return
false
;
top_opr
=
warp
;
return
true
;
};
auto
try_warp_nchw_typecvt
=
[
&
rewriter
,
&
uniq_reader_check
,
&
is_warp_nchw
,
&
make_new_warp
](
OperatorNodeBase
*
opr
)
{
// check typecvt
auto
typecvt
=
try_cast_as_op
<
opr
::
TypeCvt
>
(
opr
);
if
(
typecvt
==
nullptr
)
return
false
;
bool
is_to_f32
=
typecvt
->
output
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
;
if
(
!
is_to_f32
)
return
false
;
if
(
!
uniq_reader_check
(
typecvt
->
input
(
0
)))
return
false
;
OperatorNodeBase
*
top_opr
=
nullptr
;
if
(
!
is_warp_nchw
(
typecvt
->
input
(
0
)
->
owner_opr
(),
top_opr
))
return
false
;
auto
warp
=
try_cast_as_op
<
opr
::
WarpPerspective
>
(
top_opr
);
SymbolVar
new_warp
;
make_new_warp
(
warp
,
warp
->
param
(),
opr
->
output
()[
0
]
->
dtype
(),
new_warp
);
rewriter
.
replace_var
(
opr
->
output
(
0
),
new_warp
.
node
(),
mgb_cstr_log
(
"replace warp + typecvt"
"fuse warp_dimshuffle(NCHW)"
));
return
true
;
};
auto
try_warp_nhwc2nchw_typecvt
=
[
&
rewriter
,
&
uniq_reader_check
,
&
is_warp_nhwc2nchw
,
&
make_new_warp
](
OperatorNodeBase
*
opr
)
{
// check typecvt
auto
typecvt
=
try_cast_as_op
<
opr
::
TypeCvt
>
(
opr
);
if
(
typecvt
==
nullptr
)
return
false
;
bool
is_to_f32
=
typecvt
->
output
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
Float32
;
if
(
!
is_to_f32
)
return
false
;
if
(
!
uniq_reader_check
(
typecvt
->
input
(
0
)))
return
false
;
OperatorNodeBase
*
top_opr
=
nullptr
;
if
(
!
is_warp_nhwc2nchw
(
typecvt
->
input
(
0
)
->
owner_opr
(),
top_opr
))
return
false
;
auto
warp
=
try_cast_as_op
<
opr
::
WarpPerspective
>
(
top_opr
);
opr
::
WarpPerspective
::
Param
new_param
=
warp
->
param
();
new_param
.
format
=
megdnn
::
param
::
WarpPerspective
::
Format
::
NHWC_NCHW
;
SymbolVar
new_warp
;
make_new_warp
(
warp
,
new_param
,
opr
->
output
()[
0
]
->
dtype
(),
new_warp
);
rewriter
.
replace_var
(
opr
->
output
(
0
),
new_warp
.
node
(),
mgb_cstr_log
(
"replace conv_bias + dimshuffle + "
"typecvt to warp_dimshuffle(NHWC_NCHW)"
));
return
true
;
};
auto
try_warp_nhwc2nchw4_typecvt
=
[
&
rewriter
,
&
uniq_reader_check
,
&
is_warp_nhwc2nchw
,
&
make_new_warp
](
OperatorNodeBase
*
opr
)
{
// check relayout
auto
relayout
=
try_cast_as_op
<
opr
::
RelayoutFormat
>
(
opr
);
if
(
relayout
==
nullptr
)
return
false
;
bool
is_to_q8
=
relayout
->
output
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
;
bool
is_to_nchw2nchw4
=
relayout
->
param
().
mode
==
opr
::
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW4
;
if
(
!
(
is_to_q8
&&
is_to_nchw2nchw4
))
return
false
;
if
(
!
uniq_reader_check
(
relayout
->
input
(
0
)))
return
false
;
OperatorNodeBase
*
top_opr
=
nullptr
;
if
(
!
is_warp_nhwc2nchw
(
relayout
->
input
(
0
)
->
owner_opr
(),
top_opr
))
return
false
;
auto
warp
=
try_cast_as_op
<
opr
::
WarpPerspective
>
(
top_opr
);
bool
is_small_chn
=
warp
->
input
(
0
)
->
shape
()[
3
]
<
4
;
if
(
!
is_small_chn
)
return
false
;
opr
::
WarpPerspective
::
Param
new_param
=
warp
->
param
();
new_param
.
format
=
megdnn
::
param
::
WarpPerspective
::
Format
::
NHWC_NCHW4_IC_SMALL
;
SymbolVar
new_warp
;
make_new_warp
(
warp
,
new_param
,
opr
->
output
()[
0
]
->
dtype
(),
new_warp
);
rewriter
.
replace_var
(
opr
->
output
(
0
),
new_warp
.
node
(),
mgb_cstr_log
(
"replace warp + dimshuffle + relayout(NCHW_NCHW4)"
"to warp_dimshuffle(NHWC_NCHW4_IC_SMALL)"
));
return
true
;
};
auto
try_warp_nchw2nchw4_typecvt
=
[
&
rewriter
,
&
uniq_reader_check
,
&
is_warp_nchw
,
&
make_new_warp
](
OperatorNodeBase
*
opr
)
{
// check relayout
auto
relayout
=
try_cast_as_op
<
opr
::
RelayoutFormat
>
(
opr
);
if
(
relayout
==
nullptr
)
return
false
;
bool
is_to_q8
=
relayout
->
output
(
0
)
->
dtype
().
enumv
()
==
DTypeEnum
::
QuantizedS8
;
bool
is_to_nchw2nchw4
=
relayout
->
param
().
mode
==
opr
::
RelayoutFormat
::
Param
::
Mode
::
NCHW_NCHW4
;
if
(
!
(
is_to_q8
&&
is_to_nchw2nchw4
))
return
false
;
if
(
!
uniq_reader_check
(
relayout
->
input
(
0
)))
return
false
;
OperatorNodeBase
*
top_opr
=
nullptr
;
if
(
!
is_warp_nchw
(
relayout
->
input
(
0
)
->
owner_opr
(),
top_opr
))
return
false
;
auto
warp
=
try_cast_as_op
<
opr
::
WarpPerspective
>
(
top_opr
);
bool
is_small_chn
=
warp
->
input
(
0
)
->
shape
()[
1
]
<
4
;
if
(
!
is_small_chn
)
return
false
;
opr
::
WarpPerspective
::
Param
new_param
=
warp
->
param
();
new_param
.
format
=
megdnn
::
param
::
WarpPerspective
::
Format
::
NCHW_NCHW4_IC_SMALL
;
SymbolVar
new_warp
;
make_new_warp
(
warp
,
new_param
,
opr
->
output
()[
0
]
->
dtype
(),
new_warp
);
rewriter
.
replace_var
(
opr
->
output
(
0
),
new_warp
.
node
(),
mgb_cstr_log
(
"replace warp + relayout(NCHW_NCHW4)"
"to warp_dimshuffle(NCHW_NCHW4_IC_SMALL)"
));
return
true
;
};
auto
on_opr
=
[
&
try_warp_nchw_typecvt
,
&
try_warp_nhwc2nchw_typecvt
,
&
try_warp_nhwc2nchw4_typecvt
,
&
try_warp_nchw2nchw4_typecvt
,
&
rewriter
](
OperatorNodeBase
*
opr
)
{
if
(
!
try_warp_nchw_typecvt
(
opr
)
&&
!
try_warp_nhwc2nchw_typecvt
(
opr
)
&&
!
try_warp_nhwc2nchw4_typecvt
(
opr
)
&&
!
try_warp_nchw2nchw4_typecvt
(
opr
))
{
rewriter
.
auto_replace_outputs
(
opr
);
}
};
opt
.
graph
().
iter
(
on_opr
);
rewriter
.
apply_inplace
();
}
\ No newline at end of file
src/gopt/include/megbrain/gopt/inference.h
浏览文件 @
61f917fb
...
...
@@ -172,6 +172,16 @@ namespace gopt {
m_opr_replace_func
;
};
/*!
* \brief fuse warp perspective and dimshuffle, quint8/uint8 to qint8/float
*/
class
FuseWarpPerspectiveDimshufflePass
:
public
Pass
{
public:
const
char
*
name
()
const
override
;
void
apply
(
OptState
&
opt
)
const
override
;
};
/*!
* \brief fuse deconv and typecvt to a deconv opr
*/
...
...
src/gopt/test/inference.cpp
浏览文件 @
61f917fb
...
...
@@ -1172,7 +1172,8 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) {
param
.
pad_h
=
param
.
pad_w
=
1
;
auto
w2
=
mkcvar
(
"w2"
,
{
4
,
4
,
3
,
3
}),
y
=
opr
::
Convolution
::
make
(
elem
,
w2
,
param
),
z
=
opr
::
AxisAddRemove
::
make
(
y
,
{
opr
::
AxisAddRemove
::
AxisDesc
::
make_add
(
0
)});
z
=
opr
::
AxisAddRemove
::
make
(
y
,
{
opr
::
AxisAddRemove
::
AxisDesc
::
make_add
(
0
)});
SymbolVar
y_opt
,
z_opt
;
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
...
...
@@ -3722,5 +3723,65 @@ TEST(TestGoptInference, PreProcessCase1) {
ASSERT_TRUE
(
y_opt
.
node
()
->
owner_opr
()
->
same_type
<
opr
::
RelayoutFormat
>
());
}
TEST
(
TestGoptInference
,
WarpAndPreProcessCase
)
{
REQUIRE_GPU
(
1
);
HostTensorGenerator
<
dtype
::
Uint8
,
RandomDistribution
::
UNIFORM
>
gen
(
0
,
255
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
size_t
n
=
1
;
size_t
c
=
3
;
size_t
h
=
16
;
size_t
w
=
16
;
auto
host_x1
=
gen
({
n
,
h
,
w
,
c
},
cn
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x1
);
auto
mat_host
=
std
::
make_shared
<
HostTensorND
>
(
cn
,
TensorShape
{
n
,
3
,
3
},
dtype
::
Float32
());
warp_perspective_mat_gen
(
*
mat_host
,
n
,
h
,
w
);
auto
mat
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
mat_host
).
rename
(
"mat"
);
opr
::
WarpPerspective
::
Param
warp_param
;
warp_param
.
format
=
opr
::
WarpPerspective
::
Param
::
Format
::
NHWC
;
auto
x_warp
=
opr
::
WarpPerspective
::
make
(
x
,
mat
,
TensorShape
{
h
,
w
},
warp_param
);
auto
x_nchw
=
opr
::
Dimshuffle
::
make
(
x_warp
,
{
0
,
3
,
1
,
2
},
4
,
cn
);
auto
x_u8
=
opr
::
TypeCvt
::
make
(
x_nchw
,
dtype
::
Float32
(),
cn
);
auto
x_s8
=
x_u8
-
128
;
auto
zero
=
DTypeScalar
(
dtype
::
Float32
());
auto
zero_tensor
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
zero
,
cn
);
auto
pad_channel_tensor
=
opr
::
Broadcast
::
make
(
zero_tensor
,
{
n
,
1
,
h
,
w
},
cn
);
auto
paded_x
=
opr
::
Concat
::
make
({
x_s8
,
pad_channel_tensor
},
1
,
cn
)
.
reshape
({
n
,
1
,
4
,
h
,
w
});
auto
nchw4_out
=
opr
::
Dimshuffle
::
make
(
paded_x
,
{
0
,
1
,
3
,
4
,
2
},
5
,
cn
);
auto
result
=
opr
::
TypeCvt
::
make
(
nchw4_out
,
dtype
::
QuantizedS8
(
1.
f
));
auto
y
=
result
;
SymbolVar
y_opt
;
auto
options
=
gopt
::
OptimizeForInferenceOptions
{};
options
.
enable_fuse_preprocess
();
unpack_vector
(
gopt
::
optimize_for_inference
({
y
},
options
),
y_opt
);
ASSERT_TRUE
(
y_opt
.
node
()
->
owner_opr
()
->
same_type
<
opr
::
WarpPerspective
>
());
ASSERT_EQ
(
opr
::
WarpPerspective
::
Param
::
Format
::
NHWC_NCHW4_IC_SMALL
,
find_opr
<
opr
::
WarpPerspective
>
(
y_opt
).
param
().
format
);
graph
->
compile
({{
y_opt
,
{}}})
->
to_json
()
->
writeto_fpath
(
output_file
(
"TestGoptInference.WarpAndPreProcessCase.json"
));
HostTensorND
host_y_opt
,
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
),
make_callback_copy
(
y_opt
,
host_y_opt
)});
func
->
execute
();
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-5
);
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/impl/imgproc.cpp
浏览文件 @
61f917fb
...
...
@@ -47,7 +47,11 @@ SymbolVar WarpPerspectiveForward::make(SymbolVar i0, SymbolVar i1, SymbolVar i2,
}
void
WarpPerspectiveForward
::
init_output_dtype
()
{
output
(
0
)
->
dtype
(
input
(
0
)
->
dtype
());
if
(
config
().
output_dtype
().
valid
())
{
output
(
0
)
->
dtype
(
config
().
output_dtype
());
}
else
{
output
(
0
)
->
dtype
(
input
(
0
)
->
dtype
());
}
}
void
WarpPerspectiveForward
::
add_input_layout_constraint
()
{
...
...
@@ -78,23 +82,40 @@ void WarpPerspectiveForward::outshape_by_symvar_do_get_output_shape(
mat_idx_shp
.
to_string
().
c_str
());
}
//! The index of height, e.g.,[b, h, w, c], the height_idx = 1
size_t
height_idx
=
0
;
if
(
param
().
format
==
Param
::
Format
::
NCHW
||
param
().
format
==
Param
::
Format
::
NCHW4
)
{
height_idx
=
2
;
}
else
{
height_idx
=
1
;
}
dest
=
imgshp
;
dest
[
0
]
=
matshp
[
0
];
if
(
param
().
format
==
Param
::
Format
::
NHWCD4
)
{
dest
.
shape
[
height_idx
]
=
oshp2d
.
shape
[
0
];
dest
.
shape
[
height_idx
+
2
]
=
oshp2d
.
shape
[
1
];
}
else
{
for
(
int
i
=
0
;
i
<
2
;
++
i
)
dest
.
shape
[
height_idx
+
i
]
=
oshp2d
.
shape
[
i
];
switch
(
param
().
format
)
{
case
Param
::
Format
::
NCHW_NCHW4_IC_SMALL
:
case
Param
::
Format
::
NHWC_NCHW4_IC_SMALL
:
dest
.
ndim
=
5
;
dest
[
0
]
=
matshp
[
0
];
dest
.
shape
[
1
]
=
1
;
dest
.
shape
[
2
]
=
oshp2d
.
shape
[
0
];
dest
.
shape
[
3
]
=
oshp2d
.
shape
[
1
];
dest
.
shape
[
4
]
=
4
;
break
;
case
Param
::
Format
::
NHWC_NCHW
:
dest
[
0
]
=
matshp
[
0
];
dest
.
shape
[
1
]
=
imgshp
.
shape
[
3
];
dest
.
shape
[
2
]
=
oshp2d
.
shape
[
0
];
dest
.
shape
[
3
]
=
oshp2d
.
shape
[
1
];
break
;
default:
size_t
height_idx
=
0
;
if
(
param
().
format
==
Param
::
Format
::
NCHW
||
param
().
format
==
Param
::
Format
::
NCHW4
)
{
height_idx
=
2
;
}
else
{
height_idx
=
1
;
}
dest
=
imgshp
;
dest
[
0
]
=
matshp
[
0
];
if
(
param
().
format
==
Param
::
Format
::
NHWCD4
)
{
dest
.
shape
[
height_idx
]
=
oshp2d
.
shape
[
0
];
dest
.
shape
[
height_idx
+
2
]
=
oshp2d
.
shape
[
1
];
}
else
{
for
(
int
i
=
0
;
i
<
2
;
++
i
)
dest
.
shape
[
height_idx
+
i
]
=
oshp2d
.
shape
[
i
];
}
break
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录