From 684c07d7be914b8dfff6e955e23730795548a0b1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 17 May 2021 19:32:31 +0800 Subject: [PATCH] fix(api_cache): fix serialization for conv_desc GitOrigin-RevId: 95dbc9c685cced46dd910997bd585363c392ccbd --- dnn/src/common/api_cache.h | 64 ++++++++++++++++------------- dnn/src/cuda/api_cache.h | 83 ++++++++++++++++---------------------- dnn/src/cuda/handle.cpp | 11 ++--- 3 files changed, 77 insertions(+), 81 deletions(-) diff --git a/dnn/src/common/api_cache.h b/dnn/src/common/api_cache.h index 9009f5e1a..5e50ece31 100644 --- a/dnn/src/common/api_cache.h +++ b/dnn/src/common/api_cache.h @@ -131,12 +131,18 @@ public: T read_plain() { static_assert(std::is_trivially_copyable::value, "invalid type"); T ret; - memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); + std::memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); m_cursor += sizeof(T); return ret; } template - void write_plain(T value) { + void read_plain(T* dest) { + static_assert(std::is_trivially_copyable::value, "invalid type"); + std::memcpy(dest, m_buffer.data() + m_cursor, sizeof(T)); + m_cursor += sizeof(T); + } + template + void write_plain(const T& value) { static_assert(std::is_trivially_copyable::value, "type should be trivially copyable"); m_buffer.append(reinterpret_cast(&value), sizeof(T)); @@ -144,7 +150,7 @@ public: std::string take() { return std::move(m_buffer); } void reset(std::string new_buf) { m_cursor = 0; - m_buffer = new_buf; + m_buffer = std::move(new_buf); } }; @@ -153,7 +159,7 @@ struct Empty {}; // in: seq[1, 2, ..., m] // out: seq[N+1, N+2, ... N+m] template -static std::index_sequence inc_index_sequence( +inline std::index_sequence inc_index_sequence( std::index_sequence) { return {}; } @@ -172,7 +178,7 @@ private: // deconstruct tuple and call functor template - auto call_helper(TFunctor functor, std::index_sequence) { + auto call_helper(TFunctor&& functor, std::index_sequence) { return functor(std::get(m_storage).value...); } @@ -203,7 +209,7 @@ private: template void set_values_helper(std::index_sequence, TArg&& arg, TArgs&&... args) { - std::get(m_storage).value = arg; + std::get(m_storage).value = std::forward(arg); set_values_helper(std::index_sequence(), std::forward(args)...); } @@ -253,7 +259,7 @@ public: } Empty deserialize(StringSerializer& ser, Empty) { - value = ser.read_plain(); + ser.read_plain(&value); return Empty{}; } }; @@ -285,7 +291,8 @@ private: template static auto declbundle_helper(std::index_sequence) - -> ParamBundle(declargs()))...> { + -> ParamBundle(declargs()))>...> { return {}; } @@ -312,9 +319,11 @@ public: // declare new input template auto input() { - using TNewInputs = decltype( - std::tuple_cat(std::declval(), - std::make_tuple(std::declval()))); + static_assert(std::tuple_size::value == 0, + "input arg cannot be declared after output"); + using TNewInputs = + decltype(std::tuple_cat(std::declval(), + std::declval>())); return FunctionCacheBuilder{}; } // declare new output @@ -322,31 +331,29 @@ public: auto output() { using TNewOutputs = decltype( std::tuple_cat(std::declval(), - std::make_tuple(std::declval()))); + std::declval>())); return FunctionCacheBuilder{}; } // summary template - function_t build(TFunctor func) { + function_t build(TFunctor&& func) { + constexpr size_t n_inputs = std::tuple_size::value; + constexpr size_t n_outputs = std::tuple_size::value; auto cache = std::make_shared>(); // bundle -> ser(in args) cache->key_mapper = [](bundle_t bundle) { StringSerializer ser; - bundle.template serialize_params<0, - std::tuple_size::value>( - ser); + bundle.template serialize_params<0, n_inputs>(ser); return ser.take(); }; // bundle -> ser(out args) - cache->value_mapper = [=](bundle_t bundle) { + cache->value_mapper = [func](bundle_t bundle) { StringSerializer ser; TRet ret; ret.value = bundle.call_by(func); ret.serialize(ser, Empty{}); - bundle.template serialize_params< - std::tuple_size::value, - std::tuple_size::value + - std::tuple_size::value>(ser); + bundle.template serialize_params( + ser); return ser.take(); }; return [=](auto&&... args) mutable { @@ -361,8 +368,6 @@ public: std::forward(args)...); ser.reset((*cache)(bundle)); ret.deserialize(ser, Empty{}); - constexpr size_t n_inputs = std::tuple_size::value; - constexpr size_t n_outputs = std::tuple_size::value; bundle.template deserialize_params( ser); return ret.value; @@ -394,7 +399,8 @@ public: return *value; } T deserialize(StringSerializer& ser, Empty) { - return *value = ser.read_plain(); + ser.read_plain(value); + return *value; } }; @@ -402,16 +408,20 @@ public: template class ArrayParam { public: - TItem* value; + decltype(std::declval().value)* value; Empty serialize(StringSerializer& ser, TSize size) { + TItem param; for (TSize i = 0; i < size; ++i) { - ser.write_plain(value[i]); + param.value = value[i]; + param.serialize(ser, Empty{}); } return Empty{}; } Empty deserialize(StringSerializer& ser, TSize size) { + TItem param; for (TSize i = 0; i < size; ++i) { - value[i] = ser.read_plain(); + param.deserialize(ser, Empty{}); + value[i] = param.value; } return Empty{}; } diff --git a/dnn/src/cuda/api_cache.h b/dnn/src/cuda/api_cache.h index f58f6d75b..2299a18a2 100644 --- a/dnn/src/cuda/api_cache.h +++ b/dnn/src/cuda/api_cache.h @@ -20,14 +20,16 @@ class CudnnConvDescParam { public: cudnnConvolutionDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; - int padA[MEGDNN_MAX_NDIM]; - int strideA[MEGDNN_MAX_NDIM]; - int dilationA[MEGDNN_MAX_NDIM]; + constexpr int maxNbDims = CUDNN_DIM_MAX - 2; + int nbDims = maxNbDims; + int padA[maxNbDims]; + int strideA[maxNbDims]; + int dilationA[maxNbDims]; cudnnConvolutionMode_t mode; cudnnDataType_t computeType; - cudnnGetConvolutionNdDescriptor(value, nbDims, &nbDims, padA, strideA, - dilationA, &mode, &computeType); + cudnnGetConvolutionNdDescriptor(value, maxNbDims, &nbDims, padA, + strideA, dilationA, &mode, + &computeType); ser.write_plain(nbDims); for (int i = 0; i < nbDims; ++i) { ser.write_plain(padA[i]); @@ -38,23 +40,8 @@ public: ser.write_plain(computeType); return Empty{}; } - Empty deserialize(StringSerializer& ser, Empty) { - int ndim = ser.read_plain(); - int padA[MEGDNN_MAX_NDIM]; - int strideA[MEGDNN_MAX_NDIM]; - int dilationA[MEGDNN_MAX_NDIM]; - for (int i = 0; i < ndim; ++i) { - padA[i] = ser.read_plain(); - strideA[i] = ser.read_plain(); - dilationA[i] = ser.read_plain(); - } - cudnnConvolutionMode_t mode = ser.read_plain(); - cudnnDataType_t computeType = ser.read_plain(); - cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, - mode, computeType); - return Empty{}; - } }; + class CudnnTensorDescParam { public: cudnnTensorDescriptor_t value; @@ -63,8 +50,8 @@ public: cudnnDataType_t dataType; int dimA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; - cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, - strideA); + cudnnGetTensorNdDescriptor(value, MEGDNN_MAX_NDIM, &dataType, &nbDims, + dimA, strideA); ser.write_plain(nbDims); for (int i = 0; i < nbDims; ++i) { ser.write_plain(dimA[i]); @@ -73,21 +60,8 @@ public: ser.write_plain(dataType); return Empty{}; } - Empty deserialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; - cudnnDataType_t dataType; - int dimA[MEGDNN_MAX_NDIM]; - int strideA[MEGDNN_MAX_NDIM]; - nbDims = ser.read_plain(); - for (int i = 0; i < nbDims; ++i) { - dimA[i] = ser.read_plain(); - strideA[i] = ser.read_plain(); - } - dataType = ser.read_plain(); - cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); - return Empty{}; - } }; + class CudnnFilterDescParam { public: cudnnFilterDescriptor_t value; @@ -106,18 +80,29 @@ public: ser.write_plain(format); return Empty{}; } +}; + +template +class CudnnConvAlgoPerfParam { +public: + T value; + Empty serialize(StringSerializer& ser, Empty) { + ser.write_plain(value.algo); + ser.write_plain(value.status); + ser.write_plain(value.time); + ser.write_plain(value.memory); + ser.write_plain(value.determinism); + ser.write_plain(value.mathType); + return Empty{}; + } + Empty deserialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; - cudnnDataType_t dataType; - cudnnTensorFormat_t format; - int filterDimA[MEGDNN_MAX_NDIM]; - nbDims = ser.read_plain(); - for (int i = 0; i < nbDims; ++i) { - filterDimA[i] = ser.read_plain(); - } - dataType = ser.read_plain(); - format = ser.read_plain(); - cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); + ser.read_plain(&value.algo); + ser.read_plain(&value.status); + ser.read_plain(&value.time); + ser.read_plain(&value.memory); + ser.read_plain(&value.determinism); + ser.read_plain(&value.mathType); return Empty{}; } }; diff --git a/dnn/src/cuda/handle.cpp b/dnn/src/cuda/handle.cpp index e685cc5e9..2c6eb987d 100644 --- a/dnn/src/cuda/handle.cpp +++ b/dnn/src/cuda/handle.cpp @@ -165,7 +165,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { .input() .input>() .output>() - .output>() + .output>>() .ret>() .build(&cudnnGetConvolutionForwardAlgorithm_v7); GetConvolutionForwardAlgorithmMaxCount = @@ -196,8 +197,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { .input() .input>() .output>() - .output>() + .output>>() .ret>() .build(&cudnnGetConvolutionBackwardDataAlgorithm_v7); GetConvolutionBackwardDataAlgorithmMaxCount = @@ -228,8 +229,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { .input() .input>() .output>() - .output>() + .output>>() .ret>() .build(&cudnnGetConvolutionBackwardFilterAlgorithm_v7); GetConvolutionBackwardFilterAlgorithmMaxCount = -- GitLab