diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index a13eb877a435e57a768f479b57d99923dddf8e24..2e5bd26b30f016adb6f383c561e0c4040ddcae73 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -27,7 +27,7 @@ from megengine.core.ops.builtin import ( UniformRNG, ) from megengine.distributed.helper import get_device_count_by_fork -from megengine.random import RNG +from megengine.random import RNG, seed, uniform @pytest.mark.skipif( @@ -387,3 +387,18 @@ def test_PermutationRNG(): assert sum_result(out, lambda x: x) < 500 assert sum_result(out, np.sort) == 1000 + + +def test_seed(): + seed(10) + out1 = uniform(size=[10, 10]) + out2 = uniform(size=[10, 10]) + assert not (out1.numpy() == out2.numpy()).all() + + seed(10) + out3 = uniform(size=[10, 10]) + np.testing.assert_equal(out1.numpy(), out3.numpy()) + + seed(11) + out4 = uniform(size=[10, 10]) + assert not (out1.numpy() == out4.numpy()).all() diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index b913650e04436eb5f216640364b158e7d7374480..b5d147fc455802c0d965de28d65a848754fdd3ac 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -127,10 +127,8 @@ public: auto&& glob_handle = glob_default_handles[comp_node]; if (!glob_handle) { glob_handle = inst().do_new_handle(comp_node, glob_default_seed); - } else if (get_seed(glob_handle) != glob_default_seed) { - inst().DnnOpManagerBase::delete_handle(glob_handle); - glob_handle = inst().do_new_handle(comp_node, glob_default_seed); } + mgb_assert(get_seed(glob_handle) == glob_default_seed); return glob_handle; } @@ -141,6 +139,13 @@ public: static void set_glob_default_seed(uint64_t seed) { MGB_LOCK_GUARD(sm_mtx); + for(auto && elem : glob_default_handles){ + mgb_assert(elem.first.valid()); + if(elem.second){ + inst().DnnOpManagerBase::delete_handle(elem.second); + } + elem.second = inst().do_new_handle(elem.first, seed); + } glob_default_seed = seed; }