提交 7af1ec66 编写于 作者: M Megvii Engine Team

fix(tensor): del valid tensors when compnode finalizing

GitOrigin-RevId: bace1f2b5131f2a51de47c7a047ccfc3fd52d23d
上级 696d2c2e
...@@ -71,6 +71,7 @@ if sys.platform == "win32": ...@@ -71,6 +71,7 @@ if sys.platform == "win32":
kernel32.SetErrorMode(old_error_mode) kernel32.SetErrorMode(old_error_mode)
from .core._imperative_rt.core2 import close as _close
from .core._imperative_rt.core2 import full_sync as _full_sync from .core._imperative_rt.core2 import full_sync as _full_sync
from .core._imperative_rt.core2 import sync as _sync from .core._imperative_rt.core2 import sync as _sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
...@@ -90,7 +91,7 @@ _set_fork_exec_path_for_timed_func( ...@@ -90,7 +91,7 @@ _set_fork_exec_path_for_timed_func(
_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() _persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg() _persistent_cache_impl_ins.reg()
atexit.register(_full_sync) atexit.register(_close)
del _set_fork_exec_path_for_timed_func del _set_fork_exec_path_for_timed_func
......
...@@ -897,6 +897,11 @@ void init_tensor(py::module m) { ...@@ -897,6 +897,11 @@ void init_tensor(py::module m) {
} }
} }
static constexpr auto sync_py_task_q = []{
py::gil_scoped_release _;
py_task_q.wait_all_task_finish();
};
m.def("set_option", m.def("set_option",
[](std::string name, size_t value){ interpreter_for_py->set_option(name, value); }); [](std::string name, size_t value){ interpreter_for_py->set_option(name, value); });
m.def("get_option", m.def("get_option",
...@@ -928,16 +933,19 @@ void init_tensor(py::module m) { ...@@ -928,16 +933,19 @@ void init_tensor(py::module m) {
m.def("sync", m.def("sync",
[]() { []() {
interpreter_for_py->sync(); interpreter_for_py->sync();
py_task_q.wait_all_task_finish(); sync_py_task_q();
}, });
py::call_guard<py::gil_scoped_release>());
m.def("full_sync", m.def("full_sync",
[]() { []() {
interpreter_for_py->sync(); interpreter_for_py->sync();
CompNode::sync_all(); CompNode::sync_all();
py_task_q.wait_all_task_finish(); sync_py_task_q();
}, });
py::call_guard<py::gil_scoped_release>()); m.def("close",
[]() {
interpreter_for_py->close();
sync_py_task_q();
});
py::handle grad_key_type = GradKeyWrapper::wrap_t::type() py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach") .def<&GradKeyWrapper::attach>("attach")
......
import subprocess
import sys
import numpy as np import numpy as np
import pytest import pytest
...@@ -76,3 +79,14 @@ def test_swap_drop_basic(): ...@@ -76,3 +79,14 @@ def test_swap_drop_basic():
z.numpy() z.numpy()
_set_swap_flag(False) _set_swap_flag(False)
_set_drop_flag(False) _set_drop_flag(False)
def test_finalize():
prog = """
import megengine
with megengine.core.option("enable_host_compute", 0):
x = megengine.tensor(0)
y = x + 1
y.numpy()
"""
subprocess.check_call([sys.executable, "-c", prog])
...@@ -67,6 +67,7 @@ std::shared_ptr<void> EventPool::on_comp_node_finalize() { ...@@ -67,6 +67,7 @@ std::shared_ptr<void> EventPool::on_comp_node_finalize() {
for (auto&& i : m_cn2pool) { for (auto&& i : m_cn2pool) {
i.second.assert_all_freed(); i.second.assert_all_freed();
} }
m_cn2pool.clear();
return {}; return {};
} }
EventPool::~EventPool() { EventPool::~EventPool() {
......
...@@ -33,6 +33,7 @@ Interpreter& Interpreter::inst() { ...@@ -33,6 +33,7 @@ Interpreter& Interpreter::inst() {
} }
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
mgb_assert(check_available(), "Channel already closed");
auto info = alloc(); auto info = alloc();
info->desc.layout = value.layout(); info->desc.layout = value.layout();
info->desc.comp_node = value.comp_node(); info->desc.comp_node = value.comp_node();
...@@ -47,6 +48,7 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { ...@@ -47,6 +48,7 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
} }
Handle ChannelImpl::put(const DeviceTensorND& data) { Handle ChannelImpl::put(const DeviceTensorND& data) {
mgb_assert(check_available(), "Channel already closed");
auto info = alloc(); auto info = alloc();
info->desc.layout = data.layout(); info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node(); info->desc.comp_node = data.comp_node();
...@@ -58,6 +60,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { ...@@ -58,6 +60,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
} }
void ChannelImpl::del(Handle handle) { void ChannelImpl::del(Handle handle) {
if (!check_available()){
return;
}
mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle); mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle); auto* info = reinterpret_cast<TensorInfo*>(handle);
m_valid_handle.erase(handle); m_valid_handle.erase(handle);
...@@ -65,6 +70,7 @@ void ChannelImpl::del(Handle handle) { ...@@ -65,6 +70,7 @@ void ChannelImpl::del(Handle handle) {
} }
void ChannelImpl::swap_in(Handle handle) { void ChannelImpl::swap_in(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_swap) { if (m_worker_state.options.enable_swap) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
...@@ -74,6 +80,7 @@ void ChannelImpl::swap_in(Handle handle) { ...@@ -74,6 +80,7 @@ void ChannelImpl::swap_in(Handle handle) {
} }
void ChannelImpl::swap_out(Handle handle) { void ChannelImpl::swap_out(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_swap) { if (m_worker_state.options.enable_swap) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
...@@ -83,6 +90,7 @@ void ChannelImpl::swap_out(Handle handle) { ...@@ -83,6 +90,7 @@ void ChannelImpl::swap_out(Handle handle) {
} }
void ChannelImpl::drop(Handle handle) { void ChannelImpl::drop(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_drop) { if (m_worker_state.options.enable_drop) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
...@@ -201,6 +209,7 @@ void ChannelImpl::dispatch_kernel( ...@@ -201,6 +209,7 @@ void ChannelImpl::dispatch_kernel(
SmallVector<Handle> ChannelImpl::apply_op( SmallVector<Handle> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op, std::shared_ptr<OpDef> op,
const SmallVector<Handle>& inputs) { const SmallVector<Handle>& inputs) {
mgb_assert(check_available(), "Channel already closed");
for (auto i : inputs) { for (auto i : inputs) {
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
"invalid handle: %p", i); "invalid handle: %p", i);
...@@ -237,6 +246,7 @@ SmallVector<Handle> ChannelImpl::apply_op( ...@@ -237,6 +246,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
} }
HostTensorND ChannelImpl::get_value(Handle handle) { HostTensorND ChannelImpl::get_value(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
// TODO: maybe get_value should be done on host. i.e. delete GetValue // TODO: maybe get_value should be done on host. i.e. delete GetValue
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
...@@ -269,6 +279,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) { ...@@ -269,6 +279,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
} }
TensorShape ChannelImpl::get_shape(Handle handle) { TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
...@@ -296,6 +307,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) { ...@@ -296,6 +307,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
} }
DType ChannelImpl::get_dtype(Handle handle) { DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
...@@ -308,6 +320,7 @@ DType ChannelImpl::get_dtype(Handle handle) { ...@@ -308,6 +320,7 @@ DType ChannelImpl::get_dtype(Handle handle) {
} }
CompNode ChannelImpl::get_device(Handle handle) { CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
...@@ -320,6 +333,7 @@ CompNode ChannelImpl::get_device(Handle handle) { ...@@ -320,6 +333,7 @@ CompNode ChannelImpl::get_device(Handle handle) {
} }
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
...@@ -342,6 +356,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { ...@@ -342,6 +356,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
} }
void ChannelImpl::sync() { void ChannelImpl::sync() {
mgb_assert(check_available(), "Channel already closed");
m_buffer.flush(); m_buffer.flush();
if (m_channel_state.profiler->is_profiling()) { if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<SyncStartEvent>(); m_channel_state.profiler->record_host<SyncStartEvent>();
...@@ -356,14 +371,26 @@ void ChannelImpl::sync() { ...@@ -356,14 +371,26 @@ void ChannelImpl::sync() {
} }
void ChannelImpl::close() { void ChannelImpl::close() {
if (!check_available()) {
return;
}
std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
for (auto* handle: valid_handles) {
del(handle);
}
mgb_assert(m_valid_handle.empty());
mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
sync(); sync();
m_closed = true;
} }
size_t ChannelImpl::get_option(std::string name) { size_t ChannelImpl::get_option(std::string name) {
mgb_assert(check_available(), "Channel already closed");
return m_channel_state.options.get_option(name); return m_channel_state.options.get_option(name);
} }
void ChannelImpl::set_option(std::string name, size_t value) { void ChannelImpl::set_option(std::string name, size_t value) {
mgb_assert(check_available(), "Channel already closed");
m_channel_state.options.set_option(name, value); m_channel_state.options.set_option(name, value);
m_buffer.enqueue(SetOption{name, value}); m_buffer.enqueue(SetOption{name, value});
} }
...@@ -440,9 +467,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) { ...@@ -440,9 +467,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
m_pool.free(ptr); m_pool.free(ptr);
} }
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){ ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){}
m_channel_state.tid = std::this_thread::get_id();
}
ChannelImpl::~ChannelImpl() { ChannelImpl::~ChannelImpl() {
close(); close();
...@@ -562,6 +587,10 @@ void ChannelImpl::detach_users(TensorInfo* dest) { ...@@ -562,6 +587,10 @@ void ChannelImpl::detach_users(TensorInfo* dest) {
//dest->users.clear(); //dest->users.clear();
} }
bool ChannelImpl::check_available() {
return !m_closed;
}
void ChannelImpl::sync_device_scope(CompNode device) { void ChannelImpl::sync_device_scope(CompNode device) {
auto& prev = m_worker_state.device_scope_map[device]; auto& prev = m_worker_state.device_scope_map[device];
auto& current = m_worker_state.scopes; auto& current = m_worker_state.scopes;
...@@ -786,9 +815,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { ...@@ -786,9 +815,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
std::swap(profiler, m_worker_state.profiler); std::swap(profiler, m_worker_state.profiler);
auto records = profiler->stop(); auto records = profiler->stop();
auto host_map = [this](std::thread::id tid) { auto host_map = [this](std::thread::id tid) {
if (tid == m_channel_state.tid) { if (tid == m_worker_state.tid) {
return "channel";
} else if (tid == m_worker_state.tid) {
return "worker"; return "worker";
} else { } else {
return "unknown"; return "unknown";
...@@ -959,6 +986,7 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) ...@@ -959,6 +986,7 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
} }
void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) { void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) {
mgb_assert(check_available(), "Channel already closed");
auto profiler_option = InterpreterProfiler::Option::from_dict(option); auto profiler_option = InterpreterProfiler::Option::from_dict(option);
auto profiler = std::make_unique<InterpreterProfiler>(); auto profiler = std::make_unique<InterpreterProfiler>();
profiler->set_option(profiler_option); profiler->set_option(profiler_option);
...@@ -968,6 +996,7 @@ void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) { ...@@ -968,6 +996,7 @@ void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) {
} }
void ChannelImpl::stop_profile(std::string basename, std::string format) { void ChannelImpl::stop_profile(std::string basename, std::string format) {
mgb_assert(check_available(), "Channel already closed");
m_buffer.flush(); m_buffer.flush();
auto profiler = std::make_unique<InterpreterProfiler>(); auto profiler = std::make_unique<InterpreterProfiler>();
std::swap(profiler, m_channel_state.profiler); std::swap(profiler, m_channel_state.profiler);
...@@ -976,6 +1005,7 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) { ...@@ -976,6 +1005,7 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) {
} }
void ChannelImpl::push_scope(std::string name) { void ChannelImpl::push_scope(std::string name) {
mgb_assert(check_available(), "Channel already closed");
if (m_channel_state.profiler->is_profiling()) { if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<ChannelBeginScope>(name); m_channel_state.profiler->record_host<ChannelBeginScope>(name);
m_channel_state.scopes.push_back(name); m_channel_state.scopes.push_back(name);
...@@ -984,6 +1014,7 @@ void ChannelImpl::push_scope(std::string name) { ...@@ -984,6 +1014,7 @@ void ChannelImpl::push_scope(std::string name) {
} }
void ChannelImpl::pop_scope(std::string name) { void ChannelImpl::pop_scope(std::string name) {
mgb_assert(check_available(), "Channel already closed");
if (m_channel_state.profiler->is_profiling()) { if (m_channel_state.profiler->is_profiling()) {
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch");
m_channel_state.scopes.pop_back(); m_channel_state.scopes.pop_back();
...@@ -992,14 +1023,6 @@ void ChannelImpl::pop_scope(std::string name) { ...@@ -992,14 +1023,6 @@ void ChannelImpl::pop_scope(std::string name) {
} }
} }
void ChannelImpl::assert_in_channel() {
mgb_assert(m_channel_state.tid != std::this_thread::get_id());
}
void ChannelImpl::assert_in_worker() {
mgb_assert(m_worker_state.tid == std::this_thread::get_id());
}
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) { void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
for (auto i : vec) { for (auto i : vec) {
i->pin(); i->pin();
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <unordered_set> #include <unordered_set>
#include <variant> #include <variant>
#include "megbrain/comp_node.h"
#include "megbrain/utils/mempool.h" #include "megbrain/utils/mempool.h"
#include "megbrain/imperative/interpreter.h" #include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/profiler.h" #include "megbrain/imperative/profiler.h"
...@@ -102,8 +103,7 @@ private: ...@@ -102,8 +103,7 @@ private:
const SmallVector<LogicalTensorDesc>& input_descs, const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs); SmallVector<Handle>* outputs);
void assert_in_channel(); bool check_available();
void assert_in_worker();
void sync_device_scope(CompNode device); void sync_device_scope(CompNode device);
...@@ -120,6 +120,8 @@ private: ...@@ -120,6 +120,8 @@ private:
std::exception_ptr m_worker_exc; std::exception_ptr m_worker_exc;
std::atomic_uint64_t m_last_id = 0; std::atomic_uint64_t m_last_id = 0;
bool m_closed = false;
struct WorkQueue : AsyncQueueSC<IdentifiedCommand, WorkQueue> { struct WorkQueue : AsyncQueueSC<IdentifiedCommand, WorkQueue> {
// set max_spin=0 to prevent Queue fetch task in busy wait manner. // set max_spin=0 to prevent Queue fetch task in busy wait manner.
// this won't affect throughput when python interpreter is sending enough task, // this won't affect throughput when python interpreter is sending enough task,
...@@ -186,7 +188,6 @@ private: ...@@ -186,7 +188,6 @@ private:
int m_async_level = 2; int m_async_level = 2;
struct State { struct State {
std::thread::id tid;
OptionManager options; OptionManager options;
std::vector<std::string> scopes; std::vector<std::string> scopes;
std::unique_ptr<InterpreterProfiler> profiler; std::unique_ptr<InterpreterProfiler> profiler;
...@@ -199,6 +200,7 @@ private: ...@@ -199,6 +200,7 @@ private:
struct ChannelState: State {}; struct ChannelState: State {};
struct WorkerState: State { struct WorkerState: State {
std::thread::id tid;
CompNode::UnorderedMap<std::vector<std::string>> device_scope_map; CompNode::UnorderedMap<std::vector<std::string>> device_scope_map;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册