未验证 提交 73672709 编写于 作者: Z Zheyue Tan 提交者: GitHub

redirect stdout to log file while initializing the remote actor instance (#294)

* redirect stdout to log file while initializing the remote actor instance

* add test for catching output in  `Actor.__init__`
上级 8c7f1922
......@@ -36,7 +36,7 @@ from parl.utils.communication import loads_argument, loads_return,\
from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError
from parl.remote.message import InitializedJob
from parl.remote.utils import load_remote_class
from parl.remote.utils import load_remote_class, redirect_stdout_to_file
class Job(object):
......@@ -315,7 +315,9 @@ class Job(object):
file_name = file_name.split(os.sep)[-1]
cls = load_remote_class(file_name, class_name, end_of_file)
args, kwargs = cloudpickle.loads(message[2])
obj = cls(*args, **kwargs)
logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path):
obj = cls(*args, **kwargs)
except Exception as e:
traceback_str = str(traceback.format_exc())
error_str = str(e)
......@@ -406,11 +408,8 @@ class Job(object):
# Redirect stdout to stdout.log temporarily
logfile_path = os.path.join(self.log_dir, 'stdout.log')
with open(logfile_path, 'a') as f:
tmp = sys.stdout
sys.stdout = f
with redirect_stdout_to_file(logfile_path):
ret = getattr(obj, function_name)(*args, **kwargs)
sys.stdout = tmp
ret = dumps_return(ret)
......
......@@ -29,7 +29,7 @@ import parl
from parl.remote.client import disconnect, get_global_client
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.utils import _IS_WINDOWS, get_free_tcp_port
from parl.utils import _IS_WINDOWS
@parl.remote_class
......@@ -38,6 +38,8 @@ class Actor(object):
self.number = number
self.arg1 = arg1
self.arg2 = arg2
print("Init actor...")
self.init_output = "Init actor...\n"
def sim_output(self, start, end):
output = ""
......@@ -48,7 +50,7 @@ class Actor(object):
print(i)
output += str(i)
output += "\n"
return output
return self.init_output + output
class TestLogServer(unittest.TestCase):
......
......@@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from contextlib import contextmanager
__all__ = ['load_remote_class']
__all__ = ['load_remote_class', 'redirect_stdout_to_file']
def simplify_code(code, end_of_file):
......@@ -66,3 +68,29 @@ def load_remote_class(file_name, class_name, end_of_file):
mod = __import__(module_name)
cls = getattr(mod, class_name)
return cls
@contextmanager
def redirect_stdout_to_file(file_path):
"""Redirect stdout (e.g., `print`) to specified file.
Example:
>>> print('test')
test
>>> with redirect_stdout_to_file('test.log'):
... print('test') # Output nothing, `test` is printed to `test.log`.
>>> print('test')
test
Args:
file_path: Path of the file to output the stdout.
"""
tmp = sys.stdout
f = open(file_path, 'a')
sys.stdout = f
try:
yield
finally:
sys.stdout = tmp
f.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册