提交 4a58b178 编写于 作者: M Maarten de Vries 提交者: François Chollet

Rethrow original exception in GeneratorEnqueuer. (#8485)

* Fix randint usage in test_multiprocessing

* Rethrow original exception in GeneratorEnqueuer.

* Use multiprocessing.Manager to obtain a race-free Queue.
上级 a27b4a51
......@@ -11,6 +11,7 @@ import sys
import tarfile
import threading
import time
import traceback
import zipfile
from abc import abstractmethod
from multiprocessing.pool import ThreadPool
......@@ -553,7 +554,7 @@ class OrderedEnqueuer(SequenceEnqueuer):
yield inputs
except Exception as e:
self.stop()
raise StopIteration(e)
six.raise_from(StopIteration(e), e)
def _send_sequence(self):
"""Send current Sequence to all workers."""
......@@ -614,6 +615,7 @@ class GeneratorEnqueuer(SequenceEnqueuer):
self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self._manager = None
self.queue = None
self.seed = seed
......@@ -631,18 +633,27 @@ class GeneratorEnqueuer(SequenceEnqueuer):
try:
if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
generator_output = next(self._generator)
self.queue.put(generator_output)
self.queue.put((True, generator_output))
else:
time.sleep(self.wait_time)
except StopIteration:
break
except Exception:
except Exception as e:
# Can't pick tracebacks.
# As a compromise, print the traceback and pickle None instead.
if self._use_multiprocessing:
traceback.print_exc()
setattr(e, '__traceback__', None)
elif not hasattr(e, '__traceback__'):
setattr(e, '__traceback__', sys.exc_info()[2])
self.queue.put((False, e))
self._stop_event.set()
raise
break
try:
if self._use_multiprocessing:
self.queue = multiprocessing.Queue(maxsize=max_queue_size)
self._manager = multiprocessing.Manager()
self.queue = self._manager.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
self.queue = queue.Queue()
......@@ -686,9 +697,8 @@ class GeneratorEnqueuer(SequenceEnqueuer):
else:
thread.join(timeout)
if self._use_multiprocessing:
if self.queue is not None:
self.queue.close()
if self._manager:
self._manager.shutdown()
self._threads = []
self._stop_event = None
......@@ -704,12 +714,22 @@ class GeneratorEnqueuer(SequenceEnqueuer):
"""
while self.is_running():
if not self.queue.empty():
inputs = self.queue.get()
if inputs is not None:
yield inputs
success, value = self.queue.get()
# Rethrow any exceptions found in the queue
if not success:
six.reraise(value.__class__, value, value.__traceback__)
# Yield regular values
if value is not None:
yield value
else:
all_finished = all([not thread.is_alive() for thread in self._threads])
if all_finished and self.queue.empty():
raise StopIteration()
else:
time.sleep(self.wait_time)
# Make sure to rethrow the first exception in the queue, if any
while not self.queue.empty():
success, value = self.queue.get()
if not success:
six.reraise(value.__class__, value, value.__traceback__)
......@@ -182,7 +182,7 @@ def test_generator_enqueuer_fail_threads():
FaultSequence()), use_multiprocessing=False)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
with pytest.raises(StopIteration):
with pytest.raises(IndexError):
next(gen_output)
......@@ -191,7 +191,7 @@ def test_generator_enqueuer_fail_processes():
FaultSequence()), use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
with pytest.raises(StopIteration):
with pytest.raises(IndexError):
next(gen_output)
......
......@@ -232,7 +232,7 @@ def test_multiprocessing_fit_error():
"""Raises an exception after a few good batches"""
for i in range(good_batches):
yield (np.random.randint(batch_size, 256, (50, 2)),
np.random.randint(batch_size, 2, 50))
np.random.randint(batch_size, 12, 50))
raise RuntimeError
model = Sequential()
......@@ -241,13 +241,13 @@ def test_multiprocessing_fit_error():
samples = batch_size * (good_batches + 1)
with pytest.raises(StopIteration):
with pytest.raises(RuntimeError):
model.fit_generator(
custom_generator(), samples, 1,
workers=4, use_multiprocessing=True,
)
with pytest.raises(StopIteration):
with pytest.raises(RuntimeError):
model.fit_generator(
custom_generator(), samples, 1,
use_multiprocessing=False,
......@@ -258,25 +258,26 @@ def test_multiprocessing_fit_error():
def test_multiprocessing_evaluate_error():
batch_size = 10
good_batches = 3
workers = 4
def custom_generator():
"""Raises an exception after a few good batches"""
for i in range(good_batches):
yield (np.random.randint(batch_size, 256, (50, 2)),
np.random.randint(batch_size, 2, 50))
np.random.randint(batch_size, 12, 50))
raise RuntimeError
model = Sequential()
model.add(Dense(1, input_shape=(2, )))
model.compile(loss='mse', optimizer='adadelta')
with pytest.raises(StopIteration):
with pytest.raises(RuntimeError):
model.evaluate_generator(
custom_generator(), good_batches + 1, 1,
workers=4, use_multiprocessing=True,
custom_generator(), good_batches * workers + 1, 1,
workers=workers, use_multiprocessing=True,
)
with pytest.raises(StopIteration):
with pytest.raises(RuntimeError):
model.evaluate_generator(
custom_generator(), good_batches + 1, 1,
use_multiprocessing=False,
......@@ -299,13 +300,13 @@ def test_multiprocessing_predict_error():
model.add(Dense(1, input_shape=(5,)))
model.compile(loss='mse', optimizer='adadelta')
with pytest.raises(StopIteration):
with pytest.raises(RuntimeError):
model.predict_generator(
custom_generator(), good_batches * workers + 1, 1,
workers=workers, use_multiprocessing=True,
)
with pytest.raises(StopIteration):
with pytest.raises(RuntimeError):
model.predict_generator(
custom_generator(), good_batches + 1, 1,
use_multiprocessing=False,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册