提交 dc250745 编写于 作者: M Megvii Engine Team

feat(mge): add python custom op

GitOrigin-RevId: 35da0bb3017bdf90f7074bc84d9f3321672aad79
上级 60c44b08
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "./grad.h" #include "./grad.h"
#include "megbrain/imperative/proxy_graph_detail.h" #include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/utils/mempool.h" #include "megbrain/utils/mempool.h"
#include "range/v3/all.hpp" #include "range/v3/all.hpp"
...@@ -21,6 +22,9 @@ namespace views = ranges::views; ...@@ -21,6 +22,9 @@ namespace views = ranges::views;
namespace mgb::imperative::python { namespace mgb::imperative::python {
using scoped_disable = ApplyContext::scoped_disable;
using Flags = Tensor::Flags;
namespace { namespace {
struct GradSlotWeakPtr { struct GradSlotWeakPtr {
...@@ -78,6 +82,21 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( ...@@ -78,6 +82,21 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph(
return result; return result;
} }
struct BackwardContext {
PyTypeObject* pytype = nullptr;
auto wrap_tensor(std::shared_ptr<Tensor> t) {
if (pytype) {
return TensorWrapper::make(pytype, std::move(t));
}
return TensorWrapper::make(std::move(t));
}
auto wrap_tensor(Tensor* t) {
return wrap_tensor(t->shared_from_this());
}
};
struct BackwardGraphWithClosure { struct BackwardGraphWithClosure {
std::shared_ptr<BackwardGraphResult> backward_graph; std::shared_ptr<BackwardGraphResult> backward_graph;
SmallVector<std::shared_ptr<Tensor>> closure; SmallVector<std::shared_ptr<Tensor>> closure;
...@@ -119,7 +138,7 @@ struct BackwardGraphWithClosure { ...@@ -119,7 +138,7 @@ struct BackwardGraphWithClosure {
} }
template <typename T, typename R> template <typename T, typename R>
void operator()(T&& grads, R&& receiver) { void operator()(BackwardContext&, T&& grads, R&& receiver) {
Tensor* args[closure.size() + grads.size()]; Tensor* args[closure.size() + grads.size()];
size_t nargs = 0; size_t nargs = 0;
for (auto&& t : closure) { for (auto&& t : closure) {
...@@ -143,7 +162,7 @@ struct BackwardGraphWithClosure { ...@@ -143,7 +162,7 @@ struct BackwardGraphWithClosure {
ApplyContext ctx; ApplyContext ctx;
ctx.op = backward_graph->backward; ctx.op = backward_graph->backward;
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; ctx.flags = is_tracing ? Flags::TRACE : 0;
ctx.nargs = nargs; ctx.nargs = nargs;
ctx.args = args; ctx.args = args;
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
...@@ -174,6 +193,47 @@ struct BackwardGraphWithClosure { ...@@ -174,6 +193,47 @@ struct BackwardGraphWithClosure {
} }
}; };
struct PythonBackward {
py::object pyfunc;
size_t input_size;
PythonBackward(py::object f, size_t nin)
: pyfunc(f), input_size(nin) {}
template <typename T, typename R>
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) {
auto args = py::tuple(grads.size());
for (size_t i = 0; i < grads.size(); ++i) {
auto&& g = grads[i];
args[i] = g ? ctx.wrap_tensor(g) : py::none();
}
auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
if (input_grads.is_none()) return;
if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
if (input_size != 1) {
throw py::value_error("custom grad rule returned wrong number of grads");
}
receiver(0, tw->m_tensor);
return;
}
if (py::len(input_grads) != input_size) {
throw py::value_error("custom grad rule returned wrong number of grads");
}
for (auto [i, g] : views::enumerate(input_grads)) {
if (g.is_none()) continue;
auto* tw = TensorWrapper::try_cast(g.ptr());
if (!tw) {
throw py::type_error("custom grad rule returned non-tensor");
}
receiver(i, tw->m_tensor);
}
}
static constexpr bool input_has_grad(size_t) {return true;}
static constexpr bool output_requires_grad(size_t) {return true;}
static constexpr bool output_captured(size_t) {return true;}
};
} // namespace } // namespace
struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> {
...@@ -210,7 +270,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> { ...@@ -210,7 +270,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
// same length as inputs (of forward op) // same length as inputs (of forward op)
SmallVector<GradSlotProducerPtr> dsts; SmallVector<GradSlotProducerPtr> dsts;
// encapsules actual function to compute gradient // encapsules actual function to compute gradient
std::variant<std::monostate, BackwardGraphWithClosure> backward; std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward> backward;
// a flag used during backward // a flag used during backward
bool in_ref_keeper = false; bool in_ref_keeper = false;
...@@ -268,6 +328,30 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra ...@@ -268,6 +328,30 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra
return outputs; return outputs;
} }
apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
auto* op = ctx.op->try_cast_final<GenericPyOp>();
py::tuple pyin(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
}
auto grad_rule = py::getattr(op->obj, "_grad_rule");
auto pyret = (scoped_disable(Flags::GRAD),
py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr))); // comma expression
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
return {tw->m_tensor};
}
apply_result_t ret;
ret.reserve(py::len(outputs));
for (auto&& i : outputs) {
auto* tw = TensorWrapper::try_cast(i.ptr());
mgb_assert(tw);
ret.push_back(tw->m_tensor);
}
return ret;
}
} // namespace } // namespace
apply_result_t apply_grad(ApplyContext& ctx) { apply_result_t apply_grad(ApplyContext& ctx) {
...@@ -290,21 +374,23 @@ apply_result_t apply_grad(ApplyContext& ctx) { ...@@ -290,21 +374,23 @@ apply_result_t apply_grad(ApplyContext& ctx) {
// cleanup stale grad info // cleanup stale grad info
// under what condition? // under what condition?
tensor->m_grad_info = {}; tensor->m_grad_info = {};
tensor->m_flags &= ~Tensor::Flags::GRAD; tensor->m_flags &= ~Flags::GRAD;
} }
} else { } else {
tensor->m_flags &= ~Tensor::Flags::GRAD; tensor->m_flags &= ~Flags::GRAD;
} }
} }
ctx.flags &= ~Tensor::Flags::GRAD; ctx.flags &= ~Flags::GRAD;
if (!grad_key) { if (!grad_key) {
return apply(ctx); return apply(ctx);
} }
GradFnHelper grad_fn_holder; GradFnHelper grad_fn_holder;
auto outputs = backward_graph_grad_rule(ctx, grad_fn_holder); auto outputs = ctx.op->same_type<GenericPyOp>() ?
python_grad_rule(ctx, grad_fn_holder) :
backward_graph_grad_rule(ctx, grad_fn_holder);
auto& grad_fn = grad_fn_holder.grad_fn; auto& grad_fn = grad_fn_holder.grad_fn;
if (!grad_fn) { if (!grad_fn) {
...@@ -341,7 +427,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { ...@@ -341,7 +427,7 @@ apply_result_t apply_grad(ApplyContext& ctx) {
grad_info.grad_fn = grad_fn; grad_info.grad_fn = grad_fn;
grad_info.idx = i; grad_info.idx = i;
grad_info.insert_after(grad_key->free_vars_head); grad_info.insert_after(grad_key->free_vars_head);
outputs[i]->m_flags |= Tensor::Flags::GRAD; outputs[i]->m_flags |= Flags::GRAD;
} }
} }
} }
...@@ -357,7 +443,7 @@ void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { ...@@ -357,7 +443,7 @@ void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
if (nargs != 2) { if (nargs != 2) {
throw py::type_error("expect 2 arguments"); throw py::type_error("expect 2 arguments");
} }
auto* tw = TensorWrapper::cast_safe(args[0]); auto* tw = TensorWrapper::try_cast(args[0]);
if (!tw) { if (!tw) {
throw py::type_error("argument 1 must be Tensor"); throw py::type_error("argument 1 must be Tensor");
} }
...@@ -390,14 +476,15 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) { ...@@ -390,14 +476,15 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) {
grad_fn->key = shared_from_this(); grad_fn->key = shared_from_this();
grad_fn->slots.resize(1); grad_fn->slots.resize(1);
tensor->m_grad_info.insert_after(free_vars_head); tensor->m_grad_info.insert_after(free_vars_head);
tensor->m_flags |= Tensor::Flags::GRAD; tensor->m_flags |= Flags::GRAD;
} }
tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback);
} }
void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) { template<typename T>
void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) {
if (!grad) { if (!grad) {
grad = std::forward<decltype(delta)>(delta); grad = std::forward<T>(delta);
return; return;
} }
static ApplyContext ctx; static ApplyContext ctx;
...@@ -409,7 +496,7 @@ void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) ...@@ -409,7 +496,7 @@ void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta)
ctx.args = args; ctx.args = args;
ctx.flags = grad->m_flags | delta->m_flags; ctx.flags = grad->m_flags | delta->m_flags;
if (is_tracing) { if (is_tracing) {
ctx.flags |= Tensor::Flags::TRACE; ctx.flags |= Flags::TRACE;
} }
grad = apply(ctx)[0]; grad = apply(ctx)[0];
} }
...@@ -440,6 +527,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr ...@@ -440,6 +527,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
} }
} }
BackwardContext bctx{pytype};
std::vector<std::shared_ptr<GradFn>> ref_keeper; std::vector<std::shared_ptr<GradFn>> ref_keeper;
ref_keeper.reserve(tape.size()); ref_keeper.reserve(tape.size());
// back-propagation in reverse order // back-propagation in reverse order
...@@ -456,7 +544,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr ...@@ -456,7 +544,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
mgb_assert(0); mgb_assert(0);
} else { } else {
auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();}); auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();});
backward(std::forward<decltype(grads)>(grads), grad_receiver); backward(bctx, std::forward<decltype(grads)>(grads), grad_receiver);
} }
}, grad_fn->backward); }, grad_fn->backward);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include <Python.h> #include <Python.h>
...@@ -245,6 +246,35 @@ void _init_py_backward_graph(py::module m) { ...@@ -245,6 +246,35 @@ void _init_py_backward_graph(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second);
} }
struct PyOpBase : PyOpDef {
static PyTypeObject py_type;
static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) {
auto* obj = type->tp_alloc(type, 0);
if (obj) {
auto* self = reinterpret_cast<PyOpBase*>(obj);
new(&self->op) decltype(self->op);
}
return obj;
}
};
PyTypeObject PyOpBase::py_type;
void _init_py_op_base(py::module m) {
using py_op = PyOpBase;
auto& py_type = PyOpBase::py_type;
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase";
py_type.tp_basicsize = sizeof(py_op);
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "PyOpBase";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_op::tp_new;
mgb_assert(PyType_Ready(&py_type) >= 0);
m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type));
}
/*********** end of hand-write opdefs **************/ /*********** end of hand-write opdefs **************/
// auto generated opdefs // auto generated opdefs
...@@ -260,9 +290,16 @@ bool type_caster<OpDef>::load(handle src, bool convert) { ...@@ -260,9 +290,16 @@ bool type_caster<OpDef>::load(handle src, bool convert) {
return false; return false;
} }
value = reinterpret_cast<PyOp(OpDef)*>(obj)->op; value = reinterpret_cast<PyOp(OpDef)*>(obj)->op;
if (!value) {
// opdef only defined in Python
value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src));
}
return true; return true;
} }
handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
if (auto* pyop = op.try_cast_final<GenericPyOp>()) {
return object(pyop->obj).release();
}
PyTypeObject* pytype; PyTypeObject* pytype;
auto& c2p = PyOp(OpDef)::ctype2pytype; auto& c2p = PyOp(OpDef)::ctype2pytype;
auto&& iter = c2p.find(op.dyn_typeinfo()); auto&& iter = c2p.find(op.dyn_typeinfo());
...@@ -283,5 +320,6 @@ handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { ...@@ -283,5 +320,6 @@ handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
void init_ops(py::module m) { void init_ops(py::module m) {
_init_py_op_def(m); _init_py_op_def(m);
_init_py_backward_graph(m); _init_py_backward_graph(m);
_init_py_op_base(m);
INIT_ALL_OP(m) INIT_ALL_OP(m)
} }
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "megbrain/dtype.h" #include "megbrain/dtype.h"
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/imperative/ops/utility.h"
#include "./tensor.h" #include "./tensor.h"
#include "./grad.h" #include "./grad.h"
...@@ -22,10 +23,12 @@ ...@@ -22,10 +23,12 @@
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
#include <pybind11/operators.h> #include <pybind11/operators.h>
#include <range/v3/all.hpp>
#include <unordered_map> #include <unordered_map>
namespace py = pybind11; namespace py = pybind11;
namespace views = ranges::views;
namespace mgb::imperative::python { namespace mgb::imperative::python {
...@@ -69,21 +72,45 @@ SET_UNSET_PROP(compiled) ...@@ -69,21 +72,45 @@ SET_UNSET_PROP(compiled)
bool skip_tracing = false; bool skip_tracing = false;
Tensor::flags_t ApplyContext::global_disable = 0;
apply_result_t apply(ApplyContext& ctx) { apply_result_t apply(ApplyContext& ctx) {
// emulating scalar should be put to specific op's apply, e.g., // emulating scalar should be put to specific op's apply, e.g.,
// elementwise, reduce, typecvt. Currently it's still handled at python // elementwise, reduce, typecvt. Currently it's still handled at python
// side. It could be move to C++ side if it has an impact on performance // side. It could be move to C++ side if it has an impact on performance
if (ctx.flags & Tensor::Flags::SCALAR) { auto flags = ctx.flags & ~ApplyContext::global_disable;
if (flags & Tensor::Flags::SCALAR) {
// TODO: emulate scalar // TODO: emulate scalar
} }
if (ctx.flags & Tensor::Flags::GRAD) { if (flags & Tensor::Flags::GRAD) {
return apply_grad(ctx); return apply_grad(ctx);
} }
if (ctx.flags & Tensor::Flags::TRACE) { if (flags & Tensor::Flags::TRACE) {
return apply_trace(ctx); return apply_trace(ctx);
} else { } else {
if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) {
py::tuple pyin(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) {
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
}
auto f = py::getattr(op->obj, "_default_rule");
auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr));
if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) {
return {tw->m_tensor};
}
apply_result_t ret;
ret.reserve(py::len(pyout));
for (auto&& i : pyout) {
auto* tw = TensorWrapper::try_cast(i.ptr());
mgb_assert(tw);
ret.push_back(tw->m_tensor);
}
return ret;
}
SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; ++i) { for (size_t i = 0; i < ctx.nargs; ++i) {
handles[i] = ctx.args[i]->m_handle.get(); handles[i] = ctx.args[i]->m_handle.get();
...@@ -125,12 +152,13 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje ...@@ -125,12 +152,13 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
SmallVector<Tensor*, 64> tensors(nargs); SmallVector<Tensor*, 64> tensors(nargs);
ctx.args = &tensors[0]; ctx.args = &tensors[0];
ctx.nargs = nargs; ctx.nargs = nargs;
ctx.pytype = pytype;
if (strstr(op->ob_type->tp_name, "BackwardGraph")) { if (strstr(op->ob_type->tp_name, "BackwardGraph")) {
ctx.backward = true; ctx.backward = true;
} }
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
if (TensorWrapper* tw = TensorWrapper::cast_safe(args[i])) { if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
auto* t = tensors[i] = tw->m_tensor.get(); auto* t = tensors[i] = tw->m_tensor.get();
ctx.flags |= t->m_flags; ctx.flags |= t->m_flags;
} else { } else {
...@@ -166,7 +194,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -166,7 +194,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
if (nargs == 0) { if (nargs == 0) {
throw py::type_error("too few arguments"); throw py::type_error("too few arguments");
} }
if (auto* t = cast_safe(tup[0].ptr())) { if (auto* t = try_cast(tup[0].ptr())) {
if (nargs > 1) { if (nargs > 1) {
throw py::type_error("expect 1 argument"); throw py::type_error("expect 1 argument");
} }
...@@ -211,7 +239,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -211,7 +239,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
auto ret = pyf(*tup); auto ret = pyf(*tup);
auto py_ret = py::reinterpret_borrow<py::list>(ret); auto py_ret = py::reinterpret_borrow<py::list>(ret);
if (auto* t = cast_safe(py_ret[0].ptr())) { if (auto* t = try_cast(py_ret[0].ptr())) {
m_tensor = t->m_tensor; m_tensor = t->m_tensor;
} }
return; return;
...@@ -349,7 +377,7 @@ PyObject* TensorWrapper::varnode() { ...@@ -349,7 +377,7 @@ PyObject* TensorWrapper::varnode() {
} }
void TensorWrapper::reset(PyObject* tensor) { void TensorWrapper::reset(PyObject* tensor) {
TensorWrapper* t = TensorWrapper::cast_safe(tensor); TensorWrapper* t = TensorWrapper::try_cast(tensor);
if (!t) { if (!t) {
throw py::type_error("expect Tensor"); throw py::type_error("expect Tensor");
} }
...@@ -446,7 +474,7 @@ uint8_t max_priority(SmallVector<PyArray_Descr*> types) { ...@@ -446,7 +474,7 @@ uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
} }
} }
// Returns the data type with sufficient size to hold all types of // Returns the data type with sufficient size to hold all types of
// category `cat` in the list `types`. // category `cat` in the list `types`.
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
// Return value: New reference // Return value: New reference
...@@ -507,7 +535,7 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { ...@@ -507,7 +535,7 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
if (handle == Py_None) continue; if (handle == Py_None) continue;
TensorWrapper* tw = TensorWrapper::cast_safe(handle); TensorWrapper* tw = TensorWrapper::try_cast(handle);
if (tw) { if (tw) {
mgb::DType type = tw->m_tensor->dtype(); mgb::DType type = tw->m_tensor->dtype();
auto&& descr = npy::dtype_mgb2np_descr(type); auto&& descr = npy::dtype_mgb2np_descr(type);
...@@ -562,7 +590,7 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { ...@@ -562,7 +590,7 @@ CompNode _get_device(PyObject*const* args, size_t nargs) {
CompNode cn; CompNode cn;
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
TensorWrapper* tw = TensorWrapper::cast_safe(handle); TensorWrapper* tw = TensorWrapper::try_cast(handle);
if (tw) { if (tw) {
if (!valid) { if (!valid) {
cn = tw->m_tensor->comp_node(); cn = tw->m_tensor->comp_node();
......
...@@ -124,7 +124,7 @@ struct TensorWrapper { ...@@ -124,7 +124,7 @@ struct TensorWrapper {
friend wrap_t; friend wrap_t;
inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();} inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();}
inline static TensorWrapper* cast_safe(PyObject* op) { inline static TensorWrapper* try_cast(PyObject* op) {
if (!wrap_t::type().isinstance(op)) return nullptr; if (!wrap_t::type().isinstance(op)) return nullptr;
return cast(op); return cast(op);
} }
...@@ -173,11 +173,26 @@ struct TensorWrapper { ...@@ -173,11 +173,26 @@ struct TensorWrapper {
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */);
struct ApplyContext { struct ApplyContext {
static Tensor::flags_t global_disable;
Tensor::flags_t flags; Tensor::flags_t flags;
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
Tensor*const* args; Tensor*const* args;
size_t nargs; size_t nargs;
PyTypeObject* pytype = nullptr;
bool backward = false; bool backward = false;
class scoped_disable : NonCopyableObj {
Tensor::flags_t saved_flags;
public:
scoped_disable(Tensor::flags_t flags) : saved_flags(ApplyContext::global_disable) {
ApplyContext::global_disable |= flags;
}
~scoped_disable() {
ApplyContext::global_disable = saved_flags;
}
};
}; };
using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;
......
...@@ -85,7 +85,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { ...@@ -85,7 +85,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
// assumption: python function always returns PyList // assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret); auto tup = py::reinterpret_borrow<py::list>(ret);
for (auto i = 0; i < tup.size(); i++) { for (auto i = 0; i < tup.size(); i++) {
auto tw = TensorWrapper::cast_safe(tup[i].ptr()); auto tw = TensorWrapper::try_cast(tup[i].ptr());
outputs.emplace_back(tw->m_tensor); outputs.emplace_back(tw->m_tensor);
} }
return outputs; return outputs;
......
/**
* \file imperative/src/impl/ops/utility.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/utility.h"
#include "../op_trait.h"
namespace mgb::imperative {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp);
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/ops/utility.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h"
#include <pybind11/pybind11.h>
namespace mgb::imperative {
struct GenericPyOp final : OpDefImplBase<GenericPyOp> {
pybind11::object obj;
GenericPyOp(pybind11::object obj_) : obj(std::move(obj_)) {};
size_t hash() const override {
return pybind11::hash(obj);
}
bool is_same_st(const Hashable& rhs) const override {
return obj.equal(static_cast<const GenericPyOp&>(rhs).obj);
}
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册