提交 b9e152bf 编写于 作者: martianzhang's avatar martianzhang

sampling data type trading

上级 86da258b
......@@ -23,6 +23,7 @@ import (
)
func TestDigestExplainText(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
var text = `+----+-------------+---------+-------+---------------------------------------------------------+-------------------+---------+---------------------------+------+-------------+
| id | select_type | table | type | possible_keys | key | key_len | ref | rows | Extra |
+----+-------------+---------+-------+---------------------------------------------------------+-------------------+---------+---------------------------+------+-------------+
......@@ -34,4 +35,5 @@ func TestDigestExplainText(t *testing.T) {
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -35,7 +35,7 @@ func TestRuleImplicitAlias(t *testing.T) {
"select col from tbl tb where id < 1000",
},
{
"do 1",
"select 1",
},
}
for _, sql := range sqls[0] {
......
......@@ -95,8 +95,8 @@ func TestRuleImplicitConversion(t *testing.T) {
}
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
common.Config.OnlineDSN = dsn
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
// JOI.003 & JOI.004
......@@ -383,9 +383,11 @@ func TestDuplicateKeyChecker(t *testing.T) {
if len(rule) != 0 {
t.Errorf("got rules: %s", pretty.Sprint(rule))
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestMergeAdvices(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
dst := []IndexInfo{
{
Name: "test",
......@@ -405,6 +407,7 @@ func TestMergeAdvices(t *testing.T) {
if len(advise) != 1 {
t.Error(pretty.Sprint(advise))
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestIdxColsTypeCheck(t *testing.T) {
......@@ -450,13 +453,16 @@ func TestIdxColsTypeCheck(t *testing.T) {
}
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestGetRandomIndexSuffix(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
for i := 0; i < 5; i++ {
r := getRandomIndexSuffix()
if !(strings.HasPrefix(r, "_") && len(r) == 5) {
t.Errorf("getRandomIndexSuffix should return a string with prefix `_` and 5 length, but got:%s", r)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -58,7 +58,7 @@ func NewQuery4Audit(sql string, options ...string) (*Query4Audit, error) {
// vitess 语法解析不上报,以 tidb parser 为主
q.Stmt, vErr = sqlparser.Parse(sql)
if vErr != nil {
common.Log.Warn("NewQuery4Audit vitess parse Error: %s", vErr.Error())
common.Log.Warn("NewQuery4Audit vitess parse Error: %s, Query: %s", vErr.Error(), sql)
}
// TODO: charset, collation
......
......@@ -23,29 +23,37 @@ import (
)
func TestListTestSQLs(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() { ListTestSQLs() }, t.Name(), update)
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestListHeuristicRules(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() { ListHeuristicRules(HeuristicRules) }, t.Name(), update)
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestInBlackList(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
common.BlackList = []string{"select"}
if !InBlackList("select 1") {
t.Error("should be true")
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestIsIgnoreRule(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
common.Config.IgnoreRules = []string{"test"}
if !IsIgnoreRule("test") {
t.Error("should be true")
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -27,6 +27,7 @@ import (
)
func TestGetTableFromExprs(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
tbExprs := sqlparser.TableExprs{
&sqlparser.AliasedTableExpr{
Expr: sqlparser.TableName{
......@@ -40,9 +41,11 @@ func TestGetTableFromExprs(t *testing.T) {
if tb, ok := meta["db"]; !ok {
t.Errorf("no table qualifier, meta: %s", pretty.Sprint(tb))
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestGetParseTableWithStmt(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
for _, sql := range common.TestSQLs {
fmt.Println(sql)
stmt, err := sqlparser.Parse(sql)
......@@ -53,9 +56,11 @@ func TestGetParseTableWithStmt(t *testing.T) {
pretty.Println(meta)
fmt.Println()
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindCondition(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
for _, sql := range common.TestSQLs {
fmt.Println(sql)
stmt, err := sqlparser.Parse(sql)
......@@ -71,9 +76,11 @@ func TestFindCondition(t *testing.T) {
pretty.Println(inEq)
fmt.Println()
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindGroupBy(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"select a from t group by c",
}
......@@ -88,9 +95,11 @@ func TestFindGroupBy(t *testing.T) {
pretty.Println(res)
fmt.Println()
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindOrderBy(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"select a from t group by c order by d, c desc",
"select a from t group by c order by d desc",
......@@ -106,9 +115,11 @@ func TestFindOrderBy(t *testing.T) {
pretty.Println(res)
fmt.Println()
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindSubquery(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"SELECT * FROM t1 WHERE column1 = (SELECT column1 FROM (SELECT column1 FROM t2) a);",
"select column1 from t2",
......@@ -127,10 +138,11 @@ func TestFindSubquery(t *testing.T) {
fmt.Println(len(subquery))
pretty.Println(subquery)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindJoinTable(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
......@@ -151,9 +163,11 @@ func TestFindJoinTable(t *testing.T) {
joinMeta := FindJoinTable(stmt, nil)
pretty.Println(joinMeta)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindJoinCols(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
"select t from a LEFT JOIN b USING (c1, c2, c3)",
......@@ -175,9 +189,11 @@ func TestFindJoinCols(t *testing.T) {
columns := FindJoinCols(stmt)
pretty.Println(columns)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindJoinColBeWhereEQ(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
"SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
......@@ -197,9 +213,11 @@ func TestFindJoinColBeWhereEQ(t *testing.T) {
columns := FindEQColsInJoinCond(stmt)
pretty.Println(columns)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindJoinColBeWhereINEQ(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
"SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
......@@ -219,9 +237,11 @@ func TestFindJoinColBeWhereINEQ(t *testing.T) {
columns := FindINEQColsInJoinCond(stmt)
pretty.Println(columns)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindAllCondition(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
"select t from a LEFT JOIN b USING (c1, c2, c3)",
......@@ -247,9 +267,11 @@ func TestFindAllCondition(t *testing.T) {
columns := FindAllCondition(stmt)
pretty.Println(columns)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindColumn(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"select col, col2, sum(col1) from tb group by col",
"select col from tb group by col,sum(col1)",
......@@ -266,9 +288,11 @@ func TestFindColumn(t *testing.T) {
columns := FindColumn(stmt)
pretty.Println(columns)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindAllCols(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"select * from tb where a = '1' order by c",
"select * from tb where a = '1' group by c",
......@@ -296,9 +320,11 @@ func TestFindAllCols(t *testing.T) {
t.Error(fmt.Errorf("want 'c' got %v", columns))
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestGetSubqueryDepth(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
"select t from a LEFT JOIN b USING (c1, c2, c3)",
......@@ -323,9 +349,11 @@ func TestGetSubqueryDepth(t *testing.T) {
dep := GetSubqueryDepth(stmt)
fmt.Println(dep)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestAppendTable(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqlList := []string{
"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
}
......@@ -367,4 +395,5 @@ func TestAppendTable(t *testing.T) {
if meta[""].Table["customer_list"].TableAliases[0] != "l" || meta[""].Table["city"].TableAliases[0] != "c" {
t.Error("alias filed\n", pretty.Sprint(meta))
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -128,6 +128,7 @@ var TestSqlsPretty = []string{
}
func TestPretty(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() {
for _, sql := range append(TestSqlsPretty, common.TestSQLs...) {
fmt.Println(sql)
......@@ -137,9 +138,11 @@ func TestPretty(t *testing.T) {
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestIsKeyword(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
tks := map[string]bool{
"AGAINST": true,
"AUTO_INCREMENT": true,
......@@ -155,9 +158,11 @@ func TestIsKeyword(t *testing.T) {
t.Error("isKeyword:", tk)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRemoveComments(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
for _, sql := range TestSqlsPretty {
stmt, _ := sqlparser.Parse(sql)
newSQL := sqlparser.String(stmt)
......@@ -165,9 +170,11 @@ func TestRemoveComments(t *testing.T) {
fmt.Print(newSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestMysqlEscapeString(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
var strs = []map[string]string{
{
"input": "abc",
......@@ -198,4 +205,5 @@ abc`,
}
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -25,6 +25,7 @@ import (
)
func TestRewrite(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgTestDSNStatus := common.Config.TestDSN.Disable
common.Config.TestDSN.Disable = false
testSQL := []map[string]string{
......@@ -99,9 +100,11 @@ func TestRewrite(t *testing.T) {
}
}
common.Config.TestDSN.Disable = orgTestDSNStatus
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteStar2Columns(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgTestDSNStatus := common.Config.TestDSN.Disable
common.Config.TestDSN.Disable = false
testSQL := []map[string]string{
......@@ -131,9 +134,11 @@ func TestRewriteStar2Columns(t *testing.T) {
}
}
common.Config.TestDSN.Disable = orgTestDSNStatus
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteInsertColumns(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": `insert into film values(1,2,3,4,5)`,
......@@ -173,9 +178,11 @@ func TestRewriteInsertColumns(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteHaving(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": `SELECT state, COUNT(*) FROM Drivers GROUP BY state HAVING state IN ('GA', 'TX') ORDER BY state`,
......@@ -196,9 +203,11 @@ func TestRewriteHaving(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteAddOrderByNull(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": "SELECT sum(col1) FROM tbl GROUP BY col",
......@@ -211,9 +220,11 @@ func TestRewriteAddOrderByNull(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteRemoveDMLOrderBy(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": "DELETE FROM tbl WHERE col1=1 ORDER BY col",
......@@ -230,9 +241,11 @@ func TestRewriteRemoveDMLOrderBy(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteGroupByConst(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": "select 1;",
......@@ -259,9 +272,11 @@ func TestRewriteGroupByConst(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteStandard(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": "SELECT sum(col1) FROM tbl GROUP BY 1;",
......@@ -274,9 +289,11 @@ func TestRewriteStandard(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteCountStar(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": "SELECT count(col) FROM tbl GROUP BY 1;",
......@@ -293,9 +310,11 @@ func TestRewriteCountStar(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteInnoDB(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": "CREATE TABLE t1(id bigint(20) NOT NULL AUTO_INCREMENT);",
......@@ -312,9 +331,11 @@ func TestRewriteInnoDB(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteAutoIncrement(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": "CREATE TABLE t1(id bigint(20) NOT NULL AUTO_INCREMENT) ENGINE=InnoDB AUTO_INCREMENT=123802;",
......@@ -331,9 +352,11 @@ func TestRewriteAutoIncrement(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteIntWidth(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": "CREATE TABLE t1(id bigint(10) NOT NULL AUTO_INCREMENT) ENGINE=InnoDB AUTO_INCREMENT=123802;",
......@@ -358,9 +381,11 @@ func TestRewriteIntWidth(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteAlwaysTrue(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": "SELECT count(col) FROM tbl where 1=1;",
......@@ -427,10 +452,12 @@ func TestRewriteAlwaysTrue(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
// TODO:
func TestRewriteSubQuery2Join(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgTestDSNStatus := common.Config.TestDSN.Disable
common.Config.TestDSN.Disable = true
testSQL := []map[string]string{
......@@ -458,9 +485,11 @@ func TestRewriteSubQuery2Join(t *testing.T) {
}
}
common.Config.TestDSN.Disable = orgTestDSNStatus
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteDML2Select(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": "DELETE city, country FROM city INNER JOIN country using (country_id) WHERE city.city_id = 1;",
......@@ -513,9 +542,11 @@ func TestRewriteDML2Select(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteDistinctStar(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": `SELECT DISTINCT * FROM film;`,
......@@ -549,9 +580,11 @@ func TestRewriteDistinctStar(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestMergeAlterTables(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqls := []string{
// ADD|DROP INDEX
// TODO: PRIMARY KEY, [UNIQUE|FULLTEXT|SPATIAL] INDEX
......@@ -602,9 +635,11 @@ func TestMergeAlterTables(t *testing.T) {
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteUnionAll(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": `select country_id from city union select country_id from country;`,
......@@ -617,8 +652,10 @@ func TestRewriteUnionAll(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteTruncate(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": `delete from tbl;`,
......@@ -631,9 +668,11 @@ func TestRewriteTruncate(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRewriteOr2In(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": `select country_id from city where country_id = 1 or country_id = 2 or country_id = 3;`,
......@@ -672,9 +711,11 @@ func TestRewriteOr2In(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRmParenthesis(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []map[string]string{
{
"input": `select country_id from city where (country_id = 1);`,
......@@ -699,13 +740,16 @@ func TestRmParenthesis(t *testing.T) {
t.Errorf("want: %s\ngot: %s", sql["output"], rw.NewSQL)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestListRewriteRules(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() {
ListRewriteRules(RewriteRules)
}, t.Name(), update)
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -26,6 +26,7 @@ import (
)
func TestTokenize(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() {
for _, sql := range common.TestSQLs {
fmt.Println(sql)
......@@ -35,9 +36,11 @@ func TestTokenize(t *testing.T) {
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestTokenizer(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqls := []string{
"select c1,c2,c3 from t1,t2 join t3 on t1.c1=t2.c1 and t1.c3=t3.c1 where id>1000",
"select sourcetable, if(f.lastcontent = ?, f.lastupdate, f.lastcontent) as lastactivity, f.totalcount as activity, type.class as type, (f.nodeoptions & ?) as nounsubscribe from node as f inner join contenttype as type on type.contenttypeid = f.contenttypeid inner join subscribed as sd on sd.did = f.nodeid and sd.userid = ? union all select f.name as title, f.userid as keyval, ? as sourcetable, ifnull(f.lastpost, f.joindate) as lastactivity, f.posts as activity, ? as type, ? as nounsubscribe from user as f inner join userlist as ul on ul.relationid = f.userid and ul.userid = ? where ul.type = ? and ul.aq = ? order by title limit ?",
......@@ -57,9 +60,11 @@ func TestTokenizer(t *testing.T) {
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestGetQuotedString(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
var str = []string{
`"hello world"`,
"`hello world`",
......@@ -82,9 +87,11 @@ func TestGetQuotedString(t *testing.T) {
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestCompress(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() {
for _, sql := range common.TestSQLs {
fmt.Println(sql)
......@@ -94,10 +101,11 @@ func TestCompress(t *testing.T) {
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFormat(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() {
for _, sql := range common.TestSQLs {
fmt.Println(sql)
......@@ -107,9 +115,11 @@ func TestFormat(t *testing.T) {
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestSplitStatement(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
bufs := [][]byte{
[]byte("select * from test;hello"),
[]byte("select 'asd;fas', col from test;hello"),
......@@ -181,9 +191,11 @@ select col from tb;
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestLeftNewLines(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
bufs := [][]byte{
[]byte(`
select * from test;hello`),
......@@ -200,9 +212,11 @@ func TestLeftNewLines(t *testing.T) {
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestNewLines(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
bufs := [][]byte{
[]byte(`
select * from test;hello`),
......@@ -219,4 +233,5 @@ func TestNewLines(t *testing.T) {
if nil != err {
t.Fatal(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -56,6 +56,7 @@ type Configuration struct {
OnlySyntaxCheck bool `yaml:"only-syntax-check"` // 只做语法检查不输出优化建议
SamplingStatisticTarget int `yaml:"sampling-statistic-target"` // 数据采样因子,对应 PostgreSQL 的 default_statistics_target
Sampling bool `yaml:"sampling"` // 数据采样开关
SamplingCondition string `yaml:"sampling-condition"` // 指定采样条件,如:WHERE xxx LIMIT xxx;
Profiling bool `yaml:"profiling"` // 在开启数据采样的情况下,在测试环境执行进行profile
Trace bool `yaml:"trace"` // 在开启数据采样的情况下,在测试环境执行进行Trace
Explain bool `yaml:"explain"` // Explain开关
......@@ -506,6 +507,7 @@ func readCmdFlags() error {
explain := flag.Bool("explain", Config.Explain, "Explain, 是否开启Explain执行计划分析")
sampling := flag.Bool("sampling", Config.Sampling, "Sampling, 数据采样开关")
samplingStatisticTarget := flag.Int("sampling-statistic-target", Config.SamplingStatisticTarget, "SamplingStatisticTarget, 数据采样因子,对应 PostgreSQL 的 default_statistics_target")
samplingCondition := flag.String("sampling-condition", Config.SamplingCondition, "SamplingCondition, 数据采样条件,如: WHERE xxx LIMIT xxx")
delimiter := flag.String("delimiter", Config.Delimiter, "Delimiter, SQL分隔符")
// +++++++++++++++日志相关+++++++++++++++++
logLevel := flag.Int("log-level", Config.LogLevel, "LogLevel, 日志级别, [0:Emergency, 1:Alert, 2:Critical, 3:Error, 4:Warning, 5:Notice, 6:Informational, 7:Debug]")
......@@ -585,6 +587,7 @@ func readCmdFlags() error {
Config.Explain = *explain
Config.Sampling = *sampling
Config.SamplingStatisticTarget = *samplingStatisticTarget
Config.SamplingCondition = *samplingCondition
Config.LogLevel = *logLevel
if strings.HasPrefix(*logOutput, "/") {
......
......@@ -26,6 +26,10 @@ import (
var update = flag.Bool("update", false, "update .golden files")
func init() {
BaseDir = DevPath
}
func TestParseConfig(t *testing.T) {
err := ParseConfig("")
if err != nil {
......@@ -37,7 +41,7 @@ func TestReadConfigFile(t *testing.T) {
if Config == nil {
Config = new(Configuration)
}
Config.readConfigFile("../soar.yaml")
Config.readConfigFile(DevPath + "/soar.yaml")
}
func TestParseDSN(t *testing.T) {
......
......@@ -21,15 +21,11 @@ import (
"testing"
)
func init() {
BaseDir = DevPath
}
func TestLogger(t *testing.T) {
Log.Info("info")
Log.Debug("debug")
Log.Warning("warning")
Log.Error("error")
Log.Info("TestLogger_Info")
Log.Debug("TestLogger_Debug")
Log.Warning("TestLogger_Warning")
Log.Error("Warning_Error")
}
func TestCaller(t *testing.T) {
......@@ -47,7 +43,7 @@ func TestGetFunctionName(t *testing.T) {
}
func TestIfError(t *testing.T) {
err := errors.New("test")
err := errors.New("TestIfError")
LogIfError(err, "")
LogIfError(err, "func %s", "func_test")
}
......
......@@ -2332,6 +2332,7 @@ possible_keys: idx_fk_country_id,idx_country_id_city,idx_all,idx_other
}
func TestExplain(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
// TraditionalFormatExplain
for idx, sql := range sqls {
exp, err := connTest.Explain(sql, TraditionalExplainType, TraditionalFormatExplain)
......@@ -2350,9 +2351,11 @@ func TestExplain(t *testing.T) {
pretty.Println("No.:", idx, "\nOld: ", sql, "\nNew: ", exp.SQL)
pretty.Println(exp)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestParseExplainText(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
for _, content := range exp {
pretty.Println(RemoveSQLComments(content))
pretty.Println(ParseExplainText(content))
......@@ -2364,26 +2367,32 @@ func TestParseExplainText(t *testing.T) {
pretty.Println(explainInfo)
fmt.Println(err)
*/
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindTablesInJson(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
idx := 9
for _, j := range exp[idx : idx+1] {
pretty.Println(j)
findTablesInJSON(j, 0)
}
pretty.Println(len(explainJSONTables), explainJSONTables)
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFormatJsonIntoTraditional(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
idx := 11
for _, j := range exp[idx : idx+1] {
pretty.Println(j)
pretty.Println(FormatJSONIntoTraditional(j))
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestPrintMarkdownExplainTable(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain)
if err != nil {
t.Error(err)
......@@ -2395,9 +2404,11 @@ func TestPrintMarkdownExplainTable(t *testing.T) {
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestExplainInfoTranslator(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain)
if err != nil {
t.Error(err)
......@@ -2408,9 +2419,11 @@ func TestExplainInfoTranslator(t *testing.T) {
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestMySQLExplainWarnings(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain)
if err != nil {
t.Error(err)
......@@ -2421,9 +2434,11 @@ func TestMySQLExplainWarnings(t *testing.T) {
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestMySQLExplainQueryCost(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
expInfo, err := connTest.Explain("select 1", TraditionalExplainType, TraditionalFormatExplain)
if err != nil {
t.Error(err)
......@@ -2434,19 +2449,24 @@ func TestMySQLExplainQueryCost(t *testing.T) {
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestSupportExplainWrite(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
_, err := connTest.supportExplainWrite()
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestExplainAbleSQL(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
for _, sql := range sqls {
if _, err := connTest.explainAbleSQL(sql); err != nil {
t.Errorf("SQL: %s, not explain able", sql)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -104,8 +104,9 @@ func (db *Connector) Query(sql string, params ...interface{}) (QueryResult, erro
if common.Config.ShowLastQueryCost {
cost, err := db.Conn.Query("SHOW SESSION STATUS LIKE 'last_query_cost'")
if err == nil {
var varName string
if cost.Next() {
err = cost.Scan(res.QueryCost)
err = cost.Scan(&varName, &res.QueryCost)
common.LogIfError(err, "")
}
if err := cost.Close(); err != nil {
......
......@@ -48,6 +48,7 @@ func init() {
}
func TestQuery(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
res, err := connTest.Query("select 0")
if err != nil {
t.Error(err.Error())
......@@ -64,9 +65,11 @@ func TestQuery(t *testing.T) {
}
res.Rows.Close()
// TODO: timeout test
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestColumnCardinality(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database
connTest.Database = "sakila"
a := connTest.ColumnCardinality("actor", "first_name")
......@@ -74,9 +77,11 @@ func TestColumnCardinality(t *testing.T) {
t.Error("sakila.actor.first_name cardinality should in [0, 1], now it's", a)
}
connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestDangerousSQL(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testCase := map[string]bool{
"select * from tb;delete from tb;": true,
"show database;": false,
......@@ -91,9 +96,11 @@ func TestDangerousSQL(t *testing.T) {
t.Errorf("SQL:%s got:%v want:%v", sql, got, want)
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestWarningsAndQueryCost(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
common.Config.ShowWarnings = true
common.Config.ShowLastQueryCost = true
res, err := connTest.Query("explain select * from sakila.film")
......@@ -111,17 +118,21 @@ func TestWarningsAndQueryCost(t *testing.T) {
res.Warning.Close()
fmt.Println(res.QueryCost, err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestVersion(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
version, err := connTest.Version()
if err != nil {
t.Error(err.Error())
}
fmt.Println(version)
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestRemoveSQLComments(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
SQLs := []string{
`-- comment`,
`--`,
......@@ -140,9 +151,11 @@ comment*/`,
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestSingleIntValue(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
val, err := connTest.SingleIntValue("read_only")
if err != nil {
t.Error(err)
......@@ -150,13 +163,16 @@ func TestSingleIntValue(t *testing.T) {
if val < 0 {
t.Error("SingleIntValue, return should large than zero")
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestIsView(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
originalDatabase := connTest.Database
connTest.Database = "sakila"
if !connTest.IsView("actor_info") {
t.Error("actor_info should be a VIEW")
}
connTest.Database = originalDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -16,9 +16,14 @@
package database
import "testing"
import (
"testing"
"github.com/XiaoMi/soar/common"
)
func TestCurrentUser(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
user, host, err := connTest.CurrentUser()
if err != nil {
t.Error(err.Error())
......@@ -26,16 +31,21 @@ func TestCurrentUser(t *testing.T) {
if user != "root" || host != "%" {
t.Errorf("Want user: root, host: %%. Get user: %s, host: %s", user, host)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestHasSelectPrivilege(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
if !connTest.HasSelectPrivilege() {
t.Errorf("DSN: %s, User: %s, should has select privilege", connTest.Addr, connTest.User)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestHasAllPrivilege(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
if !connTest.HasAllPrivilege() {
t.Errorf("DSN: %s, User: %s, should has all privilege", connTest.Addr, connTest.User)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -19,21 +19,26 @@ package database
import (
"testing"
"github.com/XiaoMi/soar/common"
"github.com/kr/pretty"
)
func TestProfiling(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
rows, err := connTest.Profiling("select 1")
if err != nil {
t.Error(err)
}
pretty.Println(rows)
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFormatProfiling(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
res, err := connTest.Profiling("select 1")
if err != nil {
t.Error(err)
}
pretty.Println(FormatProfiling(res))
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -17,11 +17,15 @@
package database
import (
"database/sql"
"fmt"
"time"
"github.com/XiaoMi/soar/common"
"strings"
"database/sql"
"github.com/XiaoMi/soar/common"
"github.com/ziutek/mymysql/mysql"
)
/*--------------------
......@@ -44,99 +48,125 @@ import (
*--------------------
*/
// SamplingData 将数据从Remote拉取到 db 中
func (db *Connector) SamplingData(remote *Connector, tables ...string) error {
// SamplingData 将数据从 onlineConn 拉取到 db 中
func (db *Connector) SamplingData(onlineConn *Connector, database string, tables ...string) error {
var err error
if database == db.Database {
return fmt.Errorf("SamplingData the same database, From: %s/%s, To: %s/%s", onlineConn.Addr, database, db.Addr, db.Database)
}
// 计算需要泵取的数据量
wantRowsCount := 300 * common.Config.SamplingStatisticTarget
// 设置数据采样单条 SQL 中 value 的数量
// 该数值越大,在内存中缓存的data就越多,但相对的,插入时速度就越快
maxValCount := 200
for _, table := range tables {
// 表类型检查
if remote.IsView(table) {
return nil
}
tableStatus, err := remote.ShowTableStatus(table)
if err != nil {
return err
}
if len(tableStatus.Rows) == 0 {
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table)
if onlineConn.IsView(table) {
return nil
}
tableRows := tableStatus.Rows[0].Rows
if tableRows == 0 {
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table)
return nil
// generate where condition
var where string
if common.Config.SamplingCondition == "" {
tableStatus, err := onlineConn.ShowTableStatus(table)
if err != nil {
return err
}
if len(tableStatus.Rows) == 0 {
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table)
return nil
}
tableRows := tableStatus.Rows[0].Rows
if tableRows == 0 {
common.Log.Info("SamplingData, Table %s with no data, stop sampling", table)
return nil
}
factor := float64(wantRowsCount) / float64(tableRows)
common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor)
where = fmt.Sprintf("WHERE RAND() <= %f LIMIT %d", factor, wantRowsCount)
if factor >= 1 {
where = ""
}
} else {
where = common.Config.SamplingCondition
}
factor := float64(wantRowsCount) / float64(tableRows)
common.Log.Debug("SamplingData, tableRows: %d, wantRowsCount: %d, factor: %f", tableRows, wantRowsCount, factor)
err = startSampling(remote.Conn, db.Conn, db.Database, table, factor, wantRowsCount, maxValCount)
if err != nil {
common.Log.Error("(db *Connector) SamplingData Error : %v", err)
}
err = db.startSampling(onlineConn.Conn, database, table, where)
}
return nil
return err
}
// startSampling sampling data from OnlineDSN to TestDSN
// 因为涉及到的数据量问题,所以泵取与插入时同时进行的
// TODO: 加 ref link
func startSampling(conn, localConn *sql.DB, database, table string, factor float64, wants, maxValCount int) error {
// generate where condition
where := fmt.Sprintf("WHERE RAND() <= %f", factor)
if factor >= 1 {
where = ""
}
res, err := conn.Query(fmt.Sprintf("SELECT * FROM `%s`.`%s` %s LIMIT %d;", database, table, where, wants))
func (db *Connector) startSampling(onlineConn *sql.DB, database, table string, where string) error {
samplingQuery := fmt.Sprintf("SELECT * FROM `%s`.`%s` %s", database, table, where)
common.Log.Debug("startSampling with Query: %s", samplingQuery)
res, err := onlineConn.Query(samplingQuery)
if err != nil {
return err
}
// column info
// columns list
columns, err := res.Columns()
if err != nil {
return err
}
row := make(map[string][]byte, len(columns))
row := make([][]byte, len(columns))
tableFields := make([]interface{}, 0)
for _, col := range columns {
if _, ok := row[col]; ok {
tableFields = append(tableFields, row[col])
}
for i := range columns {
tableFields = append(tableFields, &row[i])
}
columnTypes, err := res.ColumnTypes()
if err != nil {
return err
}
// sampling data
var valuesStr string
var values []string
var valuesCount int
var valuesStr []string
maxValuesCount := 200 // one time insert values count, TODO: config able
columnsStr := "`" + strings.Join(columns, "`,`") + "`"
for res.Next() {
var values []string
res.Scan(tableFields...)
for _, val := range row {
values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val)))
for i, val := range row {
if val == nil {
values = append(values, "NULL")
} else {
switch columnTypes[i].DatabaseTypeName() {
case "TIMESTAMP", "DATETIME":
t, err := time.Parse(time.RFC3339, string(val))
common.LogIfWarn(err, "")
values = append(values, fmt.Sprintf(`"%s"`, mysql.TimeString(t)))
default:
values = append(values, fmt.Sprintf(`unhex("%s")`, fmt.Sprintf("%x", val)))
}
}
}
valuesStr = append(valuesStr, "("+strings.Join(values, `,`)+")")
valuesCount++
if maxValuesCount <= valuesCount {
err = db.doSampling(table, columnsStr, strings.Join(valuesStr, `,`))
if err != nil {
break
}
values = make([]string, 0)
valuesStr = make([]string, 0)
valuesCount = 0
}
valuesStr = fmt.Sprintf(`(%s)`, strings.Join(values, `,`))
doSampling(localConn, database, table, columnsStr, valuesStr)
}
res.Close()
return nil
return err
}
// 将泵取的数据转换成Insert语句并在数据库中执行
func doSampling(conn *sql.DB, dbName, table, colDef, values string) {
query := fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES %s;", dbName, table,
colDef, values)
_, err := conn.Exec(query)
if err != nil {
common.Log.Error("doSampling Error from %s.%s: %v", dbName, table, err)
// 将泵取的数据转换成 insert 语句并在 testConn 数据库中执行
func (db *Connector) doSampling(table, colDef, values string) error {
// db.Database is hashed database name
query := fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES %s;", db.Database, table, colDef, values)
res, err := db.Query(query)
if res.Rows != nil {
res.Rows.Close()
}
return err
}
/*
* Copyright 2018 Xiaomi, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"testing"
"github.com/XiaoMi/soar/common"
)
func init() {
common.BaseDir = common.DevPath
}
func TestSamplingData(t *testing.T) {
connOnline, err := NewConnector(common.Config.OnlineDSN)
if err != nil {
t.Error(err)
}
err = connTest.SamplingData(connOnline, "film")
if err != nil {
t.Error(err)
}
}
......@@ -459,27 +459,27 @@ func (db *Connector) ShowCreateTable(tableName string) (string, error) {
ddl, err := db.showCreate("table", tableName)
// 去除外键关联条件
var noConstraint []string
relationReg, _ := regexp.Compile("CONSTRAINT")
for _, line := range strings.Split(ddl, "\n") {
if relationReg.Match([]byte(line)) {
continue
}
// 去除外键语句会使DDL中多一个','导致语法错误,要把多余的逗号去除
if strings.Index(line, ")") == 0 {
lineWrongSyntax := noConstraint[len(noConstraint)-1]
// 如果')'前一句的末尾是',' 删除 ',' 保证语法正确性
if strings.Index(lineWrongSyntax, ",") == len(lineWrongSyntax)-1 {
noConstraint[len(noConstraint)-1] = lineWrongSyntax[:len(lineWrongSyntax)-1]
lines := strings.Split(ddl, "\n")
// CREATE VIEW ONLY 1 LINE
if len(lines) > 2 {
var noConstraint []string
relationReg, _ := regexp.Compile("CONSTRAINT")
for _, line := range lines[1 : len(lines)-1] {
if relationReg.Match([]byte(line)) {
continue
}
line = strings.TrimSuffix(line, ",")
noConstraint = append(noConstraint, line)
}
noConstraint = append(noConstraint, line)
// 去除外键语句会使DDL中多一个','导致语法错误,要把多余的逗号去除
ddl = fmt.Sprint(
lines[0], "\n",
strings.Join(noConstraint, ",\n"), "\n",
lines[len(lines)-1],
)
}
return strings.Join(noConstraint, "\n"), err
return ddl, err
}
// FindColumn find column
......
......@@ -26,6 +26,7 @@ import (
)
func TestShowTableStatus(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database
connTest.Database = "sakila"
ts, err := connTest.ShowTableStatus("film")
......@@ -47,9 +48,11 @@ func TestShowTableStatus(t *testing.T) {
}
pretty.Println(ts)
connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestShowTables(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database
connTest.Database = "sakila"
ts, err := connTest.ShowTables()
......@@ -66,23 +69,29 @@ func TestShowTables(t *testing.T) {
t.Error(err)
}
connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestShowCreateDatabase(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
err := common.GoldenDiff(func() {
fmt.Println(connTest.ShowCreateDatabase("sakila"))
}, t.Name(), update)
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestShowCreateTable(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database
connTest.Database = "sakila"
tables := []string{
"film",
"category",
"customer_list",
"inventory",
}
err := common.GoldenDiff(func() {
for _, table := range tables {
......@@ -97,9 +106,11 @@ func TestShowCreateTable(t *testing.T) {
t.Error(err)
}
connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestShowIndex(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database
connTest.Database = "sakila"
ti, err := connTest.ShowIndex("film")
......@@ -114,11 +125,12 @@ func TestShowIndex(t *testing.T) {
if err != nil {
t.Error(err)
}
connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestShowColumns(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
orgDatabase := connTest.Database
connTest.Database = "sakila"
ti, err := connTest.ShowColumns("actor_info")
......@@ -134,9 +146,11 @@ func TestShowColumns(t *testing.T) {
}
connTest.Database = orgDatabase
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFindColumn(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
ti, err := connTest.FindColumn("film_id", "sakila", "film")
if err != nil {
t.Error("FindColumn Error: ", err)
......@@ -147,15 +161,19 @@ func TestFindColumn(t *testing.T) {
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestIsFKey(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
if !connTest.IsForeignKey("sakila", "film", "language_id") {
t.Error("want True. got false")
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestShowReference(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
rv, err := connTest.ShowReference("sakila", "film")
if err != nil {
t.Error("ShowReference Error: ", err)
......@@ -167,4 +185,5 @@ func TestShowReference(t *testing.T) {
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -17,4 +17,19 @@ CREATE TABLE `film` (
KEY `idx_fk_language_id` (`language_id`),
KEY `idx_fk_original_language_id` (`original_language_id`)
) ENGINE=InnoDB AUTO_INCREMENT=1001 DEFAULT CHARSET=utf8
CREATE TABLE `category` (
`category_id` tinyint(3) unsigned NOT NULL AUTO_INCREMENT,
`name` varchar(25) NOT NULL,
`last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`category_id`)
) ENGINE=InnoDB AUTO_INCREMENT=17 DEFAULT CHARSET=utf8
CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`localhost` SQL SECURITY DEFINER VIEW `customer_list` AS select `cu`.`customer_id` AS `ID`,concat(`cu`.`first_name`,_utf8mb3' ',`cu`.`last_name`) AS `name`,`a`.`address` AS `address`,`a`.`postal_code` AS `zip code`,`a`.`phone` AS `phone`,`city`.`city` AS `city`,`country`.`country` AS `country`,if(`cu`.`active`,_utf8mb3'active',_utf8mb3'') AS `notes`,`cu`.`store_id` AS `SID` from (((`customer` `cu` join `address` `a` on((`cu`.`address_id` = `a`.`address_id`))) join `city` on((`a`.`city_id` = `city`.`city_id`))) join `country` on((`city`.`country_id` = `country`.`country_id`)))
CREATE TABLE `inventory` (
`inventory_id` mediumint(8) unsigned NOT NULL AUTO_INCREMENT,
`film_id` smallint(5) unsigned NOT NULL,
`store_id` tinyint(3) unsigned NOT NULL,
`last_update` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`inventory_id`),
KEY `idx_fk_film_id` (`film_id`),
KEY `idx_store_id_film_id` (`store_id`,`film_id`)
) ENGINE=InnoDB AUTO_INCREMENT=4582 DEFAULT CHARSET=utf8
......@@ -25,6 +25,7 @@ import (
)
func TestTrace(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
res, err := connTest.Trace("select 1")
if err != nil {
t.Error(err)
......@@ -36,9 +37,11 @@ func TestTrace(t *testing.T) {
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestFormatTrace(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
res, err := connTest.Trace("select 1")
if err != nil {
t.Error(err)
......@@ -50,4 +53,5 @@ func TestFormatTrace(t *testing.T) {
if err != nil {
t.Error(err)
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
......@@ -82,7 +82,7 @@ func BuildEnv() (*VirtualEnv, *database.Connector) {
common.LogIfError(err, "")
// 检查线上环境可用性版本
rEnvVersion, err := vEnv.Version()
rEnvVersion, err := connOnline.Version()
common.Config.OnlineDSN.Version = rEnvVersion
if err != nil {
common.Log.Warn("BuildEnv OnlineDSN: %s:********@%s/%s not available , Error: %s",
......@@ -245,20 +245,20 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
// 为了支持并发,需要将DB进行映射,但db.table这种形式无法保证DB的映射是正确的
// TODO:暂不支持 create db.tableName (id int) 形式的建表语句
if stmt.Table.Qualifier.String() != "" {
common.Log.Error("BuildVirtualEnv DDL Not support '.'")
common.Log.Error("BuildVirtualEnv DDL Not support db.tb format")
return false
}
for _, tb := range stmt.FromTables {
if tb.Qualifier.String() != "" {
common.Log.Error("BuildVirtualEnv DDL Not support '.'")
common.Log.Error("BuildVirtualEnv DDL Not support db.tb format")
return false
}
}
for _, tb := range stmt.ToTables {
if tb.Qualifier.String() != "" {
common.Log.Error("BuildVirtualEnv DDL Not support '.'")
common.Log.Error("BuildVirtualEnv DDL Not support db.tb format")
return false
}
}
......@@ -338,7 +338,7 @@ func (ve *VirtualEnv) BuildVirtualEnv(rEnv *database.Connector, SQLs ...string)
err = ve.createTable(tmpEnv, db, tb.TableName)
if err != nil {
common.Log.Error("BuildVirtualEnv Error : %v", err)
common.Log.Error("BuildVirtualEnv %s.%s Error : %v", db, tb.TableName, err)
return false
}
}
......@@ -453,7 +453,7 @@ func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string
res, err := ve.Query(ddl)
if err != nil {
// 有可能是用户新建表,因此线上环境查不到
common.Log.Error("createTable, %s Error : %v", tbName, err)
common.Log.Error("createTable: %s Error : %v", tbName, err)
return err
}
res.Rows.Close()
......@@ -461,13 +461,9 @@ func (ve VirtualEnv) createTable(rEnv *database.Connector, dbName, tbName string
// 泵取数据
if common.Config.Sampling {
common.Log.Debug("createTable, Start Sampling data from %s.%s to %s.%s ...", dbName, tbName, ve.DBRef[dbName], tbName)
err := ve.SamplingData(rEnv, tbName)
if err != nil {
common.Log.Error(" (ve VirtualEnv) createTable SamplingData Error: %v", err)
return err
}
err = ve.SamplingData(rEnv, dbName, tbName)
}
return nil
return err
}
// GenTableColumns 为 Rewrite 提供的结构体初始化
......
......@@ -49,6 +49,7 @@ func init() {
}
func TestNewVirtualEnv(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
testSQL := []string{
"create table t(id int,c1 varchar(20),PRIMARY KEY (id));",
"alter table t add index `idx_c1`(c1);",
......@@ -117,9 +118,11 @@ func TestNewVirtualEnv(t *testing.T) {
}
}
}, t.Name(), update)
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestCleanupTestDatabase(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
vEnv, _ := BuildEnv()
if common.Config.TestDSN.Disable {
common.Log.Warn("common.Config.TestDSN.Disable=true, by pass TestCleanupTestDatabase")
......@@ -146,9 +149,11 @@ func TestCleanupTestDatabase(t *testing.T) {
if err != nil {
t.Error("optimizer_060102150405 not exist, should not be dropped")
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestGenTableColumns(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
vEnv, rEnv := BuildEnv()
defer vEnv.CleanUp()
......@@ -214,4 +219,66 @@ func TestGenTableColumns(t *testing.T) {
}
}
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestCreateTable(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
orgSamplingCondition := common.Config.SamplingCondition
common.Config.SamplingCondition = "LIMIT 1"
vEnv, rEnv := BuildEnv()
defer vEnv.CleanUp()
// TODO: support VIEW,
tables := []string{
"actor",
// "actor_info", // VIEW
"address",
"category",
"city",
"country",
"customer",
"customer_list",
"film",
"film_actor",
"film_category",
"film_list",
"film_text",
"inventory",
"language",
"nicer_but_slower_film_list",
"payment",
"rental",
// "sales_by_film_category", // VIEW
// "sales_by_store", // VIEW
"staff",
"staff_list",
"store",
}
for _, table := range tables {
err := vEnv.createTable(rEnv, "sakila", table)
if err != nil {
t.Error(err)
}
}
common.Config.SamplingCondition = orgSamplingCondition
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
func TestCreateDatabase(t *testing.T) {
common.Log.Debug("Enter function: %s", common.GetFunctionName())
vEnv, rEnv := BuildEnv()
defer vEnv.CleanUp()
err := vEnv.createDatabase(rEnv, "sakila")
if err != nil {
t.Error(err)
}
if vEnv.DBHash("sakila") == "sakila" {
t.Errorf("database: sakila rehashed failed!")
}
if vEnv.DBHash("not_exist_db") != "not_exist_db" {
t.Errorf("database: not_exist_db rehashed!")
}
common.Log.Debug("Exiting function: %s", common.GetFunctionName())
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册