提交 342ec908 编写于 作者: martianzhang's avatar martianzhang

fix #255 key words format to lower case

上级 07870487
......@@ -107,7 +107,7 @@ func (q *Query4Audit) RulePrefixLike() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch expr := node.(type) {
case *sqlparser.ComparisonExpr:
if expr.Operator == "like" {
if strings.ToLower(expr.Operator) == "like" {
switch sqlval := expr.Right.(type) {
case *sqlparser.SQLVal:
// prefix like with '%', '_'
......@@ -130,7 +130,7 @@ func (q *Query4Audit) RuleEqualLike() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch expr := node.(type) {
case *sqlparser.ComparisonExpr:
if expr.Operator == "like" {
if strings.ToLower(expr.Operator) == "like" {
switch sqlval := expr.Right.(type) {
case *sqlparser.SQLVal:
// not start with '%', '_' && not end with '%', '_'
......@@ -397,7 +397,7 @@ func (q *Query4Audit) RuleOrderByRand() Rule {
for _, order := range n {
switch expr := order.Expr.(type) {
case *sqlparser.FuncExpr:
if expr.Name.String() == "rand" {
if strings.ToLower(expr.Name.String()) == "rand" {
rule = HeuristicRules["CLA.002"]
return false, nil
}
......@@ -761,7 +761,7 @@ func (q *Query4Audit) RuleTblCommentCheck() Rule {
var rule = q.RuleOK()
switch node := q.Stmt.(type) {
case *sqlparser.DDL:
if node.Action != "create" {
if strings.ToLower(node.Action) != "create" {
return rule
}
if node.TableSpec == nil {
......@@ -968,7 +968,7 @@ func (q *Query4Audit) RuleSQLCalcFoundRows() Rule {
var rule = q.RuleOK()
tkns := ast.Tokenizer(q.Query)
for _, tkn := range tkns {
if tkn.Val == "sql_calc_found_rows" {
if strings.ToLower(tkn.Val) == "sql_calc_found_rows" {
rule = HeuristicRules["KWR.001"]
break
}
......@@ -1049,7 +1049,7 @@ func (idxAdv *IndexAdvisor) RuleImpossibleOuterJoin() Rule {
for _, l1 := range idxAdv.joinCond {
for _, l2 := range l1 {
if l2.Table != "" && l2.Table != "dual" {
if l2.Table != "" && strings.ToLower(l2.Table) != "dual" {
joinTables = append(joinTables, l2.Table)
}
}
......@@ -1192,7 +1192,7 @@ func (q *Query4Audit) RuleImpossibleWhere() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) {
case *sqlparser.RangeCond:
if n.Operator == "between" {
if strings.ToLower(n.Operator) == "between" {
from := 0
to := 0
switch s := n.From.(type) {
......@@ -1893,7 +1893,7 @@ func (q *Query4Audit) RuleSysdate() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) {
case *sqlparser.FuncExpr:
if n.Name.String() == "sysdate" {
if strings.ToLower(n.Name.String()) == "sysdate" {
rule = HeuristicRules["FUN.004"]
return false, nil
}
......@@ -2161,7 +2161,7 @@ func (idxAdv *IndexAdvisor) RuleUpdatePrimaryKey() Rule {
return rule
}
for _, idx := range idxMeta.Rows {
if idx.KeyName == "PRIMARY" {
if strings.ToLower(idx.KeyName) == "primary" {
if col.Name == idx.ColumnName {
rule = HeuristicRules["CLA.016"]
return rule
......@@ -2310,7 +2310,7 @@ func (q *Query4Audit) RuleNot() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) {
case *sqlparser.ComparisonExpr:
if strings.HasPrefix(n.Operator, "not") {
if strings.HasPrefix(strings.ToLower(n.Operator), "not") {
rule = HeuristicRules["ARG.011"]
return false, nil
}
......@@ -2359,7 +2359,7 @@ func (q *Query4Audit) RuleUNIONUsage() Rule {
var rule = q.RuleOK()
switch s := q.Stmt.(type) {
case *sqlparser.Union:
if s.Type == "union" {
if strings.ToLower(s.Type) == "union" {
rule = HeuristicRules["SUB.002"]
}
}
......@@ -2435,11 +2435,11 @@ func (q *Query4Audit) RuleDataDrop() Rule {
var rule = q.RuleOK()
switch s := q.Stmt.(type) {
case *sqlparser.DBDDL:
if s.Action == "drop" {
if strings.ToLower(s.Action) == "drop" {
rule = HeuristicRules["SEC.003"]
}
case *sqlparser.DDL:
if s.Action == "drop" || s.Action == "truncate" {
if strings.ToLower(s.Action) == "drop" || strings.ToLower(s.Action) == "truncate" {
rule = HeuristicRules["SEC.003"]
}
case *sqlparser.Delete:
......@@ -2523,7 +2523,7 @@ func (q *Query4Audit) RuleTruncateTable() Rule {
var rule = q.RuleOK()
switch s := q.Stmt.(type) {
case *sqlparser.DDL:
if s.Action == "truncate" {
if strings.ToLower(s.Action) == "truncate" {
rule = HeuristicRules["SEC.001"]
}
}
......@@ -2536,7 +2536,7 @@ func (q *Query4Audit) RuleIn() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) {
case *sqlparser.ComparisonExpr:
switch n.Operator {
switch strings.ToLower(n.Operator) {
case "in":
switch r := n.Right.(type) {
case sqlparser.ValTuple:
......@@ -2842,12 +2842,12 @@ func (q *Query4Audit) RulePKNotInt() Rule {
var pk sqlparser.ColIdent
switch s := q.Stmt.(type) {
case *sqlparser.DDL:
if s.Action == "create" {
if strings.ToLower(s.Action) == "create" {
if s.TableSpec == nil {
return rule
}
for _, idx := range s.TableSpec.Indexes {
if idx.Info.Type == "primary key" {
if strings.ToLower(idx.Info.Type) == "primary key" {
if len(idx.Columns) == 1 {
pk = idx.Columns[0].Column
break
......@@ -2864,7 +2864,7 @@ func (q *Query4Audit) RulePKNotInt() Rule {
// 主键非int, bigint类型
for _, col := range s.TableSpec.Columns {
if pk.String() == col.Name.String() {
switch col.Type.Type {
switch strings.ToLower(col.Type.Type) {
case "int", "bigint", "integer":
if !col.Type.Unsigned {
rule = HeuristicRules["KEY.007"]
......@@ -2971,7 +2971,7 @@ func (q *Query4Audit) RuleFulltextIndex() Rule {
for _, tk := range tks {
switch tk.Type {
case ast.TokenTypeWord:
if strings.TrimSpace(strings.ToUpper(tk.Val)) == "FULLTEXT" {
if strings.TrimSpace(strings.ToLower(tk.Val)) == "fulltext" {
rule = HeuristicRules["KEY.010"]
}
default:
......@@ -3001,8 +3001,8 @@ func (q *Query4Audit) RuleTimestampDefault() Rule {
if option.Tp == tidb.ColumnOptionDefaultValue {
hasDefault = true
if err := option.Restore(ctx); err == nil {
if strings.HasPrefix(sb.String(), `DEFAULT '0`) ||
strings.HasPrefix(sb.String(), `DEFAULT 0`) {
if strings.HasPrefix(strings.ToLower(sb.String()), `default '0`) ||
strings.HasPrefix(strings.ToLower(sb.String()), `default 0`) {
hasDefault = false
}
}
......@@ -3034,8 +3034,8 @@ func (q *Query4Audit) RuleTimestampDefault() Rule {
if option.Tp == tidb.ColumnOptionDefaultValue {
hasDefault = true
if err := option.Restore(ctx); err == nil {
if strings.HasPrefix(sb.String(), `DEFAULT '0`) ||
strings.HasPrefix(sb.String(), `DEFAULT 0`) {
if strings.HasPrefix(strings.ToLower(sb.String()), `default '0`) ||
strings.HasPrefix(strings.ToLower(sb.String()), `default 0`) {
hasDefault = false
}
}
......@@ -3464,7 +3464,7 @@ func (q *Query4Audit) RuleColumnNotAllowType() Rule {
switch s := q.Stmt.(type) {
case *sqlparser.DDL:
switch s.Action {
switch strings.ToLower(s.Action) {
case "create", "alter":
tks := ast.Tokenize(q.Query)
for _, tk := range tks {
......@@ -3536,7 +3536,7 @@ func (q *Query4Audit) RuleNoOSCKey() Rule {
var rule = q.RuleOK()
switch s := q.Stmt.(type) {
case *sqlparser.DDL:
if s.Action == "create" {
if strings.ToLower(s.Action) == "create" {
pkReg := regexp.MustCompile(`(?i)(primary\s+key)`)
if !pkReg.MatchString(q.Query) {
ukReg := regexp.MustCompile(`(?i)(unique\s+((key)|(index)))`)
......@@ -3605,7 +3605,7 @@ func (idxAdv *IndexAdvisor) RuleMaxTextColsCount() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch stmt := node.(type) {
case *sqlparser.DDL:
if stmt.Action != "alter" {
if strings.ToLower(stmt.Action) != "alter" {
return true, nil
}
......
......@@ -1575,6 +1575,7 @@ func TestRuleSysdate(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqls := []string{
`select sysdate();`,
`select Sysdate();`,
}
for _, sql := range sqls {
q, err := NewQuery4Audit(sql)
......@@ -2435,6 +2436,7 @@ func TestRuleInjection(t *testing.T) {
{
`select benchmark(10, rand())`,
`select sleep(1)`,
`select Sleep(1)`,
`select get_lock('lock_name', 1)`,
`select release_lock('lock_name')`,
},
......@@ -2542,6 +2544,7 @@ func TestRuleTruncateTable(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqls := []string{
`TRUNCATE TABLE tbl_name;`,
`truncate TABLE tbl_name;`,
}
for _, sql := range sqls {
q, err := NewQuery4Audit(sql)
......@@ -2861,6 +2864,7 @@ func TestRulePKNotInt(t *testing.T) {
},
{
"CREATE TABLE tbl (a int unsigned auto_increment, b int, primary key(`a`)) engine=InnoDB;",
"CREATE TABLE `tb` ( `id` Bigint unsigned NOT NULL AUTO_INCREMENT COMMENT 'auto id', Primary key (`id`) ) ENGINE = InnoDB COMMENT 'comment'",
},
}
for _, sql := range sqls[0] {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册