Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
比较版本
6af4a32e1781e213477ea0f9e866576f36afc123...374fed32150cdc0c2cce03d83846a7292cfacbe9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
8 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
源分支
374fed32150cdc0c2cce03d83846a7292cfacbe9
选择Git版本
...
目标分支
6af4a32e1781e213477ea0f9e866576f36afc123
选择Git版本
比较
Commits (6)
https://gitcode.net/megvii/megengine/-/commit/e3df532c8de5b455b70b1ca6d5ebb3df39a91680
Merge remote-tracking branch 'tp/fix-spell-error'
2021-07-22T19:00:50+08:00
megvii-mge
megengine@megvii.com
https://gitcode.net/megvii/megengine/-/commit/d4ea756b33d17675314b06be08c52a88ca1aa771
perf(mgb): disable FoldingConvBiasDimshufflePass in cuda10 for performance
2021-07-22T19:03:19+08:00
Megvii Engine Team
megengine@megvii.com
GitOrigin-RevId: d1b95a6f01ba73f98c0094e00fee3e61e9139628
https://gitcode.net/megvii/megengine/-/commit/c816ddb9a3380d567f2859127992dbe849647a55
chore(mge): run get_device_count("gpu") in subprocess
2021-07-22T19:03:19+08:00
Megvii Engine Team
megengine@megvii.com
GitOrigin-RevId: 0f0dc001cfc45fc0d04de1a86c27f8bba8185d6b
https://gitcode.net/megvii/megengine/-/commit/a2bdee6233f82187b01e86ea55502cc9c4373110
chore(release): bump version
2021-07-22T19:03:19+08:00
Megvii Engine Team
megengine@megvii.com
GitOrigin-RevId: a016ea9d564a32adecaf6a80bc01a8ddb95d88db
https://gitcode.net/megvii/megengine/-/commit/da4818275c7da4e3bcefc5313a46ddf2430fda69
feat(mge/third_party): update cutlass version
2021-07-22T19:03:19+08:00
huangxinda
megengine@megvii.com
https://gitcode.net/megvii/megengine/-/commit/374fed32150cdc0c2cce03d83846a7292cfacbe9
feat(mge/third_party): update MegRay version
2021-07-22T19:03:19+08:00
huangxinda
megengine@megvii.com
隐藏空白更改
内联
并排
Showing
29 changed file
with
179 addition
and
89 deletion
+179
-89
dnn/test/cuda/conv_bias_int8.cpp
dnn/test/cuda/conv_bias_int8.cpp
+40
-0
imperative/python/megengine/distributed/helper.py
imperative/python/megengine/distributed/helper.py
+0
-18
imperative/python/megengine/distributed/launcher.py
imperative/python/megengine/distributed/launcher.py
+3
-4
imperative/python/megengine/functional/tensor.py
imperative/python/megengine/functional/tensor.py
+2
-2
imperative/python/megengine/quantization/fake_quant.py
imperative/python/megengine/quantization/fake_quant.py
+2
-2
imperative/python/test/conftest.py
imperative/python/test/conftest.py
+2
-2
imperative/python/test/integration/test_param_pack.py
imperative/python/test/integration/test_param_pack.py
+0
-1
imperative/python/test/unit/autodiff/test_grad_manger.py
imperative/python/test/unit/autodiff/test_grad_manger.py
+0
-1
imperative/python/test/unit/core/test_autodiff.py
imperative/python/test/unit/core/test_autodiff.py
+0
-1
imperative/python/test/unit/core/test_dtype_quant.py
imperative/python/test/unit/core/test_dtype_quant.py
+2
-3
imperative/python/test/unit/distributed/test_distributed.py
imperative/python/test/unit/distributed/test_distributed.py
+1
-5
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+3
-6
imperative/python/test/unit/functional/test_functional_distributed.py
...ython/test/unit/functional/test_functional_distributed.py
+0
-1
imperative/python/test/unit/functional/test_tensor.py
imperative/python/test/unit/functional/test_tensor.py
+0
-1
imperative/python/test/unit/module/test_batchnorm.py
imperative/python/test/unit/module/test_batchnorm.py
+0
-1
imperative/python/test/unit/module/test_qat.py
imperative/python/test/unit/module/test_qat.py
+2
-4
imperative/python/test/unit/quantization/test_observer.py
imperative/python/test/unit/quantization/test_observer.py
+3
-3
imperative/python/test/unit/quantization/test_op.py
imperative/python/test/unit/quantization/test_op.py
+2
-2
imperative/python/test/unit/random/test_rng.py
imperative/python/test/unit/random/test_rng.py
+13
-13
imperative/python/test/unit/utils/test_network_node.py
imperative/python/test/unit/utils/test_network_node.py
+2
-3
imperative/python/version_template.py
imperative/python/version_template.py
+1
-1
src/core/impl/comp_node/cuda/comp_node.cpp
src/core/impl/comp_node/cuda/comp_node.cpp
+84
-7
src/core/test/graph/eager_eval.cpp
src/core/test/graph/eager_eval.cpp
+1
-1
src/gopt/impl/framework.cpp
src/gopt/impl/framework.cpp
+4
-0
src/gopt/impl/tensor_reformat.cpp
src/gopt/impl/tensor_reformat.cpp
+3
-4
src/gopt/include/megbrain/gopt/inference.h
src/gopt/include/megbrain/gopt/inference.h
+6
-0
src/gopt/test/inference.cpp
src/gopt/test/inference.cpp
+1
-1
third_party/MegRay
third_party/MegRay
+1
-1
third_party/cutlass
third_party/cutlass
+1
-1
未找到文件。
dnn/test/cuda/conv_bias_int8.cpp
浏览文件 @
374fed32
...
...
@@ -1060,6 +1060,46 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4_SMALL_CHANNEL) {
param
::
ConvBias
::
Format
::
CHWN4
);
}
TEST_F
(
CUDA
,
BENCHMARK_CONV_BIAS_INT8_NCHW4_NCHW
)
{
CUBenchmarker
<
ConvBiasForward
>
benchmarker
(
handle_cuda
());
size_t
RUNS
=
1000
;
benchmarker
.
set_display
(
false
).
set_times
(
RUNS
);
using
namespace
conv_bias
;
UniformIntRNG
int_rng
{
-
3
,
3
};
UniformIntRNG
bias_rng
{
-
50
,
50
};
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW4_NCHW
;
param
.
nonlineMode
=
ConvBias
::
Param
::
NonlineMode
::
IDENTITY
;
benchmarker
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBiasForward
>
(
"INT8_NCHW4_DOTPROD_IMPLICIT_GEMM"
));
benchmarker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.9980618
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
1.9980927
f
))
.
set_dtype
(
2
,
dtype
::
Float32
())
.
set_dtype
(
3
,
dtype
::
Float32
())
.
set_dtype
(
4
,
dtype
::
Float32
())
.
set_rng
(
0
,
&
int_rng
)
.
set_rng
(
1
,
&
int_rng
)
.
set_param
(
param
);
auto
run
=
[
&
](
const
TensorShapeArray
&
shapes
)
{
auto
time_in_ms
=
benchmarker
.
execs
({
shapes
[
0
],
shapes
[
1
],
shapes
[
2
],
{},
{}})
/
RUNS
;
printf
(
"src=%s, filter=%s, dst=%s, time=%.2f
\n
"
,
shapes
[
0
].
to_string
().
c_str
(),
shapes
[
1
].
to_string
().
c_str
(),
shapes
[
2
].
to_string
().
c_str
(),
time_in_ms
);
};
run
({{
16
,
16
,
224
,
224
,
4
},
{
32
,
16
,
3
,
3
,
4
},
{
1
,
32
,
1
,
1
}});
run
({{
16
,
16
,
92
,
160
,
4
},
{
32
,
16
,
3
,
3
,
4
},
{
1
,
32
,
1
,
1
}});
run
({{
16
,
16
,
46
,
80
,
4
},
{
32
,
16
,
3
,
3
,
4
},
{
1
,
32
,
1
,
1
}});
}
#if CUDA_VERSION >= 10020
TEST_F
(
CUDA
,
BENCHMARK_CUTLASS_CONV_BIAS_INT8_NCHW32
)
{
...
...
imperative/python/megengine/distributed/helper.py
浏览文件 @
374fed32
...
...
@@ -181,11 +181,6 @@ def synchronized(func: Callable):
return
wrapper
def
_get_device_count_worker
(
queue
,
device_type
):
num
=
get_device_count
(
device_type
)
queue
.
put
(
num
)
def
_check_device_initialized
(
device_type
:
str
,
rank
:
int
):
try
:
test
=
Tensor
(
1
,
device
=
(
device_type
+
str
(
rank
)))
...
...
@@ -198,19 +193,6 @@ def _check_device_initialized(device_type: str, rank: int):
raise
RuntimeError
(
errmsg
)
def
get_device_count_by_fork
(
device_type
:
str
):
"""
Get device count in fork thread.
See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork
for more information.
"""
q
=
mp
.
Queue
()
p
=
mp
.
Process
(
target
=
_get_device_count_worker
,
args
=
(
q
,
device_type
))
p
.
start
()
p
.
join
()
return
q
.
get
()
def
bcast_list_
(
inps
:
list
,
group
:
Group
=
WORLD
):
"""
Broadcast tensors between given group.
...
...
imperative/python/megengine/distributed/launcher.py
浏览文件 @
374fed32
...
...
@@ -13,9 +13,10 @@ import queue
from
..
import
_exit
from
..core._imperative_rt.core2
import
full_sync
from
..device
import
get_device_count
from
..logger
import
get_logger
from
.group
import
_set_machine_ranks
,
group_barrier
,
init_process_group
from
.helper
import
_check_device_initialized
,
get_device_count_by_fork
from
.helper
import
_check_device_initialized
from
.server
import
Client
,
Server
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN
=
(
...
...
@@ -91,9 +92,7 @@ class launcher:
backend
=
"auto"
,
):
self
.
func
=
func
self
.
n_gpus
=
(
n_gpus
if
n_gpus
is
not
None
else
get_device_count_by_fork
(
device_type
)
)
self
.
n_gpus
=
n_gpus
if
n_gpus
is
not
None
else
get_device_count
(
device_type
)
self
.
world_size
=
world_size
if
world_size
is
not
None
else
self
.
n_gpus
self
.
rank_start
=
rank_start
self
.
master_ip
=
master_ip
...
...
imperative/python/megengine/functional/tensor.py
浏览文件 @
374fed32
...
...
@@ -1188,11 +1188,11 @@ def copy(inp, device=None):
import numpy as np
import platform
from megengine import tensor
from megengine.d
istributed.helper import get_device_count_by_fork
from megengine.d
evice import get_device_count
import megengine.functional as F
x = tensor([1, 2, 3], np.int32)
if 1 == get_device_count
_by_fork
("gpu"):
if 1 == get_device_count("gpu"):
y = F.copy(x, "cpu1")
print(y.numpy())
else:
...
...
imperative/python/megengine/quantization/fake_quant.py
浏览文件 @
374fed32
...
...
@@ -61,14 +61,14 @@ class _FakeQuantize(Module):
def
fake_quant_forward
(
self
,
inp
,
qparams
:
QParams
=
None
):
raise
NotImplementedError
def
normal_foward
(
self
,
inp
,
qparams
:
QParams
=
None
):
def
normal_fo
r
ward
(
self
,
inp
,
qparams
:
QParams
=
None
):
return
inp
def
forward
(
self
,
inp
,
qparams
:
QParams
=
None
):
if
self
.
enabled
:
return
self
.
fake_quant_forward
(
inp
,
qparams
=
qparams
)
else
:
return
self
.
normal_foward
(
inp
,
qparams
=
qparams
)
return
self
.
normal_fo
r
ward
(
inp
,
qparams
=
qparams
)
class
TQT
(
_FakeQuantize
,
QParamsModuleMixin
):
...
...
imperative/python/test/conftest.py
浏览文件 @
374fed32
...
...
@@ -15,7 +15,7 @@ import megengine.functional
import
megengine.module
from
megengine
import
Parameter
from
megengine.core._imperative_rt.core2
import
sync
from
megengine.d
istributed.helper
import
get_device_count_by_fork
from
megengine.d
evice
import
get_device_count
from
megengine.experimental.autograd
import
(
disable_higher_order_directive
,
enable_higher_order_directive
,
...
...
@@ -25,7 +25,7 @@ from megengine.module import Linear, Module
sys
.
path
.
append
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"helpers"
))
_ngpu
=
get_device_count
_by_fork
(
"gpu"
)
_ngpu
=
get_device_count
(
"gpu"
)
@
pytest
.
fixture
(
autouse
=
True
)
...
...
imperative/python/test/integration/test_param_pack.py
浏览文件 @
374fed32
...
...
@@ -16,7 +16,6 @@ import megengine.autodiff as ad
import
megengine.distributed
as
dist
import
megengine.optimizer
as
optimizer
from
megengine
import
Parameter
,
tensor
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.module
import
Module
from
megengine.optimizer
import
SGD
...
...
imperative/python/test/unit/autodiff/test_grad_manger.py
浏览文件 @
374fed32
...
...
@@ -18,7 +18,6 @@ import megengine.functional as F
import
megengine.module
as
M
import
megengine.optimizer
as
optim
from
megengine.autodiff
import
GradManager
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.jit
import
trace
...
...
imperative/python/test/unit/core/test_autodiff.py
浏览文件 @
374fed32
...
...
@@ -20,7 +20,6 @@ from megengine.core._imperative_rt import CompNode, TensorAttr, imperative
from
megengine.core._imperative_rt.core2
import
TensorWeakRef
,
apply
,
sync
from
megengine.core.autodiff.grad
import
Grad
from
megengine.core.ops.builtin
import
Elemwise
,
Identity
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.functional.distributed
import
remote_recv
,
remote_send
...
...
imperative/python/test/unit/core/test_dtype_quant.py
浏览文件 @
374fed32
...
...
@@ -31,7 +31,7 @@ from megengine.core.tensor.dtype import (
quint4
,
quint8
,
)
from
megengine.d
istributed.helper
import
get_device_count_by_fork
from
megengine.d
evice
import
get_device_count
from
megengine.tensor
import
Tensor
...
...
@@ -184,8 +184,7 @@ def test_dtype_int4_ffi_handle():
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
!=
0
,
reason
=
"TypeCvt to quint4 is not supported on GPU"
,
get_device_count
(
"gpu"
)
!=
0
,
reason
=
"TypeCvt to quint4 is not supported on GPU"
,
)
def
test_quint4_typecvt
():
device
=
"xpux"
...
...
imperative/python/test/unit/distributed/test_distributed.py
浏览文件 @
374fed32
...
...
@@ -17,11 +17,7 @@ import megengine as mge
import
megengine.distributed
as
dist
from
megengine.core.ops.builtin
import
CollectiveComm
,
ParamPackConcat
,
ParamPackSplit
from
megengine.device
import
get_default_device
from
megengine.distributed.helper
import
(
get_device_count_by_fork
,
param_pack_concat
,
param_pack_split
,
)
from
megengine.distributed.helper
import
param_pack_concat
,
param_pack_split
def
_assert_q_empty
(
q
):
...
...
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
374fed32
...
...
@@ -22,8 +22,7 @@ from megengine import Parameter, Tensor, is_cuda_available, tensor
from
megengine.core._trace_option
import
use_symbolic_shape
from
megengine.core.autodiff.grad
import
Grad
from
megengine.core.tensor.utils
import
make_shape_tuple
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.jit
import
trace
from
megengine.device
import
get_device_count
def
test_where
():
...
...
@@ -613,7 +612,7 @@ def test_nms():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"gpu"
)
>
0
,
reason
=
"cuda does not support nchw int8"
get_device_count
(
"gpu"
)
>
0
,
reason
=
"cuda does not support nchw int8"
)
def
test_conv_bias
():
inp_scale
=
1.5
...
...
@@ -715,9 +714,7 @@ def test_conv_bias():
run
(
10
,
36
,
8
,
46
,
26
,
2
,
2
,
2
,
1
,
1
,
2
,
True
,
"relu"
)
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
>
0
,
reason
=
"no int8 algorithm on cuda"
)
@
pytest
.
mark
.
skipif
(
get_device_count
(
"gpu"
)
>
0
,
reason
=
"no int8 algorithm on cuda"
)
def
test_batch_conv_bias
():
inp_scale
=
1.5
w_scale
=
2.5
...
...
imperative/python/test/unit/functional/test_functional_distributed.py
浏览文件 @
374fed32
...
...
@@ -16,7 +16,6 @@ import megengine.distributed as dist
from
megengine
import
Parameter
,
tensor
from
megengine.core._imperative_rt.core2
import
sync
from
megengine.device
import
get_default_device
,
set_default_device
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.functional.distributed
import
(
all_gather
,
all_reduce_max
,
...
...
imperative/python/test/unit/functional/test_tensor.py
浏览文件 @
374fed32
...
...
@@ -18,7 +18,6 @@ from megengine import tensor
from
megengine.core._trace_option
import
use_symbolic_shape
from
megengine.core.tensor
import
megbrain_graph
as
G
from
megengine.core.tensor.utils
import
astensor1d
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.jit
import
trace
from
megengine.utils.network
import
Network
,
set_symbolic_shape
from
megengine.utils.network_node
import
VarNode
...
...
imperative/python/test/unit/module/test_batchnorm.py
浏览文件 @
374fed32
...
...
@@ -16,7 +16,6 @@ import megengine as mge
import
megengine.distributed
as
dist
from
megengine
import
Tensor
from
megengine.core._trace_option
import
use_symbolic_shape
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.module
import
BatchNorm1d
,
BatchNorm2d
,
SyncBatchNorm
_assert_allclose
=
functools
.
partial
(
np
.
testing
.
assert_allclose
,
atol
=
5e-6
,
rtol
=
5e-6
)
...
...
imperative/python/test/unit/module/test_qat.py
浏览文件 @
374fed32
...
...
@@ -6,7 +6,7 @@ import pytest
import
megengine.utils.comp_graph_tools
as
cgtools
from
megengine
import
jit
,
tensor
from
megengine.d
istributed.helper
import
get_device_count_by_fork
from
megengine.d
evice
import
get_device_count
from
megengine.functional
import
expand_dims
from
megengine.module
import
(
BatchMatMulActivation
,
...
...
@@ -101,9 +101,7 @@ def test_qat_conv():
np
.
testing
.
assert_allclose
(
normal_outputs
.
numpy
(),
qat_outputs
.
numpy
())
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
>
0
,
reason
=
"no int8 algorithm on cuda"
)
@
pytest
.
mark
.
skipif
(
get_device_count
(
"gpu"
)
>
0
,
reason
=
"no int8 algorithm on cuda"
)
def
test_qat_batchmatmul_activation
():
batch
=
4
in_features
=
8
...
...
imperative/python/test/unit/quantization/test_observer.py
浏览文件 @
374fed32
...
...
@@ -13,7 +13,7 @@ import pytest
import
megengine
as
mge
import
megengine.distributed
as
dist
from
megengine.d
istributed.helper
import
get_device_count_by_fork
from
megengine.d
evice
import
get_device_count
from
megengine.quantization
import
QuantMode
,
create_qparams
from
megengine.quantization.observer
import
(
ExponentialMovingAverageObserver
,
...
...
@@ -78,7 +78,7 @@ def test_passive_observer():
@
pytest
.
mark
.
require_ngpu
(
2
)
@
pytest
.
mark
.
isolated_distributed
def
test_sync_min_max_observer
():
word_size
=
get_device_count
_by_fork
(
"gpu"
)
word_size
=
get_device_count
(
"gpu"
)
x
=
np
.
random
.
rand
(
3
*
word_size
,
3
,
3
,
3
).
astype
(
"float32"
)
np_min
,
np_max
=
x
.
min
(),
x
.
max
()
...
...
@@ -96,7 +96,7 @@ def test_sync_min_max_observer():
@
pytest
.
mark
.
require_ngpu
(
2
)
@
pytest
.
mark
.
isolated_distributed
def
test_sync_exponential_moving_average_observer
():
word_size
=
get_device_count
_by_fork
(
"gpu"
)
word_size
=
get_device_count
(
"gpu"
)
t
=
np
.
random
.
rand
()
x1
=
np
.
random
.
rand
(
3
*
word_size
,
3
,
3
,
3
).
astype
(
"float32"
)
x2
=
np
.
random
.
rand
(
3
*
word_size
,
3
,
3
,
3
).
astype
(
"float32"
)
...
...
imperative/python/test/unit/quantization/test_op.py
浏览文件 @
374fed32
...
...
@@ -12,7 +12,7 @@ import pytest
import
megengine
as
mge
import
megengine.functional
as
F
from
megengine.core.tensor
import
dtype
from
megengine.d
istributed.helper
import
get_device_count_by_fork
from
megengine.d
evice
import
get_device_count
from
megengine.functional.elemwise
import
_elemwise_multi_type
,
_elwise
from
megengine.quantization
import
QuantMode
,
create_qparams
...
...
@@ -68,7 +68,7 @@ def test_elemwise(kind):
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"gpu"
)
>
0
,
reason
=
"cuda does not support nchw int8"
get_device_count
(
"gpu"
)
>
0
,
reason
=
"cuda does not support nchw int8"
)
def
test_conv_bias
():
inp_scale
=
np
.
float32
(
np
.
random
.
rand
()
+
1
)
...
...
imperative/python/test/unit/random/test_rng.py
浏览文件 @
374fed32
...
...
@@ -26,12 +26,12 @@ from megengine.core.ops.builtin import (
PoissonRNG
,
UniformRNG
,
)
from
megengine.d
istributed.helper
import
get_device_count_by_fork
from
megengine.d
evice
import
get_device_count
from
megengine.random
import
RNG
,
seed
,
uniform
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_gaussian_op
():
shape
=
(
...
...
@@ -61,7 +61,7 @@ def test_gaussian_op():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_uniform_op
():
shape
=
(
...
...
@@ -89,7 +89,7 @@ def test_uniform_op():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_gamma_op
():
_shape
,
_scale
=
2
,
0.8
...
...
@@ -117,7 +117,7 @@ def test_gamma_op():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_beta_op
():
_alpha
,
_beta
=
2
,
0.8
...
...
@@ -148,7 +148,7 @@ def test_beta_op():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_poisson_op
():
lam
=
F
.
full
([
8
,
9
,
11
,
12
],
value
=
2
,
dtype
=
"float32"
)
...
...
@@ -171,7 +171,7 @@ def test_poisson_op():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
get_device_count
(
"xpu"
)
<=
2
,
reason
=
"xpu counts need > 2"
,
)
def
test_permutation_op
():
n
=
1000
...
...
@@ -205,7 +205,7 @@ def test_permutation_op():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
get_device_count
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
)
def
test_UniformRNG
():
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
...
...
@@ -233,7 +233,7 @@ def test_UniformRNG():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
get_device_count
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
)
def
test_NormalRNG
():
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
...
...
@@ -262,7 +262,7 @@ def test_NormalRNG():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
get_device_count
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
)
def
test_GammaRNG
():
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
...
...
@@ -295,7 +295,7 @@ def test_GammaRNG():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
get_device_count
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
)
def
test_BetaRNG
():
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
...
...
@@ -330,7 +330,7 @@ def test_BetaRNG():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
get_device_count
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
)
def
test_PoissonRNG
():
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
...
...
@@ -359,7 +359,7 @@ def test_PoissonRNG():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
get_device_count
(
"xpu"
)
<=
1
,
reason
=
"xpu counts need > 1"
,
)
def
test_PermutationRNG
():
m1
=
RNG
(
seed
=
111
,
device
=
"xpu0"
)
...
...
imperative/python/test/unit/utils/test_network_node.py
浏览文件 @
374fed32
...
...
@@ -13,8 +13,7 @@ import megengine.random as rand
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core._wrap
import
Device
from
megengine.core.ops
import
builtin
from
megengine.device
import
is_cuda_available
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.device
import
get_device_count
,
is_cuda_available
from
megengine.functional.external
import
tensorrt_runtime_opr
from
megengine.jit.tracing
import
trace
from
megengine.tensor
import
Tensor
...
...
@@ -273,7 +272,7 @@ def test_deformable_ps_roi_pooling():
@
pytest
.
mark
.
skipif
(
get_device_count
_by_fork
(
"gpu"
)
>
0
,
get_device_count
(
"gpu"
)
>
0
,
reason
=
"does not support int8 when gpu compute capability less than 6.1"
,
)
def
test_convbias
():
...
...
imperative/python/version_template.py
浏览文件 @
374fed32
...
...
@@ -6,5 +6,5 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
__version__
=
"1.
5
.0.dev"
__version__
=
"1.
6
.0.dev"
src/core/impl/comp_node/cuda/comp_node.cpp
浏览文件 @
374fed32
...
...
@@ -27,8 +27,14 @@ using namespace mgb;
#include <thread>
#include <cuda.h>
#include <cuda_runtime.h>
#ifdef __unix__
#include <unistd.h>
#include <sys/wait.h>
#endif
using
CudaCompNodeImpl
=
CudaCompNode
::
CompNodeImpl
;
namespace
{
...
...
@@ -700,19 +706,90 @@ void CudaCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) {
/* ===================== CudaCompNode static methods ===================== */
namespace
{
#ifndef __unix__
CUresult
get_device_count_forksafe
(
int
*
pcnt
)
{
cuInit
(
0
);
return
cuDeviceGetCount
(
pcnt
);
}
#else
struct
RAIICloseFD
:
NonCopyableObj
{
int
m_fd
=
-
1
;
RAIICloseFD
(
int
fd
)
:
m_fd
(
fd
)
{}
~
RAIICloseFD
()
{
close
();}
void
close
()
{
if
(
m_fd
!=
-
1
)
{
::
close
(
m_fd
);
m_fd
=
-
1
;
}
}
};
// an implementation that does not call cuInit
CUresult
get_device_count_forksafe
(
int
*
pcnt
)
{
auto
err
=
cuDeviceGetCount
(
pcnt
);
if
(
err
!=
CUDA_ERROR_NOT_INITIALIZED
)
return
err
;
// cuInit not called, call it in child process
int
fd
[
2
];
mgb_assert
(
pipe
(
fd
)
==
0
,
"pipe() failed"
);
int
fdr
=
fd
[
0
],
fdw
=
fd
[
1
];
RAIICloseFD
fdr_guard
(
fdr
);
RAIICloseFD
fdw_guard
(
fdw
);
auto
cpid
=
fork
();
mgb_assert
(
cpid
!=
-
1
,
"fork() failed"
);
if
(
cpid
==
0
)
{
fdr_guard
.
close
();
do
{
err
=
cuInit
(
0
);
if
(
err
!=
CUDA_SUCCESS
)
break
;
err
=
cuDeviceGetCount
(
pcnt
);
}
while
(
0
);
auto
sz
=
write
(
fdw
,
&
err
,
sizeof
(
err
));
if
(
sz
==
sizeof
(
err
)
&&
err
==
CUDA_SUCCESS
)
{
sz
=
write
(
fdw
,
pcnt
,
sizeof
(
*
pcnt
));
}
fdw_guard
.
close
();
std
::
quick_exit
(
0
);
}
fdw_guard
.
close
();
auto
sz
=
read
(
fdr
,
&
err
,
sizeof
(
err
));
mgb_assert
(
sz
==
sizeof
(
err
),
"failed to read error code from child"
);
if
(
err
==
CUDA_SUCCESS
)
{
sz
=
read
(
fdr
,
pcnt
,
sizeof
(
*
pcnt
));
mgb_assert
(
sz
==
sizeof
(
*
pcnt
),
"failed to read device count from child"
);
return
err
;
}
// try again, maybe another thread called cuInit while we fork
auto
err2
=
cuDeviceGetCount
(
pcnt
);
if
(
err2
==
CUDA_SUCCESS
)
return
err2
;
if
(
err2
==
CUDA_ERROR_NOT_INITIALIZED
)
return
err
;
return
err2
;
}
#endif
const
char
*
cu_get_error_string
(
CUresult
err
)
{
const
char
*
ret
=
nullptr
;
cuGetErrorString
(
err
,
&
ret
);
if
(
!
ret
)
ret
=
"unknown cuda error"
;
return
ret
;
}
}
// namespace
bool
CudaCompNode
::
available
()
{
static
int
result
=
-
1
;
static
Spinlock
mtx
;
MGB_LOCK_GUARD
(
mtx
);
if
(
result
==
-
1
)
{
int
ndev
=
-
1
;
auto
err
=
cudaGetDeviceCount
(
&
ndev
);
result
=
err
==
cudaSuccess
&&
ndev
>
0
;
auto
err
=
get_device_count_forksafe
(
&
ndev
);
result
=
err
==
CUDA_SUCCESS
&&
ndev
>
0
;
if
(
!
result
)
{
mgb_log_warn
(
"cuda unavailable: %s(%d) ndev=%d"
,
cu
daGetErrorS
tring
(
err
),
static_cast
<
int
>
(
err
),
ndev
);
cu
_get_error_s
tring
(
err
),
static_cast
<
int
>
(
err
),
ndev
);
}
if
(
err
==
cudaErrorInitializationError
)
{
if
(
err
==
CUDA_ERROR_NOT_INITIALIZED
)
{
mgb_throw
(
std
::
runtime_error
,
"cuda initialization error."
);
}
}
...
...
@@ -857,11 +934,11 @@ size_t CudaCompNode::get_device_count(bool warn) {
static
Spinlock
mtx
;
MGB_LOCK_GUARD
(
mtx
);
if
(
cnt
==
-
1
)
{
auto
err
=
cudaGetDeviceCount
(
&
cnt
);
if
(
err
!=
cudaSuccess
)
{
auto
err
=
get_device_count_forksafe
(
&
cnt
);
if
(
err
!=
CUDA_SUCCESS
)
{
if
(
warn
)
mgb_log_error
(
"cudaGetDeviceCount failed: %s (err %d)"
,
cu
daGetErrorS
tring
(
err
),
int
(
err
));
cu
_get_error_s
tring
(
err
),
int
(
err
));
cnt
=
0
;
}
mgb_assert
(
cnt
>=
0
);
...
...
src/core/test/graph/eager_eval.cpp
浏览文件 @
374fed32
...
...
@@ -495,7 +495,7 @@ TEST_F(TestGraphEagerReeval, MemoryAlloc) {
// | +--> 8 -> 8
// |-----> x1(fwd) -> 8 -|
// total usage : 63 + (16 after the first iteration)
// x might has iteration i's memory, but x0/x1 foward i-1's memory
// x might has iteration i's memory, but x0/x1 fo
r
ward i-1's memory
size_t
length
=
reserve
/
(
sizeof
(
dt_int32
)
*
5
*
16
);
auto
host_x
=
std
::
make_shared
<
HostTensorND
>
(
cn
,
dtype
::
Int32
());
HostTensorND
host_val
;
...
...
src/gopt/impl/framework.cpp
浏览文件 @
374fed32
...
...
@@ -772,7 +772,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
add_pass
<
RemoveRedundantTypeCvtPass
>
();
add_pass
(
FuseNCHW4Int8Preprocess
::
make
());
add_pass
<
FuseWarpPerspectiveDimshufflePass
>
();
#if CUDA_VERSION >= 10020
add_pass
<
FoldingConvBiasDimshufflePass
>
();
#endif
});
cb
(
chwn4
,
{
add_pass
<
FuseConvBiasNonlinPass
>
();
...
...
@@ -791,7 +793,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
add_pass
<
RemoveRedundantTypeCvtPass
>
();
add_pass
(
FuseNCHW4Int8Preprocess
::
make
());
add_pass
<
FuseWarpPerspectiveDimshufflePass
>
();
#if CUDA_VERSION >= 10020
add_pass
<
FoldingConvBiasDimshufflePass
>
();
#endif
});
cb
(
fuse_conv_bias_nonlinearity
,
{
add_pass
<
FuseConvBiasNonlinPass
>
();
});
...
...
src/gopt/impl/tensor_reformat.cpp
浏览文件 @
374fed32
...
...
@@ -3638,6 +3638,7 @@ void ShuffleShuffleRemovePass::apply(OptState& opt) const {
MIDOUT_E
}
#if CUDA_VERSION >= 10020
/* ==================== FoldingConvBiasDimshufflePass ================= */
const
char
*
FoldingConvBiasDimshufflePass
::
name
()
const
{
return
mgb_cstr_log
(
"folding conv bias dimshuffle pass"
);
...
...
@@ -4068,20 +4069,17 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
return
true
;
};
MGB_MARK_USED_VAR
(
try_conv_reformat_nchw322nchw4
);
MGB_MARK_USED_VAR
(
try_conv_reformat_nchw42nchw32
);
auto
on_opr
=
[
&
try_conv_dimshuffle_reshape_typecvt
,
&
try_conv_reformat_nchw42nchw32
,
&
try_conv_reformat_nchw42nhwc
,
#if CUDA_VERSION >= 10020
&
try_conv_reformat_nchw322nchw4
,
#endif
&
rewriter
](
OperatorNodeBase
*
opr
)
{
if
(
!
try_conv_dimshuffle_reshape_typecvt
(
opr
)
&&
!
try_conv_reformat_nchw42nchw32
(
opr
)
&&
!
try_conv_reformat_nchw42nhwc
(
opr
)
#if CUDA_VERSION >= 10020
&&
!
try_conv_reformat_nchw322nchw4
(
opr
)
#endif
)
{
rewriter
.
auto_replace_outputs
(
opr
);
}
...
...
@@ -4091,6 +4089,7 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
MIDOUT_E
}
#endif
/* ==================== PaddingChannelPass ================= */
const
char
*
PaddingChannelPass
::
name
()
const
{
...
...
src/gopt/include/megbrain/gopt/inference.h
浏览文件 @
374fed32
...
...
@@ -16,6 +16,10 @@
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#if MGB_CUDA
#include <cuda.h>
#endif
namespace
mgb
{
namespace
gopt
{
...
...
@@ -427,11 +431,13 @@ namespace gopt {
void
apply
(
OptState
&
opt
)
const
override
;
};
#if CUDA_VERSION >= 10020
class
FoldingConvBiasDimshufflePass
final
:
public
Pass
{
public:
const
char
*
name
()
const
override
;
void
apply
(
OptState
&
opt
)
const
override
;
};
#endif
/*!
* \brief padding channel to enable fast int8/int4 support
...
...
src/gopt/test/inference.cpp
浏览文件 @
374fed32
...
...
@@ -4155,6 +4155,7 @@ TEST(TestGoptInference, WarpAndPreProcessCase1) {
MGB_ASSERT_TENSOR_NEAR
(
host_y
,
host_y_opt
,
1e-5
);
}
#if CUDA_VERSION >= 10020
TEST
(
TestGoptInference
,
FoldingConvDimshuffle
)
{
REQUIRE_GPU
(
1
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
...
...
@@ -4307,7 +4308,6 @@ TEST(TestGoptInference, FoldingConvDimshuffleNCHW4NCHW32) {
MGB_ASSERT_TENSOR_EQ
(
host_y_fuse
,
host_y_non_fuse
);
}
#if CUDA_VERSION >= 10020
TEST
(
TestGoptInference
,
FoldingConvDimshuffleNCHW32NCHW4
)
{
REQUIRE_GPU
(
1
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
...
...
MegRay
@
5f0e18e7
比较
eb8365c1
...
5f0e18e7
Subproject commit
eb8365c1015624348dbbb0a3ed97eecb40643e9e
Subproject commit
5f0e18e73ced4689f82f1e3b90924bbcc35bd56a
cutlass
@
baee1355
比较
eafd7f8d
...
baee1355
Subproject commit
eafd7f8d33d114ef2569a1fc2b851b2325edee46
Subproject commit
baee1355d997b7bae8f93d25ba4bf41a61030599