mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
sql expressions: improve parser (#87277)
sql expressions: improve parser
This commit is contained in:
parent
48f77cdebe
commit
1a2bbd61fd
@ -24,21 +24,7 @@ func TablesList(rawSQL string) ([]string, error) {
|
|||||||
switch kind := stmt.(type) {
|
switch kind := stmt.(type) {
|
||||||
case *sqlparser.Select:
|
case *sqlparser.Select:
|
||||||
for _, from := range kind.From {
|
for _, from := range kind.From {
|
||||||
buf := sqlparser.NewTrackedBuffer(nil)
|
tables = append(tables, getTables(from)...)
|
||||||
from.Format(buf)
|
|
||||||
fromClause := buf.String()
|
|
||||||
upperFromClause := strings.ToUpper(fromClause)
|
|
||||||
if strings.Contains(upperFromClause, "JOIN") {
|
|
||||||
return extractTablesFrom(fromClause), nil
|
|
||||||
}
|
|
||||||
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") {
|
|
||||||
if strings.Contains(upperFromClause, " AS") {
|
|
||||||
name := stripAlias(fromClause)
|
|
||||||
tables = append(tables, name)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
tables = append(tables, fromClause)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return parseTables(rawSQL)
|
return parseTables(rawSQL)
|
||||||
@ -46,7 +32,66 @@ func TablesList(rawSQL string) ([]string, error) {
|
|||||||
if len(tables) == 0 {
|
if len(tables) == 0 {
|
||||||
return parseTables(rawSQL)
|
return parseTables(rawSQL)
|
||||||
}
|
}
|
||||||
return tables, nil
|
return validateTables(tables), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateTables(tables []string) []string {
|
||||||
|
validTables := []string{}
|
||||||
|
for _, table := range tables {
|
||||||
|
if strings.ToUpper(table) != "DUAL" {
|
||||||
|
validTables = append(validTables, table)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return validTables
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinTables(join *sqlparser.JoinTableExpr) []string {
|
||||||
|
t := getTables(join.LeftExpr)
|
||||||
|
t = append(t, getTables(join.RightExpr)...)
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTables(te sqlparser.TableExpr) []string {
|
||||||
|
tables := []string{}
|
||||||
|
switch v := te.(type) {
|
||||||
|
case *sqlparser.AliasedTableExpr:
|
||||||
|
tables = append(tables, nodeValue(v.Expr))
|
||||||
|
return tables
|
||||||
|
case *sqlparser.JoinTableExpr:
|
||||||
|
tables = append(tables, joinTables(v)...)
|
||||||
|
return tables
|
||||||
|
case *sqlparser.ParenTableExpr:
|
||||||
|
for _, e := range v.Exprs {
|
||||||
|
tables = getTables(e)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
tables = append(tables, unknownExpr(te)...)
|
||||||
|
}
|
||||||
|
return tables
|
||||||
|
}
|
||||||
|
|
||||||
|
func unknownExpr(te sqlparser.TableExpr) []string {
|
||||||
|
tables := []string{}
|
||||||
|
fromClause := nodeValue(te)
|
||||||
|
upperFromClause := strings.ToUpper(fromClause)
|
||||||
|
if strings.Contains(upperFromClause, "JOIN") {
|
||||||
|
return extractTablesFrom(fromClause)
|
||||||
|
}
|
||||||
|
if upperFromClause != "DUAL" && !strings.HasPrefix(fromClause, "(") {
|
||||||
|
if strings.Contains(upperFromClause, " AS") {
|
||||||
|
name := stripAlias(fromClause)
|
||||||
|
tables = append(tables, name)
|
||||||
|
return tables
|
||||||
|
}
|
||||||
|
tables = append(tables, fromClause)
|
||||||
|
}
|
||||||
|
return tables
|
||||||
|
}
|
||||||
|
|
||||||
|
func nodeValue(node sqlparser.SQLNode) string {
|
||||||
|
buf := sqlparser.NewTrackedBuffer(nil)
|
||||||
|
node.Format(buf)
|
||||||
|
return buf.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractTablesFrom(stmt string) []string {
|
func extractTablesFrom(stmt string) []string {
|
||||||
|
@ -89,3 +89,42 @@ func TestRightJoin(t *testing.T) {
|
|||||||
assert.Equal(t, "A", tables[0])
|
assert.Equal(t, "A", tables[0])
|
||||||
assert.Equal(t, "B", tables[1])
|
assert.Equal(t, "B", tables[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAliasWithJoin(t *testing.T) {
|
||||||
|
sql := `select * from A as X
|
||||||
|
RIGHT JOIN B ON A.name = X.name
|
||||||
|
LIMIT 10`
|
||||||
|
tables, err := TablesList((sql))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(tables))
|
||||||
|
assert.Equal(t, "A", tables[0])
|
||||||
|
assert.Equal(t, "B", tables[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAlias(t *testing.T) {
|
||||||
|
sql := `select * from A as X LIMIT 10`
|
||||||
|
tables, err := TablesList((sql))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, len(tables))
|
||||||
|
assert.Equal(t, "A", tables[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParens(t *testing.T) {
|
||||||
|
sql := `SELECT t1.Col1,
|
||||||
|
t2.Col1,
|
||||||
|
t3.Col1
|
||||||
|
FROM table1 AS t1
|
||||||
|
LEFT JOIN (
|
||||||
|
table2 AS t2
|
||||||
|
INNER JOIN table3 AS t3 ON t3.Col1 = t2.Col1
|
||||||
|
) ON t2.Col1 = t1.Col1;`
|
||||||
|
tables, err := TablesList((sql))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(tables))
|
||||||
|
assert.Equal(t, "table1", tables[0])
|
||||||
|
assert.Equal(t, "table2", tables[1])
|
||||||
|
assert.Equal(t, "table3", tables[2])
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user