提交 3b452d8c 编写于 作者: M Megvii Engine Team

feat(mgb): cuda conv support nhwc format and fp16 dtype

GitOrigin-RevId: b8ddcd108a4370a0b093c51bd90ebde0e007cb24
上级 10bcf757
......@@ -69,6 +69,12 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
return false;
}
if (args.src_layout->dtype.enumv() == DTypeEnum::Float16 &&
args.dst_layout->dtype.enumv() == DTypeEnum::Float16 &&
param.format == param::ConvBias::Format::NHWC) {
return false;
}
//! FIXME: conv kernel of cudnn for NCHW4_NCHW tensor format causes illegal
//! memory access errors, so we have to disable this kernel here.
if (param.format == param::ConvBias::Format::NCHW4_NCHW ||
......
......@@ -151,14 +151,14 @@ bool is_cudnn_supported(const BiasForwardSizeArgs& args) {
if (args.handle->is_tegra_k1())
return false;
// TODO: We only support NCHW format now. It seems cuDNN provides support
// for NHWC as well.
if (args.filter_meta.format == param::Convolution::Format::NCHW4) {
if (args.filter_meta.format == param::Convolution::Format::NCHW4 ||
args.filter_meta.format == param::Convolution::Format::NCHW32) {
if (args.dst_layout->dtype.enumv() != DTypeEnum::Int8 &&
args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS8) {
return false;
}
} else if (args.filter_meta.format != param::Convolution::Format::NCHW) {
} else if (args.filter_meta.format != param::Convolution::Format::NCHW &&
args.filter_meta.format != param::Convolution::Format::NHWC) {
return false;
}
auto& fm = args.filter_meta;
......
......@@ -216,6 +216,41 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_QS8) {
}
}
TEST_F(CUDA, CONV_BIAS_FORWARD_FLOAT16) {
require_compute_capability(6, 1);
Checker<ConvBiasForward> checker(handle_cuda());
ConvBias::Param param;
param.format = ConvBias::Param::Format::NHWC;
param.nonlineMode = ConvBias::Param::NonlineMode::IDENTITY;
checker.set_epsilon(2e-2)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_dtype(3, dtype::Float16())
.set_dtype(4, dtype::Float16());
{
auto src_shape = TensorShape{20, 224, 224, 4};
auto filter_shape = TensorShape{24, 1, 1, 4};
auto bias_shape = TensorShape{1, 1, 1, 24};
checker.set_param(param).execs(
{src_shape, filter_shape, bias_shape, {}, {}});
param.compute_mode = ConvBias::Param::ComputeMode::FLOAT32;
checker.set_param(param).execs(
{src_shape, filter_shape, bias_shape, {}, {}});
}
{
param.sparse = ConvBias::Param::Sparse::GROUP;
auto src_shape = TensorShape{20, 224, 224, 16};
auto filter_shape = TensorShape{4, 4, 1, 1, 4};
auto bias_shape = TensorShape{1, 1, 1, 16};
checker.set_param(param).execs(
{src_shape, filter_shape, bias_shape, {}, {}});
}
}
TEST_F(CUDA, CONV_BIAS_NCHW_QS8) {
//! not support NonlineMode::SIGMOID and NonlineMode::H_SWISH
require_compute_capability(6, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册