提交 4917534b 编写于 作者: M Megvii Engine Team

feat(mge/optimizer): save state's numpy value by default in `state_dict`

GitOrigin-RevId: ec7e4d56f54f724c039b462906583bf025d060e6
上级 84f990a0
......@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta, abstractmethod
from collections.abc import Iterable
from contextlib import contextmanager
from typing import Dict
from typing import Iterable as Iter
from typing import Union
......@@ -180,7 +179,7 @@ class Optimizer(metaclass=ABCMeta):
param.grad = None
pop_scope("clear_grad")
def state_dict(self) -> Dict:
def state_dict(self, keep_var=False) -> Dict:
r"""
Export the optimizer state.
......@@ -198,6 +197,9 @@ class Optimizer(metaclass=ABCMeta):
cur_id += 1
for param, st in self._state.items():
if not keep_var:
for k, v in st.items():
st[k] = v.numpy()
state[param2id[param]] = st
for group in self.param_groups:
......@@ -218,7 +220,6 @@ class Optimizer(metaclass=ABCMeta):
raise ValueError(
"loaded state dict has a different number of parameter groups"
)
parameter_map = dict() # type: Dict
for group_new, group_saved in zip(self.param_groups, state["param_groups"]):
if len(group_new["params"]) != len(group_saved["params"]):
raise ValueError(
......@@ -232,8 +233,9 @@ class Optimizer(metaclass=ABCMeta):
self._state[p] = state["state"][param_saved].copy()
for k, v in self._state[p].items():
if isinstance(v, Tensor):
# TODO: maybe a more efficient way?
self._state[p][k] = Tensor(v.numpy())
self._state[p][k] = v.detach()
else:
self._state[p][k] = Tensor(v)
if set(group_new.keys()) != set(group_saved.keys()):
raise ValueError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册