diff --git a/imperative/python/megengine/jit/xla_backend.py b/imperative/python/megengine/jit/xla_backend.py index db1f27425c042f034bce8ec0dccc3348417d09cc..a41427eaa0c35acb8fd39cd8752eedeb5e94b2dd 100644 --- a/imperative/python/megengine/jit/xla_backend.py +++ b/imperative/python/megengine/jit/xla_backend.py @@ -1,6 +1,8 @@ from collections import OrderedDict, defaultdict -from .. import tensor +import numpy as np + +from .. import _full_sync, tensor from ..core._imperative_rt import CompNode from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import ( @@ -15,16 +17,11 @@ from ..device import get_default_device from ..utils.dlpack import from_dlpack, to_dlpack from .tracing import trace -# try: -# from mge_xlalib.xla_extension import ArrayImpl - -# from ..xla.lib import xla_client as xc -# except ImportError: -# pass - -from mge_xlalib.xla_extension import ArrayImpl - -from ..xla.lib import xla_client as xc +try: + from mge_xlalib.xla_extension import ArrayImpl + from ..xla.lib import xla_client as xc +except ImportError as e: + pass xla_client_compute_stream = None @@ -93,6 +90,39 @@ class xla_trace(trace): def unset_env(self): set_use_xla_backend(self.orig_use_xla) + def convert_params_to_xla(self): + from ..device import coalesce_free_memory + from ..utils.module_utils import get_expand_structure + from ..tensor import Tensor + + backend = self.xla_exec.backend + devices = backend.local_devices() + _, device_id, _ = CompNode(get_default_device()).physical_locator + device_index = ( + 0 if len(devices) == 0 else [d.id for d in devices].index(device_id) + ) + device = devices[device_index] + for attr, _ in self.attr_to_key.items(): + param = get_expand_structure(attr[0], attr[1]) + param._reset(param.to("cpux")) + + for tensor, _ in self.opt_param_dict.items(): + tensor._reset(tensor.to("cpux")) + + def as_xla_array(tensor, backend, device): + np_array = tensor.numpy() + if np_array.shape == (): + np_array = np_array[np.newaxis] + xla_array = backend.buffer_from_pyval(np_array, device) + tensor._reset(Tensor(xla_array)) + + for attr, _ in self.attr_to_key.items(): + param = get_expand_structure(attr[0], attr[1]) + as_xla_array(param, backend, device) + + for tensor, _ in self.opt_param_dict.items(): + as_xla_array(tensor, backend, device) + def compile(self): from ..xla import build_xla from ..traced_module.pytree import SUPPORTED_LEAF_TYPE, register_supported_type @@ -102,13 +132,6 @@ class xla_trace(trace): from ..distributed import get_mm_server_addr, is_distributed assert self.traced - if self.overall: - for attr, _ in self.attr_to_key.items(): - param = get_expand_structure(attr[0], attr[1]) - param._reset(param.to("cpux")) - - for tensor, _ in self.opt_param_dict.items(): - tensor._reset(tensor.to("cpux")) self.xla_exec, self.inp_ids, self.out_ids = build_xla( self, return_with_io=True, @@ -116,6 +139,8 @@ class xla_trace(trace): ip=get_mm_server_addr()[0] if is_distributed() else None, port=get_mm_server_addr()[1] + 1 if is_distributed() else None, ) + if self.overall: + self.convert_params_to_xla() id2inpidx = defaultdict(list) id2outidx = defaultdict(list) for idx, id in enumerate(self.inp_ids): diff --git a/imperative/python/megengine/xla/compile.py b/imperative/python/megengine/xla/compile.py index 7efa4931ba519e8a04ad923ea25cd86eff2ed81e..dcc470deb4fcc1b6a046938e8e1f578d80f8a581 100644 --- a/imperative/python/megengine/xla/compile.py +++ b/imperative/python/megengine/xla/compile.py @@ -73,16 +73,9 @@ class InputsHandler: if i._is_external_value(): rst.append([i._external_obj()]) else: - if "gpu" in i.device.physical_name: - capsule = to_dlpack(i) - xla_array = self.from_dlpack(capsule) - rst.append([xla_array]) - else: - r = self.handler( - self.local_devices, [self.input_indices[idx],], [i,] - )[0] - rst.append(r) - i._reset(tensor(r[0])) + capsule = to_dlpack(i) + xla_array = self.from_dlpack(capsule) + rst.append([xla_array]) return rst def __str__(self): diff --git a/imperative/python/test/unit/xla/functional/test_xla_convert.py b/imperative/python/test/unit/xla/functional/test_xla_convert.py index 50e4b9a4df35c69e22d06d32c2848252c4b2fe4b..d1a35baf987e17b02ce2a7d5fac5bfb04ada1e9a 100644 --- a/imperative/python/test/unit/xla/functional/test_xla_convert.py +++ b/imperative/python/test/unit/xla/functional/test_xla_convert.py @@ -3,27 +3,48 @@ import platform import numpy as np import pytest +import megengine.distributed as dist import megengine.functional as F -import megengine.jit as jit +import megengine.functional.distributed as fdist import megengine.tensor as tensor from megengine import autodiff, is_cuda_available from megengine.autodiff.grad_manager import GradManager -from meg_xlalib.xla_extension import ArrayImpl +from megengine.core._imperative_rt.core2 import ( + is_external_convert, + set_external_convert_hook, +) +from megengine.jit import xla_trace +from megengine.module import Conv2d -def test_external_flag_set(): +@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") +@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") +@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now") +def test_external_tsf_set(): + from mge_xlalib.xla_extension import ArrayImpl @xla_trace(capture_as_const=True) - def test_fun(): - pass - + def test_func(inp): + return inp + assert is_external_convert() + inp = tensor(np.random.random((9, 9, 32, 32))) + mge_inp = test_func(inp) + xla_inp = test_func(inp) + assert xla_inp._is_external_value() + assert isinstance(xla_inp._external_obj(), ArrayImpl) + assert mge_inp.shape == xla_inp.shape + assert mge_inp.dtype == xla_inp.dtype + assert not xla_inp._is_external_value() +@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") +@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") +@pytest.mark.skipif(not is_cuda_available(), reason="only support cuda now") def test_external_value(): - m = Conv2d(9,9, 3,groups=9) + m = Conv2d(9, 9, 3, groups=9) gm = GradManager() gm.attach(m.parameters()) @@ -39,8 +60,44 @@ def test_external_value(): model.weight.grad = None return ig, wg - inp = tensor(np.random.random((9,9, 32, 32)))*100 + inp = tensor(np.random.random((9, 9, 32, 32))) * 100 + + mge_ig, mge_wg = conv_grad(inp, m) + xla_ig, xla_wg = conv_grad(inp, m) + np.testing.assert_allclose(mge_ig.numpy(), xla_ig.numpy()) + np.testing.assert_allclose(mge_wg.numpy(), xla_wg.numpy(), atol=1e-5) + + +@pytest.mark.skipif(int(platform.python_version_tuple()[1]) < 8, reason="need py38") +@pytest.mark.skipif(platform.system() != "Linux", reason="only support linux now") +@pytest.mark.require_ngpu(2) +@pytest.mark.isolated_distributed +def test_distributed_convert(): + from mge_xlalib.xla_extension import ArrayImpl + + def tester(ishape, n_gpus, dtype=None): + @dist.launcher(n_gpus=n_gpus) + def worker(data): + rank = dist.get_rank() + inp = tensor(data[rank]) + + @xla_trace(without_host=True) + def func1(inp): + return fdist.all_reduce_sum(inp) + + mge_rst = func1(inp) + xla_rst = func1(inp) + assert xla_rst._is_external_value() + assert isinstance(xla_rst._external_obj(), ArrayImpl) + + np.testing.assert_allclose(mge_rst.numpy(), xla_rst.numpy(), atol=1e-5) + assert mge_rst.shape == xla_rst.shape + assert mge_rst.dtype == xla_rst.dtype + assert not xla_rst._is_external_value() + + x = np.random.randn(*ishape).astype(dtype) + y = np.random.randn(*ishape).astype(dtype) + data = (x, y) + worker(data) - a, b = conv_grad(inp, m) - a1, b1 = conv_grad(inp, m) - np.testing.assert_allclose(a.numpy(), a1.numpy()) \ No newline at end of file + tester((16, 1, 64,), 2)