From 078704875fa6a526c4ece807170b45aa40e1a7a3 Mon Sep 17 00:00:00 2001 From: Leon Zhang Date: Thu, 30 Apr 2020 11:07:54 +0800 Subject: [PATCH] LIT.001 rule ignore PRIVLEGE statement --- advisor/heuristic.go | 9 +++++++++ advisor/heuristic_test.go | 26 +++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/advisor/heuristic.go b/advisor/heuristic.go index 6707073..780e38e 100644 --- a/advisor/heuristic.go +++ b/advisor/heuristic.go @@ -915,6 +915,15 @@ func (q *Query4Audit) RuleColCommentCheck() Rule { // RuleIPString LIT.001 func (q *Query4Audit) RuleIPString() Rule { var rule = q.RuleOK() + + for _, stmt := range q.TiStmt { + switch stmt.(type) { + case *tidb.AlterUserStmt, *tidb.CreateUserStmt, *tidb.GrantStmt, *tidb.GrantRoleStmt, + *tidb.RevokeRoleStmt, *tidb.RevokeStmt, *tidb.DropUserStmt: + return rule + } + } + re := regexp.MustCompile(`['"]\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}`) if re.FindString(q.Query) != "" { rule = HeuristicRules["LIT.001"] diff --git a/advisor/heuristic_test.go b/advisor/heuristic_test.go index f057910..8b2d63b 100644 --- a/advisor/heuristic_test.go +++ b/advisor/heuristic_test.go @@ -562,10 +562,18 @@ func TestRuleColCommentCheck(t *testing.T) { // LIT.001 func TestRuleIPString(t *testing.T) { common.Log.Debug("Entering function: %s", common.GetFunctionName()) - sqls := []string{ - "insert into tbl (IP,name) values('10.20.306.122','test')", + sqls := [][]string{ + { + "insert into tbl (IP,name) values('10.20.306.122','test')", + }, + { + `CREATE USER IF NOT EXISTS 'test'@'1.1.1.1';`, + "ALTER USER 'test'@'1.1.1.1' IDENTIFIED WITH 'mysql_native_password' AS '*xxxxx' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK;", + "GRANT SELECT ON `test`.* TO 'test'@'1.1.1.1';", + `GRANT USAGE ON *.* TO 'test'@'1.1.1.1';`, + }, } - for _, sql := range sqls { + for _, sql := range sqls[0] { q, err := NewQuery4Audit(sql) if err == nil { rule := q.RuleIPString() @@ -576,6 +584,18 @@ func TestRuleIPString(t *testing.T) { t.Error("sqlparser.Parse Error:", err) } } + + for _, sql := range sqls[1] { + q, err := NewQuery4Audit(sql) + if err == nil { + rule := q.RuleIPString() + if rule.Item != "OK" { + t.Error("Rule not match:", rule.Item, "Expect : OK") + } + } else { + t.Error("sqlparser.Parse Error:", err) + } + } common.Log.Debug("Exiting function: %s", common.GetFunctionName()) } -- GitLab