未验证 提交 6da6ff6a 编写于 作者: W wenbin 提交者: GitHub

SimplifyWithBasicOpsPass (#33637)

* simplify_with_basic

* fix

* scale factor
上级 478ea78b
......@@ -34,6 +34,26 @@ namespace ir {
*/
class Graph;
SimplifyWithBasicOpsPass::SimplifyWithBasicOpsPass() {
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("scale")
.IsNumGE(0.f)
.IsNumLE(1.f)
.End()
.AddAttr("bias")
.IsNumEQ(0.f)
.End()
.AddAttr("bias_after_scale")
.IsNumEQ(true)
.End();
}
void SimplifyWithBasicOpsPass::ApplyImpl(Graph* graph) const {
VLOG(3) << "Simplify the Graph with basic ops.";
std::unordered_set<const Node*> del_node_set;
......@@ -145,6 +165,11 @@ bool SimplifyWithBasicOpsPass::SimplifyDropout(
new_op_desc.SetAttr("bias", static_cast<float>(0));
new_op_desc.SetAttr("bias_after_scale", true);
if (!IsCompat(new_op_desc)) {
LOG(WARNING) << "Basic ops pass in scale op compat failed.";
return false;
}
auto* scale_op_node = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(dropout_x, scale_op_node);
IR_NODE_LINK_TO(scale_op_node, dropout_out);
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h"
namespace paddle {
namespace framework {
......@@ -26,7 +26,10 @@ namespace ir {
class Graph;
class Node;
class SimplifyWithBasicOpsPass : public Pass {
class SimplifyWithBasicOpsPass : public OpCompatSensiblePass {
public:
SimplifyWithBasicOpsPass();
protected:
void ApplyImpl(Graph* graph) const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册