From 4ccd9a0a86ad550a861c954d70e28ef15741b310 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Wed, 12 May 2021 23:09:32 +0800 Subject: [PATCH] fix dataloader exit hang when join re-enter (#32835) * fix dataloader exit hang when join re-enter. test=develop * double check _shutdown. test=develop --- .../fluid/dataloader/dataloader_iter.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index 52ab8369859..1f928bfc8a6 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -289,10 +289,14 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): # if user exit python program when dataloader is still # iterating, resource may no release safely, so we - # add __del__ function to to CleanupFuncRegistrar - # to make sure __del__ is always called when program + # add _shutdown_on_exit function to to CleanupFuncRegistrar + # to make sure _try_shutdown_all is always called when program # exit for resoure releasing safely - CleanupFuncRegistrar.register(self.__del__) + # worker join may hang for in _try_shutdown_all call in atexit + # for main process is in atexit state in some OS, so we add + # timeout=1 for shutdown function call in atexit, for shutdown + # function call in __del__, we keep it as it is + CleanupFuncRegistrar.register(self._shutdown_on_exit) def _init_workers(self): # multiprocess worker and indice queue list initial as empty @@ -363,7 +367,7 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): self._indices_queues[worker_id].put(None) self._worker_status[worker_id] = False - def _try_shutdown_all(self): + def _try_shutdown_all(self, timeout=None): if not self._shutdown: try: self._exit_thread_expectedly() @@ -376,11 +380,12 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): for i in range(self._num_workers): self._shutdown_worker(i) - for w in self._workers: - w.join() - for q in self._indices_queues: - q.cancel_join_thread() - q.close() + if not self._shutdown: + for w in self._workers: + w.join(timeout) + for q in self._indices_queues: + q.cancel_join_thread() + q.close() finally: core._erase_process_pids(id(self)) self._shutdown = True @@ -560,6 +565,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): def __del__(self): self._try_shutdown_all() + def _shutdown_on_exit(self): + self._try_shutdown_all(1) + def __next__(self): try: # _batches_outstanding here record the total batch data number -- GitLab