diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index 49c39ae76b461c98fb6b26c8d51488435245ce34..1196594f5ff4e976e1beb18a895eeb4be84f263f 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -8,7 +8,6 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from abc import ABCMeta, abstractmethod from collections.abc import Iterable -from contextlib import contextmanager from typing import Dict from typing import Iterable as Iter from typing import Union @@ -180,7 +179,7 @@ class Optimizer(metaclass=ABCMeta): param.grad = None pop_scope("clear_grad") - def state_dict(self) -> Dict: + def state_dict(self, keep_var=False) -> Dict: r""" Export the optimizer state. @@ -198,6 +197,9 @@ class Optimizer(metaclass=ABCMeta): cur_id += 1 for param, st in self._state.items(): + if not keep_var: + for k, v in st.items(): + st[k] = v.numpy() state[param2id[param]] = st for group in self.param_groups: @@ -218,7 +220,6 @@ class Optimizer(metaclass=ABCMeta): raise ValueError( "loaded state dict has a different number of parameter groups" ) - parameter_map = dict() # type: Dict for group_new, group_saved in zip(self.param_groups, state["param_groups"]): if len(group_new["params"]) != len(group_saved["params"]): raise ValueError( @@ -232,8 +233,9 @@ class Optimizer(metaclass=ABCMeta): self._state[p] = state["state"][param_saved].copy() for k, v in self._state[p].items(): if isinstance(v, Tensor): - # TODO: maybe a more efficient way? - self._state[p][k] = Tensor(v.numpy()) + self._state[p][k] = v.detach() + else: + self._state[p][k] = Tensor(v) if set(group_new.keys()) != set(group_saved.keys()): raise ValueError(