提交 50c4daac 编写于 作者: M Megvii Engine Team

feat(mge/interpreter): add async_level mechanism for Interpreter

GitOrigin-RevId: 8615a23b75b7e3172d724acc8f7fffd2cf9b73d5
上级 82b0f677
......@@ -77,12 +77,14 @@ void init_imperative_rt(py::module m) {
.def("get_shape", &Interpreter::Channel::get_shape)
.def("_get_dev_tensor", &Interpreter::Channel::get_dev_tensor)
.def("apply_op", &Interpreter::Channel::apply_op)
.def("config_async_level", &Interpreter::Channel::config_async_level)
.def("get_async_level", &Interpreter::Channel::get_async_level)
.def("sync", &Interpreter::Channel::sync, py::call_guard<py::gil_scoped_release>());
std::unique_ptr<Interpreter::Channel> ch = Interpreter::inst().create_channel();
m.attr("interpreter") = py::detail::make_caster<decltype(ch)>::cast(
std::move(ch), py::return_value_policy::move, {});
for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op"}) {
for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op", "config_async_level", "get_async_level"}) {
m.attr(name) = m.attr("interpreter").attr(name);
}
......
import pytest
import megengine as mge
import megengine.functional as F
from megengine.core._imperative_rt.imperative import config_async_level, get_async_level
def test_basic():
config_async_level(2)
assert get_async_level() == 2
with pytest.raises(RuntimeError):
config_async_level(3)
def test_level1_infer_value():
config_async_level(1)
a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32")
b = mge.tensor([1, 1], dtype="float32")
# make DepType::VALUE unknown
c = b * 2
with pytest.raises(RuntimeError):
d = F.reshape(a, c)
def test_level1_infer_shape_with_unknown():
config_async_level(2)
a = mge.tensor([[1, 2, 2, 3]], dtype="float32")
b = mge.tensor([1, 1])
c = b * 2
# make DepType::SHAPE unknown
d = F.reshape(a, c)
config_async_level(1)
e = mge.tensor([[1, 2]], dtype="float32")
with pytest.raises(RuntimeError):
f = F.matmul(d, e)
......@@ -54,21 +54,25 @@ void ChannelImpl::del(void* handle) {
SmallVector<void*> ChannelImpl::apply_op(
std::shared_ptr<OpDef> op,
const SmallVector<void*>& inputs) {
SmallVector<TensorInfo*> input_infos;
input_infos.reserve(inputs.size());
SmallVector<LogicalTensorDesc> input_descs;
input_descs.reserve(inputs.size());
for (auto h : inputs) {
auto info = reinterpret_cast<TensorInfo*>(h);
for (auto i : inputs) {
auto info = reinterpret_cast<TensorInfo*>(i);
input_infos.push_back(info);
input_descs.push_back(info->desc);
}
auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs);
ApplyOp cmd{std::move(op)};
cmd.inputs.reserve(inputs.size());
for (auto i : inputs) {
cmd.inputs.push_back(reinterpret_cast<TensorInfo*>(i));
}
cmd.inputs = std::move(input_infos);
cmd.outputs.reserve(output_descs.size());
SmallVector<void*> outputs;
bool is_fallible = false;
for (auto&& desc : output_descs) {
if (desc.layout.ndim == 0) {
is_fallible = true;
}
auto info = alloc();
info->desc = desc;
m_valid_handle.insert(info);
......@@ -76,6 +80,9 @@ SmallVector<void*> ChannelImpl::apply_op(
outputs.push_back(info);
}
m_worker.add_task(std::move(cmd));
if (is_fallible && m_async_level <= 1) {
sync();
}
return outputs;
}
......@@ -162,7 +169,12 @@ void ChannelImpl::close() {
}
void ChannelImpl::config_async_level(int level) {
mgb_assert(0);
mgb_assert(level <= 2 and level >= 0, "async_level should be 0, 1 or 2");
m_async_level = level;
}
int ChannelImpl::get_async_level() {
return m_async_level;
}
TensorInfo* ChannelImpl::alloc() {
......
......@@ -74,6 +74,7 @@ struct ChannelImpl : Interpreter::Channel {
void close() override;
void config_async_level(int level) override;
int get_async_level() override;
private:
TensorInfo* alloc();
......@@ -101,7 +102,11 @@ private:
ChannelImpl* m_owner;
} m_worker;
int m_async_level = 2;
//! config whether raise error exactly when invoking op.
//! level 2: both device and user side errors are async;
//! level 1: user side errors are sync;
//! level 0: both sync.
int m_async_level = 1;
};
} // namespace mgb::imperative::interpreter::intl
......@@ -41,6 +41,7 @@ struct Interpreter {
virtual void close() = 0;
virtual void config_async_level(int level) = 0;
virtual int get_async_level() = 0;
};
virtual std::unique_ptr<Channel> create_channel() = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册