提交 af576e9a 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix auto padding for nhwc layout

GitOrigin-RevId: 038e372cbecb6f14a408c87cf2d30eba020b4605
上级 af828ca9
...@@ -200,6 +200,15 @@ static inline bool is_nchw_nchw4_shuffle_vec( ...@@ -200,6 +200,15 @@ static inline bool is_nchw_nchw4_shuffle_vec(
param.pattern[4] == 2; param.pattern[4] == 2;
} }
static inline bool is_shape_before_nhwc(const TensorShape& shape) {
return shape.ndim == 4 && shape[1] == 4;
}
static inline bool is_nchw_nhwc_shuffle(const opr::Dimshuffle::Param param) {
return param.ndim == 4 && param.pattern[0] == 0 && param.pattern[1] == 2 &&
param.pattern[2] == 3 && param.pattern[3] == 1;
}
template <typename T> template <typename T>
static inline bool is_immutable_equal(OperatorNodeBase* opr, T val, static inline bool is_immutable_equal(OperatorNodeBase* opr, T val,
DTypeEnum dtype_enum) { DTypeEnum dtype_enum) {
...@@ -276,14 +285,20 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() { ...@@ -276,14 +285,20 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
auto inp0 = opr->input()[0]; auto inp0 = opr->input()[0];
return is_shape_nchw(inp0->shape()); return is_shape_nchw(inp0->shape());
}}; }};
SGM::Node shuffle_root{ SGM::Node shuffle_root{
opr::Dimshuffle::typeinfo(), opr::Dimshuffle::typeinfo(),
{{nchwx_reshape}}, {{nchwx_reshape}, {broadcast_concat}},
[](OperatorNodeBase* opr) { [](OperatorNodeBase* opr) {
auto& shuffle_opr = opr->cast_final<opr::Dimshuffle>(); auto& shuffle_opr = opr->cast_final<opr::Dimshuffle>();
auto& input_vec = shuffle_opr.input(); auto& input_vec = shuffle_opr.input();
return is_shape_before_nchw4(input_vec[0]->shape()) && bool nchw_nchw4_ok =
is_nchw_nchw4_shuffle_vec(shuffle_opr.param()); is_shape_before_nchw4(input_vec[0]->shape()) &&
is_nchw_nchw4_shuffle_vec(shuffle_opr.param());
bool nchw_nhwc_ok =
is_shape_before_nhwc(input_vec[0]->shape()) &&
is_nchw_nhwc_shuffle(shuffle_opr.param());
return nchw_nchw4_ok || nchw_nhwc_ok;
}}; }};
return shuffle_root; return shuffle_root;
}; };
...@@ -382,6 +397,19 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() { ...@@ -382,6 +397,19 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
auto out_node = opr::RelayoutFormat::make( auto out_node = opr::RelayoutFormat::make(
rewriter.get_var(src_node->output()[0]), param.mode, rewriter.get_var(src_node->output()[0]), param.mode,
config); config);
const auto& outshp = opr->output(0)->shape();
if (outshp.ndim == 4) {
auto shpvar = opr::GetVarShape::make(out_node);
auto cv = [&out_node](int v) {
return out_node.make_scalar(v);
};
auto sub = [&shpvar, &cv](int idx) {
return opr::IndexAt::make(shpvar, {{0, cv(idx)}});
};
auto nhwc_shp =
opr::Concat::make({sub(0), sub(2), sub(3), sub(4)}, 0);
out_node = opr::Reshape::make(out_node, nhwc_shp);
}
return out_node.node()->owner_opr(); return out_node.node()->owner_opr();
} else { } else {
return serialization::copy_opr_shallow(*opr, new_inp, return serialization::copy_opr_shallow(*opr, new_inp,
...@@ -740,4 +768,4 @@ void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const { ...@@ -740,4 +768,4 @@ void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const {
}; };
opt.graph().iter(on_opr); opt.graph().iter(on_opr);
rewriter.apply_inplace(); rewriter.apply_inplace();
} }
\ No newline at end of file
...@@ -92,19 +92,24 @@ void LayoutTransformPass::apply(OptState& opt) const { ...@@ -92,19 +92,24 @@ void LayoutTransformPass::apply(OptState& opt) const {
bool is_parameter = bool is_parameter =
fmtcfg.valid() && fmtcfg.val().input_tensor_types[i] == fmtcfg.valid() && fmtcfg.val().input_tensor_types[i] ==
TensorType::WEIGHT; TensorType::WEIGHT;
ReformatManager::ReformatImpl reformat; // need relayout
ReformatManager::ReformatKey key{from, to, reformat_attribute, if (from != to && !new_var->shape().is_scalar()) {
var->dtype().enumv(), ReformatManager::ReformatImpl reformat;
var->dtype().enumv()}; ReformatManager::ReformatKey key{
if (is_parameter) { from, to, reformat_attribute, var->dtype().enumv(),
auto aligned_desc = make_aligned_desc(base_fmt, out_fmt); var->dtype().enumv()};
reformat = ReformatManager::instance() if (is_parameter) {
.auto_aligned_reformat_weight( auto aligned_desc = ReformatManager::make_aligned_desc(
var, key, aligned_desc); base_fmt, out_fmt);
} else { reformat = ReformatManager::instance()
reformat = ReformatManager::instance() .auto_aligned_reformat_weight(
.auto_aligned_reformat_featrue( var, key, aligned_desc);
var, base_fmt, key); } else {
reformat = ReformatManager::instance()
.auto_aligned_reformat_featrue(
var, base_fmt, key);
}
new_var = reformat({new_var});
} }
if (from != to && !new_var->shape().is_scalar()) if (from != to && !new_var->shape().is_scalar())
new_var = reformat({new_var}); new_var = reformat({new_var});
......
...@@ -165,6 +165,7 @@ public: ...@@ -165,6 +165,7 @@ public:
private: private:
static constexpr float PROFILE_TIME_OUT = 1e7; static constexpr float PROFILE_TIME_OUT = 1e7;
using ReformatAttribute = ReformatKey::Attribute;
/*! /*!
* \brief profile opr format agnostic operators (like elemwise, elemwise multi type, typecvt etc.) * \brief profile opr format agnostic operators (like elemwise, elemwise multi type, typecvt etc.)
* *
...@@ -175,40 +176,48 @@ private: ...@@ -175,40 +176,48 @@ private:
*/ */
OperatorNodeRecord profile_operator( OperatorNodeRecord profile_operator(
const OperatorNodeBase* opr, TensorFormats base_format, const OperatorNodeBase* opr, TensorFormats base_format,
const SmallVector<TensorFormats>& available_tensor_formats) const; const SmallVector<TensorFormats>& available_tensor_formats,
ReformatAttribute extra_attribute =
ReformatAttribute::DEFAULT) const;
float profile_operator(const OperatorNodeBase* opr, float profile_operator(const OperatorNodeBase* opr,
TensorFormats base_format, TensorFormats base_format,
TensorFormats tensor_format) const; TensorFormats tensor_format,
ReformatAttribute extra_attribute =
ReformatAttribute::DEFAULT) const;
/*! /*!
* \brief profile opr format aware operators (like conv, deconv, conv_bias, etc.) * \brief profile opr format aware operators (like conv, deconv, conv_bias,
* etc.)
* *
* \param opr pointer to the operator node to be profiled * \param opr pointer to the operator node to be profiled
* \param base_config the tensor formats configuration of base opr format * \param base_config the tensor formats configuration of base opr format
* \param config all the available configuration * \param config all the available configuration
* \return the operator node record * \return the operator node record
*/ */
OperatorNodeRecord profile_operator( OperatorNodeRecord profile_operator(
const OperatorNodeBase* opr, const OperatorNodeBase* opr,
const OprTensorFormatsConfiguration& base_config, const OprTensorFormatsConfiguration& base_config,
const SmallVector<OprTensorFormatsConfiguration>& available_configs) const SmallVector<OprTensorFormatsConfiguration>& available_configs,
const; ReformatAttribute extra_attribute =
ReformatAttribute::DEFAULT) const;
float profile_operator(const OperatorNodeBase* opr, float profile_operator(const OperatorNodeBase* opr,
const OprTensorFormatsConfiguration& base_config, const OprTensorFormatsConfiguration& base_config,
const OprTensorFormatsConfiguration& config) const; const OprTensorFormatsConfiguration& config,
ReformatAttribute extra_attribute =
ReformatAttribute::DEFAULT) const;
/*! /*!
* \brief profile layout transform of the var node * \brief profile layout transform of the var node
* *
* \param var pointer to the var node to be profiled * \param var pointer to the var node to be profiled
* \param base_format the original tensor formats in which the var node is stored * \param base_format the original tensor formats in which the var node is
* \param available_tensor_formats the available tensor formats * stored \param available_tensor_formats the available tensor formats
* \param extra_attribute the extra attributes (options) of the problem * \param extra_attribute the extra attributes (options) of the problem
* \return the var node record * \return the var node record
*/ */
VarNodeRecord profile_var_node( VarNodeRecord profile_var_node(
const VarNode* var, TensorFormats base_format, const VarNode* var, TensorFormats base_format,
const SmallVector<TensorFormats>& available_tensor_formats, const SmallVector<TensorFormats>& available_tensor_formats,
ReformatKey::Attribute extra_attribute = ReformatAttribute extra_attribute =
ReformatKey::Attribute::DEFAULT) const; ReformatAttribute::DEFAULT) const;
float profile_var_node(const VarNode* var, TensorFormats base_format, float profile_var_node(const VarNode* var, TensorFormats base_format,
const ReformatKey& key) const; const ReformatKey& key) const;
int m_runs; /// sample times of the profiler int m_runs; /// sample times of the profiler
...@@ -216,20 +225,23 @@ private: ...@@ -216,20 +225,23 @@ private:
ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
const OperatorNodeBase* opr, TensorFormats base_format, const OperatorNodeBase* opr, TensorFormats base_format,
const SmallVector<TensorFormats>& available_tensor_formats) const { const SmallVector<TensorFormats>& available_tensor_formats,
ReformatAttribute extra_attribute) const {
OperatorNodeRecord record; OperatorNodeRecord record;
record.opr = opr; record.opr = opr;
auto& costs = record.costs; auto& costs = record.costs;
for (auto&& f : available_tensor_formats) { for (auto&& f : available_tensor_formats) {
auto opr_format = tensor_formats_to_opr_format(f); auto opr_format = tensor_formats_to_opr_format(f);
costs[opr_format] = profile_operator(opr, base_format, f); costs[opr_format] =
profile_operator(opr, base_format, f, extra_attribute);
} }
return record; return record;
} }
float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, float ProfilerImpl::profile_operator(const OperatorNodeBase* opr,
TensorFormats base_format, TensorFormats base_format,
TensorFormats tensor_format) const { TensorFormats tensor_format,
ReformatAttribute extra_attribute) const {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
graph->options().var_sanity_check_first_run = false; graph->options().var_sanity_check_first_run = false;
...@@ -239,8 +251,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, ...@@ -239,8 +251,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr,
auto&& cn = var->comp_node(); auto&& cn = var->comp_node();
auto&& dtype = var->dtype(); auto&& dtype = var->dtype();
auto dval = std::make_shared<DeviceTensorND>(cn, dtype); auto dval = std::make_shared<DeviceTensorND>(cn, dtype);
auto aligned_tensor_shape = auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape(
make_aligned_tensor_shape(var, base_format, tensor_format); var, base_format, tensor_format, extra_attribute);
dval->resize(aligned_tensor_shape); dval->resize(aligned_tensor_shape);
auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval);
new_inps[i] = aligned_var.node(); new_inps[i] = aligned_var.node();
...@@ -263,8 +275,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, ...@@ -263,8 +275,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr,
ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
const OperatorNodeBase* opr, const OperatorNodeBase* opr,
const OprTensorFormatsConfiguration& base_config, const OprTensorFormatsConfiguration& base_config,
const SmallVector<OprTensorFormatsConfiguration>& available_configs) const SmallVector<OprTensorFormatsConfiguration>& available_configs,
const { ReformatAttribute extra_attribute) const {
OperatorNodeRecord record; OperatorNodeRecord record;
record.opr = opr; record.opr = opr;
auto& costs = record.costs; auto& costs = record.costs;
...@@ -273,7 +285,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( ...@@ -273,7 +285,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
if (i.opr_format == OprFormat::NCHW && if (i.opr_format == OprFormat::NCHW &&
opr->input(0)->dtype().enumv() != DTypeEnum::Float32) opr->input(0)->dtype().enumv() != DTypeEnum::Float32)
continue; continue;
costs[i.opr_format] = profile_operator(opr, base_config, i); costs[i.opr_format] =
profile_operator(opr, base_config, i, extra_attribute);
} }
return record; return record;
} }
...@@ -281,7 +294,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( ...@@ -281,7 +294,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
float ProfilerImpl::profile_operator( float ProfilerImpl::profile_operator(
const OperatorNodeBase* opr, const OperatorNodeBase* opr,
const OprTensorFormatsConfiguration& base_config, const OprTensorFormatsConfiguration& base_config,
const OprTensorFormatsConfiguration& config) const { const OprTensorFormatsConfiguration& config,
ReformatAttribute extra_attribute) const {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
graph->options().var_sanity_check_first_run = false; graph->options().var_sanity_check_first_run = false;
...@@ -297,18 +311,18 @@ float ProfilerImpl::profile_operator( ...@@ -297,18 +311,18 @@ float ProfilerImpl::profile_operator(
TensorShape aligned_shape; TensorShape aligned_shape;
if (config.input_tensor_types[i] == TensorType::WEIGHT) { if (config.input_tensor_types[i] == TensorType::WEIGHT) {
mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT); mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT);
aligned_shape = make_aligned_weight_shape( aligned_shape = ReformatManager::make_aligned_weight_shape(
var, base_config.input_tensor_formats[i], var, base_config.input_tensor_formats[i],
config.input_tensor_formats[i], config.input_tensor_formats[i],
config.output_tensor_formats[0]); config.output_tensor_formats[0], extra_attribute);
} else { } else {
mgb_assert(base_config.input_tensor_types[i] == mgb_assert(base_config.input_tensor_types[i] ==
config.input_tensor_types[i]); config.input_tensor_types[i]);
mgb_assert(base_config.input_tensor_types[i] == mgb_assert(base_config.input_tensor_types[i] ==
TensorType::FEATURE); TensorType::FEATURE);
aligned_shape = make_aligned_tensor_shape( aligned_shape = ReformatManager::make_aligned_tensor_shape(
var, base_config.input_tensor_formats[i], var, base_config.input_tensor_formats[i],
config.input_tensor_formats[i]); config.input_tensor_formats[i], extra_attribute);
} }
dval->resize(aligned_shape); dval->resize(aligned_shape);
auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval);
...@@ -357,7 +371,7 @@ float ProfilerImpl::profile_operator( ...@@ -357,7 +371,7 @@ float ProfilerImpl::profile_operator(
ProfilerImpl::VarNodeRecord ProfilerImpl::profile_var_node( ProfilerImpl::VarNodeRecord ProfilerImpl::profile_var_node(
const VarNode* var, TensorFormats base_format, const VarNode* var, TensorFormats base_format,
const SmallVector<TensorFormats>& available_tensor_formats, const SmallVector<TensorFormats>& available_tensor_formats,
ReformatKey::Attribute attribute) const { ReformatAttribute attribute) const {
VarNodeRecord record; VarNodeRecord record;
record.var = var; record.var = var;
auto& costs = record.costs; auto& costs = record.costs;
...@@ -379,8 +393,8 @@ float ProfilerImpl::profile_var_node(const VarNode* var, ...@@ -379,8 +393,8 @@ float ProfilerImpl::profile_var_node(const VarNode* var,
auto&& cn = var->comp_node(); auto&& cn = var->comp_node();
auto&& dtype = var->dtype(); auto&& dtype = var->dtype();
auto dval = std::make_shared<DeviceTensorND>(cn, dtype); auto dval = std::make_shared<DeviceTensorND>(cn, dtype);
auto aligned_tensor_shape = auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape(
make_aligned_tensor_shape(var, base_format, key.input_format); var, base_format, key.input_format, key.attribute);
dval->resize(aligned_tensor_shape); dval->resize(aligned_tensor_shape);
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0; graph->options().graph_opt_level = 0;
...@@ -468,13 +482,14 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( ...@@ -468,13 +482,14 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(
auto base_format = problem.base_format(); auto base_format = problem.base_format();
auto&& available_tensor_formats = problem.available_tensor_formats(); auto&& available_tensor_formats = problem.available_tensor_formats();
auto&& reformat_attribute = problem.attribute().reformat_attribute;
ProfilingResult profiling_result; ProfilingResult profiling_result;
auto& opr_record = profiling_result.opr_record; auto& opr_record = profiling_result.opr_record;
auto& var_record = profiling_result.var_record; auto& var_record = profiling_result.var_record;
for (auto&& var : vars) { for (auto&& var : vars) {
var_record[var] = var_record[var] = profile_var_node(
profile_var_node(var, base_format, available_tensor_formats); var, base_format, available_tensor_formats, reformat_attribute);
} }
for (auto&& opr : oprs) { for (auto&& opr : oprs) {
auto&& opr_configs = problem.opr_configs(); auto&& opr_configs = problem.opr_configs();
...@@ -482,11 +497,12 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( ...@@ -482,11 +497,12 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(
if (find == opr_configs.end()) { if (find == opr_configs.end()) {
if (skip_oprs.count(opr) > 0) { if (skip_oprs.count(opr) > 0) {
SmallVector<TensorFormats> tensor_formats = {base_format}; SmallVector<TensorFormats> tensor_formats = {base_format};
opr_record[opr] = opr_record[opr] = profile_operator(
profile_operator(opr, base_format, tensor_formats); opr, base_format, tensor_formats, reformat_attribute);
} else { } else {
opr_record[opr] = profile_operator(opr, base_format, opr_record[opr] = profile_operator(opr, base_format,
available_tensor_formats); available_tensor_formats,
reformat_attribute);
} }
} else { } else {
auto&& dispatchers = find->second; auto&& dispatchers = find->second;
...@@ -498,7 +514,8 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( ...@@ -498,7 +514,8 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(
} }
} }
auto base_config = problem.base_config(opr); auto base_config = problem.base_config(opr);
opr_record[opr] = profile_operator(opr, base_config, configs); opr_record[opr] = profile_operator(opr, base_config, configs,
reformat_attribute);
} }
} }
for (auto&& rpair : opr_record) { for (auto&& rpair : opr_record) {
......
...@@ -21,7 +21,7 @@ using NamedTensorShape = megdnn::NamedTensorShape; ...@@ -21,7 +21,7 @@ using NamedTensorShape = megdnn::NamedTensorShape;
using Dimension = megdnn::Dimension; using Dimension = megdnn::Dimension;
namespace { namespace {
int gcd(const int& p, const int& q) { static inline int gcd(const int& p, const int& q) {
int x = p, y = q; int x = p, y = q;
while (y != 0) { while (y != 0) {
if (x < y) { if (x < y) {
...@@ -33,6 +33,47 @@ int gcd(const int& p, const int& q) { ...@@ -33,6 +33,47 @@ int gcd(const int& p, const int& q) {
} }
return x; return x;
} }
static inline size_t extra_alignment(
ReformatManager::ReformatKey::Attribute attr,
TensorFormats target_formats, DType dt, size_t channel_alignment) {
using Attribute = ReformatManager::ReformatKey::Attribute;
if (attr & Attribute::AUTO_PADDING_NHWC) {
constexpr size_t alignment_in_bits = 32;
size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8;
size_t extra_alignment = alignment_in_bits >= dtype_bits
? alignment_in_bits / dtype_bits
: 1;
if (target_formats == TensorFormats::NHWC)
channel_alignment = extra_alignment * channel_alignment /
gcd(channel_alignment, extra_alignment);
return channel_alignment;
}
return channel_alignment;
}
static inline std::tuple<size_t, size_t> extra_alignment(
const ReformatManager::ReformatKey& key, DType dt,
size_t input_channel_alignment, size_t output_channel_alignment) {
using Attribute = ReformatManager::ReformatKey::Attribute;
if (key.attribute & Attribute::AUTO_PADDING_NHWC) {
constexpr size_t alignment_in_bits = 32;
size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8;
size_t extra_alignment = alignment_in_bits >= dtype_bits
? alignment_in_bits / dtype_bits
: 1;
if (key.input_format == TensorFormats::NHWC)
input_channel_alignment =
input_channel_alignment * extra_alignment /
gcd(input_channel_alignment, extra_alignment);
if (key.output_format == TensorFormats::NHWC)
output_channel_alignment =
output_channel_alignment * extra_alignment /
gcd(output_channel_alignment, extra_alignment);
return {input_channel_alignment, output_channel_alignment};
}
return {input_channel_alignment, output_channel_alignment};
}
}; // namespace }; // namespace
// =================== ReformatManager::ReformatKey ====================*/ // =================== ReformatManager::ReformatKey ====================*/
...@@ -293,7 +334,8 @@ ReformatManager::ReformatImpl ReformatManager::get( ...@@ -293,7 +334,8 @@ ReformatManager::ReformatImpl ReformatManager::get(
auto rst = find->second; auto rst = find->second;
return rst; return rst;
} }
mgb_assert(key.attribute == Attribute::DEFAULT); mgb_assert(!(key.attribute & Attribute::IMAGE2D) &&
!(key.attribute & Attribute::IC_SMALL));
auto&& i = key.input_format; auto&& i = key.input_format;
auto&& o = key.output_format; auto&& o = key.output_format;
auto ishp = tensor_formats_to_named_tensor_shape(i); auto ishp = tensor_formats_to_named_tensor_shape(i);
...@@ -346,6 +388,8 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_featrue( ...@@ -346,6 +388,8 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_featrue(
"invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)",
input_alignment, output_alignment, input_alignment, output_alignment,
input_shape.to_string().c_str()); input_shape.to_string().c_str());
std::tie(input_alignment, output_alignment) = extra_alignment(
key, orig_var->dtype(), input_alignment, output_alignment);
NamedTensorShape orig_shape = NamedTensorShape orig_shape =
tensor_formats_to_named_tensor_shape(orig_format); tensor_formats_to_named_tensor_shape(orig_format);
size_t orig_channel = 0; size_t orig_channel = 0;
...@@ -451,6 +495,12 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( ...@@ -451,6 +495,12 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight(
"invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)",
in_channel_alignment, out_channel_alignment, in_channel_alignment, out_channel_alignment,
output_shape.to_string().c_str()); output_shape.to_string().c_str());
in_channel_alignment =
::extra_alignment(key.attribute, key.output_format,
orig_var->dtype(), in_channel_alignment);
out_channel_alignment =
::extra_alignment(key.attribute, key.output_format,
orig_var->dtype(), out_channel_alignment);
size_t aligned_in_channel = size_t aligned_in_channel =
divup(in_channels, in_channel_alignment) * in_channel_alignment; divup(in_channels, in_channel_alignment) * in_channel_alignment;
if (extra_alignment.name == out_channel_name) { if (extra_alignment.name == out_channel_name) {
...@@ -506,9 +556,9 @@ const ReformatManager& ReformatManager::instance() { ...@@ -506,9 +556,9 @@ const ReformatManager& ReformatManager::instance() {
return inst; return inst;
} }
TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, TensorShape ReformatManager::make_aligned_tensor_shape(
TensorFormats orig_formats, const VarNode* var, TensorFormats orig_formats,
TensorFormats target_formats) { TensorFormats target_formats, ReformatKey::Attribute extra_attribute) {
using Dimension = megdnn::Dimension; using Dimension = megdnn::Dimension;
static constexpr uint32_t UNDETERMINED_EXTENT = static constexpr uint32_t UNDETERMINED_EXTENT =
Dimension::UNDETERMINED_EXTENT; Dimension::UNDETERMINED_EXTENT;
...@@ -545,6 +595,15 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, ...@@ -545,6 +595,15 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var,
tshp[i] = oshp[idx] * factor; tshp[i] = oshp[idx] * factor;
else else
tshp[i] = divup(oshp[idx], factor); tshp[i] = divup(oshp[idx], factor);
if (name == Dimension::Name::C) {
size_t channel_alignment = target_shape[i].stride();
size_t channels = tshp[i] * channel_alignment;
size_t new_channel_alignment =
extra_alignment(extra_attribute, target_formats,
var->dtype(), channel_alignment);
tshp[i] = divup(channels, new_channel_alignment) *
new_channel_alignment / channel_alignment;
}
} else { } else {
tshp[i] = target_shape[i].extent(); tshp[i] = target_shape[i].extent();
} }
...@@ -552,11 +611,12 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, ...@@ -552,11 +611,12 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var,
return tshp; return tshp;
} }
TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, TensorShape ReformatManager::make_aligned_weight_shape(
TensorFormats orig_formats, const VarNode* var, TensorFormats orig_formats,
TensorFormats target_formats, TensorFormats target_formats, TensorFormats extra_formats,
TensorFormats extra_formats) { ReformatKey::Attribute extra_attribute) {
auto tshp = make_aligned_tensor_shape(var, orig_formats, target_formats); auto tshp = make_aligned_tensor_shape(var, orig_formats, target_formats,
extra_attribute);
auto extra_shape = tensor_formats_to_named_tensor_shape(extra_formats); auto extra_shape = tensor_formats_to_named_tensor_shape(extra_formats);
using Dimension = megdnn::Dimension; using Dimension = megdnn::Dimension;
static constexpr uint32_t UNDETERMINED_EXTENT = static constexpr uint32_t UNDETERMINED_EXTENT =
...@@ -567,6 +627,9 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, ...@@ -567,6 +627,9 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var,
if (name == Dimension::Name::C && if (name == Dimension::Name::C &&
extra_shape[i].extent() == UNDETERMINED_EXTENT) { extra_shape[i].extent() == UNDETERMINED_EXTENT) {
out_channel_alignment = extra_shape[i].stride(); out_channel_alignment = extra_shape[i].stride();
out_channel_alignment =
extra_alignment(extra_attribute, target_formats,
var->dtype(), out_channel_alignment);
} }
} }
...@@ -583,9 +646,8 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, ...@@ -583,9 +646,8 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var,
return tshp; return tshp;
} }
ReformatManager::AlignmentDesc mgb::gopt::make_aligned_desc( ReformatManager::AlignmentDesc ReformatManager::make_aligned_desc(
TensorFormats weight_format, TensorFormats out_feature_format) { TensorFormats weight_format, TensorFormats out_feature_format) {
using AlignmentDesc = ReformatManager::AlignmentDesc;
using Name = Dimension::Name; using Name = Dimension::Name;
auto weight_shape = tensor_formats_to_named_tensor_shape(weight_format); auto weight_shape = tensor_formats_to_named_tensor_shape(weight_format);
auto out_shape = tensor_formats_to_named_tensor_shape(out_feature_format); auto out_shape = tensor_formats_to_named_tensor_shape(out_feature_format);
......
...@@ -143,6 +143,7 @@ public: ...@@ -143,6 +143,7 @@ public:
TensorFormats base_format() const { TensorFormats base_format() const {
return m_ctx.attribute().base_tensor_formats; return m_ctx.attribute().base_tensor_formats;
} }
Attribute attribute() const { return m_ctx.attribute(); }
/*! /*!
* \brief return the tensor formats configuration of an operator in the * \brief return the tensor formats configuration of an operator in the
* default op format * default op format
......
...@@ -74,6 +74,7 @@ public: ...@@ -74,6 +74,7 @@ public:
DEFAULT = 0, DEFAULT = 0,
IMAGE2D = 1 << 0, IMAGE2D = 1 << 0,
IC_SMALL = 1 << 1, IC_SMALL = 1 << 1,
AUTO_PADDING_NHWC = 1 << 2,
}; };
TensorFormats input_format, output_format; TensorFormats input_format, output_format;
DTypeEnum input_dtype, output_dtype; DTypeEnum input_dtype, output_dtype;
...@@ -124,23 +125,40 @@ public: ...@@ -124,23 +125,40 @@ public:
ReformatImpl auto_aligned_reformat_weight( ReformatImpl auto_aligned_reformat_weight(
const VarNode* orig_var, const ReformatKey& key, const VarNode* orig_var, const ReformatKey& key,
const AlignmentDesc& extra_alignment = {}) const; const AlignmentDesc& extra_alignment = {}) const;
static TensorShape make_aligned_tensor_shape(
const VarNode* var, TensorFormats orig_formats,
TensorFormats target_formats,
ReformatKey::Attribute extra_attribute =
ReformatKey::Attribute::DEFAULT);
static TensorShape make_aligned_weight_shape(
const VarNode* var, TensorFormats orig_formats,
TensorFormats target_formats, TensorFormats extra_formats,
ReformatKey::Attribute extra_attribute =
ReformatKey::Attribute::DEFAULT);
static AlignmentDesc make_aligned_desc(TensorFormats weight_format,
TensorFormats out_feature_format);
static const ReformatManager& instance(); static const ReformatManager& instance();
private: private:
ReformatCache m_cache; ReformatCache m_cache;
}; };
TensorShape make_aligned_tensor_shape(const VarNode* var, MGB_DEF_ENUM_CLASS_BIT_OPR(ReformatManager::ReformatKey::Attribute);
TensorFormats orig_formats, //
TensorFormats target_formats); //TensorShape make_aligned_tensor_shape(
// const VarNode* var, TensorFormats orig_formats,
TensorShape make_aligned_weight_shape(const VarNode* var, // TensorFormats target_formats,
TensorFormats orig_formats, // ReformatManager::ReformatKey::Attribute extra_attribute =
TensorFormats target_formats, // ReformatManager::ReformatKey::Attribute::DEFAULT);
TensorFormats extra_formats); //
//TensorShape make_aligned_weight_shape(
// const VarNode* var, TensorFormats orig_formats,
// TensorFormats target_formats, TensorFormats extra_formats,
// ReformatManager::ReformatKey::Attribute extra_attribute =
// ReformatManager::ReformatKey::Attribute::DEFAULT);
ReformatManager::AlignmentDesc make_aligned_desc(
TensorFormats weight_format, TensorFormats out_feature_format);
} // namespace gopt } // namespace gopt
} // namespace mgb } // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册