提交 55e9c831 编写于 作者: M Megvii Engine Team

feat(trace): add imperative mode for debug

GitOrigin-RevId: 067b7d235e107d459b4d09f4f04627676b9073cc
上级 281ecd0b
......@@ -112,6 +112,7 @@ class trace:
without_host: if True, will run python code of wrapped function on the first call,
and run the compiled graph/function on subsequent calls. if False, will run python code every time.
Default: False
imperative: if True, will use imperative runtime to execute captured op seq. Default: False
"""
third_party_backend = False
......@@ -124,6 +125,7 @@ class trace:
def __init__(
self,
function,
*,
symbolic=False,
capture_as_const=False,
record_only=False,
......@@ -134,6 +136,7 @@ class trace:
graph_opt_config: GraphOptimizationConfig = None,
symbolic_shape: bool = True,
without_host: bool = False,
imperative: bool = False,
):
self.__wrapped__ = function
self._capture_as_const = capture_as_const or record_only
......@@ -204,6 +207,7 @@ class trace:
self._trace.symbolic = symbolic or record_only
self._trace.capture_as_const = capture_as_const or record_only
self._trace.no_exec = record_only
self._trace.imperative = imperative
self._trace.options_visitor = apply_options
self._trace.profile = profiling
self._trace.array_comparator = array_comparator
......
......@@ -1201,6 +1201,8 @@ void init_tensor(py::module m) {
bool without_host = false;
bool check_external = true;
bool remove_unused_data_required = true;
bool imperative = false;
py::function options_visitor;
std::shared_ptr<TracingTransformation> tracing;
std::shared_ptr<CompiledTransformation> compiled;
......@@ -1257,7 +1259,7 @@ void init_tensor(py::module m) {
} else if (!self.compiled) { // traced but not compiled
using namespace std::placeholders;
self.compiled = std::make_shared<CompiledTransformation>(
*self.trace_result, self.record_input_shapes);
*self.trace_result, self.record_input_shapes, self.imperative);
self.compiled->set_value_comparator(
std::bind(&Trace::compare_value, this, _1, _2));
self.options_visitor(py::cast(&self.compiled->options()));
......@@ -1405,6 +1407,7 @@ void init_tensor(py::module m) {
.def_readwrite("symbolic", &Trace::symbolic)
.def_readwrite("capture_as_const", &Trace::capture_as_const)
.def_readwrite("no_exec", &Trace::no_exec)
.def_readwrite("imperative", &Trace::imperative)
.def_readwrite("options_visitor", &Trace::options_visitor)
.def("enter", &Trace::enter)
.def("exit", &Trace::exit)
......
......@@ -555,6 +555,9 @@ py::object _astensor1d_cpp(
c_args[flat_list.size()] = Py_None;
py::tuple inp_tup = py::reinterpret_steal<py::tuple>(
convert_inputs_cpp(NULL, c_args.data(), c_args.size()));
if (!inp_tup) {
throw py::error_already_set();
}
if (device_obj.is_none()) {
std::vector<PyObject*> inp(inp_tup.size());
for (size_t i = 0; i < inp_tup.size(); ++i) {
......
......@@ -510,14 +510,15 @@ def test_trace_warp_perspective():
("((a + b), (b + c))[1] + a", "((a + b), (b + c))[0] + a", "input id mismatch"),
],
)
def test_trace_mismatch(normal_expr, mismatch_expr, reason):
@pytest.mark.parametrize("imperative", [True, False])
def test_trace_mismatch(normal_expr, mismatch_expr, reason, imperative):
a = tensor([1, 2, 3, 4])
b = tensor([5, 6, 7, 8])
c = tensor([9, 0, 1, 2])
mismatch = False
@trace(symbolic=True)
@trace(symbolic=True, imperative=imperative)
def fn(a, b, c):
if not mismatch:
result = eval(normal_expr)
......@@ -525,7 +526,7 @@ def test_trace_mismatch(normal_expr, mismatch_expr, reason):
result = eval(mismatch_expr)
return result
for i in range(20):
for _ in range(20):
try:
d = fn(a, b, c)
except TraceError as e:
......
......@@ -2,6 +2,7 @@
#include <chrono>
#include <exception>
#include <optional>
#include "megbrain/gopt/inference.h"
#include "megbrain/graph/helper.h"
......@@ -499,11 +500,6 @@ void CompiledTransformation::compile() {
return accessor;
};
std::vector<VarAccessor> var_accessors(m_vars.size());
auto exc_setter = std::bind(
&CompiledTransformation::set_exception, this, std::placeholders::_1);
for (auto&& accessor : var_accessors) {
accessor.exc_setter = exc_setter;
}
for (auto&& item : m_seq) {
bool require_link = bool(item.op) && mm_io_ops.count(item.op->dyn_typeinfo());
VarNodeArray input_vars;
......@@ -579,6 +575,12 @@ void CompiledTransformation::compile() {
dep_iter.add(output_spec.first);
}
}
for (auto& accessor : var_accessors) {
accessor.exc_setter = [this](std::exception_ptr exc) { set_exception(exc); };
if (m_imperative) {
accessor.node = nullptr;
}
}
m_executable = m_graph->compile(output_specs);
mgb_assert(m_executable != nullptr, "The compiled executable is nullptr.");
......@@ -601,7 +603,7 @@ void CompiledTransformation::assert_tensor_equal(ValueRef lhs, ValueRef rhs) {
trace_assert(m_value_comparator(lhs, rhs), "tensors not equals");
}
void CompiledTransformation::trace_input(size_t id, ValueRef value) {
ValueRef CompiledTransformation::trace_input(size_t id, ValueRef value) {
try {
auto& var = m_vars[id];
auto& var_accessor = m_var_accessors[id];
......@@ -626,32 +628,43 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
var_accessor.data_setter(value.dev_tensor()->as_nd());
m_setted_extern.insert(id);
}
break;
return value;
}
case VarKind::Constant: {
// expect host value here
mgb_assert(var.bound_data, "const var without data bound");
assert_tensor_equal(var.bound_data, value);
break;
// TODO: use value
return var.bound_data;
}
case VarKind::Internal: {
trace_assert(
value.is(m_value_type), "expect internal node, got external");
auto& traced_value = value.cast(m_value_type);
trace_assert(traced_value.id() == id, "input id mismatch");
break;
return traced_value.get_imperative_value();
}
default:
trace_assert(false, "unknown var kind");
}
} catch (TraceError&) {
throw;
} catch (const std::exception& exc) {
mgb_log_error("unexpected error %s", exc.what());
throw;
} catch (...) {
mgb_assert(false, "unexpected error");
}
}
auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t {
auto CompiledTransformation::trace_output(size_t id, ValueRef value)
-> TracedValue::ref_t {
auto traced_value = m_value_type.make(id, &m_vars[id], &m_var_accessors[id]);
m_weak_values.push_back(traced_value);
if (m_imperative) {
mgb_assert(value, "imperative mode requires value");
traced_value->set_imperative_value(value);
}
return traced_value;
}
......@@ -663,6 +676,9 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() {
ShapeValue::ref_t CompiledTransformation::TracedValue::shape() const {
if (!m_shape) {
trace_assert(m_accessor->shape_getter, "shape unreadable");
if (m_accessor->is_imperative()) {
return m_imperative_value.shape();
}
m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter()));
}
return m_shape;
......@@ -675,6 +691,23 @@ DTypeValue::ref_t CompiledTransformation::TracedValue::dtype() const {
CompNodeValue::ref_t CompiledTransformation::TracedValue::comp_node() const {
return m_var->device;
}
DeviceValue::ref_t CompiledTransformation::TracedValue::data() const {
trace_assert(m_accessor->data_getter, "data unreadable");
if (m_accessor->is_imperative()) {
return m_imperative_value.dev_tensor();
}
return DeviceValue::make(m_accessor->data_getter());
}
HostValue::ref_t CompiledTransformation::TracedValue::value() const {
trace_assert(m_accessor->value_getter, "value unreadable");
if (m_accessor->is_imperative()) {
return m_imperative_value.numpy();
}
return HostValue::make(m_accessor->value_getter());
}
auto CompiledTransformation::TracedValue::accessor() const -> const VarAccessor& {
return *m_accessor;
}
......@@ -684,12 +717,24 @@ ValueRefList CompiledTransformation::apply_op(
auto& item = next_instruction();
trace_assert(inputs.size() == item.inputs.size(), "input size mismatch");
trace_assert(apply_op.op().is_same(*item.op), "operator mismatch");
for (size_t i = 0; i < inputs.size(); ++i) {
trace_input(item.inputs[i], inputs[i]);
}
ValueRefList outputs(item.outputs.size());
for (size_t i = 0; i < item.outputs.size(); ++i) {
outputs[i] = trace_output(item.outputs[i]);
if (!m_imperative) {
for (size_t i = 0; i < inputs.size(); ++i) {
trace_input(item.inputs[i], inputs[i]);
}
for (size_t i = 0; i < item.outputs.size(); ++i) {
outputs[i] = trace_output(item.outputs[i], {});
}
} else {
SmallVector<ValueRef> input_values;
for (size_t i = 0; i < inputs.size(); ++i) {
input_values.push_back(trace_input(item.inputs[i], inputs[i]));
}
auto&& output_values = imperative::apply(apply_op, input_values);
mgb_assert(output_values.size() == outputs.size());
for (size_t i = 0; i < item.outputs.size(); ++i) {
outputs[i] = trace_output(item.outputs[i], output_values[i]);
}
}
return outputs;
}
......@@ -698,24 +743,22 @@ ValueRefList CompiledTransformation::apply_get_attr(
const GetAttr& get_attr, Span<ValueRef> inputs) {
if (auto* traced_value = inputs[0].as(m_value_type)) {
ValueRef output;
auto& var_accessor = traced_value->accessor();
switch (get_attr.attr()) {
case GetAttr::Shape:
output = traced_value->shape();
break;
case GetAttr::Data:
trace_assert(var_accessor.data_getter, "data unreadable");
output = DeviceValue::make(var_accessor.data_getter());
output = traced_value->data();
break;
case GetAttr::Value:
trace_assert(var_accessor.value_getter, "value unreadable");
output = HostValue::make(var_accessor.value_getter());
output = traced_value->value();
break;
case GetAttr::DType:
output = traced_value->dtype();
break;
case GetAttr::Device:
output = traced_value->comp_node();
break;
default:
break;
}
......@@ -745,8 +788,7 @@ ValueRefList CompiledTransformation::apply_create_tensor(
if (!tensor) {
tensor = imperative::apply(create_tensor, inputs)[0];
}
trace_input(input_id, tensor);
return {trace_output(output_id)};
return {trace_output(output_id, trace_input(input_id, tensor))};
}
ValueRefList CompiledTransformation::apply_transformation(
......@@ -762,21 +804,21 @@ ValueRefList CompiledTransformation::apply_transformation(
trace_assert(item.op == nullptr, "operator mismatch");
trace_assert(item.inputs.size() == 1, "inputs size mismatch");
trace_assert(item.outputs.size() == 1, "inputs output mismatch");
trace_input(item.inputs[0], inputs[0]);
auto value = trace_input(item.inputs[0], inputs[0]);
trace_assert(
trace_mark_var->mark() == m_vars[item.outputs[0]].mark,
"mark mismatch");
return {trace_output(item.outputs[0])};
return {trace_output(item.outputs[0], value)};
} else if (auto* trace_name_var = op.as<RenameValue>()) {
auto& item = next_instruction();
trace_assert(item.op == nullptr, "operator mismatch");
trace_assert(item.inputs.size() == 1, "inputs size mismatch");
trace_assert(item.outputs.size() == 1, "outputs size mismatch");
trace_input(item.inputs[0], inputs[0]);
auto value = trace_input(item.inputs[0], inputs[0]);
trace_assert(
trace_name_var->name() == m_vars[item.outputs[0]].name,
"name mismatch");
return {trace_output(item.outputs[0])};
return {trace_output(item.outputs[0], value)};
} else {
return op.fallback(inputs);
}
......@@ -786,11 +828,9 @@ void CompiledTransformation::on_unregister() noexcept {
// resolve pending values
for (auto&& weak_value : m_weak_values) {
if (auto traced_value = weak_value.lock()) {
auto& var_accessor = m_var_accessors[traced_value->id()];
auto value = ([&]() -> ValueRef {
try {
trace_assert(var_accessor.data_getter, "data unreadable");
auto dev_value = DeviceValue::make(var_accessor.data_getter());
auto dev_value = traced_value->data();
return imperative::apply(
CreateTensor(
CreateTensor::Common, dev_value->device(),
......@@ -821,10 +861,9 @@ void CompiledTransformation::wait() {
trace_assert(m_pc == m_seq.size(), "mismature end");
} catch (...) {
}
mgb_assert(m_executable != nullptr);
std::unique_lock lock{m_mutex};
m_cv.wait(lock, [&] { return m_graph_status == 0; });
lock.unlock();
if (!m_imperative) {
wait_worker();
}
for (auto&& box : m_boxes) {
box->reset();
}
......@@ -839,6 +878,13 @@ void CompiledTransformation::wait() {
}
}
void CompiledTransformation::wait_worker() {
mgb_assert(m_executable != nullptr);
std::unique_lock lock{m_mutex};
m_cv.wait(lock, [&] { return m_graph_status == 0; });
lock.unlock();
}
std::exception_ptr CompiledTransformation::set_exception(
std::exception_ptr exc) noexcept {
MGB_LOCK_GUARD(m_mutex);
......
......@@ -265,12 +265,14 @@ public:
using OpKind = TraceResult::SeqItem::OpKind;
struct VarAccessor {
VarNode* node;
VarNode* node; // use imperative mode when node == nullptr
std::function<TensorShape()> shape_getter;
std::function<DeviceTensorND()> data_getter;
std::function<HostTensorND()> value_getter;
std::function<void(DeviceTensorND)> data_setter;
std::function<void(std::exception_ptr)> exc_setter;
bool is_imperative() const { return node == nullptr; }
};
class TracedValue final : public ObjectValue<TracedValue> {
......@@ -281,6 +283,7 @@ public:
mutable ShapeValue::ref_t m_shape;
mutable DTypeValue::ref_t m_dtype;
mutable CompNodeValue::ref_t m_comp_node;
mutable ValueRef m_imperative_value;
public:
TracedValue(size_t id, VarInfo* var, VarAccessor* accessor)
......@@ -289,9 +292,12 @@ public:
ShapeValue::ref_t shape() const;
DTypeValue::ref_t dtype() const;
CompNodeValue::ref_t comp_node() const;
DeviceValue::ref_t data() const;
HostValue::ref_t value() const;
const VarAccessor& accessor() const;
void set_exception(std::exception_ptr exc) const {
mgb_assert(m_accessor->exc_setter, "exc setter invalid");
m_accessor->exc_setter(exc);
}
......@@ -299,7 +305,11 @@ public:
return ssprintf("TracedValue{\"id\"=%zu}", id());
}
void clear() override {}
void clear() override { m_imperative_value = {}; }
void set_imperative_value(ValueRef value) const { m_imperative_value = value; }
ValueRef get_imperative_value() const { return m_imperative_value; }
};
private:
......@@ -322,15 +332,23 @@ private:
ComputingGraph::OutputSpec m_output_spec;
ObjectType<TracedValue> m_value_type{"TracedValue"};
std::set<size_t> m_setted_extern;
bool m_imperative = false;
public:
CompiledTransformation(TraceResult result, bool input_shape_static)
CompiledTransformation(TraceResult result, bool input_shape_static, bool imperative)
: m_seq(result.seq),
m_vars(result.vars),
m_input_shape_static(input_shape_static) {
m_input_shape_static(input_shape_static),
m_imperative(imperative) {
m_graph = ComputingGraph::make();
options().no_force_inplace = true;
options().async_exec_level = 0b100;
if (!m_imperative) {
start_worker();
}
}
void start_worker() {
m_graph_executor = std::thread([&] {
while (true) {
std::unique_lock lock{m_mutex};
......@@ -384,7 +402,7 @@ public:
* \param id
* \param value
*/
void trace_input(size_t id, ValueRef value);
ValueRef trace_input(size_t id, ValueRef value);
/**
* \brief make a placeholder for output.
......@@ -393,7 +411,7 @@ public:
* \return TracedValue::ref_t output placeholder, would be reset to real value when
* trace exits
*/
TracedValue::ref_t trace_output(size_t id);
TracedValue::ref_t trace_output(size_t id, ValueRef value);
TraceResult::SeqItem& next_instruction();
......@@ -422,6 +440,8 @@ public:
void wait();
void wait_worker();
std::exception_ptr set_exception(std::exception_ptr exc) noexcept;
template <typename T>
......@@ -431,7 +451,7 @@ public:
return box;
}
~CompiledTransformation() {
void stop_worker() {
{
MGB_LOCK_GUARD(m_mutex);
m_graph_status = 2;
......@@ -439,6 +459,12 @@ public:
m_cv.notify_all();
m_graph_executor.join();
}
~CompiledTransformation() {
if (!m_imperative) {
stop_worker();
}
}
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册