From c3f8cf04fa3ca3e24f5650c4915e4e0e8f69e0fa Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 27 Apr 2021 16:55:26 +0800 Subject: [PATCH] feat(dnn): add conv_bwd_data and conv_bwd_filter accuracy shake check GitOrigin-RevId: 4069e083d2218b8a5ce2ea77c3c7d5f81acc6149 --- dnn/src/cuda/convolution/backward_data/algo.h | 8 ++++- dnn/test/common/accuracy_shake_checker.h | 3 +- dnn/test/cuda/accuracy_shake.cpp | 35 +++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/dnn/src/cuda/convolution/backward_data/algo.h b/dnn/src/cuda/convolution/backward_data/algo.h index a4286c6db..d53b3c617 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.h +++ b/dnn/src/cuda/convolution/backward_data/algo.h @@ -236,7 +236,13 @@ public: TensorLayout& grad_pg); MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) AlgoAttribute attribute() const override { - auto ret = static_cast(0); + auto ret = AlgoAttribute::DEFAULT; +#define cb(attr) \ + if (m_impl->contain_attribute_all(attr)) { \ + ret |= attr; \ + } + MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) +#undef cb if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { ret |= AlgoAttribute::REPRODUCIBLE; } diff --git a/dnn/test/common/accuracy_shake_checker.h b/dnn/test/common/accuracy_shake_checker.h index efbbe6b99..a1deafeda 100644 --- a/dnn/test/common/accuracy_shake_checker.h +++ b/dnn/test/common/accuracy_shake_checker.h @@ -168,7 +168,8 @@ public: AlgoProxy::arity>::get_all_algorithms_info( opr, layouts)) { if (!(algo_info.attribute & - AlgoAttribute::ACCURACY_DEPEND_ON_BATCH) && + AlgoAttribute::ACCURACY_DEPEND_ON_BATCH) && + (algo_info.attribute & AlgoAttribute::REPRODUCIBLE) && std::regex_match( algo_info.desc.name, std::regex("(.*)(" + m_policy_name.name + ")(.*)"))) { diff --git a/dnn/test/cuda/accuracy_shake.cpp b/dnn/test/cuda/accuracy_shake.cpp index ed93afe87..3f48f4e37 100644 --- a/dnn/test/cuda/accuracy_shake.cpp +++ b/dnn/test/cuda/accuracy_shake.cpp @@ -241,6 +241,41 @@ TEST_F(CUDA, SHAKE_LOCAL_SHARE) { checker.exec({{20, 16, 32, 32}, {3, 3, 16, 3, 3, 64}, {}}); } +TEST_F(CUDA, SHAKE_CONVOLUTION_BACKWARD_DATA) { + AccuracyShakeChecker checker(handle_cuda()); + NormalRNG default_rng; + checker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_rng(0, &default_rng) + .set_rng(1, &default_rng); + // ConvolutionBackwardData + checker.exec({{8, 16, 3, 3}, {64, 8, 5, 5}, {64, 16, 7, 7}}); + + // group + ConvolutionBackwardData::Param param; + param.sparse = Convolution::Param::Sparse::GROUP; + checker.set_param(param); + checker.exec({{2, 16, 32, 3, 3}, {2, 32, 5, 5}, {2, 64, 7, 7}}); + checker.exec({{2, 8, 32, 3, 3}, {64, 16, 19, 19}, {64, 64, 21, 21}}); +} + +TEST_F(CUDA, SHAKE_CONVOLUTION_BACKWARD_FILTER) { + AccuracyShakeChecker checker(handle_cuda()); + NormalRNG default_rng; + checker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_rng(0, &default_rng) + .set_rng(1, &default_rng); + // ConvolutionBackwardFilter + checker.exec({{2, 64, 7, 7}, {2, 32, 5, 5}, {32, 64, 3, 3}}); + + // group + ConvolutionBackwardFilter::Param param; + param.sparse = Convolution::Param::Sparse::GROUP; + checker.set_param(param); + checker.exec({{2, 64, 7, 7}, {2, 32, 5, 5}, {2, 16, 32, 3, 3}}); +} + } // namespace test } // namespace megdnn -- GitLab