提交 7c9569e4 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge/random): fix random seed

GitOrigin-RevId: 121f459b1b30a086776e9832e15849de9987da0e
上级 07de1571
......@@ -27,7 +27,7 @@ from megengine.core.ops.builtin import (
UniformRNG,
)
from megengine.distributed.helper import get_device_count_by_fork
from megengine.random import RNG
from megengine.random import RNG, seed, uniform
@pytest.mark.skipif(
......@@ -387,3 +387,18 @@ def test_PermutationRNG():
assert sum_result(out, lambda x: x) < 500
assert sum_result(out, np.sort) == 1000
def test_seed():
seed(10)
out1 = uniform(size=[10, 10])
out2 = uniform(size=[10, 10])
assert not (out1.numpy() == out2.numpy()).all()
seed(10)
out3 = uniform(size=[10, 10])
np.testing.assert_equal(out1.numpy(), out3.numpy())
seed(11)
out4 = uniform(size=[10, 10])
assert not (out1.numpy() == out4.numpy()).all()
......@@ -127,10 +127,8 @@ public:
auto&& glob_handle = glob_default_handles[comp_node];
if (!glob_handle) {
glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
} else if (get_seed(glob_handle) != glob_default_seed) {
inst().DnnOpManagerBase::delete_handle(glob_handle);
glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
}
mgb_assert(get_seed(glob_handle) == glob_default_seed);
return glob_handle;
}
......@@ -141,6 +139,13 @@ public:
static void set_glob_default_seed(uint64_t seed) {
MGB_LOCK_GUARD(sm_mtx);
for(auto && elem : glob_default_handles){
mgb_assert(elem.first.valid());
if(elem.second){
inst().DnnOpManagerBase::delete_handle(elem.second);
}
elem.second = inst().do_new_handle(elem.first, seed);
}
glob_default_seed = seed;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册