mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Access Control: SQL filter (#43560)
* Add the accesscontrol sql filter utility Co-authored-by: Emil Tullstedt <emil.tullstedt@grafana.com>
This commit is contained in:
parent
dcd4e74c54
commit
d350ed0f35
99
pkg/services/accesscontrol/filter.go
Normal file
99
pkg/services/accesscontrol/filter.go
Normal file
@ -0,0 +1,99 @@
|
||||
package accesscontrol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
)
|
||||
|
||||
var sqlIDAcceptList = map[string]struct{}{}
|
||||
|
||||
type SQLDialect interface {
|
||||
DriverName() string
|
||||
}
|
||||
|
||||
// Filter creates a where clause to restrict the view of a query based on a users permissions
|
||||
// Scopes for a certain action will be compared against prefix:id:sqlID where prefix is the scope prefix and sqlID
|
||||
// is the id to generate scope from e.g. user.id
|
||||
func Filter(ctx context.Context, dialect SQLDialect, sqlID, prefix, action string, user *models.SignedInUser) (string, []interface{}, error) {
|
||||
if _, ok := sqlIDAcceptList[sqlID]; !ok {
|
||||
return "", nil, errors.New("sqlID is not in the accept list")
|
||||
}
|
||||
|
||||
if user.Permissions == nil || user.Permissions[user.OrgId] == nil {
|
||||
return "", nil, errors.New("missing permissions")
|
||||
}
|
||||
|
||||
scopes := user.Permissions[user.OrgId][action]
|
||||
if len(scopes) == 0 {
|
||||
return " 1 = 0", nil, nil
|
||||
}
|
||||
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch {
|
||||
case strings.Contains(dialect.DriverName(), migrator.SQLite):
|
||||
sql, args = sqliteQuery(scopes, sqlID, prefix)
|
||||
case strings.Contains(dialect.DriverName(), migrator.MySQL):
|
||||
sql, args = mysqlQuery(scopes, sqlID, prefix)
|
||||
case strings.Contains(dialect.DriverName(), migrator.Postgres):
|
||||
sql, args = postgresQuery(scopes, sqlID, prefix)
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unknown database: %s", dialect.DriverName())
|
||||
}
|
||||
|
||||
return sql, args, nil
|
||||
}
|
||||
|
||||
func sqliteQuery(scopes []string, sqlID, prefix string) (string, []interface{}) {
|
||||
args := []interface{}{prefix}
|
||||
for _, s := range scopes {
|
||||
args = append(args, s)
|
||||
}
|
||||
args = append(args, prefix, prefix, prefix)
|
||||
|
||||
return fmt.Sprintf(`
|
||||
? || ':id:' || %s IN (
|
||||
WITH t(scope) AS (
|
||||
VALUES (?)`+strings.Repeat(`, (?)`, len(scopes)-1)+`
|
||||
)
|
||||
SELECT IIF(t.scope = '*' OR t.scope = ? || ':*' OR t.scope = ? || ':id:*', ? || ':id:' || %s, t.scope) FROM t
|
||||
)
|
||||
`, sqlID, sqlID), args
|
||||
}
|
||||
|
||||
func mysqlQuery(scopes []string, sqlID, prefix string) (string, []interface{}) {
|
||||
args := []interface{}{prefix, prefix, prefix, prefix}
|
||||
for _, s := range scopes {
|
||||
args = append(args, s)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`
|
||||
CONCAT(?, ':id:', %s) IN (
|
||||
SELECT IF(t.scope = '*' OR t.scope = CONCAT(?, ':*') OR t.scope = CONCAT(?, ':id:*'), CONCAT(?, ':id:', %s), t.scope) FROM
|
||||
(SELECT ? AS scope`+strings.Repeat(" UNION ALL SELECT ?", len(scopes)-1)+`) AS t
|
||||
)
|
||||
`, sqlID, sqlID), args
|
||||
}
|
||||
|
||||
func postgresQuery(scopes []string, sqlID, prefix string) (string, []interface{}) {
|
||||
args := []interface{}{prefix, prefix, prefix, prefix}
|
||||
for _, s := range scopes {
|
||||
args = append(args, s)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`
|
||||
CONCAT(?, ':id:', %s) IN (
|
||||
SELECT
|
||||
CASE WHEN p.scope = '*' OR p.scope = CONCAT(?, ':*') OR p.scope = CONCAT(?, ':id:*') THEN CONCAT(?, ':id:', %s)
|
||||
ELSE p.scope
|
||||
END
|
||||
FROM (VALUES (?)`+strings.Repeat(", (?)", len(scopes)-1)+`) as p(scope)
|
||||
)
|
||||
`, sqlID, sqlID), args
|
||||
}
|
74
pkg/services/accesscontrol/filter_bench_test.go
Normal file
74
pkg/services/accesscontrol/filter_bench_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package accesscontrol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
)
|
||||
|
||||
func BenchmarkFilter10_10(b *testing.B) { benchmarkFilter(b, 10, 10) }
|
||||
func BenchmarkFilter100_10(b *testing.B) { benchmarkFilter(b, 100, 10) }
|
||||
func BenchmarkFilter100_100(b *testing.B) { benchmarkFilter(b, 100, 100) }
|
||||
|
||||
func benchmarkFilter(b *testing.B, numDs, numPermissions int) {
|
||||
store, permissions := setupFilterBenchmark(b, numDs, numPermissions)
|
||||
b.ResetTimer()
|
||||
|
||||
// set sqlIDAcceptList before running tests
|
||||
sqlIDAcceptList = map[string]struct{}{
|
||||
"data_source.id": {},
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
baseSql := `SELECT data_source.* FROM data_source WHERE`
|
||||
query, args, err := Filter(
|
||||
context.Background(),
|
||||
&FakeDriver{name: "sqlite3"},
|
||||
"data_source.id",
|
||||
"datasources",
|
||||
"datasources:read",
|
||||
&models.SignedInUser{OrgId: 1, Permissions: map[int64]map[string][]string{1: GroupScopesByAction(permissions)}},
|
||||
)
|
||||
require.NoError(b, err)
|
||||
|
||||
var datasources []models.DataSource
|
||||
sess := store.NewSession(context.Background())
|
||||
err = sess.SQL(baseSql+query, args...).Find(&datasources)
|
||||
require.NoError(b, err)
|
||||
sess.Close()
|
||||
require.Len(b, datasources, numPermissions)
|
||||
}
|
||||
}
|
||||
|
||||
func setupFilterBenchmark(b *testing.B, numDs, numPermissions int) (*sqlstore.SQLStore, []*Permission) {
|
||||
b.Helper()
|
||||
store := sqlstore.InitTestDB(b)
|
||||
|
||||
for i := 1; i <= numDs; i++ {
|
||||
err := store.AddDataSource(context.Background(), &models.AddDataSourceCommand{
|
||||
Name: fmt.Sprintf("ds:%d", i),
|
||||
OrgId: 1,
|
||||
})
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
if numPermissions > numDs {
|
||||
numPermissions = numDs
|
||||
}
|
||||
|
||||
permissions := make([]*Permission, 0, numPermissions)
|
||||
for i := 1; i <= numPermissions; i++ {
|
||||
permissions = append(permissions, &Permission{
|
||||
Action: "datasources:read",
|
||||
Scope: Scope("datasources", "id", strconv.Itoa(i)),
|
||||
})
|
||||
}
|
||||
|
||||
return store, permissions
|
||||
}
|
251
pkg/services/accesscontrol/filter_test.go
Normal file
251
pkg/services/accesscontrol/filter_test.go
Normal file
@ -0,0 +1,251 @@
|
||||
package accesscontrol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
)
|
||||
|
||||
type filterTest struct {
|
||||
desc string
|
||||
driverName string
|
||||
sqlID string
|
||||
action string
|
||||
prefix string
|
||||
permissions []*Permission
|
||||
expectedQuery string
|
||||
expectedArgs []interface{}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
tests := []filterTest{
|
||||
{
|
||||
desc: "should produce datasource filter with sqlite driver",
|
||||
driverName: "sqlite3",
|
||||
sqlID: "data_source.id",
|
||||
prefix: "datasources",
|
||||
action: "datasources:query",
|
||||
permissions: []*Permission{
|
||||
{Action: "datasources:query", Scope: "datasources:id:1"},
|
||||
{Action: "datasources:query", Scope: "datasources:id:2"},
|
||||
{Action: "datasources:query", Scope: "datasources:id:3"},
|
||||
{Action: "datasources:query", Scope: "datasources:id:8"},
|
||||
// Other permissions
|
||||
{Action: "datasources:write", Scope: "datasources:id:100"},
|
||||
{Action: "datasources:delete", Scope: "datasources:id:101"},
|
||||
},
|
||||
expectedQuery: `
|
||||
? || ':id:' || data_source.id IN (
|
||||
WITH t(scope) AS (
|
||||
VALUES (?), (?), (?), (?)
|
||||
)
|
||||
SELECT IIF(t.scope = '*' OR t.scope = ? || ':*' OR t.scope = ? || ':id:*', ? || ':id:' || data_source.id, t.scope) FROM t
|
||||
)
|
||||
`,
|
||||
expectedArgs: []interface{}{
|
||||
"datasources",
|
||||
"datasources:id:1",
|
||||
"datasources:id:2",
|
||||
"datasources:id:3",
|
||||
"datasources:id:8",
|
||||
"datasources",
|
||||
"datasources",
|
||||
"datasources",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should produce dashboard filter with mysql driver",
|
||||
driverName: "mysql",
|
||||
sqlID: "dashboard.id",
|
||||
prefix: "dashboards",
|
||||
action: "dashboards:read",
|
||||
permissions: []*Permission{
|
||||
{Action: "dashboards:read", Scope: "dashboards:id:1"},
|
||||
{Action: "dashboards:read", Scope: "dashboards:id:2"},
|
||||
{Action: "dashboards:read", Scope: "dashboards:id:5"},
|
||||
// Other permissions
|
||||
{Action: "dashboards:write", Scope: "dashboards:id:100"},
|
||||
{Action: "dashboards:delete", Scope: "dashboards:id:101"},
|
||||
},
|
||||
expectedQuery: `
|
||||
CONCAT(?, ':id:', dashboard.id) IN (
|
||||
SELECT IF(t.scope = '*' OR t.scope = CONCAT(?, ':*') OR t.scope = CONCAT(?, ':id:*'), CONCAT(?, ':id:', dashboard.id), t.scope) FROM
|
||||
(SELECT ? AS scope UNION ALL SELECT ? UNION ALL SELECT ?) AS t
|
||||
)
|
||||
`,
|
||||
expectedArgs: []interface{}{
|
||||
"dashboards",
|
||||
"dashboards",
|
||||
"dashboards",
|
||||
"dashboards",
|
||||
"dashboards:id:1",
|
||||
"dashboards:id:2",
|
||||
"dashboards:id:5",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should produce user filter with postgres driver",
|
||||
driverName: "postgres",
|
||||
sqlID: "user.id",
|
||||
prefix: "users",
|
||||
action: "users:read",
|
||||
permissions: []*Permission{
|
||||
{Action: "users:read", Scope: "users:id:1"},
|
||||
{Action: "users:read", Scope: "users:id:100"},
|
||||
// Other permissions
|
||||
{Action: "dashboards:write", Scope: "dashboards:id:100"},
|
||||
{Action: "dashboards:delete", Scope: "dashboards:id:101"},
|
||||
},
|
||||
expectedQuery: `
|
||||
CONCAT(?, ':id:', user.id) IN (
|
||||
SELECT
|
||||
CASE WHEN p.scope = '*' OR p.scope = CONCAT(?, ':*') OR p.scope = CONCAT(?, ':id:*') THEN CONCAT(?, ':id:', user.id)
|
||||
ELSE p.scope
|
||||
END
|
||||
FROM (VALUES (?), (?)) as p(scope)
|
||||
)
|
||||
`,
|
||||
expectedArgs: []interface{}{
|
||||
"users",
|
||||
"users",
|
||||
"users",
|
||||
"users",
|
||||
"users:id:1",
|
||||
"users:id:100",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// set sqlIDAcceptList before running tests
|
||||
sqlIDAcceptList = map[string]struct{}{
|
||||
"user.id": {},
|
||||
"dashboard.id": {},
|
||||
"data_source.id": {},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
query, args, err := Filter(
|
||||
context.Background(),
|
||||
FakeDriver{name: tt.driverName},
|
||||
tt.sqlID,
|
||||
tt.prefix,
|
||||
tt.action,
|
||||
&models.SignedInUser{OrgId: 1, Permissions: map[int64]map[string][]string{1: GroupScopesByAction(tt.permissions)}},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedQuery, query)
|
||||
|
||||
require.Len(t, args, len(tt.expectedArgs))
|
||||
for i := range tt.expectedArgs {
|
||||
assert.Equal(t, tt.expectedArgs[i], args[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type filterDatasourcesTestCase struct {
|
||||
desc string
|
||||
sqlID string
|
||||
permissions []*Permission
|
||||
expectedDataSources []string
|
||||
expectErr bool
|
||||
}
|
||||
|
||||
func TestFilter_Datasources(t *testing.T) {
|
||||
tests := []filterDatasourcesTestCase{
|
||||
{
|
||||
desc: "expect all data sources to be returned",
|
||||
sqlID: "data_source.id",
|
||||
permissions: []*Permission{
|
||||
{Action: "datasources:read", Scope: "datasources:*"},
|
||||
},
|
||||
expectedDataSources: []string{"ds:1", "ds:2", "ds:3", "ds:4", "ds:5", "ds:6", "ds:7", "ds:8", "ds:9", "ds:10"},
|
||||
},
|
||||
{
|
||||
desc: "expect no data sources to be returned",
|
||||
sqlID: "data_source.id",
|
||||
permissions: []*Permission{},
|
||||
expectedDataSources: []string{},
|
||||
},
|
||||
{
|
||||
desc: "expect data sources with id 3, 7 and 8 to be returned",
|
||||
sqlID: "data_source.id",
|
||||
permissions: []*Permission{
|
||||
{Action: "datasources:read", Scope: "datasources:id:3"},
|
||||
{Action: "datasources:read", Scope: "datasources:id:7"},
|
||||
{Action: "datasources:read", Scope: "datasources:id:8"},
|
||||
},
|
||||
expectedDataSources: []string{"ds:3", "ds:7", "ds:8"},
|
||||
},
|
||||
{
|
||||
desc: "expect error if sqlID is not in the accept list",
|
||||
sqlID: "other.id",
|
||||
permissions: []*Permission{
|
||||
{Action: "datasources:read", Scope: "datasources:id:3"},
|
||||
{Action: "datasources:read", Scope: "datasources:id:7"},
|
||||
{Action: "datasources:read", Scope: "datasources:id:8"},
|
||||
},
|
||||
expectedDataSources: []string{"ds:3", "ds:7", "ds:8"},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
// set sqlIDAcceptList before running tests
|
||||
sqlIDAcceptList = map[string]struct{}{
|
||||
"data_source.id": {},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
store := sqlstore.InitTestDB(t)
|
||||
|
||||
sess := store.NewSession(context.Background())
|
||||
defer sess.Close()
|
||||
|
||||
// seed 10 data sources
|
||||
for i := 1; i <= 10; i++ {
|
||||
err := store.AddDataSource(context.Background(), &models.AddDataSourceCommand{Name: fmt.Sprintf("ds:%d", i)})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
baseSql := `SELECT data_source.* FROM data_source WHERE`
|
||||
query, args, err := Filter(
|
||||
context.Background(),
|
||||
&FakeDriver{name: "sqlite3"},
|
||||
tt.sqlID,
|
||||
"datasources",
|
||||
"datasources:read",
|
||||
&models.SignedInUser{OrgId: 1, Permissions: map[int64]map[string][]string{1: GroupScopesByAction(tt.permissions)}},
|
||||
)
|
||||
|
||||
if !tt.expectErr {
|
||||
require.NoError(t, err)
|
||||
var datasources []models.DataSource
|
||||
err = sess.SQL(baseSql+query, args...).Find(&datasources)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, datasources, len(tt.expectedDataSources))
|
||||
for i, ds := range datasources {
|
||||
assert.Equal(t, tt.expectedDataSources[i], ds.Name)
|
||||
}
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type FakeDriver struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (f FakeDriver) DriverName() string {
|
||||
return f.name
|
||||
}
|
Loading…
Reference in New Issue
Block a user