未验证 提交 5d96b6e0 编写于 作者: C Chen Weihang 提交者: GitHub

Add Queue.get delay for multiprocess data loader (#22604) (#22640)

上级 750c6f42
......@@ -34,8 +34,9 @@ if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
# NOTE: [ avoid hanging ] This value is used in getting data from another process
MP_CHECK_TIMEOUT = 10
# NOTE: [ avoid hanging ] These value is used in getting data from another process
QUEUE_GET_TIMEOUT = 5
MAX_GET_FAILED_TIME = 12
__all__ = ['PyReader', 'DataLoader']
......@@ -485,6 +486,17 @@ class DygraphGeneratorLoader(DataLoaderBase):
signal.signal(signal.SIGCHLD, __handler__)
def _exit_thread_expectedly(self):
self._thread_done_event.set()
self._blocking_queue.close()
self._data_queue.close()
def _exit_thread_unexpectedly(self):
self._thread_done_event.set()
self._blocking_queue.kill()
self._data_queue.close()
logging.error("DataLoader reader thread raised an exception!")
def _reader_process_loop(self):
try:
# set signal handler
......@@ -506,6 +518,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
six.reraise(*sys.exc_info())
def _reader_thread_loop_with_process(self):
get_sample_try_time = 0
while not self._thread_done_event.is_set():
try:
# NOTE: [ avoid hanging ] Even with carefully designed data dependencies
......@@ -513,10 +526,21 @@ class DygraphGeneratorLoader(DataLoaderBase):
# still happen when data in queue is corrupted (e.g., due to
# Queue.cancel_join_thread or unexpected exit). So we set a timeout whenever
# we try to get data from `data_queue`
sample = self._data_queue.get(timeout=MP_CHECK_TIMEOUT)
sample = self._data_queue.get(timeout=QUEUE_GET_TIMEOUT)
get_sample_try_time = 0
except queue.Empty:
self._thread_done_event.set()
logging.error("The reader has not read data for a long time.")
get_sample_try_time += 1
if get_sample_try_time > MAX_GET_FAILED_TIME:
self._exit_thread_unexpectedly()
raise RuntimeError(
"DataLoader reader thread has not read data for a long time (60s)."
)
else:
# NOTE: [ avoid failed quickly ] Sometimes if the reader child process has a heavy burden,
# the child process has no enough time to put the data in the queue when the main process
# start trying to get data from queue. At this time, failure to read data should not be
# counted as a fatal error, there should be a certain number of attempts.
continue
if not self._thread_done_event.is_set():
if sample is not None:
......@@ -532,20 +556,10 @@ class DygraphGeneratorLoader(DataLoaderBase):
if not self._blocking_queue.push(array):
self._blocking_queue.close()
except:
self._thread_done_event.set()
self._blocking_queue.kill()
self._data_queue.close()
logging.warning(
"DygraphDataLoader reader thread raised an exception."
)
self._exit_thread_unexpectedly()
six.reraise(*sys.exc_info())
else:
self._thread_done_event.set()
self._blocking_queue.close()
self._data_queue.close()
else:
self._blocking_queue.kill()
self._data_queue.close()
self._exit_thread_expectedly()
def _reader_thread_loop(self):
try:
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import sys
import time
import unittest
import numpy as np
import paddle.fluid as fluid
......@@ -20,10 +21,18 @@ from paddle.fluid import core
import paddle.compat as cpt
def get_random_images_and_labels(image_shape, label_shape):
image = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return image, label
class TestDygraphhDataLoaderWithException(unittest.TestCase):
def setUp(self):
self.batch_size = 8
self.batch_num = 4
self.capacity = 2
self.epoch_num = 1
self.capacity = 5
def test_not_capacity(self):
with fluid.dygraph.guard():
......@@ -77,6 +86,34 @@ class TestDygraphhDataLoaderWithException(unittest.TestCase):
exception = ex
self.assertIsNotNone(exception)
def test_multi_process_with_get_timeout(self):
def slow_batch_generator_creator(batch_size, batch_num):
def __reader__():
for _ in range(batch_num):
time.sleep(80)
batch_image, batch_label = get_random_images_and_labels(
[batch_size, 784], [batch_size, 1])
yield batch_image, batch_label
return __reader__
with fluid.dygraph.guard():
loader = fluid.io.DataLoader.from_generator(
capacity=self.capacity, use_multiprocess=True)
loader.set_batch_generator(
slow_batch_generator_creator(self.batch_size, self.batch_num),
places=fluid.CPUPlace())
exception = None
try:
for _ in range(self.epoch_num):
for image, _ in loader():
fluid.layers.relu(image)
except core.EnforceNotMet as ex:
self.assertIn("Blocking queue is killed",
cpt.get_exception_message(ex))
exception = ex
self.assertIsNotNone(exception)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册