提交 83cf4ee6 编写于 作者: M Megvii Engine Team

refactor(dnn/rocm): remove some useless includes

GitOrigin-RevId: 3d2c315a368f7307a88ba37f0674a36072281578
上级 323a4642
...@@ -174,7 +174,7 @@ template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \ ...@@ -174,7 +174,7 @@ template void argsort::forward<dtype>(const dtype*, dtype*, int*, void*, \
ARGSORT_FOREACH_CTYPE(INST_FORWARD) ARGSORT_FOREACH_CTYPE(INST_FORWARD)
INST_CUB_SORT(uint32_t) INST_CUB_SORT(uint32_t)
// INST_CUB_SORT(uint64_t) INST_CUB_SORT(uint64_t)
#undef INST_CUB_SORT #undef INST_CUB_SORT
#undef INST_FORWARD #undef INST_FORWARD
} }
......
...@@ -40,6 +40,7 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, ...@@ -40,6 +40,7 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace,
const int* iptr_src = NULL); const int* iptr_src = NULL);
//! iterate over all supported data types //! iterate over all supported data types
// device_radix_sort does not support dt_float16 dtype(half_float::half in rocm)
#define ARGSORT_FOREACH_CTYPE(cb) \ #define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16)) cb(float) cb(int32_t) // DNN_INC_FLOAT16(cb(dt_float16))
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
#include "./argsort.h.hip" #include "./argsort.h.hip"
#include "./backward.h.hip" #include "./backward.h.hip"
// #include "src/rocm/utils.h"
using namespace megdnn; using namespace megdnn;
using namespace rocm; using namespace rocm;
using namespace argsort; using namespace argsort;
......
...@@ -11,13 +11,9 @@ ...@@ -11,13 +11,9 @@
#include "hcc_detail/hcc_defs_prologue.h" #include "hcc_detail/hcc_defs_prologue.h"
#include "./bitonic_sort.h.hip" #include "./bitonic_sort.h.hip"
// #include "src/cuda/query_blocksize.cuh" #include "megdnn/dtype.h"
// #include "megdnn/dtype.h"
// #if __CUDACC_VER_MAJOR__ < 9
// #pragma message "warp sync disabled due to insufficient cuda version"
#define __syncwarp __syncthreads #define __syncwarp __syncthreads
// #endif
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
...@@ -84,17 +80,17 @@ struct NumTrait<int32_t> { ...@@ -84,17 +80,17 @@ struct NumTrait<int32_t> {
static __device__ __forceinline__ int32_t min() { return INT_MIN; } static __device__ __forceinline__ int32_t min() { return INT_MIN; }
}; };
// #if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
// template <> template <>
// struct NumTrait<dt_float16> { struct NumTrait<dt_float16> {
// static __device__ __forceinline__ dt_float16 max() { static __device__ __forceinline__ dt_float16 max() {
// return std::numeric_limits<dt_float16>::max(); return std::numeric_limits<dt_float16>::max();
// } }
// static __device__ __forceinline__ dt_float16 min() { static __device__ __forceinline__ dt_float16 min() {
// return std::numeric_limits<dt_float16>::lowest(); return std::numeric_limits<dt_float16>::lowest();
// } }
// }; };
// #endif #endif
struct LessThan { struct LessThan {
template <typename Key, typename Value> template <typename Key, typename Value>
...@@ -310,7 +306,7 @@ namespace rocm { ...@@ -310,7 +306,7 @@ namespace rocm {
INST(float, int); INST(float, int);
INST(int32_t, int); INST(int32_t, int);
// DNN_INC_FLOAT16(INST(dt_float16, int)); DNN_INC_FLOAT16(INST(dt_float16, int));
#undef INST #undef INST
} // namespace megdnn } // namespace megdnn
......
...@@ -18,13 +18,7 @@ ...@@ -18,13 +18,7 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#if __CUDACC_VER_MAJOR__ < 9
#pragma message "topk is a little slower on cuda earlier than 9.0"
// on cuda 9.0 and later, due to thread-divergent branches we should use
// __syncwarp; and I am too lazy to implement a correct legacy version, so just
// use __syncthreads instead for older cuda
#define __syncwarp __syncthreads #define __syncwarp __syncthreads
#endif
using namespace megdnn; using namespace megdnn;
using namespace rocm; using namespace rocm;
...@@ -256,12 +250,12 @@ static __global__ void update_prefix_and_k(const uint32_t* bucket_cnt, ...@@ -256,12 +250,12 @@ static __global__ void update_prefix_and_k(const uint32_t* bucket_cnt,
} }
} }
//if ((cumsum_bucket_cnt[NR_BUCKET] < kv) | if ((cumsum_bucket_cnt[NR_BUCKET] < kv) |
// (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) { (cumsum_bucket_cnt[i] != cumsum_bucket_cnt[i - 1] + sum)) {
// // impossible // impossible
// int* bad = 0x0; int* bad = 0x0;
// *bad = 23; *bad = 23;
//} }
} }
static uint32_t get_grid_dim_x(uint32_t length) { static uint32_t get_grid_dim_x(uint32_t length) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册