diff --git a/advisor/heuristic.go b/advisor/heuristic.go index 780e38ee03826649b4d59aea3916b4f31f33fa26..a7c2e7e72982a1600a1be3e00f5e2806f60c4276 100644 --- a/advisor/heuristic.go +++ b/advisor/heuristic.go @@ -107,7 +107,7 @@ func (q *Query4Audit) RulePrefixLike() Rule { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { switch expr := node.(type) { case *sqlparser.ComparisonExpr: - if expr.Operator == "like" { + if strings.ToLower(expr.Operator) == "like" { switch sqlval := expr.Right.(type) { case *sqlparser.SQLVal: // prefix like with '%', '_' @@ -130,7 +130,7 @@ func (q *Query4Audit) RuleEqualLike() Rule { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { switch expr := node.(type) { case *sqlparser.ComparisonExpr: - if expr.Operator == "like" { + if strings.ToLower(expr.Operator) == "like" { switch sqlval := expr.Right.(type) { case *sqlparser.SQLVal: // not start with '%', '_' && not end with '%', '_' @@ -397,7 +397,7 @@ func (q *Query4Audit) RuleOrderByRand() Rule { for _, order := range n { switch expr := order.Expr.(type) { case *sqlparser.FuncExpr: - if expr.Name.String() == "rand" { + if strings.ToLower(expr.Name.String()) == "rand" { rule = HeuristicRules["CLA.002"] return false, nil } @@ -761,7 +761,7 @@ func (q *Query4Audit) RuleTblCommentCheck() Rule { var rule = q.RuleOK() switch node := q.Stmt.(type) { case *sqlparser.DDL: - if node.Action != "create" { + if strings.ToLower(node.Action) != "create" { return rule } if node.TableSpec == nil { @@ -968,7 +968,7 @@ func (q *Query4Audit) RuleSQLCalcFoundRows() Rule { var rule = q.RuleOK() tkns := ast.Tokenizer(q.Query) for _, tkn := range tkns { - if tkn.Val == "sql_calc_found_rows" { + if strings.ToLower(tkn.Val) == "sql_calc_found_rows" { rule = HeuristicRules["KWR.001"] break } @@ -1049,7 +1049,7 @@ func (idxAdv *IndexAdvisor) RuleImpossibleOuterJoin() Rule { for _, l1 := range idxAdv.joinCond { for _, l2 := range l1 { - if l2.Table != "" && l2.Table != "dual" { + if l2.Table != "" && strings.ToLower(l2.Table) != "dual" { joinTables = append(joinTables, l2.Table) } } @@ -1192,7 +1192,7 @@ func (q *Query4Audit) RuleImpossibleWhere() Rule { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { switch n := node.(type) { case *sqlparser.RangeCond: - if n.Operator == "between" { + if strings.ToLower(n.Operator) == "between" { from := 0 to := 0 switch s := n.From.(type) { @@ -1893,7 +1893,7 @@ func (q *Query4Audit) RuleSysdate() Rule { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { switch n := node.(type) { case *sqlparser.FuncExpr: - if n.Name.String() == "sysdate" { + if strings.ToLower(n.Name.String()) == "sysdate" { rule = HeuristicRules["FUN.004"] return false, nil } @@ -2161,7 +2161,7 @@ func (idxAdv *IndexAdvisor) RuleUpdatePrimaryKey() Rule { return rule } for _, idx := range idxMeta.Rows { - if idx.KeyName == "PRIMARY" { + if strings.ToLower(idx.KeyName) == "primary" { if col.Name == idx.ColumnName { rule = HeuristicRules["CLA.016"] return rule @@ -2310,7 +2310,7 @@ func (q *Query4Audit) RuleNot() Rule { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { switch n := node.(type) { case *sqlparser.ComparisonExpr: - if strings.HasPrefix(n.Operator, "not") { + if strings.HasPrefix(strings.ToLower(n.Operator), "not") { rule = HeuristicRules["ARG.011"] return false, nil } @@ -2359,7 +2359,7 @@ func (q *Query4Audit) RuleUNIONUsage() Rule { var rule = q.RuleOK() switch s := q.Stmt.(type) { case *sqlparser.Union: - if s.Type == "union" { + if strings.ToLower(s.Type) == "union" { rule = HeuristicRules["SUB.002"] } } @@ -2435,11 +2435,11 @@ func (q *Query4Audit) RuleDataDrop() Rule { var rule = q.RuleOK() switch s := q.Stmt.(type) { case *sqlparser.DBDDL: - if s.Action == "drop" { + if strings.ToLower(s.Action) == "drop" { rule = HeuristicRules["SEC.003"] } case *sqlparser.DDL: - if s.Action == "drop" || s.Action == "truncate" { + if strings.ToLower(s.Action) == "drop" || strings.ToLower(s.Action) == "truncate" { rule = HeuristicRules["SEC.003"] } case *sqlparser.Delete: @@ -2523,7 +2523,7 @@ func (q *Query4Audit) RuleTruncateTable() Rule { var rule = q.RuleOK() switch s := q.Stmt.(type) { case *sqlparser.DDL: - if s.Action == "truncate" { + if strings.ToLower(s.Action) == "truncate" { rule = HeuristicRules["SEC.001"] } } @@ -2536,7 +2536,7 @@ func (q *Query4Audit) RuleIn() Rule { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { switch n := node.(type) { case *sqlparser.ComparisonExpr: - switch n.Operator { + switch strings.ToLower(n.Operator) { case "in": switch r := n.Right.(type) { case sqlparser.ValTuple: @@ -2842,12 +2842,12 @@ func (q *Query4Audit) RulePKNotInt() Rule { var pk sqlparser.ColIdent switch s := q.Stmt.(type) { case *sqlparser.DDL: - if s.Action == "create" { + if strings.ToLower(s.Action) == "create" { if s.TableSpec == nil { return rule } for _, idx := range s.TableSpec.Indexes { - if idx.Info.Type == "primary key" { + if strings.ToLower(idx.Info.Type) == "primary key" { if len(idx.Columns) == 1 { pk = idx.Columns[0].Column break @@ -2864,7 +2864,7 @@ func (q *Query4Audit) RulePKNotInt() Rule { // 主键非int, bigint类型 for _, col := range s.TableSpec.Columns { if pk.String() == col.Name.String() { - switch col.Type.Type { + switch strings.ToLower(col.Type.Type) { case "int", "bigint", "integer": if !col.Type.Unsigned { rule = HeuristicRules["KEY.007"] @@ -2971,7 +2971,7 @@ func (q *Query4Audit) RuleFulltextIndex() Rule { for _, tk := range tks { switch tk.Type { case ast.TokenTypeWord: - if strings.TrimSpace(strings.ToUpper(tk.Val)) == "FULLTEXT" { + if strings.TrimSpace(strings.ToLower(tk.Val)) == "fulltext" { rule = HeuristicRules["KEY.010"] } default: @@ -3001,8 +3001,8 @@ func (q *Query4Audit) RuleTimestampDefault() Rule { if option.Tp == tidb.ColumnOptionDefaultValue { hasDefault = true if err := option.Restore(ctx); err == nil { - if strings.HasPrefix(sb.String(), `DEFAULT '0`) || - strings.HasPrefix(sb.String(), `DEFAULT 0`) { + if strings.HasPrefix(strings.ToLower(sb.String()), `default '0`) || + strings.HasPrefix(strings.ToLower(sb.String()), `default 0`) { hasDefault = false } } @@ -3034,8 +3034,8 @@ func (q *Query4Audit) RuleTimestampDefault() Rule { if option.Tp == tidb.ColumnOptionDefaultValue { hasDefault = true if err := option.Restore(ctx); err == nil { - if strings.HasPrefix(sb.String(), `DEFAULT '0`) || - strings.HasPrefix(sb.String(), `DEFAULT 0`) { + if strings.HasPrefix(strings.ToLower(sb.String()), `default '0`) || + strings.HasPrefix(strings.ToLower(sb.String()), `default 0`) { hasDefault = false } } @@ -3464,7 +3464,7 @@ func (q *Query4Audit) RuleColumnNotAllowType() Rule { switch s := q.Stmt.(type) { case *sqlparser.DDL: - switch s.Action { + switch strings.ToLower(s.Action) { case "create", "alter": tks := ast.Tokenize(q.Query) for _, tk := range tks { @@ -3536,7 +3536,7 @@ func (q *Query4Audit) RuleNoOSCKey() Rule { var rule = q.RuleOK() switch s := q.Stmt.(type) { case *sqlparser.DDL: - if s.Action == "create" { + if strings.ToLower(s.Action) == "create" { pkReg := regexp.MustCompile(`(?i)(primary\s+key)`) if !pkReg.MatchString(q.Query) { ukReg := regexp.MustCompile(`(?i)(unique\s+((key)|(index)))`) @@ -3605,7 +3605,7 @@ func (idxAdv *IndexAdvisor) RuleMaxTextColsCount() Rule { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { switch stmt := node.(type) { case *sqlparser.DDL: - if stmt.Action != "alter" { + if strings.ToLower(stmt.Action) != "alter" { return true, nil } diff --git a/advisor/heuristic_test.go b/advisor/heuristic_test.go index 8b2d63bdbb87a83d5f35006060d3167cdb662512..8fb25842523165f3e7885b26f0e9a98a97eeb0fe 100644 --- a/advisor/heuristic_test.go +++ b/advisor/heuristic_test.go @@ -1575,6 +1575,7 @@ func TestRuleSysdate(t *testing.T) { common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqls := []string{ `select sysdate();`, + `select Sysdate();`, } for _, sql := range sqls { q, err := NewQuery4Audit(sql) @@ -2435,6 +2436,7 @@ func TestRuleInjection(t *testing.T) { { `select benchmark(10, rand())`, `select sleep(1)`, + `select Sleep(1)`, `select get_lock('lock_name', 1)`, `select release_lock('lock_name')`, }, @@ -2542,6 +2544,7 @@ func TestRuleTruncateTable(t *testing.T) { common.Log.Debug("Entering function: %s", common.GetFunctionName()) sqls := []string{ `TRUNCATE TABLE tbl_name;`, + `truncate TABLE tbl_name;`, } for _, sql := range sqls { q, err := NewQuery4Audit(sql) @@ -2861,6 +2864,7 @@ func TestRulePKNotInt(t *testing.T) { }, { "CREATE TABLE tbl (a int unsigned auto_increment, b int, primary key(`a`)) engine=InnoDB;", + "CREATE TABLE `tb` ( `id` Bigint unsigned NOT NULL AUTO_INCREMENT COMMENT 'auto id', Primary key (`id`) ) ENGINE = InnoDB COMMENT 'comment'", }, } for _, sql := range sqls[0] {