Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8764a6c8
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
8764a6c8
编写于
9月 22, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add volta dp4a int8 sass kernel
GitOrigin-RevId: 9fefd39678729ec185c1b09c5b4abd88ebbde3a0
上级
e296a684
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
70 addition
and
55 deletion
+70
-55
dnn/src/cuda/utils.cpp
dnn/src/cuda/utils.cpp
+52
-40
dnn/src/cuda/utils.h
dnn/src/cuda/utils.h
+13
-10
dnn/test/common/conv_bias.cpp
dnn/test/common/conv_bias.cpp
+5
-5
未找到文件。
dnn/src/cuda/utils.cpp
浏览文件 @
8764a6c8
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.h"
#include "src/cuda/utils.h"
...
@@ -30,49 +31,48 @@ struct DevicePropRec {
...
@@ -30,49 +31,48 @@ struct DevicePropRec {
constexpr
int
MAX_NR_DEVICE
=
32
;
constexpr
int
MAX_NR_DEVICE
=
32
;
DevicePropRec
device_prop_rec
[
MAX_NR_DEVICE
];
DevicePropRec
device_prop_rec
[
MAX_NR_DEVICE
];
const
char
*
cublasGetErrorString
(
cublasStatus_t
error
)
{
const
char
*
cublasGetErrorString
(
cublasStatus_t
error
)
{
switch
(
error
)
switch
(
error
)
{
{
case
CUBLAS_STATUS_SUCCESS
:
case
CUBLAS_STATUS_SUCCESS
:
return
"CUBLAS_STATUS_SUCCESS"
;
return
"CUBLAS_STATUS_SUCCESS"
;
case
CUBLAS_STATUS_NOT_INITIALIZED
:
case
CUBLAS_STATUS_NOT_INITIALIZED
:
return
"CUBLAS_STATUS_NOT_INITIALIZED"
;
return
"CUBLAS_STATUS_NOT_INITIALIZED"
;
case
CUBLAS_STATUS_ALLOC_FAILED
:
case
CUBLAS_STATUS_ALLOC_FAILED
:
return
"CUBLAS_STATUS_ALLOC_FAILED"
;
return
"CUBLAS_STATUS_ALLOC_FAILED"
;
case
CUBLAS_STATUS_INVALID_VALUE
:
case
CUBLAS_STATUS_INVALID_VALUE
:
return
"CUBLAS_STATUS_INVALID_VALUE"
;
return
"CUBLAS_STATUS_INVALID_VALUE"
;
case
CUBLAS_STATUS_ARCH_MISMATCH
:
case
CUBLAS_STATUS_ARCH_MISMATCH
:
return
"CUBLAS_STATUS_ARCH_MISMATCH"
;
return
"CUBLAS_STATUS_ARCH_MISMATCH"
;
case
CUBLAS_STATUS_MAPPING_ERROR
:
case
CUBLAS_STATUS_MAPPING_ERROR
:
return
"CUBLAS_STATUS_MAPPING_ERROR"
;
return
"CUBLAS_STATUS_MAPPING_ERROR"
;
case
CUBLAS_STATUS_EXECUTION_FAILED
:
case
CUBLAS_STATUS_EXECUTION_FAILED
:
return
"CUBLAS_STATUS_EXECUTION_FAILED"
;
return
"CUBLAS_STATUS_EXECUTION_FAILED"
;
case
CUBLAS_STATUS_INTERNAL_ERROR
:
case
CUBLAS_STATUS_INTERNAL_ERROR
:
return
"CUBLAS_STATUS_INTERNAL_ERROR"
;
return
"CUBLAS_STATUS_INTERNAL_ERROR"
;
case
CUBLAS_STATUS_LICENSE_ERROR
:
case
CUBLAS_STATUS_LICENSE_ERROR
:
return
"CUBLAS_STATUS_LICENSE_ERROR"
;
return
"CUBLAS_STATUS_LICENSE_ERROR"
;
case
CUBLAS_STATUS_NOT_SUPPORTED
:
case
CUBLAS_STATUS_NOT_SUPPORTED
:
return
"CUBLAS_STATUS_NOT_SUPPORTED"
;
return
"CUBLAS_STATUS_NOT_SUPPORTED"
;
}
}
return
"Unknown CUBLAS error"
;
return
"Unknown CUBLAS error"
;
}
}
}
// anonymous namespace
}
// anonymous namespace
void
cuda
::
__throw_cuda_error__
(
cudaError_t
err
,
const
char
*
msg
)
{
void
cuda
::
__throw_cuda_error__
(
cudaError_t
err
,
const
char
*
msg
)
{
auto
s
=
ssprintf
(
"cuda error %s(%d) occurred; expr: %s"
,
auto
s
=
ssprintf
(
"cuda error %s(%d) occurred; expr: %s"
,
cudaGetErrorString
(
err
),
int
(
err
),
msg
);
cudaGetErrorString
(
err
),
int
(
err
),
msg
);
megdnn_throw
(
s
.
c_str
());
megdnn_throw
(
s
.
c_str
());
}
}
void
cuda
::
__throw_cudnn_error__
(
cudnnStatus_t
err
,
const
char
*
msg
)
{
void
cuda
::
__throw_cudnn_error__
(
cudnnStatus_t
err
,
const
char
*
msg
)
{
auto
s
=
ssprintf
(
"cudnn error %s(%d) occurred; expr: %s"
,
auto
s
=
ssprintf
(
"cudnn error %s(%d) occurred; expr: %s"
,
cudnnGetErrorString
(
err
),
int
(
err
),
msg
);
cudnnGetErrorString
(
err
),
int
(
err
),
msg
);
megdnn_throw
(
s
.
c_str
());
megdnn_throw
(
s
.
c_str
());
}
}
void
cuda
::
__throw_cublas_error__
(
cublasStatus_t
err
,
const
char
*
msg
)
{
void
cuda
::
__throw_cublas_error__
(
cublasStatus_t
err
,
const
char
*
msg
)
{
auto
s
=
ssprintf
(
"cublas error %s(%d) occurred; expr: %s"
,
auto
s
=
ssprintf
(
"cublas error %s(%d) occurred; expr: %s"
,
cublasGetErrorString
(
err
),
int
(
err
),
msg
);
cublasGetErrorString
(
err
),
int
(
err
),
msg
);
megdnn_throw
(
s
.
c_str
());
megdnn_throw
(
s
.
c_str
());
}
}
...
@@ -92,17 +92,17 @@ void cuda::__throw_cutlass_error__(cutlass::Status err, const char* msg) {
...
@@ -92,17 +92,17 @@ void cuda::__throw_cutlass_error__(cutlass::Status err, const char* msg) {
megdnn_throw
(
s
.
c_str
());
megdnn_throw
(
s
.
c_str
());
}
}
void
cuda
::
report_error
(
const
char
*
msg
)
{
void
cuda
::
report_error
(
const
char
*
msg
)
{
megdnn_throw
(
msg
);
megdnn_throw
(
msg
);
MEGDNN_MARK_USED_VAR
(
msg
);
MEGDNN_MARK_USED_VAR
(
msg
);
}
}
uint32_t
cuda
::
safe_size_in_kern
(
size_t
size
)
{
uint32_t
cuda
::
safe_size_in_kern
(
size_t
size
)
{
if
(
!
size
||
size
>
Uint32Fastdiv
::
MAX_DIVIDEND
)
{
if
(
!
size
||
size
>
Uint32Fastdiv
::
MAX_DIVIDEND
)
{
megdnn_throw
(
ssprintf
(
megdnn_throw
(
"invalid size for element-wise kernel: %zu; "
ssprintf
(
"invalid size for element-wise kernel: %zu; "
"max supported size is %u"
,
"max supported size is %u"
,
size
,
Uint32Fastdiv
::
MAX_DIVIDEND
));
size
,
Uint32Fastdiv
::
MAX_DIVIDEND
));
}
}
return
size
;
return
size
;
}
}
...
@@ -111,7 +111,7 @@ cudaDeviceProp cuda::current_device_prop() {
...
@@ -111,7 +111,7 @@ cudaDeviceProp cuda::current_device_prop() {
int
dev
;
int
dev
;
cuda_check
(
cudaGetDevice
(
&
dev
));
cuda_check
(
cudaGetDevice
(
&
dev
));
megdnn_assert
(
dev
<
MAX_NR_DEVICE
,
"device number too large: %d"
,
dev
);
megdnn_assert
(
dev
<
MAX_NR_DEVICE
,
"device number too large: %d"
,
dev
);
auto
&&
rec
=
device_prop_rec
[
dev
];
auto
&&
rec
=
device_prop_rec
[
dev
];
if
(
!
rec
.
init
)
{
if
(
!
rec
.
init
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
rec
.
mtx
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
rec
.
mtx
);
if
(
!
rec
.
init
)
{
if
(
!
rec
.
init
)
{
...
@@ -137,6 +137,19 @@ size_t cuda::max_batch_x_channel_size() {
...
@@ -137,6 +137,19 @@ size_t cuda::max_batch_x_channel_size() {
return
current_device_prop
().
maxGridSize
[
2
];
return
current_device_prop
().
maxGridSize
[
2
];
}
}
uint32_t
cuda
::
param_buffer_start_address
()
{
auto
&&
device_prop
=
current_device_prop
();
int
cap
=
10
*
device_prop
.
major
+
device_prop
.
minor
;
// maxwell and pascal: 0x140
if
(
cap
>=
50
&&
cap
<
70
)
return
0x140
;
// volta ~ ampere: 0x160
else
if
(
cap
>=
70
)
return
0x160
;
megdnn_throw
(
ssprintf
(
"unsupported cuda compute capability %d"
,
cap
).
c_str
());
}
const
char
*
cuda
::
current_device_arch_name
()
{
const
char
*
cuda
::
current_device_arch_name
()
{
auto
&&
device_prop
=
current_device_prop
();
auto
&&
device_prop
=
current_device_prop
();
int
cap
=
10
*
device_prop
.
major
+
device_prop
.
minor
;
int
cap
=
10
*
device_prop
.
major
+
device_prop
.
minor
;
...
@@ -155,4 +168,3 @@ const char* cuda::current_device_arch_name() {
...
@@ -155,4 +168,3 @@ const char* cuda::current_device_arch_name() {
}
}
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/cuda/utils.h
浏览文件 @
8764a6c8
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
...
@@ -24,19 +25,19 @@
...
@@ -24,19 +25,19 @@
namespace
megdnn
{
namespace
megdnn
{
namespace
cuda
{
namespace
cuda
{
static
inline
HandleImpl
*
concrete_handle
(
Handle
*
handle
)
{
static
inline
HandleImpl
*
concrete_handle
(
Handle
*
handle
)
{
return
static_cast
<
cuda
::
HandleImpl
*>
(
handle
);
return
static_cast
<
cuda
::
HandleImpl
*>
(
handle
);
}
}
static
inline
cudnnHandle_t
cudnn_handle
(
Handle
*
handle
)
{
static
inline
cudnnHandle_t
cudnn_handle
(
Handle
*
handle
)
{
return
concrete_handle
(
handle
)
->
cudnn_handle
();
return
concrete_handle
(
handle
)
->
cudnn_handle
();
}
}
static
inline
cublasHandle_t
cublas_handle
(
Handle
*
handle
)
{
static
inline
cublasHandle_t
cublas_handle
(
Handle
*
handle
)
{
return
concrete_handle
(
handle
)
->
cublas_handle
();
return
concrete_handle
(
handle
)
->
cublas_handle
();
}
}
static
inline
cudaStream_t
cuda_stream
(
Handle
*
handle
)
{
static
inline
cudaStream_t
cuda_stream
(
Handle
*
handle
)
{
return
concrete_handle
(
handle
)
->
stream
();
return
concrete_handle
(
handle
)
->
stream
();
}
}
...
@@ -44,9 +45,8 @@ static inline megcore::AsyncErrorInfo* async_error_info(Handle* handle) {
...
@@ -44,9 +45,8 @@ static inline megcore::AsyncErrorInfo* async_error_info(Handle* handle) {
return
concrete_handle
(
handle
)
->
megcore_context
().
error_info
;
return
concrete_handle
(
handle
)
->
megcore_context
().
error_info
;
}
}
static
inline
void
CUDART_CB
callback_free
(
static
inline
void
CUDART_CB
callback_free
(
cudaStream_t
/* stream */
,
cudaStream_t
/* stream */
,
cudaError_t
status
,
void
*
userData
)
cudaError_t
status
,
void
*
userData
)
{
{
cuda_check
(
status
);
cuda_check
(
status
);
free
(
userData
);
free
(
userData
);
}
}
...
@@ -64,9 +64,12 @@ bool is_compute_capability_equalto(int major, int minor);
...
@@ -64,9 +64,12 @@ bool is_compute_capability_equalto(int major, int minor);
//! third demension
//! third demension
size_t
max_batch_x_channel_size
();
size_t
max_batch_x_channel_size
();
//! get param buffer start address at cmem[0]
uint32_t
param_buffer_start_address
();
const
char
*
current_device_arch_name
();
const
char
*
current_device_arch_name
();
}
// namespace cuda
}
// namespace cuda
}
// namespace megdnn
}
// namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/test/common/conv_bias.cpp
浏览文件 @
8764a6c8
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "test/common/conv_bias.h"
#include "test/common/conv_bias.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/opr_param_defs.h"
...
@@ -413,7 +414,7 @@ std::vector<TestArg> get_int8_nchw44_args(size_t kernel_size, size_t pack_size,
...
@@ -413,7 +414,7 @@ std::vector<TestArg> get_int8_nchw44_args(size_t kernel_size, size_t pack_size,
megdnn_assert
(
kernel_size
>
0
,
"not support kernel_size"
);
megdnn_assert
(
kernel_size
>
0
,
"not support kernel_size"
);
using
NLMode
=
param
::
ConvBias
::
NonlineMode
;
using
NLMode
=
param
::
ConvBias
::
NonlineMode
;
//
//
clang-format off
// clang-format off
for
(
auto
nlmode
:
{
NLMode
::
IDENTITY
,
NLMode
::
RELU
})
{
for
(
auto
nlmode
:
{
NLMode
::
IDENTITY
,
NLMode
::
RELU
})
{
for
(
auto
mode
:
{
param
::
ConvBias
::
Mode
::
CROSS_CORRELATION
})
{
for
(
auto
mode
:
{
param
::
ConvBias
::
Mode
::
CROSS_CORRELATION
})
{
for
(
size_t
b
:
{
1
,
2
})
{
for
(
size_t
b
:
{
1
,
2
})
{
...
@@ -795,7 +796,7 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
...
@@ -795,7 +796,7 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype,
return
z
;
return
z
;
};
};
megdnn_assert
(
rng
!=
nullptr
&&
bias_rng
!=
nullptr
);
megdnn_assert
(
rng
!=
nullptr
&&
bias_rng
!=
nullptr
);
checker
.
set_rng
(
0
,
rng
.
get
())
checker
.
set_rng
(
0
,
rng
.
get
())
.
set_rng
(
1
,
rng
.
get
())
.
set_rng
(
1
,
rng
.
get
())
.
set_rng
(
2
,
rng
.
get
())
.
set_rng
(
2
,
rng
.
get
())
.
set_rng
(
3
,
rng
.
get
());
.
set_rng
(
3
,
rng
.
get
());
...
@@ -1152,8 +1153,7 @@ void winograd_algo_extra_impl(const TensorNDArray& tensors, uint32_t m,
...
@@ -1152,8 +1153,7 @@ void winograd_algo_extra_impl(const TensorNDArray& tensors, uint32_t m,
handle
->
create_operator
<
WinogradFilterPreprocess
>
();
handle
->
create_operator
<
WinogradFilterPreprocess
>
();
winograd_preprocess_opr
->
param
().
output_block_size
=
m
;
winograd_preprocess_opr
->
param
().
output_block_size
=
m
;
winograd_preprocess_opr
->
param
().
format
=
format
;
winograd_preprocess_opr
->
param
().
format
=
format
;
winograd_preprocess_opr
->
param
().
compute_mode
=
winograd_preprocess_opr
->
param
().
compute_mode
=
param
.
compute_mode
;
param
.
compute_mode
;
TensorLayout
filter_transform_layout
;
TensorLayout
filter_transform_layout
;
winograd_preprocess_opr
->
deduce_layout
(
tensors
[
1
].
layout
,
winograd_preprocess_opr
->
deduce_layout
(
tensors
[
1
].
layout
,
filter_transform_layout
);
filter_transform_layout
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录