/** * \file dnn/src/cuda/api_cache.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #pragma once #include "src/common/api_cache.h" #include "src/cuda/cudnn_wrapper.h" namespace megdnn { class CudnnConvDescParam { public: cudnnConvolutionDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { constexpr int maxNbDims = CUDNN_DIM_MAX - 2; int nbDims = maxNbDims; int padA[maxNbDims]; int strideA[maxNbDims]; int dilationA[maxNbDims]; cudnnConvolutionMode_t mode; cudnnDataType_t computeType; cudnnGetConvolutionNdDescriptor(value, maxNbDims, &nbDims, padA, strideA, dilationA, &mode, &computeType); ser.write_plain(nbDims); for (int i = 0; i < nbDims; ++i) { ser.write_plain(padA[i]); ser.write_plain(strideA[i]); ser.write_plain(dilationA[i]); } ser.write_plain(mode); ser.write_plain(computeType); return Empty{}; } }; class CudnnTensorDescParam { public: cudnnTensorDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; int dimA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; cudnnGetTensorNdDescriptor(value, MEGDNN_MAX_NDIM, &dataType, &nbDims, dimA, strideA); ser.write_plain(nbDims); for (int i = 0; i < nbDims; ++i) { ser.write_plain(dimA[i]); ser.write_plain(strideA[i]); } ser.write_plain(dataType); return Empty{}; } }; class CudnnFilterDescParam { public: cudnnFilterDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; cudnnTensorFormat_t format; int filterDimA[MEGDNN_MAX_NDIM]; cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, filterDimA); ser.write_plain(nbDims); for (int i = 0; i < nbDims; ++i) { ser.write_plain(filterDimA[i]); } ser.write_plain(dataType); ser.write_plain(format); return Empty{}; } }; template class CudnnConvAlgoPerfParam { public: T value; Empty serialize(StringSerializer& ser, Empty) { ser.write_plain(value.algo); ser.write_plain(value.status); ser.write_plain(value.time); ser.write_plain(value.memory); ser.write_plain(value.determinism); ser.write_plain(value.mathType); return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { ser.read_plain(&value.algo); ser.read_plain(&value.status); ser.read_plain(&value.time); ser.read_plain(&value.memory); ser.read_plain(&value.determinism); ser.read_plain(&value.mathType); return Empty{}; } }; } // namespace megdnn