diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index 320aa125df7c2131e159af89e297ba3dfc031806..d43f16b37bc858dbddffff27985d29b0fda6d810 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -71,6 +71,7 @@ if sys.platform == "win32": kernel32.SetErrorMode(old_error_mode) +from .core._imperative_rt.core2 import close as _close from .core._imperative_rt.core2 import full_sync as _full_sync from .core._imperative_rt.core2 import sync as _sync from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func @@ -90,7 +91,7 @@ _set_fork_exec_path_for_timed_func( _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() _persistent_cache_impl_ins.reg() -atexit.register(_full_sync) +atexit.register(_close) del _set_fork_exec_path_for_timed_func diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 344046f87e7bfef38320996f74a51295eb5aef0a..1b16fe5619d988fa3d1634877efd16e8beb8f4c9 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -897,6 +897,11 @@ void init_tensor(py::module m) { } } + static constexpr auto sync_py_task_q = []{ + py::gil_scoped_release _; + py_task_q.wait_all_task_finish(); + }; + m.def("set_option", [](std::string name, size_t value){ interpreter_for_py->set_option(name, value); }); m.def("get_option", @@ -928,16 +933,19 @@ void init_tensor(py::module m) { m.def("sync", []() { interpreter_for_py->sync(); - py_task_q.wait_all_task_finish(); - }, - py::call_guard()); + sync_py_task_q(); + }); m.def("full_sync", []() { interpreter_for_py->sync(); CompNode::sync_all(); - py_task_q.wait_all_task_finish(); - }, - py::call_guard()); + sync_py_task_q(); + }); + m.def("close", + []() { + interpreter_for_py->close(); + sync_py_task_q(); + }); py::handle grad_key_type = GradKeyWrapper::wrap_t::type() .def<&GradKeyWrapper::attach>("attach") diff --git a/imperative/python/test/unit/core/test_interpreter.py b/imperative/python/test/unit/core/test_interpreter.py index cee94b612bf2ba3948e4694a2e6502c5b4877ed2..5481901355e0f640e61722cedb8a6c7f5af3b802 100644 --- a/imperative/python/test/unit/core/test_interpreter.py +++ b/imperative/python/test/unit/core/test_interpreter.py @@ -1,3 +1,6 @@ +import subprocess +import sys + import numpy as np import pytest @@ -76,3 +79,14 @@ def test_swap_drop_basic(): z.numpy() _set_swap_flag(False) _set_drop_flag(False) + + +def test_finalize(): + prog = """ +import megengine +with megengine.core.option("enable_host_compute", 0): + x = megengine.tensor(0) + y = x + 1 + y.numpy() +""" + subprocess.check_call([sys.executable, "-c", prog]) diff --git a/imperative/src/impl/event_pool.cpp b/imperative/src/impl/event_pool.cpp index 4f27fa0e20de5686b8ffea4176edec3c563425e2..c6eedee5bd1d40dbc8d9f0f5f95dea5d8fc23a5c 100644 --- a/imperative/src/impl/event_pool.cpp +++ b/imperative/src/impl/event_pool.cpp @@ -67,6 +67,7 @@ std::shared_ptr EventPool::on_comp_node_finalize() { for (auto&& i : m_cn2pool) { i.second.assert_all_freed(); } + m_cn2pool.clear(); return {}; } EventPool::~EventPool() { diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 50f649f138ea981b386607ee6a3c830a2254bd8c..1843c6b222bbe739b610103994f2427a68d6bb2a 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -33,6 +33,7 @@ Interpreter& Interpreter::inst() { } Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { + mgb_assert(check_available(), "Channel already closed"); auto info = alloc(); info->desc.layout = value.layout(); info->desc.comp_node = value.comp_node(); @@ -47,6 +48,7 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { } Handle ChannelImpl::put(const DeviceTensorND& data) { + mgb_assert(check_available(), "Channel already closed"); auto info = alloc(); info->desc.layout = data.layout(); info->desc.comp_node = data.comp_node(); @@ -58,6 +60,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { } void ChannelImpl::del(Handle handle) { + if (!check_available()){ + return; + } mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle); auto* info = reinterpret_cast(handle); m_valid_handle.erase(handle); @@ -65,6 +70,7 @@ void ChannelImpl::del(Handle handle) { } void ChannelImpl::swap_in(Handle handle) { + mgb_assert(check_available(), "Channel already closed"); if (m_worker_state.options.enable_swap) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -74,6 +80,7 @@ void ChannelImpl::swap_in(Handle handle) { } void ChannelImpl::swap_out(Handle handle) { + mgb_assert(check_available(), "Channel already closed"); if (m_worker_state.options.enable_swap) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -83,6 +90,7 @@ void ChannelImpl::swap_out(Handle handle) { } void ChannelImpl::drop(Handle handle) { + mgb_assert(check_available(), "Channel already closed"); if (m_worker_state.options.enable_drop) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -201,6 +209,7 @@ void ChannelImpl::dispatch_kernel( SmallVector ChannelImpl::apply_op( std::shared_ptr op, const SmallVector& inputs) { + mgb_assert(check_available(), "Channel already closed"); for (auto i : inputs) { mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p", i); @@ -237,6 +246,7 @@ SmallVector ChannelImpl::apply_op( } HostTensorND ChannelImpl::get_value(Handle handle) { + mgb_assert(check_available(), "Channel already closed"); // TODO: maybe get_value should be done on host. i.e. delete GetValue mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -269,6 +279,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) { } TensorShape ChannelImpl::get_shape(Handle handle) { + mgb_assert(check_available(), "Channel already closed"); mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); @@ -296,6 +307,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) { } DType ChannelImpl::get_dtype(Handle handle) { + mgb_assert(check_available(), "Channel already closed"); mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); @@ -308,6 +320,7 @@ DType ChannelImpl::get_dtype(Handle handle) { } CompNode ChannelImpl::get_device(Handle handle) { + mgb_assert(check_available(), "Channel already closed"); mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); @@ -320,6 +333,7 @@ CompNode ChannelImpl::get_device(Handle handle) { } DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { + mgb_assert(check_available(), "Channel already closed"); mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); @@ -342,6 +356,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { } void ChannelImpl::sync() { + mgb_assert(check_available(), "Channel already closed"); m_buffer.flush(); if (m_channel_state.profiler->is_profiling()) { m_channel_state.profiler->record_host(); @@ -356,14 +371,26 @@ void ChannelImpl::sync() { } void ChannelImpl::close() { + if (!check_available()) { + return; + } + std::vector valid_handles(m_valid_handle.begin(), m_valid_handle.end()); + for (auto* handle: valid_handles) { + del(handle); + } + mgb_assert(m_valid_handle.empty()); + mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size()); sync(); + m_closed = true; } size_t ChannelImpl::get_option(std::string name) { + mgb_assert(check_available(), "Channel already closed"); return m_channel_state.options.get_option(name); } void ChannelImpl::set_option(std::string name, size_t value) { + mgb_assert(check_available(), "Channel already closed"); m_channel_state.options.set_option(name, value); m_buffer.enqueue(SetOption{name, value}); } @@ -440,9 +467,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) { m_pool.free(ptr); } -ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){ - m_channel_state.tid = std::this_thread::get_id(); -} +ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){} ChannelImpl::~ChannelImpl() { close(); @@ -562,6 +587,10 @@ void ChannelImpl::detach_users(TensorInfo* dest) { //dest->users.clear(); } +bool ChannelImpl::check_available() { + return !m_closed; +} + void ChannelImpl::sync_device_scope(CompNode device) { auto& prev = m_worker_state.device_scope_map[device]; auto& current = m_worker_state.scopes; @@ -786,9 +815,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { std::swap(profiler, m_worker_state.profiler); auto records = profiler->stop(); auto host_map = [this](std::thread::id tid) { - if (tid == m_channel_state.tid) { - return "channel"; - } else if (tid == m_worker_state.tid) { + if (tid == m_worker_state.tid) { return "worker"; } else { return "unknown"; @@ -959,6 +986,7 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) } void ChannelImpl::start_profile(std::unordered_map option) { + mgb_assert(check_available(), "Channel already closed"); auto profiler_option = InterpreterProfiler::Option::from_dict(option); auto profiler = std::make_unique(); profiler->set_option(profiler_option); @@ -968,6 +996,7 @@ void ChannelImpl::start_profile(std::unordered_map option) { } void ChannelImpl::stop_profile(std::string basename, std::string format) { + mgb_assert(check_available(), "Channel already closed"); m_buffer.flush(); auto profiler = std::make_unique(); std::swap(profiler, m_channel_state.profiler); @@ -976,6 +1005,7 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) { } void ChannelImpl::push_scope(std::string name) { + mgb_assert(check_available(), "Channel already closed"); if (m_channel_state.profiler->is_profiling()) { m_channel_state.profiler->record_host(name); m_channel_state.scopes.push_back(name); @@ -984,6 +1014,7 @@ void ChannelImpl::push_scope(std::string name) { } void ChannelImpl::pop_scope(std::string name) { + mgb_assert(check_available(), "Channel already closed"); if (m_channel_state.profiler->is_profiling()) { mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); m_channel_state.scopes.pop_back(); @@ -992,14 +1023,6 @@ void ChannelImpl::pop_scope(std::string name) { } } -void ChannelImpl::assert_in_channel() { - mgb_assert(m_channel_state.tid != std::this_thread::get_id()); -} - -void ChannelImpl::assert_in_worker() { - mgb_assert(m_worker_state.tid == std::this_thread::get_id()); -} - void ChannelImpl::DynamicSublinear::pin(const SmallVector& vec) { for (auto i : vec) { i->pin(); diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index 72136fecef1a79425018fb5aafb9f12587e2e141..0d11c1c7a2fdb57929517cbd33b3435ad05f8f91 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -18,6 +18,7 @@ #include #include +#include "megbrain/comp_node.h" #include "megbrain/utils/mempool.h" #include "megbrain/imperative/interpreter.h" #include "megbrain/imperative/profiler.h" @@ -102,8 +103,7 @@ private: const SmallVector& input_descs, SmallVector* outputs); - void assert_in_channel(); - void assert_in_worker(); + bool check_available(); void sync_device_scope(CompNode device); @@ -120,6 +120,8 @@ private: std::exception_ptr m_worker_exc; std::atomic_uint64_t m_last_id = 0; + bool m_closed = false; + struct WorkQueue : AsyncQueueSC { // set max_spin=0 to prevent Queue fetch task in busy wait manner. // this won't affect throughput when python interpreter is sending enough task, @@ -186,7 +188,6 @@ private: int m_async_level = 2; struct State { - std::thread::id tid; OptionManager options; std::vector scopes; std::unique_ptr profiler; @@ -199,6 +200,7 @@ private: struct ChannelState: State {}; struct WorkerState: State { + std::thread::id tid; CompNode::UnorderedMap> device_scope_map; };