2024-02-27 15:16:00 -06:00
|
|
|
package sql
|
|
|
|
|
|
|
|
import (
|
|
|
|
"errors"
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
parser "github.com/krasun/gosqlparser"
|
|
|
|
"github.com/xwb1989/sqlparser"
|
|
|
|
)
|
|
|
|
|
|
|
|
// TablesList returns a list of tables for the sql statement
|
|
|
|
func TablesList(rawSQL string) ([]string, error) {
|
|
|
|
stmt, err := sqlparser.Parse(rawSQL)
|
|
|
|
if err != nil {
|
|
|
|
tables, err := parse(rawSQL)
|
|
|
|
if err != nil {
|
|
|
|
return parseTables(rawSQL)
|
|
|
|
}
|
|
|
|
return tables, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
tables := []string{}
|
|
|
|
switch kind := stmt.(type) {
|
|
|
|
case *sqlparser.Select:
|
|
|
|
for _, t := range kind.From {
|
|
|
|
buf := sqlparser.NewTrackedBuffer(nil)
|
|
|
|
t.Format(buf)
|
|
|
|
table := buf.String()
|
2024-05-02 07:43:20 -05:00
|
|
|
if table != "dual" && !strings.HasPrefix(table, "(") {
|
|
|
|
if strings.Contains(table, " as") {
|
|
|
|
name := stripAlias(table)
|
|
|
|
tables = append(tables, name)
|
|
|
|
continue
|
|
|
|
}
|
2024-02-27 15:16:00 -06:00
|
|
|
tables = append(tables, buf.String())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
default:
|
2024-05-02 07:43:20 -05:00
|
|
|
return parseTables(rawSQL)
|
|
|
|
}
|
|
|
|
if len(tables) == 0 {
|
|
|
|
return parseTables(rawSQL)
|
2024-02-27 15:16:00 -06:00
|
|
|
}
|
|
|
|
return tables, nil
|
|
|
|
}
|
|
|
|
|
2024-05-02 07:43:20 -05:00
|
|
|
func stripAlias(table string) string {
|
|
|
|
tableParts := []string{}
|
|
|
|
for _, part := range strings.Split(table, " ") {
|
|
|
|
if part == "as" {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
tableParts = append(tableParts, part)
|
|
|
|
}
|
|
|
|
return strings.Join(tableParts, " ")
|
|
|
|
}
|
|
|
|
|
2024-02-27 15:16:00 -06:00
|
|
|
// uses a simple tokenizer
|
|
|
|
func parse(rawSQL string) ([]string, error) {
|
|
|
|
query, err := parser.Parse(rawSQL)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
if query.GetType() == parser.StatementSelect {
|
|
|
|
sel, ok := query.(*parser.Select)
|
|
|
|
if ok {
|
|
|
|
return []string{sel.Table}, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
func parseTables(rawSQL string) ([]string, error) {
|
|
|
|
checkSql := strings.ToUpper(rawSQL)
|
2024-05-02 07:43:20 -05:00
|
|
|
rawSQL = strings.ReplaceAll(rawSQL, "\n", " ")
|
2024-02-27 15:16:00 -06:00
|
|
|
if strings.HasPrefix(checkSql, "SELECT") || strings.HasPrefix(rawSQL, "WITH") {
|
|
|
|
tables := []string{}
|
|
|
|
tokens := strings.Split(rawSQL, " ")
|
|
|
|
checkNext := false
|
|
|
|
takeNext := false
|
2024-05-02 07:43:20 -05:00
|
|
|
for _, token := range tokens {
|
|
|
|
t := strings.ToUpper(token)
|
2024-02-27 15:16:00 -06:00
|
|
|
t = strings.TrimSpace(t)
|
|
|
|
|
|
|
|
if takeNext {
|
2024-05-02 07:43:20 -05:00
|
|
|
if !existsInList(token, tables) {
|
|
|
|
tables = append(tables, token)
|
|
|
|
}
|
2024-02-27 15:16:00 -06:00
|
|
|
checkNext = false
|
|
|
|
takeNext = false
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if checkNext {
|
|
|
|
if strings.Contains(t, "(") {
|
|
|
|
checkNext = false
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
if strings.Contains(t, ",") {
|
2024-05-02 07:43:20 -05:00
|
|
|
values := strings.Split(token, ",")
|
2024-02-27 15:16:00 -06:00
|
|
|
for _, v := range values {
|
|
|
|
v := strings.TrimSpace(v)
|
|
|
|
if v != "" {
|
2024-05-02 07:43:20 -05:00
|
|
|
if !existsInList(token, tables) {
|
|
|
|
tables = append(tables, v)
|
|
|
|
}
|
2024-02-27 15:16:00 -06:00
|
|
|
} else {
|
|
|
|
takeNext = true
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
continue
|
|
|
|
}
|
2024-05-02 07:43:20 -05:00
|
|
|
if !existsInList(token, tables) {
|
|
|
|
tables = append(tables, token)
|
|
|
|
}
|
2024-02-27 15:16:00 -06:00
|
|
|
checkNext = false
|
|
|
|
}
|
|
|
|
if t == "FROM" {
|
|
|
|
checkNext = true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return tables, nil
|
|
|
|
}
|
|
|
|
return nil, errors.New("not a select statement")
|
|
|
|
}
|
2024-05-02 07:43:20 -05:00
|
|
|
|
|
|
|
func existsInList(table string, list []string) bool {
|
|
|
|
for _, t := range list {
|
|
|
|
if t == table {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false
|
|
|
|
}
|