提交 bdb331eb 编写于 作者: L looop5

reorder device args order to match real args order, and fix args not correct...

reorder device args order to match real args order, and fix args not correct when having multiple device functions
上级 5237a9a0
......@@ -99,10 +99,6 @@ add_definitions(-DDMLC_LOG_CUSTOMIZE=1)
if(USE_AKG_LOG)
add_definitions(-DUSE_AKG_LOG=1)
endif()
if(NOT USE_CUDA
OR ENABLE_AKG)
add_definitions("-DFIX_INPUT_ORDER_TVM")
endif()
# Generic compilation options
include(CheckCXXCompilerFlag)
......
......@@ -234,32 +234,25 @@ class HostDeviceSplitter : public IRMutator {
}
}
#ifdef FIX_INPUT_ORDER_TVM
std::shared_ptr<LoweredFuncNode> na = std::make_shared<LoweredFuncNode>();
for (unsigned i = 0; i < (unsigned)args_real.size(); i++) {
bool match = false;
for (unsigned j = 0; j < (unsigned)n->args.size(); j++) {
if (strcmp(args_real[i].get()->name_hint.c_str(), n->args[j].get()->name_hint.c_str()) == 0) {
na->args.push_back(n->args[j]);
match = true;
break;
} else {
continue;
}
// Reorder args to match args_real
Array<Var> ordered_args;
std::unordered_set<Var, NodeHash, NodeEqual> args_set;
std::unordered_set<Var, NodeHash, NodeEqual> args_real_set;
for (size_t i = 0; i < n->args.size(); ++i) {
args_set.insert(n->args[i]);
}
for (size_t i = 0; i < args_real.size(); ++i) {
args_real_set.insert(args_real[i]);
if (args_set.find(args_real[i]) != args_set.end()) {
ordered_args.push_back(args_real[i]);
}
if (!match) {
na->args.push_back(args_real[i]);
// mark handle data type.
for (auto kv : handle_data_type_) {
if (strcmp(args_real[i].get()->name_hint.c_str(), kv.first->name_hint.c_str()) == 0) {
n->handle_data_type.Set(args_real[i], kv.second);
}
}
}
for (size_t i = 0; i < n->args.size(); ++i) {
if (args_real_set.find(n->args[i]) == args_real_set.end()) {
ordered_args.push_back(n->args[i]);
}
}
n->args = na->args;
#endif
n->args = ordered_args;
LoweredFunc f_device(n);
Array<Expr> call_args;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册