Access Control: SQL filter (#43560)

* Add the accesscontrol sql filter utility

Co-authored-by: Emil Tullstedt <emil.tullstedt@grafana.com>
This commit is contained in:
Karl Persson 2022-01-10 14:26:57 +01:00 committed by GitHub
parent dcd4e74c54
commit d350ed0f35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 424 additions and 0 deletions

View File

@ -0,0 +1,99 @@
package accesscontrol
import (
"context"
"errors"
"fmt"
"strings"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
)
var sqlIDAcceptList = map[string]struct{}{}
type SQLDialect interface {
DriverName() string
}
// Filter creates a where clause to restrict the view of a query based on a users permissions
// Scopes for a certain action will be compared against prefix:id:sqlID where prefix is the scope prefix and sqlID
// is the id to generate scope from e.g. user.id
func Filter(ctx context.Context, dialect SQLDialect, sqlID, prefix, action string, user *models.SignedInUser) (string, []interface{}, error) {
if _, ok := sqlIDAcceptList[sqlID]; !ok {
return "", nil, errors.New("sqlID is not in the accept list")
}
if user.Permissions == nil || user.Permissions[user.OrgId] == nil {
return "", nil, errors.New("missing permissions")
}
scopes := user.Permissions[user.OrgId][action]
if len(scopes) == 0 {
return " 1 = 0", nil, nil
}
var sql string
var args []interface{}
switch {
case strings.Contains(dialect.DriverName(), migrator.SQLite):
sql, args = sqliteQuery(scopes, sqlID, prefix)
case strings.Contains(dialect.DriverName(), migrator.MySQL):
sql, args = mysqlQuery(scopes, sqlID, prefix)
case strings.Contains(dialect.DriverName(), migrator.Postgres):
sql, args = postgresQuery(scopes, sqlID, prefix)
default:
return "", nil, fmt.Errorf("unknown database: %s", dialect.DriverName())
}
return sql, args, nil
}
func sqliteQuery(scopes []string, sqlID, prefix string) (string, []interface{}) {
args := []interface{}{prefix}
for _, s := range scopes {
args = append(args, s)
}
args = append(args, prefix, prefix, prefix)
return fmt.Sprintf(`
? || ':id:' || %s IN (
WITH t(scope) AS (
VALUES (?)`+strings.Repeat(`, (?)`, len(scopes)-1)+`
)
SELECT IIF(t.scope = '*' OR t.scope = ? || ':*' OR t.scope = ? || ':id:*', ? || ':id:' || %s, t.scope) FROM t
)
`, sqlID, sqlID), args
}
func mysqlQuery(scopes []string, sqlID, prefix string) (string, []interface{}) {
args := []interface{}{prefix, prefix, prefix, prefix}
for _, s := range scopes {
args = append(args, s)
}
return fmt.Sprintf(`
CONCAT(?, ':id:', %s) IN (
SELECT IF(t.scope = '*' OR t.scope = CONCAT(?, ':*') OR t.scope = CONCAT(?, ':id:*'), CONCAT(?, ':id:', %s), t.scope) FROM
(SELECT ? AS scope`+strings.Repeat(" UNION ALL SELECT ?", len(scopes)-1)+`) AS t
)
`, sqlID, sqlID), args
}
func postgresQuery(scopes []string, sqlID, prefix string) (string, []interface{}) {
args := []interface{}{prefix, prefix, prefix, prefix}
for _, s := range scopes {
args = append(args, s)
}
return fmt.Sprintf(`
CONCAT(?, ':id:', %s) IN (
SELECT
CASE WHEN p.scope = '*' OR p.scope = CONCAT(?, ':*') OR p.scope = CONCAT(?, ':id:*') THEN CONCAT(?, ':id:', %s)
ELSE p.scope
END
FROM (VALUES (?)`+strings.Repeat(", (?)", len(scopes)-1)+`) as p(scope)
)
`, sqlID, sqlID), args
}

View File

@ -0,0 +1,74 @@
package accesscontrol
import (
"context"
"fmt"
"strconv"
"testing"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore"
)
func BenchmarkFilter10_10(b *testing.B) { benchmarkFilter(b, 10, 10) }
func BenchmarkFilter100_10(b *testing.B) { benchmarkFilter(b, 100, 10) }
func BenchmarkFilter100_100(b *testing.B) { benchmarkFilter(b, 100, 100) }
func benchmarkFilter(b *testing.B, numDs, numPermissions int) {
store, permissions := setupFilterBenchmark(b, numDs, numPermissions)
b.ResetTimer()
// set sqlIDAcceptList before running tests
sqlIDAcceptList = map[string]struct{}{
"data_source.id": {},
}
for i := 0; i < b.N; i++ {
baseSql := `SELECT data_source.* FROM data_source WHERE`
query, args, err := Filter(
context.Background(),
&FakeDriver{name: "sqlite3"},
"data_source.id",
"datasources",
"datasources:read",
&models.SignedInUser{OrgId: 1, Permissions: map[int64]map[string][]string{1: GroupScopesByAction(permissions)}},
)
require.NoError(b, err)
var datasources []models.DataSource
sess := store.NewSession(context.Background())
err = sess.SQL(baseSql+query, args...).Find(&datasources)
require.NoError(b, err)
sess.Close()
require.Len(b, datasources, numPermissions)
}
}
func setupFilterBenchmark(b *testing.B, numDs, numPermissions int) (*sqlstore.SQLStore, []*Permission) {
b.Helper()
store := sqlstore.InitTestDB(b)
for i := 1; i <= numDs; i++ {
err := store.AddDataSource(context.Background(), &models.AddDataSourceCommand{
Name: fmt.Sprintf("ds:%d", i),
OrgId: 1,
})
require.NoError(b, err)
}
if numPermissions > numDs {
numPermissions = numDs
}
permissions := make([]*Permission, 0, numPermissions)
for i := 1; i <= numPermissions; i++ {
permissions = append(permissions, &Permission{
Action: "datasources:read",
Scope: Scope("datasources", "id", strconv.Itoa(i)),
})
}
return store, permissions
}

View File

@ -0,0 +1,251 @@
package accesscontrol
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore"
)
type filterTest struct {
desc string
driverName string
sqlID string
action string
prefix string
permissions []*Permission
expectedQuery string
expectedArgs []interface{}
}
func TestFilter(t *testing.T) {
tests := []filterTest{
{
desc: "should produce datasource filter with sqlite driver",
driverName: "sqlite3",
sqlID: "data_source.id",
prefix: "datasources",
action: "datasources:query",
permissions: []*Permission{
{Action: "datasources:query", Scope: "datasources:id:1"},
{Action: "datasources:query", Scope: "datasources:id:2"},
{Action: "datasources:query", Scope: "datasources:id:3"},
{Action: "datasources:query", Scope: "datasources:id:8"},
// Other permissions
{Action: "datasources:write", Scope: "datasources:id:100"},
{Action: "datasources:delete", Scope: "datasources:id:101"},
},
expectedQuery: `
? || ':id:' || data_source.id IN (
WITH t(scope) AS (
VALUES (?), (?), (?), (?)
)
SELECT IIF(t.scope = '*' OR t.scope = ? || ':*' OR t.scope = ? || ':id:*', ? || ':id:' || data_source.id, t.scope) FROM t
)
`,
expectedArgs: []interface{}{
"datasources",
"datasources:id:1",
"datasources:id:2",
"datasources:id:3",
"datasources:id:8",
"datasources",
"datasources",
"datasources",
},
},
{
desc: "should produce dashboard filter with mysql driver",
driverName: "mysql",
sqlID: "dashboard.id",
prefix: "dashboards",
action: "dashboards:read",
permissions: []*Permission{
{Action: "dashboards:read", Scope: "dashboards:id:1"},
{Action: "dashboards:read", Scope: "dashboards:id:2"},
{Action: "dashboards:read", Scope: "dashboards:id:5"},
// Other permissions
{Action: "dashboards:write", Scope: "dashboards:id:100"},
{Action: "dashboards:delete", Scope: "dashboards:id:101"},
},
expectedQuery: `
CONCAT(?, ':id:', dashboard.id) IN (
SELECT IF(t.scope = '*' OR t.scope = CONCAT(?, ':*') OR t.scope = CONCAT(?, ':id:*'), CONCAT(?, ':id:', dashboard.id), t.scope) FROM
(SELECT ? AS scope UNION ALL SELECT ? UNION ALL SELECT ?) AS t
)
`,
expectedArgs: []interface{}{
"dashboards",
"dashboards",
"dashboards",
"dashboards",
"dashboards:id:1",
"dashboards:id:2",
"dashboards:id:5",
},
},
{
desc: "should produce user filter with postgres driver",
driverName: "postgres",
sqlID: "user.id",
prefix: "users",
action: "users:read",
permissions: []*Permission{
{Action: "users:read", Scope: "users:id:1"},
{Action: "users:read", Scope: "users:id:100"},
// Other permissions
{Action: "dashboards:write", Scope: "dashboards:id:100"},
{Action: "dashboards:delete", Scope: "dashboards:id:101"},
},
expectedQuery: `
CONCAT(?, ':id:', user.id) IN (
SELECT
CASE WHEN p.scope = '*' OR p.scope = CONCAT(?, ':*') OR p.scope = CONCAT(?, ':id:*') THEN CONCAT(?, ':id:', user.id)
ELSE p.scope
END
FROM (VALUES (?), (?)) as p(scope)
)
`,
expectedArgs: []interface{}{
"users",
"users",
"users",
"users",
"users:id:1",
"users:id:100",
},
},
}
// set sqlIDAcceptList before running tests
sqlIDAcceptList = map[string]struct{}{
"user.id": {},
"dashboard.id": {},
"data_source.id": {},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
query, args, err := Filter(
context.Background(),
FakeDriver{name: tt.driverName},
tt.sqlID,
tt.prefix,
tt.action,
&models.SignedInUser{OrgId: 1, Permissions: map[int64]map[string][]string{1: GroupScopesByAction(tt.permissions)}},
)
require.NoError(t, err)
assert.Equal(t, tt.expectedQuery, query)
require.Len(t, args, len(tt.expectedArgs))
for i := range tt.expectedArgs {
assert.Equal(t, tt.expectedArgs[i], args[i])
}
})
}
}
type filterDatasourcesTestCase struct {
desc string
sqlID string
permissions []*Permission
expectedDataSources []string
expectErr bool
}
func TestFilter_Datasources(t *testing.T) {
tests := []filterDatasourcesTestCase{
{
desc: "expect all data sources to be returned",
sqlID: "data_source.id",
permissions: []*Permission{
{Action: "datasources:read", Scope: "datasources:*"},
},
expectedDataSources: []string{"ds:1", "ds:2", "ds:3", "ds:4", "ds:5", "ds:6", "ds:7", "ds:8", "ds:9", "ds:10"},
},
{
desc: "expect no data sources to be returned",
sqlID: "data_source.id",
permissions: []*Permission{},
expectedDataSources: []string{},
},
{
desc: "expect data sources with id 3, 7 and 8 to be returned",
sqlID: "data_source.id",
permissions: []*Permission{
{Action: "datasources:read", Scope: "datasources:id:3"},
{Action: "datasources:read", Scope: "datasources:id:7"},
{Action: "datasources:read", Scope: "datasources:id:8"},
},
expectedDataSources: []string{"ds:3", "ds:7", "ds:8"},
},
{
desc: "expect error if sqlID is not in the accept list",
sqlID: "other.id",
permissions: []*Permission{
{Action: "datasources:read", Scope: "datasources:id:3"},
{Action: "datasources:read", Scope: "datasources:id:7"},
{Action: "datasources:read", Scope: "datasources:id:8"},
},
expectedDataSources: []string{"ds:3", "ds:7", "ds:8"},
expectErr: true,
},
}
// set sqlIDAcceptList before running tests
sqlIDAcceptList = map[string]struct{}{
"data_source.id": {},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
store := sqlstore.InitTestDB(t)
sess := store.NewSession(context.Background())
defer sess.Close()
// seed 10 data sources
for i := 1; i <= 10; i++ {
err := store.AddDataSource(context.Background(), &models.AddDataSourceCommand{Name: fmt.Sprintf("ds:%d", i)})
require.NoError(t, err)
}
baseSql := `SELECT data_source.* FROM data_source WHERE`
query, args, err := Filter(
context.Background(),
&FakeDriver{name: "sqlite3"},
tt.sqlID,
"datasources",
"datasources:read",
&models.SignedInUser{OrgId: 1, Permissions: map[int64]map[string][]string{1: GroupScopesByAction(tt.permissions)}},
)
if !tt.expectErr {
require.NoError(t, err)
var datasources []models.DataSource
err = sess.SQL(baseSql+query, args...).Find(&datasources)
require.NoError(t, err)
assert.Len(t, datasources, len(tt.expectedDataSources))
for i, ds := range datasources {
assert.Equal(t, tt.expectedDataSources[i], ds.Name)
}
} else {
require.Error(t, err)
}
})
}
}
type FakeDriver struct {
name string
}
func (f FakeDriver) DriverName() string {
return f.name
}