diff --git a/pkg/services/accesscontrol/filter.go b/pkg/services/accesscontrol/filter.go index 97073be1472..97a14354ba7 100644 --- a/pkg/services/accesscontrol/filter.go +++ b/pkg/services/accesscontrol/filter.go @@ -3,101 +3,65 @@ package accesscontrol import ( "context" "errors" - "fmt" + "strconv" "strings" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/services/sqlstore/migrator" ) var sqlIDAcceptList = map[string]struct{}{ "org_user.user_id": {}, } -type SQLDialect interface { - DriverName() string -} +const denyQuery = " 1 = 0" +const allowAllQuery = " 1 = 1" // Filter creates a where clause to restrict the view of a query based on a users permissions // Scopes for a certain action will be compared against prefix:id:sqlID where prefix is the scope prefix and sqlID // is the id to generate scope from e.g. user.id -func Filter(ctx context.Context, dialect SQLDialect, sqlID, prefix, action string, user *models.SignedInUser) (string, []interface{}, error) { +func Filter(ctx context.Context, sqlID, prefix, action string, user *models.SignedInUser) (string, []interface{}, error) { if _, ok := sqlIDAcceptList[sqlID]; !ok { - return "", nil, errors.New("sqlID is not in the accept list") + return denyQuery, nil, errors.New("sqlID is not in the accept list") } - if user.Permissions == nil || user.Permissions[user.OrgId] == nil { - return "", nil, errors.New("missing permissions") + return denyQuery, nil, errors.New("missing permissions") } - scopes := user.Permissions[user.OrgId][action] - if len(scopes) == 0 { - return " 1 = 0", nil, nil + var hasWildcard bool + var ids []interface{} + for _, scope := range user.Permissions[user.OrgId][action] { + if strings.HasPrefix(scope, prefix) { + if id := strings.TrimPrefix(scope, prefix); id == ":*" || id == ":id:*" { + hasWildcard = true + break + } + if id, err := parseScopeID(scope); err == nil { + ids = append(ids, id) + } + } } - var sql string - var args []interface{} - - switch { - case strings.Contains(dialect.DriverName(), migrator.SQLite): - sql, args = sqliteQuery(scopes, sqlID, prefix) - case strings.Contains(dialect.DriverName(), migrator.MySQL): - sql, args = mysqlQuery(scopes, sqlID, prefix) - case strings.Contains(dialect.DriverName(), migrator.Postgres): - sql, args = postgresQuery(scopes, sqlID, prefix) - default: - return "", nil, fmt.Errorf("unknown database: %s", dialect.DriverName()) + if hasWildcard { + return allowAllQuery, nil, nil } - return sql, args, nil + if len(ids) == 0 { + return denyQuery, nil, nil + } + + query := strings.Builder{} + query.WriteRune(' ') + query.WriteString(sqlID) + query.WriteString(" IN ") + query.WriteString("(?") + query.WriteString(strings.Repeat(",?", len(ids)-1)) + query.WriteRune(')') + + return query.String(), ids, nil } -func sqliteQuery(scopes []string, sqlID, prefix string) (string, []interface{}) { - args := []interface{}{prefix} - for _, s := range scopes { - args = append(args, s) - } - args = append(args, prefix, prefix, prefix) - - return fmt.Sprintf(` - ? || ':id:' || %s IN ( - WITH t(scope) AS ( - VALUES (?)`+strings.Repeat(`, (?)`, len(scopes)-1)+` - ) - SELECT IIF(t.scope = '*' OR t.scope = ? || ':*' OR t.scope = ? || ':id:*', ? || ':id:' || %s, t.scope) FROM t - ) - `, sqlID, sqlID), args -} - -func mysqlQuery(scopes []string, sqlID, prefix string) (string, []interface{}) { - args := []interface{}{prefix, prefix, prefix, prefix} - for _, s := range scopes { - args = append(args, s) - } - - return fmt.Sprintf(` - CONCAT(?, ':id:', %s) IN ( - SELECT IF(t.scope = '*' OR t.scope = CONCAT(?, ':*') OR t.scope = CONCAT(?, ':id:*'), CONCAT(?, ':id:', %s), t.scope) FROM - (SELECT ? AS scope`+strings.Repeat(" UNION ALL SELECT ?", len(scopes)-1)+`) AS t - ) - `, sqlID, sqlID), args -} - -func postgresQuery(scopes []string, sqlID, prefix string) (string, []interface{}) { - args := []interface{}{prefix, prefix, prefix, prefix} - for _, s := range scopes { - args = append(args, s) - } - - return fmt.Sprintf(` - CONCAT(?, ':id:', %s) IN ( - SELECT - CASE WHEN p.scope = '*' OR p.scope = CONCAT(?, ':*') OR p.scope = CONCAT(?, ':id:*') THEN CONCAT(?, ':id:', %s) - ELSE p.scope - END - FROM (VALUES (?)`+strings.Repeat(", (?)", len(scopes)-1)+`) as p(scope) - ) - `, sqlID, sqlID), args +func parseScopeID(scope string) (int64, error) { + return strconv.ParseInt(scope[strings.LastIndex(scope, ":")+1:], 10, 64) } // SetAcceptListForTest allow us to mutate the list for blackbox testing diff --git a/pkg/services/accesscontrol/filter_bench_test.go b/pkg/services/accesscontrol/filter_bench_test.go index 6bb63dcd0b0..7ef51af5251 100644 --- a/pkg/services/accesscontrol/filter_bench_test.go +++ b/pkg/services/accesscontrol/filter_bench_test.go @@ -13,9 +13,11 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore" ) -func BenchmarkFilter10_10(b *testing.B) { benchmarkFilter(b, 10, 10) } -func BenchmarkFilter100_10(b *testing.B) { benchmarkFilter(b, 100, 10) } -func BenchmarkFilter100_100(b *testing.B) { benchmarkFilter(b, 100, 100) } +func BenchmarkFilter10_10(b *testing.B) { benchmarkFilter(b, 10, 10) } +func BenchmarkFilter100_10(b *testing.B) { benchmarkFilter(b, 100, 10) } +func BenchmarkFilter100_100(b *testing.B) { benchmarkFilter(b, 100, 100) } +func BenchmarkFilter1000_100(b *testing.B) { benchmarkFilter(b, 1000, 100) } +func BenchmarkFilter1000_1000(b *testing.B) { benchmarkFilter(b, 1000, 100) } func benchmarkFilter(b *testing.B, numDs, numPermissions int) { store, permissions := setupFilterBenchmark(b, numDs, numPermissions) @@ -31,7 +33,6 @@ func benchmarkFilter(b *testing.B, numDs, numPermissions int) { baseSql := `SELECT data_source.* FROM data_source WHERE` query, args, err := accesscontrol.Filter( context.Background(), - &FakeDriver{name: "sqlite3"}, "data_source.id", "datasources", "datasources:read", diff --git a/pkg/services/accesscontrol/filter_test.go b/pkg/services/accesscontrol/filter_test.go index a081a91371f..76acf541663 100644 --- a/pkg/services/accesscontrol/filter_test.go +++ b/pkg/services/accesscontrol/filter_test.go @@ -13,145 +13,6 @@ import ( "github.com/grafana/grafana/pkg/services/sqlstore" ) -type filterTest struct { - desc string - driverName string - sqlID string - action string - prefix string - permissions []*accesscontrol.Permission - expectedQuery string - expectedArgs []interface{} -} - -func TestFilter(t *testing.T) { - tests := []filterTest{ - { - desc: "should produce datasource filter with sqlite driver", - driverName: "sqlite3", - sqlID: "data_source.id", - prefix: "datasources", - action: "datasources:query", - permissions: []*accesscontrol.Permission{ - {Action: "datasources:query", Scope: "datasources:id:1"}, - {Action: "datasources:query", Scope: "datasources:id:2"}, - {Action: "datasources:query", Scope: "datasources:id:3"}, - {Action: "datasources:query", Scope: "datasources:id:8"}, - // Other permissions - {Action: "datasources:write", Scope: "datasources:id:100"}, - {Action: "datasources:delete", Scope: "datasources:id:101"}, - }, - expectedQuery: ` - ? || ':id:' || data_source.id IN ( - WITH t(scope) AS ( - VALUES (?), (?), (?), (?) - ) - SELECT IIF(t.scope = '*' OR t.scope = ? || ':*' OR t.scope = ? || ':id:*', ? || ':id:' || data_source.id, t.scope) FROM t - ) - `, - expectedArgs: []interface{}{ - "datasources", - "datasources:id:1", - "datasources:id:2", - "datasources:id:3", - "datasources:id:8", - "datasources", - "datasources", - "datasources", - }, - }, - { - desc: "should produce dashboard filter with mysql driver", - driverName: "mysql", - sqlID: "dashboard.id", - prefix: "dashboards", - action: "dashboards:read", - permissions: []*accesscontrol.Permission{ - {Action: "dashboards:read", Scope: "dashboards:id:1"}, - {Action: "dashboards:read", Scope: "dashboards:id:2"}, - {Action: "dashboards:read", Scope: "dashboards:id:5"}, - // Other permissions - {Action: "dashboards:write", Scope: "dashboards:id:100"}, - {Action: "dashboards:delete", Scope: "dashboards:id:101"}, - }, - expectedQuery: ` - CONCAT(?, ':id:', dashboard.id) IN ( - SELECT IF(t.scope = '*' OR t.scope = CONCAT(?, ':*') OR t.scope = CONCAT(?, ':id:*'), CONCAT(?, ':id:', dashboard.id), t.scope) FROM - (SELECT ? AS scope UNION ALL SELECT ? UNION ALL SELECT ?) AS t - ) - `, - expectedArgs: []interface{}{ - "dashboards", - "dashboards", - "dashboards", - "dashboards", - "dashboards:id:1", - "dashboards:id:2", - "dashboards:id:5", - }, - }, - { - desc: "should produce user filter with postgres driver", - driverName: "postgres", - sqlID: "user.id", - prefix: "users", - action: "users:read", - permissions: []*accesscontrol.Permission{ - {Action: "users:read", Scope: "users:id:1"}, - {Action: "users:read", Scope: "users:id:100"}, - // Other permissions - {Action: "dashboards:write", Scope: "dashboards:id:100"}, - {Action: "dashboards:delete", Scope: "dashboards:id:101"}, - }, - expectedQuery: ` - CONCAT(?, ':id:', user.id) IN ( - SELECT - CASE WHEN p.scope = '*' OR p.scope = CONCAT(?, ':*') OR p.scope = CONCAT(?, ':id:*') THEN CONCAT(?, ':id:', user.id) - ELSE p.scope - END - FROM (VALUES (?), (?)) as p(scope) - ) - `, - expectedArgs: []interface{}{ - "users", - "users", - "users", - "users", - "users:id:1", - "users:id:100", - }, - }, - } - - // set sqlIDAcceptList before running tests - restore := accesscontrol.SetAcceptListForTest(map[string]struct{}{ - "user.id": {}, - "dashboard.id": {}, - "data_source.id": {}, - }) - defer restore() - - for _, tt := range tests { - t.Run(tt.desc, func(t *testing.T) { - query, args, err := accesscontrol.Filter( - context.Background(), - FakeDriver{name: tt.driverName}, - tt.sqlID, - tt.prefix, - tt.action, - &models.SignedInUser{OrgId: 1, Permissions: map[int64]map[string][]string{1: accesscontrol.GroupScopesByAction(tt.permissions)}}, - ) - require.NoError(t, err) - assert.Equal(t, tt.expectedQuery, query) - - require.Len(t, args, len(tt.expectedArgs)) - for i := range tt.expectedArgs { - assert.Equal(t, tt.expectedArgs[i], args[i]) - } - }) - } -} - type filterDatasourcesTestCase struct { desc string sqlID string @@ -221,7 +82,6 @@ func TestFilter_Datasources(t *testing.T) { baseSql := `SELECT data_source.* FROM data_source WHERE` query, args, err := accesscontrol.Filter( context.Background(), - &FakeDriver{name: "sqlite3"}, tt.sqlID, "datasources", "datasources:read", @@ -244,11 +104,3 @@ func TestFilter_Datasources(t *testing.T) { }) } } - -type FakeDriver struct { - name string -} - -func (f FakeDriver) DriverName() string { - return f.name -} diff --git a/pkg/services/sqlstore/org_users.go b/pkg/services/sqlstore/org_users.go index fcfe7f54573..f0817f16d27 100644 --- a/pkg/services/sqlstore/org_users.go +++ b/pkg/services/sqlstore/org_users.go @@ -119,7 +119,7 @@ func (ss *SQLStore) GetOrgUsers(ctx context.Context, query *models.GetOrgUsersQu whereConditions = append(whereConditions, fmt.Sprintf("%s.is_service_account = %t", x.Dialect().Quote("user"), query.IsServiceAccount)) if ss.Cfg.IsFeatureToggleEnabled(featuremgmt.FlagAccesscontrol) { - q, args, err := accesscontrol.Filter(ctx, ss.Dialect, "org_user.user_id", "users", "org.users:read", query.User) + q, args, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User) if err != nil { return err } @@ -182,7 +182,7 @@ func (ss *SQLStore) SearchOrgUsers(ctx context.Context, query *models.SearchOrgU whereConditions = append(whereConditions, fmt.Sprintf("%s.is_service_account = %t", x.Dialect().Quote("user"), query.IsServiceAccount)) if ss.Cfg.IsFeatureToggleEnabled(featuremgmt.FlagAccesscontrol) { - q, args, err := accesscontrol.Filter(ctx, ss.Dialect, "org_user.user_id", "users", "org.users:read", query.User) + q, args, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User) if err != nil { return err }