mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
sql expressions: improve parser (#87277)
sql expressions: improve parser
This commit is contained in:
parent
48f77cdebe
commit
1a2bbd61fd
@ -24,21 +24,7 @@ func TablesList(rawSQL string) ([]string, error) {
|
||||
switch kind := stmt.(type) {
|
||||
case *sqlparser.Select:
|
||||
for _, from := range kind.From {
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
from.Format(buf)
|
||||
fromClause := buf.String()
|
||||
upperFromClause := strings.ToUpper(fromClause)
|
||||
if strings.Contains(upperFromClause, "JOIN") {
|
||||
return extractTablesFrom(fromClause), nil
|
||||
}
|
||||
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") {
|
||||
if strings.Contains(upperFromClause, " AS") {
|
||||
name := stripAlias(fromClause)
|
||||
tables = append(tables, name)
|
||||
continue
|
||||
}
|
||||
tables = append(tables, fromClause)
|
||||
}
|
||||
tables = append(tables, getTables(from)...)
|
||||
}
|
||||
default:
|
||||
return parseTables(rawSQL)
|
||||
@ -46,7 +32,66 @@ func TablesList(rawSQL string) ([]string, error) {
|
||||
if len(tables) == 0 {
|
||||
return parseTables(rawSQL)
|
||||
}
|
||||
return tables, nil
|
||||
return validateTables(tables), nil
|
||||
}
|
||||
|
||||
func validateTables(tables []string) []string {
|
||||
validTables := []string{}
|
||||
for _, table := range tables {
|
||||
if strings.ToUpper(table) != "DUAL" {
|
||||
validTables = append(validTables, table)
|
||||
}
|
||||
}
|
||||
return validTables
|
||||
}
|
||||
|
||||
func joinTables(join *sqlparser.JoinTableExpr) []string {
|
||||
t := getTables(join.LeftExpr)
|
||||
t = append(t, getTables(join.RightExpr)...)
|
||||
return t
|
||||
}
|
||||
|
||||
func getTables(te sqlparser.TableExpr) []string {
|
||||
tables := []string{}
|
||||
switch v := te.(type) {
|
||||
case *sqlparser.AliasedTableExpr:
|
||||
tables = append(tables, nodeValue(v.Expr))
|
||||
return tables
|
||||
case *sqlparser.JoinTableExpr:
|
||||
tables = append(tables, joinTables(v)...)
|
||||
return tables
|
||||
case *sqlparser.ParenTableExpr:
|
||||
for _, e := range v.Exprs {
|
||||
tables = getTables(e)
|
||||
}
|
||||
default:
|
||||
tables = append(tables, unknownExpr(te)...)
|
||||
}
|
||||
return tables
|
||||
}
|
||||
|
||||
func unknownExpr(te sqlparser.TableExpr) []string {
|
||||
tables := []string{}
|
||||
fromClause := nodeValue(te)
|
||||
upperFromClause := strings.ToUpper(fromClause)
|
||||
if strings.Contains(upperFromClause, "JOIN") {
|
||||
return extractTablesFrom(fromClause)
|
||||
}
|
||||
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") {
|
||||
if strings.Contains(upperFromClause, " AS") {
|
||||
name := stripAlias(fromClause)
|
||||
tables = append(tables, name)
|
||||
return tables
|
||||
}
|
||||
tables = append(tables, fromClause)
|
||||
}
|
||||
return tables
|
||||
}
|
||||
|
||||
func nodeValue(node sqlparser.SQLNode) string {
|
||||
buf := sqlparser.NewTrackedBuffer(nil)
|
||||
node.Format(buf)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func extractTablesFrom(stmt string) []string {
|
||||
|
@ -89,3 +89,42 @@ func TestRightJoin(t *testing.T) {
|
||||
assert.Equal(t, "A", tables[0])
|
||||
assert.Equal(t, "B", tables[1])
|
||||
}
|
||||
|
||||
func TestAliasWithJoin(t *testing.T) {
|
||||
sql := `select * from A as X
|
||||
RIGHT JOIN B ON A.name = X.name
|
||||
LIMIT 10`
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 2, len(tables))
|
||||
assert.Equal(t, "A", tables[0])
|
||||
assert.Equal(t, "B", tables[1])
|
||||
}
|
||||
|
||||
func TestAlias(t *testing.T) {
|
||||
sql := `select * from A as X LIMIT 10`
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 1, len(tables))
|
||||
assert.Equal(t, "A", tables[0])
|
||||
}
|
||||
|
||||
func TestParens(t *testing.T) {
|
||||
sql := `SELECT t1.Col1,
|
||||
t2.Col1,
|
||||
t3.Col1
|
||||
FROM table1 AS t1
|
||||
LEFT JOIN (
|
||||
table2 AS t2
|
||||
INNER JOIN table3 AS t3 ON t3.Col1 = t2.Col1
|
||||
) ON t2.Col1 = t1.Col1;`
|
||||
tables, err := TablesList((sql))
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, 3, len(tables))
|
||||
assert.Equal(t, "table1", tables[0])
|
||||
assert.Equal(t, "table2", tables[1])
|
||||
assert.Equal(t, "table3", tables[2])
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user