diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 9fa969b9da9443f1f11a42e005689ea733af6f9f..aab4e87dd89f974c03bd2802047b00e8d7fa9d02 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -21,9 +21,9 @@ from ..utils.naming import auto_naming logger = get_logger(__name__) -def _expand_structure(key, obj): +def _expand_structure(prefix, obj): if isinstance(obj, (Tensor, Module)): - return [(key, obj)] + return [(prefix, obj)] elif isinstance(obj, (list, tuple, dict)): ret = [] if isinstance(obj, dict): @@ -37,12 +37,32 @@ def _expand_structure(key, obj): "keys for Tensor and Module must be str, error key: {}".format(k) ) for kt, vt in sub_ret: - ret.extend([(key + "." + kt, vt)]) + ret.extend([(prefix + "." + kt, vt)]) return ret else: return [] +def _access_structure(obj, key, callback=None): + key_list = key.split(".") + cur = obj + parent = None + for k in key_list: + parent = cur + if isinstance(cur, (Tensor, Module)): + cur = getattr(cur, k) + elif isinstance(cur, (list, tuple)): + k = int(k) + cur = cur[k] + elif isinstance(cur, dict): + cur = cur[k] + else: + raise ValueError( + "Unsupport value type {} to access attribute".format(type(cur)) + ) + return callback(parent, k, cur) + + def _is_parameter(obj): return isinstance(obj, Parameter) diff --git a/imperative/python/megengine/module/sequential.py b/imperative/python/megengine/module/sequential.py index e484110c3e66bc7e4a3a5d7aee25cee9fa3ed3fd..b4dbdafd2958cc533693980ccd4e4a87498f2bea 100644 --- a/imperative/python/megengine/module/sequential.py +++ b/imperative/python/megengine/module/sequential.py @@ -18,9 +18,9 @@ class Sequential(Module): Alternatively, an ordered dict of modules can also be passed in. To make it easier to understand, here is a small example: - + Examples: - + .. testcode:: import numpy as np diff --git a/imperative/python/megengine/quantization/quantize.py b/imperative/python/megengine/quantization/quantize.py index 09fa9012abc3ea298f8902b6ddb5d90cc344fa31..7d39c1d27c40ec63be5bd62536a13c5ed34c4a69 100644 --- a/imperative/python/megengine/quantization/quantize.py +++ b/imperative/python/megengine/quantization/quantize.py @@ -7,7 +7,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from copy import copy, deepcopy from functools import partial -from typing import Callable, Dict, Tuple +from typing import Callable import numpy as np @@ -19,6 +19,7 @@ from ..module import quantized as Quantized from ..module.qat import QATModule from ..module.quantized import QuantizedModule from ..tensor import Tensor +from ..utils.module_utils import set_expand_structure from .qconfig import QConfig, ema_fakequant_qconfig @@ -79,11 +80,7 @@ def quantize(module: Module, inplace: bool = True, mapping: dict = None): module._flatten(with_key=True, with_parent=True, predicate=is_qat) ): new_mod = convert_dict[type(submodule)].from_qat_module(submodule) - if isinstance(parent, Float.Sequential): - # cannnot use setattr to be compatible with Sequential's ``__setitem__`` - parent[int(key.split(".")[-1])] = new_mod - else: - setattr(parent, key.split(".")[-1], new_mod) + set_expand_structure(parent, key, new_mod) return module @@ -126,11 +123,7 @@ def quantize_qat( continue new_mod = convert_dict[type(submodule)].from_float_module(submodule) - if isinstance(parent, Float.Sequential): - # cannnot use setattr to be compatible with Sequential's ``__setitem__`` - parent[int(key.split(".")[-1])] = new_mod - else: - setattr(parent, key.split(".")[-1], new_mod) + set_expand_structure(parent, key, new_mod) propagate_qconfig(module, qconfig) return module diff --git a/imperative/python/megengine/utils/module_utils.py b/imperative/python/megengine/utils/module_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c66eb6060c61edfc59ac0cddd60676c757168ee7 --- /dev/null +++ b/imperative/python/megengine/utils/module_utils.py @@ -0,0 +1,43 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from collections import Iterable + +from ..module import Sequential +from ..module.module import Module, _access_structure +from ..tensor import Tensor + + +def get_expand_structure(obj: Module, key: str): + """ + Gets Module's attribute compatible with complex key from Module's :meth:`~.named_children`. + Supports handling structure containing list or dict. + """ + + def f(_, __, cur): + return cur + + return _access_structure(obj, key, callback=f) + + +def set_expand_structure(obj: Module, key: str, value): + """ + Sets Module's attribute compatible with complex key from Module's :meth:`~.named_children`. + Supports handling structure containing list or dict. + """ + + def f(parent, key, cur): + if isinstance(parent, (Tensor, Module)): + # cannnot use setattr to be compatible with Sequential's ``__setitem__`` + if isinstance(cur, Sequential): + parent[int(key)] = value + else: + setattr(parent, key, value) + else: + parent[key] = value + + _access_structure(obj, key, callback=f) diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index b28219be81664af95d9702a6d28bf4debd868904..05540312ef0a016612a06405c47122e4c30e28bf 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -6,8 +6,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import os -import tempfile from collections import OrderedDict from io import BytesIO @@ -29,7 +27,9 @@ from megengine.module import ( Sequential, Softmax, ) +from megengine.module.module import _access_structure from megengine.quantization.quantize import quantize, quantize_qat +from megengine.utils.module_utils import get_expand_structure, set_expand_structure class MLP(Module): @@ -45,146 +45,6 @@ class MLP(Module): return x -def has_gpu(num=1): - try: - mgb.comp_node("gpu{}".format(num - 1)) - except mgb.MegBrainError: - return False - - return True - - -def randomNp(*args): - for arg in args: - assert isinstance(arg, int) - return np.random.random(args) - - -def randomTorch(*args): - import torch # pylint: disable=import-outside-toplevel - - for arg in args: - assert isinstance(arg, int) - return torch.tensor(randomNp(*args), dtype=torch.float32) - - -def graph_mode(*modes): - if not set(modes).issubset({"eager", "static"}): - raise ValueError("graph mode must be in (eager, static)") - - def decorator(func): - def wrapper(*args, **kwargs): - if "eager" in set(modes): - func(*args, **kwargs) - if "static" in set(modes): - with Graph() as cg: - cg.set_option("eager_evaluation", False) - func(*args, **kwargs) - - return wrapper - - return decorator - - -def _default_compare_fn(x, y): - np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) - - -def opr_test( - cases, - func, - mode=("eager", "static", "dynamic_shape"), - compare_fn=_default_compare_fn, - ref_fn=None, - **kwargs -): - """ - mode: the list of test mode which are eager, static and dynamic_shape - will test all the cases if None. - func: the function to run opr. - compare_fn: the function to compare the result and expected, use np.testing.assert_allclose if None. - ref_fn: the function to generate expected data, should assign output if None. - cases: the list which have dict element, the list length should be 2 for dynamic shape test. - and the dict should have input, - and should have output if ref_fn is None. - should use list for multiple inputs and outputs for each case. - kwargs: The additional kwargs for opr func. - - simple examples: - - dtype = np.float32 - cases = [{"input": [10, 20]}, {"input": [20, 30]}] - opr_test(cases, - F.eye, - ref_fn=lambda n, m: np.eye(n, m).astype(dtype), - dtype=dtype) - - """ - - def check_results(results, expected): - if not isinstance(results, Tuple): - results = (results,) - for r, e in zip(results, expected): - compare_fn(r, e) - - def get_trace_fn(func, enabled, symbolic): - jit.trace.enabled = enabled - return jit.trace(func, symbolic=symbolic) - - def get_param(cases, idx): - case = cases[idx] - inp = case.get("input", None) - outp = case.get("output", None) - if inp is None: - raise ValueError("the test case should have input") - if not isinstance(inp, List): - inp = (inp,) - else: - inp = tuple(inp) - if ref_fn is not None and callable(ref_fn): - outp = ref_fn(*inp) - if outp is None: - raise ValueError("the test case should have output or reference function") - if not isinstance(outp, List): - outp = (outp,) - else: - outp = tuple(outp) - - return inp, outp - - if not set(mode).issubset({"eager", "static", "dynamic_shape"}): - raise ValueError("opr test mode must be in (eager, static, dynamic_shape)") - - if len(cases) == 0: - raise ValueError("should give one case at least") - - if "dynamic_shape" in set(mode): - if len(cases) != 2: - raise ValueError("should give 2 cases for dynamic shape test") - - if not callable(func): - raise ValueError("the input func should be callable") - - inp, outp = get_param(cases, 0) - - def run(*args, **kwargs): - return func(*args, **kwargs) - - if "eager" in set(mode): - f = get_trace_fn(run, False, False) - results = f(*inp, **kwargs) - check_results(results, outp) - - if "static" in set(mode) or "dynamic_shape" in set(mode): - f = get_trace_fn(run, True, True) - results = f(*inp, **kwargs) - check_results(results, outp) - if "dynamic_shape" in set(mode): - inp, outp = get_param(cases, 1) - results = f(*inp, **kwargs) - check_results(results, outp) - - class MyModule(Module): class InnerModule(Module): def __init__(self): @@ -306,13 +166,13 @@ def test_module_api_hooks(): post_hook_num = 0 hooks = [] - def pre_hook(module, inputs): + def pre_hook(_, inputs): nonlocal pre_hook_num pre_hook_num += 1 modified_inputs = tuple(inp + 1 for inp in inputs) return modified_inputs - def post_hook(module, inputs, outputs): + def post_hook(_, __, outputs): nonlocal post_hook_num post_hook_num += 1 outputs += 1 @@ -376,7 +236,7 @@ class MyModule2(Module): def test_expand_structure(): m = MyModule2() - assert list(m.named_modules()) == [ + rst = [ ("", m), ("a.0", m.a[0]), ("a.1.x", m.a[1]["x"]), @@ -387,6 +247,16 @@ def test_expand_structure(): ("a.2.0.bn", m.a[2][0].bn), ("bn", m.bn), ] + assert list(m.named_modules()) == rst + + for item in rst[1:]: + assert get_expand_structure(m, item[0]) == item[1] + + for item in reversed(rst[1:]): + if _access_structure(m, item[0], lambda p, k, o: isinstance(p, tuple)): + continue + set_expand_structure(m, item[0], "TEST_VALUE") + assert get_expand_structure(m, item[0]) == "TEST_VALUE" def test_flatten_others(): @@ -603,21 +473,6 @@ def test_pickle_module(): np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) -@pytest.mark.skip(reason="under development") -def test_dump_model(): - data_shape = (2, 28) - data = Tensor(np.random.random(data_shape)) - mlp = MLP() - pred = mlp(data) - f = tempfile.NamedTemporaryFile(delete=False) - f_name = f.name - try: - mge.dump(pred, f_name) - finally: - f.close() - os.unlink(f_name) - - def test_load_quantized(): from megengine.core.tensor import dtype