From 55e9c8318c64be347b06931c7b6814f9cfabab67 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 2 Aug 2022 13:48:56 +0800 Subject: [PATCH] feat(trace): add imperative mode for debug GitOrigin-RevId: 067b7d235e107d459b4d09f4f04627676b9073cc --- imperative/python/megengine/jit/tracing.py | 4 + imperative/python/src/tensor.cpp | 5 +- imperative/python/src/tensor_utils.cpp | 3 + .../python/test/unit/jit/test_tracing.py | 7 +- imperative/src/impl/transformations/trace.cpp | 112 ++++++++++++------ .../imperative/transformations/trace.h | 40 +++++-- 6 files changed, 127 insertions(+), 44 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 114d6fca0..7ba77eca4 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -112,6 +112,7 @@ class trace: without_host: if True, will run python code of wrapped function on the first call, and run the compiled graph/function on subsequent calls. if False, will run python code every time. Default: False + imperative: if True, will use imperative runtime to execute captured op seq. Default: False """ third_party_backend = False @@ -124,6 +125,7 @@ class trace: def __init__( self, function, + *, symbolic=False, capture_as_const=False, record_only=False, @@ -134,6 +136,7 @@ class trace: graph_opt_config: GraphOptimizationConfig = None, symbolic_shape: bool = True, without_host: bool = False, + imperative: bool = False, ): self.__wrapped__ = function self._capture_as_const = capture_as_const or record_only @@ -204,6 +207,7 @@ class trace: self._trace.symbolic = symbolic or record_only self._trace.capture_as_const = capture_as_const or record_only self._trace.no_exec = record_only + self._trace.imperative = imperative self._trace.options_visitor = apply_options self._trace.profile = profiling self._trace.array_comparator = array_comparator diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 556466515..8e33415e6 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -1201,6 +1201,8 @@ void init_tensor(py::module m) { bool without_host = false; bool check_external = true; bool remove_unused_data_required = true; + bool imperative = false; + py::function options_visitor; std::shared_ptr tracing; std::shared_ptr compiled; @@ -1257,7 +1259,7 @@ void init_tensor(py::module m) { } else if (!self.compiled) { // traced but not compiled using namespace std::placeholders; self.compiled = std::make_shared( - *self.trace_result, self.record_input_shapes); + *self.trace_result, self.record_input_shapes, self.imperative); self.compiled->set_value_comparator( std::bind(&Trace::compare_value, this, _1, _2)); self.options_visitor(py::cast(&self.compiled->options())); @@ -1405,6 +1407,7 @@ void init_tensor(py::module m) { .def_readwrite("symbolic", &Trace::symbolic) .def_readwrite("capture_as_const", &Trace::capture_as_const) .def_readwrite("no_exec", &Trace::no_exec) + .def_readwrite("imperative", &Trace::imperative) .def_readwrite("options_visitor", &Trace::options_visitor) .def("enter", &Trace::enter) .def("exit", &Trace::exit) diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index a82db85b3..83b9eedf8 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -555,6 +555,9 @@ py::object _astensor1d_cpp( c_args[flat_list.size()] = Py_None; py::tuple inp_tup = py::reinterpret_steal( convert_inputs_cpp(NULL, c_args.data(), c_args.size())); + if (!inp_tup) { + throw py::error_already_set(); + } if (device_obj.is_none()) { std::vector inp(inp_tup.size()); for (size_t i = 0; i < inp_tup.size(); ++i) { diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index 12f46edaf..dda912b44 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -510,14 +510,15 @@ def test_trace_warp_perspective(): ("((a + b), (b + c))[1] + a", "((a + b), (b + c))[0] + a", "input id mismatch"), ], ) -def test_trace_mismatch(normal_expr, mismatch_expr, reason): +@pytest.mark.parametrize("imperative", [True, False]) +def test_trace_mismatch(normal_expr, mismatch_expr, reason, imperative): a = tensor([1, 2, 3, 4]) b = tensor([5, 6, 7, 8]) c = tensor([9, 0, 1, 2]) mismatch = False - @trace(symbolic=True) + @trace(symbolic=True, imperative=imperative) def fn(a, b, c): if not mismatch: result = eval(normal_expr) @@ -525,7 +526,7 @@ def test_trace_mismatch(normal_expr, mismatch_expr, reason): result = eval(mismatch_expr) return result - for i in range(20): + for _ in range(20): try: d = fn(a, b, c) except TraceError as e: diff --git a/imperative/src/impl/transformations/trace.cpp b/imperative/src/impl/transformations/trace.cpp index 5fd3a5c32..43fb7f1ce 100644 --- a/imperative/src/impl/transformations/trace.cpp +++ b/imperative/src/impl/transformations/trace.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "megbrain/gopt/inference.h" #include "megbrain/graph/helper.h" @@ -499,11 +500,6 @@ void CompiledTransformation::compile() { return accessor; }; std::vector var_accessors(m_vars.size()); - auto exc_setter = std::bind( - &CompiledTransformation::set_exception, this, std::placeholders::_1); - for (auto&& accessor : var_accessors) { - accessor.exc_setter = exc_setter; - } for (auto&& item : m_seq) { bool require_link = bool(item.op) && mm_io_ops.count(item.op->dyn_typeinfo()); VarNodeArray input_vars; @@ -579,6 +575,12 @@ void CompiledTransformation::compile() { dep_iter.add(output_spec.first); } } + for (auto& accessor : var_accessors) { + accessor.exc_setter = [this](std::exception_ptr exc) { set_exception(exc); }; + if (m_imperative) { + accessor.node = nullptr; + } + } m_executable = m_graph->compile(output_specs); mgb_assert(m_executable != nullptr, "The compiled executable is nullptr."); @@ -601,7 +603,7 @@ void CompiledTransformation::assert_tensor_equal(ValueRef lhs, ValueRef rhs) { trace_assert(m_value_comparator(lhs, rhs), "tensors not equals"); } -void CompiledTransformation::trace_input(size_t id, ValueRef value) { +ValueRef CompiledTransformation::trace_input(size_t id, ValueRef value) { try { auto& var = m_vars[id]; auto& var_accessor = m_var_accessors[id]; @@ -626,32 +628,43 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { var_accessor.data_setter(value.dev_tensor()->as_nd()); m_setted_extern.insert(id); } - break; + return value; } case VarKind::Constant: { // expect host value here mgb_assert(var.bound_data, "const var without data bound"); assert_tensor_equal(var.bound_data, value); - break; + // TODO: use value + return var.bound_data; } case VarKind::Internal: { trace_assert( value.is(m_value_type), "expect internal node, got external"); auto& traced_value = value.cast(m_value_type); trace_assert(traced_value.id() == id, "input id mismatch"); - break; + return traced_value.get_imperative_value(); } + default: + trace_assert(false, "unknown var kind"); } } catch (TraceError&) { throw; + } catch (const std::exception& exc) { + mgb_log_error("unexpected error %s", exc.what()); + throw; } catch (...) { mgb_assert(false, "unexpected error"); } } -auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t { +auto CompiledTransformation::trace_output(size_t id, ValueRef value) + -> TracedValue::ref_t { auto traced_value = m_value_type.make(id, &m_vars[id], &m_var_accessors[id]); m_weak_values.push_back(traced_value); + if (m_imperative) { + mgb_assert(value, "imperative mode requires value"); + traced_value->set_imperative_value(value); + } return traced_value; } @@ -663,6 +676,9 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() { ShapeValue::ref_t CompiledTransformation::TracedValue::shape() const { if (!m_shape) { trace_assert(m_accessor->shape_getter, "shape unreadable"); + if (m_accessor->is_imperative()) { + return m_imperative_value.shape(); + } m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter())); } return m_shape; @@ -675,6 +691,23 @@ DTypeValue::ref_t CompiledTransformation::TracedValue::dtype() const { CompNodeValue::ref_t CompiledTransformation::TracedValue::comp_node() const { return m_var->device; } + +DeviceValue::ref_t CompiledTransformation::TracedValue::data() const { + trace_assert(m_accessor->data_getter, "data unreadable"); + if (m_accessor->is_imperative()) { + return m_imperative_value.dev_tensor(); + } + return DeviceValue::make(m_accessor->data_getter()); +} + +HostValue::ref_t CompiledTransformation::TracedValue::value() const { + trace_assert(m_accessor->value_getter, "value unreadable"); + if (m_accessor->is_imperative()) { + return m_imperative_value.numpy(); + } + return HostValue::make(m_accessor->value_getter()); +} + auto CompiledTransformation::TracedValue::accessor() const -> const VarAccessor& { return *m_accessor; } @@ -684,12 +717,24 @@ ValueRefList CompiledTransformation::apply_op( auto& item = next_instruction(); trace_assert(inputs.size() == item.inputs.size(), "input size mismatch"); trace_assert(apply_op.op().is_same(*item.op), "operator mismatch"); - for (size_t i = 0; i < inputs.size(); ++i) { - trace_input(item.inputs[i], inputs[i]); - } ValueRefList outputs(item.outputs.size()); - for (size_t i = 0; i < item.outputs.size(); ++i) { - outputs[i] = trace_output(item.outputs[i]); + if (!m_imperative) { + for (size_t i = 0; i < inputs.size(); ++i) { + trace_input(item.inputs[i], inputs[i]); + } + for (size_t i = 0; i < item.outputs.size(); ++i) { + outputs[i] = trace_output(item.outputs[i], {}); + } + } else { + SmallVector input_values; + for (size_t i = 0; i < inputs.size(); ++i) { + input_values.push_back(trace_input(item.inputs[i], inputs[i])); + } + auto&& output_values = imperative::apply(apply_op, input_values); + mgb_assert(output_values.size() == outputs.size()); + for (size_t i = 0; i < item.outputs.size(); ++i) { + outputs[i] = trace_output(item.outputs[i], output_values[i]); + } } return outputs; } @@ -698,24 +743,22 @@ ValueRefList CompiledTransformation::apply_get_attr( const GetAttr& get_attr, Span inputs) { if (auto* traced_value = inputs[0].as(m_value_type)) { ValueRef output; - auto& var_accessor = traced_value->accessor(); switch (get_attr.attr()) { case GetAttr::Shape: output = traced_value->shape(); break; case GetAttr::Data: - trace_assert(var_accessor.data_getter, "data unreadable"); - output = DeviceValue::make(var_accessor.data_getter()); + output = traced_value->data(); break; case GetAttr::Value: - trace_assert(var_accessor.value_getter, "value unreadable"); - output = HostValue::make(var_accessor.value_getter()); + output = traced_value->value(); break; case GetAttr::DType: output = traced_value->dtype(); break; case GetAttr::Device: output = traced_value->comp_node(); + break; default: break; } @@ -745,8 +788,7 @@ ValueRefList CompiledTransformation::apply_create_tensor( if (!tensor) { tensor = imperative::apply(create_tensor, inputs)[0]; } - trace_input(input_id, tensor); - return {trace_output(output_id)}; + return {trace_output(output_id, trace_input(input_id, tensor))}; } ValueRefList CompiledTransformation::apply_transformation( @@ -762,21 +804,21 @@ ValueRefList CompiledTransformation::apply_transformation( trace_assert(item.op == nullptr, "operator mismatch"); trace_assert(item.inputs.size() == 1, "inputs size mismatch"); trace_assert(item.outputs.size() == 1, "inputs output mismatch"); - trace_input(item.inputs[0], inputs[0]); + auto value = trace_input(item.inputs[0], inputs[0]); trace_assert( trace_mark_var->mark() == m_vars[item.outputs[0]].mark, "mark mismatch"); - return {trace_output(item.outputs[0])}; + return {trace_output(item.outputs[0], value)}; } else if (auto* trace_name_var = op.as()) { auto& item = next_instruction(); trace_assert(item.op == nullptr, "operator mismatch"); trace_assert(item.inputs.size() == 1, "inputs size mismatch"); trace_assert(item.outputs.size() == 1, "outputs size mismatch"); - trace_input(item.inputs[0], inputs[0]); + auto value = trace_input(item.inputs[0], inputs[0]); trace_assert( trace_name_var->name() == m_vars[item.outputs[0]].name, "name mismatch"); - return {trace_output(item.outputs[0])}; + return {trace_output(item.outputs[0], value)}; } else { return op.fallback(inputs); } @@ -786,11 +828,9 @@ void CompiledTransformation::on_unregister() noexcept { // resolve pending values for (auto&& weak_value : m_weak_values) { if (auto traced_value = weak_value.lock()) { - auto& var_accessor = m_var_accessors[traced_value->id()]; auto value = ([&]() -> ValueRef { try { - trace_assert(var_accessor.data_getter, "data unreadable"); - auto dev_value = DeviceValue::make(var_accessor.data_getter()); + auto dev_value = traced_value->data(); return imperative::apply( CreateTensor( CreateTensor::Common, dev_value->device(), @@ -821,10 +861,9 @@ void CompiledTransformation::wait() { trace_assert(m_pc == m_seq.size(), "mismature end"); } catch (...) { } - mgb_assert(m_executable != nullptr); - std::unique_lock lock{m_mutex}; - m_cv.wait(lock, [&] { return m_graph_status == 0; }); - lock.unlock(); + if (!m_imperative) { + wait_worker(); + } for (auto&& box : m_boxes) { box->reset(); } @@ -839,6 +878,13 @@ void CompiledTransformation::wait() { } } +void CompiledTransformation::wait_worker() { + mgb_assert(m_executable != nullptr); + std::unique_lock lock{m_mutex}; + m_cv.wait(lock, [&] { return m_graph_status == 0; }); + lock.unlock(); +} + std::exception_ptr CompiledTransformation::set_exception( std::exception_ptr exc) noexcept { MGB_LOCK_GUARD(m_mutex); diff --git a/imperative/src/include/megbrain/imperative/transformations/trace.h b/imperative/src/include/megbrain/imperative/transformations/trace.h index 1be4d54c1..36bc3431f 100644 --- a/imperative/src/include/megbrain/imperative/transformations/trace.h +++ b/imperative/src/include/megbrain/imperative/transformations/trace.h @@ -265,12 +265,14 @@ public: using OpKind = TraceResult::SeqItem::OpKind; struct VarAccessor { - VarNode* node; + VarNode* node; // use imperative mode when node == nullptr std::function shape_getter; std::function data_getter; std::function value_getter; std::function data_setter; std::function exc_setter; + + bool is_imperative() const { return node == nullptr; } }; class TracedValue final : public ObjectValue { @@ -281,6 +283,7 @@ public: mutable ShapeValue::ref_t m_shape; mutable DTypeValue::ref_t m_dtype; mutable CompNodeValue::ref_t m_comp_node; + mutable ValueRef m_imperative_value; public: TracedValue(size_t id, VarInfo* var, VarAccessor* accessor) @@ -289,9 +292,12 @@ public: ShapeValue::ref_t shape() const; DTypeValue::ref_t dtype() const; CompNodeValue::ref_t comp_node() const; + DeviceValue::ref_t data() const; + HostValue::ref_t value() const; const VarAccessor& accessor() const; void set_exception(std::exception_ptr exc) const { + mgb_assert(m_accessor->exc_setter, "exc setter invalid"); m_accessor->exc_setter(exc); } @@ -299,7 +305,11 @@ public: return ssprintf("TracedValue{\"id\"=%zu}", id()); } - void clear() override {} + void clear() override { m_imperative_value = {}; } + + void set_imperative_value(ValueRef value) const { m_imperative_value = value; } + + ValueRef get_imperative_value() const { return m_imperative_value; } }; private: @@ -322,15 +332,23 @@ private: ComputingGraph::OutputSpec m_output_spec; ObjectType m_value_type{"TracedValue"}; std::set m_setted_extern; + bool m_imperative = false; public: - CompiledTransformation(TraceResult result, bool input_shape_static) + CompiledTransformation(TraceResult result, bool input_shape_static, bool imperative) : m_seq(result.seq), m_vars(result.vars), - m_input_shape_static(input_shape_static) { + m_input_shape_static(input_shape_static), + m_imperative(imperative) { m_graph = ComputingGraph::make(); options().no_force_inplace = true; options().async_exec_level = 0b100; + if (!m_imperative) { + start_worker(); + } + } + + void start_worker() { m_graph_executor = std::thread([&] { while (true) { std::unique_lock lock{m_mutex}; @@ -384,7 +402,7 @@ public: * \param id * \param value */ - void trace_input(size_t id, ValueRef value); + ValueRef trace_input(size_t id, ValueRef value); /** * \brief make a placeholder for output. @@ -393,7 +411,7 @@ public: * \return TracedValue::ref_t output placeholder, would be reset to real value when * trace exits */ - TracedValue::ref_t trace_output(size_t id); + TracedValue::ref_t trace_output(size_t id, ValueRef value); TraceResult::SeqItem& next_instruction(); @@ -422,6 +440,8 @@ public: void wait(); + void wait_worker(); + std::exception_ptr set_exception(std::exception_ptr exc) noexcept; template @@ -431,7 +451,7 @@ public: return box; } - ~CompiledTransformation() { + void stop_worker() { { MGB_LOCK_GUARD(m_mutex); m_graph_status = 2; @@ -439,6 +459,12 @@ public: m_cv.notify_all(); m_graph_executor.join(); } + + ~CompiledTransformation() { + if (!m_imperative) { + stop_worker(); + } + } }; } // namespace mgb::imperative -- GitLab