提交 a49f4a66 编写于 作者: M Megvii Engine Team

feat(dnn): add indexing_one_hot and indexing_set_one_hot opr

GitOrigin-RevId: c5406c71ffa91864c1fb0828278e51ef3af45c97
上级 1310ad49
/** /**
* \file dnn/test/common/opr_trait.h * \file dnn/src/common/opr_trait.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include <cstddef> #include <cstddef>
namespace megdnn { namespace megdnn {
namespace test {
template <typename Opr> template <typename Opr>
struct OprTrait {}; struct OprTrait {};
...@@ -114,7 +113,10 @@ DEF(FakeQuantForward, 4, true, true); ...@@ -114,7 +113,10 @@ DEF(FakeQuantForward, 4, true, true);
DEF(FakeQuantBackward, 5, true, false); DEF(FakeQuantBackward, 5, true, false);
DEF(TQTForward, 3, true, true); DEF(TQTForward, 3, true, true);
DEF(TQTBackward, 5, true, false); DEF(TQTBackward, 5, true, false);
} // namespace test DEF(PowC, 2, false, true);
DEF(UniformRNG, 1, true, true);
DEF(GaussianRNG, 1, true, true);
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#pragma once #pragma once
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "test/common/opr_trait.h" #include "src/common/opr_trait.h"
#include "test/common/utils.h" #include "test/common/utils.h"
namespace megdnn { namespace megdnn {
......
...@@ -11,12 +11,13 @@ ...@@ -11,12 +11,13 @@
*/ */
#pragma once #pragma once
#include "src/common/opr_trait.h"
#include "test/common/deduce_layout_proxy.h" #include "test/common/deduce_layout_proxy.h"
#include "test/common/exec_proxy.h" #include "test/common/exec_proxy.h"
#include "test/common/fast_run_cache.h" #include "test/common/fast_run_cache.h"
#include "test/common/inspect_type.h" #include "test/common/inspect_type.h"
#include "test/common/opr_algo_proxy.h" #include "test/common/opr_algo_proxy.h"
#include "test/common/opr_trait.h"
#include "test/common/timer.h" #include "test/common/timer.h"
#include "test/common/workspace_wrapper.h" #include "test/common/workspace_wrapper.h"
......
...@@ -12,13 +12,12 @@ ...@@ -12,13 +12,12 @@
#include "megdnn/handle.h" #include "megdnn/handle.h"
#include "megdnn/oprs/general.h" #include "megdnn/oprs/general.h"
#include "test/common/opr_proxy.h"
#include "src/common/opr_trait.h"
namespace megdnn { namespace megdnn {
namespace test { namespace test {
DEF(PowC, 2, false, true);
void run_powc_test(Handle* handle, DType dtype); void run_powc_test(Handle* handle, DType dtype);
} // namespace test } // namespace test
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册