/** * \file dnn/src/common/rounding_converter.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "megdnn/dtype.h" #if MEGDNN_CC_HOST && !defined(__host__) #define MEGDNN_HOST_DEVICE_SELF_DEFINE #define __host__ #define __device__ #if __GNUC__ || __has_attribute(always_inline) #define __forceinline__ inline __attribute__((always_inline)) #else #define __forceinline__ inline #endif #endif namespace megdnn { namespace rounding { template struct RoundingConverter; template <> struct RoundingConverter { __host__ __device__ __forceinline__ float operator()(float x) const { return x; } }; #ifndef MEGDNN_DISABLE_FLOAT16 template <> struct RoundingConverter { __host__ __device__ __forceinline__ half_float::half operator()( float x) const { return static_cast(x); } }; template <> struct RoundingConverter { __host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()( float x) const { return static_cast(x); } }; #endif // #ifdef MEGDNN_DISABLE_FLOAT16 template <> struct RoundingConverter { __host__ __device__ __forceinline__ int8_t operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif return static_cast(round(x)); } }; template <> struct RoundingConverter { __host__ __device__ __forceinline__ uint8_t operator()(float x) const { #if MEGDNN_CC_HOST using std::round; using std::max; using std::min; #endif x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places return static_cast(round(x)); } }; template <> struct RoundingConverter { __host__ __device__ __forceinline__ dt_qint4 operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif return static_cast(round(x)); } }; template <> struct RoundingConverter { __host__ __device__ __forceinline__ dt_quint4 operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif return static_cast(round(x)); } }; } // namespace rounding } // namespace megdnn /* vim: set ft=cpp: */