mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Refresh OAuth access_token automatically using the refresh_token (#56076)
* Verify OAuth token expiration for oauth users in the ctx handler middleware * Use refresh token to get a new access token * Refactor oauth_token.go * Add tests for the middleware changes * Align other tests * Add tests, wip * Add more tests * Add InvalidateOAuthTokens method * Fix ExpiryDate update to default * Invalidate OAuth tokens during logout * Improve logout * Add more comments * Cleanup * Fix import order * Add error to HasOAuthEntry return values * add dev debug logs * Fix tests Co-authored-by: jguer <joao.guerreiro@grafana.com>
This commit is contained in:
@@ -213,7 +213,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
|
||||
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, &usertest.FakeUserService{}, sqlStore)
|
||||
loginService := &logintest.LoginServiceFake{}
|
||||
authenticator := &logintest.AuthenticatorFake{}
|
||||
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake())
|
||||
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake(), nil)
|
||||
|
||||
return ctxHdlr
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/middleware/csrf"
|
||||
"github.com/grafana/grafana/pkg/services/folder"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/querylibrary"
|
||||
"github.com/grafana/grafana/pkg/services/searchV2"
|
||||
"github.com/grafana/grafana/pkg/services/store/object"
|
||||
@@ -206,6 +207,7 @@ type HTTPServer struct {
|
||||
annotationsRepo annotations.Repository
|
||||
tagService tag.Service
|
||||
userAuthService userauth.Service
|
||||
oauthTokenService oauthtoken.OAuthTokenService
|
||||
}
|
||||
|
||||
type ServerOptions struct {
|
||||
@@ -248,6 +250,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi
|
||||
accesscontrolService accesscontrol.Service, dashboardThumbsService thumbs.DashboardThumbService, navTreeService navtree.Service,
|
||||
annotationRepo annotations.Repository, tagService tag.Service, searchv2HTTPService searchV2.SearchHTTPService,
|
||||
userAuthService userauth.Service, queryLibraryHTTPService querylibrary.HTTPService, queryLibraryService querylibrary.Service,
|
||||
oauthTokenService oauthtoken.OAuthTokenService,
|
||||
) (*HTTPServer, error) {
|
||||
web.Env = cfg.Env
|
||||
m := web.New()
|
||||
@@ -352,6 +355,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi
|
||||
userAuthService: userAuthService,
|
||||
QueryLibraryHTTPService: queryLibraryHTTPService,
|
||||
QueryLibraryService: queryLibraryService,
|
||||
oauthTokenService: oauthTokenService,
|
||||
}
|
||||
if hs.Listener != nil {
|
||||
hs.log.Debug("Using provided listener")
|
||||
|
||||
@@ -304,6 +304,13 @@ func (hs *HTTPServer) Logout(c *models.ReqContext) {
|
||||
}
|
||||
}
|
||||
|
||||
// Invalidate the OAuth tokens in case the User logged in with OAuth or the last external AuthEntry is an OAuth one
|
||||
if entry, exists, _ := hs.oauthTokenService.HasOAuthEntry(c.Req.Context(), c.SignedInUser); exists {
|
||||
if err := hs.oauthTokenService.InvalidateOAuthTokens(c.Req.Context(), entry); err != nil {
|
||||
hs.log.Warn("failed to invalidate oauth tokens for user", "userId", c.UserID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
err := hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken, false)
|
||||
if err != nil && !errors.Is(err, models.ErrUserTokenNotFound) {
|
||||
hs.log.Error("failed to revoke auth token", "error", err)
|
||||
|
||||
@@ -194,7 +194,15 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
|
||||
// token.TokenType was defaulting to "bearer", which is out of spec, so we explicitly set to "Bearer"
|
||||
token.TokenType = "Bearer"
|
||||
|
||||
oauthLogger.Debug("OAuthLogin: got token", "expiry", fmt.Sprintf("%v", token.Expiry))
|
||||
if hs.Cfg.Env != setting.Dev {
|
||||
oauthLogger.Debug("OAuthLogin: got token", "expiry", fmt.Sprintf("%v", token.Expiry))
|
||||
} else {
|
||||
oauthLogger.Debug("OAuthLogin: got token",
|
||||
"expiry", fmt.Sprintf("%v", token.Expiry),
|
||||
"access_token", fmt.Sprintf("%v", token.AccessToken),
|
||||
"refresh_token", fmt.Sprintf("%v", token.RefreshToken),
|
||||
)
|
||||
}
|
||||
|
||||
// set up oauth2 client
|
||||
client := connect.Client(oauthCtx, token)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
"github.com/grafana/grafana/pkg/plugins/backendplugin"
|
||||
pluginClient "github.com/grafana/grafana/pkg/plugins/manager/client"
|
||||
@@ -56,6 +57,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource)
|
||||
return ts.passThruEnabled
|
||||
}
|
||||
|
||||
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// `/ds/query` endpoint test
|
||||
func TestAPIEndpoint_Metrics_QueryMetricsV2(t *testing.T) {
|
||||
qds := query.ProvideService(
|
||||
|
||||
@@ -1065,3 +1065,15 @@ func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user *
|
||||
func (m *mockOAuthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
|
||||
return m.oAuthEnabled
|
||||
}
|
||||
|
||||
func (m *mockOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (m *mockOAuthTokenService) TryTokenRefresh(context.Context, *models.UserAuth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockOAuthTokenService) InvalidateOAuthTokens(context.Context, *models.UserAuth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -339,6 +340,123 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
assert.Nil(t, sc.context.UserToken)
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and non-expired OAuth access token", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(11 * time.Second)}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||
return &models.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.NotNil(t, sc.context.UserToken)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token fails", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
signedInUser := &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.userService.ExpectedSignedInUser = signedInUser
|
||||
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{
|
||||
UserId: userID,
|
||||
OAuthExpiry: fakeGetTime()().Add(-1 * time.Second),
|
||||
OAuthAccessToken: "access_token",
|
||||
OAuthRefreshToken: "refresh_token"}
|
||||
sc.oauthTokenService.ExpectedErrors = map[string]error{"TryTokenRefresh": errors.New("error")}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||
return &models.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
token := sc.oauthTokenService.GetCurrentOAuthToken(sc.context.Req.Context(), signedInUser)
|
||||
assert.Equal(t, token.AccessToken, "")
|
||||
assert.Equal(t, token.RefreshToken, "")
|
||||
assert.True(t, token.Expiry.IsZero())
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.Nil(t, sc.context.UserToken)
|
||||
assert.False(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, int64(0), sc.context.UserID)
|
||||
assert.Equal(t, "grafana_session=; Path=/; Max-Age=0; HttpOnly", sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token succeeds", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(-5 * time.Second), OAuthRefreshToken: "refreshtoken"}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||
return &models.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.NotNil(t, sc.context.UserToken)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and OAuth Access Token's Expiry is not set", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||
return &models.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.NotNil(t, sc.context.UserToken)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.mockSQLStore.ExpectedOrg = &models.Org{Id: 1, Name: sc.cfg.AnonymousOrgName}
|
||||
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
|
||||
@@ -655,7 +773,8 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
||||
sc.userService = usertest.NewUserServiceFake()
|
||||
sc.orgService = orgtest.NewOrgServiceFake()
|
||||
sc.apiKeyService = &apikeytest.Service{}
|
||||
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService)
|
||||
sc.oauthTokenService = &auth.FakeOAuthTokenService{}
|
||||
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService, sc.oauthTokenService)
|
||||
sc.sqlStore = ctxHdlr.SQLStore
|
||||
sc.contextHandler = ctxHdlr
|
||||
sc.m.Use(ctxHdlr.Middleware)
|
||||
@@ -691,6 +810,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
||||
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock,
|
||||
loginService *loginservice.LoginServiceMock, apiKeyService *apikeytest.Service,
|
||||
userService *usertest.FakeUserService, orgService *orgtest.FakeOrgService,
|
||||
oauthTokenService *auth.FakeOAuthTokenService,
|
||||
) *contexthandler.ContextHandler {
|
||||
t.Helper()
|
||||
|
||||
@@ -708,7 +828,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.S
|
||||
tracer := tracing.InitializeTracerForTest()
|
||||
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, userService, mockSQLStore)
|
||||
authenticator := &logintest.AuthenticatorFake{ExpectedUser: &user.User{}}
|
||||
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService)
|
||||
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService, oauthTokenService)
|
||||
}
|
||||
|
||||
type fakeRenderService struct {
|
||||
|
||||
@@ -68,7 +68,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
|
||||
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
|
||||
sc.remoteCacheService = remotecache.NewFakeStore(t)
|
||||
|
||||
contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil)
|
||||
contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil, nil)
|
||||
sc.m.Use(contextHandler.Middleware)
|
||||
// mock out gc goroutine
|
||||
sc.m.Use(OrgRedirect(cfg, sc.userService))
|
||||
|
||||
@@ -44,6 +44,7 @@ type scenarioContext struct {
|
||||
loginService *loginservice.LoginServiceMock
|
||||
apiKeyService *apikeytest.Service
|
||||
userService *usertest.FakeUserService
|
||||
oauthTokenService *auth.FakeOAuthTokenService
|
||||
orgService *orgtest.FakeOrgService
|
||||
|
||||
req *http.Request
|
||||
|
||||
@@ -3,9 +3,12 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type FakeUserAuthTokenService struct {
|
||||
@@ -105,3 +108,46 @@ func (s *FakeUserAuthTokenService) GetUserRevokedTokens(ctx context.Context, use
|
||||
func (s *FakeUserAuthTokenService) BatchRevokeAllUserTokens(ctx context.Context, userIds []int64) error {
|
||||
return s.BatchRevokedTokenProvider(ctx, userIds)
|
||||
}
|
||||
|
||||
type FakeOAuthTokenService struct {
|
||||
passThruEnabled bool
|
||||
ExpectedAuthUser *models.UserAuth
|
||||
ExpectedErrors map[string]error
|
||||
}
|
||||
|
||||
func (ts *FakeOAuthTokenService) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token {
|
||||
return &oauth2.Token{
|
||||
AccessToken: ts.ExpectedAuthUser.OAuthAccessToken,
|
||||
RefreshToken: ts.ExpectedAuthUser.OAuthRefreshToken,
|
||||
Expiry: ts.ExpectedAuthUser.OAuthExpiry,
|
||||
TokenType: ts.ExpectedAuthUser.OAuthTokenType,
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *FakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) bool {
|
||||
return ts.passThruEnabled
|
||||
}
|
||||
|
||||
func (ts *FakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||
if ts.ExpectedAuthUser != nil {
|
||||
return ts.ExpectedAuthUser, true, nil
|
||||
}
|
||||
if error, ok := ts.ExpectedErrors["HasOAuthEntry"]; ok {
|
||||
return nil, false, error
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (ts *FakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
|
||||
ts.ExpectedAuthUser.OAuthAccessToken = ""
|
||||
ts.ExpectedAuthUser.OAuthRefreshToken = ""
|
||||
ts.ExpectedAuthUser.OAuthExpiry = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *FakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
|
||||
if err, ok := ts.ExpectedErrors["TryTokenRefresh"]; ok {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func getContextHandler(t *testing.T) *ContextHandler {
|
||||
|
||||
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc,
|
||||
renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator,
|
||||
&userService, orgService)
|
||||
&userService, orgService, nil)
|
||||
}
|
||||
|
||||
type FakeGetSignUserStore struct {
|
||||
|
||||
@@ -4,6 +4,7 @@ package contexthandler
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler/ctxkey"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/rendering"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
@@ -44,39 +46,42 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtS
|
||||
remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore sqlstore.Store,
|
||||
tracer tracing.Tracer, authProxy *authproxy.AuthProxy, loginService login.Service,
|
||||
apiKeyService apikey.Service, authenticator loginpkg.Authenticator, userService user.Service,
|
||||
orgService org.Service) *ContextHandler {
|
||||
orgService org.Service, oauthTokenService oauthtoken.OAuthTokenService,
|
||||
) *ContextHandler {
|
||||
return &ContextHandler{
|
||||
Cfg: cfg,
|
||||
AuthTokenService: tokenService,
|
||||
JWTAuthService: jwtService,
|
||||
RemoteCache: remoteCache,
|
||||
RenderService: renderService,
|
||||
SQLStore: sqlStore,
|
||||
tracer: tracer,
|
||||
authProxy: authProxy,
|
||||
authenticator: authenticator,
|
||||
loginService: loginService,
|
||||
apiKeyService: apiKeyService,
|
||||
userService: userService,
|
||||
orgService: orgService,
|
||||
Cfg: cfg,
|
||||
AuthTokenService: tokenService,
|
||||
JWTAuthService: jwtService,
|
||||
RemoteCache: remoteCache,
|
||||
RenderService: renderService,
|
||||
SQLStore: sqlStore,
|
||||
tracer: tracer,
|
||||
authProxy: authProxy,
|
||||
authenticator: authenticator,
|
||||
loginService: loginService,
|
||||
apiKeyService: apiKeyService,
|
||||
userService: userService,
|
||||
orgService: orgService,
|
||||
oauthTokenService: oauthTokenService,
|
||||
}
|
||||
}
|
||||
|
||||
// ContextHandler is a middleware.
|
||||
type ContextHandler struct {
|
||||
Cfg *setting.Cfg
|
||||
AuthTokenService models.UserTokenService
|
||||
JWTAuthService models.JWTService
|
||||
RemoteCache *remotecache.RemoteCache
|
||||
RenderService rendering.Service
|
||||
SQLStore sqlstore.Store
|
||||
tracer tracing.Tracer
|
||||
authProxy *authproxy.AuthProxy
|
||||
authenticator loginpkg.Authenticator
|
||||
loginService login.Service
|
||||
apiKeyService apikey.Service
|
||||
userService user.Service
|
||||
orgService org.Service
|
||||
Cfg *setting.Cfg
|
||||
AuthTokenService models.UserTokenService
|
||||
JWTAuthService models.JWTService
|
||||
RemoteCache *remotecache.RemoteCache
|
||||
RenderService rendering.Service
|
||||
SQLStore sqlstore.Store
|
||||
tracer tracing.Tracer
|
||||
authProxy *authproxy.AuthProxy
|
||||
authenticator loginpkg.Authenticator
|
||||
loginService login.Service
|
||||
apiKeyService apikey.Service
|
||||
userService user.Service
|
||||
orgService org.Service
|
||||
oauthTokenService oauthtoken.OAuthTokenService
|
||||
// GetTime returns the current time.
|
||||
// Stubbable by tests.
|
||||
GetTime func() time.Time
|
||||
@@ -428,6 +433,38 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org
|
||||
return false
|
||||
}
|
||||
|
||||
getTime := h.GetTime
|
||||
if getTime == nil {
|
||||
getTime = time.Now
|
||||
}
|
||||
|
||||
// Check whether the logged in User has a token (whether the User used an OAuth provider to login)
|
||||
oauthToken, exists, _ := h.oauthTokenService.HasOAuthEntry(ctx, queryResult)
|
||||
if exists {
|
||||
// Skip where the OAuthExpiry is default/zero/unset
|
||||
if !oauthToken.OAuthExpiry.IsZero() && oauthToken.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime()) {
|
||||
reqContext.Logger.Info("access token expired", "userId", query.UserID, "expiry", fmt.Sprintf("%v", oauthToken.OAuthExpiry))
|
||||
|
||||
// If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and Invalidate the OAuth tokens
|
||||
if err = h.oauthTokenService.TryTokenRefresh(ctx, oauthToken); err != nil {
|
||||
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
|
||||
reqContext.Logger.Error("could not fetch a new access token", "userId", oauthToken.UserId, "error", err)
|
||||
}
|
||||
|
||||
reqContext.Resp.Before(h.deleteInvalidCookieEndOfRequestFunc(reqContext))
|
||||
if err = h.oauthTokenService.InvalidateOAuthTokens(ctx, oauthToken); err != nil {
|
||||
reqContext.Logger.Error("could not invalidate OAuth tokens", "userId", oauthToken.UserId, "error", err)
|
||||
}
|
||||
|
||||
err = h.AuthTokenService.RevokeToken(ctx, token, false)
|
||||
if err != nil && !errors.Is(err, models.ErrUserTokenNotFound) {
|
||||
reqContext.Logger.Error("failed to revoke auth token", "error", err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reqContext.SignedInUser = queryResult
|
||||
reqContext.IsSignedIn = true
|
||||
reqContext.UserToken = token
|
||||
|
||||
@@ -204,13 +204,8 @@ func (s *AuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAu
|
||||
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
|
||||
}
|
||||
|
||||
cond := &models.UserAuth{
|
||||
UserId: cmd.UserId,
|
||||
AuthModule: cmd.AuthModule,
|
||||
}
|
||||
|
||||
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error {
|
||||
upd, err := sess.Update(authUser, cond)
|
||||
upd, err := sess.MustCols("o_auth_expiry").Where("user_id = ? AND auth_module = ?", cmd.UserId, cmd.AuthModule).Update(authUser)
|
||||
s.logger.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_module", cmd.AuthModule, "rows", upd)
|
||||
return err
|
||||
})
|
||||
|
||||
@@ -3,8 +3,12 @@ package oauthtoken
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
@@ -12,26 +16,39 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/datasources"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = log.New("oauthtoken")
|
||||
// ExpiryDelta is used to prevent any issue that is caused by the clock skew (server times can differ slightly between different machines).
|
||||
// Shouldn't be more than 30s
|
||||
ExpiryDelta = 10 * time.Second
|
||||
ErrNoRefreshTokenFound = errors.New("no refresh token found")
|
||||
ErrNotAnOAuthProvider = errors.New("not an oauth provider")
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
SocialService social.Service
|
||||
AuthInfoService login.AuthInfoService
|
||||
Cfg *setting.Cfg
|
||||
SocialService social.Service
|
||||
AuthInfoService login.AuthInfoService
|
||||
singleFlightGroup *singleflight.Group
|
||||
}
|
||||
|
||||
type OAuthTokenService interface {
|
||||
GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token
|
||||
IsOAuthPassThruEnabled(*datasources.DataSource) bool
|
||||
HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error)
|
||||
TryTokenRefresh(context.Context, *models.UserAuth) error
|
||||
InvalidateOAuthTokens(context.Context, *models.UserAuth) error
|
||||
}
|
||||
|
||||
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService) *Service {
|
||||
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg) *Service {
|
||||
return &Service{
|
||||
SocialService: socialService,
|
||||
AuthInfoService: authInfoService,
|
||||
Cfg: cfg,
|
||||
SocialService: socialService,
|
||||
AuthInfoService: authInfoService,
|
||||
singleFlightGroup: new(singleflight.Group),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,59 +63,17 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs
|
||||
if err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery); err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
// Not necessarily an error. User may be logged in another way.
|
||||
logger.Debug("no OAuth token for user found", "userId", usr.UserID, "username", usr.Login)
|
||||
logger.Debug("no oauth token for user found", "userId", usr.UserID, "username", usr.Login)
|
||||
} else {
|
||||
logger.Error("failed to get OAuth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
|
||||
logger.Error("failed to get oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
authProvider := authInfoQuery.Result.AuthModule
|
||||
connect, err := o.SocialService.GetConnector(authProvider)
|
||||
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfoQuery.Result)
|
||||
if err != nil {
|
||||
logger.Error("failed to get OAuth connector", "provider", authProvider, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
client, err := o.SocialService.GetOAuthHttpClient(authProvider)
|
||||
if err != nil {
|
||||
logger.Error("failed to get OAuth http client", "provider", authProvider, "error", err)
|
||||
return nil
|
||||
}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||
|
||||
persistedToken := &oauth2.Token{
|
||||
AccessToken: authInfoQuery.Result.OAuthAccessToken,
|
||||
Expiry: authInfoQuery.Result.OAuthExpiry,
|
||||
RefreshToken: authInfoQuery.Result.OAuthRefreshToken,
|
||||
TokenType: authInfoQuery.Result.OAuthTokenType,
|
||||
}
|
||||
|
||||
if authInfoQuery.Result.OAuthIdToken != "" {
|
||||
persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": authInfoQuery.Result.OAuthIdToken})
|
||||
}
|
||||
|
||||
// TokenSource handles refreshing the token if it has expired
|
||||
token, err := connect.TokenSource(ctx, persistedToken).Token()
|
||||
if err != nil {
|
||||
logger.Error("failed to retrieve OAuth access token", "provider", authInfoQuery.Result.AuthModule, "userId", usr.UserID, "username", usr.Login, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the tokens are not the same, update the entry in the DB
|
||||
if !tokensEq(persistedToken, token) {
|
||||
updateAuthCommand := &models.UpdateAuthInfoCommand{
|
||||
UserId: authInfoQuery.Result.UserId,
|
||||
AuthModule: authInfoQuery.Result.AuthModule,
|
||||
AuthId: authInfoQuery.Result.AuthId,
|
||||
OAuthToken: token,
|
||||
}
|
||||
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
|
||||
logger.Error("failed to update auth info during token refresh", "userId", usr.UserID, "username", usr.Login, "error", err)
|
||||
return nil
|
||||
}
|
||||
logger.Debug("updated OAuth info for user", "userId", usr.UserID, "username", usr.Login)
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
@@ -107,6 +82,128 @@ func (o *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
|
||||
return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool()
|
||||
}
|
||||
|
||||
// HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User
|
||||
func (o *Service) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||
if usr == nil {
|
||||
// No user, therefore no token
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
authInfoQuery := &models.GetAuthInfoQuery{UserId: usr.UserID}
|
||||
err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
if err != nil {
|
||||
if errors.Is(err, user.ErrUserNotFound) {
|
||||
// Not necessarily an error. User may be logged in another way.
|
||||
return nil, false, nil
|
||||
}
|
||||
logger.Error("failed to fetch oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
|
||||
return nil, false, err
|
||||
}
|
||||
if !strings.Contains(authInfoQuery.Result.AuthModule, "oauth") {
|
||||
return nil, false, nil
|
||||
}
|
||||
return authInfoQuery.Result, true, nil
|
||||
}
|
||||
|
||||
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
|
||||
// It uses a singleflight.Group to prevent getting the Refresh Token multiple times for a given User
|
||||
func (o *Service) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
|
||||
lockKey := fmt.Sprintf("oauth-refresh-token-%d", usr.UserId)
|
||||
_, err, _ := o.singleFlightGroup.Do(lockKey, func() (interface{}, error) {
|
||||
logger.Debug("singleflight request for getting a new access token", "key", lockKey)
|
||||
authProvider := usr.AuthModule
|
||||
|
||||
if !strings.Contains(authProvider, "oauth") {
|
||||
logger.Error("the specified user's auth provider is not oauth", "authmodule", usr.AuthModule, "userid", usr.UserId)
|
||||
return nil, ErrNotAnOAuthProvider
|
||||
}
|
||||
|
||||
if usr.OAuthRefreshToken == "" {
|
||||
logger.Debug("no refresh token available", "authmodule", usr.AuthModule, "userid", usr.UserId)
|
||||
return nil, ErrNoRefreshTokenFound
|
||||
}
|
||||
|
||||
return o.tryGetOrRefreshAccessToken(ctx, usr)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
|
||||
func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
|
||||
return o.AuthInfoService.UpdateAuthInfo(ctx, &models.UpdateAuthInfoCommand{
|
||||
UserId: usr.UserId,
|
||||
AuthModule: usr.AuthModule,
|
||||
AuthId: usr.AuthId,
|
||||
OAuthToken: &oauth2.Token{
|
||||
AccessToken: "",
|
||||
RefreshToken: "",
|
||||
Expiry: time.Time{},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *models.UserAuth) (*oauth2.Token, error) {
|
||||
authProvider := usr.AuthModule
|
||||
connect, err := o.SocialService.GetConnector(authProvider)
|
||||
if err != nil {
|
||||
logger.Error("failed to get oauth connector", "provider", authProvider, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := o.SocialService.GetOAuthHttpClient(authProvider)
|
||||
if err != nil {
|
||||
logger.Error("failed to get oauth http client", "provider", authProvider, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||
|
||||
persistedToken := &oauth2.Token{
|
||||
AccessToken: usr.OAuthAccessToken,
|
||||
Expiry: usr.OAuthExpiry,
|
||||
RefreshToken: usr.OAuthRefreshToken,
|
||||
TokenType: usr.OAuthTokenType,
|
||||
}
|
||||
|
||||
if usr.OAuthIdToken != "" {
|
||||
persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": usr.OAuthIdToken})
|
||||
}
|
||||
|
||||
// TokenSource handles refreshing the token if it has expired
|
||||
token, err := connect.TokenSource(ctx, persistedToken).Token()
|
||||
if err != nil {
|
||||
logger.Error("failed to retrieve oauth access token", "provider", usr.AuthModule, "userId", usr.UserId, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the tokens are not the same, update the entry in the DB
|
||||
if !tokensEq(persistedToken, token) {
|
||||
updateAuthCommand := &models.UpdateAuthInfoCommand{
|
||||
UserId: usr.UserId,
|
||||
AuthModule: usr.AuthModule,
|
||||
AuthId: usr.AuthId,
|
||||
OAuthToken: token,
|
||||
}
|
||||
|
||||
if o.Cfg.Env == setting.Dev {
|
||||
logger.Debug("oauth got token",
|
||||
"user", usr.UserId,
|
||||
"auth_module", usr.AuthModule,
|
||||
"expiry", fmt.Sprintf("%v", token.Expiry),
|
||||
"access_token", fmt.Sprintf("%v", token.AccessToken),
|
||||
"refresh_token", fmt.Sprintf("%v", token.RefreshToken),
|
||||
)
|
||||
}
|
||||
|
||||
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
|
||||
logger.Error("failed to update auth info during token refresh", "userId", usr.UserId, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
logger.Debug("updated oauth info for user", "userId", usr.UserId)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
|
||||
func tokensEq(t1, t2 *oauth2.Token) bool {
|
||||
return t1.AccessToken == t2.AccessToken &&
|
||||
|
||||
374
pkg/services/oauthtoken/oauth_token_test.go
Normal file
374
pkg/services/oauthtoken/oauth_token_test.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package oauthtoken
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
func TestService_HasOAuthEntry(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
user *user.SignedInUser
|
||||
want *models.UserAuth
|
||||
wantExist bool
|
||||
wantErr bool
|
||||
err error
|
||||
getAuthInfoErr error
|
||||
getAuthInfoUser models.UserAuth
|
||||
}{
|
||||
{
|
||||
name: "returns false without an error in case user is nil",
|
||||
user: nil,
|
||||
want: nil,
|
||||
wantExist: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "returns false and an error in case GetAuthInfo returns an error",
|
||||
user: &user.SignedInUser{},
|
||||
want: nil,
|
||||
wantExist: false,
|
||||
wantErr: true,
|
||||
getAuthInfoErr: errors.New("error"),
|
||||
},
|
||||
{
|
||||
name: "returns false without an error in case auth entry is not found",
|
||||
user: &user.SignedInUser{},
|
||||
want: nil,
|
||||
wantExist: false,
|
||||
wantErr: false,
|
||||
getAuthInfoErr: user.ErrUserNotFound,
|
||||
},
|
||||
{
|
||||
name: "returns false without an error in case the auth entry is not oauth",
|
||||
user: &user.SignedInUser{},
|
||||
want: nil,
|
||||
wantExist: false,
|
||||
wantErr: false,
|
||||
getAuthInfoUser: models.UserAuth{AuthModule: "auth_saml"},
|
||||
},
|
||||
{
|
||||
name: "returns true when the auth entry is found",
|
||||
user: &user.SignedInUser{},
|
||||
want: &models.UserAuth{AuthModule: "oauth_generic_oauth"},
|
||||
wantExist: true,
|
||||
wantErr: false,
|
||||
getAuthInfoUser: models.UserAuth{AuthModule: "oauth_generic_oauth"},
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv, authInfoStore, _ := setupOAuthTokenService(t)
|
||||
authInfoStore.ExpectedOAuth = &tc.getAuthInfoUser
|
||||
authInfoStore.ExpectedError = tc.getAuthInfoErr
|
||||
|
||||
entry, exists, err := srv.HasOAuthEntry(context.Background(), tc.user)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
if tc.want != nil {
|
||||
assert.True(t, reflect.DeepEqual(tc.want, entry))
|
||||
}
|
||||
assert.Equal(t, tc.wantExist, exists)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_TryTokenRefresh_ValidToken(t *testing.T) {
|
||||
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
|
||||
ctx := context.Background()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "testrefresh",
|
||||
Expiry: time.Now(),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
usr := &models.UserAuth{
|
||||
AuthModule: "oauth_generic_oauth",
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthTokenType: token.TokenType,
|
||||
}
|
||||
|
||||
authInfoStore.ExpectedOAuth = usr
|
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr)
|
||||
assert.Nil(t, err)
|
||||
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
|
||||
|
||||
authInfoQuery := &models.GetAuthInfoQuery{}
|
||||
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
|
||||
assert.Nil(t, err)
|
||||
|
||||
// User's token data had not been updated
|
||||
resultUsr := authInfoQuery.Result
|
||||
assert.Equal(t, resultUsr.OAuthAccessToken, token.AccessToken)
|
||||
assert.Equal(t, resultUsr.OAuthExpiry, token.Expiry)
|
||||
assert.Equal(t, resultUsr.OAuthRefreshToken, token.RefreshToken)
|
||||
assert.Equal(t, resultUsr.OAuthTokenType, token.TokenType)
|
||||
}
|
||||
|
||||
func TestService_TryTokenRefresh_NoRefreshToken(t *testing.T) {
|
||||
srv, _, socialConnector := setupOAuthTokenService(t)
|
||||
ctx := context.Background()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "",
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
usr := &models.UserAuth{
|
||||
AuthModule: "oauth_generic_oauth",
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthTokenType: token.TokenType,
|
||||
}
|
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.ErrorIs(t, err, ErrNoRefreshTokenFound)
|
||||
|
||||
socialConnector.AssertNotCalled(t, "TokenSource")
|
||||
}
|
||||
|
||||
func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) {
|
||||
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
|
||||
ctx := context.Background()
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "testaccess",
|
||||
RefreshToken: "testrefresh",
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
|
||||
newToken := &oauth2.Token{
|
||||
AccessToken: "testaccess_new",
|
||||
RefreshToken: "testrefresh_new",
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
TokenType: "Bearer",
|
||||
}
|
||||
|
||||
usr := &models.UserAuth{
|
||||
AuthModule: "oauth_generic_oauth",
|
||||
OAuthAccessToken: token.AccessToken,
|
||||
OAuthRefreshToken: token.RefreshToken,
|
||||
OAuthExpiry: token.Expiry,
|
||||
OAuthTokenType: token.TokenType,
|
||||
}
|
||||
|
||||
authInfoStore.ExpectedOAuth = usr
|
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.ReuseTokenSource(token, oauth2.StaticTokenSource(newToken)), nil)
|
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr)
|
||||
|
||||
assert.Nil(t, err)
|
||||
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
|
||||
|
||||
authInfoQuery := &models.GetAuthInfoQuery{}
|
||||
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||
|
||||
assert.Nil(t, err)
|
||||
|
||||
// newToken should be returned after the .Token() call, therefore the User had to be updated
|
||||
assert.Equal(t, authInfoQuery.Result.OAuthAccessToken, newToken.AccessToken)
|
||||
assert.Equal(t, authInfoQuery.Result.OAuthExpiry, newToken.Expiry)
|
||||
assert.Equal(t, authInfoQuery.Result.OAuthRefreshToken, newToken.RefreshToken)
|
||||
assert.Equal(t, authInfoQuery.Result.OAuthTokenType, newToken.TokenType)
|
||||
}
|
||||
|
||||
func TestService_TryTokenRefresh_DifferentAuthModuleForUser(t *testing.T) {
|
||||
srv, _, socialConnector := setupOAuthTokenService(t)
|
||||
ctx := context.Background()
|
||||
token := &oauth2.Token{}
|
||||
usr := &models.UserAuth{
|
||||
AuthModule: "auth.saml",
|
||||
}
|
||||
|
||||
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
||||
|
||||
err := srv.TryTokenRefresh(ctx, usr)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.ErrorIs(t, err, ErrNotAnOAuthProvider)
|
||||
|
||||
socialConnector.AssertNotCalled(t, "TokenSource")
|
||||
}
|
||||
|
||||
func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *MockSocialConnector) {
|
||||
t.Helper()
|
||||
|
||||
socialConnector := &MockSocialConnector{}
|
||||
socialService := &FakeSocialService{
|
||||
connector: socialConnector,
|
||||
}
|
||||
|
||||
authInfoStore := &FakeAuthInfoStore{}
|
||||
authInfoService := authinfoservice.ProvideAuthInfoService(nil, authInfoStore, &usagestats.UsageStatsMock{})
|
||||
return &Service{
|
||||
Cfg: setting.NewCfg(),
|
||||
SocialService: socialService,
|
||||
AuthInfoService: authInfoService,
|
||||
singleFlightGroup: &singleflight.Group{},
|
||||
}, authInfoStore, socialConnector
|
||||
}
|
||||
|
||||
type FakeSocialService struct {
|
||||
httpClient *http.Client
|
||||
connector *MockSocialConnector
|
||||
}
|
||||
|
||||
func (fss *FakeSocialService) GetOAuthProviders() map[string]bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (fss *FakeSocialService) GetOAuthHttpClient(string) (*http.Client, error) {
|
||||
return fss.httpClient, nil
|
||||
}
|
||||
|
||||
func (fss *FakeSocialService) GetConnector(string) (social.SocialConnector, error) {
|
||||
return fss.connector, nil
|
||||
}
|
||||
|
||||
func (fss *FakeSocialService) GetOAuthInfoProvider(string) *social.OAuthInfo {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (fss *FakeSocialService) GetOAuthInfoProviders() map[string]*social.OAuthInfo {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
type MockSocialConnector struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSocialConnector) Type() int {
|
||||
args := m.Called()
|
||||
return args.Int(0)
|
||||
}
|
||||
|
||||
func (m *MockSocialConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
||||
args := m.Called(client, token)
|
||||
return args.Get(0).(*social.BasicUserInfo), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSocialConnector) IsEmailAllowed(email string) bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *MockSocialConnector) IsSignupAllowed() bool {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *MockSocialConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *MockSocialConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *MockSocialConnector) Client(ctx context.Context, t *oauth2.Token) *http.Client {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *MockSocialConnector) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource {
|
||||
args := m.Called(ctx, t)
|
||||
return args.Get(0).(oauth2.TokenSource)
|
||||
}
|
||||
|
||||
type FakeAuthInfoStore struct {
|
||||
ExpectedError error
|
||||
ExpectedUser *user.User
|
||||
ExpectedOAuth *models.UserAuth
|
||||
ExpectedDuplicateUserEntries int
|
||||
ExpectedHasDuplicateUserEntries int
|
||||
ExpectedLoginStats login.LoginStats
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *models.GetExternalUserInfoByLoginQuery) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *models.GetAuthInfoQuery) error {
|
||||
query.Result = f.ExpectedOAuth
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *models.UserAuth) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error {
|
||||
f.ExpectedOAuth.OAuthAccessToken = cmd.OAuthToken.AccessToken
|
||||
f.ExpectedOAuth.OAuthExpiry = cmd.OAuthToken.Expiry
|
||||
f.ExpectedOAuth.OAuthTokenType = cmd.OAuthToken.TokenType
|
||||
f.ExpectedOAuth.OAuthRefreshToken = cmd.OAuthToken.RefreshToken
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAuthInfoCommand) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) {
|
||||
return f.ExpectedUser, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) {
|
||||
return f.ExpectedUser, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
|
||||
return f.ExpectedUser, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]interface{}, error) {
|
||||
var res = make(map[string]interface{})
|
||||
res["stats.users.duplicate_user_entries"] = f.ExpectedDuplicateUserEntries
|
||||
res["stats.users.has_duplicate_user_entries"] = f.ExpectedHasDuplicateUserEntries
|
||||
res["stats.users.duplicate_user_entries_by_login"] = 0
|
||||
res["stats.users.has_duplicate_user_entries_by_login"] = 0
|
||||
res["stats.users.duplicate_user_entries_by_email"] = 0
|
||||
res["stats.users.has_duplicate_user_entries_by_email"] = 0
|
||||
res["stats.users.mixed_cased_users"] = f.ExpectedLoginStats.MixedCasedUsers
|
||||
return res, f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) RunMetricsCollection(ctx context.Context) error {
|
||||
return f.ExpectedError
|
||||
}
|
||||
|
||||
func (f *FakeAuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, error) {
|
||||
return f.ExpectedLoginStats, f.ExpectedError
|
||||
}
|
||||
@@ -164,6 +164,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource)
|
||||
return ts.passThruEnabled
|
||||
}
|
||||
|
||||
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// copied from pkg/api/plugins_test.go
|
||||
type fakePluginClient struct {
|
||||
plugins.Client
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||
"github.com/grafana/grafana/pkg/expr"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -440,6 +441,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource)
|
||||
return ts.passThruEnabled
|
||||
}
|
||||
|
||||
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (ts *fakeOAuthTokenService) TryTokenRefresh(context.Context, *models.UserAuth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(context.Context, *models.UserAuth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeDataSourceCache struct {
|
||||
ds *datasources.DataSource
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user