From 7c9569e4e5a2ba16fdf91f50606b525a00b14f4e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 9 Jul 2021 14:33:09 +0800 Subject: [PATCH] fix(mge/random): fix random seed GitOrigin-RevId: 121f459b1b30a086776e9832e15849de9987da0e --- imperative/python/test/unit/random/test_rng.py | 17 ++++++++++++++++- imperative/src/impl/ops/rng.cpp | 11 ++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index a13eb877a..2e5bd26b3 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -27,7 +27,7 @@ from megengine.core.ops.builtin import ( UniformRNG, ) from megengine.distributed.helper import get_device_count_by_fork -from megengine.random import RNG +from megengine.random import RNG, seed, uniform @pytest.mark.skipif( @@ -387,3 +387,18 @@ def test_PermutationRNG(): assert sum_result(out, lambda x: x) < 500 assert sum_result(out, np.sort) == 1000 + + +def test_seed(): + seed(10) + out1 = uniform(size=[10, 10]) + out2 = uniform(size=[10, 10]) + assert not (out1.numpy() == out2.numpy()).all() + + seed(10) + out3 = uniform(size=[10, 10]) + np.testing.assert_equal(out1.numpy(), out3.numpy()) + + seed(11) + out4 = uniform(size=[10, 10]) + assert not (out1.numpy() == out4.numpy()).all() diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index b913650e0..b5d147fc4 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -127,10 +127,8 @@ public: auto&& glob_handle = glob_default_handles[comp_node]; if (!glob_handle) { glob_handle = inst().do_new_handle(comp_node, glob_default_seed); - } else if (get_seed(glob_handle) != glob_default_seed) { - inst().DnnOpManagerBase::delete_handle(glob_handle); - glob_handle = inst().do_new_handle(comp_node, glob_default_seed); } + mgb_assert(get_seed(glob_handle) == glob_default_seed); return glob_handle; } @@ -141,6 +139,13 @@ public: static void set_glob_default_seed(uint64_t seed) { MGB_LOCK_GUARD(sm_mtx); + for(auto && elem : glob_default_handles){ + mgb_assert(elem.first.valid()); + if(elem.second){ + inst().DnnOpManagerBase::delete_handle(elem.second); + } + elem.second = inst().do_new_handle(elem.first, seed); + } glob_default_seed = seed; } -- GitLab