diff --git a/pkg/expr/sql/parser.go b/pkg/expr/sql/parser.go index d5ea2d1f6be..48e34ea1154 100644 --- a/pkg/expr/sql/parser.go +++ b/pkg/expr/sql/parser.go @@ -26,16 +26,35 @@ func TablesList(rawSQL string) ([]string, error) { buf := sqlparser.NewTrackedBuffer(nil) t.Format(buf) table := buf.String() - if table != "dual" { + if table != "dual" && !strings.HasPrefix(table, "(") { + if strings.Contains(table, " as") { + name := stripAlias(table) + tables = append(tables, name) + continue + } tables = append(tables, buf.String()) } } default: - return nil, errors.New("not a select statement") + return parseTables(rawSQL) + } + if len(tables) == 0 { + return parseTables(rawSQL) } return tables, nil } +func stripAlias(table string) string { + tableParts := []string{} + for _, part := range strings.Split(table, " ") { + if part == "as" { + break + } + tableParts = append(tableParts, part) + } + return strings.Join(tableParts, " ") +} + // uses a simple tokenizer func parse(rawSQL string) ([]string, error) { query, err := parser.Parse(rawSQL) @@ -53,17 +72,20 @@ func parse(rawSQL string) ([]string, error) { func parseTables(rawSQL string) ([]string, error) { checkSql := strings.ToUpper(rawSQL) + rawSQL = strings.ReplaceAll(rawSQL, "\n", " ") if strings.HasPrefix(checkSql, "SELECT") || strings.HasPrefix(rawSQL, "WITH") { tables := []string{} tokens := strings.Split(rawSQL, " ") checkNext := false takeNext := false - for _, t := range tokens { - t = strings.ToUpper(t) + for _, token := range tokens { + t := strings.ToUpper(token) t = strings.TrimSpace(t) if takeNext { - tables = append(tables, t) + if !existsInList(token, tables) { + tables = append(tables, token) + } checkNext = false takeNext = false continue @@ -74,11 +96,13 @@ func parseTables(rawSQL string) ([]string, error) { continue } if strings.Contains(t, ",") { - values := strings.Split(t, ",") + values := strings.Split(token, ",") for _, v := range values { v := strings.TrimSpace(v) if v != "" { - tables = append(tables, v) + if !existsInList(token, tables) { + tables = append(tables, v) + } } else { takeNext = true break @@ -86,7 +110,9 @@ func parseTables(rawSQL string) ([]string, error) { } continue } - tables = append(tables, t) + if !existsInList(token, tables) { + tables = append(tables, token) + } checkNext = false } if t == "FROM" { @@ -97,3 +123,12 @@ func parseTables(rawSQL string) ([]string, error) { } return nil, errors.New("not a select statement") } + +func existsInList(table string, list []string) bool { + for _, t := range list { + if t == table { + return true + } + } + return false +} diff --git a/pkg/expr/sql/parser_test.go b/pkg/expr/sql/parser_test.go index 2c8e43681e6..ce018c3c0f5 100644 --- a/pkg/expr/sql/parser_test.go +++ b/pkg/expr/sql/parser_test.go @@ -11,7 +11,7 @@ func TestParse(t *testing.T) { tables, err := parseTables((sql)) assert.Nil(t, err) - assert.Equal(t, "FOO", tables[0]) + assert.Equal(t, "foo", tables[0]) } func TestParseWithComma(t *testing.T) { @@ -19,8 +19,8 @@ func TestParseWithComma(t *testing.T) { tables, err := parseTables((sql)) assert.Nil(t, err) - assert.Equal(t, "FOO", tables[0]) - assert.Equal(t, "BAR", tables[1]) + assert.Equal(t, "foo", tables[0]) + assert.Equal(t, "bar", tables[1]) } func TestParseWithCommas(t *testing.T) { @@ -28,9 +28,9 @@ func TestParseWithCommas(t *testing.T) { tables, err := parseTables((sql)) assert.Nil(t, err) - assert.Equal(t, "FOO", tables[0]) - assert.Equal(t, "BAR", tables[1]) - assert.Equal(t, "BAZ", tables[2]) + assert.Equal(t, "foo", tables[0]) + assert.Equal(t, "bar", tables[1]) + assert.Equal(t, "baz", tables[2]) } func TestArray(t *testing.T) { @@ -56,3 +56,12 @@ func TestXxx(t *testing.T) { assert.Equal(t, 0, len(tables)) } + +func TestParseSubquery(t *testing.T) { + sql := "select * from (select * from people limit 1)" + tables, err := TablesList((sql)) + assert.Nil(t, err) + + assert.Equal(t, 1, len(tables)) + assert.Equal(t, "people", tables[0]) +} diff --git a/pkg/expr/sql_command.go b/pkg/expr/sql_command.go index ec041e9abd6..bbcb942467f 100644 --- a/pkg/expr/sql_command.go +++ b/pkg/expr/sql_command.go @@ -74,7 +74,11 @@ func (gr *SQLCommand) Execute(ctx context.Context, now time.Time, vars mathexp.V allFrames := []*data.Frame{} for _, ref := range gr.varsToQuery { - results := vars[ref] + results, ok := vars[ref] + if !ok { + logger.Warn("no results found for", "ref", ref) + continue + } frames := results.Values.AsDataFrames(ref) allFrames = append(allFrames, frames...) }