提交 7af1ec66 编写于 作者: M Megvii Engine Team

fix(tensor): del valid tensors when compnode finalizing

GitOrigin-RevId: bace1f2b5131f2a51de47c7a047ccfc3fd52d23d
上级 696d2c2e
......@@ -71,6 +71,7 @@ if sys.platform == "win32":
kernel32.SetErrorMode(old_error_mode)
from .core._imperative_rt.core2 import close as _close
from .core._imperative_rt.core2 import full_sync as _full_sync
from .core._imperative_rt.core2 import sync as _sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
......@@ -90,7 +91,7 @@ _set_fork_exec_path_for_timed_func(
_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg()
atexit.register(_full_sync)
atexit.register(_close)
del _set_fork_exec_path_for_timed_func
......
......@@ -897,6 +897,11 @@ void init_tensor(py::module m) {
}
}
static constexpr auto sync_py_task_q = []{
py::gil_scoped_release _;
py_task_q.wait_all_task_finish();
};
m.def("set_option",
[](std::string name, size_t value){ interpreter_for_py->set_option(name, value); });
m.def("get_option",
......@@ -928,16 +933,19 @@ void init_tensor(py::module m) {
m.def("sync",
[]() {
interpreter_for_py->sync();
py_task_q.wait_all_task_finish();
},
py::call_guard<py::gil_scoped_release>());
sync_py_task_q();
});
m.def("full_sync",
[]() {
interpreter_for_py->sync();
CompNode::sync_all();
py_task_q.wait_all_task_finish();
},
py::call_guard<py::gil_scoped_release>());
sync_py_task_q();
});
m.def("close",
[]() {
interpreter_for_py->close();
sync_py_task_q();
});
py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach")
......
import subprocess
import sys
import numpy as np
import pytest
......@@ -76,3 +79,14 @@ def test_swap_drop_basic():
z.numpy()
_set_swap_flag(False)
_set_drop_flag(False)
def test_finalize():
prog = """
import megengine
with megengine.core.option("enable_host_compute", 0):
x = megengine.tensor(0)
y = x + 1
y.numpy()
"""
subprocess.check_call([sys.executable, "-c", prog])
......@@ -67,6 +67,7 @@ std::shared_ptr<void> EventPool::on_comp_node_finalize() {
for (auto&& i : m_cn2pool) {
i.second.assert_all_freed();
}
m_cn2pool.clear();
return {};
}
EventPool::~EventPool() {
......
......@@ -33,6 +33,7 @@ Interpreter& Interpreter::inst() {
}
Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
mgb_assert(check_available(), "Channel already closed");
auto info = alloc();
info->desc.layout = value.layout();
info->desc.comp_node = value.comp_node();
......@@ -47,6 +48,7 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) {
}
Handle ChannelImpl::put(const DeviceTensorND& data) {
mgb_assert(check_available(), "Channel already closed");
auto info = alloc();
info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node();
......@@ -58,6 +60,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
}
void ChannelImpl::del(Handle handle) {
if (!check_available()){
return;
}
mgb_assert(m_valid_handle.count(handle), "invalid handle: %p", handle);
auto* info = reinterpret_cast<TensorInfo*>(handle);
m_valid_handle.erase(handle);
......@@ -65,6 +70,7 @@ void ChannelImpl::del(Handle handle) {
}
void ChannelImpl::swap_in(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_swap) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -74,6 +80,7 @@ void ChannelImpl::swap_in(Handle handle) {
}
void ChannelImpl::swap_out(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_swap) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -83,6 +90,7 @@ void ChannelImpl::swap_out(Handle handle) {
}
void ChannelImpl::drop(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
if (m_worker_state.options.enable_drop) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -201,6 +209,7 @@ void ChannelImpl::dispatch_kernel(
SmallVector<Handle> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op,
const SmallVector<Handle>& inputs) {
mgb_assert(check_available(), "Channel already closed");
for (auto i : inputs) {
mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(),
"invalid handle: %p", i);
......@@ -237,6 +246,7 @@ SmallVector<Handle> ChannelImpl::apply_op(
}
HostTensorND ChannelImpl::get_value(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
// TODO: maybe get_value should be done on host. i.e. delete GetValue
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
......@@ -269,6 +279,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
}
TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -296,6 +307,7 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
}
DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -308,6 +320,7 @@ DType ChannelImpl::get_dtype(Handle handle) {
}
CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -320,6 +333,7 @@ CompNode ChannelImpl::get_device(Handle handle) {
}
DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
mgb_assert(check_available(), "Channel already closed");
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
......@@ -342,6 +356,7 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
}
void ChannelImpl::sync() {
mgb_assert(check_available(), "Channel already closed");
m_buffer.flush();
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<SyncStartEvent>();
......@@ -356,14 +371,26 @@ void ChannelImpl::sync() {
}
void ChannelImpl::close() {
if (!check_available()) {
return;
}
std::vector<Handle> valid_handles(m_valid_handle.begin(), m_valid_handle.end());
for (auto* handle: valid_handles) {
del(handle);
}
mgb_assert(m_valid_handle.empty());
mgb_log_debug("%ld tensor exists before channel close", (long)valid_handles.size());
sync();
m_closed = true;
}
size_t ChannelImpl::get_option(std::string name) {
mgb_assert(check_available(), "Channel already closed");
return m_channel_state.options.get_option(name);
}
void ChannelImpl::set_option(std::string name, size_t value) {
mgb_assert(check_available(), "Channel already closed");
m_channel_state.options.set_option(name, value);
m_buffer.enqueue(SetOption{name, value});
}
......@@ -440,9 +467,7 @@ void ChannelImpl::real_free(TensorInfo* ptr) {
m_pool.free(ptr);
}
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){
m_channel_state.tid = std::this_thread::get_id();
}
ChannelImpl::ChannelImpl() : m_worker(this), m_buffer(this){}
ChannelImpl::~ChannelImpl() {
close();
......@@ -562,6 +587,10 @@ void ChannelImpl::detach_users(TensorInfo* dest) {
//dest->users.clear();
}
bool ChannelImpl::check_available() {
return !m_closed;
}
void ChannelImpl::sync_device_scope(CompNode device) {
auto& prev = m_worker_state.device_scope_map[device];
auto& current = m_worker_state.scopes;
......@@ -786,9 +815,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
std::swap(profiler, m_worker_state.profiler);
auto records = profiler->stop();
auto host_map = [this](std::thread::id tid) {
if (tid == m_channel_state.tid) {
return "channel";
} else if (tid == m_worker_state.tid) {
if (tid == m_worker_state.tid) {
return "worker";
} else {
return "unknown";
......@@ -959,6 +986,7 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range)
}
void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) {
mgb_assert(check_available(), "Channel already closed");
auto profiler_option = InterpreterProfiler::Option::from_dict(option);
auto profiler = std::make_unique<InterpreterProfiler>();
profiler->set_option(profiler_option);
......@@ -968,6 +996,7 @@ void ChannelImpl::start_profile(std::unordered_map<std::string, int> option) {
}
void ChannelImpl::stop_profile(std::string basename, std::string format) {
mgb_assert(check_available(), "Channel already closed");
m_buffer.flush();
auto profiler = std::make_unique<InterpreterProfiler>();
std::swap(profiler, m_channel_state.profiler);
......@@ -976,6 +1005,7 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) {
}
void ChannelImpl::push_scope(std::string name) {
mgb_assert(check_available(), "Channel already closed");
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<ChannelBeginScope>(name);
m_channel_state.scopes.push_back(name);
......@@ -984,6 +1014,7 @@ void ChannelImpl::push_scope(std::string name) {
}
void ChannelImpl::pop_scope(std::string name) {
mgb_assert(check_available(), "Channel already closed");
if (m_channel_state.profiler->is_profiling()) {
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch");
m_channel_state.scopes.pop_back();
......@@ -992,14 +1023,6 @@ void ChannelImpl::pop_scope(std::string name) {
}
}
void ChannelImpl::assert_in_channel() {
mgb_assert(m_channel_state.tid != std::this_thread::get_id());
}
void ChannelImpl::assert_in_worker() {
mgb_assert(m_worker_state.tid == std::this_thread::get_id());
}
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
for (auto i : vec) {
i->pin();
......
......@@ -18,6 +18,7 @@
#include <unordered_set>
#include <variant>
#include "megbrain/comp_node.h"
#include "megbrain/utils/mempool.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/profiler.h"
......@@ -102,8 +103,7 @@ private:
const SmallVector<LogicalTensorDesc>& input_descs,
SmallVector<Handle>* outputs);
void assert_in_channel();
void assert_in_worker();
bool check_available();
void sync_device_scope(CompNode device);
......@@ -120,6 +120,8 @@ private:
std::exception_ptr m_worker_exc;
std::atomic_uint64_t m_last_id = 0;
bool m_closed = false;
struct WorkQueue : AsyncQueueSC<IdentifiedCommand, WorkQueue> {
// set max_spin=0 to prevent Queue fetch task in busy wait manner.
// this won't affect throughput when python interpreter is sending enough task,
......@@ -186,7 +188,6 @@ private:
int m_async_level = 2;
struct State {
std::thread::id tid;
OptionManager options;
std::vector<std::string> scopes;
std::unique_ptr<InterpreterProfiler> profiler;
......@@ -199,6 +200,7 @@ private:
struct ChannelState: State {};
struct WorkerState: State {
std::thread::id tid;
CompNode::UnorderedMap<std::vector<std::string>> device_scope_map;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册