提交 a0a5fcf1 编写于 作者: M Megvii Engine Team

feat(dnn): support tf32

GitOrigin-RevId: 9e5871f933744468b91b7ab5ac6159a4b7a67084
上级 f0088335
......@@ -88,7 +88,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {
#if CUDART_VERSION >= 9010
auto io16_c32 = [&]() {
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#else
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
auto zero = handle->zero_device();
auto one = handle->one_device();
cublas_check(cublasGemmBatchedEx(
......@@ -104,7 +108,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {
#if CUDART_VERSION >= 9000
auto io16_c16 = [&]() {
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#else
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
auto zero = handle->zero_device_h();
auto one = handle->one_device_h();
cublas_check(cublasHgemmBatched(
......
......@@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const
batched_igemm();
} else if (desc.dt_compute == CUBLAS_COMPUTE_16F) {
batched_hgemm();
} else if (desc.dt_compute == CUBLAS_COMPUTE_32F) {
} else if (desc.dt_compute == CUBLAS_COMPUTE_32F_FAST_TF32) {
batched_sgemm();
} else {
megdnn_throw("compute_type must be int32/float16/float32");
......
......@@ -49,18 +49,26 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
auto sgemm = [&]() {
auto zero = handle->zero_device();
auto one = handle->one_device();
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
cublas_check(cublasSgemm(
cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0],
args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0], zero,
args.tensor_c.ptr<dt_float32>(), args.tensor_c.layout.stride[0]));
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
#endif
};
auto sgemm_ex = [&]() {
auto zero = handle->zero_device();
auto one = handle->one_device();
#if CUDART_VERSION >= 9000
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#elif CUDART_VERSION >= 9000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
auto sgemm_ex_err = cublasSgemmEx(
......@@ -78,7 +86,9 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
};
auto hgemm = [&]() {
#if CUDART_VERSION >= 9000
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#elif CUDART_VERSION >= 9000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
auto one_half = handle->one_device_h();
......
......@@ -28,7 +28,7 @@ static cublasComputeType_t to_cublas_compute_type(DType tp) {
case DTypeEnum::Float16:
return CUBLAS_COMPUTE_16F;
case DTypeEnum::Float32:
return CUBLAS_COMPUTE_32F;
return CUBLAS_COMPUTE_32F_FAST_TF32;
case DTypeEnum::Int32:
case DTypeEnum::QuantizedS32:
return CUBLAS_COMPUTE_32I;
......
......@@ -107,7 +107,7 @@ void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const {
case CUBLAS_COMPUTE_16F:
hgemm();
break;
case CUBLAS_COMPUTE_32F:
case CUBLAS_COMPUTE_32F_FAST_TF32:
sgemm();
break;
case CUBLAS_COMPUTE_32I:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册