提交 342ec908 编写于 作者: martianzhang's avatar martianzhang

fix #255 key words format to lower case

上级 07870487
...@@ -107,7 +107,7 @@ func (q *Query4Audit) RulePrefixLike() Rule { ...@@ -107,7 +107,7 @@ func (q *Query4Audit) RulePrefixLike() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch expr := node.(type) { switch expr := node.(type) {
case *sqlparser.ComparisonExpr: case *sqlparser.ComparisonExpr:
if expr.Operator == "like" { if strings.ToLower(expr.Operator) == "like" {
switch sqlval := expr.Right.(type) { switch sqlval := expr.Right.(type) {
case *sqlparser.SQLVal: case *sqlparser.SQLVal:
// prefix like with '%', '_' // prefix like with '%', '_'
...@@ -130,7 +130,7 @@ func (q *Query4Audit) RuleEqualLike() Rule { ...@@ -130,7 +130,7 @@ func (q *Query4Audit) RuleEqualLike() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch expr := node.(type) { switch expr := node.(type) {
case *sqlparser.ComparisonExpr: case *sqlparser.ComparisonExpr:
if expr.Operator == "like" { if strings.ToLower(expr.Operator) == "like" {
switch sqlval := expr.Right.(type) { switch sqlval := expr.Right.(type) {
case *sqlparser.SQLVal: case *sqlparser.SQLVal:
// not start with '%', '_' && not end with '%', '_' // not start with '%', '_' && not end with '%', '_'
...@@ -397,7 +397,7 @@ func (q *Query4Audit) RuleOrderByRand() Rule { ...@@ -397,7 +397,7 @@ func (q *Query4Audit) RuleOrderByRand() Rule {
for _, order := range n { for _, order := range n {
switch expr := order.Expr.(type) { switch expr := order.Expr.(type) {
case *sqlparser.FuncExpr: case *sqlparser.FuncExpr:
if expr.Name.String() == "rand" { if strings.ToLower(expr.Name.String()) == "rand" {
rule = HeuristicRules["CLA.002"] rule = HeuristicRules["CLA.002"]
return false, nil return false, nil
} }
...@@ -761,7 +761,7 @@ func (q *Query4Audit) RuleTblCommentCheck() Rule { ...@@ -761,7 +761,7 @@ func (q *Query4Audit) RuleTblCommentCheck() Rule {
var rule = q.RuleOK() var rule = q.RuleOK()
switch node := q.Stmt.(type) { switch node := q.Stmt.(type) {
case *sqlparser.DDL: case *sqlparser.DDL:
if node.Action != "create" { if strings.ToLower(node.Action) != "create" {
return rule return rule
} }
if node.TableSpec == nil { if node.TableSpec == nil {
...@@ -968,7 +968,7 @@ func (q *Query4Audit) RuleSQLCalcFoundRows() Rule { ...@@ -968,7 +968,7 @@ func (q *Query4Audit) RuleSQLCalcFoundRows() Rule {
var rule = q.RuleOK() var rule = q.RuleOK()
tkns := ast.Tokenizer(q.Query) tkns := ast.Tokenizer(q.Query)
for _, tkn := range tkns { for _, tkn := range tkns {
if tkn.Val == "sql_calc_found_rows" { if strings.ToLower(tkn.Val) == "sql_calc_found_rows" {
rule = HeuristicRules["KWR.001"] rule = HeuristicRules["KWR.001"]
break break
} }
...@@ -1049,7 +1049,7 @@ func (idxAdv *IndexAdvisor) RuleImpossibleOuterJoin() Rule { ...@@ -1049,7 +1049,7 @@ func (idxAdv *IndexAdvisor) RuleImpossibleOuterJoin() Rule {
for _, l1 := range idxAdv.joinCond { for _, l1 := range idxAdv.joinCond {
for _, l2 := range l1 { for _, l2 := range l1 {
if l2.Table != "" && l2.Table != "dual" { if l2.Table != "" && strings.ToLower(l2.Table) != "dual" {
joinTables = append(joinTables, l2.Table) joinTables = append(joinTables, l2.Table)
} }
} }
...@@ -1192,7 +1192,7 @@ func (q *Query4Audit) RuleImpossibleWhere() Rule { ...@@ -1192,7 +1192,7 @@ func (q *Query4Audit) RuleImpossibleWhere() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) { switch n := node.(type) {
case *sqlparser.RangeCond: case *sqlparser.RangeCond:
if n.Operator == "between" { if strings.ToLower(n.Operator) == "between" {
from := 0 from := 0
to := 0 to := 0
switch s := n.From.(type) { switch s := n.From.(type) {
...@@ -1893,7 +1893,7 @@ func (q *Query4Audit) RuleSysdate() Rule { ...@@ -1893,7 +1893,7 @@ func (q *Query4Audit) RuleSysdate() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) { switch n := node.(type) {
case *sqlparser.FuncExpr: case *sqlparser.FuncExpr:
if n.Name.String() == "sysdate" { if strings.ToLower(n.Name.String()) == "sysdate" {
rule = HeuristicRules["FUN.004"] rule = HeuristicRules["FUN.004"]
return false, nil return false, nil
} }
...@@ -2161,7 +2161,7 @@ func (idxAdv *IndexAdvisor) RuleUpdatePrimaryKey() Rule { ...@@ -2161,7 +2161,7 @@ func (idxAdv *IndexAdvisor) RuleUpdatePrimaryKey() Rule {
return rule return rule
} }
for _, idx := range idxMeta.Rows { for _, idx := range idxMeta.Rows {
if idx.KeyName == "PRIMARY" { if strings.ToLower(idx.KeyName) == "primary" {
if col.Name == idx.ColumnName { if col.Name == idx.ColumnName {
rule = HeuristicRules["CLA.016"] rule = HeuristicRules["CLA.016"]
return rule return rule
...@@ -2310,7 +2310,7 @@ func (q *Query4Audit) RuleNot() Rule { ...@@ -2310,7 +2310,7 @@ func (q *Query4Audit) RuleNot() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) { switch n := node.(type) {
case *sqlparser.ComparisonExpr: case *sqlparser.ComparisonExpr:
if strings.HasPrefix(n.Operator, "not") { if strings.HasPrefix(strings.ToLower(n.Operator), "not") {
rule = HeuristicRules["ARG.011"] rule = HeuristicRules["ARG.011"]
return false, nil return false, nil
} }
...@@ -2359,7 +2359,7 @@ func (q *Query4Audit) RuleUNIONUsage() Rule { ...@@ -2359,7 +2359,7 @@ func (q *Query4Audit) RuleUNIONUsage() Rule {
var rule = q.RuleOK() var rule = q.RuleOK()
switch s := q.Stmt.(type) { switch s := q.Stmt.(type) {
case *sqlparser.Union: case *sqlparser.Union:
if s.Type == "union" { if strings.ToLower(s.Type) == "union" {
rule = HeuristicRules["SUB.002"] rule = HeuristicRules["SUB.002"]
} }
} }
...@@ -2435,11 +2435,11 @@ func (q *Query4Audit) RuleDataDrop() Rule { ...@@ -2435,11 +2435,11 @@ func (q *Query4Audit) RuleDataDrop() Rule {
var rule = q.RuleOK() var rule = q.RuleOK()
switch s := q.Stmt.(type) { switch s := q.Stmt.(type) {
case *sqlparser.DBDDL: case *sqlparser.DBDDL:
if s.Action == "drop" { if strings.ToLower(s.Action) == "drop" {
rule = HeuristicRules["SEC.003"] rule = HeuristicRules["SEC.003"]
} }
case *sqlparser.DDL: case *sqlparser.DDL:
if s.Action == "drop" || s.Action == "truncate" { if strings.ToLower(s.Action) == "drop" || strings.ToLower(s.Action) == "truncate" {
rule = HeuristicRules["SEC.003"] rule = HeuristicRules["SEC.003"]
} }
case *sqlparser.Delete: case *sqlparser.Delete:
...@@ -2523,7 +2523,7 @@ func (q *Query4Audit) RuleTruncateTable() Rule { ...@@ -2523,7 +2523,7 @@ func (q *Query4Audit) RuleTruncateTable() Rule {
var rule = q.RuleOK() var rule = q.RuleOK()
switch s := q.Stmt.(type) { switch s := q.Stmt.(type) {
case *sqlparser.DDL: case *sqlparser.DDL:
if s.Action == "truncate" { if strings.ToLower(s.Action) == "truncate" {
rule = HeuristicRules["SEC.001"] rule = HeuristicRules["SEC.001"]
} }
} }
...@@ -2536,7 +2536,7 @@ func (q *Query4Audit) RuleIn() Rule { ...@@ -2536,7 +2536,7 @@ func (q *Query4Audit) RuleIn() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) { switch n := node.(type) {
case *sqlparser.ComparisonExpr: case *sqlparser.ComparisonExpr:
switch n.Operator { switch strings.ToLower(n.Operator) {
case "in": case "in":
switch r := n.Right.(type) { switch r := n.Right.(type) {
case sqlparser.ValTuple: case sqlparser.ValTuple:
...@@ -2842,12 +2842,12 @@ func (q *Query4Audit) RulePKNotInt() Rule { ...@@ -2842,12 +2842,12 @@ func (q *Query4Audit) RulePKNotInt() Rule {
var pk sqlparser.ColIdent var pk sqlparser.ColIdent
switch s := q.Stmt.(type) { switch s := q.Stmt.(type) {
case *sqlparser.DDL: case *sqlparser.DDL:
if s.Action == "create" { if strings.ToLower(s.Action) == "create" {
if s.TableSpec == nil { if s.TableSpec == nil {
return rule return rule
} }
for _, idx := range s.TableSpec.Indexes { for _, idx := range s.TableSpec.Indexes {
if idx.Info.Type == "primary key" { if strings.ToLower(idx.Info.Type) == "primary key" {
if len(idx.Columns) == 1 { if len(idx.Columns) == 1 {
pk = idx.Columns[0].Column pk = idx.Columns[0].Column
break break
...@@ -2864,7 +2864,7 @@ func (q *Query4Audit) RulePKNotInt() Rule { ...@@ -2864,7 +2864,7 @@ func (q *Query4Audit) RulePKNotInt() Rule {
// 主键非int, bigint类型 // 主键非int, bigint类型
for _, col := range s.TableSpec.Columns { for _, col := range s.TableSpec.Columns {
if pk.String() == col.Name.String() { if pk.String() == col.Name.String() {
switch col.Type.Type { switch strings.ToLower(col.Type.Type) {
case "int", "bigint", "integer": case "int", "bigint", "integer":
if !col.Type.Unsigned { if !col.Type.Unsigned {
rule = HeuristicRules["KEY.007"] rule = HeuristicRules["KEY.007"]
...@@ -2971,7 +2971,7 @@ func (q *Query4Audit) RuleFulltextIndex() Rule { ...@@ -2971,7 +2971,7 @@ func (q *Query4Audit) RuleFulltextIndex() Rule {
for _, tk := range tks { for _, tk := range tks {
switch tk.Type { switch tk.Type {
case ast.TokenTypeWord: case ast.TokenTypeWord:
if strings.TrimSpace(strings.ToUpper(tk.Val)) == "FULLTEXT" { if strings.TrimSpace(strings.ToLower(tk.Val)) == "fulltext" {
rule = HeuristicRules["KEY.010"] rule = HeuristicRules["KEY.010"]
} }
default: default:
...@@ -3001,8 +3001,8 @@ func (q *Query4Audit) RuleTimestampDefault() Rule { ...@@ -3001,8 +3001,8 @@ func (q *Query4Audit) RuleTimestampDefault() Rule {
if option.Tp == tidb.ColumnOptionDefaultValue { if option.Tp == tidb.ColumnOptionDefaultValue {
hasDefault = true hasDefault = true
if err := option.Restore(ctx); err == nil { if err := option.Restore(ctx); err == nil {
if strings.HasPrefix(sb.String(), `DEFAULT '0`) || if strings.HasPrefix(strings.ToLower(sb.String()), `default '0`) ||
strings.HasPrefix(sb.String(), `DEFAULT 0`) { strings.HasPrefix(strings.ToLower(sb.String()), `default 0`) {
hasDefault = false hasDefault = false
} }
} }
...@@ -3034,8 +3034,8 @@ func (q *Query4Audit) RuleTimestampDefault() Rule { ...@@ -3034,8 +3034,8 @@ func (q *Query4Audit) RuleTimestampDefault() Rule {
if option.Tp == tidb.ColumnOptionDefaultValue { if option.Tp == tidb.ColumnOptionDefaultValue {
hasDefault = true hasDefault = true
if err := option.Restore(ctx); err == nil { if err := option.Restore(ctx); err == nil {
if strings.HasPrefix(sb.String(), `DEFAULT '0`) || if strings.HasPrefix(strings.ToLower(sb.String()), `default '0`) ||
strings.HasPrefix(sb.String(), `DEFAULT 0`) { strings.HasPrefix(strings.ToLower(sb.String()), `default 0`) {
hasDefault = false hasDefault = false
} }
} }
...@@ -3464,7 +3464,7 @@ func (q *Query4Audit) RuleColumnNotAllowType() Rule { ...@@ -3464,7 +3464,7 @@ func (q *Query4Audit) RuleColumnNotAllowType() Rule {
switch s := q.Stmt.(type) { switch s := q.Stmt.(type) {
case *sqlparser.DDL: case *sqlparser.DDL:
switch s.Action { switch strings.ToLower(s.Action) {
case "create", "alter": case "create", "alter":
tks := ast.Tokenize(q.Query) tks := ast.Tokenize(q.Query)
for _, tk := range tks { for _, tk := range tks {
...@@ -3536,7 +3536,7 @@ func (q *Query4Audit) RuleNoOSCKey() Rule { ...@@ -3536,7 +3536,7 @@ func (q *Query4Audit) RuleNoOSCKey() Rule {
var rule = q.RuleOK() var rule = q.RuleOK()
switch s := q.Stmt.(type) { switch s := q.Stmt.(type) {
case *sqlparser.DDL: case *sqlparser.DDL:
if s.Action == "create" { if strings.ToLower(s.Action) == "create" {
pkReg := regexp.MustCompile(`(?i)(primary\s+key)`) pkReg := regexp.MustCompile(`(?i)(primary\s+key)`)
if !pkReg.MatchString(q.Query) { if !pkReg.MatchString(q.Query) {
ukReg := regexp.MustCompile(`(?i)(unique\s+((key)|(index)))`) ukReg := regexp.MustCompile(`(?i)(unique\s+((key)|(index)))`)
...@@ -3605,7 +3605,7 @@ func (idxAdv *IndexAdvisor) RuleMaxTextColsCount() Rule { ...@@ -3605,7 +3605,7 @@ func (idxAdv *IndexAdvisor) RuleMaxTextColsCount() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch stmt := node.(type) { switch stmt := node.(type) {
case *sqlparser.DDL: case *sqlparser.DDL:
if stmt.Action != "alter" { if strings.ToLower(stmt.Action) != "alter" {
return true, nil return true, nil
} }
......
...@@ -1575,6 +1575,7 @@ func TestRuleSysdate(t *testing.T) { ...@@ -1575,6 +1575,7 @@ func TestRuleSysdate(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName()) common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqls := []string{ sqls := []string{
`select sysdate();`, `select sysdate();`,
`select Sysdate();`,
} }
for _, sql := range sqls { for _, sql := range sqls {
q, err := NewQuery4Audit(sql) q, err := NewQuery4Audit(sql)
...@@ -2435,6 +2436,7 @@ func TestRuleInjection(t *testing.T) { ...@@ -2435,6 +2436,7 @@ func TestRuleInjection(t *testing.T) {
{ {
`select benchmark(10, rand())`, `select benchmark(10, rand())`,
`select sleep(1)`, `select sleep(1)`,
`select Sleep(1)`,
`select get_lock('lock_name', 1)`, `select get_lock('lock_name', 1)`,
`select release_lock('lock_name')`, `select release_lock('lock_name')`,
}, },
...@@ -2542,6 +2544,7 @@ func TestRuleTruncateTable(t *testing.T) { ...@@ -2542,6 +2544,7 @@ func TestRuleTruncateTable(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName()) common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqls := []string{ sqls := []string{
`TRUNCATE TABLE tbl_name;`, `TRUNCATE TABLE tbl_name;`,
`truncate TABLE tbl_name;`,
} }
for _, sql := range sqls { for _, sql := range sqls {
q, err := NewQuery4Audit(sql) q, err := NewQuery4Audit(sql)
...@@ -2861,6 +2864,7 @@ func TestRulePKNotInt(t *testing.T) { ...@@ -2861,6 +2864,7 @@ func TestRulePKNotInt(t *testing.T) {
}, },
{ {
"CREATE TABLE tbl (a int unsigned auto_increment, b int, primary key(`a`)) engine=InnoDB;", "CREATE TABLE tbl (a int unsigned auto_increment, b int, primary key(`a`)) engine=InnoDB;",
"CREATE TABLE `tb` ( `id` Bigint unsigned NOT NULL AUTO_INCREMENT COMMENT 'auto id', Primary key (`id`) ) ENGINE = InnoDB COMMENT 'comment'",
}, },
} }
for _, sql := range sqls[0] { for _, sql := range sqls[0] {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册