提交 6ea8e019 编写于 作者: M Megvii Engine Team

fix(dnn): correctly using MEGDNN_DISABLE_FLOAT16 directives

GitOrigin-RevId: c6b124f195c9fc3a830bb058797d7d5619aad72d
上级 45a9977d
...@@ -15,6 +15,7 @@ using conv_fun = std::function<void( ...@@ -15,6 +15,7 @@ using conv_fun = std::function<void(
const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param,
const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids)>; const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids)>;
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_nchw88) MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_nchw88)
MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_nchw88_stride1)
namespace { namespace {
static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
......
...@@ -14,7 +14,7 @@ struct RoundingConverter<float> { ...@@ -14,7 +14,7 @@ struct RoundingConverter<float> {
} }
}; };
#ifndef MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
template <> template <>
struct RoundingConverter<half_float::half> { struct RoundingConverter<half_float::half> {
...@@ -32,7 +32,7 @@ struct RoundingConverter<half_bfloat16::bfloat16> { ...@@ -32,7 +32,7 @@ struct RoundingConverter<half_bfloat16::bfloat16> {
} }
}; };
#endif // #ifdef MEGDNN_DISABLE_FLOAT16 #endif // #if !MEGDNN_DISABLE_FLOAT16
template <> template <>
struct RoundingConverter<int8_t> { struct RoundingConverter<int8_t> {
......
...@@ -295,7 +295,7 @@ void WarpPerspectiveForwardImpl::exec( ...@@ -295,7 +295,7 @@ void WarpPerspectiveForwardImpl::exec(
m_error_tracker, stream); m_error_tracker, stream);
} else if (DNN_FLOAT16_SELECT( } else if (DNN_FLOAT16_SELECT(
src.layout.dtype == dtype::Float16(), false)) { src.layout.dtype == dtype::Float16(), false)) {
#ifndef MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
warp_perspective::forward_proxy( warp_perspective::forward_proxy(
is_nhwc, src.ptr<dt_float16>(), mat.ptr<dt_float32>(), is_nhwc, src.ptr<dt_float16>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr, mat_idx.raw_ptr() ? mat_idx.ptr<int>() : nullptr,
...@@ -563,7 +563,7 @@ void WarpPerspectiveForwardImpl::exec( ...@@ -563,7 +563,7 @@ void WarpPerspectiveForwardImpl::exec(
m_error_tracker, stream); m_error_tracker, stream);
} else if (DNN_FLOAT16_SELECT( } else if (DNN_FLOAT16_SELECT(
src.layout.dtype == dtype::Float16(), false)) { src.layout.dtype == dtype::Float16(), false)) {
#ifndef MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
SmallVector<size_t> workspace_sizes{sizeof(dt_float16*) * srcs.size()}; SmallVector<size_t> workspace_sizes{sizeof(dt_float16*) * srcs.size()};
WorkspaceBundle workspace_cpu(nullptr, workspace_sizes); WorkspaceBundle workspace_cpu(nullptr, workspace_sizes);
auto total_workspace_size = workspace_cpu.total_size_in_bytes(); auto total_workspace_size = workspace_cpu.total_size_in_bytes();
......
...@@ -1924,7 +1924,7 @@ void forward_proxy_nchw64( ...@@ -1924,7 +1924,7 @@ void forward_proxy_nchw64(
cudaStream_t); cudaStream_t);
INST(float) INST(float)
INST(uint8_t) INST(uint8_t)
#ifndef MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
INST(dt_float16) INST(dt_float16)
#endif #endif
INST(int8_t) INST(int8_t)
...@@ -1936,7 +1936,7 @@ INST(int8_t) ...@@ -1936,7 +1936,7 @@ INST(int8_t)
int, int, int, ctype, BorderMode, megcore::AsyncErrorInfo*, void*, \ int, int, int, ctype, BorderMode, megcore::AsyncErrorInfo*, void*, \
cudaStream_t); cudaStream_t);
INST(float) INST(float)
#ifndef MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
INST(dt_float16) INST(dt_float16)
#endif #endif
#undef INST #undef INST
......
...@@ -73,7 +73,7 @@ struct powci_general_even { ...@@ -73,7 +73,7 @@ struct powci_general_even {
template <size_t size> template <size_t size>
struct float_itype; struct float_itype;
#ifndef MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
template <> template <>
struct float_itype<2> { struct float_itype<2> {
using type = uint16_t; using type = uint16_t;
......
...@@ -84,7 +84,7 @@ ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors( ...@@ -84,7 +84,7 @@ ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors(
#define INST(_dtype) template struct ResizeImpl::KernParam<_dtype>; #define INST(_dtype) template struct ResizeImpl::KernParam<_dtype>;
INST(dt_float32); INST(dt_float32);
#ifndef MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
INST(dt_float16); INST(dt_float16);
#endif #endif
INST(dt_int8); INST(dt_int8);
......
...@@ -15,7 +15,7 @@ MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) { ...@@ -15,7 +15,7 @@ MegRay::DType mgb::opr::get_megray_dtype(megdnn::DType dtype) {
return MegRay::DType::MEGRAY_INT32; return MegRay::DType::MEGRAY_INT32;
case DTypeEnum::Float32: case DTypeEnum::Float32:
return MegRay::DType::MEGRAY_FLOAT32; return MegRay::DType::MEGRAY_FLOAT32;
#ifndef MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
case DTypeEnum::Float16: case DTypeEnum::Float16:
return MegRay::DType::MEGRAY_FLOAT16; return MegRay::DType::MEGRAY_FLOAT16;
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册