winograd_filter_preprocess.cpp 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
/**
 * \file dnn/test/arm_common/winograd_filter_preprocess.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "test/common/checker.h"
#include "test/common/benchmarker.h"
#include "test/common/winograd_filter_preprocess.h"

#include "test/arm_common/fixture.h"

using namespace megdnn;
using namespace test;

TEST_F(ARM_COMMON, WinogradFilterPreprocessF32) {
    using namespace winograd_filter_preprocess;
    Checker<WinogradFilterPreprocess> checker(handle());
    // default
    std::vector<TestArg> args = get_args(6, 3);
    std::vector<TestArg> args54 = get_args(5, 4);
    std::vector<TestArg> args45 = get_args(4, 5);

    // mk4
    std::vector<TestArg> args_mk4_out2 =
            get_mk_packed_args(2, param::Winograd::Format::MK4, 4);
    std::vector<TestArg> args_mk4_out6 =
            get_mk_packed_args(6, param::Winograd::Format::MK4, 4);

    args.insert(args.end(), args54.begin(), args54.end());
    args.insert(args.end(), args45.begin(), args45.end());
    args.insert(args.end(), args_mk4_out2.begin(), args_mk4_out2.end());
    args.insert(args.end(), args_mk4_out6.begin(), args_mk4_out6.end());
    for (auto&& arg : args) {
        checker.set_param(arg.param)
                .set_dtype(0, dtype::Float32{})
                .set_dtype(1, dtype::Float32{})
                .execs({arg.src, {}});
    }
}

TEST_F(ARM_COMMON, WinogradFilterPreprocessQs8) {
    using namespace winograd_filter_preprocess;
    std::vector<TestArg> args =
            get_mk_packed_args(2, param::Winograd::Format::MK8, 8);
    Checker<WinogradFilterPreprocess> checker(handle());
    UniformIntRNG rng{-50, 50};
    checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &rng);
    for (auto&& arg : args) {
        checker.set_param(arg.param)
                .set_dtype(0, dtype::QuantizedS8(2.5f))
                .set_dtype(1, dtype::QuantizedS16(2.5f))
                .execs({arg.src, {}});
    }
}

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON, WinogradFilterPreprocessF16) {
    using namespace winograd_filter_preprocess;
    Checker<WinogradFilterPreprocess> checker(handle());
    // default
    std::vector<TestArg> args = get_args(6, 3);
    std::vector<TestArg> args_23 =
            get_mk_packed_args(2, param::Winograd::Format::DEFAULT, 4);
    std::vector<TestArg> args45 = get_args(4, 5);

    // mk8
    std::vector<TestArg> args_mk8_out2 =
            get_mk_packed_args(2, param::Winograd::Format::MK8, 8);

    args.insert(args.end(), args_23.begin(), args_23.end());
    args.insert(args.end(), args45.begin(), args45.end());
    args.insert(args.end(), args_mk8_out2.begin(), args_mk8_out2.end());

    Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
    for (auto&& arg : args) {
        checker.set_param(arg.param)
                .set_rng(0, rng)
                .set_dtype(0, dtype::Float16{})
                .set_dtype(1, dtype::Float16{})
                .execs({arg.src, {}});
    }
}

#endif

// vim: syntax=cpp.doxygen