未验证 提交 c322e6fb 编写于 作者: B Bo Zhou 提交者: GitHub

importting submodule in xparl (#373)

* fix #370

* add unit tests for importting submodule in xparl

* add comment

* revert scripts.py

* yapf

* unit test

* fix relative path bug

* read files using relative path

* remove print

* update CMake

* add try times

* fix the bug in log_server_test.py

* fix relative path bug&add unit test

* update utils.py

* update Cmake

* fix pyc bug

* Update utils.py
上级 28ea49fc
......@@ -30,10 +30,20 @@ function(py_test TARGET_NAME)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS ENVS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND python -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300)
if (${FILE_NAME} MATCHES ".*abs_test.py")
add_test(NAME ${TARGET_NAME}"_with_abs_path"
COMMAND python -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
set_tests_properties(${TARGET_NAME}"_with_abs_path" PROPERTIES TIMEOUT 300)
else()
get_filename_component(WORKING_DIR ${py_test_SRCS} DIRECTORY)
get_filename_component(FILE_NAME ${py_test_SRCS} NAME)
get_filename_component(COMBINED_PATH ${CMAKE_CURRENT_SOURCE_DIR}/${WORKING_DIR} ABSOLUTE)
add_test(NAME ${TARGET_NAME}
COMMAND python -u ${FILE_NAME} ${py_test_ARGS}
WORKING_DIRECTORY ${COMBINED_PATH})
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 300)
endif()
endfunction()
function(import_test TARGET_NAME)
......
......@@ -50,7 +50,6 @@ class Client(object):
distributed_files (list): A list of files to be distributed at all
remote instances(e,g. the configuration
file for initialization) .
"""
self.master_address = master_address
self.process_id = process_id
......@@ -95,34 +94,39 @@ class Client(object):
pyfiles['python_files'] = {}
pyfiles['other_files'] = {}
code_files = filter(lambda x: x.endswith('.py'), os.listdir('./'))
try:
for file in code_files:
assert os.path.exists(file)
with open(file, 'rb') as code_file:
code = code_file.read()
pyfiles['python_files'][file] = code
for file in distributed_files:
assert os.path.exists(file)
assert not os.path.isabs(
file
), "[XPARL] Please do not distribute a file with absolute path."
with open(file, 'rb') as f:
content = f.read()
pyfiles['other_files'][file] = content
# append entry file to code list
main_file = sys.argv[0]
with open(main_file, 'rb') as code_file:
main_file = sys.argv[0]
main_folder = './'
sep = os.sep
if sep in main_file:
main_folder = sep.join(main_file.split(sep)[:-1])
code_files = filter(lambda x: x.endswith('.py'),
os.listdir(main_folder))
for file_name in code_files:
file_path = os.path.join(main_folder, file_name)
assert os.path.exists(file_path)
with open(file_path, 'rb') as code_file:
code = code_file.read()
# parl/remote/remote_decorator.py -> remote_decorator.py
file_name = main_file.split(os.sep)[-1]
pyfiles['python_files'][file_name] = code
except AssertionError as e:
raise Exception(
'Failed to create the client, the file {} does not exist.'.
format(file))
# append entry file to code list
assert os.path.isfile(
main_file
), "[xparl] error occurs when distributing files. cannot find the entry file:{} in current working directory: {}".format(
main_file, os.getcwd())
with open(main_file, 'rb') as code_file:
code = code_file.read()
# parl/remote/remote_decorator.py -> remote_decorator.py
file_name = main_file.split(os.sep)[-1]
pyfiles['python_files'][file_name] = code
for file_name in distributed_files:
assert os.path.exists(file_name)
assert not os.path.isabs(
file_name
), "[XPARL] Please do not distribute a file with absolute path."
with open(file_name, 'rb') as f:
content = f.read()
pyfiles['other_files'][file_name] = content
return cloudpickle.dumps(pyfiles)
def _create_sockets(self, master_address):
......
......@@ -311,8 +311,6 @@ class Job(object):
try:
file_name, class_name, end_of_file = cloudpickle.loads(
message[1])
#/home/nlp-ol/Firework/baidu/nlp/evokit/python_api/es_agent -> es_agent
file_name = file_name.split(os.sep)[-1]
cls = load_remote_class(file_name, class_name, end_of_file)
args, kwargs = cloudpickle.loads(message[2])
logfile_path = os.path.join(self.log_dir, 'stdout.log')
......
......@@ -19,6 +19,7 @@ import time
import zmq
import numpy as np
import inspect
import sys
from parl.utils import get_ip_address, logger, to_str, to_byte
from parl.utils.communication import loads_argument, loads_return,\
......@@ -27,6 +28,7 @@ from parl.remote import remote_constants
from parl.remote.exceptions import RemoteError, RemoteAttributeError,\
RemoteDeserializeError, RemoteSerializeError, ResourceError
from parl.remote.client import get_global_client
from parl.remote.utils import locate_remote_file
def remote_class(*args, **kwargs):
......@@ -120,13 +122,22 @@ def remote_class(*args, **kwargs):
self.job_shutdown = False
self.send_file(self.job_socket)
file_name = inspect.getfile(cls)[:-3]
module_path = inspect.getfile(cls)
if module_path.endswith('pyc'):
module_path = module_path[:-4]
elif module_path.endswith('py'):
module_path = module_path[:-3]
else:
raise FileNotFoundError(
"cannot not find the module:{}".format(module_path))
res = inspect.getfile(cls)
file_path = locate_remote_file(module_path)
cls_source = inspect.getsourcelines(cls)
end_of_file = cls_source[1] + len(cls_source[0])
class_name = cls.__name__
self.job_socket.send_multipart([
remote_constants.INIT_OBJECT_TAG,
cloudpickle.dumps([file_name, class_name, end_of_file]),
cloudpickle.dumps([file_path, class_name, end_of_file]),
cloudpickle.dumps([args, kwargs]),
])
message = self.job_socket.recv_multipart()
......
......@@ -24,6 +24,7 @@ import time
import unittest
import requests
requests.adapters.DEFAULT_RETRIES = 5
import parl
from parl.remote.client import disconnect, get_global_client
......@@ -128,7 +129,6 @@ class TestLogServer(unittest.TestCase):
monitor_file = __file__.replace('log_server_test.pyc', '../monitor.py')
monitor_file = monitor_file.replace('log_server_test.py',
'../monitor.py')
command = [
sys.executable, monitor_file, "--monitor_port",
str(monitor_port), "--address", "localhost:" + str(master_port)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
@parl.remote_class
class B(object):
def add_sum(self, a, b):
return a + b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import parl
import time
import threading
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
class TestImport(unittest.TestCase):
def tearDown(self):
disconnect()
def test_import_local_module(self):
from Module2 import B
port = 8448
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect("localhost:8448")
obj = B()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import parl
import time
import threading
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
class TestImport(unittest.TestCase):
def tearDown(self):
disconnect()
def test_import_local_module(self):
from Module2 import B
port = 8442
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect("localhost:8442")
obj = B()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
def test_import_subdir_module_0(self):
from subdir import Module
port = 8443
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect(
"localhost:8443",
distributed_files=['./subdir/Module.py', './subdir/__init__.py'])
obj = Module.A()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
def test_import_subdir_module_1(self):
from subdir.Module import A
port = 8444
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker = Worker('localhost:{}'.format(port), 1)
time.sleep(10)
parl.connect(
"localhost:8444",
distributed_files=['./subdir/Module.py', './subdir/__init__.py'])
obj = A()
res = obj.add_sum(10, 5)
self.assertEqual(res, 15)
worker.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
@parl.remote_class
class A(object):
def add_sum(self, a, b):
return a + b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -13,8 +13,11 @@
# limitations under the License.
import sys
from contextlib import contextmanager
import os
__all__ = ['load_remote_class', 'redirect_stdout_to_file']
__all__ = [
'load_remote_class', 'redirect_stdout_to_file', 'locate_remote_file'
]
def simplify_code(code, end_of_file):
......@@ -32,7 +35,7 @@ def simplify_code(code, end_of_file):
def data_process():
XXXX
------------------>
The last two lines of the above code block will be removed as they are not class related.
The last two lines of the above code block will be removed as they are not class-related.
"""
to_write_lines = []
for i, line in enumerate(code):
......@@ -60,12 +63,18 @@ def load_remote_class(file_name, class_name, end_of_file):
with open(file_name + '.py') as t_file:
code = t_file.readlines()
code = simplify_code(code, end_of_file)
module_name = 'xparl_' + file_name
tmp_file_name = 'xparl_' + file_name + '.py'
#folder/xx.py -> folder/xparl_xx.py
file_name = file_name.split(os.sep)
prefix = os.sep.join(file_name[:-1])
if prefix == "":
prefix = '.'
module_name = prefix + os.sep + 'xparl_' + file_name[-1]
tmp_file_name = module_name + '.py'
with open(tmp_file_name, 'w') as t_file:
for line in code:
t_file.write(line)
mod = __import__(module_name)
module_name = module_name.lstrip('.' + os.sep).replace(os.sep, '.')
mod = __import__(module_name, globals(), locals(), [class_name], 0)
cls = getattr(mod, class_name)
return cls
......@@ -74,6 +83,9 @@ def load_remote_class(file_name, class_name, end_of_file):
def redirect_stdout_to_file(file_path):
"""Redirect stdout (e.g., `print`) to specified file.
Args:
file_path: Path of the file to output the stdout.
Example:
>>> print('test')
test
......@@ -81,10 +93,6 @@ def redirect_stdout_to_file(file_path):
... print('test') # Output nothing, `test` is printed to `test.log`.
>>> print('test')
test
Args:
file_path: Path of the file to output the stdout.
"""
tmp = sys.stdout
f = open(file_path, 'a')
......@@ -94,3 +102,33 @@ def redirect_stdout_to_file(file_path):
finally:
sys.stdout = tmp
f.close()
def locate_remote_file(module_path):
"""xparl has to locate the file that has the class decorated by parl.remote_class.
This function returns the relative path between this file and the entry file.
Args:
module_path: Absolute path of the module.
Example:
module_path: /home/user/dir/subdir/my_module
entry_file: /home/user/dir/main.py
--------> relative_path: subdir/my_module
"""
entry_file = sys.argv[0]
entry_file = entry_file.split(os.sep)[-1]
entry_path = None
for path in sys.path:
to_check_path = os.path.join(path, entry_file)
if os.path.isfile(to_check_path):
entry_path = path
break
if entry_path is None or \
(module_path.startswith(os.sep) and entry_path != module_path[:len(entry_path)]):
raise FileNotFoundError("cannot locate the remote file")
if module_path.startswith(os.sep):
relative_module_path = '.' + module_path[len(entry_path):]
else:
relative_module_path = module_path
return relative_module_path
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册