提交 866ab6e6 编写于 作者: M Megvii Engine Team

fix(imperative/dataloader): collect garbage explicitly before starting a process in dataloader

GitOrigin-RevId: 3c3b51ad9ca02a26baf80061970d098d416a2f6e
上级 67167cb3
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import gc
import math import math
import multiprocessing import multiprocessing
import platform import platform
...@@ -246,6 +247,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): ...@@ -246,6 +247,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
), ),
daemon=True, daemon=True,
) )
gc.collect()
self.task_feeding_worker.start() self.task_feeding_worker.start()
self.workers = [] self.workers = []
...@@ -262,6 +264,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): ...@@ -262,6 +264,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
), ),
daemon=True, daemon=True,
) )
gc.collect()
worker.start() worker.start()
self.workers.append(worker) self.workers.append(worker)
...@@ -293,6 +296,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): ...@@ -293,6 +296,7 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
), ),
daemon=True, daemon=True,
) )
gc.collect()
self.data_collecting_worker.start() self.data_collecting_worker.start()
self.__initialized = True self.__initialized = True
...@@ -465,6 +469,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): ...@@ -465,6 +469,7 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
self.recieve_worker = multiprocessing.Process( self.recieve_worker = multiprocessing.Process(
target=self._worker_to_raw_data_queues, daemon=True target=self._worker_to_raw_data_queues, daemon=True
) )
gc.collect()
self.recieve_worker.start() self.recieve_worker.start()
self.transform_workers = [] self.transform_workers = []
...@@ -472,12 +477,14 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): ...@@ -472,12 +477,14 @@ class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
worker = multiprocessing.Process( worker = multiprocessing.Process(
target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True
) )
gc.collect()
worker.start() worker.start()
self.transform_workers.append(worker) self.transform_workers.append(worker)
self.collect_worker = multiprocessing.Process( self.collect_worker = multiprocessing.Process(
target=self._worker_to_batch_queue, daemon=True target=self._worker_to_batch_queue, daemon=True
) )
gc.collect()
self.collect_worker.start() self.collect_worker.start()
self.__initialized = True self.__initialized = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册