提交 cad8568c 编写于 作者: M Megvii Engine Team

fix(mge/optimizer): fix optimizer's state_dict bug

GitOrigin-RevId: 67fb112fb8e4d7b295a6f4a2a2c8254002c97bbc
上级 0ed36998
......@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from abc import ABCMeta, abstractmethod
from collections.abc import Iterable
from typing import Dict
......@@ -197,10 +198,11 @@ class Optimizer(metaclass=ABCMeta):
cur_id += 1
for param, st in self._state.items():
_st = copy.copy(st)
if not keep_var:
for k, v in st.items():
st[k] = v.numpy()
state[param2id[param]] = st
_st[k] = v.numpy()
state[param2id[param]] = _st
for group in self.param_groups:
param_group = {k: v for k, v in group.items() if k != "params"}
......
......@@ -104,6 +104,10 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
)
step += 1
check_func(ori_params, net.parameters(), step)
try_state_dict = {
"net": net.state_dict(),
"opt": opt.state_dict(),
}
def test_sgd():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册