未验证 提交 8ca10db8 编写于 作者: 石晓伟 提交者: GitHub

make passes related to the device type, test=develop (#2012)

* make passes related to the device type, test=develop

* improve tips, test=develop
上级 13bbd2b8
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cudnn.h> #include <cudnn.h>
#include <limits>
#include <string> #include <string>
namespace paddle { namespace paddle {
...@@ -39,8 +40,8 @@ __device__ __forceinline__ half from_float<half>(float x) { ...@@ -39,8 +40,8 @@ __device__ __forceinline__ half from_float<half>(float x) {
template <> template <>
__device__ __forceinline__ int8_t from_float<int8_t>(float x) { __device__ __forceinline__ int8_t from_float<int8_t>(float x) {
x = fmaxf(x, INT8_MIN); x = fmaxf(x, std::numeric_limits<char>::min());
x = fminf(x, INT8_MAX); x = fminf(x, std::numeric_limits<char>::max());
return __float2int_rn(x); return __float2int_rn(x);
} }
......
...@@ -42,4 +42,5 @@ class ArgumentTypeDisplayPass : public DebugPass { ...@@ -42,4 +42,5 @@ class ArgumentTypeDisplayPass : public DebugPass {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(argument_type_display_pass, REGISTER_MIR_PASS(argument_type_display_pass,
paddle::lite::mir::ArgumentTypeDisplayPass); paddle::lite::mir::ArgumentTypeDisplayPass)
.SetTargets({TARGET(kAny)});
...@@ -34,4 +34,4 @@ bool RegisterDemoPass() { ...@@ -34,4 +34,4 @@ bool RegisterDemoPass() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass); REGISTER_MIR_PASS(demo, paddle::lite::mir::DemoPass).SetTargets({TARGET(kAny)});
...@@ -69,4 +69,5 @@ class IdentityScaleEliminatePass : public ProgramPass { ...@@ -69,4 +69,5 @@ class IdentityScaleEliminatePass : public ProgramPass {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(identity_scale_eliminate_pass, REGISTER_MIR_PASS(identity_scale_eliminate_pass,
paddle::lite::mir::IdentityScaleEliminatePass); paddle::lite::mir::IdentityScaleEliminatePass)
.SetTargets({TARGET(kAny)});
...@@ -38,4 +38,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -38,4 +38,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_conv_activation_fuse_pass, REGISTER_MIR_PASS(lite_conv_activation_fuse_pass,
paddle::lite::mir::ConvActivationFusePass); paddle::lite::mir::ConvActivationFusePass)
.SetTargets({TARGET(kAny)});
...@@ -34,4 +34,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -34,4 +34,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass); REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.SetTargets({TARGET(kAny)});
...@@ -38,4 +38,5 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -38,4 +38,5 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass, REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass,
paddle::lite::mir::ConvElementwiseFusePass); paddle::lite::mir::ConvElementwiseFusePass)
.SetTargets({TARGET(kAny)});
...@@ -33,4 +33,5 @@ void ElementwiseAddActivationFusePass::Apply( ...@@ -33,4 +33,5 @@ void ElementwiseAddActivationFusePass::Apply(
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass, REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass); paddle::lite::mir::ElementwiseAddActivationFusePass)
.SetTargets({TARGET(kAny)});
...@@ -31,4 +31,5 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -31,4 +31,5 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass); REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.SetTargets({TARGET(kAny)});
...@@ -35,4 +35,5 @@ void InterpolateFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -35,4 +35,5 @@ void InterpolateFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_interpolate_fuse_pass, REGISTER_MIR_PASS(lite_interpolate_fuse_pass,
paddle::lite::mir::InterpolateFusePass); paddle::lite::mir::InterpolateFusePass)
.SetTargets({TARGET(kAny)});
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "lite/core/mir/fusion/quant_dequant_fuse_pass.h" #include "lite/core/mir/fusion/quant_dequant_fuse_pass.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "lite/api/paddle_place.h"
#include "lite/core/mir/fusion/quant_dequant_op_fuser.h" #include "lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include "lite/core/mir/pass_registry.h" #include "lite/core/mir/pass_registry.h"
...@@ -42,4 +43,5 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -42,4 +43,5 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass, REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass,
paddle::lite::mir::QuantDequantFusePass); paddle::lite::mir::QuantDequantFusePass)
.SetTargets({TARGET(kAny)});
...@@ -35,4 +35,5 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -35,4 +35,5 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass, REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass,
paddle::lite::mir::ShuffleChannelFusePass); paddle::lite::mir::ShuffleChannelFusePass)
.SetTargets({TARGET(kAny)});
...@@ -36,4 +36,5 @@ void TransposeSoftmaxTransposeFusePass::Apply( ...@@ -36,4 +36,5 @@ void TransposeSoftmaxTransposeFusePass::Apply(
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass, REGISTER_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass,
paddle::lite::mir::TransposeSoftmaxTransposeFusePass); paddle::lite::mir::TransposeSoftmaxTransposeFusePass)
.SetTargets({TARGET(kAny)});
...@@ -38,5 +38,5 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -38,5 +38,5 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(generate_program_pass, REGISTER_MIR_PASS(generate_program_pass, paddle::lite::mir::GenerateProgramPass)
paddle::lite::mir::GenerateProgramPass); .SetTargets({TARGET(kAny)});
...@@ -98,4 +98,5 @@ std::string Visualize(mir::SSAGraph* graph) { ...@@ -98,4 +98,5 @@ std::string Visualize(mir::SSAGraph* graph) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass); REGISTER_MIR_PASS(graph_visualze, paddle::lite::mir::GraphVisualizePass)
.SetTargets({TARGET(kAny)});
...@@ -71,4 +71,5 @@ class IoCopyKernelPickPass : public StmtPass { ...@@ -71,4 +71,5 @@ class IoCopyKernelPickPass : public StmtPass {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(io_copy_kernel_pick_pass, REGISTER_MIR_PASS(io_copy_kernel_pick_pass,
paddle::lite::mir::IoCopyKernelPickPass); paddle::lite::mir::IoCopyKernelPickPass)
.SetTargets({TARGET(kAny)});
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <set>
#include <string> #include <string>
#include "lite/core/mir/node.h" #include "lite/core/mir/node.h"
#include "lite/core/mir/ssa_graph.h" #include "lite/core/mir/ssa_graph.h"
...@@ -44,6 +46,13 @@ class Pass { ...@@ -44,6 +46,13 @@ class Pass {
void set_doc(const std::string& doc) { doc_ = doc; } void set_doc(const std::string& doc) { doc_ = doc; }
const std::string& doc() const { return doc_; } const std::string& doc() const { return doc_; }
void set_targets(const std::set<TargetType>& targets) { targets_ = targets; }
const std::set<TargetType>& targets() const { return targets_; }
bool is_supported_target(TargetType target) const {
if (targets_.find(TARGET(kAny)) != targets_.end()) return true;
return (targets_.find(target) != targets_.end());
}
Kind kind() const { return kind_; } Kind kind() const { return kind_; }
bool is_debug_pass() const { return kind_ == Kind::kDebug; } bool is_debug_pass() const { return kind_ == Kind::kDebug; }
bool is_program_pass() const { return kind_ == Kind::kProgramWise; } bool is_program_pass() const { return kind_ == Kind::kProgramWise; }
...@@ -55,6 +64,7 @@ class Pass { ...@@ -55,6 +64,7 @@ class Pass {
const Kind kind_; const Kind kind_;
std::string name_; std::string name_;
std::string doc_; std::string doc_;
std::set<TargetType> targets_;
}; };
// Different kinds. // Different kinds.
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
#pragma once #pragma once
#include <set>
#include <string> #include <string>
#include "lite/api/paddle_lite_factory_helper.h" #include "lite/api/paddle_lite_factory_helper.h"
#include "lite/api/paddle_place.h"
#include "lite/core/mir/pass_manager.h" #include "lite/core/mir/pass_manager.h"
namespace paddle { namespace paddle {
...@@ -24,12 +26,19 @@ namespace mir { ...@@ -24,12 +26,19 @@ namespace mir {
class PassRegistry { class PassRegistry {
public: public:
PassRegistry(const std::string& name, mir::Pass* pass) { PassRegistry(const std::string& name, mir::Pass* pass)
VLOG(2) << "Registry add MIR pass " << name; : name_(name), pass_(pass) {
PassManager::Global().AddNewPass(name, pass); PassManager::Global().AddNewPass(name_, pass_);
}
PassRegistry& SetTargets(const std::set<TargetType>& targets) {
pass_->set_targets(targets);
return *this;
} }
bool Touch() const { return true; } bool Touch() const { return true; }
private:
std::string name_;
mir::Pass* pass_;
}; };
} // namespace mir } // namespace mir
...@@ -41,4 +50,6 @@ class PassRegistry { ...@@ -41,4 +50,6 @@ class PassRegistry {
new class__); \ new class__); \
bool mir_pass_registry##name__##_fake() { \ bool mir_pass_registry##name__##_fake() { \
return mir_pass_registry##name__.Touch(); \ return mir_pass_registry##name__.Touch(); \
} } \
static paddle::lite::mir::PassRegistry mir_pass_registry_func_##name__ \
__attribute__((unused)) = mir_pass_registry##name__
...@@ -38,4 +38,5 @@ class RuntimeContextAssignPass : public StmtPass { ...@@ -38,4 +38,5 @@ class RuntimeContextAssignPass : public StmtPass {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(runtime_context_assign_pass, REGISTER_MIR_PASS(runtime_context_assign_pass,
paddle::lite::mir::RuntimeContextAssignPass); paddle::lite::mir::RuntimeContextAssignPass)
.SetTargets({TARGET(kAny)});
...@@ -132,4 +132,5 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -132,4 +132,5 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(static_kernel_pick_pass, REGISTER_MIR_PASS(static_kernel_pick_pass,
paddle::lite::mir::StaticKernelPickPass); paddle::lite::mir::StaticKernelPickPass)
.SetTargets({TARGET(kAny)});
...@@ -214,4 +214,5 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() { ...@@ -214,4 +214,5 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(generate_npu_program_pass, REGISTER_MIR_PASS(generate_npu_program_pass,
paddle::lite::mir::subgraph::GenerateNPUProgramPass); paddle::lite::mir::subgraph::GenerateNPUProgramPass)
.SetTargets({TARGET(kAny)});
...@@ -310,4 +310,5 @@ int SubgraphProgramPass::FuseSubgraph( ...@@ -310,4 +310,5 @@ int SubgraphProgramPass::FuseSubgraph(
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(subgraph_program_pass, REGISTER_MIR_PASS(subgraph_program_pass,
paddle::lite::mir::subgraph::SubgraphProgramPass); paddle::lite::mir::subgraph::SubgraphProgramPass)
.SetTargets({TARGET(kAny)});
...@@ -173,4 +173,5 @@ void TypeLayoutTransformPass::SetValidPlaces( ...@@ -173,4 +173,5 @@ void TypeLayoutTransformPass::SetValidPlaces(
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(type_layout_cast_pass, REGISTER_MIR_PASS(type_layout_cast_pass,
paddle::lite::mir::TypeLayoutTransformPass); paddle::lite::mir::TypeLayoutTransformPass)
.SetTargets({TARGET(kAny)});
...@@ -179,4 +179,5 @@ void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) { ...@@ -179,4 +179,5 @@ void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(type_precision_cast_pass, REGISTER_MIR_PASS(type_precision_cast_pass,
paddle::lite::mir::PrecisionCastPass); paddle::lite::mir::PrecisionCastPass)
.SetTargets({TARGET(kAny)});
...@@ -179,4 +179,5 @@ void TypeTargetTransformPass::SetValidPlaces( ...@@ -179,4 +179,5 @@ void TypeTargetTransformPass::SetValidPlaces(
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(type_target_cast_pass, REGISTER_MIR_PASS(type_target_cast_pass,
paddle::lite::mir::TypeTargetTransformPass); paddle::lite::mir::TypeTargetTransformPass)
.SetTargets({TARGET(kAny)});
...@@ -31,4 +31,5 @@ void VariablePlaceInferencePass::Apply(const std::unique_ptr<SSAGraph> &graph) { ...@@ -31,4 +31,5 @@ void VariablePlaceInferencePass::Apply(const std::unique_ptr<SSAGraph> &graph) {
} // namespace paddle } // namespace paddle
REGISTER_MIR_PASS(variable_place_inference_pass, REGISTER_MIR_PASS(variable_place_inference_pass,
paddle::lite::mir::VariablePlaceInferencePass); paddle::lite::mir::VariablePlaceInferencePass)
.SetTargets({TARGET(kAny)});
...@@ -153,9 +153,6 @@ class KernelRegistry final { ...@@ -153,9 +153,6 @@ class KernelRegistry final {
const std::string &name, const std::string &name,
typename KernelRegistryForTarget<Target, Precision, Layout>::creator_t typename KernelRegistryForTarget<Target, Precision, Layout>::creator_t
&&creator) { &&creator) {
VLOG(3) << "register for " << TargetToStr(Target) << ":"
<< PrecisionToStr(Precision) << "//"
<< GetKernelOffset<Target, Precision, Layout>();
using kernel_registor_t = using kernel_registor_t =
KernelRegistryForTarget<Target, Precision, Layout>; KernelRegistryForTarget<Target, Precision, Layout>;
auto &varient = registries_[GetKernelOffset<Target, Precision, Layout>()]; auto &varient = registries_[GetKernelOffset<Target, Precision, Layout>()];
...@@ -219,9 +216,6 @@ class KernelRegistor : public lite::Registor<KernelType> { ...@@ -219,9 +216,6 @@ class KernelRegistor : public lite::Registor<KernelType> {
public: public:
KernelRegistor(const std::string &op_type, const std::string &alias) KernelRegistor(const std::string &op_type, const std::string &alias)
: Registor<KernelType>([=] { : Registor<KernelType>([=] {
VLOG(3) << "Register kernel " << op_type << " for "
<< TargetToStr(target) << " " << PrecisionToStr(precision)
<< " " << DataLayoutToStr(layout) << " alias " << alias;
KernelRegistry::Global().Register<target, precision, layout>( KernelRegistry::Global().Register<target, precision, layout>(
op_type, [=]() -> std::unique_ptr<KernelType> { op_type, [=]() -> std::unique_ptr<KernelType> {
std::unique_ptr<KernelType> x(new KernelType); std::unique_ptr<KernelType> x(new KernelType);
......
...@@ -183,11 +183,22 @@ class Optimizer { ...@@ -183,11 +183,22 @@ class Optimizer {
// Specify the passes and run them. // Specify the passes and run them.
void RunPasses(const std::vector<std::string>& passes) { void RunPasses(const std::vector<std::string>& passes) {
for (auto& x : passes) { for (auto& x : passes) {
LOG(INFO) << "== Running pass " << x; LOG(INFO) << "== Running pass: " << x;
auto* pass = mir::PassManager::Global().LookUp(x); mir::Pass* pass = mir::PassManager::Global().LookUp(x);
CHECK(pass) << "Can not find pass: " << x; CHECK(pass) << "Can not find pass: " << x;
pass->Apply(graph_); bool supported = false;
LOG(INFO) << "== Running pass Done." << x; for (const auto& place : valid_places_) {
if (pass->is_supported_target(place.target)) {
supported = true;
}
}
if (!supported) {
LOG(WARNING) << "Skip " << x
<< " pass because the target does not match.";
} else {
pass->Apply(graph_);
LOG(INFO) << "== Finished running: " << x;
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册