AccessControl: Use an SQLFilter struct (#44887)

This commit is contained in:
Gabriel MABILLE 2022-02-07 16:18:52 +01:00 committed by GitHub
parent bdac6576e4
commit 178193c84b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 26 additions and 19 deletions

View File

@ -13,18 +13,20 @@ var sqlIDAcceptList = map[string]struct{}{
"org_user.user_id": {},
}
const denyQuery = " 1 = 0"
const allowAllQuery = " 1 = 1"
var (
denyQuery = SQLFilter{" 1 = 0", nil}
allowAllQuery = SQLFilter{" 1 = 1", nil}
)
// Filter creates a where clause to restrict the view of a query based on a users permissions
// Scopes for a certain action will be compared against prefix:id:sqlID where prefix is the scope prefix and sqlID
// is the id to generate scope from e.g. user.id
func Filter(ctx context.Context, sqlID, prefix, action string, user *models.SignedInUser) (string, []interface{}, error) {
func Filter(ctx context.Context, sqlID, prefix, action string, user *models.SignedInUser) (SQLFilter, error) {
if _, ok := sqlIDAcceptList[sqlID]; !ok {
return denyQuery, nil, errors.New("sqlID is not in the accept list")
return denyQuery, errors.New("sqlID is not in the accept list")
}
if user.Permissions == nil || user.Permissions[user.OrgId] == nil {
return denyQuery, nil, errors.New("missing permissions")
if user == nil || user.Permissions == nil || user.Permissions[user.OrgId] == nil {
return denyQuery, errors.New("missing permissions")
}
var hasWildcard bool
@ -42,11 +44,11 @@ func Filter(ctx context.Context, sqlID, prefix, action string, user *models.Sign
}
if hasWildcard {
return allowAllQuery, nil, nil
return allowAllQuery, nil
}
if len(ids) == 0 {
return denyQuery, nil, nil
return denyQuery, nil
}
query := strings.Builder{}
@ -57,7 +59,7 @@ func Filter(ctx context.Context, sqlID, prefix, action string, user *models.Sign
query.WriteString(strings.Repeat(",?", len(ids)-1))
query.WriteRune(')')
return query.String(), ids, nil
return SQLFilter{query.String(), ids}, nil
}
func parseScopeID(scope string) (int64, error) {

View File

@ -31,7 +31,7 @@ func benchmarkFilter(b *testing.B, numDs, numPermissions int) {
for i := 0; i < b.N; i++ {
baseSql := `SELECT data_source.* FROM data_source WHERE`
query, args, err := accesscontrol.Filter(
acFilter, err := accesscontrol.Filter(
context.Background(),
"data_source.id",
"datasources",
@ -42,7 +42,7 @@ func benchmarkFilter(b *testing.B, numDs, numPermissions int) {
var datasources []models.DataSource
sess := store.NewSession(context.Background())
err = sess.SQL(baseSql+query, args...).Find(&datasources)
err = sess.SQL(baseSql+acFilter.Where, acFilter.Args...).Find(&datasources)
require.NoError(b, err)
sess.Close()
require.Len(b, datasources, numPermissions)

View File

@ -104,7 +104,7 @@ func TestFilter_Datasources(t *testing.T) {
}
baseSql := `SELECT data_source.* FROM data_source WHERE`
query, args, err := accesscontrol.Filter(
acFilter, err := accesscontrol.Filter(
context.Background(),
tt.sqlID,
"datasources",
@ -115,7 +115,7 @@ func TestFilter_Datasources(t *testing.T) {
if !tt.expectErr {
require.NoError(t, err)
var datasources []models.DataSource
err = sess.SQL(baseSql+query, args...).Find(&datasources)
err = sess.SQL(baseSql+acFilter.Where, acFilter.Args...).Find(&datasources)
require.NoError(t, err)
assert.Len(t, datasources, len(tt.expectedDataSources))

View File

@ -248,6 +248,11 @@ type GetResourcesPermissionsQuery struct {
OnlyManaged bool
}
type SQLFilter struct {
Where string
Args []interface{}
}
const (
GlobalOrgID = 0
// Permission actions

View File

@ -119,12 +119,12 @@ func (ss *SQLStore) GetOrgUsers(ctx context.Context, query *models.GetOrgUsersQu
whereConditions = append(whereConditions, fmt.Sprintf("%s.is_service_account = %t", x.Dialect().Quote("user"), query.IsServiceAccount))
if ss.Cfg.IsFeatureToggleEnabled(featuremgmt.FlagAccesscontrol) && query.User != nil {
q, args, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User)
acFilter, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User)
if err != nil {
return err
}
whereConditions = append(whereConditions, q)
whereParams = append(whereParams, args...)
whereConditions = append(whereConditions, acFilter.Where)
whereParams = append(whereParams, acFilter.Args...)
}
if query.Query != "" {
@ -182,12 +182,12 @@ func (ss *SQLStore) SearchOrgUsers(ctx context.Context, query *models.SearchOrgU
whereConditions = append(whereConditions, fmt.Sprintf("%s.is_service_account = %t", x.Dialect().Quote("user"), query.IsServiceAccount))
if ss.Cfg.IsFeatureToggleEnabled(featuremgmt.FlagAccesscontrol) {
q, args, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User)
acFilter, err := accesscontrol.Filter(ctx, "org_user.user_id", "users", "org.users:read", query.User)
if err != nil {
return err
}
whereConditions = append(whereConditions, q)
whereParams = append(whereParams, args...)
whereConditions = append(whereConditions, acFilter.Where)
whereParams = append(whereParams, acFilter.Args...)
}
if query.Query != "" {