From cad8568c34727318e61a0895684f5c92ab5732bc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 31 Mar 2021 12:01:56 +0800 Subject: [PATCH] fix(mge/optimizer): fix optimizer's state_dict bug GitOrigin-RevId: 67fb112fb8e4d7b295a6f4a2a2c8254002c97bbc --- imperative/python/megengine/optimizer/optimizer.py | 6 ++++-- imperative/python/test/integration/test_optimizer.py | 4 ++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index 1196594f5..cea3e49dc 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -6,6 +6,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import copy from abc import ABCMeta, abstractmethod from collections.abc import Iterable from typing import Dict @@ -197,10 +198,11 @@ class Optimizer(metaclass=ABCMeta): cur_id += 1 for param, st in self._state.items(): + _st = copy.copy(st) if not keep_var: for k, v in st.items(): - st[k] = v.numpy() - state[param2id[param]] = st + _st[k] = v.numpy() + state[param2id[param]] = _st for group in self.param_groups: param_group = {k: v for k, v in group.items() if k != "params"} diff --git a/imperative/python/test/integration/test_optimizer.py b/imperative/python/test/integration/test_optimizer.py index 6210233e0..fd51d567b 100644 --- a/imperative/python/test/integration/test_optimizer.py +++ b/imperative/python/test/integration/test_optimizer.py @@ -104,6 +104,10 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ) step += 1 check_func(ori_params, net.parameters(), step) + try_state_dict = { + "net": net.state_dict(), + "opt": opt.state_dict(), + } def test_sgd(): -- GitLab