提交 75129cf0 编写于 作者: M Megvii Engine Team

chore(mge): clean up before merge to dev

* remove dead test
* clean some codes
* fix test_fake_quant due to change of implementation

GitOrigin-RevId: f030a9966d1664cd4a75cc9a1a992174610b70c0
上级 aba0acc7
......@@ -16,7 +16,6 @@ from typing import Dict, List, Union
import numpy as np
from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id
from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
......@@ -26,6 +25,19 @@ from ..ops.builtin import OpDef
from .core import OpBase, TensorBase
def set_priority_to_id(dest_vars):
"""
For all oprs in the subgraph constructed by dest_vars,
sets its priority to id if its original priority is zero.
:param dest_vars: target vars representing the graph.
"""
dest_vec = []
for i in dest_vars:
assert isinstance(i, _imperative_rt.VarNode)
dest_vec.append(i)
_imperative_rt.graph._set_priority_to_id(dest_vec)
class Graph(_imperative_rt.ComputingGraph):
def __init__(self):
super().__init__()
......@@ -46,8 +58,8 @@ class Graph(_imperative_rt.ComputingGraph):
cache[obj] = wrapper(obj)
return cache[obj]
def set_priority_to_id(self, dest_vars):
_set_priority_to_id(_unwrap(dest_vars))
def _set_priority_to_id(self, dest_vars):
set_priority_to_id(_unwrap(dest_vars))
def compile(self, *args):
self._function = super().compile(_unwrap(args))
......
......@@ -9,7 +9,7 @@
import functools
import multiprocessing as mp
from ..core._imperative_rt import sync
from ..core._imperative_rt.core2 import sync
from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork
from .server import Server
......
......@@ -367,7 +367,7 @@ class trace:
lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
else:
lazy_eval_graph.options.graph_opt_level = 2
lazy_eval_graph.set_priority_to_id([*lazy_eval_links, *readers])
lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
lazy_eval_graph.compile(*lazy_eval_links, *readers)
lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors):
......@@ -618,7 +618,7 @@ class trace:
graph.options.graph_opt_level = self._graph_opt_level
else:
graph.options.graph_opt_level = 2
graph.set_priority_to_id([*readers, *in_out_links, *io_links])
graph._set_priority_to_id([*readers, *in_out_links, *io_links])
graph.compile(*readers, *in_out_links, *io_links)
def _reset_exec_env(self):
......
......@@ -13,6 +13,7 @@ import numpy
from ..core import _imperative_rt
from ..core._imperative_rt import OperatorNode, VarNode
from ..core.tensor import megbrain_graph as G
from ..core.tensor.megbrain_graph import set_priority_to_id
from ..tensor import Tensor
__all__ = [
......@@ -271,19 +272,6 @@ def replace_oprs(
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
def set_priority_to_id(dest_vars):
"""
For all oprs in the subgraph constructed by dest_vars,
sets its priority to id if its original priority is zero.
:param dest_vars: target vars representing the graph.
"""
dest_vec = []
for i in dest_vars:
assert isinstance(i, VarNode)
dest_vec.append(i)
_imperative_rt.graph._set_priority_to_id(dest_vec)
def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]:
"""
Loads a serialized computing graph and run inference with input data.
......
......@@ -32,7 +32,7 @@ struct GradSlotWeakPtr {
size_t idx;
};
struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject {
struct BackwardGraphCache : std::unordered_map<uint64_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject {
std::shared_ptr<void> on_comp_node_finalize() override {
clear();
return {};
......@@ -56,7 +56,7 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph(
}
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) &&
bool_ptr == reinterpret_cast<bool*>(buf + buf_size));
size_t key = XXHash{}.update(buf, buf_size).digest();
uint64_t key = XXHash{}.update(buf, buf_size).digest();
auto&& iter = backward_graph_cache.find(key);
if (iter != backward_graph_cache.end()) {
......
......@@ -32,7 +32,7 @@ inline bool _is_quantize(PyArray_Descr* dtype) {
PyObject* _get_mgb_dtype(PyArray_Descr* dtype) {
// Return value: New reference.
if (!_is_quantize(dtype)) {
throw py::type_error("expact quantize dtype");
throw py::type_error("expect quantize dtype");
}
PyObject* ob = PyDict_GetItemString(dtype->metadata, "mgb_dtype");
if (!PyDict_CheckExact(ob)) {
......
......@@ -143,8 +143,10 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
// PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
// return nullptr;
// }
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "expect Op");
if (nargs < 2) {
PyErr_SetString(PyExc_TypeError,
"py_apply expects one Op and at least one tensor "
"as argument");
return nullptr;
}
......@@ -227,7 +229,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
}
}
} else {
py::detail::loader_life_support life_sup; // required to cast DType
py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
auto data = tup[0].cast<py::array>();
DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>();
......@@ -298,7 +300,6 @@ PyObject* TensorWrapper::handle() {
void TensorWrapper::set_handle(PyObject* dest) {
auto py_dest = py::reinterpret_borrow<py::object>(dest);
SharedHandle real_dest = py_dest.cast<SharedHandle>();
auto&& t = std::move(m_tensor->m_handle);
m_tensor->m_handle = std::move(real_dest);
}
......@@ -617,7 +618,7 @@ CompNode _get_device(PyObject*const* args, size_t nargs) {
}
}
if (!valid) {
mgb_assert(0, "expact at least 1 device");
mgb_assert(0, "expect at least 1 device");
}
Py_DECREF(tuple);
return cn;
......
......@@ -88,6 +88,7 @@ def test_dist_grad():
worker()
def test_grad():
x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np)
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pytest
# from megengine.core.interpreter.hints import function
@pytest.mark.skip(reason="under rewrite")
def test_1():
@function
def f(x, p):
x = x + 1
if p:
return x * x
return x * 2
x = Tensor(0)
for _ in range(5):
assert f(x, 0).numpy() == 2
assert f(x, 1).numpy() == 1
......@@ -83,7 +83,7 @@ def test_TQT():
def _save_to(self, name="grad"):
def callback(tensor, grad):
def callback(grad):
setattr(self, name, grad)
return callback
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册