提交 98a74e4a 编写于 作者: M Megvii Engine Team

refactor(dnn): refactor opr proxy in test

GitOrigin-RevId: a1d8682e6f6957a212cc8793c2f0cc9b58f2192b
上级 57546b4c
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
......@@ -20,36 +21,126 @@ namespace test {
template <typename Opr, size_t Arity>
struct AlgoProxy;
template <typename Opr>
struct AlgoProxy<Opr, 3> {
static std::vector<typename Opr::AlgorithmInfo> get_all_algorithms_info(
Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 3);
return opr->get_all_algorithms_info(layouts[0], layouts[1], layouts[2]);
}
static typename Opr::AlgorithmInfo get_algorithm_info_heuristic(
Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 3);
return opr->get_algorithm_info_heuristic(layouts[0], layouts[1],
layouts[2]);
#define DEF_ALGO_PROXY(arity) \
template <typename Opr> \
struct AlgoProxy<Opr, arity> { \
static std::vector<typename Opr::AlgorithmInfo> \
get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \
megdnn_assert(layouts.size() == arity); \
return opr->get_all_algorithms_info(LAYOUTS); \
} \
static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \
Opr* opr, const TensorLayoutArray& layouts) { \
megdnn_assert(layouts.size() == arity); \
return opr->get_algorithm_info_heuristic(LAYOUTS); \
} \
static size_t get_workspace_in_bytes( \
Opr* opr, const TensorLayoutArray& layouts) { \
megdnn_assert(layouts.size() == arity); \
return opr->get_workspace_in_bytes(LAYOUTS); \
} \
static void exec(Opr* opr, const TensorNDArray& tensors, \
Workspace workspace) { \
megdnn_assert(tensors.size() == arity); \
return opr->exec(TENSORS, workspace); \
} \
}
};
template <typename Opr>
struct AlgoProxy<Opr, 5> {
static std::vector<typename Opr::AlgorithmInfo> get_all_algorithms_info(
Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 5);
return opr->get_all_algorithms_info(layouts[0], layouts[1], layouts[2],
layouts[3], layouts[4]);
}
static typename Opr::AlgorithmInfo get_algorithm_info_heuristic(
Opr* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() == 5);
return opr->get_algorithm_info_heuristic(
layouts[0], layouts[1], layouts[2], layouts[3], layouts[4]);
}
};
#define LAYOUTS layouts[0], layouts[1], layouts[2]
#define TENSORS tensors[0], tensors[1], tensors[2]
DEF_ALGO_PROXY(3);
#undef LAYOUTS
#undef TENSORS
#define LAYOUTS layouts[0], layouts[1], layouts[2], layouts[3], layouts[4]
#define TENSORS tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]
DEF_ALGO_PROXY(5);
#undef LAYOUTS
#undef TENSORS
#define LAYOUTS \
layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5], \
layouts[6], layouts[7]
#define TENSORS \
tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5], \
tensors[6], tensors[7]
DEF_ALGO_PROXY(8);
#undef LAYOUTS
#undef TENSORS
#undef DEF_ALGO_PROXY
#define DEF_ALGO_PROXY(Opr, arity) \
template <> \
struct AlgoProxy<Opr, arity> { \
static std::vector<typename Opr::AlgorithmInfo> \
get_all_algorithms_info(Opr* opr, const TensorLayoutArray& layouts) { \
megdnn_assert(layouts.size() == arity); \
return opr->get_all_algorithms_info(LAYOUTS); \
} \
static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \
Opr* opr, const TensorLayoutArray& layouts) { \
megdnn_assert(layouts.size() == arity); \
return opr->get_algorithm_info_heuristic(LAYOUTS); \
} \
static size_t get_workspace_in_bytes( \
Opr* opr, const TensorLayoutArray& layouts, \
const typename Opr::PreprocessedFilter* preprocessed_filter = \
nullptr) { \
megdnn_assert(layouts.size() == arity); \
return opr->get_workspace_in_bytes(LAYOUTS, preprocessed_filter); \
} \
static void exec( \
Opr* opr, const TensorNDArray& tensors, \
const typename Opr::PreprocessedFilter* preprocessed_filter, \
Workspace workspace) { \
megdnn_assert(tensors.size() == arity); \
return opr->exec(TENSORS, preprocessed_filter, workspace); \
} \
static void exec(Opr* opr, const TensorNDArray& tensors, \
Workspace workspace) { \
megdnn_assert(tensors.size() == arity); \
return opr->exec(TENSORS, nullptr, workspace); \
} \
static size_t get_preprocess_workspace_in_bytes( \
Opr* opr, const TensorLayoutArray& layouts) { \
megdnn_assert(layouts.size() == arity); \
return opr->get_preprocess_workspace_in_bytes(LAYOUTS); \
} \
static SmallVector<TensorLayout> deduce_preprocessed_filter_layout( \
Opr* opr, const TensorLayoutArray& layouts) { \
megdnn_assert(layouts.size() == arity); \
return opr->deduce_preprocessed_filter_layout(LAYOUTS); \
} \
static void exec_preprocess( \
Opr* opr, const TensorNDArray& tensors, \
const TensorLayoutArray& layouts, \
Opr::PreprocessedFilter* preprocessed_filter, \
_megdnn_workspace workspace) { \
megdnn_assert(layouts.size() == arity && tensors.size() == arity); \
return opr->exec_preprocess(PREPROCESS_ARGS, preprocessed_filter, \
workspace); \
} \
};
#define LAYOUTS layouts[0], layouts[1], layouts[2]
#define TENSORS tensors[0], tensors[1], tensors[2]
#define PREPROCESS_ARGS layouts[0], tensors[1], layouts[2]
DEF_ALGO_PROXY(ConvolutionForward, 3);
#undef PREPROCESS_ARGS
#undef LAYOUTS
#undef TENSORS
#define LAYOUTS layouts[0], layouts[1], layouts[2], layouts[3], layouts[4]
#define TENSORS tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]
#define PREPROCESS_ARGS \
layouts[0], tensors[1], tensors[2], layouts[3], layouts[4]
DEF_ALGO_PROXY(ConvBias, 5);
#undef PREPROCESS_ARGS
#undef LAYOUTS
#undef TENSORS
#undef DEF_ALGO_PROXY
template <typename Opr, size_t arity = OprTrait<Opr>::arity>
struct OprAlgoProxyDefaultImpl : public AlgoProxy<Opr, arity> {};
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册