2024-02-27 15:16:00 -06:00
|
|
|
package sql
|
|
|
|
|
|
|
|
import (
|
2024-05-14 16:05:29 -05:00
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"sort"
|
2024-02-27 15:16:00 -06:00
|
|
|
"strings"
|
|
|
|
|
2024-05-14 16:05:29 -05:00
|
|
|
"github.com/jeremywohl/flatten"
|
|
|
|
"github.com/scottlepp/go-duck/duck"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
TABLE_NAME = "table_name"
|
|
|
|
ERROR = ".error"
|
|
|
|
ERROR_MESSAGE = ".error_message"
|
2024-02-27 15:16:00 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
// TablesList returns a list of tables for the sql statement
|
|
|
|
func TablesList(rawSQL string) ([]string, error) {
|
2024-05-14 16:05:29 -05:00
|
|
|
duckDB := duck.NewInMemoryDB()
|
|
|
|
cmd := fmt.Sprintf("SELECT json_serialize_sql('%s')", rawSQL)
|
|
|
|
ret, err := duckDB.RunCommands([]string{cmd})
|
2024-02-27 15:16:00 -06:00
|
|
|
if err != nil {
|
2024-05-14 16:05:29 -05:00
|
|
|
return nil, fmt.Errorf("error serializing sql: %s", err.Error())
|
2024-02-27 15:16:00 -06:00
|
|
|
}
|
2024-05-03 07:08:07 -05:00
|
|
|
|
2024-05-14 16:05:29 -05:00
|
|
|
ast := []map[string]any{}
|
|
|
|
err = json.Unmarshal([]byte(ret), &ast)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error converting json to ast: %s", err.Error())
|
2024-05-03 07:08:07 -05:00
|
|
|
}
|
|
|
|
|
2024-05-14 16:05:29 -05:00
|
|
|
return tablesFromAST(ast)
|
2024-05-03 07:08:07 -05:00
|
|
|
}
|
|
|
|
|
2024-05-14 16:05:29 -05:00
|
|
|
func tablesFromAST(ast []map[string]any) ([]string, error) {
|
|
|
|
flat, err := flatten.Flatten(ast[0], "", flatten.DotStyle)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error flattening ast: %s", err.Error())
|
2024-05-03 07:08:07 -05:00
|
|
|
}
|
2024-02-27 15:16:00 -06:00
|
|
|
|
2024-05-02 13:48:05 -05:00
|
|
|
tables := []string{}
|
2024-05-14 16:05:29 -05:00
|
|
|
for k, v := range flat {
|
|
|
|
if strings.HasSuffix(k, ERROR) {
|
|
|
|
v, ok := v.(bool)
|
|
|
|
if ok && v {
|
|
|
|
return nil, astError(k, flat)
|
|
|
|
}
|
2024-05-02 13:48:05 -05:00
|
|
|
}
|
2024-05-14 16:05:29 -05:00
|
|
|
if strings.Contains(k, TABLE_NAME) {
|
|
|
|
table, ok := v.(string)
|
|
|
|
if ok && !existsInList(table, tables) {
|
|
|
|
tables = append(tables, v.(string))
|
2024-05-02 13:48:05 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2024-05-14 16:05:29 -05:00
|
|
|
sort.Strings(tables)
|
2024-05-02 13:48:05 -05:00
|
|
|
|
2024-05-14 16:05:29 -05:00
|
|
|
return tables, nil
|
2024-05-02 07:43:20 -05:00
|
|
|
}
|
|
|
|
|
2024-05-14 16:05:29 -05:00
|
|
|
func astError(k string, flat map[string]any) error {
|
|
|
|
key := strings.Replace(k, ERROR, "", 1)
|
|
|
|
message, ok := flat[key+ERROR_MESSAGE]
|
|
|
|
if !ok {
|
|
|
|
message = "unknown error in sql"
|
2024-02-27 15:16:00 -06:00
|
|
|
}
|
2024-05-14 16:05:29 -05:00
|
|
|
return fmt.Errorf("error in sql: %s", message)
|
2024-02-27 15:16:00 -06:00
|
|
|
}
|
2024-05-02 07:43:20 -05:00
|
|
|
|
|
|
|
func existsInList(table string, list []string) bool {
|
|
|
|
for _, t := range list {
|
|
|
|
if t == table {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false
|
|
|
|
}
|