conftest.py 1.8 KB
Newer Older
1 2 3 4 5 6 7
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8
import os
9
import platform
10 11
import sys

12 13
import pytest

14 15 16
import megengine.functional
import megengine.module
from megengine import Parameter
17
from megengine.core._imperative_rt.core2 import sync
18
from megengine.device import get_device_count
19 20 21 22
from megengine.experimental.autograd import (
    disable_higher_order_directive,
    enable_higher_order_directive,
)
23 24
from megengine.jit import trace as _trace
from megengine.module import Linear, Module
25

26
sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
27

28
_ngpu = get_device_count("gpu")
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47


@pytest.fixture(autouse=True)
def skip_by_ngpu(request):
    if request.node.get_closest_marker("require_ngpu"):
        require_ngpu = int(request.node.get_closest_marker("require_ngpu").args[0])
        if require_ngpu > _ngpu:
            pytest.skip("skipped for ngpu unsatisfied: {}".format(require_ngpu))


@pytest.fixture(autouse=True)
def skip_distributed(request):
    if request.node.get_closest_marker("distributed_isolated"):
        if platform.system() in ("Windows", "Darwin"):
            pytest.skip(
                "skipped for distributed unsupported at platform: {}".format(
                    platform.system()
                )
            )
48 49 50 51 52 53 54 55 56 57


@pytest.fixture(autouse=True)
def resolve_require_higher_order_directive(request):
    marker = request.node.get_closest_marker("require_higher_order_directive")
    if marker:
        enable_higher_order_directive()
    yield
    if marker:
        disable_higher_order_directive()