mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Merge pull request #15378 from grafana/auth_token_quotas
use authTokenService for session quotas restrictions
This commit is contained in:
commit
dcec61e1b4
@ -16,7 +16,7 @@ func (hs *HTTPServer) registerRoutes() {
|
|||||||
reqOrgAdmin := middleware.ReqOrgAdmin
|
reqOrgAdmin := middleware.ReqOrgAdmin
|
||||||
redirectFromLegacyDashboardURL := middleware.RedirectFromLegacyDashboardURL()
|
redirectFromLegacyDashboardURL := middleware.RedirectFromLegacyDashboardURL()
|
||||||
redirectFromLegacyDashboardSoloURL := middleware.RedirectFromLegacyDashboardSoloURL()
|
redirectFromLegacyDashboardSoloURL := middleware.RedirectFromLegacyDashboardSoloURL()
|
||||||
quota := middleware.Quota
|
quota := middleware.Quota(hs.QuotaService)
|
||||||
bind := binding.Bind
|
bind := binding.Bind
|
||||||
|
|
||||||
r := hs.RouteRegister
|
r := hs.RouteRegister
|
||||||
@ -286,7 +286,7 @@ func (hs *HTTPServer) registerRoutes() {
|
|||||||
|
|
||||||
dashboardRoute.Post("/calculate-diff", bind(dtos.CalculateDiffOptions{}), Wrap(CalculateDashboardDiff))
|
dashboardRoute.Post("/calculate-diff", bind(dtos.CalculateDiffOptions{}), Wrap(CalculateDashboardDiff))
|
||||||
|
|
||||||
dashboardRoute.Post("/db", bind(m.SaveDashboardCommand{}), Wrap(PostDashboard))
|
dashboardRoute.Post("/db", bind(m.SaveDashboardCommand{}), Wrap(hs.PostDashboard))
|
||||||
dashboardRoute.Get("/home", Wrap(GetHomeDashboard))
|
dashboardRoute.Get("/home", Wrap(GetHomeDashboard))
|
||||||
dashboardRoute.Get("/tags", GetDashboardTags)
|
dashboardRoute.Get("/tags", GetDashboardTags)
|
||||||
dashboardRoute.Post("/import", bind(dtos.ImportDashboardCommand{}), Wrap(ImportDashboard))
|
dashboardRoute.Post("/import", bind(dtos.ImportDashboardCommand{}), Wrap(ImportDashboard))
|
||||||
@ -294,7 +294,7 @@ func (hs *HTTPServer) registerRoutes() {
|
|||||||
dashboardRoute.Group("/id/:dashboardId", func(dashIdRoute routing.RouteRegister) {
|
dashboardRoute.Group("/id/:dashboardId", func(dashIdRoute routing.RouteRegister) {
|
||||||
dashIdRoute.Get("/versions", Wrap(GetDashboardVersions))
|
dashIdRoute.Get("/versions", Wrap(GetDashboardVersions))
|
||||||
dashIdRoute.Get("/versions/:id", Wrap(GetDashboardVersion))
|
dashIdRoute.Get("/versions/:id", Wrap(GetDashboardVersion))
|
||||||
dashIdRoute.Post("/restore", bind(dtos.RestoreDashboardVersionCommand{}), Wrap(RestoreDashboardVersion))
|
dashIdRoute.Post("/restore", bind(dtos.RestoreDashboardVersionCommand{}), Wrap(hs.RestoreDashboardVersion))
|
||||||
|
|
||||||
dashIdRoute.Group("/permissions", func(dashboardPermissionRoute routing.RouteRegister) {
|
dashIdRoute.Group("/permissions", func(dashboardPermissionRoute routing.RouteRegister) {
|
||||||
dashboardPermissionRoute.Get("/", Wrap(GetDashboardPermissionList))
|
dashboardPermissionRoute.Get("/", Wrap(GetDashboardPermissionList))
|
||||||
|
@ -18,7 +18,6 @@ import (
|
|||||||
m "github.com/grafana/grafana/pkg/models"
|
m "github.com/grafana/grafana/pkg/models"
|
||||||
"github.com/grafana/grafana/pkg/plugins"
|
"github.com/grafana/grafana/pkg/plugins"
|
||||||
"github.com/grafana/grafana/pkg/services/guardian"
|
"github.com/grafana/grafana/pkg/services/guardian"
|
||||||
"github.com/grafana/grafana/pkg/services/quota"
|
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/util"
|
"github.com/grafana/grafana/pkg/util"
|
||||||
)
|
)
|
||||||
@ -208,14 +207,14 @@ func DeleteDashboardByUID(c *m.ReqContext) Response {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func PostDashboard(c *m.ReqContext, cmd m.SaveDashboardCommand) Response {
|
func (hs *HTTPServer) PostDashboard(c *m.ReqContext, cmd m.SaveDashboardCommand) Response {
|
||||||
cmd.OrgId = c.OrgId
|
cmd.OrgId = c.OrgId
|
||||||
cmd.UserId = c.UserId
|
cmd.UserId = c.UserId
|
||||||
|
|
||||||
dash := cmd.GetDashboardModel()
|
dash := cmd.GetDashboardModel()
|
||||||
|
|
||||||
if dash.Id == 0 && dash.Uid == "" {
|
if dash.Id == 0 && dash.Uid == "" {
|
||||||
limitReached, err := quota.QuotaReached(c, "dashboard")
|
limitReached, err := hs.QuotaService.QuotaReached(c, "dashboard")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Error(500, "failed to get quota", err)
|
return Error(500, "failed to get quota", err)
|
||||||
}
|
}
|
||||||
@ -463,7 +462,7 @@ func CalculateDashboardDiff(c *m.ReqContext, apiOptions dtos.CalculateDiffOption
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RestoreDashboardVersion restores a dashboard to the given version.
|
// RestoreDashboardVersion restores a dashboard to the given version.
|
||||||
func RestoreDashboardVersion(c *m.ReqContext, apiCmd dtos.RestoreDashboardVersionCommand) Response {
|
func (hs *HTTPServer) RestoreDashboardVersion(c *m.ReqContext, apiCmd dtos.RestoreDashboardVersionCommand) Response {
|
||||||
dash, rsp := getDashboardHelper(c.OrgId, "", c.ParamsInt64(":dashboardId"), "")
|
dash, rsp := getDashboardHelper(c.OrgId, "", c.ParamsInt64(":dashboardId"), "")
|
||||||
if rsp != nil {
|
if rsp != nil {
|
||||||
return rsp
|
return rsp
|
||||||
@ -490,7 +489,7 @@ func RestoreDashboardVersion(c *m.ReqContext, apiCmd dtos.RestoreDashboardVersio
|
|||||||
saveCmd.Dashboard.Set("uid", dash.Uid)
|
saveCmd.Dashboard.Set("uid", dash.Uid)
|
||||||
saveCmd.Message = fmt.Sprintf("Restored from version %d", version.Version)
|
saveCmd.Message = fmt.Sprintf("Restored from version %d", version.Version)
|
||||||
|
|
||||||
return PostDashboard(c, saveCmd)
|
return hs.PostDashboard(c, saveCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetDashboardTags(c *m.ReqContext) {
|
func GetDashboardTags(c *m.ReqContext) {
|
||||||
|
@ -881,12 +881,16 @@ func postDashboardScenario(desc string, url string, routePattern string, mock *d
|
|||||||
Convey(desc+" "+url, func() {
|
Convey(desc+" "+url, func() {
|
||||||
defer bus.ClearBusHandlers()
|
defer bus.ClearBusHandlers()
|
||||||
|
|
||||||
|
hs := HTTPServer{
|
||||||
|
Bus: bus.GetBus(),
|
||||||
|
}
|
||||||
|
|
||||||
sc := setupScenarioContext(url)
|
sc := setupScenarioContext(url)
|
||||||
sc.defaultHandler = Wrap(func(c *m.ReqContext) Response {
|
sc.defaultHandler = Wrap(func(c *m.ReqContext) Response {
|
||||||
sc.context = c
|
sc.context = c
|
||||||
sc.context.SignedInUser = &m.SignedInUser{OrgId: cmd.OrgId, UserId: cmd.UserId}
|
sc.context.SignedInUser = &m.SignedInUser{OrgId: cmd.OrgId, UserId: cmd.UserId}
|
||||||
|
|
||||||
return PostDashboard(c, cmd)
|
return hs.PostDashboard(c, cmd)
|
||||||
})
|
})
|
||||||
|
|
||||||
origNewDashboardService := dashboards.NewService
|
origNewDashboardService := dashboards.NewService
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/cache"
|
"github.com/grafana/grafana/pkg/services/cache"
|
||||||
"github.com/grafana/grafana/pkg/services/datasources"
|
"github.com/grafana/grafana/pkg/services/datasources"
|
||||||
"github.com/grafana/grafana/pkg/services/hooks"
|
"github.com/grafana/grafana/pkg/services/hooks"
|
||||||
|
"github.com/grafana/grafana/pkg/services/quota"
|
||||||
"github.com/grafana/grafana/pkg/services/rendering"
|
"github.com/grafana/grafana/pkg/services/rendering"
|
||||||
"github.com/grafana/grafana/pkg/services/session"
|
"github.com/grafana/grafana/pkg/services/session"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
@ -55,6 +56,7 @@ type HTTPServer struct {
|
|||||||
CacheService *cache.CacheService `inject:""`
|
CacheService *cache.CacheService `inject:""`
|
||||||
DatasourceCache datasources.CacheService `inject:""`
|
DatasourceCache datasources.CacheService `inject:""`
|
||||||
AuthTokenService models.UserTokenService `inject:""`
|
AuthTokenService models.UserTokenService `inject:""`
|
||||||
|
QuotaService *quota.QuotaService `inject:""`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hs *HTTPServer) Init() error {
|
func (hs *HTTPServer) Init() error {
|
||||||
|
@ -4,18 +4,30 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/bus"
|
"github.com/grafana/grafana/pkg/bus"
|
||||||
"github.com/grafana/grafana/pkg/log"
|
"github.com/grafana/grafana/pkg/log"
|
||||||
m "github.com/grafana/grafana/pkg/models"
|
m "github.com/grafana/grafana/pkg/models"
|
||||||
|
"github.com/grafana/grafana/pkg/registry"
|
||||||
"github.com/grafana/grafana/pkg/services/quota"
|
"github.com/grafana/grafana/pkg/services/quota"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
bus.AddHandler("auth", UpsertUser)
|
registry.RegisterService(&LoginService{})
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
logger = log.New("login.ext_user")
|
logger = log.New("login.ext_user")
|
||||||
)
|
)
|
||||||
|
|
||||||
func UpsertUser(cmd *m.UpsertUserCommand) error {
|
type LoginService struct {
|
||||||
|
Bus bus.Bus `inject:""`
|
||||||
|
QuotaService *quota.QuotaService `inject:""`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ls *LoginService) Init() error {
|
||||||
|
ls.Bus.AddHandler(ls.UpsertUser)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ls *LoginService) UpsertUser(cmd *m.UpsertUserCommand) error {
|
||||||
extUser := cmd.ExternalUser
|
extUser := cmd.ExternalUser
|
||||||
|
|
||||||
userQuery := &m.GetUserByAuthInfoQuery{
|
userQuery := &m.GetUserByAuthInfoQuery{
|
||||||
@ -37,7 +49,7 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
|
|||||||
return ErrInvalidCredentials
|
return ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
limitReached, err := quota.QuotaReached(cmd.ReqContext, "user")
|
limitReached, err := ls.QuotaService.QuotaReached(cmd.ReqContext, "user")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("Error getting user quota. error: %v", err)
|
log.Warn("Error getting user quota. error: %v", err)
|
||||||
return ErrGettingUserQuota
|
return ErrGettingUserQuota
|
||||||
@ -57,7 +69,7 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
|
|||||||
AuthModule: extUser.AuthModule,
|
AuthModule: extUser.AuthModule,
|
||||||
AuthId: extUser.AuthId,
|
AuthId: extUser.AuthId,
|
||||||
}
|
}
|
||||||
if err := bus.Dispatch(cmd2); err != nil {
|
if err := ls.Bus.Dispatch(cmd2); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -78,12 +90,12 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
|
|||||||
|
|
||||||
// Sync isGrafanaAdmin permission
|
// Sync isGrafanaAdmin permission
|
||||||
if extUser.IsGrafanaAdmin != nil && *extUser.IsGrafanaAdmin != cmd.Result.IsAdmin {
|
if extUser.IsGrafanaAdmin != nil && *extUser.IsGrafanaAdmin != cmd.Result.IsAdmin {
|
||||||
if err := bus.Dispatch(&m.UpdateUserPermissionsCommand{UserId: cmd.Result.Id, IsGrafanaAdmin: *extUser.IsGrafanaAdmin}); err != nil {
|
if err := ls.Bus.Dispatch(&m.UpdateUserPermissionsCommand{UserId: cmd.Result.Id, IsGrafanaAdmin: *extUser.IsGrafanaAdmin}); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = bus.Dispatch(&m.SyncTeamsCommand{
|
err = ls.Bus.Dispatch(&m.SyncTeamsCommand{
|
||||||
User: cmd.Result,
|
User: cmd.Result,
|
||||||
ExternalUser: extUser,
|
ExternalUser: extUser,
|
||||||
})
|
})
|
||||||
|
@ -395,8 +395,11 @@ func ldapAutherScenario(desc string, fn scenarioFunc) {
|
|||||||
defer bus.ClearBusHandlers()
|
defer bus.ClearBusHandlers()
|
||||||
|
|
||||||
sc := &scenarioContext{}
|
sc := &scenarioContext{}
|
||||||
|
loginService := &LoginService{
|
||||||
|
Bus: bus.GetBus(),
|
||||||
|
}
|
||||||
|
|
||||||
bus.AddHandler("test", UpsertUser)
|
bus.AddHandler("test", loginService.UpsertUser)
|
||||||
|
|
||||||
bus.AddHandlerCtx("test", func(ctx context.Context, cmd *m.SyncTeamsCommand) error {
|
bus.AddHandlerCtx("test", func(ctx context.Context, cmd *m.SyncTeamsCommand) error {
|
||||||
return nil
|
return nil
|
||||||
|
@ -682,6 +682,7 @@ type fakeUserAuthTokenService struct {
|
|||||||
tryRotateTokenProvider func(token *m.UserToken, clientIP, userAgent string) (bool, error)
|
tryRotateTokenProvider func(token *m.UserToken, clientIP, userAgent string) (bool, error)
|
||||||
lookupTokenProvider func(unhashedToken string) (*m.UserToken, error)
|
lookupTokenProvider func(unhashedToken string) (*m.UserToken, error)
|
||||||
revokeTokenProvider func(token *m.UserToken) error
|
revokeTokenProvider func(token *m.UserToken) error
|
||||||
|
activeAuthTokenCount func() (int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFakeUserAuthTokenService() *fakeUserAuthTokenService {
|
func newFakeUserAuthTokenService() *fakeUserAuthTokenService {
|
||||||
@ -704,6 +705,9 @@ func newFakeUserAuthTokenService() *fakeUserAuthTokenService {
|
|||||||
revokeTokenProvider: func(token *m.UserToken) error {
|
revokeTokenProvider: func(token *m.UserToken) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
activeAuthTokenCount: func() (int64, error) {
|
||||||
|
return 10, nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -722,3 +726,7 @@ func (s *fakeUserAuthTokenService) TryRotateToken(token *m.UserToken, clientIP,
|
|||||||
func (s *fakeUserAuthTokenService) RevokeToken(token *m.UserToken) error {
|
func (s *fakeUserAuthTokenService) RevokeToken(token *m.UserToken) error {
|
||||||
return s.revokeTokenProvider(token)
|
return s.revokeTokenProvider(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *fakeUserAuthTokenService) ActiveTokenCount() (int64, error) {
|
||||||
|
return s.activeAuthTokenCount()
|
||||||
|
}
|
||||||
|
@ -9,16 +9,20 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/quota"
|
"github.com/grafana/grafana/pkg/services/quota"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Quota(target string) macaron.Handler {
|
// Quota returns a function that returns a function used to call quotaservice based on target name
|
||||||
return func(c *m.ReqContext) {
|
func Quota(quotaService *quota.QuotaService) func(target string) macaron.Handler {
|
||||||
limitReached, err := quota.QuotaReached(c, target)
|
//https://open.spotify.com/track/7bZSoBEAEEUsGEuLOf94Jm?si=T1Tdju5qRSmmR0zph_6RBw fuuuuunky
|
||||||
if err != nil {
|
return func(target string) macaron.Handler {
|
||||||
c.JsonApiErr(500, "failed to get quota", err)
|
return func(c *m.ReqContext) {
|
||||||
return
|
limitReached, err := quotaService.QuotaReached(c, target)
|
||||||
}
|
if err != nil {
|
||||||
if limitReached {
|
c.JsonApiErr(500, "failed to get quota", err)
|
||||||
c.JsonApiErr(403, fmt.Sprintf("%s Quota reached", target), nil)
|
return
|
||||||
return
|
}
|
||||||
|
if limitReached {
|
||||||
|
c.JsonApiErr(403, fmt.Sprintf("%s Quota reached", target), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,10 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/services/quota"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/bus"
|
"github.com/grafana/grafana/pkg/bus"
|
||||||
m "github.com/grafana/grafana/pkg/models"
|
m "github.com/grafana/grafana/pkg/models"
|
||||||
"github.com/grafana/grafana/pkg/services/session"
|
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
)
|
)
|
||||||
@ -13,10 +14,6 @@ import (
|
|||||||
func TestMiddlewareQuota(t *testing.T) {
|
func TestMiddlewareQuota(t *testing.T) {
|
||||||
|
|
||||||
Convey("Given the grafana quota middleware", t, func() {
|
Convey("Given the grafana quota middleware", t, func() {
|
||||||
session.GetSessionCount = func() int {
|
|
||||||
return 4
|
|
||||||
}
|
|
||||||
|
|
||||||
setting.AnonymousEnabled = false
|
setting.AnonymousEnabled = false
|
||||||
setting.Quota = setting.QuotaSettings{
|
setting.Quota = setting.QuotaSettings{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
@ -39,6 +36,12 @@ func TestMiddlewareQuota(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fakeAuthTokenService := newFakeUserAuthTokenService()
|
||||||
|
qs := "a.QuotaService{
|
||||||
|
AuthTokenService: fakeAuthTokenService,
|
||||||
|
}
|
||||||
|
QuotaFn := Quota(qs)
|
||||||
|
|
||||||
middlewareScenario("with user not logged in", func(sc *scenarioContext) {
|
middlewareScenario("with user not logged in", func(sc *scenarioContext) {
|
||||||
bus.AddHandler("globalQuota", func(query *m.GetGlobalQuotaByTargetQuery) error {
|
bus.AddHandler("globalQuota", func(query *m.GetGlobalQuotaByTargetQuery) error {
|
||||||
query.Result = &m.GlobalQuotaDTO{
|
query.Result = &m.GlobalQuotaDTO{
|
||||||
@ -48,26 +51,30 @@ func TestMiddlewareQuota(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("global quota not reached", func() {
|
Convey("global quota not reached", func() {
|
||||||
sc.m.Get("/user", Quota("user"), sc.defaultHandler)
|
sc.m.Get("/user", QuotaFn("user"), sc.defaultHandler)
|
||||||
sc.fakeReq("GET", "/user").exec()
|
sc.fakeReq("GET", "/user").exec()
|
||||||
So(sc.resp.Code, ShouldEqual, 200)
|
So(sc.resp.Code, ShouldEqual, 200)
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("global quota reached", func() {
|
Convey("global quota reached", func() {
|
||||||
setting.Quota.Global.User = 4
|
setting.Quota.Global.User = 4
|
||||||
sc.m.Get("/user", Quota("user"), sc.defaultHandler)
|
sc.m.Get("/user", QuotaFn("user"), sc.defaultHandler)
|
||||||
sc.fakeReq("GET", "/user").exec()
|
sc.fakeReq("GET", "/user").exec()
|
||||||
So(sc.resp.Code, ShouldEqual, 403)
|
So(sc.resp.Code, ShouldEqual, 403)
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("global session quota not reached", func() {
|
Convey("global session quota not reached", func() {
|
||||||
setting.Quota.Global.Session = 10
|
setting.Quota.Global.Session = 10
|
||||||
sc.m.Get("/user", Quota("session"), sc.defaultHandler)
|
sc.m.Get("/user", QuotaFn("session"), sc.defaultHandler)
|
||||||
sc.fakeReq("GET", "/user").exec()
|
sc.fakeReq("GET", "/user").exec()
|
||||||
So(sc.resp.Code, ShouldEqual, 200)
|
So(sc.resp.Code, ShouldEqual, 200)
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("global session quota reached", func() {
|
Convey("global session quota reached", func() {
|
||||||
setting.Quota.Global.Session = 1
|
setting.Quota.Global.Session = 1
|
||||||
sc.m.Get("/user", Quota("session"), sc.defaultHandler)
|
sc.m.Get("/user", QuotaFn("session"), sc.defaultHandler)
|
||||||
sc.fakeReq("GET", "/user").exec()
|
sc.fakeReq("GET", "/user").exec()
|
||||||
So(sc.resp.Code, ShouldEqual, 403)
|
So(sc.resp.Code, ShouldEqual, 403)
|
||||||
})
|
})
|
||||||
@ -95,6 +102,7 @@ func TestMiddlewareQuota(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
bus.AddHandler("userQuota", func(query *m.GetUserQuotaByTargetQuery) error {
|
bus.AddHandler("userQuota", func(query *m.GetUserQuotaByTargetQuery) error {
|
||||||
query.Result = &m.UserQuotaDTO{
|
query.Result = &m.UserQuotaDTO{
|
||||||
Target: query.Target,
|
Target: query.Target,
|
||||||
@ -103,6 +111,7 @@ func TestMiddlewareQuota(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
bus.AddHandler("orgQuota", func(query *m.GetOrgQuotaByTargetQuery) error {
|
bus.AddHandler("orgQuota", func(query *m.GetOrgQuotaByTargetQuery) error {
|
||||||
query.Result = &m.OrgQuotaDTO{
|
query.Result = &m.OrgQuotaDTO{
|
||||||
Target: query.Target,
|
Target: query.Target,
|
||||||
@ -111,45 +120,49 @@ func TestMiddlewareQuota(t *testing.T) {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("global datasource quota reached", func() {
|
Convey("global datasource quota reached", func() {
|
||||||
setting.Quota.Global.DataSource = 4
|
setting.Quota.Global.DataSource = 4
|
||||||
sc.m.Get("/ds", Quota("data_source"), sc.defaultHandler)
|
sc.m.Get("/ds", QuotaFn("data_source"), sc.defaultHandler)
|
||||||
sc.fakeReq("GET", "/ds").exec()
|
sc.fakeReq("GET", "/ds").exec()
|
||||||
So(sc.resp.Code, ShouldEqual, 403)
|
So(sc.resp.Code, ShouldEqual, 403)
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("user Org quota not reached", func() {
|
Convey("user Org quota not reached", func() {
|
||||||
setting.Quota.User.Org = 5
|
setting.Quota.User.Org = 5
|
||||||
sc.m.Get("/org", Quota("org"), sc.defaultHandler)
|
sc.m.Get("/org", QuotaFn("org"), sc.defaultHandler)
|
||||||
sc.fakeReq("GET", "/org").exec()
|
sc.fakeReq("GET", "/org").exec()
|
||||||
So(sc.resp.Code, ShouldEqual, 200)
|
So(sc.resp.Code, ShouldEqual, 200)
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("user Org quota reached", func() {
|
Convey("user Org quota reached", func() {
|
||||||
setting.Quota.User.Org = 4
|
setting.Quota.User.Org = 4
|
||||||
sc.m.Get("/org", Quota("org"), sc.defaultHandler)
|
sc.m.Get("/org", QuotaFn("org"), sc.defaultHandler)
|
||||||
sc.fakeReq("GET", "/org").exec()
|
sc.fakeReq("GET", "/org").exec()
|
||||||
So(sc.resp.Code, ShouldEqual, 403)
|
So(sc.resp.Code, ShouldEqual, 403)
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("org dashboard quota not reached", func() {
|
Convey("org dashboard quota not reached", func() {
|
||||||
setting.Quota.Org.Dashboard = 10
|
setting.Quota.Org.Dashboard = 10
|
||||||
sc.m.Get("/dashboard", Quota("dashboard"), sc.defaultHandler)
|
sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
|
||||||
sc.fakeReq("GET", "/dashboard").exec()
|
sc.fakeReq("GET", "/dashboard").exec()
|
||||||
So(sc.resp.Code, ShouldEqual, 200)
|
So(sc.resp.Code, ShouldEqual, 200)
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("org dashboard quota reached", func() {
|
Convey("org dashboard quota reached", func() {
|
||||||
setting.Quota.Org.Dashboard = 4
|
setting.Quota.Org.Dashboard = 4
|
||||||
sc.m.Get("/dashboard", Quota("dashboard"), sc.defaultHandler)
|
sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
|
||||||
sc.fakeReq("GET", "/dashboard").exec()
|
sc.fakeReq("GET", "/dashboard").exec()
|
||||||
So(sc.resp.Code, ShouldEqual, 403)
|
So(sc.resp.Code, ShouldEqual, 403)
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("org dashboard quota reached but quotas disabled", func() {
|
Convey("org dashboard quota reached but quotas disabled", func() {
|
||||||
setting.Quota.Org.Dashboard = 4
|
setting.Quota.Org.Dashboard = 4
|
||||||
setting.Quota.Enabled = false
|
setting.Quota.Enabled = false
|
||||||
sc.m.Get("/dashboard", Quota("dashboard"), sc.defaultHandler)
|
sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
|
||||||
sc.fakeReq("GET", "/dashboard").exec()
|
sc.fakeReq("GET", "/dashboard").exec()
|
||||||
So(sc.resp.Code, ShouldEqual, 200)
|
So(sc.resp.Code, ShouldEqual, 200)
|
||||||
})
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -29,4 +29,5 @@ type UserTokenService interface {
|
|||||||
LookupToken(unhashedToken string) (*UserToken, error)
|
LookupToken(unhashedToken string) (*UserToken, error)
|
||||||
TryRotateToken(token *UserToken, clientIP, userAgent string) (bool, error)
|
TryRotateToken(token *UserToken, clientIP, userAgent string) (bool, error)
|
||||||
RevokeToken(token *UserToken) error
|
RevokeToken(token *UserToken) error
|
||||||
|
ActiveTokenCount() (int64, error)
|
||||||
}
|
}
|
||||||
|
@ -35,6 +35,13 @@ func (s *UserAuthTokenService) Init() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UserAuthTokenService) ActiveTokenCount() (int64, error) {
|
||||||
|
var model userAuthToken
|
||||||
|
count, err := s.SQLStore.NewSession().Where(`created_at > ? AND rotated_at > ?`, s.createdAfterParam(), s.rotatedAfterParam()).Count(&model)
|
||||||
|
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) {
|
func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) {
|
||||||
clientIP = util.ParseIPAddress(clientIP)
|
clientIP = util.ParseIPAddress(clientIP)
|
||||||
token, err := util.RandomHex(16)
|
token, err := util.RandomHex(16)
|
||||||
@ -79,13 +86,8 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo
|
|||||||
s.log.Debug("looking up token", "unhashed", unhashedToken, "hashed", hashedToken)
|
s.log.Debug("looking up token", "unhashed", unhashedToken, "hashed", hashedToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenMaxLifetime := time.Duration(s.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour
|
|
||||||
tokenMaxInactiveLifetime := time.Duration(s.Cfg.LoginMaxInactiveLifetimeDays) * 24 * time.Hour
|
|
||||||
createdAfter := getTime().Add(-tokenMaxLifetime).Unix()
|
|
||||||
rotatedAfter := getTime().Add(-tokenMaxInactiveLifetime).Unix()
|
|
||||||
|
|
||||||
var model userAuthToken
|
var model userAuthToken
|
||||||
exists, err := s.SQLStore.NewSession().Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?", hashedToken, hashedToken, createdAfter, rotatedAfter).Get(&model)
|
exists, err := s.SQLStore.NewSession().Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?", hashedToken, hashedToken, s.createdAfterParam(), s.rotatedAfterParam()).Get(&model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -219,6 +221,16 @@ func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UserAuthTokenService) createdAfterParam() int64 {
|
||||||
|
tokenMaxLifetime := time.Duration(s.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour
|
||||||
|
return getTime().Add(-tokenMaxLifetime).Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserAuthTokenService) rotatedAfterParam() int64 {
|
||||||
|
tokenMaxInactiveLifetime := time.Duration(s.Cfg.LoginMaxInactiveLifetimeDays) * 24 * time.Hour
|
||||||
|
return getTime().Add(-tokenMaxInactiveLifetime).Unix()
|
||||||
|
}
|
||||||
|
|
||||||
func hashToken(token string) string {
|
func hashToken(token string) string {
|
||||||
hashBytes := sha256.Sum256([]byte(token + setting.SecretKey))
|
hashBytes := sha256.Sum256([]byte(token + setting.SecretKey))
|
||||||
return hex.EncodeToString(hashBytes[:])
|
return hex.EncodeToString(hashBytes[:])
|
||||||
|
@ -31,6 +31,12 @@ func TestUserAuthToken(t *testing.T) {
|
|||||||
So(userToken, ShouldNotBeNil)
|
So(userToken, ShouldNotBeNil)
|
||||||
So(userToken.AuthTokenSeen, ShouldBeFalse)
|
So(userToken.AuthTokenSeen, ShouldBeFalse)
|
||||||
|
|
||||||
|
Convey("Can count active tokens", func() {
|
||||||
|
count, err := userAuthTokenService.ActiveTokenCount()
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(count, ShouldEqual, 1)
|
||||||
|
})
|
||||||
|
|
||||||
Convey("When lookup unhashed token should return user auth token", func() {
|
Convey("When lookup unhashed token should return user auth token", func() {
|
||||||
userToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
|
userToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
@ -114,6 +120,12 @@ func TestUserAuthToken(t *testing.T) {
|
|||||||
notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
|
notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
|
||||||
So(err, ShouldEqual, models.ErrUserTokenNotFound)
|
So(err, ShouldEqual, models.ErrUserTokenNotFound)
|
||||||
So(notGood, ShouldBeNil)
|
So(notGood, ShouldBeNil)
|
||||||
|
|
||||||
|
Convey("should not find active token when expired", func() {
|
||||||
|
count, err := userAuthTokenService.ActiveTokenCount()
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(count, ShouldEqual, 0)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("when rotated_at is 5 days ago and created_at is 29 days and 23:59:59 ago should not find token", func() {
|
Convey("when rotated_at is 5 days ago and created_at is 29 days and 23:59:59 ago should not find token", func() {
|
||||||
|
@ -3,11 +3,23 @@ package quota
|
|||||||
import (
|
import (
|
||||||
"github.com/grafana/grafana/pkg/bus"
|
"github.com/grafana/grafana/pkg/bus"
|
||||||
m "github.com/grafana/grafana/pkg/models"
|
m "github.com/grafana/grafana/pkg/models"
|
||||||
"github.com/grafana/grafana/pkg/services/session"
|
"github.com/grafana/grafana/pkg/registry"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
func QuotaReached(c *m.ReqContext, target string) (bool, error) {
|
func init() {
|
||||||
|
registry.RegisterService(&QuotaService{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type QuotaService struct {
|
||||||
|
AuthTokenService m.UserTokenService `inject:""`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qs *QuotaService) Init() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (qs *QuotaService) QuotaReached(c *m.ReqContext, target string) (bool, error) {
|
||||||
if !setting.Quota.Enabled {
|
if !setting.Quota.Enabled {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
@ -30,7 +42,12 @@ func QuotaReached(c *m.ReqContext, target string) (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
if target == "session" {
|
if target == "session" {
|
||||||
usedSessions := session.GetSessionCount()
|
|
||||||
|
usedSessions, err := qs.AuthTokenService.ActiveTokenCount()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
if int64(usedSessions) > scope.DefaultLimit {
|
if int64(usedSessions) > scope.DefaultLimit {
|
||||||
c.Logger.Debug("Sessions limit reached", "active", usedSessions, "limit", scope.DefaultLimit)
|
c.Logger.Debug("Sessions limit reached", "active", usedSessions, "limit", scope.DefaultLimit)
|
||||||
return true, nil
|
return true, nil
|
||||||
|
Loading…
Reference in New Issue
Block a user