From 9c954d06ab73fff8ec8b057c4e03048f11e93e2d Mon Sep 17 00:00:00 2001 From: Misi Date: Tue, 18 Oct 2022 18:17:28 +0200 Subject: [PATCH] Auth: Refresh OAuth access_token automatically using the refresh_token (#56076) * Verify OAuth token expiration for oauth users in the ctx handler middleware * Use refresh token to get a new access token * Refactor oauth_token.go * Add tests for the middleware changes * Align other tests * Add tests, wip * Add more tests * Add InvalidateOAuthTokens method * Fix ExpiryDate update to default * Invalidate OAuth tokens during logout * Improve logout * Add more comments * Cleanup * Fix import order * Add error to HasOAuthEntry return values * add dev debug logs * Fix tests Co-authored-by: jguer --- pkg/api/common_test.go | 2 +- pkg/api/http_server.go | 4 + pkg/api/login.go | 7 + pkg/api/login_oauth.go | 10 +- pkg/api/metrics_test.go | 13 + pkg/api/pluginproxy/ds_proxy_test.go | 12 + pkg/middleware/middleware_test.go | 124 +++++- pkg/middleware/recovery_test.go | 2 +- pkg/middleware/testing.go | 1 + pkg/services/auth/testing.go | 46 +++ .../contexthandler/auth_proxy_test.go | 2 +- pkg/services/contexthandler/contexthandler.go | 91 +++-- .../authinfoservice/database/database.go | 7 +- pkg/services/oauthtoken/oauth_token.go | 197 ++++++--- pkg/services/oauthtoken/oauth_token_test.go | 374 ++++++++++++++++++ .../publicdashboards/api/common_test.go | 12 + pkg/services/query/query_test.go | 13 + 17 files changed, 828 insertions(+), 89 deletions(-) create mode 100644 pkg/services/oauthtoken/oauth_token_test.go diff --git a/pkg/api/common_test.go b/pkg/api/common_test.go index 5e59546a325..2e22fd8659b 100644 --- a/pkg/api/common_test.go +++ b/pkg/api/common_test.go @@ -213,7 +213,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, &usertest.FakeUserService{}, sqlStore) loginService := &logintest.LoginServiceFake{} authenticator := &logintest.AuthenticatorFake{} - ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake()) + ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake(), nil) return ctxHdlr } diff --git a/pkg/api/http_server.go b/pkg/api/http_server.go index 4fb0fd818a2..581e68ce6b7 100644 --- a/pkg/api/http_server.go +++ b/pkg/api/http_server.go @@ -16,6 +16,7 @@ import ( "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/middleware/csrf" "github.com/grafana/grafana/pkg/services/folder" + "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/services/querylibrary" "github.com/grafana/grafana/pkg/services/searchV2" "github.com/grafana/grafana/pkg/services/store/object" @@ -206,6 +207,7 @@ type HTTPServer struct { annotationsRepo annotations.Repository tagService tag.Service userAuthService userauth.Service + oauthTokenService oauthtoken.OAuthTokenService } type ServerOptions struct { @@ -248,6 +250,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi accesscontrolService accesscontrol.Service, dashboardThumbsService thumbs.DashboardThumbService, navTreeService navtree.Service, annotationRepo annotations.Repository, tagService tag.Service, searchv2HTTPService searchV2.SearchHTTPService, userAuthService userauth.Service, queryLibraryHTTPService querylibrary.HTTPService, queryLibraryService querylibrary.Service, + oauthTokenService oauthtoken.OAuthTokenService, ) (*HTTPServer, error) { web.Env = cfg.Env m := web.New() @@ -352,6 +355,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi userAuthService: userAuthService, QueryLibraryHTTPService: queryLibraryHTTPService, QueryLibraryService: queryLibraryService, + oauthTokenService: oauthTokenService, } if hs.Listener != nil { hs.log.Debug("Using provided listener") diff --git a/pkg/api/login.go b/pkg/api/login.go index a6da8729de1..4ab51b22d5c 100644 --- a/pkg/api/login.go +++ b/pkg/api/login.go @@ -304,6 +304,13 @@ func (hs *HTTPServer) Logout(c *models.ReqContext) { } } + // Invalidate the OAuth tokens in case the User logged in with OAuth or the last external AuthEntry is an OAuth one + if entry, exists, _ := hs.oauthTokenService.HasOAuthEntry(c.Req.Context(), c.SignedInUser); exists { + if err := hs.oauthTokenService.InvalidateOAuthTokens(c.Req.Context(), entry); err != nil { + hs.log.Warn("failed to invalidate oauth tokens for user", "userId", c.UserID, "error", err) + } + } + err := hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken, false) if err != nil && !errors.Is(err, models.ErrUserTokenNotFound) { hs.log.Error("failed to revoke auth token", "error", err) diff --git a/pkg/api/login_oauth.go b/pkg/api/login_oauth.go index b2cc9475e6f..6b570cd0c98 100644 --- a/pkg/api/login_oauth.go +++ b/pkg/api/login_oauth.go @@ -194,7 +194,15 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { // token.TokenType was defaulting to "bearer", which is out of spec, so we explicitly set to "Bearer" token.TokenType = "Bearer" - oauthLogger.Debug("OAuthLogin: got token", "expiry", fmt.Sprintf("%v", token.Expiry)) + if hs.Cfg.Env != setting.Dev { + oauthLogger.Debug("OAuthLogin: got token", "expiry", fmt.Sprintf("%v", token.Expiry)) + } else { + oauthLogger.Debug("OAuthLogin: got token", + "expiry", fmt.Sprintf("%v", token.Expiry), + "access_token", fmt.Sprintf("%v", token.AccessToken), + "refresh_token", fmt.Sprintf("%v", token.RefreshToken), + ) + } // set up oauth2 client client := connect.Client(oauthCtx, token) diff --git a/pkg/api/metrics_test.go b/pkg/api/metrics_test.go index afb7fb303b6..bfbc222b71b 100644 --- a/pkg/api/metrics_test.go +++ b/pkg/api/metrics_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" + "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/plugins/backendplugin" pluginClient "github.com/grafana/grafana/pkg/plugins/manager/client" @@ -56,6 +57,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) return ts.passThruEnabled } +func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) { + return nil, false, nil +} + +func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error { + return nil +} + +func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error { + return nil +} + // `/ds/query` endpoint test func TestAPIEndpoint_Metrics_QueryMetricsV2(t *testing.T) { qds := query.ProvideService( diff --git a/pkg/api/pluginproxy/ds_proxy_test.go b/pkg/api/pluginproxy/ds_proxy_test.go index fdbe5138d36..9e9d2df185a 100644 --- a/pkg/api/pluginproxy/ds_proxy_test.go +++ b/pkg/api/pluginproxy/ds_proxy_test.go @@ -1065,3 +1065,15 @@ func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user * func (m *mockOAuthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool { return m.oAuthEnabled } + +func (m *mockOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) { + return nil, false, nil +} + +func (m *mockOAuthTokenService) TryTokenRefresh(context.Context, *models.UserAuth) error { + return nil +} + +func (m *mockOAuthTokenService) InvalidateOAuthTokens(context.Context, *models.UserAuth) error { + return nil +} diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index c47ab96e1f3..d6c6b4d33b6 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "errors" "fmt" "io" "net" @@ -339,6 +340,123 @@ func TestMiddlewareContext(t *testing.T) { assert.Nil(t, sc.context.UserToken) }) + middlewareScenario(t, "Non-expired auth token in cookie and non-expired OAuth access token", func( + t *testing.T, sc *scenarioContext) { + const userID int64 = 12 + sc.contextHandler.GetTime = fakeGetTime() + + sc.withTokenSessionCookie("token") + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID} + sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(11 * time.Second)} + + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { + return &models.UserToken{ + UserId: userID, + UnhashedToken: unhashedToken, + }, nil + } + + sc.fakeReq("GET", "/").exec() + + require.NotNil(t, sc.context) + require.NotNil(t, sc.context.UserToken) + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, userID, sc.context.UserID) + assert.Equal(t, userID, sc.context.UserToken.UserId) + assert.Equal(t, "token", sc.context.UserToken.UnhashedToken) + assert.Empty(t, sc.resp.Header().Get("Set-Cookie")) + }) + + middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token fails", func( + t *testing.T, sc *scenarioContext) { + const userID int64 = 12 + sc.contextHandler.GetTime = fakeGetTime() + + sc.withTokenSessionCookie("token") + signedInUser := &user.SignedInUser{OrgID: 2, UserID: userID} + sc.userService.ExpectedSignedInUser = signedInUser + sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{ + UserId: userID, + OAuthExpiry: fakeGetTime()().Add(-1 * time.Second), + OAuthAccessToken: "access_token", + OAuthRefreshToken: "refresh_token"} + sc.oauthTokenService.ExpectedErrors = map[string]error{"TryTokenRefresh": errors.New("error")} + + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { + return &models.UserToken{ + UserId: userID, + UnhashedToken: unhashedToken, + }, nil + } + + sc.fakeReq("GET", "/").exec() + + token := sc.oauthTokenService.GetCurrentOAuthToken(sc.context.Req.Context(), signedInUser) + assert.Equal(t, token.AccessToken, "") + assert.Equal(t, token.RefreshToken, "") + assert.True(t, token.Expiry.IsZero()) + + require.NotNil(t, sc.context) + require.Nil(t, sc.context.UserToken) + assert.False(t, sc.context.IsSignedIn) + assert.Equal(t, int64(0), sc.context.UserID) + assert.Equal(t, "grafana_session=; Path=/; Max-Age=0; HttpOnly", sc.resp.Header().Get("Set-Cookie")) + }) + + middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token succeeds", func( + t *testing.T, sc *scenarioContext) { + const userID int64 = 12 + sc.contextHandler.GetTime = fakeGetTime() + + sc.withTokenSessionCookie("token") + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID} + sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(-5 * time.Second), OAuthRefreshToken: "refreshtoken"} + + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { + return &models.UserToken{ + UserId: userID, + UnhashedToken: unhashedToken, + }, nil + } + + sc.fakeReq("GET", "/").exec() + + require.NotNil(t, sc.context) + require.NotNil(t, sc.context.UserToken) + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, userID, sc.context.UserID) + assert.Equal(t, userID, sc.context.UserToken.UserId) + assert.Equal(t, "token", sc.context.UserToken.UnhashedToken) + assert.Empty(t, sc.resp.Header().Get("Set-Cookie")) + }) + + middlewareScenario(t, "Non-expired auth token in cookie and OAuth Access Token's Expiry is not set", func( + t *testing.T, sc *scenarioContext) { + const userID int64 = 12 + sc.contextHandler.GetTime = fakeGetTime() + + sc.withTokenSessionCookie("token") + sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID} + sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID} + + sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { + return &models.UserToken{ + UserId: userID, + UnhashedToken: unhashedToken, + }, nil + } + + sc.fakeReq("GET", "/").exec() + + require.NotNil(t, sc.context) + require.NotNil(t, sc.context.UserToken) + assert.True(t, sc.context.IsSignedIn) + assert.Equal(t, userID, sc.context.UserID) + assert.Equal(t, userID, sc.context.UserToken.UserId) + assert.Equal(t, "token", sc.context.UserToken.UnhashedToken) + assert.Empty(t, sc.resp.Header().Get("Set-Cookie")) + }) + middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) { sc.mockSQLStore.ExpectedOrg = &models.Org{Id: 1, Name: sc.cfg.AnonymousOrgName} sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName} @@ -655,7 +773,8 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func( sc.userService = usertest.NewUserServiceFake() sc.orgService = orgtest.NewOrgServiceFake() sc.apiKeyService = &apikeytest.Service{} - ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService) + sc.oauthTokenService = &auth.FakeOAuthTokenService{} + ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService, sc.oauthTokenService) sc.sqlStore = ctxHdlr.SQLStore sc.contextHandler = ctxHdlr sc.m.Use(ctxHdlr.Middleware) @@ -691,6 +810,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func( func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock, loginService *loginservice.LoginServiceMock, apiKeyService *apikeytest.Service, userService *usertest.FakeUserService, orgService *orgtest.FakeOrgService, + oauthTokenService *auth.FakeOAuthTokenService, ) *contexthandler.ContextHandler { t.Helper() @@ -708,7 +828,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.S tracer := tracing.InitializeTracerForTest() authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, userService, mockSQLStore) authenticator := &logintest.AuthenticatorFake{ExpectedUser: &user.User{}} - return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService) + return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService, oauthTokenService) } type fakeRenderService struct { diff --git a/pkg/middleware/recovery_test.go b/pkg/middleware/recovery_test.go index 4f226eef062..866eeaa490f 100644 --- a/pkg/middleware/recovery_test.go +++ b/pkg/middleware/recovery_test.go @@ -68,7 +68,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) { sc.userAuthTokenService = auth.NewFakeUserAuthTokenService() sc.remoteCacheService = remotecache.NewFakeStore(t) - contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil) + contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil, nil) sc.m.Use(contextHandler.Middleware) // mock out gc goroutine sc.m.Use(OrgRedirect(cfg, sc.userService)) diff --git a/pkg/middleware/testing.go b/pkg/middleware/testing.go index 02949e9b4ab..67b2916340a 100644 --- a/pkg/middleware/testing.go +++ b/pkg/middleware/testing.go @@ -44,6 +44,7 @@ type scenarioContext struct { loginService *loginservice.LoginServiceMock apiKeyService *apikeytest.Service userService *usertest.FakeUserService + oauthTokenService *auth.FakeOAuthTokenService orgService *orgtest.FakeOrgService req *http.Request diff --git a/pkg/services/auth/testing.go b/pkg/services/auth/testing.go index 4993c843432..63b08a9a639 100644 --- a/pkg/services/auth/testing.go +++ b/pkg/services/auth/testing.go @@ -3,9 +3,12 @@ package auth import ( "context" "net" + "time" "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/user" + "golang.org/x/oauth2" ) type FakeUserAuthTokenService struct { @@ -105,3 +108,46 @@ func (s *FakeUserAuthTokenService) GetUserRevokedTokens(ctx context.Context, use func (s *FakeUserAuthTokenService) BatchRevokeAllUserTokens(ctx context.Context, userIds []int64) error { return s.BatchRevokedTokenProvider(ctx, userIds) } + +type FakeOAuthTokenService struct { + passThruEnabled bool + ExpectedAuthUser *models.UserAuth + ExpectedErrors map[string]error +} + +func (ts *FakeOAuthTokenService) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token { + return &oauth2.Token{ + AccessToken: ts.ExpectedAuthUser.OAuthAccessToken, + RefreshToken: ts.ExpectedAuthUser.OAuthRefreshToken, + Expiry: ts.ExpectedAuthUser.OAuthExpiry, + TokenType: ts.ExpectedAuthUser.OAuthTokenType, + } +} + +func (ts *FakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) bool { + return ts.passThruEnabled +} + +func (ts *FakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) { + if ts.ExpectedAuthUser != nil { + return ts.ExpectedAuthUser, true, nil + } + if error, ok := ts.ExpectedErrors["HasOAuthEntry"]; ok { + return nil, false, error + } + return nil, false, nil +} + +func (ts *FakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error { + ts.ExpectedAuthUser.OAuthAccessToken = "" + ts.ExpectedAuthUser.OAuthRefreshToken = "" + ts.ExpectedAuthUser.OAuthExpiry = time.Time{} + return nil +} + +func (ts *FakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error { + if err, ok := ts.ExpectedErrors["TryTokenRefresh"]; ok { + return err + } + return nil +} diff --git a/pkg/services/contexthandler/auth_proxy_test.go b/pkg/services/contexthandler/auth_proxy_test.go index 8c904689ed6..90d5832ae56 100644 --- a/pkg/services/contexthandler/auth_proxy_test.go +++ b/pkg/services/contexthandler/auth_proxy_test.go @@ -104,7 +104,7 @@ func getContextHandler(t *testing.T) *ContextHandler { return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, - &userService, orgService) + &userService, orgService, nil) } type FakeGetSignUserStore struct { diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index 2d42e2e3d54..efec9ddc1ff 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -4,6 +4,7 @@ package contexthandler import ( "context" "errors" + "fmt" "net/http" "net/url" "strconv" @@ -23,6 +24,7 @@ import ( "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" "github.com/grafana/grafana/pkg/services/contexthandler/ctxkey" "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/oauthtoken" "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/rendering" "github.com/grafana/grafana/pkg/services/sqlstore" @@ -44,39 +46,42 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtS remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore sqlstore.Store, tracer tracing.Tracer, authProxy *authproxy.AuthProxy, loginService login.Service, apiKeyService apikey.Service, authenticator loginpkg.Authenticator, userService user.Service, - orgService org.Service) *ContextHandler { + orgService org.Service, oauthTokenService oauthtoken.OAuthTokenService, +) *ContextHandler { return &ContextHandler{ - Cfg: cfg, - AuthTokenService: tokenService, - JWTAuthService: jwtService, - RemoteCache: remoteCache, - RenderService: renderService, - SQLStore: sqlStore, - tracer: tracer, - authProxy: authProxy, - authenticator: authenticator, - loginService: loginService, - apiKeyService: apiKeyService, - userService: userService, - orgService: orgService, + Cfg: cfg, + AuthTokenService: tokenService, + JWTAuthService: jwtService, + RemoteCache: remoteCache, + RenderService: renderService, + SQLStore: sqlStore, + tracer: tracer, + authProxy: authProxy, + authenticator: authenticator, + loginService: loginService, + apiKeyService: apiKeyService, + userService: userService, + orgService: orgService, + oauthTokenService: oauthTokenService, } } // ContextHandler is a middleware. type ContextHandler struct { - Cfg *setting.Cfg - AuthTokenService models.UserTokenService - JWTAuthService models.JWTService - RemoteCache *remotecache.RemoteCache - RenderService rendering.Service - SQLStore sqlstore.Store - tracer tracing.Tracer - authProxy *authproxy.AuthProxy - authenticator loginpkg.Authenticator - loginService login.Service - apiKeyService apikey.Service - userService user.Service - orgService org.Service + Cfg *setting.Cfg + AuthTokenService models.UserTokenService + JWTAuthService models.JWTService + RemoteCache *remotecache.RemoteCache + RenderService rendering.Service + SQLStore sqlstore.Store + tracer tracing.Tracer + authProxy *authproxy.AuthProxy + authenticator loginpkg.Authenticator + loginService login.Service + apiKeyService apikey.Service + userService user.Service + orgService org.Service + oauthTokenService oauthtoken.OAuthTokenService // GetTime returns the current time. // Stubbable by tests. GetTime func() time.Time @@ -428,6 +433,38 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org return false } + getTime := h.GetTime + if getTime == nil { + getTime = time.Now + } + + // Check whether the logged in User has a token (whether the User used an OAuth provider to login) + oauthToken, exists, _ := h.oauthTokenService.HasOAuthEntry(ctx, queryResult) + if exists { + // Skip where the OAuthExpiry is default/zero/unset + if !oauthToken.OAuthExpiry.IsZero() && oauthToken.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime()) { + reqContext.Logger.Info("access token expired", "userId", query.UserID, "expiry", fmt.Sprintf("%v", oauthToken.OAuthExpiry)) + + // If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and Invalidate the OAuth tokens + if err = h.oauthTokenService.TryTokenRefresh(ctx, oauthToken); err != nil { + if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) { + reqContext.Logger.Error("could not fetch a new access token", "userId", oauthToken.UserId, "error", err) + } + + reqContext.Resp.Before(h.deleteInvalidCookieEndOfRequestFunc(reqContext)) + if err = h.oauthTokenService.InvalidateOAuthTokens(ctx, oauthToken); err != nil { + reqContext.Logger.Error("could not invalidate OAuth tokens", "userId", oauthToken.UserId, "error", err) + } + + err = h.AuthTokenService.RevokeToken(ctx, token, false) + if err != nil && !errors.Is(err, models.ErrUserTokenNotFound) { + reqContext.Logger.Error("failed to revoke auth token", "error", err) + } + return false + } + } + } + reqContext.SignedInUser = queryResult reqContext.IsSignedIn = true reqContext.UserToken = token diff --git a/pkg/services/login/authinfoservice/database/database.go b/pkg/services/login/authinfoservice/database/database.go index 7ceb1143211..04f88e12030 100644 --- a/pkg/services/login/authinfoservice/database/database.go +++ b/pkg/services/login/authinfoservice/database/database.go @@ -204,13 +204,8 @@ func (s *AuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAu authUser.OAuthExpiry = cmd.OAuthToken.Expiry } - cond := &models.UserAuth{ - UserId: cmd.UserId, - AuthModule: cmd.AuthModule, - } - return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error { - upd, err := sess.Update(authUser, cond) + upd, err := sess.MustCols("o_auth_expiry").Where("user_id = ? AND auth_module = ?", cmd.UserId, cmd.AuthModule).Update(authUser) s.logger.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_module", cmd.AuthModule, "rows", upd) return err }) diff --git a/pkg/services/oauthtoken/oauth_token.go b/pkg/services/oauthtoken/oauth_token.go index d258d0ff259..250a8f87ead 100644 --- a/pkg/services/oauthtoken/oauth_token.go +++ b/pkg/services/oauthtoken/oauth_token.go @@ -3,8 +3,12 @@ package oauthtoken import ( "context" "errors" + "fmt" + "strings" + "time" "golang.org/x/oauth2" + "golang.org/x/sync/singleflight" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/login/social" @@ -12,26 +16,39 @@ import ( "github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/setting" ) var ( logger = log.New("oauthtoken") + // ExpiryDelta is used to prevent any issue that is caused by the clock skew (server times can differ slightly between different machines). + // Shouldn't be more than 30s + ExpiryDelta = 10 * time.Second + ErrNoRefreshTokenFound = errors.New("no refresh token found") + ErrNotAnOAuthProvider = errors.New("not an oauth provider") ) type Service struct { - SocialService social.Service - AuthInfoService login.AuthInfoService + Cfg *setting.Cfg + SocialService social.Service + AuthInfoService login.AuthInfoService + singleFlightGroup *singleflight.Group } type OAuthTokenService interface { GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token IsOAuthPassThruEnabled(*datasources.DataSource) bool + HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) + TryTokenRefresh(context.Context, *models.UserAuth) error + InvalidateOAuthTokens(context.Context, *models.UserAuth) error } -func ProvideService(socialService social.Service, authInfoService login.AuthInfoService) *Service { +func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg) *Service { return &Service{ - SocialService: socialService, - AuthInfoService: authInfoService, + Cfg: cfg, + SocialService: socialService, + AuthInfoService: authInfoService, + singleFlightGroup: new(singleflight.Group), } } @@ -46,59 +63,17 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs if err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery); err != nil { if errors.Is(err, user.ErrUserNotFound) { // Not necessarily an error. User may be logged in another way. - logger.Debug("no OAuth token for user found", "userId", usr.UserID, "username", usr.Login) + logger.Debug("no oauth token for user found", "userId", usr.UserID, "username", usr.Login) } else { - logger.Error("failed to get OAuth token for user", "userId", usr.UserID, "username", usr.Login, "error", err) + logger.Error("failed to get oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err) } return nil } - authProvider := authInfoQuery.Result.AuthModule - connect, err := o.SocialService.GetConnector(authProvider) + token, err := o.tryGetOrRefreshAccessToken(ctx, authInfoQuery.Result) if err != nil { - logger.Error("failed to get OAuth connector", "provider", authProvider, "error", err) return nil } - - client, err := o.SocialService.GetOAuthHttpClient(authProvider) - if err != nil { - logger.Error("failed to get OAuth http client", "provider", authProvider, "error", err) - return nil - } - ctx = context.WithValue(ctx, oauth2.HTTPClient, client) - - persistedToken := &oauth2.Token{ - AccessToken: authInfoQuery.Result.OAuthAccessToken, - Expiry: authInfoQuery.Result.OAuthExpiry, - RefreshToken: authInfoQuery.Result.OAuthRefreshToken, - TokenType: authInfoQuery.Result.OAuthTokenType, - } - - if authInfoQuery.Result.OAuthIdToken != "" { - persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": authInfoQuery.Result.OAuthIdToken}) - } - - // TokenSource handles refreshing the token if it has expired - token, err := connect.TokenSource(ctx, persistedToken).Token() - if err != nil { - logger.Error("failed to retrieve OAuth access token", "provider", authInfoQuery.Result.AuthModule, "userId", usr.UserID, "username", usr.Login, "error", err) - return nil - } - - // If the tokens are not the same, update the entry in the DB - if !tokensEq(persistedToken, token) { - updateAuthCommand := &models.UpdateAuthInfoCommand{ - UserId: authInfoQuery.Result.UserId, - AuthModule: authInfoQuery.Result.AuthModule, - AuthId: authInfoQuery.Result.AuthId, - OAuthToken: token, - } - if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil { - logger.Error("failed to update auth info during token refresh", "userId", usr.UserID, "username", usr.Login, "error", err) - return nil - } - logger.Debug("updated OAuth info for user", "userId", usr.UserID, "username", usr.Login) - } return token } @@ -107,6 +82,128 @@ func (o *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool { return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool() } +// HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User +func (o *Service) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error) { + if usr == nil { + // No user, therefore no token + return nil, false, nil + } + + authInfoQuery := &models.GetAuthInfoQuery{UserId: usr.UserID} + err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery) + if err != nil { + if errors.Is(err, user.ErrUserNotFound) { + // Not necessarily an error. User may be logged in another way. + return nil, false, nil + } + logger.Error("failed to fetch oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err) + return nil, false, err + } + if !strings.Contains(authInfoQuery.Result.AuthModule, "oauth") { + return nil, false, nil + } + return authInfoQuery.Result, true, nil +} + +// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful +// It uses a singleflight.Group to prevent getting the Refresh Token multiple times for a given User +func (o *Service) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error { + lockKey := fmt.Sprintf("oauth-refresh-token-%d", usr.UserId) + _, err, _ := o.singleFlightGroup.Do(lockKey, func() (interface{}, error) { + logger.Debug("singleflight request for getting a new access token", "key", lockKey) + authProvider := usr.AuthModule + + if !strings.Contains(authProvider, "oauth") { + logger.Error("the specified user's auth provider is not oauth", "authmodule", usr.AuthModule, "userid", usr.UserId) + return nil, ErrNotAnOAuthProvider + } + + if usr.OAuthRefreshToken == "" { + logger.Debug("no refresh token available", "authmodule", usr.AuthModule, "userid", usr.UserId) + return nil, ErrNoRefreshTokenFound + } + + return o.tryGetOrRefreshAccessToken(ctx, usr) + }) + return err +} + +// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero +func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error { + return o.AuthInfoService.UpdateAuthInfo(ctx, &models.UpdateAuthInfoCommand{ + UserId: usr.UserId, + AuthModule: usr.AuthModule, + AuthId: usr.AuthId, + OAuthToken: &oauth2.Token{ + AccessToken: "", + RefreshToken: "", + Expiry: time.Time{}, + }, + }) +} + +func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *models.UserAuth) (*oauth2.Token, error) { + authProvider := usr.AuthModule + connect, err := o.SocialService.GetConnector(authProvider) + if err != nil { + logger.Error("failed to get oauth connector", "provider", authProvider, "error", err) + return nil, err + } + + client, err := o.SocialService.GetOAuthHttpClient(authProvider) + if err != nil { + logger.Error("failed to get oauth http client", "provider", authProvider, "error", err) + return nil, err + } + ctx = context.WithValue(ctx, oauth2.HTTPClient, client) + + persistedToken := &oauth2.Token{ + AccessToken: usr.OAuthAccessToken, + Expiry: usr.OAuthExpiry, + RefreshToken: usr.OAuthRefreshToken, + TokenType: usr.OAuthTokenType, + } + + if usr.OAuthIdToken != "" { + persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": usr.OAuthIdToken}) + } + + // TokenSource handles refreshing the token if it has expired + token, err := connect.TokenSource(ctx, persistedToken).Token() + if err != nil { + logger.Error("failed to retrieve oauth access token", "provider", usr.AuthModule, "userId", usr.UserId, "error", err) + return nil, err + } + + // If the tokens are not the same, update the entry in the DB + if !tokensEq(persistedToken, token) { + updateAuthCommand := &models.UpdateAuthInfoCommand{ + UserId: usr.UserId, + AuthModule: usr.AuthModule, + AuthId: usr.AuthId, + OAuthToken: token, + } + + if o.Cfg.Env == setting.Dev { + logger.Debug("oauth got token", + "user", usr.UserId, + "auth_module", usr.AuthModule, + "expiry", fmt.Sprintf("%v", token.Expiry), + "access_token", fmt.Sprintf("%v", token.AccessToken), + "refresh_token", fmt.Sprintf("%v", token.RefreshToken), + ) + } + + if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil { + logger.Error("failed to update auth info during token refresh", "userId", usr.UserId, "error", err) + return nil, err + } + logger.Debug("updated oauth info for user", "userId", usr.UserId) + } + + return token, nil +} + // tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in func tokensEq(t1, t2 *oauth2.Token) bool { return t1.AccessToken == t2.AccessToken && diff --git a/pkg/services/oauthtoken/oauth_token_test.go b/pkg/services/oauthtoken/oauth_token_test.go new file mode 100644 index 00000000000..ae761f0580f --- /dev/null +++ b/pkg/services/oauthtoken/oauth_token_test.go @@ -0,0 +1,374 @@ +package oauthtoken + +import ( + "context" + "errors" + "net/http" + "reflect" + "testing" + "time" + + "github.com/grafana/grafana/pkg/infra/usagestats" + "github.com/grafana/grafana/pkg/login/social" + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/login/authinfoservice" + "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/setting" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/oauth2" + "golang.org/x/sync/singleflight" +) + +func TestService_HasOAuthEntry(t *testing.T) { + testCases := []struct { + name string + user *user.SignedInUser + want *models.UserAuth + wantExist bool + wantErr bool + err error + getAuthInfoErr error + getAuthInfoUser models.UserAuth + }{ + { + name: "returns false without an error in case user is nil", + user: nil, + want: nil, + wantExist: false, + wantErr: false, + }, + { + name: "returns false and an error in case GetAuthInfo returns an error", + user: &user.SignedInUser{}, + want: nil, + wantExist: false, + wantErr: true, + getAuthInfoErr: errors.New("error"), + }, + { + name: "returns false without an error in case auth entry is not found", + user: &user.SignedInUser{}, + want: nil, + wantExist: false, + wantErr: false, + getAuthInfoErr: user.ErrUserNotFound, + }, + { + name: "returns false without an error in case the auth entry is not oauth", + user: &user.SignedInUser{}, + want: nil, + wantExist: false, + wantErr: false, + getAuthInfoUser: models.UserAuth{AuthModule: "auth_saml"}, + }, + { + name: "returns true when the auth entry is found", + user: &user.SignedInUser{}, + want: &models.UserAuth{AuthModule: "oauth_generic_oauth"}, + wantExist: true, + wantErr: false, + getAuthInfoUser: models.UserAuth{AuthModule: "oauth_generic_oauth"}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + srv, authInfoStore, _ := setupOAuthTokenService(t) + authInfoStore.ExpectedOAuth = &tc.getAuthInfoUser + authInfoStore.ExpectedError = tc.getAuthInfoErr + + entry, exists, err := srv.HasOAuthEntry(context.Background(), tc.user) + + if tc.wantErr { + assert.Error(t, err) + } + + if tc.want != nil { + assert.True(t, reflect.DeepEqual(tc.want, entry)) + } + assert.Equal(t, tc.wantExist, exists) + }) + } +} + +func TestService_TryTokenRefresh_ValidToken(t *testing.T) { + srv, authInfoStore, socialConnector := setupOAuthTokenService(t) + ctx := context.Background() + token := &oauth2.Token{ + AccessToken: "testaccess", + RefreshToken: "testrefresh", + Expiry: time.Now(), + TokenType: "Bearer", + } + usr := &models.UserAuth{ + AuthModule: "oauth_generic_oauth", + OAuthAccessToken: token.AccessToken, + OAuthRefreshToken: token.RefreshToken, + OAuthExpiry: token.Expiry, + OAuthTokenType: token.TokenType, + } + + authInfoStore.ExpectedOAuth = usr + + socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)) + + err := srv.TryTokenRefresh(ctx, usr) + assert.Nil(t, err) + socialConnector.AssertNumberOfCalls(t, "TokenSource", 1) + + authInfoQuery := &models.GetAuthInfoQuery{} + err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery) + + assert.Nil(t, err) + + // User's token data had not been updated + resultUsr := authInfoQuery.Result + assert.Equal(t, resultUsr.OAuthAccessToken, token.AccessToken) + assert.Equal(t, resultUsr.OAuthExpiry, token.Expiry) + assert.Equal(t, resultUsr.OAuthRefreshToken, token.RefreshToken) + assert.Equal(t, resultUsr.OAuthTokenType, token.TokenType) +} + +func TestService_TryTokenRefresh_NoRefreshToken(t *testing.T) { + srv, _, socialConnector := setupOAuthTokenService(t) + ctx := context.Background() + token := &oauth2.Token{ + AccessToken: "testaccess", + RefreshToken: "", + Expiry: time.Now().Add(-time.Hour), + TokenType: "Bearer", + } + usr := &models.UserAuth{ + AuthModule: "oauth_generic_oauth", + OAuthAccessToken: token.AccessToken, + OAuthRefreshToken: token.RefreshToken, + OAuthExpiry: token.Expiry, + OAuthTokenType: token.TokenType, + } + + socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)) + + err := srv.TryTokenRefresh(ctx, usr) + + assert.NotNil(t, err) + assert.ErrorIs(t, err, ErrNoRefreshTokenFound) + + socialConnector.AssertNotCalled(t, "TokenSource") +} + +func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) { + srv, authInfoStore, socialConnector := setupOAuthTokenService(t) + ctx := context.Background() + token := &oauth2.Token{ + AccessToken: "testaccess", + RefreshToken: "testrefresh", + Expiry: time.Now().Add(-time.Hour), + TokenType: "Bearer", + } + + newToken := &oauth2.Token{ + AccessToken: "testaccess_new", + RefreshToken: "testrefresh_new", + Expiry: time.Now().Add(time.Hour), + TokenType: "Bearer", + } + + usr := &models.UserAuth{ + AuthModule: "oauth_generic_oauth", + OAuthAccessToken: token.AccessToken, + OAuthRefreshToken: token.RefreshToken, + OAuthExpiry: token.Expiry, + OAuthTokenType: token.TokenType, + } + + authInfoStore.ExpectedOAuth = usr + + socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.ReuseTokenSource(token, oauth2.StaticTokenSource(newToken)), nil) + + err := srv.TryTokenRefresh(ctx, usr) + + assert.Nil(t, err) + socialConnector.AssertNumberOfCalls(t, "TokenSource", 1) + + authInfoQuery := &models.GetAuthInfoQuery{} + err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery) + + assert.Nil(t, err) + + // newToken should be returned after the .Token() call, therefore the User had to be updated + assert.Equal(t, authInfoQuery.Result.OAuthAccessToken, newToken.AccessToken) + assert.Equal(t, authInfoQuery.Result.OAuthExpiry, newToken.Expiry) + assert.Equal(t, authInfoQuery.Result.OAuthRefreshToken, newToken.RefreshToken) + assert.Equal(t, authInfoQuery.Result.OAuthTokenType, newToken.TokenType) +} + +func TestService_TryTokenRefresh_DifferentAuthModuleForUser(t *testing.T) { + srv, _, socialConnector := setupOAuthTokenService(t) + ctx := context.Background() + token := &oauth2.Token{} + usr := &models.UserAuth{ + AuthModule: "auth.saml", + } + + socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token)) + + err := srv.TryTokenRefresh(ctx, usr) + + assert.NotNil(t, err) + assert.ErrorIs(t, err, ErrNotAnOAuthProvider) + + socialConnector.AssertNotCalled(t, "TokenSource") +} + +func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *MockSocialConnector) { + t.Helper() + + socialConnector := &MockSocialConnector{} + socialService := &FakeSocialService{ + connector: socialConnector, + } + + authInfoStore := &FakeAuthInfoStore{} + authInfoService := authinfoservice.ProvideAuthInfoService(nil, authInfoStore, &usagestats.UsageStatsMock{}) + return &Service{ + Cfg: setting.NewCfg(), + SocialService: socialService, + AuthInfoService: authInfoService, + singleFlightGroup: &singleflight.Group{}, + }, authInfoStore, socialConnector +} + +type FakeSocialService struct { + httpClient *http.Client + connector *MockSocialConnector +} + +func (fss *FakeSocialService) GetOAuthProviders() map[string]bool { + panic("not implemented") +} + +func (fss *FakeSocialService) GetOAuthHttpClient(string) (*http.Client, error) { + return fss.httpClient, nil +} + +func (fss *FakeSocialService) GetConnector(string) (social.SocialConnector, error) { + return fss.connector, nil +} + +func (fss *FakeSocialService) GetOAuthInfoProvider(string) *social.OAuthInfo { + panic("not implemented") +} + +func (fss *FakeSocialService) GetOAuthInfoProviders() map[string]*social.OAuthInfo { + panic("not implemented") +} + +type MockSocialConnector struct { + mock.Mock +} + +func (m *MockSocialConnector) Type() int { + args := m.Called() + return args.Int(0) +} + +func (m *MockSocialConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { + args := m.Called(client, token) + return args.Get(0).(*social.BasicUserInfo), args.Error(1) +} + +func (m *MockSocialConnector) IsEmailAllowed(email string) bool { + panic("not implemented") +} + +func (m *MockSocialConnector) IsSignupAllowed() bool { + panic("not implemented") +} + +func (m *MockSocialConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + panic("not implemented") +} + +func (m *MockSocialConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + panic("not implemented") +} + +func (m *MockSocialConnector) Client(ctx context.Context, t *oauth2.Token) *http.Client { + panic("not implemented") +} + +func (m *MockSocialConnector) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource { + args := m.Called(ctx, t) + return args.Get(0).(oauth2.TokenSource) +} + +type FakeAuthInfoStore struct { + ExpectedError error + ExpectedUser *user.User + ExpectedOAuth *models.UserAuth + ExpectedDuplicateUserEntries int + ExpectedHasDuplicateUserEntries int + ExpectedLoginStats login.LoginStats +} + +func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *models.GetExternalUserInfoByLoginQuery) error { + return f.ExpectedError +} + +func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *models.GetAuthInfoQuery) error { + query.Result = f.ExpectedOAuth + return f.ExpectedError +} + +func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error { + return f.ExpectedError +} + +func (f *FakeAuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *models.UserAuth) error { + return f.ExpectedError +} + +func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error { + f.ExpectedOAuth.OAuthAccessToken = cmd.OAuthToken.AccessToken + f.ExpectedOAuth.OAuthExpiry = cmd.OAuthToken.Expiry + f.ExpectedOAuth.OAuthTokenType = cmd.OAuthToken.TokenType + f.ExpectedOAuth.OAuthRefreshToken = cmd.OAuthToken.RefreshToken + return f.ExpectedError +} + +func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAuthInfoCommand) error { + return f.ExpectedError +} + +func (f *FakeAuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) { + return f.ExpectedUser, f.ExpectedError +} + +func (f *FakeAuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) { + return f.ExpectedUser, f.ExpectedError +} + +func (f *FakeAuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) { + return f.ExpectedUser, f.ExpectedError +} + +func (f *FakeAuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]interface{}, error) { + var res = make(map[string]interface{}) + res["stats.users.duplicate_user_entries"] = f.ExpectedDuplicateUserEntries + res["stats.users.has_duplicate_user_entries"] = f.ExpectedHasDuplicateUserEntries + res["stats.users.duplicate_user_entries_by_login"] = 0 + res["stats.users.has_duplicate_user_entries_by_login"] = 0 + res["stats.users.duplicate_user_entries_by_email"] = 0 + res["stats.users.has_duplicate_user_entries_by_email"] = 0 + res["stats.users.mixed_cased_users"] = f.ExpectedLoginStats.MixedCasedUsers + return res, f.ExpectedError +} + +func (f *FakeAuthInfoStore) RunMetricsCollection(ctx context.Context) error { + return f.ExpectedError +} + +func (f *FakeAuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, error) { + return f.ExpectedLoginStats, f.ExpectedError +} diff --git a/pkg/services/publicdashboards/api/common_test.go b/pkg/services/publicdashboards/api/common_test.go index 0e64857dbe3..f49c3f983b5 100644 --- a/pkg/services/publicdashboards/api/common_test.go +++ b/pkg/services/publicdashboards/api/common_test.go @@ -164,6 +164,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) return ts.passThruEnabled } +func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) { + return nil, false, nil +} + +func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error { + return nil +} + +func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error { + return nil +} + // copied from pkg/api/plugins_test.go type fakePluginClient struct { plugins.Client diff --git a/pkg/services/query/query_test.go b/pkg/services/query/query_test.go index 697bfe64848..e0c82ba4849 100644 --- a/pkg/services/query/query_test.go +++ b/pkg/services/query/query_test.go @@ -8,6 +8,7 @@ import ( "github.com/grafana/grafana-plugin-sdk-go/backend" "github.com/grafana/grafana/pkg/expr" + "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/setting" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -440,6 +441,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) return ts.passThruEnabled } +func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) { + return nil, false, nil +} + +func (ts *fakeOAuthTokenService) TryTokenRefresh(context.Context, *models.UserAuth) error { + return nil +} + +func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(context.Context, *models.UserAuth) error { + return nil +} + type fakeDataSourceCache struct { ds *datasources.DataSource }