2021-07-24 postnightly release (8152433d)

上级 cf2b6d24
......@@ -8,7 +8,6 @@ import torch
import torch._C
from torch.testing import FileCheck
from torch.jit.mobile import _load_for_lite_interpreter
from pathlib import Path
from torch.testing._internal.common_utils import (
IS_FBCODE,
......@@ -17,6 +16,7 @@ from torch.testing._internal.common_utils import (
IS_WINDOWS,
TEST_WITH_ROCM,
skipIfRocm,
find_library_location,
)
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
......@@ -74,9 +74,8 @@ class JitBackendTestCase(JitTestCase):
def setUp(self):
super().setUp()
torch_root = Path(__file__).resolve().parent.parent.parent
p = torch_root / 'build' / 'lib' / 'libjitbackend_test.so'
torch.ops.load_library(str(p))
lib_file_path = find_library_location('libjitbackend_test.so')
torch.ops.load_library(str(lib_file_path))
# Subclasses are expected to set up three variables in their setUp methods:
# module - a regular, Python version of the module being tested
# scripted_module - a scripted version of module
......@@ -492,9 +491,8 @@ class JitBackendTestCaseWithCompiler(JitTestCase):
def setUp(self):
super().setUp()
torch_root = Path(__file__).resolve().parent.parent.parent
p = torch_root / 'build' / 'lib' / 'libbackend_with_compiler.so'
torch.ops.load_library(str(p))
lib_file_path = find_library_location('libbackend_with_compiler.so')
torch.ops.load_library(str(lib_file_path))
# Subclasses are expected to set up four variables in their setUp methods:
# module - a regular, Python version of the module being tested
# scripted_module - a scripted version of module
......
......@@ -6,13 +6,18 @@ import unittest
import torch
from typing import Optional
from pathlib import Path
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS, IS_FBCODE
from torch.testing._internal.common_utils import (
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
find_library_location,
)
from torch.testing import FileCheck
if __name__ == "__main__":
......@@ -26,13 +31,8 @@ class TestTorchbind(JitTestCase):
def setUp(self):
if IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE:
raise unittest.SkipTest("non-portable load_library call used in test")
if TEST_WITH_ROCM:
torch_root = Path(torch.__file__).resolve().parent
p = torch_root / 'lib' / 'libtorchbind_test.so'
else:
torch_root = Path(__file__).resolve().parent.parent.parent
p = torch_root / 'build' / 'lib' / 'libtorchbind_test.so'
torch.ops.load_library(str(p))
lib_file_path = find_library_location('libtorchbind_test.so')
torch.ops.load_library(str(lib_file_path))
def test_torchbind(self):
def test_equality(f, cmp_key):
......
......@@ -14,7 +14,6 @@ import traceback
import warnings
import unittest
from math import sqrt
from pathlib import Path
from torch.multiprocessing import Process
from torch.testing import FileCheck
from torch.testing._internal.common_methods_invocations import op_db
......@@ -40,7 +39,14 @@ if sys.version_info >= (3, 7):
if sys.version_info >= (3, 7):
from fx.test_gradual_type import TypeCheckerTest # noqa: F401
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, IS_WINDOWS, IS_FBCODE, IS_MACOS
from torch.testing._internal.common_utils import (
IS_FBCODE,
IS_MACOS,
IS_WINDOWS,
TEST_WITH_ROCM,
find_library_location,
run_tests,
)
from torch.testing._internal.jit_utils import JitTestCase
from fx.named_tup import MyNamedTup
......@@ -109,9 +115,8 @@ class TestFX(JitTestCase):
def setUp(self):
if TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS:
return
torch_root = Path(__file__).resolve().parent.parent
p = torch_root / 'build' / 'lib' / 'libtorchbind_test.so'
torch.ops.load_library(str(p))
lib_file_path = find_library_location('libtorchbind_test.so')
torch.ops.load_library(str(lib_file_path))
def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
"""Check that an nn.Module's results match the GraphModule version
......
......@@ -23,7 +23,7 @@ import warnings
import random
import contextlib
import shutil
import pathlib
from pathlib import Path
import socket
import subprocess
import time
......@@ -199,7 +199,7 @@ UNITTEST_ARGS = [sys.argv[0]] + remaining
torch.manual_seed(SEED)
# CI Prefix path used only on CI environment
CI_TEST_PREFIX = str(pathlib.Path(os.getcwd()))
CI_TEST_PREFIX = str(Path(os.getcwd()))
def wait_for_process(p):
try:
......@@ -2499,3 +2499,13 @@ def has_breakpad() -> bool:
return True
except RuntimeError as e:
return False
def find_library_location(lib_name: str) -> Path:
# return the shared library file in the installed folder if exist,
# else the file in the build folder
torch_root = Path(torch.__file__).resolve().parent
path = torch_root / 'lib' / lib_name
if os.path.exists(path):
return path
torch_root = Path(__file__).resolve().parent.parent.parent
return torch_root / 'build' / 'lib' / lib_name
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册