meta_test.go 14.4 KB
Newer Older
martianzhang's avatar
martianzhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Copyright 2018 Xiaomi, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ast

import (
20
	"flag"
martianzhang's avatar
martianzhang 已提交
21
	"fmt"
martianzhang's avatar
martianzhang 已提交
22 23
	"path/filepath"
	"runtime"
martianzhang's avatar
martianzhang 已提交
24 25 26 27 28 29 30 31
	"testing"

	"github.com/XiaoMi/soar/common"

	"github.com/kr/pretty"
	"vitess.io/vitess/go/vt/sqlparser"
)

32 33 34 35
var update = flag.Bool("update", false, "update .golden files")

func TestMain(m *testing.M) {
	// 初始化 init
martianzhang's avatar
martianzhang 已提交
36 37 38 39
	if common.DevPath == "" {
		_, file, _, _ := runtime.Caller(0)
		common.DevPath, _ = filepath.Abs(filepath.Dir(filepath.Join(file, ".."+string(filepath.Separator))))
	}
40 41 42 43 44 45 46 47 48 49 50 51 52
	common.BaseDir = common.DevPath
	err := common.ParseConfig("")
	common.LogIfError(err, "init ParseConfig")
	common.Log.Debug("ast_test init")

	// 分割线
	flag.Parse()
	m.Run()

	// 环境清理
	//
}

martianzhang's avatar
martianzhang 已提交
53
func TestGetTableFromExprs(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
54
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67
	tbExprs := sqlparser.TableExprs{
		&sqlparser.AliasedTableExpr{
			Expr: sqlparser.TableName{
				Name:      sqlparser.NewTableIdent("table"),
				Qualifier: sqlparser.NewTableIdent("db"),
			},
			As: sqlparser.NewTableIdent("as"),
		},
	}
	meta := GetTableFromExprs(tbExprs)
	if tb, ok := meta["db"]; !ok {
		t.Errorf("no table qualifier, meta: %s", pretty.Sprint(tb))
	}
martianzhang's avatar
martianzhang 已提交
68
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
69 70 71
}

func TestGetParseTableWithStmt(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
72
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
73 74 75 76 77 78 79 80 81 82
	for _, sql := range common.TestSQLs {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
		if err != nil {
			t.Errorf("SQL Parsed error: %v", err)
		}
		meta := GetMeta(stmt, nil)
		pretty.Println(meta)
		fmt.Println()
	}
martianzhang's avatar
martianzhang 已提交
83
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
84 85 86
}

func TestFindCondition(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
87
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
88 89
	sqls := []string{
		`SELECT * FROM film WHERE length % 20 = 4;`,
90
		`select * from actor where actor_id = 1 order by if(first_name="PENELOPE", last_name, "") desc`,
martianzhang's avatar
martianzhang 已提交
91 92
	}
	for _, sql := range append(sqls, common.TestSQLs...) {
martianzhang's avatar
martianzhang 已提交
93 94
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
L
liipx 已提交
95
		// pretty.Println(stmt)
martianzhang's avatar
martianzhang 已提交
96 97 98 99 100
		if err != nil {
			panic(err)
		}
		eq := FindEQColsInWhere(stmt)
		inEq := FindINEQColsInWhere(stmt)
101
		fmt.Println("WhereEQ:")
martianzhang's avatar
martianzhang 已提交
102
		pretty.Println(eq)
103
		fmt.Println("WhereINEQ:")
martianzhang's avatar
martianzhang 已提交
104 105 106
		pretty.Println(inEq)
		fmt.Println()
	}
martianzhang's avatar
martianzhang 已提交
107
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
108 109 110
}

func TestFindGroupBy(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
111
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125
	sqlList := []string{
		"select a from t group by c",
	}

	for _, sql := range sqlList {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
		if err != nil {
			panic(err)
		}
		res := FindGroupByCols(stmt)
		pretty.Println(res)
		fmt.Println()
	}
martianzhang's avatar
martianzhang 已提交
126
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
127 128 129
}

func TestFindOrderBy(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
130
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
	sqlList := []string{
		"select a from t group by c order by d, c desc",
		"select a from t group by c order by d desc",
	}

	for _, sql := range sqlList {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
		if err != nil {
			panic(err)
		}
		res := FindOrderByCols(stmt)
		pretty.Println(res)
		fmt.Println()
	}
martianzhang's avatar
martianzhang 已提交
146
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
147 148 149
}

func TestFindSubquery(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
150
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
	sqlList := []string{
		"SELECT * FROM t1 WHERE column1 = (SELECT column1 FROM (SELECT column1 FROM t2) a);",
		"select column1 from t2",
		"SELECT * FROM t1 WHERE column1 = (SELECT column1 FROM t2);",
		"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
	}

	for _, sql := range sqlList {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
		if err != nil {
			panic(err)
		}

		subquery := FindSubquery(0, stmt)
		fmt.Println(len(subquery))
		pretty.Println(subquery)
	}
martianzhang's avatar
martianzhang 已提交
169
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
170 171 172
}

func TestFindJoinTable(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
173
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
174 175 176 177 178 179 180 181 182 183 184 185
	sqlList := []string{
		"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
		"SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"SELECT * FROM t1 RIGHT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"SELECT left_tbl.* FROM left_tbl LEFT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
		"SELECT left_tbl.* FROM left_tbl RIGHT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
	}

	for _, sql := range sqlList {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
L
liipx 已提交
186
		// pretty.Println(stmt)
martianzhang's avatar
martianzhang 已提交
187 188 189 190 191 192 193
		if err != nil {
			panic(err)
		}

		joinMeta := FindJoinTable(stmt, nil)
		pretty.Println(joinMeta)
	}
martianzhang's avatar
martianzhang 已提交
194
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
195 196 197
}

func TestFindJoinCols(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
198
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211
	sqlList := []string{
		"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"select t from a LEFT JOIN b USING (c1, c2, c3)",
		"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
		"SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"SELECT * FROM t1 RIGHT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"SELECT left_tbl.* FROM left_tbl LEFT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
		"SELECT left_tbl.* FROM left_tbl RIGHT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
	}

	for _, sql := range sqlList {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
L
liipx 已提交
212
		// pretty.Println(stmt)
martianzhang's avatar
martianzhang 已提交
213 214 215 216 217 218 219
		if err != nil {
			panic(err)
		}

		columns := FindJoinCols(stmt)
		pretty.Println(columns)
	}
martianzhang's avatar
martianzhang 已提交
220
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
221 222 223
}

func TestFindJoinColBeWhereEQ(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
224
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
225 226 227 228 229 230 231 232 233 234 235
	sqlList := []string{
		"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
		"SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"SELECT * FROM t1 RIGHT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"SELECT left_tbl.* FROM left_tbl LEFT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
		"SELECT left_tbl.* FROM left_tbl RIGHT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
	}

	for _, sql := range sqlList {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
L
liipx 已提交
236
		// pretty.Println(stmt)
martianzhang's avatar
martianzhang 已提交
237 238 239 240 241 242 243
		if err != nil {
			panic(err)
		}

		columns := FindEQColsInJoinCond(stmt)
		pretty.Println(columns)
	}
martianzhang's avatar
martianzhang 已提交
244
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
245 246 247
}

func TestFindJoinColBeWhereINEQ(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
248
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
249 250 251 252 253 254 255 256 257 258 259
	sqlList := []string{
		"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
		"SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"SELECT * FROM t1 RIGHT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b > 'b' AND t4.c = t1.c)",
		"SELECT left_tbl.* FROM left_tbl LEFT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
		"SELECT left_tbl.* FROM left_tbl RIGHT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
	}

	for _, sql := range sqlList {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
L
liipx 已提交
260
		// pretty.Println(stmt)
martianzhang's avatar
martianzhang 已提交
261 262 263 264 265 266 267
		if err != nil {
			panic(err)
		}

		columns := FindINEQColsInJoinCond(stmt)
		pretty.Println(columns)
	}
martianzhang's avatar
martianzhang 已提交
268
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
269 270 271
}

func TestFindAllCondition(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
272
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
	sqlList := []string{
		"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"select t from a LEFT JOIN b USING (c1, c2, c3)",
		"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
		"SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"SELECT * FROM t1 RIGHT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"SELECT left_tbl.* FROM left_tbl LEFT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
		"SELECT left_tbl.* FROM left_tbl RIGHT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
		"SELECT * FROM t1 where a in ('a','b')",
		"SELECT * FROM t1 where a BETWEEN 'bar' AND 'foo'",
		"SELECT * FROM t1 where a = sum(a,b)",
		"SELECT distinct a FROM t1 where a = '2001-01-01 01:01:01'",
	}

	for _, sql := range sqlList {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
L
liipx 已提交
290
		// pretty.Println(stmt)
martianzhang's avatar
martianzhang 已提交
291 292 293 294 295 296 297
		if err != nil {
			panic(err)
		}

		columns := FindAllCondition(stmt)
		pretty.Println(columns)
	}
martianzhang's avatar
martianzhang 已提交
298
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
299 300 301
}

func TestFindColumn(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
302
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
303 304 305 306 307 308 309 310
	sqlList := []string{
		"select col, col2, sum(col1) from tb group by col",
		"select col from tb group by col,sum(col1)",
		"select col, sum(col1) from tb group by col",
	}
	for _, sql := range sqlList {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
L
liipx 已提交
311
		// pretty.Println(stmt)
martianzhang's avatar
martianzhang 已提交
312 313 314 315 316 317 318
		if err != nil {
			panic(err)
		}

		columns := FindColumn(stmt)
		pretty.Println(columns)
	}
martianzhang's avatar
martianzhang 已提交
319
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
320 321 322
}

func TestFindAllCols(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
323
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
324
	sqlList := []string{
L
liipx 已提交
325 326 327 328
		"select * from tb where a = '1' order by c",
		"select * from tb where a = '1' group by c",
		"select * from tb where c = '1' group by a",
		"select * from tb join tb2 on c = c where c = '1' group by a",
martianzhang's avatar
martianzhang 已提交
329 330
	}

L
liipx 已提交
331 332 333 334 335 336 337 338
	targets := []Expression{
		OrderByExpression,
		GroupByExpression,
		WhereExpression,
		JoinExpression,
	}

	for i, sql := range sqlList {
martianzhang's avatar
martianzhang 已提交
339 340
		stmt, err := sqlparser.Parse(sql)
		if err != nil {
L
liipx 已提交
341 342
			t.Error(err)
			return
martianzhang's avatar
martianzhang 已提交
343 344
		}

L
liipx 已提交
345 346 347 348 349
		columns := FindAllCols(stmt, targets[i])
		if columns[0].Name != "c" {
			fmt.Println(sql)
			t.Error(fmt.Errorf("want 'c' got %v", columns))
		}
martianzhang's avatar
martianzhang 已提交
350
	}
martianzhang's avatar
martianzhang 已提交
351
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
352 353 354
}

func TestGetSubqueryDepth(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
355
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
	sqlList := []string{
		"SELECT * FROM t1 LEFT JOIN (t2 CROSS JOIN t3 CROSS JOIN t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"select t from a LEFT JOIN b USING (c1, c2, c3)",
		"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
		"SELECT * FROM t1 LEFT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"SELECT * FROM t1 RIGHT JOIN (t2, t3, t4) ON (t2.a = t1.a AND t3.b = t1.b AND t4.c = t1.c)",
		"SELECT left_tbl.* FROM left_tbl LEFT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
		"SELECT left_tbl.* FROM left_tbl RIGHT JOIN right_tbl ON left_tbl.id = right_tbl.id WHERE right_tbl.id IS NULL;",
		"SELECT * FROM t1 where a in ('a','b')",
		"SELECT * FROM t1 where a BETWEEN 'bar' AND 'foo'",
		"SELECT * FROM t1 where a = sum(a,b)",
		"SELECT distinct a FROM t1 where a = '2001-01-01 01:01:01'",
	}

	for _, sql := range sqlList {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
		if err != nil {
			t.Error("syntax check error.")
		}

		dep := GetSubqueryDepth(stmt)
		fmt.Println(dep)
	}
martianzhang's avatar
martianzhang 已提交
380
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
martianzhang's avatar
martianzhang 已提交
381
}
L
liipx 已提交
382 383

func TestAppendTable(t *testing.T) {
martianzhang's avatar
martianzhang 已提交
384
	common.Log.Debug("Entering function: %s", common.GetFunctionName())
L
liipx 已提交
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
	sqlList := []string{
		"select ID,name from (select address from customer_list where SID=1 order by phone limit 50,10) a join customer_list l on (a.address=l.address) join city c on (c.city=l.city) order by phone desc;",
	}

	meta := make(map[string]*common.DB)
	for _, sql := range sqlList {
		fmt.Println(sql)
		stmt, err := sqlparser.Parse(sql)
		if err != nil {
			t.Error("syntax check error.")
		}

		err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
			switch expr := node.(type) {
			case *sqlparser.AliasedTableExpr:
				switch table := expr.Expr.(type) {
				case sqlparser.TableName:
					appendTable(table, expr.As.String(), meta)
				default:
					if meta == nil {
						meta = make(map[string]*common.DB)
					}
					if meta[""] == nil {
						meta[""] = common.NewDB("")
					}
					meta[""].Table[""] = common.NewTable("")
					meta[""].Table[""].TableAliases = append(meta[""].Table[""].TableAliases, expr.As.String())
				}
			}
			return true, nil
		}, stmt)

		if err != nil {
			t.Error(err)
		}
	}

	// 仅对第一条测试SQL进行测试,验证别名正确性
	if meta[""].Table["customer_list"].TableAliases[0] != "l" || meta[""].Table["city"].TableAliases[0] != "c" {
		t.Error("alias filed\n", pretty.Sprint(meta))
	}
martianzhang's avatar
martianzhang 已提交
426
	common.Log.Debug("Exiting function: %s", common.GetFunctionName())
L
liipx 已提交
427
}