未验证 提交 947e1373 编写于 作者: HansBug's avatar HansBug 😆 提交者: GitHub

Merge pull request #83 from opendilab/fix/treelize

fix(hansbug): fix bug of #82, add more unittests
......@@ -3,6 +3,7 @@ from operator import __mul__
import pytest
from treevalue import FastTreeValue
from treevalue.tree import func_treelize, TreeValue, method_treelize, classmethod_treelize, delayed
......@@ -401,3 +402,29 @@ class TestTreeFuncFunc:
'v': {'a': 12, 'b': 25, 'x': {'c': 38, 'd': 51}},
})
assert cnt_1 == 4
def test_return_treevalue(self):
def func(x):
return FastTreeValue({
'x': x, 'y': x ** 2,
})
f = FastTreeValue({
'x': func,
'y': {
'z': func,
}
})
v = FastTreeValue({'x': 2, 'y': {'z': 34}})
assert f(v) == FastTreeValue({
'x': {
'x': v.x,
'y': v.x ** 2,
},
'y': {
'z': {
'x': v.y.z,
'y': v.y.z ** 2,
}
}
})
......@@ -57,6 +57,7 @@ cdef object _c_func_treelize_run(object func, list args, dict kwargs, _e_tree_mo
cdef list _a_args
cdef dict _a_kwargs
cdef object _a_ret
if not has_tree:
_a_args = []
for v in args:
......@@ -72,7 +73,11 @@ cdef object _c_func_treelize_run(object func, list args, dict kwargs, _e_tree_mo
else:
_a_kwargs[k] = missing_func()
return func(*_a_args, **_a_kwargs)
_a_ret = func(*_a_args, **_a_kwargs)
if isinstance(_a_ret, TreeValue):
return _a_ret._detach()
else:
return _a_ret
cdef dict _d_res = {}
cdef str ak
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册