提交 ae47fd4e 编写于 作者: M Megvii Engine Team

fix(mge): fix none return value for attrs, add test_correctness

GitOrigin-RevId: 1bb96373f450c74ccf9e640cd4b2c73579f3c398
上级 97d12b3e
......@@ -20,10 +20,10 @@ import numpy as np
from ..core._imperative_rt import GraphProfiler, common, put
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import TensorWeakRef
from ..core._imperative_rt.core2 import __make_empty_tensor as make_empty_tensor
from ..core._imperative_rt.core2 import (
TensorWeakRef,
apply,
call_level,
set_compiled,
set_symbolic,
set_tracing,
......@@ -86,6 +86,9 @@ class TensorInfo:
__slots__ = (
# collected attributes
"external",
"data_read",
"shape_read",
"value_read",
"exported",
"device",
"dtype",
......@@ -102,6 +105,9 @@ class TensorInfo:
def __init__(self):
self.exported = None
self.data_read = None
self.shape_read = None
self.value_read = None
self.bound_data = None
self.data_setter = None
......@@ -154,7 +160,7 @@ class trace:
self._graph_opt_level = opt_level
self._symbolic_shape = symbolic_shape
self._handle2tensors = {}
self._handle2compiledtensors = {}
self._output_handles = set()
self._reset()
......@@ -244,11 +250,12 @@ class trace:
# )
self._pc += 1
outputs = []
for h in ohandles:
t = CompiledTensorProxy(h)
t._dev_tensor()
self._handle2compiledtensors[h] = t
outputs = [self._handle2tensors[h] for h in ohandles]
outputs += [t._CompiledTensorProxy__tensor]
self._output_handles.update(ohandles)
self._active_tensors.update([TensorWeakRef(o) for o in outputs])
return outputs
......@@ -347,11 +354,12 @@ class trace:
self._lazy_eval_links = ()
def _take_escaped_tensors(self):
escaped_tensors = tuple(self._active_tensors)
escaped_tensors = tuple(filter(lambda x: x() is not None, self._active_tensors))
self._active_tensors.clear()
return escaped_tensors
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
lazy_eval_tensors = list(filter(lambda x: x() is not None, lazy_eval_tensors))
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
self._apply_graph_options(lazy_eval_graph)
# FIXME
......@@ -393,6 +401,12 @@ class trace:
if self._inputs_to_restore:
for x in self._inputs_to_restore:
x.mixin_handle = -1
for h, x in list(self._handle2tensors.items()):
info = self._tinfo[h]
info.data_read = x.data_read
info.shape_read = x.shape_read
info.value_read = x.value_read
del self._handle2tensors[h]
if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links
):
......@@ -433,8 +447,9 @@ class trace:
raise TraceMismatchError("premature end")
if not self._symbolic or not self._untraced:
for x in self._active_tensors:
x()._dev_tensor()
x().mixin_handle = -1
if x() is not None:
x()._dev_tensor()
x().mixin_handle = -1
try:
do_enter()
......@@ -581,8 +596,7 @@ class trace:
readers.append(opnode.outputs[0])
in_out_links = opnode.outputs
x = self._handle2tensors[h]
if x.data_read:
if info.data_read:
# Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately
# to leverage eager h2d copy
......@@ -890,7 +904,7 @@ class trace:
self._output_bindings.append(h)
else:
h = x.mixin_handle
if h not in self._handle2compiledtensors:
if h not in self._output_handles:
raise RuntimeError("output is not computed from inputs")
if h != self._output_bindings[i]:
raise TraceMismatchError(
......@@ -927,8 +941,7 @@ class CompiledTensorProxy:
self.__shape = None
self.__data = None
self.__value = None
self.__tensor = active_trace._handle2tensors[handle]
self.__tensor.mixin_handle = handle
self.__tensor = make_empty_tensor()
@property
def dtype(self):
......@@ -943,19 +956,19 @@ class CompiledTensorProxy:
if self._isscalar:
return ()
if self.__shape is None:
if self.__tensor.shape_read:
if self.__info.shape_read:
self.__shape = self.__info.shape_reader.get_value().shape
elif self.__tensor.data_read:
self.__shape = self.__tensor._dev_tensor().shape
elif self.__info.data_read:
self.__shape = self.__info._dev_tensor().shape
else:
raise TraceMismatchError("shape of this tensor is not read in trace")
return self.__shape
def numpy(self):
if self.__value is None:
if self.__tensor.value_read:
if self.__info.value_read:
self.__value = self.__info.value_reader.get_value()
elif self.__tensor.data_read:
elif self.__info.data_read:
self.__value = self._dev_tensor().numpy()
else:
raise TraceMismatchError("value of this tensor is not read in trace")
......@@ -965,7 +978,7 @@ class CompiledTensorProxy:
def _dev_tensor(self):
if self.__data is None:
if not self.__tensor.data_read:
if not self.__info.data_read:
raise TraceMismatchError("raw data of this tensor is not read in trace")
self.__data = self.__info.data_reader.get_value()
self.__tensor._reset(RawTensor(self.__data))
......
......@@ -53,9 +53,6 @@ bool is_tracing = false;
bool is_symbolic = false;
bool is_compiled = false;
int64_t call_level = 0;
#define SET_UNSET_PROP(mode) \
void set_##mode() { \
is_##mode = true; \
......@@ -321,17 +318,22 @@ PyObject* TensorWrapper::numpy() {
auto&& type = mgr.get_infer_type(m_tensor->m_var);
using InferType = cg::static_infer::InferType;
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr;
}
auto* val = mgr.infer_value_fallible(m_tensor->m_var);
if (!val) {
PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr;
}
return py::cast(*val).attr("numpy")().release().ptr();
}
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get());
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
if (!arr) return nullptr;
if (!arr) {
PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr;
}
if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
mgb_assert(PyArray_Check(arr.ptr()));
return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
......@@ -343,7 +345,7 @@ PyObject* TensorWrapper::varnode() {
if (m_tensor->m_var) {
return py::cast(m_tensor->m_var).release().ptr();
}
return nullptr;
return py::none().release().ptr();
}
void TensorWrapper::reset(PyObject* tensor) {
......@@ -364,6 +366,7 @@ PyObject* TensorWrapper::detach() {
} else {
new_tensor = std::make_shared<Tensor>(m_tensor->m_var);
}
new_tensor->m_trace_info = m_tensor->m_trace_info;
auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
return ret.release().ptr();
......@@ -628,6 +631,10 @@ WRAP_FUNC_PY35(get_device);
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
#endif
py::object make_empty_tensorwrapper() {
return TensorWrapper::make(std::move(std::make_shared<Tensor>()));
}
void init_tensor(py::module m) {
interpreter_for_py = interpreter::Interpreter::inst().create_channel();
......@@ -699,7 +706,6 @@ void init_tensor(py::module m) {
m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);
m.attr("skip_tracing") = &skip_tracing;
m.attr("call_level") = &call_level;
py::class_<SharedHandle>(m, "SharedHandle")
.def(py::init<const SharedHandle&>());
......@@ -711,6 +717,7 @@ void init_tensor(py::module m) {
m.def("set_compiled", &set_compiled);
m.def("unset_compiled", &unset_compiled);
m.def("__make_empty_tensor", &make_empty_tensorwrapper);
}
#undef MGE_PY_INTERFACE
......
......@@ -74,6 +74,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
using Handle = interpreter::Interpreter::Handle;
inline Tensor() : m_handle(nullptr), m_var(nullptr) {}
inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {}
inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {}
inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {}
......@@ -188,7 +189,6 @@ void init_tensor(pybind11::module);
extern bool is_tracing;
extern bool is_symbolic;
extern bool is_compiled;
extern int64_t call_level;
extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode;
extern pybind11::object cpp_apply_backward_varnode;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册