sql expressions: improve parser (#87277)

sql expressions: improve parser
This commit is contained in:
Scott Lepper 2024-05-03 13:08:07 +01:00 committed by GitHub
parent 48f77cdebe
commit 1a2bbd61fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 100 additions and 16 deletions

View File

@ -24,21 +24,7 @@ func TablesList(rawSQL string) ([]string, error) {
switch kind := stmt.(type) {
case *sqlparser.Select:
for _, from := range kind.From {
buf := sqlparser.NewTrackedBuffer(nil)
from.Format(buf)
fromClause := buf.String()
upperFromClause := strings.ToUpper(fromClause)
if strings.Contains(upperFromClause, "JOIN") {
return extractTablesFrom(fromClause), nil
}
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") {
if strings.Contains(upperFromClause, " AS") {
name := stripAlias(fromClause)
tables = append(tables, name)
continue
}
tables = append(tables, fromClause)
}
tables = append(tables, getTables(from)...)
}
default:
return parseTables(rawSQL)
@ -46,7 +32,66 @@ func TablesList(rawSQL string) ([]string, error) {
if len(tables) == 0 {
return parseTables(rawSQL)
}
return tables, nil
return validateTables(tables), nil
}
func validateTables(tables []string) []string {
validTables := []string{}
for _, table := range tables {
if strings.ToUpper(table) != "DUAL" {
validTables = append(validTables, table)
}
}
return validTables
}
func joinTables(join *sqlparser.JoinTableExpr) []string {
t := getTables(join.LeftExpr)
t = append(t, getTables(join.RightExpr)...)
return t
}
func getTables(te sqlparser.TableExpr) []string {
tables := []string{}
switch v := te.(type) {
case *sqlparser.AliasedTableExpr:
tables = append(tables, nodeValue(v.Expr))
return tables
case *sqlparser.JoinTableExpr:
tables = append(tables, joinTables(v)...)
return tables
case *sqlparser.ParenTableExpr:
for _, e := range v.Exprs {
tables = getTables(e)
}
default:
tables = append(tables, unknownExpr(te)...)
}
return tables
}
func unknownExpr(te sqlparser.TableExpr) []string {
tables := []string{}
fromClause := nodeValue(te)
upperFromClause := strings.ToUpper(fromClause)
if strings.Contains(upperFromClause, "JOIN") {
return extractTablesFrom(fromClause)
}
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") {
if strings.Contains(upperFromClause, " AS") {
name := stripAlias(fromClause)
tables = append(tables, name)
return tables
}
tables = append(tables, fromClause)
}
return tables
}
func nodeValue(node sqlparser.SQLNode) string {
buf := sqlparser.NewTrackedBuffer(nil)
node.Format(buf)
return buf.String()
}
func extractTablesFrom(stmt string) []string {

View File

@ -89,3 +89,42 @@ func TestRightJoin(t *testing.T) {
assert.Equal(t, "A", tables[0])
assert.Equal(t, "B", tables[1])
}
func TestAliasWithJoin(t *testing.T) {
sql := `select * from A as X
RIGHT JOIN B ON A.name = X.name
LIMIT 10`
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 2, len(tables))
assert.Equal(t, "A", tables[0])
assert.Equal(t, "B", tables[1])
}
func TestAlias(t *testing.T) {
sql := `select * from A as X LIMIT 10`
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 1, len(tables))
assert.Equal(t, "A", tables[0])
}
func TestParens(t *testing.T) {
sql := `SELECT t1.Col1,
t2.Col1,
t3.Col1
FROM table1 AS t1
LEFT JOIN (
table2 AS t2
INNER JOIN table3 AS t3 ON t3.Col1 = t2.Col1
) ON t2.Col1 = t1.Col1;`
tables, err := TablesList((sql))
assert.Nil(t, err)
assert.Equal(t, 3, len(tables))
assert.Equal(t, "table1", tables[0])
assert.Equal(t, "table2", tables[1])
assert.Equal(t, "table3", tables[2])
}