algo.h 11.8 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/convolution/backward_data/algo.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15
 */

#pragma once

#include <unordered_map>
16 17 18 19
#include "src/common/algo_base.h"
#include "src/common/metahelper.h"
#include "src/cuda/convolution/helper.h"
#include "src/cuda/cudnn_wrapper.h"
20 21 22 23 24 25 26 27 28 29

namespace megdnn {
namespace cuda {

/*!
 * \brief base class for convolution algos
 *
 * All the algo impls should try to support non-contiguous batch dim, for group
 * conv execution.
 */
30 31 32
class ConvolutionBackwardDataImpl::AlgoBase : public Algorithm {
protected:
    ~AlgoBase() = default;
33

34 35 36 37 38 39 40 41
public:
    enum class AlgoType : uint32_t {
        CUDA_CUDNN,
        CUDA_MATMUL,
        CUDA_CHANWISE,
        CUDA_CHANWISE_SMALL,
        CUDA_BFLOAT16,
        CUDA_GROUP_CONV_GENERAL,
42 43
        CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
        CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8
44 45 46 47 48 49 50 51 52 53 54 55 56
    };
    using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;

    AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
    struct SizeArgs {
        HandleImpl* handle;
        CanonizedFilterMeta filter_meta;
        const TensorLayout *diff_layout, *grad_layout, *filter_layout;
        ConvolutionBackwardDataImpl* opr;

        std::string to_string() const;
        void init_desc(convolution::CUDNNBwdDataDescs& desc) const {
            desc.set(filter_meta, *diff_layout, *grad_layout, opr->param());
57
        }
58 59 60 61 62 63 64 65 66
        SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
                 const TensorLayout& diff, const TensorLayout& grad);
        SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter,
                 const CanonizedFilterMeta& filter_meta,
                 const TensorLayout& diff, const TensorLayout& grad);

        convolution::ForwardSizeArgs as_fwd_args() const {
            return {handle, grad_layout, filter_layout, filter_meta,
                    diff_layout};
67
        }
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
    };
    struct ExecArgs : public SizeArgs {
        const TensorND *filter_tensor, *diff_tensor, *grad_tensor;
        Workspace workspace;

        ExecArgs(ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter,
                 _megdnn_tensor_in diff, _megdnn_tensor_out grad,
                 _megdnn_workspace workspace);
    };
    virtual bool is_available(const SizeArgs& args) const = 0;
    virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0;
    virtual void exec(const ExecArgs& args) const = 0;

    bool is_available_wk(const SizeArgs& args, size_t limit) {
        return is_available(args) && get_workspace_in_bytes(args) <= limit;
    }

85 86
    bool is_available_attribute(
            const SizeArgs& args,
87 88
            const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
            const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT,
89
            size_t limit = std::numeric_limits<size_t>::max()) {
90 91 92
        return contain_attribute_all(positive_attr) &&
               !contain_attribute_any(negative_attr) &&
               is_available_wk(args, limit);
93 94 95 96 97 98 99 100 101 102 103 104 105
    }

    AlgoBase& check_workspace(const SizeArgs& args,
                              const Workspace& workspace) {
        auto req = get_workspace_in_bytes(args);
        megdnn_assert(req <= workspace.size,
                      "conv bwd data algo %s: "
                      "required workspace %zu bytes, got %zu",
                      name(), req, workspace.size);
        return *this;
    }

    virtual bool is_cudnn() const { return false; }
106 107 108 109
};

class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase {
    cudnnConvolutionBwdDataAlgo_t m_cudnn_enum;
110
    CudnnAlgoPack::Attr m_attr;
111

112 113 114 115 116 117 118
public:
    AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum)
            : m_cudnn_enum(cudnn_enum) {
        megdnn_assert(CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) !=
                      CudnnAlgoPack::conv_bwd_data_algos().end());
        m_attr = CudnnAlgoPack::conv_bwd_data_algos().at(cudnn_enum);
    }
119

120 121 122
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
123

124
    const char* name() const override { return m_attr.name.c_str(); }
125 126 127 128 129
    AlgoAttribute attribute() const override {
        auto ret = static_cast<AlgoAttribute>(0);
        if (m_attr.is_reproducible) {
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
130 131 132
        if (m_attr.accuracy_depend_on_batch) {
            ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
        }
133 134
        return ret;
    }
135
    cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; }
136

137 138
    bool is_cudnn() const override { return true; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN)
139

140 141 142 143 144
    std::string param() const override {
        std::string ret;
        serialize_write_pod(m_cudnn_enum, ret);
        return ret;
    }
145 146 147
};

//! im2col and matmul, with dilation
148 149 150
class ConvolutionBackwardDataImpl::AlgoMatmul final : public AlgoBase {
    template <typename T>
    static void exec_internal(const ExecArgs& args);
151

152 153 154 155
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
156

157 158 159 160
    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;

161 162
    const char* name() const override { return "MATMUL"; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL)
163
    AlgoAttribute attribute() const override {
164 165
        return AlgoAttribute::REPRODUCIBLE |
               AlgoAttribute::ACCURACY_DEPEND_ON_BATCH;
166
    }
167 168
};

169 170 171 172 173
class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
174

175 176
    const char* name() const override { return "CHANNEL_WISE"; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE)
177 178 179
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
180 181
};

182 183 184 185 186
class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
187

188 189
    const char* name() const override { return "CHANNEL_WISE_SMALL"; }
    MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
190
    AlgoAttribute attribute() const override {
191 192
        return AlgoAttribute::REPRODUCIBLE |
               AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
193
    }
194 195
};

196 197 198 199 200 201
class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;

202 203 204 205 206 207 208
    std::vector<SearchItem> get_subopr_list(
            const TensorLayoutArray& layouts,
            const OperatorBase* opr) const override;

    const char* name() const override {
        return "CONVOLUTION_BACKWARD_DATD_BFLOAT16";
    }
209 210 211 212

    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
213 214 215

private:
    WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const;
216
    MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16)
217 218
};

219
//! implement group conv by another algo
220 221 222
class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final
        : public AlgoBase {
    AlgoBase* m_impl;
223 224
    std::string m_name;

225 226
public:
    AlgoGroupConvGeneral(AlgoBase* impl);
227

228 229 230
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
231

232
    const char* name() const override { return m_name.c_str(); }
233

234 235 236 237

    static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg,
                                 TensorLayout& grad_pg);
    MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL)
238
    AlgoAttribute attribute() const override {
239 240 241 242 243 244 245
        auto ret = AlgoAttribute::DEFAULT;
#define cb(attr)                               \
    if (m_impl->contain_attribute_all(attr)) { \
        ret |= attr;                           \
    }
        MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb)
#undef cb
246
        if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) {
247 248 249 250
            ret |= AlgoAttribute::REPRODUCIBLE;
        }
        return ret;
    }
251 252
};

253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
class ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm final
        : public AlgoBase {
public:
    struct AlgoParam {
        int threadblock_m;
        int threadblock_n;
        int threadblock_k;
        int warp_m;
        int warp_n;
        int warp_k;
        int stage;
        std::string to_string() {
            return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m,
                            threadblock_n, threadblock_k, warp_m, warp_n,
                            warp_k, stage);
        }
    };
    AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param)
            : m_algo_param{algo_param},
              m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s",
                              m_algo_param.to_string().c_str())} {}
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override { return m_name.c_str(); }
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8)
private:
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
                                         const SizeArgs& args) const;
    AlgoParam m_algo_param;
    std::string m_name;
};

289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
class ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm final
        : public AlgoBase {
public:
    bool is_available(const SizeArgs& args) const override;
    size_t get_workspace_in_bytes(const SizeArgs& args) const override;
    void exec(const ExecArgs& args) const override;
    const char* name() const override {
        return "INT8_NCHW_DOTPROD_IMPLICIT_GEMM";
    }
    AlgoAttribute attribute() const override {
        return AlgoAttribute::REPRODUCIBLE;
    }
    MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8);
private:
    WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr,
                                         const SizeArgs& args) const;
};

307
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
308 309
    // defined in cudnn.cpp
    void fill_cudnn_algos();
310 311
    // defined in implicit_gemm_int8_nchw4_dp4a.cpp
    void fill_int8_dp4a_algos();
312

313
    AlgoBase::Mapper m_all_algos_map;
314

315 316
public:
    AlgoPack();
317

318 319 320 321 322 323
    std::vector<AlgoCUDNN> cudnn;
    AlgoMatmul matmul;
    AlgoChanwise chanwise;
    AlgoChanwiseSmall chanwise_small;
    std::vector<AlgoGroupConvGeneral> gconv;
    std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv;
324
    AlgoBFloat16 bfloat16;
325
    std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
326
    AlgoInt8NCHWDotProdImplicitGemm int8_nchw_dotprod;
327

328
    std::vector<AlgoBase*>
329 330 331
            //! all algorithms
            all_algos,
            //! non-cudnn algos, used for heuristic if cudnn is not supported
332
            non_cudnn_algos, bfloat16_algos, int8_algos;
333 334

    AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo);
335

336
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
337 338
};

339 340
}  // namespace cuda
}  // namespace megdnn
341 342

// vim: syntax=cpp.doxygen