提交 0871623a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!110 rewrite insn pattern generator in EmitInsn

Merge pull request !110 from wYann/insn_reduction_final
此差异已折叠。
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef EMIT_INSN_ARGS_CALCULATOR_H_
#define EMIT_INSN_ARGS_CALCULATOR_H_
namespace akg {
struct InsnArg {
int dst_m0{1};
int dst_m1{0};
std::vector<Expr> src_m0_list;
std::vector<Expr> src_m1_list;
int repeat{1};
int block_len{1};
int block_num{1};
int body_num{1};
int tail_len{0};
int dst_tail_offset{0};
std::vector<Expr> src_tail_offset_list;
};
struct Meta {
int block_size{0};
int src_block_size{0};
int dst_block_size{0};
int block_offset{0};
const float vec_rate{0.6};
Type src_dtype;
Type dst_dtype;
Type dtype;
bool cast{false};
bool tail{false};
bool scalar{false};
bool liner{false};
bool same_dst_src{false};
};
enum SplitStat { SUCCESS, NO_SPLIT, TAIL };
class InsnAxis {
public:
InsnAxis() = default;
InsnAxis(const For *for_stmt, const Array<StmtStoreInfo> &info_list);
virtual ~InsnAxis() = default;
bool IsValid();
void Print(const std::string &name = "");
int min{0};
int extent{0};
Var var;
int dst_stride{0};
int src_stride{0};
std::vector<int> src_stride_list;
std::vector<int> stride_list;
bool is_valid{true};
private:
Expr GetStrideByAxis(const Array<Var> &vars, const Array<Expr> &strides, Var obj_var);
};
using AxisIt = std::list<InsnAxis>::iterator;
std::list<InsnAxis> GetAxisList(const StmtInfo &for_info, const Array<StmtStoreInfo> &info_list);
Array<StmtStoreInfo> GetInfoList(const StmtStoreInfo &dst_info, const Array<StmtStoreInfo> &src_info_list);
int DivFloor(int a, int b);
void Print(std::list<InsnAxis> &axis_list);
class InsnArgsCalculator {
public:
InsnArgsCalculator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, const StmtInfo &for_info,
const std::string &intrin_name);
virtual ~InsnArgsCalculator() = default;
PatternResult ExportResult();
void CalAxis();
void InitArg();
virtual std::function<bool(const InsnAxis &)> GetStrideLambda();
virtual std::function<bool(const InsnAxis &)> GetM0LimitLambda();
virtual std::function<bool(const InsnAxis &)> GetM1LimitLambda();
std::function<bool(const InsnAxis &)> GetBlockStrideLimitLambda();
AxisIt GetAxisByLambda(const std::function<bool(const InsnAxis &)> &lambda);
InsnAxis ExtractAxis(AxisIt &it);
bool IsValid(AxisIt &it);
AxisIt GetVecAxisIt();
AxisIt GetBlockAxis();
AxisIt GetRepeatAxisIt();
InsnAxis GetRepeatAxis();
void SetArgMask(int len);
void SetArgBlockNum(int data_num);
void SetArgBlockLen(int data_len);
void SetArgM0(int dst_m0, int lsrc_m0, int rsrc_m0);
void SetArgM1(int dst_m1, int lsrc_m1, int rsrc_m1);
void SetArgRepeat(int repeat);
void BlockAxisReduction();
void RepeatAxisReduction();
void CastCaseReduction();
virtual void InsnReduction();
StmtInfo ExportForInfo();
Expr GetOffset(int stride_index);
InsnAxis GetInvalidAxis();
SplitStat SplitAxis(int extent, InsnAxis &axis);
std::list<InsnAxis> axis_list_;
protected:
InsnArg arg_;
Meta meta_;
StmtInfoList dst_info_list_;
StmtInfoList src_info_list_;
StmtStoreInfo dst_info_;
StmtInfo for_info_;
const std::string intrin_name_;
const int max_block_stride_{4};
};
class SingleVecInsnArgsCalculator : public InsnArgsCalculator {
public:
SingleVecInsnArgsCalculator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, const StmtInfo &for_info,
const std::string &intrin_name = "");
virtual ~SingleVecInsnArgsCalculator() override = default;
PatternResult GetInsnArgs();
};
class BinaryVecInsnArgsCalculator : public InsnArgsCalculator {
public:
BinaryVecInsnArgsCalculator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, const StmtInfo &for_info,
const std::string &mode, const std::string &intrin_name = "", bool expand_mask = true);
virtual ~BinaryVecInsnArgsCalculator() override = default;
PatternResult GetInsnArgs();
std::function<bool(const InsnAxis &)> GetM0LimitLambda();
std::function<bool(const InsnAxis &)> GetM1LimitLambda();
void InsnReduction();
private:
std::string mode_;
bool expand_mask_;
InsnAxis vec_axis_;
};
class LastAxisReduceInsnArgsCalculator : InsnArgsCalculator{
public:
LastAxisReduceInsnArgsCalculator(const StmtStoreInfo &dst_info, const StmtStoreInfo &src_info, const StmtInfo &for_info,
const std::string &intrin_name)
: InsnArgsCalculator({dst_info}, {src_info}, for_info, intrin_name),
dst_info(dst_info),
src_info(src_info),
for_info(for_info),
arg_info(ArgInfo(make_node<ArgInfoNode>())),
body_args(VectorArgInfo()),
tail_args(VectorArgInfo()),
intrin_name(intrin_name) {}
PatternResult GetInsnArgs();
~LastAxisReduceInsnArgsCalculator() = default;
protected:
Array<Var> GetPattern();
PatternResult GenResult(const Array<Var> &elim_var);
private:
void CalcParams();
struct Params {
Array<Var> src_var;
int block_size = 0;
int vec_max_len = 0;
int last_dim_shape = 0;
Expr insn_offset_scale_factor;
};
StmtStoreInfo dst_info;
StmtStoreInfo src_info;
StmtInfo for_info;
ArgInfo arg_info;
VectorArgInfo body_args;
VectorArgInfo tail_args;
Array<VectorArgInfo> mix_vec_arg_list;
std::string intrin_name;
Params params;
};
BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_info_list,
const StmtInfoList &src_info_list, const StmtInfo &for_info,
StmtInfo &if_info, bool last_axis, int postfix);
ArgInfo GetBinaryVecInsnArgs(const Stmt &stmt, std::string intrin_name, StmtInfoList &dst_info_list,
StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info,
bool enable_bisect = true);
} // namespace akg
#endif
\ No newline at end of file
此差异已折叠。
......@@ -35,7 +35,7 @@
#include "insn_info.h"
#include "insn_pattern.h"
#include "insn_emitter_multimask.h"
#include "insn_args_calculator.h"
namespace akg {
namespace ir {
/// Sort indexes
......@@ -71,8 +71,7 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
Array<Expr> call_args;
int call_cnt = 0;
if (intrin_name == "vector_dup" || intrin_name == "vadds" ||
intrin_name == "vmuls" || intrin_name == "vaxpy") {
if (intrin_name == "vector_dup" || intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == "vaxpy") {
auto GetCallInfo = [&intrin_name, &call_args, &call_cnt](const NodeRef &op) {
if (op.as<Call>() && op.as<Call>()->name == intrin_name) {
call_args = op.as<Call>()->args;
......@@ -82,8 +81,8 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
PostOrderVisit(op, GetCallInfo);
CHECK_EQ(call_cnt, 1);
}
SingleType insn_type {SingleType::SIMD};
Expr scalar_src {};
SingleType insn_type{SingleType::SIMD};
Expr scalar_src{};
if (intrin_name == "vector_dup") {
insn_type = SingleType::Vector_Dump;
src_info_list = {};
......@@ -93,10 +92,11 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
src_info_list = {src_info_list[0]};
scalar_src = call_args[1];
}
// check is single vector broadcast reduce mode exist
SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info);
auto params = generator.GetInsnArgs();
SingleVecInsnArgsCalculator args_calculator = SingleVecInsnArgsCalculator(dst_info_list, src_info_list, for_info, intrin_name);
PatternResult params = args_calculator.GetInsnArgs();
dst_info_list = params.dst_info_list;
src_info_list = params.src_info_list;
for_info = params.for_info;
......@@ -141,23 +141,16 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) {
src_info = src_info_list[0];
}
const int vec_max_len = GetVecMaxLen(dst_info->dtype_);
if (enable_bisect && GetIntConst(GetItem(src_info->shape_, -1)) > vec_max_len) {
CommentManager::GetInstance().AddComment("Bisect_optimize", "enabled");
auto wrapper =
SeparateComInfoToBisectionInfoList(dst_info_list, src_info_list, for_info, if_info, true, postfix);
return EmitCceBinaryVectorToBisectionReduction(wrapper, if_info, intrin_name);
} else {
CommentManager::GetInstance().AddComment("Pattern", arg_info.GetPattern());
ReduceLastAxisPatternGenerator generator =
ReduceLastAxisPatternGenerator(dst_info, src_info, for_info, intrin_name);
auto result = generator.GetInsnArgs();
arg_info = result.arg_info;
dst_info = result.dst_info_list[0];
src_info = result.src_info_list[0];
for_info = result.for_info;
return EmitCceBinaryVectorToReduceLastAxis(dst_info, src_info, if_info, for_info, arg_info, intrin_name);
}
CommentManager::GetInstance().AddComment("Pattern", arg_info.GetPattern());
LastAxisReduceInsnArgsCalculator args_calculator = LastAxisReduceInsnArgsCalculator(dst_info, src_info, for_info, intrin_name);
PatternResult result = args_calculator.GetInsnArgs();
arg_info = result.arg_info;
dst_info = result.dst_info_list[0];
src_info = result.src_info_list[0];
for_info = result.for_info;
return EmitCceBinaryVectorToReduceLastAxis(dst_info, src_info, if_info, for_info, arg_info, intrin_name);
}
case ARG_VECTOR_REDUCTION_BISECTION: {
CommentManager::GetInstance().AddComment("Compute_type", "reduction");
......@@ -192,7 +185,7 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
return FoldInsnWithForInfo(insn_list, if_info, for_info, stmt);
}
}
}
} // namespace ir
/// Function to emit scalar intrin
/// \param op - The input stmt to be emitted as intrin
......@@ -984,8 +977,9 @@ Stmt BinaryDropoutEmitter(const Stmt &op) {
src1.GetNode()->data_ = mask->buffer_var;
src1.GetNode()->data_alignment_ = GetInt32Const(mask->predicate);
SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info, "elewise");
auto params = generator.GetInsnArgs();
SingleVecInsnArgsCalculator args_calculator = SingleVecInsnArgsCalculator(dst_info_list, src_info_list, for_info);
PatternResult params = args_calculator.GetInsnArgs();
dst_info_list = params.dst_info_list;
src_info_list = params.src_info_list;
for_info = params.for_info;
......@@ -1484,8 +1478,10 @@ Stmt BinaryArgOpEmitter(const Stmt &op, const std::string &intrin_name) {
if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) {
src_info = src_info_list[0];
}
ReduceLastAxisPatternGenerator generator = ReduceLastAxisPatternGenerator(dst_info, src_info, for_info, intrin_name);
auto result = generator.GetInsnArgs();
LastAxisReduceInsnArgsCalculator args_calculator = LastAxisReduceInsnArgsCalculator(dst_info, src_info, for_info, intrin_name);
PatternResult result = args_calculator.GetInsnArgs();
arg_info = result.arg_info;
dst_info = result.dst_info_list[0];
src_info = result.src_info_list[0];
......
......@@ -104,10 +104,7 @@ StmtStoreInfo StmtStoreInfo::Copy() const {
StmtInfo StmtInfo::Copy() const {
auto stmt_info = StmtInfo();
stmt_info.ops_ = ops_;
for (auto var : vars_) {
auto new_var = Variable::make(var->type, var->name_hint);
stmt_info.vars_.push_back(new_var);
}
stmt_info.vars_ = vars_;
for (size_t i = 0; i < vars_.size(); ++i) {
for (size_t j = 0; j < stmt_info.ops_.size(); ++j) {
......
......@@ -276,15 +276,7 @@ struct BisectionInfoWrapper {
Map<std::string, Expr> dma_arg_info_map_;
};
struct InsnAxis {
int min{0};
int extent{0};
Var var;
int dst_stride{0};
int src_stride{0};
std::list<int> src_stride_list;
std::list<int> stride_list;
};
IterVar GetCceAxis();
......
......@@ -15,7 +15,6 @@
*/
#include "insn_pattern.h"
#include <tvm/runtime/packed_func.h>
#include <tvm/base.h>
#include <tvm/ir_pass.h>
......@@ -200,28 +199,6 @@ ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_
return arg_info;
}
/// Get first non zero shape from input shapes
/// \param dst_shape
/// \param src0_shape
/// \param src1_shape
/// \return
int PatternGenerator::GetNonZeroShape(const Expr &dst_shape, const Expr &src0_shape, const Expr &src1_shape) {
int shape = 0;
for (int val :
{GetInt32Const(dst_shape), GetInt32Const(src0_shape), src1_shape.defined() ? GetInt32Const(src1_shape) : 0}) {
if (val == 0) {
continue;
}
if (shape != 0 && val != shape) {
LOG(FATAL) << "Error: same var has different shapes. " << GetIntConst(dst_shape) << " "
<< GetIntConst(src0_shape);
}
shape = val;
}
CHECK(shape != 0) << "Error: all shapes are equal to 0.";
return shape;
}
/// In case
/// for (cc3) {
/// A[(cc3*16)] = (B[(cc3*16)] - C[(cc3*16)])
......@@ -432,25 +409,6 @@ void CleanZeroStrides(Array<StmtStoreInfo> &info_list) {
}
}
/// Swap axis in Array
/// \param var
/// \param shape
/// \param strides
/// \param idx1
/// \param idx2
void PatternGenerator::GetShapeInfoAndSwap(Array<Var> &var, Array<Expr> &shape, Array<Expr> &strides, int idx1,
int idx2) {
auto tmp_var = GetItem(var, idx1);
SetItem(var, idx1, GetItem(var, idx2));
SetItem(var, idx2, tmp_var);
auto tmp_shape = GetItem(shape, idx1);
SetItem(shape, idx1, GetItem(shape, idx2));
SetItem(shape, idx2, tmp_shape);
auto tmp_stride = GetItem(strides, idx1);
SetItem(strides, idx1, GetItem(strides, idx2));
SetItem(strides, idx2, tmp_stride);
}
/// Get insn args of load 2D intrin
/// \param intrin_name
/// \param dst_info_list
......@@ -856,6 +814,38 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
return arg_info_map;
}
/// Replace com_info's var with new for loop's var
/// \param info
/// \param old_for_info
/// \param new_for_info
void ReplaceVarWithNewForInfo(StmtStoreInfo &info, const StmtInfo &old_for_info, const StmtInfo &new_for_info) {
for (size_t i = 0; i < new_for_info.vars_.size(); ++i) {
for (size_t j = 0; j < info->var_.size(); ++j) {
if (info->var_[j]->name_hint == new_for_info.vars_[i]->name_hint) {
SetItem(info.GetNode()->var_, static_cast<int>(j), new_for_info.vars_[i]);
}
}
info.GetNode()->index_ = substitute(old_for_info.vars_[i], new_for_info.vars_[i], info->index_);
}
}
std::string GetBinaryVecMode(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
const std::string &intrin_name, bool enable_bisect) {
std::set<std::string> reduce_bisect_list = {"vadd", "vsub", "vmul", "vmax"};
std::string mode = "reduction";
if (IsElementwise(dst_info_list, src_info_list)) {
mode = "elewise";
} else if (IsBroadcast(dst_info_list, src_info_list)) {
mode = "broadcast";
} else if (IsLastAxisReduction(dst_info_list, src_info_list)) {
mode = "reduce_last_axis";
} else if (enable_bisect && reduce_bisect_list.count(intrin_name) != 0 &&
IsBisectionReduction(dst_info_list, src_info_list)) {
mode = "reduce_bisection";
}
return mode;
}
const char *const DummyLastVar = "cc_last";
TVM_REGISTER_API("cce_util.GetVecMask").set_body([](const TVMArgs args, TVMRetValue *ret) {
......
......@@ -37,220 +37,12 @@ struct PatternResult {
StmtInfo for_info;
};
class PatternGenerator {
public:
PatternGenerator(const StmtInfoList &dst_info_list, const StmtInfo &for_info)
: for_info(for_info),
not_this_pattern(-1.0f),
split_latency_coef(10.0f),
repeat_latency_coef(3.0f),
offset_latency_coef(0.1f) {
CHECK(!dst_info_list.empty());
dst_info = dst_info_list[0];
}
virtual ~PatternGenerator() = default;
virtual PatternResult GetInsnArgs() = 0;
protected:
int GetNonZeroShape(const Expr &dst_shape, const Expr &src0_shape, const Expr &src1_shape = Expr());
void GetShapeInfoAndSwap(Array<Var> &var, Array<Expr> &shape, Array<Expr> &strides, int idx1, int idx2);
virtual float Compute3DPatternMaskRate() { return not_this_pattern; }
virtual float Compute2DBlockPatternMaskRate() { return not_this_pattern; }
virtual float Compute2DPatternMaskRate() { return not_this_pattern; }
virtual float Compute1DPatternMaskRate() { return not_this_pattern; }
virtual Array<Var> Get3DPattern() { return {}; }
virtual Array<Var> Get2DBlockPattern() { return {}; }
virtual Array<Var> Get2DPattern() { return {}; }
virtual Array<Var> Get1DPattern() { return {}; }
virtual PatternResult GenResult(const Array<Var> &elim_var) = 0;
StmtStoreInfo dst_info;
StmtInfo for_info;
const float not_this_pattern;
const float split_latency_coef;
const float repeat_latency_coef;
const float offset_latency_coef;
};
class SingleVecPatternGenerator : public PatternGenerator {
public:
SingleVecPatternGenerator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
const StmtInfo &for_info, const std::string &mode = "elewise")
: PatternGenerator(dst_info_list, for_info),
arg_info(ArgInfo(make_node<ArgInfoNode>())),
body_args(VectorArgInfo()),
tail_args(VectorArgInfo()),
mode(mode) {
if (src_info_list.empty()) {
src_info = dst_info.Copy();
} else {
CHECK(!src_info_list.empty());
src_info = src_info_list[0];
}
}
~SingleVecPatternGenerator() override = default;
PatternResult GetInsnArgs() final;
protected:
float Compute3DPatternMaskRate() final;
float Compute2DBlockPatternMaskRate() final;
float Compute2DPatternMaskRate() final;
float Compute1DPatternMaskRate() final;
float Compute3DsPatternMaskRate();
float Compute2DRepeatPatternMaskRate();
Array<Var> Get3DPattern() final;
Array<Var> Get2DBlockPattern() final;
Array<Var> Get2DPattern() final;
Array<Var> Get1DPattern() final;
Array<Var> Get3DsPattern();
Array<Var> Get2DRepeatPattern();
PatternResult GenResult(const Array<Var> &elim_var) final;
private:
void CalcParams();
int GetLastDimShape(const Expr &dst_shape, const Expr &src_shape);
struct Params {
Array<Var> dst_var;
Array<Var> src_var;
Array<Expr> dst_shape;
Array<Expr> src_shape;
Array<Expr> dst_strides;
Array<Expr> src_strides;
int non_zero_shape1 = 0;
int non_zero_shape2 = 0;
int non_zero_shape3 = 0;
int all_points = 0;
int dst_block_size = 0;
int src_block_size = 0;
int mask_block_size = 0;
int dst_bits = 0;
int src_bits = 0;
int max_bits = 0;
int dst_vec_max_len = 0;
int vec_max_len = 0;
int block_offset = 0;
};
StmtStoreInfo src_info;
Params params;
ArgInfo arg_info;
VectorArgInfo body_args;
VectorArgInfo tail_args;
std::string mode;
Type data_type;
};
class BinaryVecPatternGenerator : public PatternGenerator {
public:
BinaryVecPatternGenerator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
const StmtInfo &for_info, const std::string &mode, bool expand_mask = true)
: PatternGenerator(dst_info_list, for_info),
src_info_list(src_info_list),
arg_info(ArgInfo(make_node<ArgInfoNode>())),
body_args(VectorArgInfo()),
tail_args(VectorArgInfo()),
empty_var(Var("")),
mode(mode),
expand_mask(expand_mask) {}
~BinaryVecPatternGenerator() override = default;
PatternResult GetInsnArgs() final;
protected:
float Compute3DPatternMaskRate() final;
float Compute2DBlockPatternMaskRate() final;
float Compute2DPatternMaskRate() final;
float Compute1DPatternMaskRate() final;
Array<Var> Get3DPattern() final;
Array<Var> Get2DBlockPattern() final;
Array<Var> Get2DPattern() final;
Array<Var> Get1DPattern() final;
PatternResult GenResult(const Array<Var> &elim_var) final;
private:
void CalcParams();
bool IsSamePatternComInfo(const StmtStoreInfo &info_a, const StmtStoreInfo &info_b);
bool IsNonZeroShapeEqual(const Array<Expr> &shape_list);
void AppendEmptyVar(StmtInfoList &info_list);
struct Params {
Array<Var> dst_var;
Array<Expr> dst_shape;
Array<Expr> dst_strides;
Array<Var> src_var0;
Array<Expr> src_shape0;
Array<Expr> src_strides0;
Array<Var> src_var1;
Array<Expr> src_shape1;
Array<Expr> src_strides1;
int non_zero_shape1 = 0;
int non_zero_shape2 = 0;
int non_zero_shape3 = 0;
int all_points = 0;
int block_size = 0;
int last_dim_shape = 0;
int vec_max_len = 0;
};
StmtInfoList src_info_list;
ArgInfo arg_info;
VectorArgInfo body_args;
VectorArgInfo tail_args;
Params params;
Var empty_var;
std::string mode;
bool expand_mask;
};
class ReduceLastAxisPatternGenerator : public PatternGenerator {
public:
ReduceLastAxisPatternGenerator(const StmtStoreInfo &dst_info, const StmtStoreInfo &src_info, const StmtInfo &for_info,
const std::string &intrin_name)
: PatternGenerator({dst_info}, for_info),
src_info(src_info),
arg_info(ArgInfo(make_node<ArgInfoNode>())),
body_args(VectorArgInfo()),
tail_args(VectorArgInfo()),
intrin_name(intrin_name) {}
PatternResult GetInsnArgs() final;
~ReduceLastAxisPatternGenerator() override = default;
protected:
float Compute2DBlockPatternMaskRate() final;
Array<Var> Get2DBlockPattern() final;
Array<Var> Get1DPattern() final;
PatternResult GenResult(const Array<Var> &elim_var) final;
private:
void CalcParams();
struct Params {
Array<Var> src_var;
int block_size = 0;
int vec_max_len = 0;
int last_dim_shape = 0;
Expr insn_offset_scale_factor;
};
StmtStoreInfo src_info;
ArgInfo arg_info;
VectorArgInfo body_args;
VectorArgInfo tail_args;
Array<VectorArgInfo> mix_vec_arg_list;
std::string intrin_name;
Params params;
};
std::string GetSingleVecComputationInfo(const Stmt &stmt, const std::string &intrin_name,
Array<StmtStoreInfo> &dst_info_list, Array<StmtStoreInfo> &src_info_list,
StmtInfo &if_info, StmtInfo &for_info, bool need_compact = true);
ArgInfo GetBinaryVecInsnArgs(const Stmt &stmt, std::string intrin_name, StmtInfoList &dst_info_list,
StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info,
bool enable_bisect = true);
std::string GetBinaryVecMode(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
const std::string &intrin_name, bool enable_bisect = true);
ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_list, StmtInfo &for_info);
......@@ -277,10 +69,7 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
const StmtInfoList &src_info_list, StmtInfo &for_info,
Map<std::string, Expr> &ub_copy_pre, Map<std::string, Expr> &ub_copy_post);
BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_info_list,
const StmtInfoList &src_info_list, const StmtInfo &for_info,
StmtInfo &if_info, bool last_axis, int postfix);
void ReplaceVarWithNewForInfo(StmtStoreInfo &info, const StmtInfo &old_for_info, const StmtInfo &new_for_info);
extern const char *const DummyLastVar;
} // namespace akg
#endif // EMIT_INSN_INSN_PATTERN_H_
此差异已折叠。
......@@ -21,6 +21,7 @@
#include "pass/ir_util.h"
#include "poly/poly_util.h"
#include "emit_insn/insn_emitter.h"
#include "emit_insn/ir_transform.h"
namespace akg {
namespace ir {
......@@ -475,6 +476,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma
}
stmt = UnalignedMad().Mutate(stmt);
stmt = RegCondition().Mutate(stmt);
stmt = ForVarUnique().Mutate(stmt);
return stmt;
}
} // namespace ir
......
......@@ -343,8 +343,12 @@ class BroadcastCalculate : public IRMutator {
};
Stmt MultiLastAxisReductions(Stmt stmt, bool is_dynamic = false) {
auto ori_stmt = stmt;
stmt = MultiLastAxisReduction().Mutate(stmt);
stmt = BroadcastCalculate(is_dynamic).Mutate(stmt);
if (!is_dynamic && !Equal(ori_stmt, stmt)) {
stmt = MergeLoops(stmt);
}
return stmt;
}
} // namespace ir
......
......@@ -21,7 +21,7 @@
#include <algorithm>
#include "emit_insn/insn_info.h"
#include "emit_insn/insn_pattern.h"
#include "emit_insn/insn_args_calculator.h"
namespace akg {
namespace ir {
......@@ -48,85 +48,63 @@ class TailSpliter : public IRMutator {
if (src_info_list.empty()) {
src_info_list = {dst_info.Copy()};
}
auto get_info_list = [](const StmtStoreInfo &dst_info, const Array<StmtStoreInfo> &src_info_list) {
Array<StmtStoreInfo> res;
res.push_back(dst_info.Copy());
for (auto it : src_info_list) {
res.push_back(it.Copy());
}
return res;
};
auto info_list = get_info_list(dst_info, src_info_list);
FillEmptyVar(info_list);
auto axis_list = GetAixsList(for_info, info_list);
auto get_last_axis_it = [](const std::list<InsnAxis> &axis_list) {
for (auto it = axis_list.begin(); it != axis_list.end(); it++) {
auto stride_list = it->stride_list;
if (!(std::any_of(stride_list.begin(), stride_list.end(), [](int stride) { return stride > 1; }) ||
std::all_of(stride_list.begin(), stride_list.end(), [](int stride) { return stride == 0; }))) {
return it;
}
}
return axis_list.end();
};
auto last_axis_it = get_last_axis_it(axis_list);
if (last_axis_it == axis_list.end()) {
return s;
}
auto last_axis = *last_axis_it;
auto last_axis_shape = last_axis.extent;
auto info_list = GetInfoList(dst_info, src_info_list);
FillEmptyVar(info_list);
int dst_block_size = GetUbBlkSize(dst_info->dtype_);
int src_block_size = GetUbBlkSize(src_info_list[0]->dtype_);
int block_size = dst_block_size > src_block_size ? dst_block_size : src_block_size;
int block_size = dst_block_size < src_block_size ? dst_block_size : src_block_size;
int cast_block_size = dst_block_size > src_block_size ? dst_block_size : src_block_size;
int vec_max_len = block_size * FULL_BLOCK_NUM;
if (last_axis_shape > vec_max_len && last_axis_shape % vec_max_len != 0) {
return Block::make(TailMake(s, last_axis, vec_max_len, false), TailMake(s, last_axis, vec_max_len, true));
}
if (last_axis_shape < vec_max_len * tail_rate_ && last_axis_shape > block_size &&
last_axis_shape % block_size != 0 && axis_list.size() > 1) {
return Block::make(TailMake(s, last_axis, block_size, false), TailMake(s, last_axis, block_size, true));
}
}
return IRMutator::Mutate_(op, s);
}
std::list<InsnAxis> GetAixsList(const StmtInfo &for_info, const Array<StmtStoreInfo> &info_list) {
std::list<InsnAxis> axis_list;
auto GetStrideByAxis = [](const Array<Var> &vars, const Array<Expr> &strides, Var obj_var) {
int index = 0;
for (auto var_it : vars) {
if (Equal(var_it, obj_var)) {
return strides[index];
auto args_calculator = InsnArgsCalculator(dst_info_list, src_info_list, for_info, "");
auto vec_axis_it = args_calculator.GetVecAxisIt();
bool cast = dst_block_size != src_block_size;
if (args_calculator.IsValid(vec_axis_it)) {
auto vec_axis = *vec_axis_it;
auto vec_axis_shape = vec_axis.extent;
if (vec_axis_shape >= vec_max_len) {
if (vec_axis_shape % vec_max_len != 0) {
return TailBlock(s, vec_axis, vec_max_len);
}
} else {
if (vec_axis_shape < vec_max_len * tail_rate_ && vec_axis_shape > cast_block_size &&
vec_axis_shape % cast_block_size != 0 && args_calculator.axis_list_.size() > 1) {
return TailBlock(s, vec_axis, cast_block_size);
}
}
index++;
}
return Expr(0);
};
for (auto it : for_info.ops_) {
InsnAxis axis;
auto for_stmt = it.as<For>();
CHECK(for_stmt);
axis.var = for_stmt->loop_var;
axis.extent = GetInt32Const(for_stmt->extent);
axis.min = GetInt32Const(for_stmt->min);
int index = 0;
for (auto it : info_list) {
auto stride = GetInt32Const(GetStrideByAxis(it->var_, it->strides_, axis.var));
axis.stride_list.push_back(stride);
if (index == 0) {
axis.dst_stride = stride;
} else {
axis.src_stride_list.push_back(stride);
if (!cast && (!args_calculator.IsValid(vec_axis_it) || vec_axis_it->extent <= cast_block_size * tail_rate_)) {
auto get_block_axis = [&](std::list<InsnAxis> &axis_list) {
InsnAxis block_axis;
block_axis.is_valid = false;
std::vector<InsnAxis> temp_axis_set;
auto block_stride_lambda = [&](int stride) { return stride % block_size == 0 && stride / block_size <= 4; };
for (auto axis : axis_list) {
if (std::all_of(axis.stride_list.begin(), axis.stride_list.end(), block_stride_lambda) &&
axis.dst_stride != 0 && axis.extent != 0 && axis.extent > FULL_BLOCK_NUM &&
axis.extent % FULL_BLOCK_NUM != 0) {
temp_axis_set.push_back(axis);
}
}
if (!temp_axis_set.empty()) {
return temp_axis_set[0];
} else {
return block_axis;
}
};
auto block_axis = get_block_axis(args_calculator.axis_list_);
if (block_axis.IsValid()) {
return TailBlock(s, block_axis, FULL_BLOCK_NUM);
}
index++;
}
axis_list.push_back(axis);
return s;
}
return axis_list;
return IRMutator::Mutate_(op, s);
}
Stmt TailBlock(const Stmt &s, const InsnAxis &tail_axis, int body_size) {
return Block::make(TailMake(s, tail_axis, body_size, false), TailMake(s, tail_axis, body_size, true));
}
Stmt TailMake(const Stmt &s, const InsnAxis &tail_axis, int body_size, bool is_tail) {
if (auto attr_stmt = s.as<AttrStmt>()) {
return AttrStmt::make(attr_stmt->node, attr_stmt->attr_key, attr_stmt->value,
......@@ -145,8 +123,7 @@ class TailSpliter : public IRMutator {
}
return For::make(for_stmt->loop_var, for_stmt->min, for_stmt->extent, for_stmt->for_type, for_stmt->device_api,
TailMake(for_stmt->body, tail_axis, body_size, is_tail));
}
}
if (s.as<Store>() && is_tail) {
return substitute(tail_axis.var, Add::make(Expr(tail_axis.extent / body_size * body_size), tail_axis.var), s);
}
......@@ -156,6 +133,20 @@ class TailSpliter : public IRMutator {
private:
const float tail_rate_{0.6};
const std::set<std::string> include_intrin_list_ = {
// binary vec
"vec_binary_add",
"vec_binary_sub",
"vec_binary_mul",
"vec_binary_min",
"vec_binary_max",
"vec_binary_div",
"vec_binary_and",
"vec_binary_or",
"vec_binary_vmadd",
"vec_binary_vmaddrelu",
"vec_binary_vmla",
// single vec
"vec_single_fabs",
"vec_single_log",
"vec_single_exp",
......@@ -165,20 +156,28 @@ class TailSpliter : public IRMutator {
"vec_single_rsqrt",
"vec_single_relu",
"vec_single_not",
// vector_scalar
"vec_single_muls",
"vec_single_adds",
// Mov
"broadcast",
"mask_broadcast",
// vector_cast
"vec_single_cast",
"vec_single_floor",
"vec_single_round",
"vec_single_ceil",
"vec_single_trunc",
// scalar case
"vector_dup",
"vmuls",
"vadds",
"vaxpy",
};
};
Stmt SplitTail(Stmt stmt) { return TailSpliter().Mutate(stmt); }
Stmt SplitTail(Stmt stmt) {
auto tail_spliter = TailSpliter();
auto first_round = tail_spliter.Mutate(stmt);
auto second_round = tail_spliter.Mutate(stmt);
return second_round;
}
} // namespace ir
} // namespace akg
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册