Update quota service to accept context (#45186)

This commit is contained in:
Yuriy Tseretyan 2022-02-10 16:17:50 -05:00 committed by GitHub
parent c59567a236
commit d4ac1f0ce1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 14 deletions

View File

@ -4,10 +4,12 @@ import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
"github.com/stretchr/testify/assert"
)
func TestMiddlewareQuota(t *testing.T) {
@ -246,3 +248,7 @@ type mockQuotaService struct {
func (m *mockQuotaService) QuotaReached(c *models.ReqContext, target string) (bool, error) {
return m.reached, m.err
}
func (m *mockQuotaService) CheckQuotaReached(c context.Context, target string, params *quota.ScopeParameters) (bool, error) {
return m.reached, m.err
}

View File

@ -1,8 +1,10 @@
package quota
import (
"context"
"errors"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting"
@ -15,6 +17,7 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, sqlS
Cfg: cfg,
AuthTokenService: tokenService,
SQLStore: sqlStore,
Logger: log.New("quota_service"),
}
}
@ -22,30 +25,51 @@ type QuotaService struct {
AuthTokenService models.UserTokenService
Cfg *setting.Cfg
SQLStore sqlstore.Store
Logger log.Logger
}
type Service interface {
QuotaReached(c *models.ReqContext, target string) (bool, error)
CheckQuotaReached(ctx context.Context, target string, scopeParams *ScopeParameters) (bool, error)
}
type ScopeParameters struct {
OrgId int64
UserId int64
}
// QuotaReached checks that quota is reached for a target. Runs CheckQuotaReached and take context and scope parameters from the request context
func (qs *QuotaService) QuotaReached(c *models.ReqContext, target string) (bool, error) {
if !qs.Cfg.Quota.Enabled {
return false, nil
}
// No request context means this is a background service, like LDAP Background Sync.
// TODO: we should replace the req context with a more limited interface or struct,
// something that we could easily provide from background jobs.
// No request context means this is a background service, like LDAP Background Sync
if c == nil {
return false, nil
}
var params *ScopeParameters
if c.IsSignedIn {
params = &ScopeParameters{
OrgId: c.OrgId,
UserId: c.UserId,
}
}
return qs.CheckQuotaReached(c.Req.Context(), target, params)
}
// CheckQuotaReached check that quota is reached for a target. If ScopeParameters are not defined, only global scope is checked
func (qs *QuotaService) CheckQuotaReached(ctx context.Context, target string, scopeParams *ScopeParameters) (bool, error) {
if !qs.Cfg.Quota.Enabled {
return false, nil
}
// get the list of scopes that this target is valid for. Org, User, Global
scopes, err := qs.getQuotaScopes(target)
if err != nil {
return false, err
}
for _, scope := range scopes {
c.Logger.Debug("Checking quota", "target", target, "scope", scope)
qs.Logger.Debug("Checking quota", "target", target, "scope", scope)
switch scope.Name {
case "global":
@ -56,35 +80,35 @@ func (qs *QuotaService) QuotaReached(c *models.ReqContext, target string) (bool,
return true, nil
}
if target == "session" {
usedSessions, err := qs.AuthTokenService.ActiveTokenCount(c.Req.Context())
usedSessions, err := qs.AuthTokenService.ActiveTokenCount(ctx)
if err != nil {
return false, err
}
if usedSessions > scope.DefaultLimit {
c.Logger.Debug("Sessions limit reached", "active", usedSessions, "limit", scope.DefaultLimit)
qs.Logger.Debug("Sessions limit reached", "active", usedSessions, "limit", scope.DefaultLimit)
return true, nil
}
continue
}
query := models.GetGlobalQuotaByTargetQuery{Target: scope.Target, UnifiedAlertingEnabled: qs.Cfg.UnifiedAlerting.IsEnabled()}
if err := qs.SQLStore.GetGlobalQuotaByTarget(c.Req.Context(), &query); err != nil {
if err := qs.SQLStore.GetGlobalQuotaByTarget(ctx, &query); err != nil {
return true, err
}
if query.Result.Used >= scope.DefaultLimit {
return true, nil
}
case "org":
if !c.IsSignedIn {
if scopeParams == nil {
continue
}
query := models.GetOrgQuotaByTargetQuery{
OrgId: c.OrgId,
OrgId: scopeParams.OrgId,
Target: scope.Target,
Default: scope.DefaultLimit,
UnifiedAlertingEnabled: qs.Cfg.UnifiedAlerting.IsEnabled(),
}
if err := qs.SQLStore.GetOrgQuotaByTarget(c.Req.Context(), &query); err != nil {
if err := qs.SQLStore.GetOrgQuotaByTarget(ctx, &query); err != nil {
return true, err
}
if query.Result.Limit < 0 {
@ -98,11 +122,11 @@ func (qs *QuotaService) QuotaReached(c *models.ReqContext, target string) (bool,
return true, nil
}
case "user":
if !c.IsSignedIn || c.UserId == 0 {
if scopeParams == nil || scopeParams.UserId == 0 {
continue
}
query := models.GetUserQuotaByTargetQuery{UserId: c.UserId, Target: scope.Target, Default: scope.DefaultLimit, UnifiedAlertingEnabled: qs.Cfg.UnifiedAlerting.IsEnabled()}
if err := qs.SQLStore.GetUserQuotaByTarget(c.Req.Context(), &query); err != nil {
query := models.GetUserQuotaByTargetQuery{UserId: scopeParams.UserId, Target: scope.Target, Default: scope.DefaultLimit, UnifiedAlertingEnabled: qs.Cfg.UnifiedAlerting.IsEnabled()}
if err := qs.SQLStore.GetUserQuotaByTarget(ctx, &query); err != nil {
return true, err
}
if query.Result.Limit < 0 {