diff --git a/pkg/services/accesscontrol/filter.go b/pkg/services/accesscontrol/filter.go index ae7fed5625c..9ea828b1ad9 100644 --- a/pkg/services/accesscontrol/filter.go +++ b/pkg/services/accesscontrol/filter.go @@ -13,18 +13,20 @@ var sqlIDAcceptList = map[string]struct{}{ "org_user.user_id": {}, } -const denyQuery = " 1 = 0" -const allowAllQuery = " 1 = 1" +var ( + denyQuery = SQLFilter{" 1 = 0", nil} + allowAllQuery = SQLFilter{" 1 = 1", nil} +) // Filter creates a where clause to restrict the view of a query based on a users permissions // Scopes for a certain action will be compared against prefix:id:sqlID where prefix is the scope prefix and sqlID // is the id to generate scope from e.g. user.id -func Filter(ctx context.Context, sqlID, prefix, action string, user *models.SignedInUser) (string, []interface{}, error) { +func Filter(ctx context.Context, sqlID, prefix, action string, user *models.SignedInUser) (SQLFilter, error) { if _, ok := sqlIDAcceptList[sqlID]; !ok { - return denyQuery, nil, errors.New("sqlID is not in the accept list") + return denyQuery, errors.New("sqlID is not in the accept list") } - if user.Permissions == nil || user.Permissions[user.OrgId] == nil { - return denyQuery, nil, errors.New("missing permissions") + if user == nil || user.Permissions == nil || user.Permissions[user.OrgId] == nil { + return denyQuery, errors.New("missing permissions") } var hasWildcard bool @@ -42,11 +44,11 @@ func Filter(ctx context.Context, sqlID, prefix, action string, user *models.Sign } if hasWildcard { - return allowAllQuery, nil, nil + return allowAllQuery, nil } if len(ids) == 0 { - return denyQuery, nil, nil + return denyQuery, nil } query := strings.Builder{} @@ -57,7 +59,7 @@ func Filter(ctx context.Context, sqlID, prefix, action string, user *models.Sign query.WriteString(strings.Repeat(",?", len(ids)-1)) query.WriteRune(')') - return query.String(), ids, nil + return SQLFilter{query.String(), ids}, nil } func parseScopeID(scope string) (int64, error) { diff --git a/pkg/services/accesscontrol/filter_bench_test.go b/pkg/services/accesscontrol/filter_bench_test.go index 7ef51af5251..126eaf4adb4 100644 --- a/pkg/services/accesscontrol/filter_bench_test.go +++ b/pkg/services/accesscontrol/filter_bench_test.go @@ -31,7 +31,7 @@ func benchmarkFilter(b *testing.B, numDs, numPermissions int) { for i := 0; i < b.N; i++ { baseSql := `SELECT data_source.* FROM data_source WHERE` - query, args, err := accesscontrol.Filter( + acFilter, err := accesscontrol.Filter( context.Background(), "data_source.id", "datasources", @@ -42,7 +42,7 @@ func benchmarkFilter(b *testing.B, numDs, numPermissions int) { var datasources []models.DataSource sess := store.NewSession(context.Background()) - err = sess.SQL(baseSql+query, args...).Find(&datasources) + err = sess.SQL(baseSql+acFilter.Where, acFilter.Args...).Find(&datasources) require.NoError(b, err) sess.Close() require.Len(b, datasources, numPermissions) diff --git a/pkg/services/accesscontrol/filter_test.go b/pkg/services/accesscontrol/filter_test.go index aa71df99b0a..2dcfbb65eed 100644 --- a/pkg/services/accesscontrol/filter_test.go +++ b/pkg/services/accesscontrol/filter_test.go @@ -104,7 +104,7 @@ func TestFilter_Datasources(t *testing.T) { } baseSql := `SELECT data_source.* FROM data_source WHERE` - query, args, err := accesscontrol.Filter( + acFilter, err := accesscontrol.Filter( context.Background(), tt.sqlID, "datasources", @@ -115,7 +115,7 @@ func TestFilter_Datasources(t *testing.T) { if !tt.expectErr { require.NoError(t, err) var datasources []models.DataSource - err = sess.SQL(baseSql+query, args...).Find(&datasources) + err = sess.SQL(baseSql+acFilter.Where, acFilter.Args...).Find(&datasources) require.NoError(t, err) assert.Len(t, datasources, len(tt.expectedDataSources)) diff --git a/pkg/services/accesscontrol/models.go b/pkg/services/accesscontrol/models.go index a1d7c62e9fa..a2925aefda5 100644 --- a/pkg/services/accesscontrol/models.go +++ b/pkg/services/accesscontrol/models.go @@ -248,6 +248,11 @@ type GetResourcesPermissionsQuery struct { OnlyManaged bool } +type SQLFilter struct { + Where string + Args []interface{} +} + const ( GlobalOrgID = 0 // Permission actions diff --git a/pkg/services/sqlstore/org_users.go b/pkg/services/sqlstore/org_users.go index 6c81c92d5a9..0af6765da50 100644 --- a/pkg/services/sqlstore/org_users.go +++ b/pkg/services/sqlstore/org_users.go @@ -119,12 +119,12 @@ func (ss *SQLStore) GetOrgUsers(ctx context.Context, query *models.GetOrgUsersQu whereConditions = append(whereConditions, fmt.Sprintf("%s.is_service_account = %t", x.Dialect().Quote("user"), query.IsServiceAccount)) if ss.Cfg.IsFeatureToggleEnabled(featuremgmt.FlagAccesscontrol) && query.User != nil { - q, args, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User) + acFilter, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User) if err != nil { return err } - whereConditions = append(whereConditions, q) - whereParams = append(whereParams, args...) + whereConditions = append(whereConditions, acFilter.Where) + whereParams = append(whereParams, acFilter.Args...) } if query.Query != "" { @@ -182,12 +182,12 @@ func (ss *SQLStore) SearchOrgUsers(ctx context.Context, query *models.SearchOrgU whereConditions = append(whereConditions, fmt.Sprintf("%s.is_service_account = %t", x.Dialect().Quote("user"), query.IsServiceAccount)) if ss.Cfg.IsFeatureToggleEnabled(featuremgmt.FlagAccesscontrol) { - q, args, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User) + acFilter, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User) if err != nil { return err } - whereConditions = append(whereConditions, q) - whereParams = append(whereParams, args...) + whereConditions = append(whereConditions, acFilter.Where) + whereParams = append(whereParams, acFilter.Args...) } if query.Query != "" {