Access control: Rewrite access control SQL filter (#44488)

* Rewrite access control sql filter
This commit is contained in:
Karl Persson 2022-01-27 13:06:08 +01:00 committed by GitHub
parent b42161a713
commit bf63ccbe00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 225 deletions

View File

@ -3,101 +3,65 @@ package accesscontrol
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
)
var sqlIDAcceptList = map[string]struct{}{
"org_user.user_id": {},
}
type SQLDialect interface {
DriverName() string
}
const denyQuery = " 1 = 0"
const allowAllQuery = " 1 = 1"
// Filter creates a where clause to restrict the view of a query based on a users permissions
// Scopes for a certain action will be compared against prefix:id:sqlID where prefix is the scope prefix and sqlID
// is the id to generate scope from e.g. user.id
func Filter(ctx context.Context, dialect SQLDialect, sqlID, prefix, action string, user *models.SignedInUser) (string, []interface{}, error) {
func Filter(ctx context.Context, sqlID, prefix, action string, user *models.SignedInUser) (string, []interface{}, error) {
if _, ok := sqlIDAcceptList[sqlID]; !ok {
return "", nil, errors.New("sqlID is not in the accept list")
return denyQuery, nil, errors.New("sqlID is not in the accept list")
}
if user.Permissions == nil || user.Permissions[user.OrgId] == nil {
return "", nil, errors.New("missing permissions")
return denyQuery, nil, errors.New("missing permissions")
}
scopes := user.Permissions[user.OrgId][action]
if len(scopes) == 0 {
return " 1 = 0", nil, nil
var hasWildcard bool
var ids []interface{}
for _, scope := range user.Permissions[user.OrgId][action] {
if strings.HasPrefix(scope, prefix) {
if id := strings.TrimPrefix(scope, prefix); id == ":*" || id == ":id:*" {
hasWildcard = true
break
}
if id, err := parseScopeID(scope); err == nil {
ids = append(ids, id)
}
}
}
var sql string
var args []interface{}
switch {
case strings.Contains(dialect.DriverName(), migrator.SQLite):
sql, args = sqliteQuery(scopes, sqlID, prefix)
case strings.Contains(dialect.DriverName(), migrator.MySQL):
sql, args = mysqlQuery(scopes, sqlID, prefix)
case strings.Contains(dialect.DriverName(), migrator.Postgres):
sql, args = postgresQuery(scopes, sqlID, prefix)
default:
return "", nil, fmt.Errorf("unknown database: %s", dialect.DriverName())
if hasWildcard {
return allowAllQuery, nil, nil
}
return sql, args, nil
if len(ids) == 0 {
return denyQuery, nil, nil
}
query := strings.Builder{}
query.WriteRune(' ')
query.WriteString(sqlID)
query.WriteString(" IN ")
query.WriteString("(?")
query.WriteString(strings.Repeat(",?", len(ids)-1))
query.WriteRune(')')
return query.String(), ids, nil
}
func sqliteQuery(scopes []string, sqlID, prefix string) (string, []interface{}) {
args := []interface{}{prefix}
for _, s := range scopes {
args = append(args, s)
}
args = append(args, prefix, prefix, prefix)
return fmt.Sprintf(`
? || ':id:' || %s IN (
WITH t(scope) AS (
VALUES (?)`+strings.Repeat(`, (?)`, len(scopes)-1)+`
)
SELECT IIF(t.scope = '*' OR t.scope = ? || ':*' OR t.scope = ? || ':id:*', ? || ':id:' || %s, t.scope) FROM t
)
`, sqlID, sqlID), args
}
func mysqlQuery(scopes []string, sqlID, prefix string) (string, []interface{}) {
args := []interface{}{prefix, prefix, prefix, prefix}
for _, s := range scopes {
args = append(args, s)
}
return fmt.Sprintf(`
CONCAT(?, ':id:', %s) IN (
SELECT IF(t.scope = '*' OR t.scope = CONCAT(?, ':*') OR t.scope = CONCAT(?, ':id:*'), CONCAT(?, ':id:', %s), t.scope) FROM
(SELECT ? AS scope`+strings.Repeat(" UNION ALL SELECT ?", len(scopes)-1)+`) AS t
)
`, sqlID, sqlID), args
}
func postgresQuery(scopes []string, sqlID, prefix string) (string, []interface{}) {
args := []interface{}{prefix, prefix, prefix, prefix}
for _, s := range scopes {
args = append(args, s)
}
return fmt.Sprintf(`
CONCAT(?, ':id:', %s) IN (
SELECT
CASE WHEN p.scope = '*' OR p.scope = CONCAT(?, ':*') OR p.scope = CONCAT(?, ':id:*') THEN CONCAT(?, ':id:', %s)
ELSE p.scope
END
FROM (VALUES (?)`+strings.Repeat(", (?)", len(scopes)-1)+`) as p(scope)
)
`, sqlID, sqlID), args
func parseScopeID(scope string) (int64, error) {
return strconv.ParseInt(scope[strings.LastIndex(scope, ":")+1:], 10, 64)
}
// SetAcceptListForTest allow us to mutate the list for blackbox testing

View File

@ -13,9 +13,11 @@ import (
"github.com/grafana/grafana/pkg/services/sqlstore"
)
func BenchmarkFilter10_10(b *testing.B) { benchmarkFilter(b, 10, 10) }
func BenchmarkFilter100_10(b *testing.B) { benchmarkFilter(b, 100, 10) }
func BenchmarkFilter100_100(b *testing.B) { benchmarkFilter(b, 100, 100) }
func BenchmarkFilter10_10(b *testing.B) { benchmarkFilter(b, 10, 10) }
func BenchmarkFilter100_10(b *testing.B) { benchmarkFilter(b, 100, 10) }
func BenchmarkFilter100_100(b *testing.B) { benchmarkFilter(b, 100, 100) }
func BenchmarkFilter1000_100(b *testing.B) { benchmarkFilter(b, 1000, 100) }
func BenchmarkFilter1000_1000(b *testing.B) { benchmarkFilter(b, 1000, 100) }
func benchmarkFilter(b *testing.B, numDs, numPermissions int) {
store, permissions := setupFilterBenchmark(b, numDs, numPermissions)
@ -31,7 +33,6 @@ func benchmarkFilter(b *testing.B, numDs, numPermissions int) {
baseSql := `SELECT data_source.* FROM data_source WHERE`
query, args, err := accesscontrol.Filter(
context.Background(),
&FakeDriver{name: "sqlite3"},
"data_source.id",
"datasources",
"datasources:read",

View File

@ -13,145 +13,6 @@ import (
"github.com/grafana/grafana/pkg/services/sqlstore"
)
type filterTest struct {
desc string
driverName string
sqlID string
action string
prefix string
permissions []*accesscontrol.Permission
expectedQuery string
expectedArgs []interface{}
}
func TestFilter(t *testing.T) {
tests := []filterTest{
{
desc: "should produce datasource filter with sqlite driver",
driverName: "sqlite3",
sqlID: "data_source.id",
prefix: "datasources",
action: "datasources:query",
permissions: []*accesscontrol.Permission{
{Action: "datasources:query", Scope: "datasources:id:1"},
{Action: "datasources:query", Scope: "datasources:id:2"},
{Action: "datasources:query", Scope: "datasources:id:3"},
{Action: "datasources:query", Scope: "datasources:id:8"},
// Other permissions
{Action: "datasources:write", Scope: "datasources:id:100"},
{Action: "datasources:delete", Scope: "datasources:id:101"},
},
expectedQuery: `
? || ':id:' || data_source.id IN (
WITH t(scope) AS (
VALUES (?), (?), (?), (?)
)
SELECT IIF(t.scope = '*' OR t.scope = ? || ':*' OR t.scope = ? || ':id:*', ? || ':id:' || data_source.id, t.scope) FROM t
)
`,
expectedArgs: []interface{}{
"datasources",
"datasources:id:1",
"datasources:id:2",
"datasources:id:3",
"datasources:id:8",
"datasources",
"datasources",
"datasources",
},
},
{
desc: "should produce dashboard filter with mysql driver",
driverName: "mysql",
sqlID: "dashboard.id",
prefix: "dashboards",
action: "dashboards:read",
permissions: []*accesscontrol.Permission{
{Action: "dashboards:read", Scope: "dashboards:id:1"},
{Action: "dashboards:read", Scope: "dashboards:id:2"},
{Action: "dashboards:read", Scope: "dashboards:id:5"},
// Other permissions
{Action: "dashboards:write", Scope: "dashboards:id:100"},
{Action: "dashboards:delete", Scope: "dashboards:id:101"},
},
expectedQuery: `
CONCAT(?, ':id:', dashboard.id) IN (
SELECT IF(t.scope = '*' OR t.scope = CONCAT(?, ':*') OR t.scope = CONCAT(?, ':id:*'), CONCAT(?, ':id:', dashboard.id), t.scope) FROM
(SELECT ? AS scope UNION ALL SELECT ? UNION ALL SELECT ?) AS t
)
`,
expectedArgs: []interface{}{
"dashboards",
"dashboards",
"dashboards",
"dashboards",
"dashboards:id:1",
"dashboards:id:2",
"dashboards:id:5",
},
},
{
desc: "should produce user filter with postgres driver",
driverName: "postgres",
sqlID: "user.id",
prefix: "users",
action: "users:read",
permissions: []*accesscontrol.Permission{
{Action: "users:read", Scope: "users:id:1"},
{Action: "users:read", Scope: "users:id:100"},
// Other permissions
{Action: "dashboards:write", Scope: "dashboards:id:100"},
{Action: "dashboards:delete", Scope: "dashboards:id:101"},
},
expectedQuery: `
CONCAT(?, ':id:', user.id) IN (
SELECT
CASE WHEN p.scope = '*' OR p.scope = CONCAT(?, ':*') OR p.scope = CONCAT(?, ':id:*') THEN CONCAT(?, ':id:', user.id)
ELSE p.scope
END
FROM (VALUES (?), (?)) as p(scope)
)
`,
expectedArgs: []interface{}{
"users",
"users",
"users",
"users",
"users:id:1",
"users:id:100",
},
},
}
// set sqlIDAcceptList before running tests
restore := accesscontrol.SetAcceptListForTest(map[string]struct{}{
"user.id": {},
"dashboard.id": {},
"data_source.id": {},
})
defer restore()
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
query, args, err := accesscontrol.Filter(
context.Background(),
FakeDriver{name: tt.driverName},
tt.sqlID,
tt.prefix,
tt.action,
&models.SignedInUser{OrgId: 1, Permissions: map[int64]map[string][]string{1: accesscontrol.GroupScopesByAction(tt.permissions)}},
)
require.NoError(t, err)
assert.Equal(t, tt.expectedQuery, query)
require.Len(t, args, len(tt.expectedArgs))
for i := range tt.expectedArgs {
assert.Equal(t, tt.expectedArgs[i], args[i])
}
})
}
}
type filterDatasourcesTestCase struct {
desc string
sqlID string
@ -221,7 +82,6 @@ func TestFilter_Datasources(t *testing.T) {
baseSql := `SELECT data_source.* FROM data_source WHERE`
query, args, err := accesscontrol.Filter(
context.Background(),
&FakeDriver{name: "sqlite3"},
tt.sqlID,
"datasources",
"datasources:read",
@ -244,11 +104,3 @@ func TestFilter_Datasources(t *testing.T) {
})
}
}
type FakeDriver struct {
name string
}
func (f FakeDriver) DriverName() string {
return f.name
}

View File

@ -119,7 +119,7 @@ func (ss *SQLStore) GetOrgUsers(ctx context.Context, query *models.GetOrgUsersQu
whereConditions = append(whereConditions, fmt.Sprintf("%s.is_service_account = %t", x.Dialect().Quote("user"), query.IsServiceAccount))
if ss.Cfg.IsFeatureToggleEnabled(featuremgmt.FlagAccesscontrol) {
q, args, err := accesscontrol.Filter(ctx, ss.Dialect, "org_user.user_id", "users", "org.users:read", query.User)
q, args, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User)
if err != nil {
return err
}
@ -182,7 +182,7 @@ func (ss *SQLStore) SearchOrgUsers(ctx context.Context, query *models.SearchOrgU
whereConditions = append(whereConditions, fmt.Sprintf("%s.is_service_account = %t", x.Dialect().Quote("user"), query.IsServiceAccount))
if ss.Cfg.IsFeatureToggleEnabled(featuremgmt.FlagAccesscontrol) {
q, args, err := accesscontrol.Filter(ctx, ss.Dialect, "org_user.user_id", "users", "org.users:read", query.User)
q, args, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User)
if err != nil {
return err
}