From fef1e1d5bc8c1ea7d6deb6a8ef808988cdec61aa Mon Sep 17 00:00:00 2001 From: Karl Persson Date: Fri, 18 Nov 2022 09:56:06 +0100 Subject: [PATCH] Auth: Refactor auth package (#58920) * Auth: move interface to its own file * Auth: move to test package * Auth: move quota consts to auth file * Auth: move service to impl package * Auth: move interfaces and related models to auth package * Auth: Create sub package and type alias to avoid circular dependency --- pkg/api/admin_users.go | 7 +-- pkg/api/admin_users_test.go | 13 ++--- pkg/api/common_test.go | 6 +-- pkg/api/http_server.go | 5 +- pkg/api/ldap_debug_test.go | 4 +- pkg/api/login.go | 7 +-- pkg/api/login_test.go | 12 ++--- pkg/api/user_token.go | 13 ++--- pkg/api/user_token_test.go | 41 +++++++------- pkg/cmd/grafana-cli/runner/wire.go | 5 +- pkg/cmd/grafana-cli/runner/wireexts_oss.go | 7 +-- pkg/middleware/auth.go | 5 +- pkg/middleware/middleware_test.go | 39 +++++++------- pkg/middleware/org_redirect_test.go | 10 ++-- pkg/middleware/quota_test.go | 6 +-- pkg/middleware/recovery_test.go | 4 +- pkg/middleware/testing.go | 6 +-- pkg/models/context.go | 3 +- pkg/models/usertoken/user_token.go | 26 +++++++++ .../backgroundsvcs/background_services.go | 4 +- pkg/server/wire.go | 5 +- pkg/server/wireexts_oss.go | 7 +-- pkg/services/accesscontrol/middleware.go | 6 +-- .../user_token.go => services/auth/auth.go} | 40 ++++++-------- .../auth/{ => authimpl}/auth_token.go | 54 +++++++++---------- .../auth/{ => authimpl}/auth_token_test.go | 26 ++++----- pkg/services/auth/{ => authimpl}/model.go | 16 ++---- .../auth/{ => authimpl}/token_cleanup.go | 2 +- .../auth/{ => authimpl}/token_cleanup_test.go | 2 +- pkg/services/auth/{ => authtest}/testing.go | 47 ++++++++-------- .../contexthandler/auth_proxy_test.go | 4 +- pkg/services/contexthandler/contexthandler.go | 10 ++-- .../contexthandler/contexthandler_test.go | 18 ++++--- pkg/services/ngalert/api/util_test.go | 3 +- pkg/services/quota/quotaimpl/quota_test.go | 3 +- 35 files changed, 245 insertions(+), 221 deletions(-) create mode 100644 pkg/models/usertoken/user_token.go rename pkg/{models/user_token.go => services/auth/auth.go} (74%) rename pkg/services/auth/{ => authimpl}/auth_token.go (91%) rename pkg/services/auth/{ => authimpl}/auth_token_test.go (96%) rename pkg/services/auth/{ => authimpl}/model.go (77%) rename pkg/services/auth/{ => authimpl}/token_cleanup.go (99%) rename pkg/services/auth/{ => authimpl}/token_cleanup_test.go (99%) rename pkg/services/auth/{ => authtest}/testing.go (77%) diff --git a/pkg/api/admin_users.go b/pkg/api/admin_users.go index e0b244bfab4..bf284cddb33 100644 --- a/pkg/api/admin_users.go +++ b/pkg/api/admin_users.go @@ -14,6 +14,7 @@ import ( "github.com/grafana/grafana/pkg/infra/metrics" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/accesscontrol" + "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/web" @@ -416,7 +417,7 @@ func (hs *HTTPServer) AdminGetUserAuthTokens(c *models.ReqContext) response.Resp // 404: notFoundError // 500: internalServerError func (hs *HTTPServer) AdminRevokeUserAuthToken(c *models.ReqContext) response.Response { - cmd := models.RevokeAuthTokenCmd{} + cmd := auth.RevokeAuthTokenCmd{} if err := web.Bind(c.Req, &cmd); err != nil { return response.Error(http.StatusBadRequest, "bad request data", err) } @@ -476,7 +477,7 @@ type AdminLogoutUserParams struct { type AdminRevokeUserAuthTokenParams struct { // in:body // required:true - Body models.RevokeAuthTokenCmd `json:"body"` + Body auth.RevokeAuthTokenCmd `json:"body"` // in:path // required:true UserID int64 `json:"user_id"` @@ -508,5 +509,5 @@ type AdminCreateUserResponseResponse struct { // swagger:response adminGetUserAuthTokensResponse type AdminGetUserAuthTokensResponse struct { // in:body - Body []*models.UserToken `json:"body"` + Body []*auth.UserToken `json:"body"` } diff --git a/pkg/api/admin_users_test.go b/pkg/api/admin_users_test.go index 41f79831f74..3de90077df9 100644 --- a/pkg/api/admin_users_test.go +++ b/pkg/api/admin_users_test.go @@ -13,6 +13,7 @@ import ( "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/login/loginservice" "github.com/grafana/grafana/pkg/services/login/logintest" "github.com/grafana/grafana/pkg/services/org" @@ -65,7 +66,7 @@ func TestAdminAPIEndpoint(t *testing.T) { }) t.Run("When a server admin attempts to revoke an auth token for a non-existing user", func(t *testing.T) { - cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2} + cmd := auth.RevokeAuthTokenCmd{AuthTokenId: 2} mockUser := usertest.NewUserServiceFake() mockUser.ExpectedError = user.ErrUserNotFound adminRevokeUserAuthTokenScenario(t, "Should return not found when calling POST on", @@ -263,7 +264,7 @@ func putAdminScenario(t *testing.T, desc string, url string, routePattern string func adminLogoutUserScenario(t *testing.T, desc string, url string, routePattern string, fn scenarioFunc, userService *usertest.FakeUserService) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { hs := HTTPServer{ - AuthTokenService: auth.NewFakeUserAuthTokenService(), + AuthTokenService: authtest.NewFakeUserAuthTokenService(), userService: userService, } @@ -285,9 +286,9 @@ func adminLogoutUserScenario(t *testing.T, desc string, url string, routePattern }) } -func adminRevokeUserAuthTokenScenario(t *testing.T, desc string, url string, routePattern string, cmd models.RevokeAuthTokenCmd, fn scenarioFunc, userService user.Service) { +func adminRevokeUserAuthTokenScenario(t *testing.T, desc string, url string, routePattern string, cmd auth.RevokeAuthTokenCmd, fn scenarioFunc, userService user.Service) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { - fakeAuthTokenService := auth.NewFakeUserAuthTokenService() + fakeAuthTokenService := authtest.NewFakeUserAuthTokenService() hs := HTTPServer{ AuthTokenService: fakeAuthTokenService, @@ -315,7 +316,7 @@ func adminRevokeUserAuthTokenScenario(t *testing.T, desc string, url string, rou func adminGetUserAuthTokensScenario(t *testing.T, desc string, url string, routePattern string, fn scenarioFunc, userService *usertest.FakeUserService) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { - fakeAuthTokenService := auth.NewFakeUserAuthTokenService() + fakeAuthTokenService := authtest.NewFakeUserAuthTokenService() hs := HTTPServer{ AuthTokenService: fakeAuthTokenService, @@ -341,7 +342,7 @@ func adminGetUserAuthTokensScenario(t *testing.T, desc string, url string, route func adminDisableUserScenario(t *testing.T, desc string, action string, url string, routePattern string, fn scenarioFunc) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { - fakeAuthTokenService := auth.NewFakeUserAuthTokenService() + fakeAuthTokenService := authtest.NewFakeUserAuthTokenService() authInfoService := &logintest.AuthInfoServiceFake{} diff --git a/pkg/api/common_test.go b/pkg/api/common_test.go index e6f0c159401..39be06adbc5 100644 --- a/pkg/api/common_test.go +++ b/pkg/api/common_test.go @@ -28,7 +28,7 @@ import ( accesscontrolmock "github.com/grafana/grafana/pkg/services/accesscontrol/mock" "github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol" "github.com/grafana/grafana/pkg/services/annotations/annotationstest" - "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" "github.com/grafana/grafana/pkg/services/contexthandler/ctxkey" @@ -181,7 +181,7 @@ type scenarioContext struct { defaultHandler web.Handler req *http.Request url string - userAuthTokenService *auth.FakeUserAuthTokenService + userAuthTokenService *authtest.FakeUserAuthTokenService sqlStore sqlstore.Store authInfoService *logintest.AuthInfoServiceFake dashboardVersionService dashver.Service @@ -207,7 +207,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{ Name: "database", } - userAuthTokenSvc := auth.NewFakeUserAuthTokenService() + userAuthTokenSvc := authtest.NewFakeUserAuthTokenService() renderSvc := &fakeRenderService{} authJWTSvc := models.NewFakeJWTService() tracer := tracing.InitializeTracerForTest() diff --git a/pkg/api/http_server.go b/pkg/api/http_server.go index f2cfce24fc1..7320b71b2cd 100644 --- a/pkg/api/http_server.go +++ b/pkg/api/http_server.go @@ -15,6 +15,7 @@ import ( "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/middleware/csrf" + "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/folder" "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/services/querylibrary" @@ -120,7 +121,7 @@ type HTTPServer struct { navTreeService navtree.Service CacheService *localcache.CacheService DataSourceCache datasources.CacheService - AuthTokenService models.UserTokenService + AuthTokenService auth.UserTokenService QuotaService quota.Service RemoteCacheService *remotecache.RemoteCache ProvisioningService provisioning.ProvisioningService @@ -220,7 +221,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi pluginRequestValidator models.PluginRequestValidator, pluginStaticRouteResolver plugins.StaticRouteResolver, pluginDashboardService plugindashboards.Service, pluginStore plugins.Store, pluginClient plugins.Client, pluginErrorResolver plugins.ErrorResolver, pluginInstaller plugins.Installer, settingsProvider setting.Provider, - dataSourceCache datasources.CacheService, userTokenService models.UserTokenService, + dataSourceCache datasources.CacheService, userTokenService auth.UserTokenService, cleanUpService *cleanup.CleanUpService, shortURLService shorturls.Service, queryHistoryService queryhistory.Service, correlationsService correlations.Service, thumbService thumbs.Service, remoteCache *remotecache.RemoteCache, provisioningService provisioning.ProvisioningService, loginService login.Service, authenticator loginpkg.Authenticator, accessControl accesscontrol.AccessControl, diff --git a/pkg/api/ldap_debug_test.go b/pkg/api/ldap_debug_test.go index 5a93f53e3ba..95b92e4e06f 100644 --- a/pkg/api/ldap_debug_test.go +++ b/pkg/api/ldap_debug_test.go @@ -15,7 +15,7 @@ import ( "github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/accesscontrol" - "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/ldap" "github.com/grafana/grafana/pkg/services/login/loginservice" "github.com/grafana/grafana/pkg/services/login/logintest" @@ -379,7 +379,7 @@ func postSyncUserWithLDAPContext(t *testing.T, requestURL string, preHook func(* hs := &HTTPServer{ Cfg: sc.cfg, - AuthTokenService: auth.NewFakeUserAuthTokenService(), + AuthTokenService: authtest.NewFakeUserAuthTokenService(), Login: loginservice.LoginServiceMock{}, authInfoService: sc.authInfoService, userService: userService, diff --git a/pkg/api/login.go b/pkg/api/login.go index 20baa31bde5..9b23ad05b6e 100644 --- a/pkg/api/login.go +++ b/pkg/api/login.go @@ -16,6 +16,7 @@ import ( "github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/middleware/cookies" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/auth" loginService "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/secrets" "github.com/grafana/grafana/pkg/services/user" @@ -227,7 +228,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext) response.Response { err = hs.loginUserWithUser(usr, c) if err != nil { - var createTokenErr *models.CreateTokenErr + var createTokenErr *auth.CreateTokenErr if errors.As(err, &createTokenErr) { resp = response.Error(createTokenErr.StatusCode, createTokenErr.ExternalErr, createTokenErr.InternalErr) } else { @@ -299,7 +300,7 @@ func (hs *HTTPServer) Logout(c *models.ReqContext) { } err := hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken, false) - if err != nil && !errors.Is(err, models.ErrUserTokenNotFound) { + if err != nil && !errors.Is(err, auth.ErrUserTokenNotFound) { hs.log.Error("failed to revoke auth token", "error", err) } @@ -370,7 +371,7 @@ func (hs *HTTPServer) samlSingleLogoutEnabled() bool { } func getLoginExternalError(err error) string { - var createTokenErr *models.CreateTokenErr + var createTokenErr *auth.CreateTokenErr if errors.As(err, &createTokenErr) { return createTokenErr.ExternalErr } diff --git a/pkg/api/login_test.go b/pkg/api/login_test.go index 27452b9bd09..ea30addb451 100644 --- a/pkg/api/login_test.go +++ b/pkg/api/login_test.go @@ -12,8 +12,6 @@ import ( "strings" "testing" - loginservice "github.com/grafana/grafana/pkg/services/login" - "github.com/grafana/grafana/pkg/services/navtree" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,9 +23,11 @@ import ( "github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/hooks" "github.com/grafana/grafana/pkg/services/licensing" + loginservice "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/navtree" "github.com/grafana/grafana/pkg/services/secrets" "github.com/grafana/grafana/pkg/services/secrets/fakes" secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager" @@ -323,7 +323,7 @@ func TestLoginPostRedirect(t *testing.T) { Cfg: setting.NewCfg(), HooksService: &hooks.HooksService{}, License: &licensing.OSSLicensingService{}, - AuthTokenService: auth.NewFakeUserAuthTokenService(), + AuthTokenService: authtest.NewFakeUserAuthTokenService(), } hs.Cfg.CookieSecure = true @@ -564,7 +564,7 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte Cfg: sc.cfg, SettingsProvider: &setting.OSSImpl{Cfg: sc.cfg}, License: &licensing.OSSLicensingService{}, - AuthTokenService: auth.NewFakeUserAuthTokenService(), + AuthTokenService: authtest.NewFakeUserAuthTokenService(), log: log.New("hello"), SocialService: &mockSocialService{}, } @@ -602,7 +602,7 @@ func TestLoginPostRunLokingHook(t *testing.T) { log: log.New("test"), Cfg: setting.NewCfg(), License: &licensing.OSSLicensingService{}, - AuthTokenService: auth.NewFakeUserAuthTokenService(), + AuthTokenService: authtest.NewFakeUserAuthTokenService(), HooksService: hookService, } diff --git a/pkg/api/user_token.go b/pkg/api/user_token.go index 772cb0ff183..3e12fca2d2e 100644 --- a/pkg/api/user_token.go +++ b/pkg/api/user_token.go @@ -9,6 +9,7 @@ import ( "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/response" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/web" @@ -43,7 +44,7 @@ func (hs *HTTPServer) GetUserAuthTokens(c *models.ReqContext) response.Response // 403: forbiddenError // 500: internalServerError func (hs *HTTPServer) RevokeUserAuthToken(c *models.ReqContext) response.Response { - cmd := models.RevokeAuthTokenCmd{} + cmd := auth.RevokeAuthTokenCmd{} if err := web.Bind(c.Req, &cmd); err != nil { return response.Error(http.StatusBadRequest, "bad request data", err) } @@ -143,7 +144,7 @@ func (hs *HTTPServer) getUserAuthTokensInternal(c *models.ReqContext, userID int return response.JSON(http.StatusOK, result) } -func (hs *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, userID int64, cmd models.RevokeAuthTokenCmd) response.Response { +func (hs *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, userID int64, cmd auth.RevokeAuthTokenCmd) response.Response { userQuery := user.GetUserByIDQuery{ID: userID} _, err := hs.userService.GetByID(c.Req.Context(), &userQuery) if err != nil { @@ -155,7 +156,7 @@ func (hs *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, userID i token, err := hs.AuthTokenService.GetUserToken(c.Req.Context(), userID, cmd.AuthTokenId) if err != nil { - if errors.Is(err, models.ErrUserTokenNotFound) { + if errors.Is(err, auth.ErrUserTokenNotFound) { return response.Error(404, "User auth token not found", err) } return response.Error(500, "Failed to get user auth token", err) @@ -167,7 +168,7 @@ func (hs *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, userID i err = hs.AuthTokenService.RevokeToken(c.Req.Context(), token, false) if err != nil { - if errors.Is(err, models.ErrUserTokenNotFound) { + if errors.Is(err, auth.ErrUserTokenNotFound) { return response.Error(404, "User auth token not found", err) } return response.Error(500, "Failed to revoke user auth token", err) @@ -182,11 +183,11 @@ func (hs *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, userID i type RevokeUserAuthTokenParams struct { // in:body // required:true - Body models.RevokeAuthTokenCmd `json:"body"` + Body auth.RevokeAuthTokenCmd `json:"body"` } // swagger:response getUserAuthTokensResponse type GetUserAuthTokensResponse struct { // in:body - Body []*models.UserToken `json:"body"` + Body []*auth.UserToken `json:"body"` } diff --git a/pkg/api/user_token_test.go b/pkg/api/user_token_test.go index 2dd7d7e7fbd..093a27011b9 100644 --- a/pkg/api/user_token_test.go +++ b/pkg/api/user_token_test.go @@ -12,6 +12,7 @@ import ( "github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user/usertest" @@ -20,7 +21,7 @@ import ( func TestUserTokenAPIEndpoint(t *testing.T) { userMock := usertest.NewUserServiceFake() t.Run("When current user attempts to revoke an auth token for a non-existing user", func(t *testing.T) { - cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2} + cmd := auth.RevokeAuthTokenCmd{AuthTokenId: 2} userMock.ExpectedError = user.ErrUserNotFound revokeUserAuthTokenScenario(t, "Should return not found when calling POST on", "/api/user/revoke-auth-token", "/api/user/revoke-auth-token", cmd, 200, func(sc *scenarioContext) { @@ -59,15 +60,15 @@ func TestUserTokenAPIEndpoint(t *testing.T) { }) t.Run("When revoke an auth token for a user", func(t *testing.T) { - cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2} - token := &models.UserToken{Id: 1} + cmd := auth.RevokeAuthTokenCmd{AuthTokenId: 2} + token := &auth.UserToken{Id: 1} mockUser := &usertest.FakeUserService{ ExpectedUser: &user.User{ID: 200}, } revokeUserAuthTokenInternalScenario(t, "Should be successful", cmd, 200, token, func(sc *scenarioContext) { - sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) { - return &models.UserToken{Id: 2}, nil + sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) { + return &auth.UserToken{Id: 2}, nil } sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() assert.Equal(t, 200, sc.resp.Code) @@ -75,11 +76,11 @@ func TestUserTokenAPIEndpoint(t *testing.T) { }) t.Run("When revoke the active auth token used by himself", func(t *testing.T) { - cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2} - token := &models.UserToken{Id: 2} + cmd := auth.RevokeAuthTokenCmd{AuthTokenId: 2} + token := &auth.UserToken{Id: 2} mockUser := usertest.NewUserServiceFake() revokeUserAuthTokenInternalScenario(t, "Should not be successful", cmd, testUserID, token, func(sc *scenarioContext) { - sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) { + sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) { return token, nil } sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() @@ -88,10 +89,10 @@ func TestUserTokenAPIEndpoint(t *testing.T) { }) t.Run("When gets auth tokens for a user", func(t *testing.T) { - currentToken := &models.UserToken{Id: 1} + currentToken := &auth.UserToken{Id: 1} mockUser := usertest.NewUserServiceFake() getUserAuthTokensInternalScenario(t, "Should be successful", currentToken, func(sc *scenarioContext) { - tokens := []*models.UserToken{ + tokens := []*auth.UserToken{ { Id: 1, ClientIp: "127.0.0.1", @@ -107,7 +108,7 @@ func TestUserTokenAPIEndpoint(t *testing.T) { SeenAt: 0, }, } - sc.userAuthTokenService.GetUserTokensProvider = func(ctx context.Context, userId int64) ([]*models.UserToken, error) { + sc.userAuthTokenService.GetUserTokensProvider = func(ctx context.Context, userId int64) ([]*auth.UserToken, error) { return tokens, nil } sc.fakeReqWithParams("GET", sc.url, map[string]string{}).exec() @@ -145,10 +146,10 @@ func TestUserTokenAPIEndpoint(t *testing.T) { }) } -func revokeUserAuthTokenScenario(t *testing.T, desc string, url string, routePattern string, cmd models.RevokeAuthTokenCmd, +func revokeUserAuthTokenScenario(t *testing.T, desc string, url string, routePattern string, cmd auth.RevokeAuthTokenCmd, userId int64, fn scenarioFunc, userService user.Service) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { - fakeAuthTokenService := auth.NewFakeUserAuthTokenService() + fakeAuthTokenService := authtest.NewFakeUserAuthTokenService() hs := HTTPServer{ AuthTokenService: fakeAuthTokenService, @@ -175,7 +176,7 @@ func revokeUserAuthTokenScenario(t *testing.T, desc string, url string, routePat func getUserAuthTokensScenario(t *testing.T, desc string, url string, routePattern string, userId int64, fn scenarioFunc, userService user.Service) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { - fakeAuthTokenService := auth.NewFakeUserAuthTokenService() + fakeAuthTokenService := authtest.NewFakeUserAuthTokenService() hs := HTTPServer{ AuthTokenService: fakeAuthTokenService, @@ -202,7 +203,7 @@ func getUserAuthTokensScenario(t *testing.T, desc string, url string, routePatte func logoutUserFromAllDevicesInternalScenario(t *testing.T, desc string, userId int64, fn scenarioFunc, userService user.Service) { t.Run(desc, func(t *testing.T) { hs := HTTPServer{ - AuthTokenService: auth.NewFakeUserAuthTokenService(), + AuthTokenService: authtest.NewFakeUserAuthTokenService(), userService: userService, } @@ -222,10 +223,10 @@ func logoutUserFromAllDevicesInternalScenario(t *testing.T, desc string, userId }) } -func revokeUserAuthTokenInternalScenario(t *testing.T, desc string, cmd models.RevokeAuthTokenCmd, userId int64, - token *models.UserToken, fn scenarioFunc, userService user.Service) { +func revokeUserAuthTokenInternalScenario(t *testing.T, desc string, cmd auth.RevokeAuthTokenCmd, userId int64, + token *auth.UserToken, fn scenarioFunc, userService user.Service) { t.Run(desc, func(t *testing.T) { - fakeAuthTokenService := auth.NewFakeUserAuthTokenService() + fakeAuthTokenService := authtest.NewFakeUserAuthTokenService() hs := HTTPServer{ AuthTokenService: fakeAuthTokenService, @@ -248,9 +249,9 @@ func revokeUserAuthTokenInternalScenario(t *testing.T, desc string, cmd models.R }) } -func getUserAuthTokensInternalScenario(t *testing.T, desc string, token *models.UserToken, fn scenarioFunc, userService user.Service) { +func getUserAuthTokensInternalScenario(t *testing.T, desc string, token *auth.UserToken, fn scenarioFunc, userService user.Service) { t.Run(desc, func(t *testing.T) { - fakeAuthTokenService := auth.NewFakeUserAuthTokenService() + fakeAuthTokenService := authtest.NewFakeUserAuthTokenService() hs := HTTPServer{ AuthTokenService: fakeAuthTokenService, diff --git a/pkg/cmd/grafana-cli/runner/wire.go b/pkg/cmd/grafana-cli/runner/wire.go index 82efb1ce69e..8ee0af3f6a5 100644 --- a/pkg/cmd/grafana-cli/runner/wire.go +++ b/pkg/cmd/grafana-cli/runner/wire.go @@ -7,6 +7,7 @@ import ( "context" "github.com/google/wire" + "github.com/grafana/grafana/pkg/services/auth/authimpl" "github.com/grafana/grafana/pkg/tsdb/parca" "github.com/grafana/grafana/pkg/tsdb/phlare" @@ -253,8 +254,8 @@ var wireSet = wire.NewSet( influxdb.ProvideService, wire.Bind(new(social.Service), new(*social.SocialService)), oauthtoken.ProvideService, - auth.ProvideActiveAuthTokenService, - wire.Bind(new(auth.ActiveTokenService), new(*auth.ActiveAuthTokenService)), + authimpl.ProvideActiveAuthTokenService, + wire.Bind(new(auth.ActiveTokenService), new(*authimpl.ActiveAuthTokenService)), wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)), tempo.ProvideService, loki.ProvideService, diff --git a/pkg/cmd/grafana-cli/runner/wireexts_oss.go b/pkg/cmd/grafana-cli/runner/wireexts_oss.go index 407ce3c1945..dfc36b3cd26 100644 --- a/pkg/cmd/grafana-cli/runner/wireexts_oss.go +++ b/pkg/cmd/grafana-cli/runner/wireexts_oss.go @@ -17,6 +17,7 @@ import ( "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" "github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol" "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authimpl" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/datasources/permissions" datasourceservice "github.com/grafana/grafana/pkg/services/datasources/service" @@ -48,9 +49,9 @@ var wireExtsSet = wire.NewSet( wire.Bind(new(setting.Provider), new(*setting.OSSImpl)), osskmsproviders.ProvideService, wire.Bind(new(kmsproviders.Service), new(osskmsproviders.Service)), - auth.ProvideUserAuthTokenService, - wire.Bind(new(models.UserTokenService), new(*auth.UserAuthTokenService)), - wire.Bind(new(models.UserTokenBackgroundService), new(*auth.UserAuthTokenService)), + authimpl.ProvideUserAuthTokenService, + wire.Bind(new(auth.UserTokenService), new(*authimpl.UserAuthTokenService)), + wire.Bind(new(auth.UserTokenBackgroundService), new(*authimpl.UserAuthTokenService)), acimpl.ProvideService, wire.Bind(new(accesscontrol.Service), new(*acimpl.Service)), wire.Bind(new(accesscontrol.RoleRegistry), new(*acimpl.Service)), diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go index 7ab6274bc4c..b7e00465586 100644 --- a/pkg/middleware/auth.go +++ b/pkg/middleware/auth.go @@ -10,6 +10,7 @@ import ( "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/middleware/cookies" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/dashboards" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/team" @@ -42,7 +43,7 @@ func notAuthorized(c *models.ReqContext) { c.Redirect(setting.AppSubUrl + "/login") } -func tokenRevoked(c *models.ReqContext, err *models.TokenRevokedError) { +func tokenRevoked(c *models.ReqContext, err *auth.TokenRevokedError) { if c.IsApiRequest() { c.JSON(401, map[string]interface{}{ "message": "Token revoked", @@ -117,7 +118,7 @@ func Auth(options *AuthOptions) web.Handler { requireLogin := !c.AllowAnonymous || forceLogin || options.ReqNoAnonynmous if !c.IsSignedIn && options.ReqSignedIn && requireLogin { - var revokedErr *models.TokenRevokedError + var revokedErr *auth.TokenRevokedError if errors.As(c.LookupTokenErr, &revokedErr) { tokenRevoked(c, revokedErr) return diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index fec44973196..94efe1c7ee4 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -28,6 +28,7 @@ import ( "github.com/grafana/grafana/pkg/services/apikey" "github.com/grafana/grafana/pkg/services/apikey/apikeytest" "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" "github.com/grafana/grafana/pkg/services/featuremgmt" @@ -264,8 +265,8 @@ func TestMiddlewareContext(t *testing.T) { sc.withTokenSessionCookie("token") sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID} - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return &auth.UserToken{ UserId: userID, UnhashedToken: unhashedToken, }, nil @@ -288,14 +289,14 @@ func TestMiddlewareContext(t *testing.T) { sc.withTokenSessionCookie("token") sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID} - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return &auth.UserToken{ UserId: userID, UnhashedToken: "", }, nil } - sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *models.UserToken, + sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) { userToken.UnhashedToken = "rotated" return true, nil @@ -371,8 +372,8 @@ func TestMiddlewareContext(t *testing.T) { middlewareScenario(t, "Invalid/expired auth token in cookie", func(t *testing.T, sc *scenarioContext) { sc.withTokenSessionCookie("token") - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return nil, models.ErrUserTokenNotFound + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return nil, auth.ErrUserTokenNotFound } sc.fakeReq("GET", "/").exec() @@ -391,8 +392,8 @@ func TestMiddlewareContext(t *testing.T) { sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID} sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(11 * time.Second)} - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return &auth.UserToken{ UserId: userID, UnhashedToken: unhashedToken, }, nil @@ -424,8 +425,8 @@ func TestMiddlewareContext(t *testing.T) { OAuthRefreshToken: "refresh_token"} sc.oauthTokenService.ExpectedErrors = map[string]error{"TryTokenRefresh": errors.New("error")} - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return &auth.UserToken{ UserId: userID, UnhashedToken: unhashedToken, }, nil @@ -454,8 +455,8 @@ func TestMiddlewareContext(t *testing.T) { sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID} sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(-5 * time.Second), OAuthRefreshToken: "refreshtoken"} - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return &auth.UserToken{ UserId: userID, UnhashedToken: unhashedToken, }, nil @@ -481,8 +482,8 @@ func TestMiddlewareContext(t *testing.T) { sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID} sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID} - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return &auth.UserToken{ UserId: userID, UnhashedToken: unhashedToken, }, nil @@ -819,14 +820,14 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func( sc.userService = usertest.NewUserServiceFake() sc.orgService = orgtest.NewOrgServiceFake() sc.apiKeyService = &apikeytest.Service{} - sc.oauthTokenService = &auth.FakeOAuthTokenService{} + sc.oauthTokenService = &authtest.FakeOAuthTokenService{} ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService, sc.oauthTokenService) sc.sqlStore = ctxHdlr.SQLStore sc.contextHandler = ctxHdlr sc.m.Use(ctxHdlr.Middleware) sc.m.Use(OrgRedirect(sc.cfg, sc.userService)) - sc.userAuthTokenService = ctxHdlr.AuthTokenService.(*auth.FakeUserAuthTokenService) + sc.userAuthTokenService = ctxHdlr.AuthTokenService.(*authtest.FakeUserAuthTokenService) sc.jwtAuthService = ctxHdlr.JWTAuthService.(*models.FakeJWTService) sc.remoteCacheService = ctxHdlr.RemoteCache @@ -856,7 +857,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func( func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *dbtest.FakeDB, loginService *loginservice.LoginServiceMock, apiKeyService *apikeytest.Service, userService *usertest.FakeUserService, orgService *orgtest.FakeOrgService, - oauthTokenService *auth.FakeOAuthTokenService, + oauthTokenService *authtest.FakeOAuthTokenService, ) *contexthandler.ContextHandler { t.Helper() @@ -868,7 +869,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *dbtest.Fake } remoteCacheSvc := remotecache.NewFakeStore(t) - userAuthTokenSvc := auth.NewFakeUserAuthTokenService() + userAuthTokenSvc := authtest.NewFakeUserAuthTokenService() renderSvc := &fakeRenderService{} authJWTSvc := models.NewFakeJWTService() tracer := tracing.InitializeTracerForTest() diff --git a/pkg/middleware/org_redirect_test.go b/pkg/middleware/org_redirect_test.go index 8d1460983df..810f04e56a1 100644 --- a/pkg/middleware/org_redirect_test.go +++ b/pkg/middleware/org_redirect_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/user" ) @@ -48,8 +48,8 @@ func TestOrgRedirectMiddleware(t *testing.T) { middlewareScenario(t, tc.desc, func(t *testing.T, sc *scenarioContext) { sc.withTokenSessionCookie("token") sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12} - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return &auth.UserToken{ UserId: 0, UnhashedToken: "", }, nil @@ -68,8 +68,8 @@ func TestOrgRedirectMiddleware(t *testing.T) { sc.userService.ExpectedSetUsingOrgError = fmt.Errorf("") sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12} - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return &auth.UserToken{ UserId: 12, UnhashedToken: "", }, nil diff --git a/pkg/middleware/quota_test.go b/pkg/middleware/quota_test.go index 446b7842933..156a626b10e 100644 --- a/pkg/middleware/quota_test.go +++ b/pkg/middleware/quota_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/quota/quotatest" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" @@ -55,8 +55,8 @@ func TestMiddlewareQuota(t *testing.T) { setUp := func(sc *scenarioContext) { sc.withTokenSessionCookie("token") sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: 12} - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return &auth.UserToken{ UserId: 12, UnhashedToken: "", }, nil diff --git a/pkg/middleware/recovery_test.go b/pkg/middleware/recovery_test.go index 866eeaa490f..1a8fe8537c5 100644 --- a/pkg/middleware/recovery_test.go +++ b/pkg/middleware/recovery_test.go @@ -10,7 +10,7 @@ import ( "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" ) @@ -65,7 +65,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) { sc.m.Use(AddDefaultResponseHeaders(cfg)) sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]")) - sc.userAuthTokenService = auth.NewFakeUserAuthTokenService() + sc.userAuthTokenService = authtest.NewFakeUserAuthTokenService() sc.remoteCacheService = remotecache.NewFakeStore(t) contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil, nil) diff --git a/pkg/middleware/testing.go b/pkg/middleware/testing.go index a091f9118fe..7e858cfd845 100644 --- a/pkg/middleware/testing.go +++ b/pkg/middleware/testing.go @@ -13,7 +13,7 @@ import ( "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/apikey/apikeytest" - "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/contexthandler/ctxkey" "github.com/grafana/grafana/pkg/services/login/loginservice" @@ -36,7 +36,7 @@ type scenarioContext struct { handlerFunc handlerFunc defaultHandler web.Handler url string - userAuthTokenService *auth.FakeUserAuthTokenService + userAuthTokenService *authtest.FakeUserAuthTokenService jwtAuthService *models.FakeJWTService remoteCacheService *remotecache.RemoteCache cfg *setting.Cfg @@ -46,7 +46,7 @@ type scenarioContext struct { loginService *loginservice.LoginServiceMock apiKeyService *apikeytest.Service userService *usertest.FakeUserService - oauthTokenService *auth.FakeOAuthTokenService + oauthTokenService *authtest.FakeOAuthTokenService orgService *orgtest.FakeOrgService req *http.Request diff --git a/pkg/models/context.go b/pkg/models/context.go index dd07b7236b8..432dee40e99 100644 --- a/pkg/models/context.go +++ b/pkg/models/context.go @@ -5,6 +5,7 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/tracing" + "github.com/grafana/grafana/pkg/models/usertoken" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" @@ -15,7 +16,7 @@ import ( type ReqContext struct { *web.Context *user.SignedInUser - UserToken *UserToken + UserToken *usertoken.UserToken IsSignedIn bool IsRenderCall bool diff --git a/pkg/models/usertoken/user_token.go b/pkg/models/usertoken/user_token.go new file mode 100644 index 00000000000..beb2ac1355f --- /dev/null +++ b/pkg/models/usertoken/user_token.go @@ -0,0 +1,26 @@ +package usertoken + +type TokenRevokedError struct { + UserID int64 + TokenID int64 + MaxConcurrentSessions int64 +} + +func (e *TokenRevokedError) Error() string { return "user token revoked" } + +// UserToken represents a user token +type UserToken struct { + Id int64 + UserId int64 + AuthToken string + PrevAuthToken string + UserAgent string + ClientIp string + AuthTokenSeen bool + SeenAt int64 + RotatedAt int64 + CreatedAt int64 + UpdatedAt int64 + RevokedAt int64 + UnhashedToken string +} diff --git a/pkg/server/backgroundsvcs/background_services.go b/pkg/server/backgroundsvcs/background_services.go index eaaa092aa08..2582d680327 100644 --- a/pkg/server/backgroundsvcs/background_services.go +++ b/pkg/server/backgroundsvcs/background_services.go @@ -7,10 +7,10 @@ import ( "github.com/grafana/grafana/pkg/infra/tracing" uss "github.com/grafana/grafana/pkg/infra/usagestats/service" "github.com/grafana/grafana/pkg/infra/usagestats/statscollector" - "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins/manager/process" "github.com/grafana/grafana/pkg/registry" "github.com/grafana/grafana/pkg/services/alerting" + "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/cleanup" "github.com/grafana/grafana/pkg/services/dashboardsnapshots" "github.com/grafana/grafana/pkg/services/grpcserver" @@ -38,7 +38,7 @@ import ( func ProvideBackgroundServiceRegistry( httpServer *api.HTTPServer, ng *ngalert.AlertNG, cleanup *cleanup.CleanUpService, live *live.GrafanaLive, pushGateway *pushhttp.Gateway, notifications *notifications.NotificationService, processManager *process.Manager, - rendering *rendering.RenderingService, tokenService models.UserTokenBackgroundService, tracing tracing.Tracer, + rendering *rendering.RenderingService, tokenService auth.UserTokenBackgroundService, tracing tracing.Tracer, provisioning *provisioning.ProvisioningServiceImpl, alerting *alerting.AlertEngine, usageStats *uss.UsageStats, statsCollector *statscollector.Service, grafanaUpdateChecker *updatechecker.GrafanaService, pluginsUpdateChecker *updatechecker.PluginsService, metrics *metrics.InternalMetricsService, diff --git a/pkg/server/wire.go b/pkg/server/wire.go index 57338f5d46c..4d2802abbe8 100644 --- a/pkg/server/wire.go +++ b/pkg/server/wire.go @@ -49,6 +49,7 @@ import ( "github.com/grafana/grafana/pkg/services/annotations/annotationsimpl" "github.com/grafana/grafana/pkg/services/apikey/apikeyimpl" "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authimpl" "github.com/grafana/grafana/pkg/services/auth/jwt" "github.com/grafana/grafana/pkg/services/cleanup" "github.com/grafana/grafana/pkg/services/comments" @@ -271,8 +272,8 @@ var wireBasicSet = wire.NewSet( influxdb.ProvideService, wire.Bind(new(social.Service), new(*social.SocialService)), oauthtoken.ProvideService, - auth.ProvideActiveAuthTokenService, - wire.Bind(new(auth.ActiveTokenService), new(*auth.ActiveAuthTokenService)), + authimpl.ProvideActiveAuthTokenService, + wire.Bind(new(auth.ActiveTokenService), new(*authimpl.ActiveAuthTokenService)), wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)), tempo.ProvideService, loki.ProvideService, diff --git a/pkg/server/wireexts_oss.go b/pkg/server/wireexts_oss.go index 5c366331875..6fae56eaaca 100644 --- a/pkg/server/wireexts_oss.go +++ b/pkg/server/wireexts_oss.go @@ -17,6 +17,7 @@ import ( "github.com/grafana/grafana/pkg/services/accesscontrol/acimpl" "github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol" "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authimpl" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/datasources/permissions" datasourceservice "github.com/grafana/grafana/pkg/services/datasources/service" @@ -39,9 +40,9 @@ import ( ) var wireExtsBasicSet = wire.NewSet( - auth.ProvideUserAuthTokenService, - wire.Bind(new(models.UserTokenService), new(*auth.UserAuthTokenService)), - wire.Bind(new(models.UserTokenBackgroundService), new(*auth.UserAuthTokenService)), + authimpl.ProvideUserAuthTokenService, + wire.Bind(new(auth.UserTokenService), new(*authimpl.UserAuthTokenService)), + wire.Bind(new(auth.UserTokenBackgroundService), new(*authimpl.UserAuthTokenService)), licensing.ProvideService, wire.Bind(new(models.Licensing), new(*licensing.OSSLicensingService)), setting.ProvideProvider, diff --git a/pkg/services/accesscontrol/middleware.go b/pkg/services/accesscontrol/middleware.go index c0d3868e266..99b93d329dd 100644 --- a/pkg/services/accesscontrol/middleware.go +++ b/pkg/services/accesscontrol/middleware.go @@ -14,8 +14,8 @@ import ( "time" "github.com/grafana/grafana/pkg/middleware/cookies" - "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/models/usertoken" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" @@ -41,7 +41,7 @@ func Middleware(ac AccessControl) func(web.Handler, Evaluator) web.Handler { } } - var revokedErr *models.TokenRevokedError + var revokedErr *usertoken.TokenRevokedError if errors.As(c.LookupTokenErr, &revokedErr) { unauthorized(c, revokedErr) return @@ -111,7 +111,7 @@ func unauthorized(c *models.ReqContext, err error) { "message": "Unauthorized", } - var revokedErr *models.TokenRevokedError + var revokedErr *usertoken.TokenRevokedError if errors.As(err, &revokedErr) { response["message"] = "Token revoked" response["error"] = map[string]interface{}{ diff --git a/pkg/models/user_token.go b/pkg/services/auth/auth.go similarity index 74% rename from pkg/models/user_token.go rename to pkg/services/auth/auth.go index 6c92a40d86b..77b21316707 100644 --- a/pkg/models/user_token.go +++ b/pkg/services/auth/auth.go @@ -1,19 +1,32 @@ -package models +package auth import ( "context" "errors" "net" + "github.com/grafana/grafana/pkg/models/usertoken" "github.com/grafana/grafana/pkg/registry" + "github.com/grafana/grafana/pkg/services/quota" "github.com/grafana/grafana/pkg/services/user" ) +const ( + QuotaTargetSrv quota.TargetSrv = "auth" + QuotaTarget quota.Target = "session" +) + +type ActiveTokenService interface { + ActiveTokenCount(ctx context.Context, _ *quota.ScopeParameters) (*quota.Map, error) +} + // Typed errors var ( ErrUserTokenNotFound = errors.New("user token not found") ) +type TokenRevokedError = usertoken.TokenRevokedError + // CreateTokenErr represents a token creation error; used in Enterprise type CreateTokenErr struct { StatusCode int @@ -35,30 +48,7 @@ type TokenExpiredError struct { func (e *TokenExpiredError) Error() string { return "user token expired" } -type TokenRevokedError struct { - UserID int64 - TokenID int64 - MaxConcurrentSessions int64 -} - -func (e *TokenRevokedError) Error() string { return "user token revoked" } - -// UserToken represents a user token -type UserToken struct { - Id int64 - UserId int64 - AuthToken string - PrevAuthToken string - UserAgent string - ClientIp string - AuthTokenSeen bool - SeenAt int64 - RotatedAt int64 - CreatedAt int64 - UpdatedAt int64 - RevokedAt int64 - UnhashedToken string -} +type UserToken = usertoken.UserToken type RevokeAuthTokenCmd struct { AuthTokenId int64 `json:"authTokenId"` diff --git a/pkg/services/auth/auth_token.go b/pkg/services/auth/authimpl/auth_token.go similarity index 91% rename from pkg/services/auth/auth_token.go rename to pkg/services/auth/authimpl/auth_token.go index f261e33bcd1..c969686a2f8 100644 --- a/pkg/services/auth/auth_token.go +++ b/pkg/services/auth/authimpl/auth_token.go @@ -1,4 +1,4 @@ -package auth +package authimpl import ( "context" @@ -11,7 +11,7 @@ import ( "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/serverlock" - "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/quota" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" @@ -42,10 +42,6 @@ type UserAuthTokenService struct { log log.Logger } -type ActiveTokenService interface { - ActiveTokenCount(ctx context.Context, _ *quota.ScopeParameters) (*quota.Map, error) -} - type ActiveAuthTokenService struct { cfg *setting.Cfg sqlStore db.DB @@ -63,7 +59,7 @@ func ProvideActiveAuthTokenService(cfg *setting.Cfg, sqlStore db.DB, quotaServic } if err := quotaService.RegisterQuotaReporter("a.NewUsageReporter{ - TargetSrv: QuotaTargetSrv, + TargetSrv: auth.QuotaTargetSrv, DefaultLimits: defaultLimits, Reporter: s.ActiveTokenCount, }); err != nil { @@ -86,7 +82,7 @@ func (a *ActiveAuthTokenService) ActiveTokenCount(ctx context.Context, _ *quota. return err }) - tag, err := quota.NewTag(QuotaTargetSrv, QuotaTarget, quota.GlobalScope) + tag, err := quota.NewTag(auth.QuotaTargetSrv, auth.QuotaTarget, quota.GlobalScope) if err != nil { return nil, err } @@ -96,7 +92,7 @@ func (a *ActiveAuthTokenService) ActiveTokenCount(ctx context.Context, _ *quota. return u, err } -func (s *UserAuthTokenService) CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*models.UserToken, error) { +func (s *UserAuthTokenService) CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error) { token, err := util.RandomHex(16) if err != nil { return nil, err @@ -138,13 +134,13 @@ func (s *UserAuthTokenService) CreateToken(ctx context.Context, user *user.User, ctxLogger := s.log.FromContext(ctx) ctxLogger.Debug("user auth token created", "tokenId", userAuthToken.Id, "userId", userAuthToken.UserId, "clientIP", userAuthToken.ClientIp, "userAgent", userAuthToken.UserAgent, "authToken", userAuthToken.AuthToken) - var userToken models.UserToken + var userToken auth.UserToken err = userAuthToken.toUserToken(&userToken) return &userToken, err } -func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*models.UserToken, error) { +func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { hashedToken := hashToken(unhashedToken) var model userAuthToken var exists bool @@ -162,14 +158,14 @@ func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken st } if !exists { - return nil, models.ErrUserTokenNotFound + return nil, auth.ErrUserTokenNotFound } ctxLogger := s.log.FromContext(ctx) if model.RevokedAt > 0 { ctxLogger.Debug("user token has been revoked", "user ID", model.UserId, "token ID", model.Id) - return nil, &models.TokenRevokedError{ + return nil, &auth.TokenRevokedError{ UserID: model.UserId, TokenID: model.Id, } @@ -177,7 +173,7 @@ func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken st if model.CreatedAt <= s.createdAfterParam() || model.RotatedAt <= s.rotatedAfterParam() { ctxLogger.Debug("user token has expired", "user ID", model.UserId, "token ID", model.Id) - return nil, &models.TokenExpiredError{ + return nil, &auth.TokenExpiredError{ UserID: model.UserId, TokenID: model.Id, } @@ -242,13 +238,13 @@ func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken st model.UnhashedToken = unhashedToken - var userToken models.UserToken + var userToken auth.UserToken err = model.toUserToken(&userToken) return &userToken, err } -func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *models.UserToken, +func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) { if token == nil { return false, nil @@ -328,9 +324,9 @@ func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *models return false, nil } -func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *models.UserToken, soft bool) error { +func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *auth.UserToken, soft bool) error { if token == nil { - return models.ErrUserTokenNotFound + return auth.ErrUserTokenNotFound } model, err := userAuthTokenFromUserToken(token) @@ -361,7 +357,7 @@ func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *models.Us if rowsAffected == 0 { ctxLogger.Debug("user auth token not found/revoked", "tokenId", model.Id, "userId", model.UserId, "clientIP", model.ClientIp, "userAgent", model.UserAgent) - return models.ErrUserTokenNotFound + return auth.ErrUserTokenNotFound } ctxLogger.Debug("user auth token revoked", "tokenId", model.Id, "userId", model.UserId, "clientIP", model.ClientIp, "userAgent", model.UserAgent, "soft", soft) @@ -418,8 +414,8 @@ func (s *UserAuthTokenService) BatchRevokeAllUserTokens(ctx context.Context, use }) } -func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) { - var result models.UserToken +func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) { + var result auth.UserToken err := s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error { var token userAuthToken exists, err := dbSession.Where("id = ? AND user_id = ?", userTokenId, userId).Get(&token) @@ -428,7 +424,7 @@ func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTok } if !exists { - return models.ErrUserTokenNotFound + return auth.ErrUserTokenNotFound } return token.toUserToken(&result) @@ -437,8 +433,8 @@ func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTok return &result, err } -func (s *UserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) ([]*models.UserToken, error) { - result := []*models.UserToken{} +func (s *UserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) ([]*auth.UserToken, error) { + result := []*auth.UserToken{} err := s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error { var tokens []*userAuthToken err := dbSession.Where("user_id = ? AND created_at > ? AND rotated_at > ? AND revoked_at = 0", @@ -451,7 +447,7 @@ func (s *UserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) } for _, token := range tokens { - var userToken models.UserToken + var userToken auth.UserToken if err := token.toUserToken(&userToken); err != nil { return err } @@ -464,8 +460,8 @@ func (s *UserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) return result, err } -func (s *UserAuthTokenService) GetUserRevokedTokens(ctx context.Context, userId int64) ([]*models.UserToken, error) { - result := []*models.UserToken{} +func (s *UserAuthTokenService) GetUserRevokedTokens(ctx context.Context, userId int64) ([]*auth.UserToken, error) { + result := []*auth.UserToken{} err := s.SQLStore.WithDbSession(ctx, func(dbSession *db.Session) error { var tokens []*userAuthToken err := dbSession.Where("user_id = ? AND revoked_at > 0", userId).Find(&tokens) @@ -474,7 +470,7 @@ func (s *UserAuthTokenService) GetUserRevokedTokens(ctx context.Context, userId } for _, token := range tokens { - var userToken models.UserToken + var userToken auth.UserToken if err := token.toUserToken(&userToken); err != nil { return err } @@ -507,7 +503,7 @@ func readQuotaConfig(cfg *setting.Cfg) (*quota.Map, error) { return limits, nil } - globalQuotaTag, err := quota.NewTag(QuotaTargetSrv, QuotaTarget, quota.GlobalScope) + globalQuotaTag, err := quota.NewTag(auth.QuotaTargetSrv, auth.QuotaTarget, quota.GlobalScope) if err != nil { return limits, err } diff --git a/pkg/services/auth/auth_token_test.go b/pkg/services/auth/authimpl/auth_token_test.go similarity index 96% rename from pkg/services/auth/auth_token_test.go rename to pkg/services/auth/authimpl/auth_token_test.go index 16886d7b439..97528bf6be2 100644 --- a/pkg/services/auth/auth_token_test.go +++ b/pkg/services/auth/authimpl/auth_token_test.go @@ -1,4 +1,4 @@ -package auth +package authimpl import ( "context" @@ -8,12 +8,12 @@ import ( "testing" "time" + "github.com/grafana/grafana/pkg/services/auth" "github.com/stretchr/testify/require" "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/log" - "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/quota" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" @@ -29,7 +29,7 @@ func TestUserAuthToken(t *testing.T) { defer func() { getTime = time.Now }() t.Run("When creating token", func(t *testing.T) { - createToken := func() *models.UserToken { + createToken := func() *auth.UserToken { userToken, err := ctx.tokenService.CreateToken(context.Background(), user, net.ParseIP("192.168.10.11"), "some user agent") require.Nil(t, err) @@ -43,7 +43,7 @@ func TestUserAuthToken(t *testing.T) { t.Run("Can count active tokens", func(t *testing.T) { m, err := ctx.activeTokenService.ActiveTokenCount(context.Background(), "a.ScopeParameters{}) require.Nil(t, err) - tag, err := quota.NewTag(QuotaTargetSrv, QuotaTarget, quota.GlobalScope) + tag, err := quota.NewTag(auth.QuotaTargetSrv, auth.QuotaTarget, quota.GlobalScope) require.NoError(t, err) count, ok := m.Get(tag) require.True(t, ok) @@ -65,7 +65,7 @@ func TestUserAuthToken(t *testing.T) { t.Run("When lookup hashed token should return user auth token not found error", func(t *testing.T) { userToken, err := ctx.tokenService.LookupToken(context.Background(), userToken.AuthToken) - require.Equal(t, models.ErrUserTokenNotFound, err) + require.Equal(t, auth.ErrUserTokenNotFound, err) require.Nil(t, userToken) }) @@ -90,13 +90,13 @@ func TestUserAuthToken(t *testing.T) { t.Run("revoking nil token should return error", func(t *testing.T) { err := ctx.tokenService.RevokeToken(context.Background(), nil, false) - require.Equal(t, models.ErrUserTokenNotFound, err) + require.Equal(t, auth.ErrUserTokenNotFound, err) }) t.Run("revoking non-existing token should return error", func(t *testing.T) { userToken.Id = 1000 err := ctx.tokenService.RevokeToken(context.Background(), userToken, false) - require.Equal(t, models.ErrUserTokenNotFound, err) + require.Equal(t, auth.ErrUserTokenNotFound, err) }) ctx = createTestContext(t) @@ -209,13 +209,13 @@ func TestUserAuthToken(t *testing.T) { } notGood, err := ctx.tokenService.LookupToken(context.Background(), userToken.UnhashedToken) - require.Equal(t, reflect.TypeOf(err), reflect.TypeOf(&models.TokenExpiredError{})) + require.Equal(t, reflect.TypeOf(err), reflect.TypeOf(&auth.TokenExpiredError{})) require.Nil(t, notGood) t.Run("should not find active token when expired", func(t *testing.T) { m, err := ctx.activeTokenService.ActiveTokenCount(context.Background(), "a.ScopeParameters{}) require.Nil(t, err) - tag, err := quota.NewTag(QuotaTargetSrv, QuotaTarget, quota.GlobalScope) + tag, err := quota.NewTag(auth.QuotaTargetSrv, auth.QuotaTarget, quota.GlobalScope) require.NoError(t, err) count, ok := m.Get(tag) require.True(t, ok) @@ -247,7 +247,7 @@ func TestUserAuthToken(t *testing.T) { } notGood, err := ctx.tokenService.LookupToken(context.Background(), userToken.UnhashedToken) - require.Equal(t, reflect.TypeOf(err), reflect.TypeOf(&models.TokenExpiredError{})) + require.Equal(t, reflect.TypeOf(err), reflect.TypeOf(&auth.TokenExpiredError{})) require.Nil(t, notGood) }) }) @@ -274,7 +274,7 @@ func TestUserAuthToken(t *testing.T) { model, err := ctx.getAuthTokenByID(userToken.Id) require.Nil(t, err) - var tok models.UserToken + var tok auth.UserToken err = model.toUserToken(&tok) require.Nil(t, err) @@ -471,7 +471,7 @@ func TestUserAuthToken(t *testing.T) { }) t.Run("When populating userAuthToken from UserToken should copy all properties", func(t *testing.T) { - ut := models.UserToken{ + ut := auth.UserToken{ Id: 1, UserId: 2, AuthToken: "a", @@ -524,7 +524,7 @@ func TestUserAuthToken(t *testing.T) { require.Nil(t, err) uatMap := uatJSON.MustMap() - var ut models.UserToken + var ut auth.UserToken err = uat.toUserToken(&ut) require.Nil(t, err) utBytes, err := json.Marshal(ut) diff --git a/pkg/services/auth/model.go b/pkg/services/auth/authimpl/model.go similarity index 77% rename from pkg/services/auth/model.go rename to pkg/services/auth/authimpl/model.go index afc5b566c48..407927df572 100644 --- a/pkg/services/auth/model.go +++ b/pkg/services/auth/authimpl/model.go @@ -1,10 +1,9 @@ -package auth +package authimpl import ( "fmt" - "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/services/quota" + "github.com/grafana/grafana/pkg/services/auth" ) type userAuthToken struct { @@ -23,13 +22,13 @@ type userAuthToken struct { UnhashedToken string `xorm:"-"` } -func userAuthTokenFromUserToken(ut *models.UserToken) (*userAuthToken, error) { +func userAuthTokenFromUserToken(ut *auth.UserToken) (*userAuthToken, error) { var uat userAuthToken err := uat.fromUserToken(ut) return &uat, err } -func (uat *userAuthToken) fromUserToken(ut *models.UserToken) error { +func (uat *userAuthToken) fromUserToken(ut *auth.UserToken) error { if uat == nil { return fmt.Errorf("needs pointer to userAuthToken struct") } @@ -51,7 +50,7 @@ func (uat *userAuthToken) fromUserToken(ut *models.UserToken) error { return nil } -func (uat *userAuthToken) toUserToken(ut *models.UserToken) error { +func (uat *userAuthToken) toUserToken(ut *auth.UserToken) error { if uat == nil { return fmt.Errorf("needs pointer to userAuthToken struct") } @@ -72,8 +71,3 @@ func (uat *userAuthToken) toUserToken(ut *models.UserToken) error { return nil } - -const ( - QuotaTargetSrv quota.TargetSrv = "auth" - QuotaTarget quota.Target = "session" -) diff --git a/pkg/services/auth/token_cleanup.go b/pkg/services/auth/authimpl/token_cleanup.go similarity index 99% rename from pkg/services/auth/token_cleanup.go rename to pkg/services/auth/authimpl/token_cleanup.go index a82f13630fe..08d8ae7c614 100644 --- a/pkg/services/auth/token_cleanup.go +++ b/pkg/services/auth/authimpl/token_cleanup.go @@ -1,4 +1,4 @@ -package auth +package authimpl import ( "context" diff --git a/pkg/services/auth/token_cleanup_test.go b/pkg/services/auth/authimpl/token_cleanup_test.go similarity index 99% rename from pkg/services/auth/token_cleanup_test.go rename to pkg/services/auth/authimpl/token_cleanup_test.go index a39e0d7892b..e207448c8f6 100644 --- a/pkg/services/auth/token_cleanup_test.go +++ b/pkg/services/auth/authimpl/token_cleanup_test.go @@ -1,4 +1,4 @@ -package auth +package authimpl import ( "context" diff --git a/pkg/services/auth/testing.go b/pkg/services/auth/authtest/testing.go similarity index 77% rename from pkg/services/auth/testing.go rename to pkg/services/auth/authtest/testing.go index 63b08a9a639..d2bbd09b46b 100644 --- a/pkg/services/auth/testing.go +++ b/pkg/services/auth/authtest/testing.go @@ -1,4 +1,4 @@ -package auth +package authtest import ( "context" @@ -6,42 +6,43 @@ import ( "time" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/user" "golang.org/x/oauth2" ) type FakeUserAuthTokenService struct { - CreateTokenProvider func(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*models.UserToken, error) - TryRotateTokenProvider func(ctx context.Context, token *models.UserToken, clientIP net.IP, userAgent string) (bool, error) - LookupTokenProvider func(ctx context.Context, unhashedToken string) (*models.UserToken, error) - RevokeTokenProvider func(ctx context.Context, token *models.UserToken, soft bool) error + CreateTokenProvider func(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error) + TryRotateTokenProvider func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) + LookupTokenProvider func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) + RevokeTokenProvider func(ctx context.Context, token *auth.UserToken, soft bool) error RevokeAllUserTokensProvider func(ctx context.Context, userId int64) error ActiveAuthTokenCount func(ctx context.Context) (int64, error) - GetUserTokenProvider func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) - GetUserTokensProvider func(ctx context.Context, userId int64) ([]*models.UserToken, error) - GetUserRevokedTokensProvider func(ctx context.Context, userId int64) ([]*models.UserToken, error) + GetUserTokenProvider func(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) + GetUserTokensProvider func(ctx context.Context, userId int64) ([]*auth.UserToken, error) + GetUserRevokedTokensProvider func(ctx context.Context, userId int64) ([]*auth.UserToken, error) BatchRevokedTokenProvider func(ctx context.Context, userIds []int64) error } func NewFakeUserAuthTokenService() *FakeUserAuthTokenService { return &FakeUserAuthTokenService{ - CreateTokenProvider: func(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*models.UserToken, error) { - return &models.UserToken{ + CreateTokenProvider: func(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error) { + return &auth.UserToken{ UserId: 0, UnhashedToken: "", }, nil }, - TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP, userAgent string) (bool, error) { + TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) { return false, nil }, - LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { - return &models.UserToken{ + LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { + return &auth.UserToken{ UserId: 0, UnhashedToken: "", }, nil }, - RevokeTokenProvider: func(ctx context.Context, token *models.UserToken, soft bool) error { + RevokeTokenProvider: func(ctx context.Context, token *auth.UserToken, soft bool) error { return nil }, RevokeAllUserTokensProvider: func(ctx context.Context, userId int64) error { @@ -53,10 +54,10 @@ func NewFakeUserAuthTokenService() *FakeUserAuthTokenService { ActiveAuthTokenCount: func(ctx context.Context) (int64, error) { return 10, nil }, - GetUserTokenProvider: func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) { + GetUserTokenProvider: func(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) { return nil, nil }, - GetUserTokensProvider: func(ctx context.Context, userId int64) ([]*models.UserToken, error) { + GetUserTokensProvider: func(ctx context.Context, userId int64) ([]*auth.UserToken, error) { return nil, nil }, } @@ -68,20 +69,20 @@ func (s *FakeUserAuthTokenService) Init() error { return nil } -func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*models.UserToken, error) { +func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error) { return s.CreateTokenProvider(context.Background(), user, clientIP, userAgent) } -func (s *FakeUserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*models.UserToken, error) { +func (s *FakeUserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { return s.LookupTokenProvider(context.Background(), unhashedToken) } -func (s *FakeUserAuthTokenService) TryRotateToken(ctx context.Context, token *models.UserToken, clientIP net.IP, +func (s *FakeUserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) { return s.TryRotateTokenProvider(context.Background(), token, clientIP, userAgent) } -func (s *FakeUserAuthTokenService) RevokeToken(ctx context.Context, token *models.UserToken, soft bool) error { +func (s *FakeUserAuthTokenService) RevokeToken(ctx context.Context, token *auth.UserToken, soft bool) error { return s.RevokeTokenProvider(context.Background(), token, soft) } @@ -93,15 +94,15 @@ func (s *FakeUserAuthTokenService) ActiveTokenCount(ctx context.Context) (int64, return s.ActiveAuthTokenCount(context.Background()) } -func (s *FakeUserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) { +func (s *FakeUserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*auth.UserToken, error) { return s.GetUserTokenProvider(context.Background(), userId, userTokenId) } -func (s *FakeUserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) ([]*models.UserToken, error) { +func (s *FakeUserAuthTokenService) GetUserTokens(ctx context.Context, userId int64) ([]*auth.UserToken, error) { return s.GetUserTokensProvider(context.Background(), userId) } -func (s *FakeUserAuthTokenService) GetUserRevokedTokens(ctx context.Context, userId int64) ([]*models.UserToken, error) { +func (s *FakeUserAuthTokenService) GetUserRevokedTokens(ctx context.Context, userId int64) ([]*auth.UserToken, error) { return s.GetUserRevokedTokensProvider(context.Background(), userId) } diff --git a/pkg/services/contexthandler/auth_proxy_test.go b/pkg/services/contexthandler/auth_proxy_test.go index 307ca2b51bc..dc697eff484 100644 --- a/pkg/services/contexthandler/auth_proxy_test.go +++ b/pkg/services/contexthandler/auth_proxy_test.go @@ -13,7 +13,7 @@ import ( "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" "github.com/grafana/grafana/pkg/services/login/loginservice" "github.com/grafana/grafana/pkg/services/org/orgtest" @@ -80,7 +80,7 @@ func getContextHandler(t *testing.T) *ContextHandler { cfg.AuthProxyHeaderProperty = "username" remoteCacheSvc, err := remotecache.ProvideService(cfg, sqlStore) require.NoError(t, err) - userAuthTokenSvc := auth.NewFakeUserAuthTokenService() + userAuthTokenSvc := authtest.NewFakeUserAuthTokenService() renderSvc := &fakeRenderService{} authJWTSvc := models.NewFakeJWTService() tracer := tracing.InitializeTracerForTest() diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index c8d67f165c0..6e2be2233c5 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -44,7 +44,7 @@ const ( const ServiceName = "ContextHandler" -func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtService models.JWTService, +func ProvideService(cfg *setting.Cfg, tokenService auth.UserTokenService, jwtService models.JWTService, remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore db.DB, tracer tracing.Tracer, authProxy *authproxy.AuthProxy, loginService login.Service, apiKeyService apikey.Service, authenticator loginpkg.Authenticator, userService user.Service, @@ -77,7 +77,7 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtS // ContextHandler is a middleware. type ContextHandler struct { Cfg *setting.Cfg - AuthTokenService models.UserTokenService + AuthTokenService auth.UserTokenService JWTAuthService models.JWTService RemoteCache *remotecache.RemoteCache RenderService rendering.Service @@ -474,7 +474,7 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org } err = h.AuthTokenService.RevokeToken(ctx, token, false) - if err != nil && !errors.Is(err, models.ErrUserTokenNotFound) { + if err != nil && !errors.Is(err, auth.ErrUserTokenNotFound) { reqContext.Logger.Error("failed to revoke auth token", "error", err) } return false @@ -506,8 +506,8 @@ func (h *ContextHandler) deleteInvalidCookieEndOfRequestFunc(reqContext *models. } } -func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *models.ReqContext, authTokenService models.UserTokenService, - token *models.UserToken) web.BeforeFunc { +func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *models.ReqContext, authTokenService auth.UserTokenService, + token *auth.UserToken) web.BeforeFunc { return func(w web.ResponseWriter) { // if response has already been written, skip. if w.Written() { diff --git a/pkg/services/contexthandler/contexthandler_test.go b/pkg/services/contexthandler/contexthandler_test.go index 3e61b65d0ed..fa882f2fc81 100644 --- a/pkg/services/contexthandler/contexthandler_test.go +++ b/pkg/services/contexthandler/contexthandler_test.go @@ -7,14 +7,16 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/grafana/grafana-plugin-sdk-go/backend/gtime" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authtest" "github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/web" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestDontRotateTokensOnCancelledRequests(t *testing.T) { @@ -25,15 +27,15 @@ func TestDontRotateTokensOnCancelledRequests(t *testing.T) { require.NoError(t, err) tryRotateCallCount := 0 - uts := &auth.FakeUserAuthTokenService{ - TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP, + uts := &authtest.FakeUserAuthTokenService{ + TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) { tryRotateCallCount++ return false, nil }, } - token := &models.UserToken{AuthToken: "oldtoken"} + token := &auth.UserToken{AuthToken: "oldtoken"} fn := ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token) cancel() @@ -48,8 +50,8 @@ func TestTokenRotationAtEndOfRequest(t *testing.T) { reqContext, rr, err := initTokenRotationScenario(context.Background(), t, ctxHdlr) require.NoError(t, err) - uts := &auth.FakeUserAuthTokenService{ - TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP, + uts := &authtest.FakeUserAuthTokenService{ + TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) { newToken, err := util.RandomHex(16) require.NoError(t, err) @@ -58,7 +60,7 @@ func TestTokenRotationAtEndOfRequest(t *testing.T) { }, } - token := &models.UserToken{AuthToken: "oldtoken"} + token := &auth.UserToken{AuthToken: "oldtoken"} ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token)(reqContext.Resp) diff --git a/pkg/services/ngalert/api/util_test.go b/pkg/services/ngalert/api/util_test.go index 34f4deb08e2..3ed32d797e6 100644 --- a/pkg/services/ngalert/api/util_test.go +++ b/pkg/services/ngalert/api/util_test.go @@ -12,6 +12,7 @@ import ( "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/models" accesscontrolmock "github.com/grafana/grafana/pkg/services/accesscontrol/mock" + "github.com/grafana/grafana/pkg/services/auth" models2 "github.com/grafana/grafana/pkg/services/ngalert/models" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/user" @@ -45,7 +46,7 @@ func TestAlertingProxy_createProxyContext(t *testing.T) { Req: &http.Request{}, }, SignedInUser: &user.SignedInUser{}, - UserToken: &models.UserToken{}, + UserToken: &auth.UserToken{}, IsSignedIn: rand.Int63()%2 == 1, IsRenderCall: rand.Int63()%2 == 1, AllowAnonymous: rand.Int63()%2 == 1, diff --git a/pkg/services/quota/quotaimpl/quota_test.go b/pkg/services/quota/quotaimpl/quota_test.go index 17164adc785..5d73021f2c9 100644 --- a/pkg/services/quota/quotaimpl/quota_test.go +++ b/pkg/services/quota/quotaimpl/quota_test.go @@ -14,6 +14,7 @@ import ( "github.com/grafana/grafana/pkg/services/apikey" "github.com/grafana/grafana/pkg/services/apikey/apikeyimpl" "github.com/grafana/grafana/pkg/services/auth" + "github.com/grafana/grafana/pkg/services/auth/authimpl" "github.com/grafana/grafana/pkg/services/dashboards" dashboardStore "github.com/grafana/grafana/pkg/services/dashboards/database" "github.com/grafana/grafana/pkg/services/datasources" @@ -464,7 +465,7 @@ func getQuotaBySrvTargetScope(t *testing.T, quotaService quota.Service, srv quot func setupEnv(t *testing.T, sqlStore *sqlstore.SQLStore, b bus.Bus, quotaService quota.Service) { _, err := apikeyimpl.ProvideService(sqlStore, sqlStore.Cfg, quotaService) require.NoError(t, err) - _, err = auth.ProvideActiveAuthTokenService(sqlStore.Cfg, sqlStore, quotaService) + _, err = authimpl.ProvideActiveAuthTokenService(sqlStore.Cfg, sqlStore, quotaService) require.NoError(t, err) _, err = dashboardStore.ProvideDashboardStore(sqlStore, sqlStore.Cfg, featuremgmt.WithFeatures(), tagimpl.ProvideService(sqlStore, sqlStore.Cfg), quotaService) require.NoError(t, err)