From 5158fa4f3736be7b22dbcabfa09d5244478d25a1 Mon Sep 17 00:00:00 2001 From: umiswing Date: Tue, 1 Nov 2022 00:03:52 +0800 Subject: [PATCH] =?UTF-8?q?summer-ospp=202022:=20=E9=A3=9E=E6=A1=A8PaddleP?= =?UTF-8?q?addle=20Sparse=20Conv=E5=BC=80=E5=8F=91=E5=92=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96:=20gather-gemm-scatter=20fuse=20(#46679)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmake/external/cutlass.cmake | 43 ++ cmake/third_party.cmake | 10 + paddle/phi/kernels/sparse/gpu/conv_kernel.cu | 201 +++++-- .../kernels/sparse/gpu/gather_gemm_scatter.cu | 188 ++++++ .../kernels/sparse/gpu/gather_gemm_scatter.h | 555 ++++++++++++++++++ 5 files changed, 941 insertions(+), 56 deletions(-) create mode 100644 cmake/external/cutlass.cmake create mode 100644 paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu create mode 100644 paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake new file mode 100644 index 00000000000..a80a729a139 --- /dev/null +++ b/cmake/external/cutlass.cmake @@ -0,0 +1,43 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include(ExternalProject) + +set(CUTLASS_PREFIX_DIR ${THIRD_PARTY_PATH}/cutlass) + +set(CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git) +set(CUTLASS_TAG v2.9.1) + +include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/") +include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/include/") +include_directories( + "${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/util/include/") + +add_definitions("-DPADDLE_WITH_CUTLASS") + +ExternalProject_Add( + extern_cutlass + ${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE} + GIT_REPOSITORY ${CUTLASS_REPOSITORY} + GIT_TAG "${CUTLASS_TAG}" + PREFIX ${CUTLASS_PREFIX_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "") + +add_library(cutlass INTERFACE) + +add_dependencies(cutlass extern_cutlass) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 06ca0d16df0..4475f5b14d2 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -505,4 +505,14 @@ if(WITH_CUSPARSELT) list(APPEND third_party_deps extern_cusparselt) endif() +if(WITH_GPU + AND NOT WITH_ARM + AND NOT WIN32 + AND NOT APPLE) + if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0) + include(external/cutlass) # download, build, install cusparselt + list(APPEND third_party_deps extern_cutlass) + endif() +endif() + add_custom_target(third_party ALL DEPENDS ${third_party_deps}) diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index 282033e62e3..e5e3cd0f5c1 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -22,6 +22,9 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" #include "paddle/phi/kernels/sparse/gpu/conv.cu.h" +#ifdef PADDLE_WITH_CUTLASS +#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h" +#endif #include "glog/logging.h" @@ -120,29 +123,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter); } - // 2. gather - phi::DenseTensor in_features = - phi::Empty(dev_ctx, {rulebook_len, in_channels}); - phi::DenseTensor out_features = - phi::Empty(dev_ctx, {rulebook_len, out_channels}); - T* in_features_ptr = in_features.data(); - T* out_features_ptr = out_features.data(); - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, &out_features, static_cast(0.0f)); - - Gather(dev_ctx, - x.values().data(), - rulebook_ptr, - rulebook_len, - in_channels, - in_features_ptr); - - // 3. call gemm for every werght - auto blas = phi::funcs::GetBlas(dev_ctx); - auto* out_values = out->mutable_values(); - T* out_values_ptr = out_values->data(); - set_zero(dev_ctx, out_values, static_cast(0.0f)); - if (subm) { auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); @@ -162,43 +142,152 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, out_index_ptr, unique_value_ptr); } +#ifdef PADDLE_WITH_CUTLASS + bool cutlass = true; + if (dev_ctx.GetComputeCapability() < 80) cutlass = false; + if (in_channels % 4 != 0 || out_channels % 4 != 0) { + if (std::is_same::value) cutlass = false; + if (std::is_same::value) cutlass = false; + } + if (!std::is_same::value) cutlass = false; + if (cutlass) { + auto* out_values = out->mutable_non_zero_elements(); + T* out_values_ptr = out_values->data(); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, out_values, static_cast(0.0f)); + + const T* kernel_ptr = kernel.data(); + for (int i = 0; i < kernel_size; i++) { + if (h_counter_ptr[i] <= 0) { + continue; + } - const T* kernel_ptr = kernel.data(); - for (int i = 0; i < kernel_size; i++) { - if (h_counter_ptr[i] <= 0) { - continue; + const int M = h_counter_ptr[i]; + const int K = in_channels; + const int N = out_channels; + const T* tmp_kernel_ptr = kernel_ptr + i * K * N; + const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; + const IntT* scatter_indices = + rulebook_ptr + rulebook_len + h_offsets_ptr[i]; + + if constexpr (std::is_same::value && + std::is_same::value) { + fp16_gather_gemm_scatter gather_gemm_scatter = + getBestFp16Kernel(M, N, K); + gather_gemm_scatter( + dev_ctx, + reinterpret_cast( + x.non_zero_elements().data()), + reinterpret_cast(tmp_kernel_ptr), + reinterpret_cast(out_values_ptr), + reinterpret_cast(out_values_ptr), + M, + N, + K, + static_cast(gather_indices), + static_cast(scatter_indices), + static_cast(1), + static_cast(1)); + } + if constexpr (std::is_same::value && + std::is_same::value) { + fp32_gather_gemm_scatter gather_gemm_scatter = + getBestFp32Kernel(M, N, K); + gather_gemm_scatter(dev_ctx, + x.non_zero_elements().data(), + tmp_kernel_ptr, + out_values_ptr, + out_values_ptr, + M, + N, + K, + gather_indices, + scatter_indices, + static_cast(1), + static_cast(1)); + } + if constexpr (std::is_same::value && + std::is_same::value) { + fp64_gather_gemm_scatter gather_gemm_scatter = + getBestFp64Kernel(M, N, K); + gather_gemm_scatter(dev_ctx, + x.non_zero_elements().data(), + tmp_kernel_ptr, + out_values_ptr, + out_values_ptr, + M, + N, + K, + gather_indices, + scatter_indices, + static_cast(1), + static_cast(1)); + } } + } else { +#endif + // 2. gather + phi::DenseTensor in_features = + phi::Empty(dev_ctx, {rulebook_len, in_channels}); + phi::DenseTensor out_features = + phi::Empty(dev_ctx, {rulebook_len, out_channels}); + T* in_features_ptr = in_features.data(); + T* out_features_ptr = out_features.data(); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, &out_features, static_cast(0.0f)); - // call gemm: (n, in_channels) * (in_channels, out_channels) - const int M = h_counter_ptr[i]; - const int K = in_channels; - const int N = out_channels; - T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels; - const T* tmp_kernel_ptr = kernel_ptr + i * K * N; - T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels; - - blas.GEMM(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - static_cast(1), - tmp_in_ptr, - tmp_kernel_ptr, - static_cast(0), - tmp_out_ptr); - } + Gather(dev_ctx, + x.values().data(), + rulebook_ptr, + rulebook_len, + in_channels, + in_features_ptr); + + // 3. call gemm for every werght + auto blas = phi::funcs::GetBlas(dev_ctx); + auto* out_values = out->mutable_values(); + T* out_values_ptr = out_values->data(); + set_zero(dev_ctx, out_values, static_cast(0.0f)); - // 4. scatter - phi::funcs::sparse::ScatterV2(dev_ctx, - out_features_ptr, - out_index.data(), - unique_value.data(), - out->nnz(), - kernel_size, - out_channels, - 1, - out_values_ptr); + const T* kernel_ptr = kernel.data(); + for (int i = 0; i < kernel_size; i++) { + if (h_counter_ptr[i] <= 0) { + continue; + } + + // call gemm: (n, in_channels) * (in_channels, out_channels) + const int M = h_counter_ptr[i]; + const int K = in_channels; + const int N = out_channels; + T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels; + const T* tmp_kernel_ptr = kernel_ptr + i * K * N; + T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels; + + blas.GEMM(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + static_cast(1), + tmp_in_ptr, + tmp_kernel_ptr, + static_cast(0), + tmp_out_ptr); + } + + // 4. scatter + phi::funcs::sparse::ScatterV2(dev_ctx, + out_features_ptr, + out_index.data(), + unique_value.data(), + out->nnz(), + kernel_size, + out_channels, + 1, + out_values_ptr); +#ifdef PADDLE_WITH_CUTLASS + } +#endif } /** diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu new file mode 100644 index 00000000000..48727c8f851 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu @@ -0,0 +1,188 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_CUTLASS +#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h" +namespace phi { +namespace sparse { +fp16_gather_gemm_scatter getBestFp16Kernel(const int M, + const int N, + const int K) { + if (K == 4 && N == 16) { + return launchKernel; + } + if (K == 16 && N == 16) { + return launchKernel; + } + if (K == 16 && N == 32) { + return launchKernel; + } + if (K == 32 && N == 32) { + return launchKernel; + } + if (K == 32 && N == 64) { + return launchKernel; + } + if (K == 64 && N == 64) { + if (M > 100000) + launchKernel< + cutlass::half_t, + cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8::Gemm>; + if (M > 20000) + launchKernel< + cutlass::half_t, + cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8::Gemm>; + if (M > 15000) + return launchKernel< + cutlass::half_t, + cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8::Gemm>; + return launchKernel; + } + if (K == 128) { + if (M >= 5000) + return launchKernel< + cutlass::half_t, + cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>; + return launchKernel; + } + if (N == 128) { + return launchKernel; + } + return launchKernel; +} +fp32_gather_gemm_scatter getBestFp32Kernel(const int M, + const int N, + const int K) { + if (K == 4 && N == 16) { + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 16 && N == 16) { + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 16 && N == 32) { + if (M >= 10000) + return launchKernel< + float, + cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 32 && N == 32) { + if (M >= 10000) + return launchKernel< + float, + cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 32 && N == 64) { + if (M >= 10000) + return launchKernel< + float, + cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 64 && N == 64) { + if (M >= 15000) + return launchKernel< + float, + cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; + } + if (K == 128) { + if (M >= 100000) + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>; + if (M >= 5000) + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>; + } + if (N == 128) { + if (M >= 100000) + return launchKernel< + float, + cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>; + if (M >= 5000) + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>; + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4::Gemm>; + } + return launchKernel< + float, + cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; +} +fp64_gather_gemm_scatter getBestFp64Kernel(const int M, + const int N, + const int K) { + if (K == 4 && N == 16) { + return launchKernel; + } + if (K == 16 && N == 16) { + if (M >= 10000) + return launchKernel; + return launchKernel; + } + if (K == 16 && N == 32) { + return launchKernel; + } + if (K == 32 && N == 32) { + return launchKernel; + } + if (K == 32 && N == 64) { + return launchKernel; + } + if (K == 64 && N == 64) { + return launchKernel; + } + return launchKernel; +} + +} // namespace sparse +} // namespace phi +#endif diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h new file mode 100644 index 00000000000..462cd710340 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h @@ -0,0 +1,555 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#ifdef PADDLE_WITH_CUTLASS +#include "cutlass/arch/mma.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/util/device_memory.h" +#include "examples/common/helper.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +namespace phi { +namespace sparse { +typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx, + const cutlass::half_t* const a, + const cutlass::half_t* const b, + const cutlass::half_t* const c, + cutlass::half_t* const d, + const int m, + const int n, + const int k, + const int32_t* a_indices, + const int32_t* c_d_indices, + cutlass::half_t const alpha, + cutlass::half_t const beta); +typedef void (*fp32_gather_gemm_scatter)(const GPUContext& dev_ctx, + const float* const a, + const float* const b, + const float* const c, + float* const d, + const int m, + const int n, + const int k, + const int32_t* a_indices, + const int32_t* c_d_indices, + float const alpha, + float const beta); +typedef void (*fp64_gather_gemm_scatter)(const GPUContext& dev_ctx, + const double* const a, + const double* const b, + const double* const c, + double* const d, + const int m, + const int n, + const int k, + const int32_t* a_indices, + const int32_t* c_d_indices, + double const alpha, + double const beta); +fp16_gather_gemm_scatter getBestFp16Kernel(const int M, + const int K, + const int N); +fp32_gather_gemm_scatter getBestFp32Kernel(const int M, + const int K, + const int N); +fp64_gather_gemm_scatter getBestFp64Kernel(const int M, + const int K, + const int N); +template +void launchKernel(const GPUContext& dev_ctx, + const T* const a, + const T* const b, + const T* const c, + T* const d, + const int m, + const int n, + const int k, + const int32_t* a_indices, + const int32_t* c_d_indices, + T const alpha, + T const beta) { + cutlass::gemm::GemmCoord problem_size_real({m, n, k}); + int split_k_slices = 1; + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size_real, + split_k_slices, + {alpha, beta}, + a, + b, + c, + d, + cutlass::layout::RowMajor().capacity(problem_size_real.mk()), + cutlass::layout::RowMajor().capacity(problem_size_real.kn()), + cutlass::layout::RowMajor().capacity(problem_size_real.mn()), + cutlass::layout::RowMajor().capacity(problem_size_real.mn()), + problem_size_real.k(), + problem_size_real.n(), + problem_size_real.n(), + problem_size_real.n(), + a_indices, + nullptr, + c_d_indices}; + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + Gemm gemm_op; + cutlass::Status status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + gemm_op(dev_ctx.stream()); +} +struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_h1688gemm_64x128_32x2_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 4, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 4, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 5, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread:: + LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread:: + LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 10, + 4, + 4, + cutlass::arch::OpMultiplyAddFastF16, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 3, + 4, + 4, + cutlass::arch::OpMultiplyAddFastF16, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4, + 4, + 4, + cutlass::arch::OpMultiplyAddFastF16, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 16>, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 3, + 4, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 6, + 4, + 4, + cutlass::arch::OpMultiplyAddFastF16, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 3, + 4, + 4, + cutlass::arch::OpMultiplyAddFastF32, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_d884gemm_16x32_16x5_nn_align1 { + using Gemm = cutlass::gemm::device::GemmUniversal< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::RowMajor, + double, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<16, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 5, + 1, + 1, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 { + using Gemm = cutlass::gemm::device::GemmUniversal< + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::RowMajor, + double, + cutlass::layout::RowMajor, + double, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 16, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 5, + 1, + 1, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + true, + false, + true>; +}; +} // namespace sparse +} // namespace phi +#endif -- GitLab