From af576e9ae079f0965b78f588d4a7d93bbf64d66f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 24 Aug 2021 19:40:32 +0800 Subject: [PATCH] fix(mgb/gopt): fix auto padding for nhwc layout GitOrigin-RevId: 038e372cbecb6f14a408c87cf2d30eba020b4605 --- src/gopt/impl/fuse_nchw4_int8_preprocess.cpp | 36 +++++++- src/gopt/impl/layout_transform_pass.cpp | 31 ++++--- src/gopt/impl/profiler_impl.cpp | 83 +++++++++++------- src/gopt/impl/reformat_manager.cpp | 86 ++++++++++++++++--- .../megbrain/gopt/global_layout_transform.h | 1 + .../include/megbrain/gopt/reformat_manager.h | 38 +++++--- 6 files changed, 203 insertions(+), 72 deletions(-) diff --git a/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp b/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp index 4e45cc508..a16cec192 100644 --- a/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp +++ b/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp @@ -200,6 +200,15 @@ static inline bool is_nchw_nchw4_shuffle_vec( param.pattern[4] == 2; } +static inline bool is_shape_before_nhwc(const TensorShape& shape) { + return shape.ndim == 4 && shape[1] == 4; +} + +static inline bool is_nchw_nhwc_shuffle(const opr::Dimshuffle::Param param) { + return param.ndim == 4 && param.pattern[0] == 0 && param.pattern[1] == 2 && + param.pattern[2] == 3 && param.pattern[3] == 1; +} + template static inline bool is_immutable_equal(OperatorNodeBase* opr, T val, DTypeEnum dtype_enum) { @@ -276,14 +285,20 @@ std::unique_ptr FuseNCHW4Int8Preprocess::make() { auto inp0 = opr->input()[0]; return is_shape_nchw(inp0->shape()); }}; + SGM::Node shuffle_root{ opr::Dimshuffle::typeinfo(), - {{nchwx_reshape}}, + {{nchwx_reshape}, {broadcast_concat}}, [](OperatorNodeBase* opr) { auto& shuffle_opr = opr->cast_final(); auto& input_vec = shuffle_opr.input(); - return is_shape_before_nchw4(input_vec[0]->shape()) && - is_nchw_nchw4_shuffle_vec(shuffle_opr.param()); + bool nchw_nchw4_ok = + is_shape_before_nchw4(input_vec[0]->shape()) && + is_nchw_nchw4_shuffle_vec(shuffle_opr.param()); + bool nchw_nhwc_ok = + is_shape_before_nhwc(input_vec[0]->shape()) && + is_nchw_nhwc_shuffle(shuffle_opr.param()); + return nchw_nchw4_ok || nchw_nhwc_ok; }}; return shuffle_root; }; @@ -382,6 +397,19 @@ std::unique_ptr FuseNCHW4Int8Preprocess::make() { auto out_node = opr::RelayoutFormat::make( rewriter.get_var(src_node->output()[0]), param.mode, config); + const auto& outshp = opr->output(0)->shape(); + if (outshp.ndim == 4) { + auto shpvar = opr::GetVarShape::make(out_node); + auto cv = [&out_node](int v) { + return out_node.make_scalar(v); + }; + auto sub = [&shpvar, &cv](int idx) { + return opr::IndexAt::make(shpvar, {{0, cv(idx)}}); + }; + auto nhwc_shp = + opr::Concat::make({sub(0), sub(2), sub(3), sub(4)}, 0); + out_node = opr::Reshape::make(out_node, nhwc_shp); + } return out_node.node()->owner_opr(); } else { return serialization::copy_opr_shallow(*opr, new_inp, @@ -740,4 +768,4 @@ void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const { }; opt.graph().iter(on_opr); rewriter.apply_inplace(); -} \ No newline at end of file +} diff --git a/src/gopt/impl/layout_transform_pass.cpp b/src/gopt/impl/layout_transform_pass.cpp index d994cdcc3..0c0399a1a 100644 --- a/src/gopt/impl/layout_transform_pass.cpp +++ b/src/gopt/impl/layout_transform_pass.cpp @@ -92,19 +92,24 @@ void LayoutTransformPass::apply(OptState& opt) const { bool is_parameter = fmtcfg.valid() && fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; - ReformatManager::ReformatImpl reformat; - ReformatManager::ReformatKey key{from, to, reformat_attribute, - var->dtype().enumv(), - var->dtype().enumv()}; - if (is_parameter) { - auto aligned_desc = make_aligned_desc(base_fmt, out_fmt); - reformat = ReformatManager::instance() - .auto_aligned_reformat_weight( - var, key, aligned_desc); - } else { - reformat = ReformatManager::instance() - .auto_aligned_reformat_featrue( - var, base_fmt, key); + // need relayout + if (from != to && !new_var->shape().is_scalar()) { + ReformatManager::ReformatImpl reformat; + ReformatManager::ReformatKey key{ + from, to, reformat_attribute, var->dtype().enumv(), + var->dtype().enumv()}; + if (is_parameter) { + auto aligned_desc = ReformatManager::make_aligned_desc( + base_fmt, out_fmt); + reformat = ReformatManager::instance() + .auto_aligned_reformat_weight( + var, key, aligned_desc); + } else { + reformat = ReformatManager::instance() + .auto_aligned_reformat_featrue( + var, base_fmt, key); + } + new_var = reformat({new_var}); } if (from != to && !new_var->shape().is_scalar()) new_var = reformat({new_var}); diff --git a/src/gopt/impl/profiler_impl.cpp b/src/gopt/impl/profiler_impl.cpp index a760245c5..7ab62c3cf 100644 --- a/src/gopt/impl/profiler_impl.cpp +++ b/src/gopt/impl/profiler_impl.cpp @@ -165,6 +165,7 @@ public: private: static constexpr float PROFILE_TIME_OUT = 1e7; + using ReformatAttribute = ReformatKey::Attribute; /*! * \brief profile opr format agnostic operators (like elemwise, elemwise multi type, typecvt etc.) * @@ -175,40 +176,48 @@ private: */ OperatorNodeRecord profile_operator( const OperatorNodeBase* opr, TensorFormats base_format, - const SmallVector& available_tensor_formats) const; + const SmallVector& available_tensor_formats, + ReformatAttribute extra_attribute = + ReformatAttribute::DEFAULT) const; float profile_operator(const OperatorNodeBase* opr, TensorFormats base_format, - TensorFormats tensor_format) const; + TensorFormats tensor_format, + ReformatAttribute extra_attribute = + ReformatAttribute::DEFAULT) const; /*! - * \brief profile opr format aware operators (like conv, deconv, conv_bias, etc.) + * \brief profile opr format aware operators (like conv, deconv, conv_bias, + * etc.) * * \param opr pointer to the operator node to be profiled * \param base_config the tensor formats configuration of base opr format - * \param config all the available configuration + * \param config all the available configuration * \return the operator node record */ OperatorNodeRecord profile_operator( const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config, - const SmallVector& available_configs) - const; + const SmallVector& available_configs, + ReformatAttribute extra_attribute = + ReformatAttribute::DEFAULT) const; float profile_operator(const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config, - const OprTensorFormatsConfiguration& config) const; + const OprTensorFormatsConfiguration& config, + ReformatAttribute extra_attribute = + ReformatAttribute::DEFAULT) const; /*! * \brief profile layout transform of the var node * * \param var pointer to the var node to be profiled - * \param base_format the original tensor formats in which the var node is stored - * \param available_tensor_formats the available tensor formats + * \param base_format the original tensor formats in which the var node is + * stored \param available_tensor_formats the available tensor formats * \param extra_attribute the extra attributes (options) of the problem * \return the var node record */ VarNodeRecord profile_var_node( const VarNode* var, TensorFormats base_format, const SmallVector& available_tensor_formats, - ReformatKey::Attribute extra_attribute = - ReformatKey::Attribute::DEFAULT) const; + ReformatAttribute extra_attribute = + ReformatAttribute::DEFAULT) const; float profile_var_node(const VarNode* var, TensorFormats base_format, const ReformatKey& key) const; int m_runs; /// sample times of the profiler @@ -216,20 +225,23 @@ private: ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( const OperatorNodeBase* opr, TensorFormats base_format, - const SmallVector& available_tensor_formats) const { + const SmallVector& available_tensor_formats, + ReformatAttribute extra_attribute) const { OperatorNodeRecord record; record.opr = opr; auto& costs = record.costs; for (auto&& f : available_tensor_formats) { auto opr_format = tensor_formats_to_opr_format(f); - costs[opr_format] = profile_operator(opr, base_format, f); + costs[opr_format] = + profile_operator(opr, base_format, f, extra_attribute); } return record; } float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, TensorFormats base_format, - TensorFormats tensor_format) const { + TensorFormats tensor_format, + ReformatAttribute extra_attribute) const { auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0; graph->options().var_sanity_check_first_run = false; @@ -239,8 +251,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, auto&& cn = var->comp_node(); auto&& dtype = var->dtype(); auto dval = std::make_shared(cn, dtype); - auto aligned_tensor_shape = - make_aligned_tensor_shape(var, base_format, tensor_format); + auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( + var, base_format, tensor_format, extra_attribute); dval->resize(aligned_tensor_shape); auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); new_inps[i] = aligned_var.node(); @@ -263,8 +275,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config, - const SmallVector& available_configs) - const { + const SmallVector& available_configs, + ReformatAttribute extra_attribute) const { OperatorNodeRecord record; record.opr = opr; auto& costs = record.costs; @@ -273,7 +285,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( if (i.opr_format == OprFormat::NCHW && opr->input(0)->dtype().enumv() != DTypeEnum::Float32) continue; - costs[i.opr_format] = profile_operator(opr, base_config, i); + costs[i.opr_format] = + profile_operator(opr, base_config, i, extra_attribute); } return record; } @@ -281,7 +294,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( float ProfilerImpl::profile_operator( const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config, - const OprTensorFormatsConfiguration& config) const { + const OprTensorFormatsConfiguration& config, + ReformatAttribute extra_attribute) const { auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0; graph->options().var_sanity_check_first_run = false; @@ -297,18 +311,18 @@ float ProfilerImpl::profile_operator( TensorShape aligned_shape; if (config.input_tensor_types[i] == TensorType::WEIGHT) { mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT); - aligned_shape = make_aligned_weight_shape( + aligned_shape = ReformatManager::make_aligned_weight_shape( var, base_config.input_tensor_formats[i], config.input_tensor_formats[i], - config.output_tensor_formats[0]); + config.output_tensor_formats[0], extra_attribute); } else { mgb_assert(base_config.input_tensor_types[i] == config.input_tensor_types[i]); mgb_assert(base_config.input_tensor_types[i] == TensorType::FEATURE); - aligned_shape = make_aligned_tensor_shape( + aligned_shape = ReformatManager::make_aligned_tensor_shape( var, base_config.input_tensor_formats[i], - config.input_tensor_formats[i]); + config.input_tensor_formats[i], extra_attribute); } dval->resize(aligned_shape); auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); @@ -357,7 +371,7 @@ float ProfilerImpl::profile_operator( ProfilerImpl::VarNodeRecord ProfilerImpl::profile_var_node( const VarNode* var, TensorFormats base_format, const SmallVector& available_tensor_formats, - ReformatKey::Attribute attribute) const { + ReformatAttribute attribute) const { VarNodeRecord record; record.var = var; auto& costs = record.costs; @@ -379,8 +393,8 @@ float ProfilerImpl::profile_var_node(const VarNode* var, auto&& cn = var->comp_node(); auto&& dtype = var->dtype(); auto dval = std::make_shared(cn, dtype); - auto aligned_tensor_shape = - make_aligned_tensor_shape(var, base_format, key.input_format); + auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( + var, base_format, key.input_format, key.attribute); dval->resize(aligned_tensor_shape); auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0; @@ -468,13 +482,14 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( auto base_format = problem.base_format(); auto&& available_tensor_formats = problem.available_tensor_formats(); + auto&& reformat_attribute = problem.attribute().reformat_attribute; ProfilingResult profiling_result; auto& opr_record = profiling_result.opr_record; auto& var_record = profiling_result.var_record; for (auto&& var : vars) { - var_record[var] = - profile_var_node(var, base_format, available_tensor_formats); + var_record[var] = profile_var_node( + var, base_format, available_tensor_formats, reformat_attribute); } for (auto&& opr : oprs) { auto&& opr_configs = problem.opr_configs(); @@ -482,11 +497,12 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( if (find == opr_configs.end()) { if (skip_oprs.count(opr) > 0) { SmallVector tensor_formats = {base_format}; - opr_record[opr] = - profile_operator(opr, base_format, tensor_formats); + opr_record[opr] = profile_operator( + opr, base_format, tensor_formats, reformat_attribute); } else { opr_record[opr] = profile_operator(opr, base_format, - available_tensor_formats); + available_tensor_formats, + reformat_attribute); } } else { auto&& dispatchers = find->second; @@ -498,7 +514,8 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( } } auto base_config = problem.base_config(opr); - opr_record[opr] = profile_operator(opr, base_config, configs); + opr_record[opr] = profile_operator(opr, base_config, configs, + reformat_attribute); } } for (auto&& rpair : opr_record) { diff --git a/src/gopt/impl/reformat_manager.cpp b/src/gopt/impl/reformat_manager.cpp index d69cae29b..d37868385 100644 --- a/src/gopt/impl/reformat_manager.cpp +++ b/src/gopt/impl/reformat_manager.cpp @@ -21,7 +21,7 @@ using NamedTensorShape = megdnn::NamedTensorShape; using Dimension = megdnn::Dimension; namespace { -int gcd(const int& p, const int& q) { +static inline int gcd(const int& p, const int& q) { int x = p, y = q; while (y != 0) { if (x < y) { @@ -33,6 +33,47 @@ int gcd(const int& p, const int& q) { } return x; } + +static inline size_t extra_alignment( + ReformatManager::ReformatKey::Attribute attr, + TensorFormats target_formats, DType dt, size_t channel_alignment) { + using Attribute = ReformatManager::ReformatKey::Attribute; + if (attr & Attribute::AUTO_PADDING_NHWC) { + constexpr size_t alignment_in_bits = 32; + size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; + size_t extra_alignment = alignment_in_bits >= dtype_bits + ? alignment_in_bits / dtype_bits + : 1; + if (target_formats == TensorFormats::NHWC) + channel_alignment = extra_alignment * channel_alignment / + gcd(channel_alignment, extra_alignment); + return channel_alignment; + } + return channel_alignment; +} + +static inline std::tuple extra_alignment( + const ReformatManager::ReformatKey& key, DType dt, + size_t input_channel_alignment, size_t output_channel_alignment) { + using Attribute = ReformatManager::ReformatKey::Attribute; + if (key.attribute & Attribute::AUTO_PADDING_NHWC) { + constexpr size_t alignment_in_bits = 32; + size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; + size_t extra_alignment = alignment_in_bits >= dtype_bits + ? alignment_in_bits / dtype_bits + : 1; + if (key.input_format == TensorFormats::NHWC) + input_channel_alignment = + input_channel_alignment * extra_alignment / + gcd(input_channel_alignment, extra_alignment); + if (key.output_format == TensorFormats::NHWC) + output_channel_alignment = + output_channel_alignment * extra_alignment / + gcd(output_channel_alignment, extra_alignment); + return {input_channel_alignment, output_channel_alignment}; + } + return {input_channel_alignment, output_channel_alignment}; +} }; // namespace // =================== ReformatManager::ReformatKey ====================*/ @@ -293,7 +334,8 @@ ReformatManager::ReformatImpl ReformatManager::get( auto rst = find->second; return rst; } - mgb_assert(key.attribute == Attribute::DEFAULT); + mgb_assert(!(key.attribute & Attribute::IMAGE2D) && + !(key.attribute & Attribute::IC_SMALL)); auto&& i = key.input_format; auto&& o = key.output_format; auto ishp = tensor_formats_to_named_tensor_shape(i); @@ -346,6 +388,8 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_featrue( "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", input_alignment, output_alignment, input_shape.to_string().c_str()); + std::tie(input_alignment, output_alignment) = extra_alignment( + key, orig_var->dtype(), input_alignment, output_alignment); NamedTensorShape orig_shape = tensor_formats_to_named_tensor_shape(orig_format); size_t orig_channel = 0; @@ -451,6 +495,12 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", in_channel_alignment, out_channel_alignment, output_shape.to_string().c_str()); + in_channel_alignment = + ::extra_alignment(key.attribute, key.output_format, + orig_var->dtype(), in_channel_alignment); + out_channel_alignment = + ::extra_alignment(key.attribute, key.output_format, + orig_var->dtype(), out_channel_alignment); size_t aligned_in_channel = divup(in_channels, in_channel_alignment) * in_channel_alignment; if (extra_alignment.name == out_channel_name) { @@ -506,9 +556,9 @@ const ReformatManager& ReformatManager::instance() { return inst; } -TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, - TensorFormats orig_formats, - TensorFormats target_formats) { +TensorShape ReformatManager::make_aligned_tensor_shape( + const VarNode* var, TensorFormats orig_formats, + TensorFormats target_formats, ReformatKey::Attribute extra_attribute) { using Dimension = megdnn::Dimension; static constexpr uint32_t UNDETERMINED_EXTENT = Dimension::UNDETERMINED_EXTENT; @@ -545,6 +595,15 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, tshp[i] = oshp[idx] * factor; else tshp[i] = divup(oshp[idx], factor); + if (name == Dimension::Name::C) { + size_t channel_alignment = target_shape[i].stride(); + size_t channels = tshp[i] * channel_alignment; + size_t new_channel_alignment = + extra_alignment(extra_attribute, target_formats, + var->dtype(), channel_alignment); + tshp[i] = divup(channels, new_channel_alignment) * + new_channel_alignment / channel_alignment; + } } else { tshp[i] = target_shape[i].extent(); } @@ -552,11 +611,12 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, return tshp; } -TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, - TensorFormats orig_formats, - TensorFormats target_formats, - TensorFormats extra_formats) { - auto tshp = make_aligned_tensor_shape(var, orig_formats, target_formats); +TensorShape ReformatManager::make_aligned_weight_shape( + const VarNode* var, TensorFormats orig_formats, + TensorFormats target_formats, TensorFormats extra_formats, + ReformatKey::Attribute extra_attribute) { + auto tshp = make_aligned_tensor_shape(var, orig_formats, target_formats, + extra_attribute); auto extra_shape = tensor_formats_to_named_tensor_shape(extra_formats); using Dimension = megdnn::Dimension; static constexpr uint32_t UNDETERMINED_EXTENT = @@ -567,6 +627,9 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, if (name == Dimension::Name::C && extra_shape[i].extent() == UNDETERMINED_EXTENT) { out_channel_alignment = extra_shape[i].stride(); + out_channel_alignment = + extra_alignment(extra_attribute, target_formats, + var->dtype(), out_channel_alignment); } } @@ -583,9 +646,8 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, return tshp; } -ReformatManager::AlignmentDesc mgb::gopt::make_aligned_desc( +ReformatManager::AlignmentDesc ReformatManager::make_aligned_desc( TensorFormats weight_format, TensorFormats out_feature_format) { - using AlignmentDesc = ReformatManager::AlignmentDesc; using Name = Dimension::Name; auto weight_shape = tensor_formats_to_named_tensor_shape(weight_format); auto out_shape = tensor_formats_to_named_tensor_shape(out_feature_format); diff --git a/src/gopt/include/megbrain/gopt/global_layout_transform.h b/src/gopt/include/megbrain/gopt/global_layout_transform.h index 5d7297175..50a9b6158 100644 --- a/src/gopt/include/megbrain/gopt/global_layout_transform.h +++ b/src/gopt/include/megbrain/gopt/global_layout_transform.h @@ -143,6 +143,7 @@ public: TensorFormats base_format() const { return m_ctx.attribute().base_tensor_formats; } + Attribute attribute() const { return m_ctx.attribute(); } /*! * \brief return the tensor formats configuration of an operator in the * default op format diff --git a/src/gopt/include/megbrain/gopt/reformat_manager.h b/src/gopt/include/megbrain/gopt/reformat_manager.h index 7464dea2a..cef608e44 100644 --- a/src/gopt/include/megbrain/gopt/reformat_manager.h +++ b/src/gopt/include/megbrain/gopt/reformat_manager.h @@ -74,6 +74,7 @@ public: DEFAULT = 0, IMAGE2D = 1 << 0, IC_SMALL = 1 << 1, + AUTO_PADDING_NHWC = 1 << 2, }; TensorFormats input_format, output_format; DTypeEnum input_dtype, output_dtype; @@ -124,23 +125,40 @@ public: ReformatImpl auto_aligned_reformat_weight( const VarNode* orig_var, const ReformatKey& key, const AlignmentDesc& extra_alignment = {}) const; + + static TensorShape make_aligned_tensor_shape( + const VarNode* var, TensorFormats orig_formats, + TensorFormats target_formats, + ReformatKey::Attribute extra_attribute = + ReformatKey::Attribute::DEFAULT); + static TensorShape make_aligned_weight_shape( + const VarNode* var, TensorFormats orig_formats, + TensorFormats target_formats, TensorFormats extra_formats, + ReformatKey::Attribute extra_attribute = + ReformatKey::Attribute::DEFAULT); + static AlignmentDesc make_aligned_desc(TensorFormats weight_format, + TensorFormats out_feature_format); + static const ReformatManager& instance(); private: ReformatCache m_cache; }; -TensorShape make_aligned_tensor_shape(const VarNode* var, - TensorFormats orig_formats, - TensorFormats target_formats); - -TensorShape make_aligned_weight_shape(const VarNode* var, - TensorFormats orig_formats, - TensorFormats target_formats, - TensorFormats extra_formats); +MGB_DEF_ENUM_CLASS_BIT_OPR(ReformatManager::ReformatKey::Attribute); +// +//TensorShape make_aligned_tensor_shape( +// const VarNode* var, TensorFormats orig_formats, +// TensorFormats target_formats, +// ReformatManager::ReformatKey::Attribute extra_attribute = +// ReformatManager::ReformatKey::Attribute::DEFAULT); +// +//TensorShape make_aligned_weight_shape( +// const VarNode* var, TensorFormats orig_formats, +// TensorFormats target_formats, TensorFormats extra_formats, +// ReformatManager::ReformatKey::Attribute extra_attribute = +// ReformatManager::ReformatKey::Attribute::DEFAULT); -ReformatManager::AlignmentDesc make_aligned_desc( - TensorFormats weight_format, TensorFormats out_feature_format); } // namespace gopt } // namespace mgb -- GitLab