提交 72531f2b 编写于 作者: M Megvii Engine Team 提交者: huangxinda

test(autograd): add more tests for higher order grad

GitOrigin-RevId: 5fc308f87a6c4cb2de9b8edb654fe7a2416333ec
上级 522e556b
......@@ -345,12 +345,12 @@ class GradManager:
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()
def __and__(self, other):
def __or__(self, other):
if isinstance(other, GradManager):
return GradManagerGroup([self, other])
return NotImplemented
__rand__ = __and__
__ror__ = __or__
class GradManagerGroup:
......@@ -364,8 +364,6 @@ class GradManagerGroup:
return NotImplemented
return GradManagerGroup([*self._gms, *other._gms])
__and__ = merge_with
__rand__ = merge_with
__or__ = merge_with
__ror__ = merge_with
......
......@@ -468,7 +468,7 @@ PyObject* GradKeyWrapper::get_priority() {
}
void GradKeyWrapper::set_priority(pybind11::handle priority) {
m_key->name = py::cast<int>(priority);
m_key->priority = py::cast<int>(priority);
}
void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) {
......@@ -535,7 +535,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
size_t priority_backup;
CleanupGuard(GradKey* this_) : owner(this_) {
priority_backup = sm_min_priority;
sm_min_priority = owner->priority;
sm_min_priority = owner->priority + 1;
}
~CleanupGuard() {
owner->cleanup();
......@@ -636,7 +636,7 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
Py_RETURN_FALSE;
}
int GradKey::sm_min_priority = 0;
int GradKey::sm_min_priority = std::numeric_limits<int>::min();
GradKey::~GradKey() {
cleanup();
......
......@@ -966,6 +966,7 @@ void init_tensor(py::module m) {
.def<&GradKeyWrapper::attach>("attach")
.def<&GradKeyWrapper::is_attached_to>("is_attached_to")
.def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name")
.def_getset<&GradKeyWrapper::get_priority, &GradKeyWrapper::set_priority>("priority")
.finalize();
if (!grad_key_type) throw py::error_already_set();
py::setattr(m, "GradKey", grad_key_type);
......
......@@ -279,3 +279,69 @@ def test_broadcast_grad(trace_mode):
func()
worker()
def test_2nd_grad_with_manager():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
gm = GradManager().attach([x])
gm2 = GradManager().attach([x])
with gm:
with gm2:
y = F.cos(x)
gm2.backward(y)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
gm.backward(x.grad)
np.testing.assert_almost_equal(
x.grad.numpy(), -np.sin(x_np) - np.cos(x_np), decimal=5
)
def test_grad_manager_group():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
gm = GradManager().attach([x])
gm2 = GradManager().attach([x])
with gm | gm2:
y = F.cos(x)
gm.backward(y)
gm2.backward(y)
np.testing.assert_almost_equal(x.grad.numpy(), -2 * np.sin(x_np), decimal=5)
x.grad = None
def test_grad_manager_group_visibility():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
gm = GradManager().attach([x])
gm2 = GradManager().attach([x])
with gm | gm2:
y = F.cos(x)
gm2.backward(y)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
gm.backward(x.grad)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
def test_grad_manager_visibility_by_order():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
gm = GradManager().attach([x])
gm2 = GradManager().attach([x])
with gm2:
with gm:
y = F.cos(x)
gm2.backward(y)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
gm.backward(x.grad)
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)
......@@ -126,7 +126,7 @@ def test_2nd_grad():
x.grad = None
grad2(z, ones)
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np))
np.testing.assert_almost_equal(x.grad.numpy(), -np.cos(x_np), decimal=5)
def test_grad_with_tensor_wrapper():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册