From a85531dd0f2b7a782b46ba850ec5a98b610e061e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 23 Dec 2020 14:53:10 +0800 Subject: [PATCH] feat(mgb/opr): add tqt opr GitOrigin-RevId: 49c62cd5327f55ceade849918cce13a1fef2ab39 --- dnn/include/megdnn/oprs/nn.h | 53 ++++++- dnn/scripts/opr_param_defs.py | 6 +- dnn/src/common/handle_impl.h | 4 +- dnn/src/common/tqt.cpp | 62 +++++++++ dnn/src/cuda/handle_create.cpp | 1 + dnn/src/cuda/tqt/kern.cu | 30 ++++ dnn/src/cuda/tqt/kern.cuh | 129 ++++++++++++++++++ dnn/src/cuda/tqt/opr_impl.cpp | 125 +++++++++++++++++ dnn/src/cuda/tqt/opr_impl.h | 56 ++++++++ dnn/src/naive/fake_quant/opr_impl.cpp | 1 - dnn/src/naive/handle.cpp | 1 + dnn/src/naive/tqt/opr_impl.cpp | 122 +++++++++++++++++ dnn/src/naive/tqt/opr_impl.h | 47 +++++++ dnn/test/common/deduce_layout_proxy.h | 10 ++ dnn/test/common/exec_proxy.h | 16 +++ dnn/test/common/opr_trait.h | 2 + dnn/test/common/tqt.h | 45 ++++++ dnn/test/cuda/tqt.cpp | 91 ++++++++++++ .../megengine/quantization/fake_quant.py | 51 +------ .../python/megengine/quantization/utils.py | 6 + .../test/unit/quantization/test_fake_quant.py | 50 +++---- imperative/src/impl/ops/specializations.cpp | 15 +- src/core/include/megbrain/ir/ops.td | 1 + src/opr/impl/dnn/dnn.oprdecl | 3 + src/opr/impl/dnn/dnn.sereg.h | 16 +++ src/opr/impl/dnn/tqt.cpp | 83 +++++++++++ src/opr/impl/internal/megdnn_opr_wrapper.inl | 5 + src/opr/include/megbrain/opr/dnn/tqt.h | 46 +++++++ src/opr/test/dnn/tqt.cpp | 67 +++++++++ src/serialization/impl/schema.fbs | 1 + 30 files changed, 1067 insertions(+), 78 deletions(-) create mode 100644 dnn/src/common/tqt.cpp create mode 100644 dnn/src/cuda/tqt/kern.cu create mode 100644 dnn/src/cuda/tqt/kern.cuh create mode 100644 dnn/src/cuda/tqt/opr_impl.cpp create mode 100644 dnn/src/cuda/tqt/opr_impl.h create mode 100644 dnn/src/naive/tqt/opr_impl.cpp create mode 100644 dnn/src/naive/tqt/opr_impl.h create mode 100644 dnn/test/common/tqt.h create mode 100644 dnn/test/cuda/tqt.cpp create mode 100644 src/opr/impl/dnn/tqt.cpp create mode 100644 src/opr/include/megbrain/opr/dnn/tqt.h create mode 100644 src/opr/test/dnn/tqt.cpp diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 7a475c5a2..12514f746 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -224,8 +224,8 @@ public: const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) = 0; /** - * \brief execute weight preprocessing, read weights form filter and write to - * preprocessed_filter after preprocessed. + * \brief execute weight preprocessing, read weights form filter and write + * to preprocessed_filter after preprocessed. * * \praram[in] workspace the needed tmp workspace when exec_preprocess */ @@ -1684,6 +1684,55 @@ protected: const TensorLayout& grad, size_t workspace_in_bytes); }; +class TQTBase : public OperatorBase { + DEF_OPR_IMPL_CTOR(TQTBase, OperatorBase); + DEF_OPR_PARAM(TQT); + +protected: + void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output); + void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& output); +}; + +class TQTForward : public TQTBase { + DEF_OPR_IMPL(TQTForward, TQTBase, 2, 1); + +public: + virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_out output, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& input, const TensorLayout& scale, + TensorLayout& output); + virtual size_t get_workspace_in_bytes(const TensorLayout& input, + const TensorLayout& scale, + const TensorLayout& output) = 0; + +protected: + void check_exec(const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& output, size_t workspace_in_bytes); +}; +using TQT = TQTForward; + +class TQTBackward : public TQTBase { + DEF_OPR_IMPL(TQTBackward, TQTBase, 3, 2); + +public: + virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& diff, + const TensorLayout& input, + const TensorLayout& scale, + const TensorLayout& grad_x, + const TensorLayout& grad_s) = 0; + +protected: + void check_exec(const TensorLayout& diff, const TensorLayout& input, + const TensorLayout& scale, const TensorLayout& grad_x, + const TensorLayout& grad_s, size_t workspace_in_bytes); +}; + } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 63540ed96..6a899d01b 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -948,5 +948,7 @@ when the ``I`` suffix is present. add_fields('int32','qmin','-2147483648'). add_fields('int32','qmax','2147483647') ) - - +(pdef('TQT'). + add_fields('int32', 'qmin', '-2147483648'). + add_fields('int32', 'qmax', '2147483647') + ) diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 5adb79de3..77227b6cc 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -202,7 +202,9 @@ private: cb(AdaptivePoolingBackward) \ cb(DctChannelSelectForward) \ cb(FakeQuantForward) \ - cb(FakeQuantBackward) + cb(FakeQuantBackward) \ + cb(TQTForward) \ + cb(TQTBackward) /*! * \brief specialize HandleImpl::create_operator for a single opr type; diff --git a/dnn/src/common/tqt.cpp b/dnn/src/common/tqt.cpp new file mode 100644 index 000000000..44199df6c --- /dev/null +++ b/dnn/src/common/tqt.cpp @@ -0,0 +1,62 @@ +/** + * \file dnn/src/common/tqt.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { + +void TQTBase::deduce_layout_fwd(const TensorLayout& input, + TensorLayout& output) { + output = TensorLayout(input, input.dtype); +} + +void TQTBase::check_layout_fwd(const TensorLayout& input, + const TensorLayout& scale, + const TensorLayout& output) { + megdnn_assert(input.dtype == dtype::Float32()); + megdnn_assert(scale.dtype == dtype::Float32()); + TensorLayout expected; + deduce_layout_fwd(input, expected); + megdnn_assert_eq_layout(expected, output); +} + +void TQTForward::deduce_layout(const TensorLayout& input, + const TensorLayout& /* scale */, + TensorLayout& output) { + deduce_layout_fwd(input, output); +} + +void TQTForward::check_exec(const TensorLayout& input, + const TensorLayout& scale, + const TensorLayout& output, + size_t workspace_in_bytes) { + check_layout_fwd(input, scale, output); + auto required_workspace_space = + get_workspace_in_bytes(input, scale, output); + megdnn_assert(workspace_in_bytes >= required_workspace_space); +} + +void TQTBackward::check_exec(const TensorLayout& diff, + const TensorLayout& input, + const TensorLayout& scale, + const TensorLayout& grad_x, + const TensorLayout& grad_s, + size_t workspace_in_bytes) { + megdnn_assert_eq_shape(diff, input); + megdnn_assert_eq_shape(grad_x, input); + auto required_worspace_space = + get_workspace_in_bytes(diff, input, scale, grad_x, grad_s); + megdnn_assert(workspace_in_bytes >= required_worspace_space); +} + +} // namespace megdnn diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index af4949fb0..ce01a1aca 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -77,6 +77,7 @@ #include "src/cuda/batch_conv_bias/opr_impl.h" #include "src/cuda/remap/opr_impl.h" #include "src/cuda/fake_quant/opr_impl.h" +#include "src/cuda/tqt/opr_impl.h" namespace megdnn { namespace cuda { diff --git a/dnn/src/cuda/tqt/kern.cu b/dnn/src/cuda/tqt/kern.cu new file mode 100644 index 000000000..b0455fdc4 --- /dev/null +++ b/dnn/src/cuda/tqt/kern.cu @@ -0,0 +1,30 @@ +/** + * \file dnn/src/cuda/tqt/kern.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./kern.cuh" + +namespace megdnn { +namespace cuda { + +#define cb(_dtype) \ + INST_RUN_ELEMWISE(TQTKernOp::ctype>, \ + DTypeTrait<_dtype>::ctype, 1); \ + INST_RUN_ELEMWISE(TQTBwdKernOp::ctype>, \ + DTypeTrait<_dtype>::ctype, 1); \ + INST_RUN_ELEMWISE(TQTKernOpNonContig::ctype>, \ + DTypeTrait<_dtype>::ctype, 3); \ + INST_RUN_ELEMWISE(TQTBwdKernOpNonContig::ctype>, \ + DTypeTrait<_dtype>::ctype, 5); +cb(megdnn::dtype::Float32) + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/tqt/kern.cuh b/dnn/src/cuda/tqt/kern.cuh new file mode 100644 index 000000000..b26b15205 --- /dev/null +++ b/dnn/src/cuda/tqt/kern.cuh @@ -0,0 +1,129 @@ +/** + * \file dnn/src/cuda/tqt/kern.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/cuda/elemwise_helper.cuh" +#include "src/cuda/utils.cuh" + +#if MEGDNN_CC_HOST +#include "megdnn/oprs.h" +#endif + +namespace megdnn { +namespace cuda { + +template +struct TQTKernOp { + ctype* input; + ctype* output; + ctype qmin, qmax; + + __device__ void operator()(uint32_t idx, ctype scale) { + ctype t = powf(2, scale); + ctype x = round(input[idx] / t); + x = fmaxf(fminf(x, qmax), qmin); + output[idx] = x * t; + } + +#if MEGDNN_CC_HOST + TQTKernOp(const TensorND& input, const TensorND& output, + const TQT::Param& param) + : input{input.ptr()}, + output{output.ptr()}, + qmin(param.qmin), + qmax(param.qmax) {} +#endif +}; + +template +struct TQTBwdKernOp { + ctype* diff; + ctype* input; + ctype* grad_x; + ctype* grad_s; + ctype qmin, qmax; + + __device__ void operator()(uint32_t idx, ctype scale) { + ctype t = powf(2, scale); + ctype scaled = input[idx] / t; + ctype rounded = round(scaled); + rounded = fmaxf(fminf(rounded, qmax), qmin); + bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax; + bool mask_quant = !mask_clip; + + grad_x[idx] = diff[idx] * mask_quant; + ctype grad_quant = + diff[idx] * mask_quant * (rounded - scaled) * t * log(2.0); + ctype grad_clip = diff[idx] * mask_clip * rounded * t * log(2.0); + grad_s[idx] = grad_quant + grad_clip; + } + +#if MEGDNN_CC_HOST + TQTBwdKernOp(const TensorND& diff, const TensorND& input, + const TensorND& grad_x, const TensorND& grad_s, + const TQT::Param& param) + : diff{diff.ptr()}, + input{input.ptr()}, + grad_x{grad_x.ptr()}, + grad_s{grad_s.ptr()}, + qmin(param.qmin), + qmax(param.qmax) {} +#endif +}; + +template +struct TQTKernOpNonContig { + ctype qmin; + ctype qmax; + + __device__ void operator()(uint32_t, ctype& input, ctype& scale, + ctype& output) { + ctype t = powf(2, scale); + ctype x = round(input / t); + x = fmaxf(fminf(x, qmax), qmin); + output = x * t; + } +#if MEGDNN_CC_HOST + TQTKernOpNonContig(const TQT::Param& param) + : qmin(param.qmin), qmax(param.qmax) {} +#endif +}; + +template +struct TQTBwdKernOpNonContig { + ctype qmin; + ctype qmax; + + __device__ void operator()(uint32_t, ctype& diff, ctype& input, + ctype& scale, ctype& grad_x, ctype& grad_s) { + ctype t = powf(2, scale); + ctype scaled = input / t; + ctype rounded = round(scaled); + rounded = fmaxf(fminf(rounded, qmax), qmin); + bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax; + bool mask_quant = !mask_clip; + + grad_x = diff * mask_quant; + ctype grad_quant = + diff * mask_quant * (rounded - scaled) * t * log(2.0); + ctype grad_clip = diff * mask_clip * rounded * t * log(2.0); + grad_s = grad_quant + grad_clip; + } +#if MEGDNN_CC_HOST + TQTBwdKernOpNonContig(const TQT::Param& param) + : qmin(param.qmin), qmax(param.qmax) {} +#endif +}; + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/tqt/opr_impl.cpp b/dnn/src/cuda/tqt/opr_impl.cpp new file mode 100644 index 000000000..fe9bca3c2 --- /dev/null +++ b/dnn/src/cuda/tqt/opr_impl.cpp @@ -0,0 +1,125 @@ +/** + * \file dnn/src/cuda/tqt/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./opr_impl.h" +#include "./kern.cuh" +#include "src/common/utils.h" +namespace megdnn { +namespace cuda { + +void TQTForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_out output, + _megdnn_workspace workspace) { + check_exec(input.layout, scale.layout, output.layout, workspace.size); + + if (!input.layout.is_contiguous() || !output.layout.is_contiguous()) + return exec_noncontig(input, scale, output); + + ElemwiseOpParamN<1> ele_param; + ele_param[0] = scale; + ele_param[0].layout = ele_param[0].layout.broadcast(input.layout); + ele_param.init_from_given_tensor(); + auto m_param = param(); + auto stream = cuda_stream(handle()); + +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 1>(ele_param, stream, \ + {input, output, m_param}); \ + return; \ + } + cb(megdnn::dtype::Float32) +#undef cb +} + +void TQTForwardImpl::exec_noncontig(_megdnn_tensor_in input, + _megdnn_tensor_in scale, + _megdnn_tensor_out output) { + ElemwiseOpParamN<3> ele_param; + ele_param[0] = input; + ele_param[1] = scale; + ele_param[1].layout = ele_param[1].layout.broadcast(input.layout); + ele_param[2] = output; + ele_param.init_from_given_tensor(); + auto m_param = param(); + auto stream = cuda_stream(handle()); + +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 3>(ele_param, stream, \ + {m_param}); \ + return; \ + } + cb(megdnn::dtype::Float32) +#undef cb +} + +void TQTBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s, + _megdnn_workspace workspace) { + check_exec(diff.layout, input.layout, scale.layout, grad_x.layout, + grad_s.layout, workspace.size); + + if (!input.layout.is_contiguous() || !diff.layout.is_contiguous() || + !grad_x.layout.is_contiguous() || !grad_s.layout.is_contiguous()) + return exec_noncontig(diff, input, scale, grad_x, grad_s); + + ElemwiseOpParamN<1> ele_param; + ele_param[0] = scale; + ele_param[0].layout = ele_param[0].layout.broadcast(input.layout); + ele_param.init_from_given_tensor(); + auto m_param = param(); + auto stream = cuda_stream(handle()); + +#define cb(DType) \ + if (grad_x.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 1>( \ + ele_param, stream, {diff, input, grad_x, grad_s, m_param}); \ + return; \ + } + cb(megdnn::dtype::Float32) +#undef cb +} + +void TQTBackwardImpl::exec_noncontig(_megdnn_tensor_in diff, + _megdnn_tensor_in input, + _megdnn_tensor_in scale, + _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s) { + ElemwiseOpParamN<5> ele_param; + ele_param[0] = diff; + ele_param[1] = input; + ele_param[2] = scale; + ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); + ele_param[3] = grad_x; + ele_param[4] = grad_s; + ele_param.init_from_given_tensor(); + auto m_param = param(); + auto stream = cuda_stream(handle()); + +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 5>(ele_param, stream, \ + {m_param}); \ + return; \ + } + cb(megdnn::dtype::Float32) +#undef cb +} + +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/tqt/opr_impl.h b/dnn/src/cuda/tqt/opr_impl.h new file mode 100644 index 000000000..1759e208f --- /dev/null +++ b/dnn/src/cuda/tqt/opr_impl.h @@ -0,0 +1,56 @@ +/** + * \file dnn/src/cuda/tqt/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megdnn/oprs.h" +#include "src/cuda/utils.h" +namespace megdnn { +namespace cuda { + +class TQTForwardImpl final : public TQTForward { +public: + using TQTForward::TQTForward; + void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_out output, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, /* input */ + const TensorLayout&, /* scale */ + const TensorLayout& /* output */) override { + return 0; + } + +private: + void exec_noncontig(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_out output); +}; + +class TQTBackwardImpl final : public TQTBackward { +public: + using TQTBackward::TQTBackward; + void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& /* diff */, + const TensorLayout& /* input */, + const TensorLayout& /* scale */, + const TensorLayout& /* grad_x */, + const TensorLayout& /* grad_s */) override { + return 0; + } + +private: + void exec_noncontig(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s); +}; + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/naive/fake_quant/opr_impl.cpp b/dnn/src/naive/fake_quant/opr_impl.cpp index 466365888..d9262e81e 100644 --- a/dnn/src/naive/fake_quant/opr_impl.cpp +++ b/dnn/src/naive/fake_quant/opr_impl.cpp @@ -12,7 +12,6 @@ #include "src/naive/fake_quant/opr_impl.h" #include -#include #include "megdnn/tensor_iter.h" #include "src/common/elemwise_helper.cuh" #include "src/common/utils.h" diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index 4a91dbf9a..5e93508b8 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -80,6 +80,7 @@ #include "src/naive/warp_perspective/opr_impl.h" #include "src/naive/remap/opr_impl.h" #include "src/naive/fake_quant/opr_impl.h" +#include "src/naive/tqt/opr_impl.h" static size_t g_image2d_pitch_alignment = 1; diff --git a/dnn/src/naive/tqt/opr_impl.cpp b/dnn/src/naive/tqt/opr_impl.cpp new file mode 100644 index 000000000..beb0a7334 --- /dev/null +++ b/dnn/src/naive/tqt/opr_impl.cpp @@ -0,0 +1,122 @@ +/** + * \file dnn/src/naive/tqt/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/naive/tqt/opr_impl.h" +#include +#include "megdnn/tensor_iter.h" +#include "src/common/elemwise_helper.cuh" +#include "src/common/utils.h" +#include "src/naive/handle.h" + +namespace { +using namespace megdnn; + +template +void forward_impl(const ElemwiseOpParamN<3> src, float qmin, float qmax) { + auto inp = tensor_iter_valonly(src[0]).begin(); + auto out = tensor_iter_valonly(src[1]).begin(); + auto scale = tensor_iter_valonly(src[2]).begin(); + size_t total = src[0].layout.total_nr_elems(); + for (size_t i = 0; i < total; ++i) { + T t = pow(2, *scale); + T x = round(*inp / t); + x = x <= qmin ? qmin : x; + x = x >= qmax ? qmax : x; + *out = x * t; + ++inp; + ++out; + ++scale; + } +} + +template +void backward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) { + auto diff = tensor_iter_valonly(src[0]).begin(); + auto input = tensor_iter_valonly(src[1]).begin(); + auto scale = tensor_iter_valonly(src[2]).begin(); + auto grad_x = tensor_iter_valonly(src[3]).begin(); + auto grad_s = tensor_iter_valonly(src[4]).begin(); + size_t total = src[0].layout.total_nr_elems(); + + for (size_t i = 0; i < total; ++i) { + T t = pow(2, *scale); + T scaled = *input / t; + T rounded = round(scaled); + rounded = rounded <= qmin ? qmin : rounded; + rounded = rounded >= qmax ? qmax : rounded; + bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax; + bool mask_quant = !mask_clip; + + *grad_x = *diff * mask_quant; + T grad_quant = *diff * mask_quant * (rounded - scaled) * t * log(2.0); + T grad_clip = *diff * mask_clip * rounded * t * log(2.0); + *grad_s = grad_quant + grad_clip; + + ++input; + ++diff; + ++scale; + ++grad_x; + ++grad_s; + } +} + +} // namespace +namespace megdnn { +namespace naive { + +void TQTForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_out output, + _megdnn_workspace workspace) { + check_exec(input.layout, scale.layout, output.layout, workspace.size); + ElemwiseOpParamN<3> src; + src[0] = input; + src[1] = output; + src[2] = scale; + src[2].layout = src[2].layout.broadcast(input.layout); +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + forward_impl(src, param().qmin, param().qmax)); \ + return; \ + } + cb(dtype::Float32) +#undef cb +} + +void TQTBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s, + _megdnn_workspace workspace) { + check_exec(diff.layout, input.layout, scale.layout, grad_x.layout, + grad_s.layout, workspace.size); + ElemwiseOpParamN<5> src; + src[0] = diff; + src[1] = input; + src[2] = scale; + src[2].layout = src[2].layout.broadcast(input.layout); + src[3] = grad_x; + src[4] = grad_s; +#define cb(DType) \ + if (diff.layout.dtype == DType() && grad_x.layout.dtype == DType() && \ + input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + backward_impl(src, param().qmin, param().qmax)); \ + return; \ + } + cb(dtype::Float32) +#undef cb +} + +} // namespace naive +} // namespace megdnn diff --git a/dnn/src/naive/tqt/opr_impl.h b/dnn/src/naive/tqt/opr_impl.h new file mode 100644 index 000000000..100d75111 --- /dev/null +++ b/dnn/src/naive/tqt/opr_impl.h @@ -0,0 +1,47 @@ +/** + * \file dnn/src/naive/tqt/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class TQTForwardImpl final : public TQTForward { +public: + using TQTForward::TQTForward; + void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_out output, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& /* input */, + const TensorLayout& /* scale */, + const TensorLayout& /* output */) override { + return 0; + } +}; + +class TQTBackwardImpl final : public TQTBackward { +public: + using TQTBackward::TQTBackward; + void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, + _megdnn_tensor_in scale, _megdnn_tensor_out grad_x, + _megdnn_tensor_out grad_s, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& /* diff */, + const TensorLayout& /* input */, + const TensorLayout& /* scale */, + const TensorLayout& /* grad_x */, + const TensorLayout& /* grad_s */) override { + return 0; + } +}; + +} // namespace naive +} // namespace megdnn diff --git a/dnn/test/common/deduce_layout_proxy.h b/dnn/test/common/deduce_layout_proxy.h index 85df7627e..d0d78d678 100644 --- a/dnn/test/common/deduce_layout_proxy.h +++ b/dnn/test/common/deduce_layout_proxy.h @@ -58,6 +58,16 @@ struct DeduceLayoutProxy { } }; +template +struct DeduceLayoutProxy { + static void deduce_layout(Opr*, TensorLayoutArray&) {} +}; + +template +struct DeduceLayoutProxy { + static void deduce_layout(Opr*, TensorLayoutArray&) {} +}; + template struct DeduceLayoutProxy { static void deduce_layout(Opr*, TensorLayoutArray&) {} diff --git a/dnn/test/common/exec_proxy.h b/dnn/test/common/exec_proxy.h index e3f481ace..76e50d5d9 100644 --- a/dnn/test/common/exec_proxy.h +++ b/dnn/test/common/exec_proxy.h @@ -37,6 +37,22 @@ struct ExecProxy { tensors[5], tensors[6], tensors[7], W.workspace()); } }; + +template +struct ExecProxy { + WorkspaceWrapper W; + void exec(Opr* opr, const TensorNDArray& tensors) { + if (!W.valid()) { + W = WorkspaceWrapper(opr->handle(), 0); + } + W.update(opr->get_workspace_in_bytes( + tensors[0].layout, tensors[1].layout, tensors[2].layout, + tensors[3].layout, tensors[4].layout, tensors[5].layout)); + opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], + tensors[5], W.workspace()); + } +}; + template struct ExecProxy { WorkspaceWrapper W; diff --git a/dnn/test/common/opr_trait.h b/dnn/test/common/opr_trait.h index bf0af976f..93aa547a1 100644 --- a/dnn/test/common/opr_trait.h +++ b/dnn/test/common/opr_trait.h @@ -112,6 +112,8 @@ DEF(RemapBackwardMat, 4, true, false); DEF(DctChannelSelectForward, 4, true, true); DEF(FakeQuantForward, 4, true, true); DEF(FakeQuantBackward, 5, true, false); +DEF(TQTForward, 3, true, true); +DEF(TQTBackward, 5, true, false); } // namespace test } // namespace megdnn diff --git a/dnn/test/common/tqt.h b/dnn/test/common/tqt.h new file mode 100644 index 000000000..e19b85eb0 --- /dev/null +++ b/dnn/test/common/tqt.h @@ -0,0 +1,45 @@ +/** + * \file dnn/test/common/tqt.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megdnn/basic_types.h" +#include "megdnn/opr_param_defs.h" + +namespace megdnn { +namespace test { +namespace tqt { + +struct TestArg { + param::TQT param; + TensorShape ishape; + TensorShape scale_shape; + TestArg(param::TQT param, TensorShape ishape, TensorShape scale_shape) + : param(param), ishape(ishape), scale_shape(scale_shape) {} +}; + +inline std::vector get_args() { + std::vector args; + param::TQT cur_param; + + cur_param.qmin = -127; + cur_param.qmax = 127; + + for (size_t i = 10; i < 30; i += 2) { + args.emplace_back(cur_param, TensorShape{10, 64, i, i}, TensorShape{1}); + } + + return args; +} + +} // namespace tqt +} // namespace test +} // namespace megdnn diff --git a/dnn/test/cuda/tqt.cpp b/dnn/test/cuda/tqt.cpp new file mode 100644 index 000000000..fdbce4db0 --- /dev/null +++ b/dnn/test/cuda/tqt.cpp @@ -0,0 +1,91 @@ +/** + * \file dnn/test/cuda/tqt.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "test/common/tqt.h" +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/cuda/fixture.h" + +namespace megdnn { +namespace test { + +using namespace tqt; + +TEST_F(CUDA, TQT) { + std::vector args = get_args(); + auto dtype = dtype::Float32(); + + for (auto&& arg : args) { + auto param = arg.param; + auto ishape = arg.ishape; + auto scale_shape = arg.scale_shape; + Checker checker(handle_cuda()); + checker.set_param(param) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype) + .execs({ishape, scale_shape, ishape}); + } + // test noncontiguous layout + for (auto&& arg : args) { + auto param = arg.param; + auto ishape = arg.ishape; + auto sshape = arg.scale_shape; + Checker checker(handle_cuda()); + TensorLayout ilayout( + ishape, + {(long int)(ishape[1] * ishape[2] * ishape[3] * 2), + (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1}, + dtype::Float32()); + checker.set_param(param).execl( + {ilayout, {sshape, dtype::Float32()}, ilayout}); + } +} + +TEST_F(CUDA, TQT_BACKWARD) { + std::vector args = get_args(); + auto dtype = dtype::Float32(); + + for (auto&& arg : args) { + auto param = arg.param; + auto ishape = arg.ishape; + auto scale_shape = arg.scale_shape; + Checker checker(handle_cuda()); + checker.set_param(param) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype) + .set_dtype(3, dtype) + .set_dtype(4, dtype) + .execs({ishape, ishape, scale_shape, ishape, ishape}); + } + // test noncontiguous layout + for (auto&& arg : args) { + auto param = arg.param; + auto ishape = arg.ishape; + auto sshape = arg.scale_shape; + Checker checker(handle_cuda()); + TensorLayout ilayout( + ishape, + {(long int)(ishape[1] * ishape[2] * ishape[3] * 2), + (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1}, + dtype::Float32()); + checker.set_param(param).execl({ilayout, + ilayout, + {sshape, dtype::Float32()}, + ilayout, + ilayout}); + } +} + +} // namespace test +} // namespace megdnn \ No newline at end of file diff --git a/imperative/python/megengine/quantization/fake_quant.py b/imperative/python/megengine/quantization/fake_quant.py index e20813a72..66bc56b6d 100644 --- a/imperative/python/megengine/quantization/fake_quant.py +++ b/imperative/python/megengine/quantization/fake_quant.py @@ -15,7 +15,7 @@ from ..core.autodiff.grad import Function from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype from ..module import Module from ..tensor import Parameter, Tensor -from .utils import QuantMode, fake_quant_tensor, get_qparam_dict +from .utils import QuantMode, fake_quant_tensor, get_qparam_dict, tqt_forward class _FakeQuantize(Module): @@ -65,51 +65,6 @@ class _FakeQuantize(Module): return self.normal_foward(inp, q_dict=q_dict) -class TQT_Function(Function): - def __init__(self, lowerbound, upperbound): - super().__init__() - self.lowerbound = lowerbound - self.upperbound = upperbound - self.saved_tensors = () - - def save_for_backward(self, *tensors: Iterable[Tensor]): - """ - Saves tensors needed for gradient computation. This method should be called only - once in :meth:`~.function.Function.forward`, additional calls will replace values saved previously. - - The saved tensors can be accessed through the ``saved_tensors`` attribute. - """ - self.saved_tensors = tensors - - def forward(self, inp, scale): - t = 2 ** scale - # t = F.maximum(t, 1e-4) - inp_scaled = inp / t - inp_clipped = F.maximum(F.minimum(inp_scaled, self.upperbound), self.lowerbound) - inp_rounded = F.round(inp_clipped) - inp_flq = inp_rounded * t - self.save_for_backward(inp_scaled, inp_rounded, t) - return inp_flq - - def backward(self, grad_inp_flq): - (inp_scaled, inp_rounded, t) = self.saved_tensors - mask_clip = F.logical_and( - inp_scaled < -0.5 + self.lowerbound, inp_scaled > self.upperbound + 0.5 - ) # mask for accumulating the gradients of |data_scaled|>L - mask_quant = F.logical_not(mask_clip) - grad_quant = ( - grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) - ) # gradient within |data_scaled|<=L - grad_clip = ( - grad_inp_flq * mask_clip * inp_rounded - ) # gradient with | data_scaled|>L - grad_s = grad_clip.sum() + grad_quant.sum() - # dL/ds = dL/dt * t * ln(2) - grad_s = grad_s * t * math.log(2) - grad_inp = grad_inp_flq * mask_quant - return grad_inp, grad_s - - class TQT(_FakeQuantize): r""" TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds @@ -130,11 +85,11 @@ class TQT(_FakeQuantize): ), "only symmetric quantization is supported by TQT" if "scale" not in q_dict or q_dict["scale"] is None: raise AssertionError("Can not get an initialized scale") - self.scale = F.log(q_dict["scale"]) / math.log(2) + self.scale = Tensor(F.log(q_dict["scale"]) / math.log(2)) def fake_quant_forward(self, inp, q_dict=None): # when enable, TQT will do fakequant forward, finetune the scale - return TQT_Function(self.qmin, self.qmax)(inp, self.scale) + return tqt_forward(self.qmin, self.qmax, inp, self.scale) def get_qparams(self): q_dict = get_qparam_dict(QuantMode.SYMMERTIC) diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index c3e6c9849..3c297cf7f 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -33,6 +33,12 @@ class Round(Function): return output_grads +def tqt_forward(qmin, qmax, inp, scale): + op = builtin.TQT(qmin=qmin, qmax=qmax) + (output,) = apply(op, inp, scale) + return output + + def register_method_to_class(cls): def decorator(func): @wraps(func) diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index be218e231..e0b0384d3 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -13,12 +13,11 @@ import megengine as mge from megengine import tensor from megengine.core.autodiff.grad import Function, Grad from megengine.core.tensor.utils import make_shape_tuple -from megengine.quantization.fake_quant import TQT_Function from megengine.quantization.internal_fake_quant import * -from megengine.quantization.utils import QuantMode, fake_quant_tensor +from megengine.quantization.utils import QuantMode, fake_quant_tensor, tqt_forward -class numpy_TQT_Function: +class TQT_numpy: def __init__(self, lowerbound, upperbound): super().__init__() self.lowerbound = lowerbound @@ -57,27 +56,32 @@ class numpy_TQT_Function: return grad_inp, grad_s -def test_TQT(): - f = TQT_Function(-127, 127) - nf = numpy_TQT_Function(-127, 127) +def test_tqt(): - def check_inp(a, b, c, a_np, b_np, c_np): - np.testing.assert_allclose( - f.forward(a, b).numpy(), - nf.forward(a_np, b_np).astype("float32"), - rtol=1e-6, - atol=1e-6, - ) - c1, c2 = f.backward(c) - c1_np, c2_np = nf.backward(c_np) - np.testing.assert_allclose(c1.numpy(), c1_np.astype("float32"), rtol=1e-6) - np.testing.assert_allclose(c2.numpy(), c2_np.astype("float32"), rtol=5e-5) - - a_np = np.random.random((4, 3)).astype("float32") - b_np = np.random.random((1)).astype("float32") - a = tensor(a_np) - b = tensor(b_np) - check_inp(a, b, b, a_np, b_np, b_np) + g = [] + + def cb(grad): + g.append(grad) + + x = np.random.normal(size=(1, 2, 3, 4)) + s = np.random.rand(1) + 1 + g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32") + + n = TQT_numpy(-127, 127) + y_np = n.forward(x, s) + g_x_np, g_s_np = n.backward(g_y) + + x = mge.tensor(x, dtype="float32") + s = mge.tensor(s, dtype="float32") + g_y = mge.tensor(g_y, dtype="float32") + grad = Grad().wrt(x, s, callback=cb) + y = tqt_forward(-127, 127, x, s) + grad(y, g_y) + g_x, g_s = g + + np.testing.assert_allclose(y.numpy(), y_np, atol=1e-6) + np.testing.assert_allclose(g_x.numpy(), g_x_np, atol=1e-6) + np.testing.assert_allclose(g_s.numpy(), g_s_np, atol=1e-6) diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index bf798101c..f609e1228 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -15,6 +15,7 @@ #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/adaptive_pooling.h" #include "megbrain/opr/dnn/fake_quant.h" +#include "megbrain/opr/dnn/tqt.h" #include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/roi_align.h" @@ -625,6 +626,18 @@ OP_TRAIT_REG(FakeQuant, FakeQuant) .apply_on_var_node(apply_on_var_node) .fallback(); }} // fake_quant +namespace { namespace tqt { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 2); + return opr::TQT::make(inputs[0], inputs[1], op.param()); +} +OP_TRAIT_REG(TQT, TQT) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // tqt namespace { namespace elemwise_multi_type { auto apply_on_var_node( const OpDef& def, @@ -636,7 +649,7 @@ auto apply_on_var_node( OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) .apply_on_var_node(apply_on_var_node) .fallback(); -}} // fake_quant +}} // elemwise_multi_type namespace { namespace svd { auto apply_on_var_node( diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 1b6141576..141166c25 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -232,6 +232,7 @@ def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">; def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; +def TQT: MgbHashableOp<"TQT", [TQTParam]>; def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { let extraArguments = (ins MgbDTypeAttr:$dtype diff --git a/src/opr/impl/dnn/dnn.oprdecl b/src/opr/impl/dnn/dnn.oprdecl index 04d8f81bb..a78d418bf 100644 --- a/src/opr/impl/dnn/dnn.oprdecl +++ b/src/opr/impl/dnn/dnn.oprdecl @@ -319,5 +319,8 @@ decl_opr('FakeQuant', inputs=[Doc('src','input tenosr'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor')], params='FakeQuant') +decl_opr('TQT', + inputs=[Doc('src','input tensor'),Doc('scale','scale tensor')], + params='TQT') # vim: ft=python diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index 3aff31381..e0166b8b4 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -19,6 +19,7 @@ #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/fake_quant.h" +#include "megbrain/opr/dnn/tqt.h" #include "megbrain/serialization/sereg.h" @@ -238,6 +239,19 @@ namespace serialization { } }; + template <> + struct OprMaker { + using Param = opr::TQTBackward::Param; + static cg::OperatorNodeBase* make(const Param& param, + const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + return opr::TQTBackward::make(i[0], i[1], i[2], param, config)[0] + .node() + ->owner_opr(); + } + }; + template struct MakeLocalShareCaller2 { template @@ -426,6 +440,8 @@ namespace opr { MGB_SEREG_OPR(BatchConvBiasForward, 0); MGB_SEREG_OPR(FakeQuant, 3); MGB_SEREG_OPR(FakeQuantBackward, 4); + MGB_SEREG_OPR(TQT, 2); + MGB_SEREG_OPR(TQTBackward, 3); } // namespace opr diff --git a/src/opr/impl/dnn/tqt.cpp b/src/opr/impl/dnn/tqt.cpp new file mode 100644 index 000000000..849347652 --- /dev/null +++ b/src/opr/impl/dnn/tqt.cpp @@ -0,0 +1,83 @@ +/** + * \file src/opr/impl/dnn/tqt.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/opr/dnn/tqt.h" +#include "../internal/megdnn_opr_wrapper.inl" +#include "megbrain/graph/grad_impl.h" +#include "megbrain/opr/basic_arith_wrapper.h" +#include "megbrain/opr/internal/out_shape_by_sym_var.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/utility.h" + +using namespace mgb; +using namespace opr; + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(TQTForward); +MEGDNN_OPR_INIT2(TQTForward, "tqt_fwd"); + +#ifdef MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(TQTForward) { + SymbolVarArray grad = TQTBackward::make(out_grad[0], opr.input(0), + opr.input(1), opr.param()); + + if (wrt_idx == 0) { + return grad[0].node(); + } else if (wrt_idx == 1) { + return reduce_sum(grad[1], GetVarShape::make(opr.input(wrt_idx))) + .node(); + } else { + return nullptr; + } +} +#endif + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(TQTBackward); + +TQTBackward::TQTBackward(VarNode* y_grad, VarNode* x, VarNode* scale, + const Param& param, const OperatorNodeConfig& config) + : Super({x->owner_graph(), config, "tqt_bwd", {y_grad, x, scale}}, 1, + true) { + init_megdnn_opr(*this, param); + add_input({y_grad, x, scale}); +} + +SymbolVarArray TQTBackward::make(SymbolVar y_grad, SymbolVar x, SymbolVar scale, + const Param& param, + const OperatorNodeConfig& config) { + auto&& out = x.node()->owner_graph() + ->insert_opr(std::make_unique( + y_grad.node(), x.node(), scale.node(), param, + config)) + ->output(); + SymbolVarArray ret(out.size()); + for (size_t i = 0; i < ret.size(); ++i) { + ret[i] = out[i]; + } + return ret; +} + +void TQTBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + + mgr.register_shape_infer(output(0), + ShapeInferDesc::make_identity(input(1))); + mgr.register_shape_infer(output(1), + ShapeInferDesc::make_identity(input(1))); + this->init_output_static_infer_desc_workspace( + intl::AutoAddWorkspaceNeedLimitGetter::val); +} + +void TQTBackward::init_output_dtype() { + output(0)->dtype(input(1)->dtype()); + output(1)->dtype(input(2)->dtype()); +} diff --git a/src/opr/impl/internal/megdnn_opr_wrapper.inl b/src/opr/impl/internal/megdnn_opr_wrapper.inl index f2c8ad45c..04c642b53 100644 --- a/src/opr/impl/internal/megdnn_opr_wrapper.inl +++ b/src/opr/impl/internal/megdnn_opr_wrapper.inl @@ -163,6 +163,11 @@ namespace { #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0), _o(1) #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" +#define _NR_INPUTS 3 +#define _NR_OUTPUTS 3 +#define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0), _o(1), _o(2) +#include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl" + #define _NR_INPUTS 4 #define _NR_OUTPUTS 1 #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0) diff --git a/src/opr/include/megbrain/opr/dnn/tqt.h b/src/opr/include/megbrain/opr/dnn/tqt.h new file mode 100644 index 000000000..61b395713 --- /dev/null +++ b/src/opr/include/megbrain/opr/dnn/tqt.h @@ -0,0 +1,46 @@ +/** + * \file src/opr/include/megbrain/opr/dnn/tqt.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megdnn/oprs.h" +namespace mgb { +namespace opr { + +MGB_DEFINE_OPR_CLASS(TQTForward, + intl::MegDNNOprWrapperFwd) // { +public: + TQTForward(VarNode* src, VarNode* scale, const Param& param, + const OperatorNodeConfig& config); + + static SymbolVar make(SymbolVar src, SymbolVar scale, const Param& param = {}, + const OperatorNodeConfig& config = {}); +}; +using TQT = TQTForward; + +MGB_DEFINE_OPR_CLASS(TQTBackward, + intl::MegDNNOprWrapperBwd) // { +public: + TQTBackward(VarNode* y_grad, VarNode* x, VarNode* scale, const Param& param, + const OperatorNodeConfig& config); + + static SymbolVarArray make(SymbolVar y_grad, SymbolVar x, SymbolVar scale, + const Param& param = {}, + const OperatorNodeConfig& config = {}); + +private: + void init_output_static_infer_desc() override; + void init_output_dtype() override; +}; + +} // namespace opr +} // namespace mgb diff --git a/src/opr/test/dnn/tqt.cpp b/src/opr/test/dnn/tqt.cpp new file mode 100644 index 000000000..977ea03e8 --- /dev/null +++ b/src/opr/test/dnn/tqt.cpp @@ -0,0 +1,67 @@ +/** + * \file src/opr/test/dnn/tqt.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/opr/dnn/tqt.h" +#include "megbrain/comp_node_env.h" +#include "megbrain/test/autocheck.h" + +using namespace std; +using namespace mgb; + +namespace { + +void run() { + using Checker = AutoOprChecker<2, 1>; + + auto make_graph = + [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + auto o0 = opr::TQTForward::make(inputs[0], inputs[1]); + return {o0}; + }; + + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + auto opr = MegDNNHandle::get( + CompNodeEnv::from_comp_node(CompNode::default_cpu())) + ->create_operator(); + dest[0].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize(inp[0]->shape()); + opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), + {}); + }; + + auto gen = [&](HostTensorND& src) { + HostTensorGenerator + src_gen(10.f); + src = *src_gen(src.shape(), src.comp_node()); + }; + + Checker::RunOptions opt; + opt.numdiff_max_err = 1e-5; + + Checker checker{make_graph, fwd}; + checker.set_input_generator(0, gen) + .set_input_generator(1, gen) + .set_input_allow_grad(0, false) + .set_input_allow_grad(1, false) + .set_output_allow_grad(0, false); + checker.run({TensorShape{1, 2, 3, 4}, TensorShape{1}}, opt) + .run({TensorShape{2, 3, 8, 8}, TensorShape{1}}, opt) + .run({TensorShape{1, 3, 4, 4}, TensorShape{1}}, opt); +} + +} // anonymous namespace + +TEST(TestOprDNN, TQTForward) { + REQUIRE_GPU(1); + run(); +} diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 67a9b90be..403d92244 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -105,6 +105,7 @@ union OperatorParam { param.NvOf = 71, param.DctChannelSelect = 72, param.FakeQuant = 73, + param.TQT = 74, } table Operator { -- GitLab