From 4917534b65f4f1bd3aa1c3fef49d08468837a222 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 16 Mar 2021 13:38:41 +0800 Subject: [PATCH] feat(mge/optimizer): save state's numpy value by default in `state_dict` GitOrigin-RevId: ec7e4d56f54f724c039b462906583bf025d060e6 --- imperative/python/megengine/optimizer/optimizer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index 49c39ae76..1196594f5 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -8,7 +8,6 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from abc import ABCMeta, abstractmethod from collections.abc import Iterable -from contextlib import contextmanager from typing import Dict from typing import Iterable as Iter from typing import Union @@ -180,7 +179,7 @@ class Optimizer(metaclass=ABCMeta): param.grad = None pop_scope("clear_grad") - def state_dict(self) -> Dict: + def state_dict(self, keep_var=False) -> Dict: r""" Export the optimizer state. @@ -198,6 +197,9 @@ class Optimizer(metaclass=ABCMeta): cur_id += 1 for param, st in self._state.items(): + if not keep_var: + for k, v in st.items(): + st[k] = v.numpy() state[param2id[param]] = st for group in self.param_groups: @@ -218,7 +220,6 @@ class Optimizer(metaclass=ABCMeta): raise ValueError( "loaded state dict has a different number of parameter groups" ) - parameter_map = dict() # type: Dict for group_new, group_saved in zip(self.param_groups, state["param_groups"]): if len(group_new["params"]) != len(group_saved["params"]): raise ValueError( @@ -232,8 +233,9 @@ class Optimizer(metaclass=ABCMeta): self._state[p] = state["state"][param_saved].copy() for k, v in self._state[p].items(): if isinstance(v, Tensor): - # TODO: maybe a more efficient way? - self._state[p][k] = Tensor(v.numpy()) + self._state[p][k] = v.detach() + else: + self._state[p][k] = Tensor(v) if set(group_new.keys()) != set(group_saved.keys()): raise ValueError( -- GitLab