提交 43c59204 编写于 作者: M Megvii Engine Team 提交者: huangxinda

refactor(dnn/cuda): refactor relayout format kernels

GitOrigin-RevId: ab86e6653342ae9f74dd069cfb85aabca6dc637c
上级 f41a8086
......@@ -110,35 +110,33 @@ MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage,
return (result << (shift - bits)) >> shift;
}
MEGDNN_DEVICE __forceinline__ static void transform_int4x8_to_int8(
int (&result)[8], const int& source) {
#pragma unroll
for (int i = 0; i < 8; i++) {
result[i] = unpack_integer_4bits<true>(
reinterpret_cast<unsigned const&>(source), (i << 2));
}
}
MEGDNN_DEVICE __forceinline__ static void transform_uint4x8_to_int8(
template <bool signedness>
MEGDNN_DEVICE __forceinline__ static void transform_b4x8_to_int8(
int (&result)[8], const int& source) {
#pragma unroll
for (int i = 0; i < 8; i++) {
result[i] = unpack_integer_4bits<false>(
result[i] = unpack_integer_4bits<signedness>(
reinterpret_cast<unsigned const&>(source), (i << 2));
}
}
MEGDNN_DEVICE __forceinline__ static void transform_int4x2_to_int8(
template <bool signedness>
MEGDNN_DEVICE __forceinline__ static void transform_b4x2_to_int8(
int (&result)[2], const uint8_t& source) {
result[0] = unpack_integer_4bits<true>(source, 0);
result[1] = unpack_integer_4bits<true>(source, 4);
result[0] = unpack_integer_4bits<signedness>(source, 0);
result[1] = unpack_integer_4bits<signedness>(source, 4);
}
MEGDNN_DEVICE __forceinline__ static void transform_uint4x2_to_int8(
int (&result)[2], const uint8_t& source) {
result[0] = unpack_integer_4bits<false>(source, 0);
result[1] = unpack_integer_4bits<false>(source, 4);
template <bool signedness>
MEGDNN_DEVICE __forceinline__ static int transform_int8_to_b4x8(
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) {
if (signedness) {
return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7);
} else {
return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7);
}
}
} // namespace integer_subbyte
} // namespace cuda
} // namespace megdnn
......
/**
* \file dnn/src/cuda/relayout_format/cuda_post_process.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/relayout_format/relayout_format.cuh"
namespace megdnn {
namespace cuda {
namespace relayout_format {
namespace internal {
template <typename SrcType, typename DstType, bool same_scale>
struct CudaPostProcess;
template <>
struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, true> {
CudaPostProcess(float, uint8_t, float, uint8_t){};
inline __device__ int8_t operator()(uint8_t val) { return val - 128; }
};
template <>
struct CudaPostProcess<dtype::Uint8, dtype::QuantizedS8, false> {
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt;
CudaPostProcess(float, uint8_t, float dst_scale, uint8_t) {
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale);
};
inline __device__ int8_t operator()(uint8_t val) {
return m_dst_type_cvt.quantize((float)val - 128.f).as_int8();
}
};
template <>
struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, false> {
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt;
CudaDTypeParamImpl<dt_quint8> m_src_type_cvt;
CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale,
uint8_t) {
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale);
m_src_type_cvt =
CudaDTypeParamImpl<dt_quint8>(src_scale, src_zero_point);
};
inline __device__ int8_t operator()(uint8_t val) {
float med_var = m_src_type_cvt.dequantize(dt_quint8(val));
return m_dst_type_cvt.quantize(med_var).as_int8();
}
};
template <>
struct CudaPostProcess<dtype::Quantized8Asymm, dtype::QuantizedS8, true> {
uint8_t m_src_zero_point = 0;
CudaPostProcess(float, uint8_t src_zero_point, float, uint8_t) {
m_src_zero_point = src_zero_point;
};
inline __device__ int8_t operator()(uint8_t val) {
return val - m_src_zero_point;
}
};
template <>
struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, false> {
CudaDTypeParamImpl<dt_qint8> m_dst_type_cvt;
CudaDTypeParamImpl<dt_qint8> m_src_type_cvt;
CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) {
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint8>(dst_scale);
m_src_type_cvt = CudaDTypeParamImpl<dt_qint8>(src_scale);
};
inline __device__ int8_t operator()(int8_t val) {
float med_var = m_src_type_cvt.dequantize(dt_qint8(val));
return m_dst_type_cvt.quantize(med_var).as_int8();
}
};
template <>
struct CudaPostProcess<dtype::QuantizedS8, dtype::QuantizedS8, true> {
CudaPostProcess(){};
CudaPostProcess(float, uint8_t, float, uint8_t){};
inline __device__ int8_t operator()(int8_t val) { return val; }
};
template <>
struct CudaPostProcess<dtype::QuantizedS32, dtype::QuantizedS32, false> {
CudaDTypeParamImpl<dt_qint32> m_dst_type_cvt;
CudaDTypeParamImpl<dt_qint32> m_src_type_cvt;
CudaPostProcess(float src_scale, int, float dst_scale, int) {
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint32>(dst_scale);
m_src_type_cvt = CudaDTypeParamImpl<dt_qint32>(src_scale);
};
inline __device__ int operator()(int val) {
float med_var = m_src_type_cvt.dequantize(dt_qint32(val));
return m_dst_type_cvt.quantize(med_var).as_int32();
}
};
template <>
struct CudaPostProcess<dtype::QuantizedS32, dtype::QuantizedS32, true> {
CudaPostProcess(float, int, float, int){};
inline __device__ int operator()(int val) { return val; }
};
template <>
struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, false> {
using SrcType = dtype::QuantizedS4;
using DstType = dtype::QuantizedS4;
CudaDTypeParamImpl<dt_qint4> m_dst_type_cvt;
CudaDTypeParamImpl<dt_qint4> m_src_type_cvt;
CudaPostProcess(float src_scale, uint8_t, float dst_scale, uint8_t) {
m_dst_type_cvt = CudaDTypeParamImpl<dt_qint4>(dst_scale);
m_src_type_cvt = CudaDTypeParamImpl<dt_qint4>(src_scale);
}
inline __device__ int8_t operator()(int8_t val) {
float intermediate = m_src_type_cvt.dequantize(dt_qint4(val));
return m_dst_type_cvt.quantize(intermediate).as_int8();
}
};
template <>
struct CudaPostProcess<dtype::QuantizedS4, dtype::QuantizedS4, true> {
using SrcType = dtype::QuantizedS4;
using DstType = dtype::QuantizedS4;
CudaPostProcess(float, uint8_t, float, uint8_t){};
inline __device__ int8_t operator()(int8_t val) { return val; }
};
template <>
struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, false> {
using SrcType = dtype::Quantized4Asymm;
using DstType = dtype::Quantized4Asymm;
CudaDTypeParamImpl<dt_quint4> m_dst_type_cvt;
CudaDTypeParamImpl<dt_quint4> m_src_type_cvt;
CudaPostProcess(float src_scale, uint8_t src_zero_point, float dst_scale,
uint8_t dst_zero_point) {
m_dst_type_cvt =
CudaDTypeParamImpl<dt_quint4>(dst_scale, dst_zero_point);
m_src_type_cvt =
CudaDTypeParamImpl<dt_quint4>(src_scale, src_zero_point);
};
inline __device__ uint8_t operator()(uint8_t val) {
float intermediate = m_src_type_cvt.dequantize(dt_quint4(val));
return m_dst_type_cvt.quantize(intermediate).as_uint8();
}
};
template <>
struct CudaPostProcess<dtype::Quantized4Asymm, dtype::Quantized4Asymm, true> {
using SrcType = dtype::Quantized4Asymm;
using DstType = dtype::Quantized4Asymm;
uint8_t m_src_zero_point = 0;
uint8_t m_dst_zero_point = 0;
CudaPostProcess(float, uint8_t src_zero_point, float,
uint8_t dst_zero_point) {
m_src_zero_point = src_zero_point;
m_dst_zero_point = dst_zero_point;
};
inline __device__ uint8_t operator()(uint8_t val) {
int result = val - m_src_zero_point + m_dst_zero_point;
result = result >= 0 ? result : 0;
result = result < 16 ? result : 15;
return static_cast<uint8_t>(result);
}
};
} // namespace internal
} // namespace relayout_format
} // namespace cuda
} // namespace megdnn
/**
* \file dnn/src/cuda/relayout_format/helper.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
namespace megdnn {
namespace cuda {
namespace relayout_format {
#define devfunc __forceinline__ __device__
template <int size_nbits>
devfunc int make_zero(int zero_point);
template <>
devfunc int make_zero<4>(int zero_point) {
return transform_int8_to_uint4x8(zero_point, zero_point, zero_point,
zero_point, zero_point, zero_point,
zero_point, zero_point);
}
template <typename AccessType, int LoadBytes>
struct global_load_with_zero_point;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Specializations
//
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
// The redundant mov PTX instruction is used to enforce the compiler to
// initialize data to zero before ld.global
template <typename AccessType>
struct global_load_with_zero_point<AccessType, 32> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
uint4* data = reinterpret_cast<uint4*>(&D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %9, 0;\n"
" mov.b32 %0, %10;\n"
" mov.b32 %1, %10;\n"
" mov.b32 %2, %10;\n"
" mov.b32 %3, %10;\n"
" mov.b32 %4, %10;\n"
" mov.b32 %5, %10;\n"
" mov.b32 %6, %10;\n"
" mov.b32 %7, %10;\n"
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n"
" @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z),
"=r"(data[0].w), "=r"(data[1].x), "=r"(data[1].y),
"=r"(data[1].z), "=r"(data[1].w)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)),
"l"(((uint8_t*)ptr) + 16));
}
};
template <typename AccessType>
struct global_load_with_zero_point<AccessType, 16> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
uint4& data = reinterpret_cast<uint4&>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" mov.b32 %0, %6;\n"
" mov.b32 %1, %6;\n"
" mov.b32 %2, %6;\n"
" mov.b32 %3, %6;\n"
" @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
"}\n"
: "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)));
}
};
template <typename AccessType>
struct global_load_with_zero_point<AccessType, 8> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
uint2& data = reinterpret_cast<uint2&>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %3, 0;\n"
" mov.b32 %0, %4;\n"
" mov.b32 %1, %4;\n"
" @p ld.global.v2.u32 {%0, %1}, [%2];\n"
"}\n"
: "=r"(data.x), "=r"(data.y)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)));
}
};
template <typename AccessType>
struct global_load_with_zero_point<AccessType, 4> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
unsigned& data = reinterpret_cast<unsigned&>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %2, 0;\n"
" mov.b32 %0, %3;\n"
" @p ld.global.u32 %0, [%1];\n"
"}\n"
: "=r"(data)
: "l"(ptr), "r"((int)pred_guard),
"r"(reinterpret_cast<unsigned&>(zero_point)));
}
};
template <typename AccessType>
struct global_load_with_zero_point<AccessType, 1> {
devfunc global_load_with_zero_point(AccessType& D, void const* ptr,
bool pred_guard, int zero_point) {
if (pred_guard)
D = *(reinterpret_cast<AccessType const*>(ptr));
else {
unsigned uv = reinterpret_cast<unsigned&>(zero_point);
uint8_t& data = reinterpret_cast<uint8_t&>(D);
data = uv & 0xff;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <
/// Fragment type to store loaded data
typename AccessType,
/// The bytes of loading
int LoadBytes>
struct global_store;
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Specializations
//
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename AccessType>
struct global_store<AccessType, 32> {
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) {
uint4 const* data = reinterpret_cast<uint4 const*>(&D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
" @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n"
"}\n"
:
: "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z),
"r"(data[0].w), "r"((int)pred_guard),
"l"(((uint8_t*)ptr) + 16), "r"(data[1].x), "r"(data[1].y),
"r"(data[1].z), "r"(data[1].w));
}
};
template <typename AccessType>
struct global_store<AccessType, 16> {
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) {
uint4 const& data = reinterpret_cast<uint4 const&>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
"}\n"
:
: "l"(ptr), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w),
"r"((int)pred_guard));
}
};
template <typename AccessType>
struct global_store<AccessType, 8> {
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) {
uint2 const& data = reinterpret_cast<uint2 const&>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %3, 0;\n"
" @p st.global.v2.u32 [%0], {%1, %2};\n"
"}\n"
:
: "l"(ptr), "r"(data.x), "r"(data.y), "r"((int)pred_guard));
}
};
template <typename AccessType>
struct global_store<AccessType, 4> {
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) {
uint32_t const& data = reinterpret_cast<uint32_t const&>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %2, 0;\n"
" @p st.global.u32 [%0], %1;\n"
"}\n"
:
: "l"(ptr), "r"(data), "r"((int)pred_guard));
}
};
template <typename AccessType>
struct global_store<AccessType, 2> {
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) {
uint16_t const& data = reinterpret_cast<uint16_t const&>(D);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %2, 0;\n"
" @p st.global.u16 [%0], %1;\n"
"}\n"
:
: "l"(ptr), "h"(data), "r"((int)pred_guard));
}
};
template <typename AccessType>
struct global_store<AccessType, 1> {
devfunc global_store(AccessType const& D, void* ptr, bool pred_guard) {
if (pred_guard)
*(reinterpret_cast<AccessType*>(ptr)) = D;
}
};
#undef devfunc
} // namespace relayout_format
} // namespace cuda
} // namespace megdnn
......@@ -39,6 +39,20 @@ void relayout_format_cuda_nchwx_nchw(const TensorND& src, const TensorND& dst,
const uint8_t src_zero_point = 0,
const uint8_t dst_zero_point = 0);
void relayout_format_cuda_nchw_nhwc(const TensorND& src, const TensorND& dst,
const cudaStream_t& stream,
const float src_scale = 1.f,
const float dst_scale = 1.f,
const uint8_t src_zero_point = 0,
const uint8_t dst_zero_point = 0);
void relayout_format_cuda_nhwc_nchw(const TensorND& src, const TensorND& dst,
const cudaStream_t& stream,
const float src_scale = 1.f,
const float dst_scale = 1.f,
const uint8_t src_zero_point = 0,
const uint8_t dst_zero_point = 0);
void relayout_format_cuda_nchw_nchw4_weight(const TensorND& src,
const TensorND& dst,
const cudaStream_t& stream);
......
/**
* \file dnn/src/cuda/relayout_format/relayout_format_kern.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/int_fastdiv.cuh"
#include "src/cuda/memory_utils.cuh"
#include "src/cuda/relayout_format/translayout.cuh"
namespace megdnn {
namespace cuda {
namespace relayout_format {
namespace internal {
using namespace memory;
template <typename Type_, int pack_size_, int chan_blk_, int width_,
int size_nbits_>
class TensorIteratorOverChannel {
public:
using Type = Type_;
static constexpr int pack_size = pack_size_;
static constexpr int chan_blk = chan_blk_;
static constexpr int width = width_;
static constexpr int size_nbits = size_nbits_;
static constexpr int elements_in_type =
chan_blk * width * size_nbits / (8 * sizeof(Type));
static constexpr int lane_size_in_type =
(width * pack_size * size_nbits) / (8 * sizeof(Type));
static constexpr int pack_size_in_type =
(pack_size * size_nbits) >= (8 * sizeof(Type))
? (pack_size * size_nbits / (8 * sizeof(Type)))
: (width * pack_size * size_nbits / (8 * sizeof(Type)));
static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type);
using AccessType = array_wrapper<Type, pack_size_in_type>;
using Fragment = array_wrapper<Type, elements_in_type>;
MEGDNN_HOST TensorIteratorOverChannel()
: pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {}
MEGDNN_HOST TensorIteratorOverChannel(Type* pointer_,
int chan_stride_in_elements_,
int channel_, int, int)
: pointer{pointer_},
chan_stride_in_elements{chan_stride_in_elements_},
channel{channel_} {}
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) {
pointer += (c_idx / pack_size) * chan_stride_in_elements +
hw_idx * pack_size * size_nbits / (8 * sizeof(Type));
channel -= c_idx;
}
MEGDNN_DEVICE __forceinline__ void add_pointer_offset(
size_t offset_in_type) {
pointer += offset_in_type;
}
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) {
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
Type* pointer_ = pointer;
#pragma unroll
for (int i = 0; i < chan_blk; i += pack_size) {
#pragma unroll
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) {
int frag_idx = i / pack_size *
(lane_size_in_type / pack_size_in_type) +
j;
bool guard = i < channel;
global_load<AccessType, pack_size_in_byte>(
frag_ptr[frag_idx],
reinterpret_cast<void*>(pointer_ +
j * pack_size_in_type),
guard, zero_point);
}
pointer_ += chan_stride_in_elements;
}
}
MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) {
const AccessType* frag_ptr = reinterpret_cast<const AccessType*>(&frag);
Type* pointer_ = pointer;
#pragma unroll
for (int i = 0; i < chan_blk; i += pack_size) {
#pragma unroll
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) {
int frag_idx = i / pack_size *
(lane_size_in_type / pack_size_in_type) +
j;
bool guard = i < channel;
global_store<AccessType, pack_size_in_byte>(
frag_ptr[frag_idx],
reinterpret_cast<void*>(pointer_ +
j * pack_size_in_type),
guard);
}
pointer_ += chan_stride_in_elements;
}
}
MEGDNN_DEVICE __forceinline__ void advance() {
pointer += (chan_blk / pack_size) * chan_stride_in_elements;
channel -= chan_blk;
}
private:
Type* pointer;
int chan_stride_in_elements;
int channel;
};
template <typename Type_, int pack_size_, int chan_blk_, int width_,
int size_nbits_>
class MaskedTensorIteratorOverChannel {
public:
using Type = Type_;
static constexpr int pack_size = pack_size_;
static constexpr int chan_blk = chan_blk_;
static constexpr int width = width_;
static constexpr int size_nbits = size_nbits_;
static constexpr int elements_in_type =
chan_blk * width * size_nbits / (8 * sizeof(Type));
static constexpr int lane_size_in_type =
(width * pack_size * size_nbits) / (8 * sizeof(Type));
static constexpr int pack_size_in_type =
(pack_size * size_nbits) >= (8 * sizeof(Type))
? (pack_size * size_nbits / (8 * sizeof(Type)))
: (width * pack_size * size_nbits / (8 * sizeof(Type)));
static constexpr int pack_size_in_byte = pack_size_in_type * sizeof(Type);
static constexpr int accesses = elements_in_type / pack_size_in_type;
static constexpr int mask_size = (accesses + 32 - 1) / 32;
using AccessType = array_wrapper<Type, pack_size_in_type>;
using Fragment = array_wrapper<Type, elements_in_type>;
MEGDNN_HOST MaskedTensorIteratorOverChannel()
: pointer{nullptr}, chan_stride_in_elements{0}, channel{0} {}
MEGDNN_HOST MaskedTensorIteratorOverChannel(Type* pointer_,
int chan_stride_in_elements_,
int channel_, int bound_,
int div_)
: pointer{pointer_},
chan_stride_in_elements{chan_stride_in_elements_},
channel{channel_},
bound{bound_},
div{uint32_t(div_)} {}
MEGDNN_DEVICE __forceinline__ void initialize(int c_idx, int hw_idx) {
pointer += (c_idx / pack_size) * chan_stride_in_elements;
channel -= c_idx;
int w[lane_size_in_type / pack_size_in_type];
#pragma unroll
for (int i = 0; i < mask_size; ++i) {
mask[i] = 0;
}
#pragma unroll
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) {
int offset = hw_idx + j;
int h = (int)((uint32_t)(offset) / div);
w[j] = (int)((uint32_t)(offset) % div);
stride[j] = (h * bound + w[j]) * pack_size * size_nbits /
(8 * sizeof(Type));
}
#pragma unroll
for (int i = 0; i < chan_blk; i += pack_size) {
#pragma unroll
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) {
bool guard = (i < channel) && (w[j] < bound);
int index = (i / pack_size) *
(lane_size_in_type / pack_size_in_type) +
j;
int mask_index = (index >> 5);
int mask_shift = (index & 0x1f);
mask[mask_index] |= (guard << mask_shift);
}
}
}
MEGDNN_DEVICE __forceinline__ void add_pointer_offset(
size_t offset_in_type) {
pointer += offset_in_type;
}
MEGDNN_DEVICE __forceinline__ void load(Fragment& frag, int zero_point) {
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
Type* pointer_ = pointer;
#pragma unroll
for (int i = 0; i < chan_blk; i += pack_size) {
#pragma unroll
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) {
int frag_idx = i / pack_size *
(lane_size_in_type / pack_size_in_type) +
j;
int mask_index = (frag_idx >> 5);
int mask_shift = (frag_idx & 0x1f);
bool guard = (mask[mask_index] & (1 << mask_shift));
global_load<AccessType, pack_size_in_byte>(
frag_ptr[frag_idx],
reinterpret_cast<void*>(pointer_ + stride[j]), guard,
zero_point);
}
pointer_ += chan_stride_in_elements;
}
}
MEGDNN_DEVICE __forceinline__ void store(const Fragment& frag) {
const AccessType* frag_ptr = reinterpret_cast<const AccessType*>(&frag);
Type* pointer_ = pointer;
#pragma unroll
for (int i = 0; i < chan_blk; i += pack_size) {
#pragma unroll
for (int j = 0; j < lane_size_in_type / pack_size_in_type; j++) {
int frag_idx = i / pack_size *
(lane_size_in_type / pack_size_in_type) +
j;
int mask_index = (frag_idx >> 5);
int mask_shift = (frag_idx & 0x1f);
bool guard = (mask[mask_index] & (1 << mask_shift));
global_store<AccessType, pack_size_in_byte>(
frag_ptr[frag_idx],
reinterpret_cast<void*>(pointer_ + stride[j]), guard);
}
pointer_ += chan_stride_in_elements;
}
}
MEGDNN_DEVICE __forceinline__ void advance() {
pointer += (chan_blk / pack_size) * chan_stride_in_elements;
channel -= chan_blk;
}
private:
Type* pointer;
int chan_stride_in_elements;
int channel;
int bound;
Uint32Fastdiv div;
uint32_t mask[mask_size];
size_t stride[lane_size_in_type / pack_size_in_type];
};
template <bool padding_, typename Type_, int pack_size_, int chan_blk_,
int width_, int size_nbits_>
struct TensorIteratorPolicy;
template <typename Type_, int pack_size_, int chan_blk_, int width_,
int size_nbits_>
struct TensorIteratorPolicy<true, Type_, pack_size_, chan_blk_, width_,
size_nbits_> {
using TensorIterator =
MaskedTensorIteratorOverChannel<Type_, pack_size_, chan_blk_,
width_, size_nbits_>;
};
template <typename Type_, int pack_size_, int chan_blk_, int width_,
int size_nbits_>
struct TensorIteratorPolicy<false, Type_, pack_size_, chan_blk_, width_,
size_nbits_> {
using TensorIterator =
TensorIteratorOverChannel<Type_, pack_size_, chan_blk_, width_,
size_nbits_>;
};
template <typename SrcIterator_, typename DstIterator_, typename Transpose_,
typename CudaPostProcess_>
struct RelayoutProblem {
using SrcIterator = SrcIterator_;
using DstIterator = DstIterator_;
using Transpose = Transpose_;
using CudaPostProcess = CudaPostProcess_;
MEGDNN_STATIC_ASSERT(SrcIterator::chan_blk == DstIterator::chan_blk,
"channel block mismatch");
MEGDNN_STATIC_ASSERT(SrcIterator::width == DstIterator::width,
"width block mismatch");
MEGDNN_STATIC_ASSERT(SrcIterator::size_nbits == DstIterator::size_nbits,
"size in bits of elements mismatch");
static constexpr int pack_chan = SrcIterator::chan_blk;
static constexpr int pack_width = SrcIterator::width;
using DnnSrcType = typename CudaPostProcess::SrcType;
using DnnDstType = typename CudaPostProcess::DstType;
struct Param {
SrcIterator src_iterator;
DstIterator dst_iterator;
CudaPostProcess post_process;
int n_stride_src;
int n_stride_dst;
int batch_size;
int channels;
int hw;
int zero_point;
MEGDNN_HOST MEGDNN_DEVICE Param(SrcIterator src_iterator_,
DstIterator dst_iterator_,
CudaPostProcess post_process_,
int n_stride_src_, int n_stride_dst_,
int batch_size_, int channels_, int hw_,
int zero_point_)
: src_iterator{src_iterator_},
dst_iterator{dst_iterator_},
post_process{post_process_},
n_stride_src{n_stride_src_},
n_stride_dst{n_stride_dst_},
batch_size{batch_size_},
channels{channels_},
hw{hw_},
zero_point{zero_point_} {}
};
};
template <typename RelayoutProblem_>
__global__ void relayout_kern(typename RelayoutProblem_::Param param) {
using SrcIterator = typename RelayoutProblem_::SrcIterator;
using DstIterator = typename RelayoutProblem_::DstIterator;
static constexpr int pack_chan = RelayoutProblem_::pack_chan;
static constexpr int pack_width = RelayoutProblem_::pack_width;
const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int thread_offset = thread_idx * pack_width;
const int hw_idx = (thread_offset % param.hw);
const int nc_blks = thread_offset / param.hw;
const int c_blks = (param.channels + pack_chan - 1) / pack_chan;
const int n_idx = nc_blks / c_blks;
const int c_blk_idx = nc_blks % c_blks;
const int c_idx = c_blk_idx * pack_chan;
if (n_idx < param.batch_size) {
const int src_offset = n_idx * param.n_stride_src;
const int dst_offset = n_idx * param.n_stride_dst;
param.src_iterator.add_pointer_offset(src_offset);
param.dst_iterator.add_pointer_offset(dst_offset);
param.src_iterator.initialize(c_idx, hw_idx);
param.dst_iterator.initialize(c_idx, hw_idx);
typename SrcIterator::Fragment src_frag;
typename DstIterator::Fragment dst_frag;
int zp = make_zero<SrcIterator::size_nbits>(param.zero_point);
param.src_iterator.load(src_frag, zp);
RelayoutProblem_::Transpose::trans(
reinterpret_cast<typename SrcIterator::Fragment&>(dst_frag),
src_frag, param.post_process);
param.dst_iterator.store(dst_frag);
}
}
} // namespace internal
} // namespace relayout_format
} // namespace cuda
} // namespace megdnn
/**
* \file dnn/src/cuda/relayout_format/relayout_format_utils.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/integer_subbyte_utils.cuh"
#include "src/cuda/relayout_format/relayout_format.cuh"
namespace megdnn {
namespace cuda {
namespace relayout_format {
namespace internal {
template <typename cype, int pack_w, typename enable = void>
struct DTypeRWHelper;
template <typename ctype>
struct DTypeRWHelper<
ctype, 1,
typename std::enable_if<std::is_same<ctype, dt_qint8>::value ||
std::is_same<ctype, dt_quint8>::value ||
std::is_same<ctype, dt_uint8>::value>::type> {
using InnerDtype = char;
using DstDtype = char4;
};
template <typename ctype>
struct DTypeRWHelper<
ctype, 4,
typename std::enable_if<std::is_same<ctype, dt_qint8>::value ||
std::is_same<ctype, dt_quint8>::value ||
std::is_same<ctype, dt_uint8>::value>::type> {
using InnerDtype = char4;
using DstDtype = char4;
};
template <>
struct DTypeRWHelper<dt_qint32, 1> {
using InnerDtype = int;
using DstDtype = int4;
};
template <>
struct DTypeRWHelper<dt_qint32, 4> {
using InnerDtype = int4;
using DstDtype = int4;
};
template <typename ctype>
struct DTypeRWHelper<
ctype, 2,
typename std::enable_if<std::is_same<ctype, dt_qint4>::value ||
std::is_same<ctype, dt_quint4>::value>::type> {
using InnerDtype = char;
using DstDtype = array_wrapper<uint8_t, 32>;
};
template <typename ctype>
struct DTypeRWHelper<
ctype, 8,
typename std::enable_if<std::is_same<ctype, dt_qint4>::value ||
std::is_same<ctype, dt_quint4>::value>::type> {
using InnerDtype = unsigned;
using DstDtype = array_wrapper<uint8_t, 32>;
};
template <typename DstType>
inline __device__ DstType make_zero_pad(const uint8_t zero_point) {
return zero_point;
}
template <>
inline __device__ char4 make_zero_pad<char4>(const uint8_t zero_point) {
char izp = reinterpret_cast<const char&>(zero_point);
return {izp, izp, izp, izp};
}
template <>
inline __device__ int4 make_zero_pad<int4>(const uint8_t zero_point) {
return {zero_point, zero_point, zero_point, zero_point};
}
template <int size_nbits>
inline __device__ int make_zero(int zero_point);
template <>
inline __device__ int make_zero<4>(int zero_point) {
return integer_subbyte::transform_int8_to_uint4x8(
zero_point, zero_point, zero_point, zero_point, zero_point,
zero_point, zero_point, zero_point);
}
template <typename DstDtype>
inline __device__ void write_helper(DstDtype* ptr, DstDtype val) {
*ptr = val;
}
template <>
inline __device__ void write_helper<char4>(char4* ptr, char4 val) {
int32_t* rel_ptr = (int32_t*)ptr;
*rel_ptr = *(int32_t*)(&val);
}
template <>
inline __device__ void write_helper<array_wrapper<uint8_t, 32>>(
array_wrapper<uint8_t, 32>* ptr, array_wrapper<uint8_t, 32> val) {
uint4 const* data = reinterpret_cast<uint4 const*>(&val);
void* ptr_ = reinterpret_cast<void*>(ptr);
asm volatile(
"{\n"
" st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
" st.global.v4.u32 [%5], {%6, %7, %8, %9};\n"
"}\n"
:
: "l"(ptr_), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z),
"r"(data[0].w), "l"(((uint8_t*)ptr_) + 16), "r"(data[1].x),
"r"(data[1].y), "r"(data[1].z), "r"(data[1].w));
}
} // namespace internal
} // namespace relayout_format
} // namespace cuda
} // namespace megdnn
此差异已折叠。
......@@ -176,60 +176,22 @@ __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat,
}
}
template <bool signedness>
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8(int s0, int s1,
int s2, int s3,
int s4, int s5,
int s6, int s7);
template <>
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<true>(
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) {
return transform_int8_to_int4x8(s0, s1, s2, s3, s4, s5, s6, s7);
}
template <>
MEGDNN_DEVICE __forceinline__ int transform_int8_to_bit4x8<false>(
int s0, int s1, int s2, int s3, int s4, int s5, int s6, int s7) {
return transform_int8_to_uint4x8(s0, s1, s2, s3, s4, s5, s6, s7);
}
template <bool signedness>
MEGDNN_DEVICE __forceinline__ void
transform_bit4x8_to_int8(int (&result)[8], const int& source);
template <>
MEGDNN_DEVICE __forceinline__ void
transform_bit4x8_to_int8<true>(int (&result)[8], const int& source){
transform_int4x8_to_int8(result, source);
}
template <>
MEGDNN_DEVICE __forceinline__ void
transform_bit4x8_to_int8<false>(int (&result)[8], const int& source){
transform_uint4x8_to_int8(result, source);
}
template <bool signedness, typename OutputConverter>
MEGDNN_DEVICE __forceinline__ int pack_output_func(
OutputConverter& output_converter, int (&s00)[8], int (&s01)[8],
int (&s10)[8], int (&s11)[8], float w00, float w01, float w10,
float w11) {
#define warp_perspective_transform(idx) \
static_cast<int>(output_converter(s00[idx] * w00 + \
s01[idx] * w01 + \
s10[idx] * w10 + \
s11[idx] * w11) \
#define warp_perspective_transform(idx) \
static_cast<int>(output_converter(s00[idx] * w00 + s01[idx] * w01 + \
s10[idx] * w10 + s11[idx] * w11) \
.as_storage())
return transform_int8_to_bit4x8<signedness>(
return transform_int8_to_b4x8<signedness>(
warp_perspective_transform(0), warp_perspective_transform(1),
warp_perspective_transform(2), warp_perspective_transform(3),
warp_perspective_transform(4), warp_perspective_transform(5),
warp_perspective_transform(6), warp_perspective_transform(7));
#undef warp_perspective_transform
#undef warp_perspective_transform
}
template <typename ctype, typename Getter, typename SrcVisitor,
......@@ -278,31 +240,31 @@ __global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat,
s[2] = __ldg(sptr_int4 + i_coor_10 + c1);
s[3] = __ldg(sptr_int4 + i_coor_11 + c1);
transform_bit4x8_to_int8<signedness>(s00, s[0].x);
transform_bit4x8_to_int8<signedness>(s01, s[1].x);
transform_bit4x8_to_int8<signedness>(s10, s[2].x);
transform_bit4x8_to_int8<signedness>(s11, s[3].x);
transform_b4x8_to_int8<signedness>(s00, s[0].x);
transform_b4x8_to_int8<signedness>(s01, s[1].x);
transform_b4x8_to_int8<signedness>(s10, s[2].x);
transform_b4x8_to_int8<signedness>(s11, s[3].x);
d.x = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, w00, w01, w10, w11);
transform_bit4x8_to_int8<signedness>(s00, s[0].y);
transform_bit4x8_to_int8<signedness>(s01, s[1].y);
transform_bit4x8_to_int8<signedness>(s10, s[2].y);
transform_bit4x8_to_int8<signedness>(s11, s[3].y);
transform_b4x8_to_int8<signedness>(s00, s[0].y);
transform_b4x8_to_int8<signedness>(s01, s[1].y);
transform_b4x8_to_int8<signedness>(s10, s[2].y);
transform_b4x8_to_int8<signedness>(s11, s[3].y);
d.y = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, w00, w01, w10, w11);
transform_bit4x8_to_int8<signedness>(s00, s[0].z);
transform_bit4x8_to_int8<signedness>(s01, s[1].z);
transform_bit4x8_to_int8<signedness>(s10, s[2].z);
transform_bit4x8_to_int8<signedness>(s11, s[3].z);
transform_b4x8_to_int8<signedness>(s00, s[0].z);
transform_b4x8_to_int8<signedness>(s01, s[1].z);
transform_b4x8_to_int8<signedness>(s10, s[2].z);
transform_b4x8_to_int8<signedness>(s11, s[3].z);
d.z = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, w00, w01, w10, w11);
transform_bit4x8_to_int8<signedness>(s00, s[0].w);
transform_bit4x8_to_int8<signedness>(s01, s[1].w);
transform_bit4x8_to_int8<signedness>(s10, s[2].w);
transform_bit4x8_to_int8<signedness>(s11, s[3].w);
transform_b4x8_to_int8<signedness>(s00, s[0].w);
transform_b4x8_to_int8<signedness>(s01, s[1].w);
transform_b4x8_to_int8<signedness>(s10, s[2].w);
transform_b4x8_to_int8<signedness>(s11, s[3].w);
d.w = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, w00, w01, w10, w11);
......@@ -403,15 +365,7 @@ __global__ void kern_const_border_nchw4(SrcVisitor src,
}
}
}
template <bool signedness>
MEGDNN_DEVICE __forceinline__ static void transform_bit4x8_to_int8(
int (&result)[8], const int& source) {
#pragma unroll
for (int i = 0; i < 8; i++) {
result[i] = unpack_integer_4bits<signedness>(
reinterpret_cast<unsigned const&>(source), (i << 2));
}
}
template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_const_border_nchw64(SrcVisitor src,
......@@ -457,7 +411,7 @@ __global__ void kern_const_border_nchw64(SrcVisitor src,
bool flag00 = okh0 && okw0, flag01 = okh0 && okw1,
flag10 = okh1 && okw0, flag11 = okh1 && okw1;
int8_t bval_4 = bval.as_storage() & 0xF;
int bval_8 = transform_int8_to_bit4x8<signedness>(
int bval_8 = transform_int8_to_b4x8<signedness>(
bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4, bval_4);
int4 bval_int4;
bval_int4.x = bval_8;
......@@ -488,31 +442,31 @@ __global__ void kern_const_border_nchw64(SrcVisitor src,
s[3] = bval_int4;
}
transform_bit4x8_to_int8<signedness>(s00, s[0].x);
transform_bit4x8_to_int8<signedness>(s01, s[1].x);
transform_bit4x8_to_int8<signedness>(s10, s[2].x);
transform_bit4x8_to_int8<signedness>(s11, s[3].x);
transform_b4x8_to_int8<signedness>(s00, s[0].x);
transform_b4x8_to_int8<signedness>(s01, s[1].x);
transform_b4x8_to_int8<signedness>(s10, s[2].x);
transform_b4x8_to_int8<signedness>(s11, s[3].x);
d.x = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, w00, w01, w10, w11);
transform_bit4x8_to_int8<signedness>(s00, s[0].y);
transform_bit4x8_to_int8<signedness>(s01, s[1].y);
transform_bit4x8_to_int8<signedness>(s10, s[2].y);
transform_bit4x8_to_int8<signedness>(s11, s[3].y);
transform_b4x8_to_int8<signedness>(s00, s[0].y);
transform_b4x8_to_int8<signedness>(s01, s[1].y);
transform_b4x8_to_int8<signedness>(s10, s[2].y);
transform_b4x8_to_int8<signedness>(s11, s[3].y);
d.y = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, w00, w01, w10, w11);
transform_bit4x8_to_int8<signedness>(s00, s[0].z);
transform_bit4x8_to_int8<signedness>(s01, s[1].z);
transform_bit4x8_to_int8<signedness>(s10, s[2].z);
transform_bit4x8_to_int8<signedness>(s11, s[3].z);
transform_b4x8_to_int8<signedness>(s00, s[0].z);
transform_b4x8_to_int8<signedness>(s01, s[1].z);
transform_b4x8_to_int8<signedness>(s10, s[2].z);
transform_b4x8_to_int8<signedness>(s11, s[3].z);
d.z = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, w00, w01, w10, w11);
transform_bit4x8_to_int8<signedness>(s00, s[0].w);
transform_bit4x8_to_int8<signedness>(s01, s[1].w);
transform_bit4x8_to_int8<signedness>(s10, s[2].w);
transform_bit4x8_to_int8<signedness>(s11, s[3].w);
transform_b4x8_to_int8<signedness>(s00, s[0].w);
transform_b4x8_to_int8<signedness>(s01, s[1].w);
transform_b4x8_to_int8<signedness>(s10, s[2].w);
transform_b4x8_to_int8<signedness>(s11, s[3].w);
d.w = pack_output_func<signedness>(output_converter, s00, s01, s10,
s11, w00, w01, w10, w11);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册