Auth: Refresh OAuth access_token automatically using the refresh_token (#56076)

* Verify OAuth token expiration for oauth users in the ctx handler middleware

* Use refresh token to get a new access token

* Refactor oauth_token.go

* Add tests for the middleware changes

* Align other tests

* Add tests, wip

* Add more tests

* Add InvalidateOAuthTokens method

* Fix ExpiryDate update to default

* Invalidate OAuth tokens during logout

* Improve logout

* Add more comments

* Cleanup

* Fix import order

* Add error to HasOAuthEntry return values

* add dev debug logs

* Fix tests

Co-authored-by: jguer <joao.guerreiro@grafana.com>
This commit is contained in:
Misi
2022-10-18 18:17:28 +02:00
committed by GitHub
parent 984ec00aac
commit 9c954d06ab
17 changed files with 828 additions and 89 deletions

View File

@@ -213,7 +213,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, &usertest.FakeUserService{}, sqlStore)
loginService := &logintest.LoginServiceFake{}
authenticator := &logintest.AuthenticatorFake{}
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake())
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake(), nil)
return ctxHdlr
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/middleware/csrf"
"github.com/grafana/grafana/pkg/services/folder"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/querylibrary"
"github.com/grafana/grafana/pkg/services/searchV2"
"github.com/grafana/grafana/pkg/services/store/object"
@@ -206,6 +207,7 @@ type HTTPServer struct {
annotationsRepo annotations.Repository
tagService tag.Service
userAuthService userauth.Service
oauthTokenService oauthtoken.OAuthTokenService
}
type ServerOptions struct {
@@ -248,6 +250,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi
accesscontrolService accesscontrol.Service, dashboardThumbsService thumbs.DashboardThumbService, navTreeService navtree.Service,
annotationRepo annotations.Repository, tagService tag.Service, searchv2HTTPService searchV2.SearchHTTPService,
userAuthService userauth.Service, queryLibraryHTTPService querylibrary.HTTPService, queryLibraryService querylibrary.Service,
oauthTokenService oauthtoken.OAuthTokenService,
) (*HTTPServer, error) {
web.Env = cfg.Env
m := web.New()
@@ -352,6 +355,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi
userAuthService: userAuthService,
QueryLibraryHTTPService: queryLibraryHTTPService,
QueryLibraryService: queryLibraryService,
oauthTokenService: oauthTokenService,
}
if hs.Listener != nil {
hs.log.Debug("Using provided listener")

View File

@@ -304,6 +304,13 @@ func (hs *HTTPServer) Logout(c *models.ReqContext) {
}
}
// Invalidate the OAuth tokens in case the User logged in with OAuth or the last external AuthEntry is an OAuth one
if entry, exists, _ := hs.oauthTokenService.HasOAuthEntry(c.Req.Context(), c.SignedInUser); exists {
if err := hs.oauthTokenService.InvalidateOAuthTokens(c.Req.Context(), entry); err != nil {
hs.log.Warn("failed to invalidate oauth tokens for user", "userId", c.UserID, "error", err)
}
}
err := hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken, false)
if err != nil && !errors.Is(err, models.ErrUserTokenNotFound) {
hs.log.Error("failed to revoke auth token", "error", err)

View File

@@ -194,7 +194,15 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
// token.TokenType was defaulting to "bearer", which is out of spec, so we explicitly set to "Bearer"
token.TokenType = "Bearer"
oauthLogger.Debug("OAuthLogin: got token", "expiry", fmt.Sprintf("%v", token.Expiry))
if hs.Cfg.Env != setting.Dev {
oauthLogger.Debug("OAuthLogin: got token", "expiry", fmt.Sprintf("%v", token.Expiry))
} else {
oauthLogger.Debug("OAuthLogin: got token",
"expiry", fmt.Sprintf("%v", token.Expiry),
"access_token", fmt.Sprintf("%v", token.AccessToken),
"refresh_token", fmt.Sprintf("%v", token.RefreshToken),
)
}
// set up oauth2 client
client := connect.Client(oauthCtx, token)

View File

@@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin"
pluginClient "github.com/grafana/grafana/pkg/plugins/manager/client"
@@ -56,6 +57,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource)
return ts.passThruEnabled
}
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
return nil, false, nil
}
func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
return nil
}
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
return nil
}
// `/ds/query` endpoint test
func TestAPIEndpoint_Metrics_QueryMetricsV2(t *testing.T) {
qds := query.ProvideService(

View File

@@ -1065,3 +1065,15 @@ func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user *
func (m *mockOAuthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
return m.oAuthEnabled
}
func (m *mockOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
return nil, false, nil
}
func (m *mockOAuthTokenService) TryTokenRefresh(context.Context, *models.UserAuth) error {
return nil
}
func (m *mockOAuthTokenService) InvalidateOAuthTokens(context.Context, *models.UserAuth) error {
return nil
}

View File

@@ -2,6 +2,7 @@ package middleware
import (
"context"
"errors"
"fmt"
"io"
"net"
@@ -339,6 +340,123 @@ func TestMiddlewareContext(t *testing.T) {
assert.Nil(t, sc.context.UserToken)
})
middlewareScenario(t, "Non-expired auth token in cookie and non-expired OAuth access token", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(11 * time.Second)}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
require.NotNil(t, sc.context)
require.NotNil(t, sc.context.UserToken)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token fails", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
signedInUser := &user.SignedInUser{OrgID: 2, UserID: userID}
sc.userService.ExpectedSignedInUser = signedInUser
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{
UserId: userID,
OAuthExpiry: fakeGetTime()().Add(-1 * time.Second),
OAuthAccessToken: "access_token",
OAuthRefreshToken: "refresh_token"}
sc.oauthTokenService.ExpectedErrors = map[string]error{"TryTokenRefresh": errors.New("error")}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
token := sc.oauthTokenService.GetCurrentOAuthToken(sc.context.Req.Context(), signedInUser)
assert.Equal(t, token.AccessToken, "")
assert.Equal(t, token.RefreshToken, "")
assert.True(t, token.Expiry.IsZero())
require.NotNil(t, sc.context)
require.Nil(t, sc.context.UserToken)
assert.False(t, sc.context.IsSignedIn)
assert.Equal(t, int64(0), sc.context.UserID)
assert.Equal(t, "grafana_session=; Path=/; Max-Age=0; HttpOnly", sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token succeeds", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(-5 * time.Second), OAuthRefreshToken: "refreshtoken"}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
require.NotNil(t, sc.context)
require.NotNil(t, sc.context.UserToken)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "Non-expired auth token in cookie and OAuth Access Token's Expiry is not set", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
require.NotNil(t, sc.context)
require.NotNil(t, sc.context.UserToken)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
sc.mockSQLStore.ExpectedOrg = &models.Org{Id: 1, Name: sc.cfg.AnonymousOrgName}
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
@@ -655,7 +773,8 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
sc.userService = usertest.NewUserServiceFake()
sc.orgService = orgtest.NewOrgServiceFake()
sc.apiKeyService = &apikeytest.Service{}
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService)
sc.oauthTokenService = &auth.FakeOAuthTokenService{}
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService, sc.oauthTokenService)
sc.sqlStore = ctxHdlr.SQLStore
sc.contextHandler = ctxHdlr
sc.m.Use(ctxHdlr.Middleware)
@@ -691,6 +810,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock,
loginService *loginservice.LoginServiceMock, apiKeyService *apikeytest.Service,
userService *usertest.FakeUserService, orgService *orgtest.FakeOrgService,
oauthTokenService *auth.FakeOAuthTokenService,
) *contexthandler.ContextHandler {
t.Helper()
@@ -708,7 +828,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.S
tracer := tracing.InitializeTracerForTest()
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, userService, mockSQLStore)
authenticator := &logintest.AuthenticatorFake{ExpectedUser: &user.User{}}
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService)
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService, oauthTokenService)
}
type fakeRenderService struct {

View File

@@ -68,7 +68,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
sc.remoteCacheService = remotecache.NewFakeStore(t)
contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil)
contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil, nil)
sc.m.Use(contextHandler.Middleware)
// mock out gc goroutine
sc.m.Use(OrgRedirect(cfg, sc.userService))

View File

@@ -44,6 +44,7 @@ type scenarioContext struct {
loginService *loginservice.LoginServiceMock
apiKeyService *apikeytest.Service
userService *usertest.FakeUserService
oauthTokenService *auth.FakeOAuthTokenService
orgService *orgtest.FakeOrgService
req *http.Request

View File

@@ -3,9 +3,12 @@ package auth
import (
"context"
"net"
"time"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/user"
"golang.org/x/oauth2"
)
type FakeUserAuthTokenService struct {
@@ -105,3 +108,46 @@ func (s *FakeUserAuthTokenService) GetUserRevokedTokens(ctx context.Context, use
func (s *FakeUserAuthTokenService) BatchRevokeAllUserTokens(ctx context.Context, userIds []int64) error {
return s.BatchRevokedTokenProvider(ctx, userIds)
}
type FakeOAuthTokenService struct {
passThruEnabled bool
ExpectedAuthUser *models.UserAuth
ExpectedErrors map[string]error
}
func (ts *FakeOAuthTokenService) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token {
return &oauth2.Token{
AccessToken: ts.ExpectedAuthUser.OAuthAccessToken,
RefreshToken: ts.ExpectedAuthUser.OAuthRefreshToken,
Expiry: ts.ExpectedAuthUser.OAuthExpiry,
TokenType: ts.ExpectedAuthUser.OAuthTokenType,
}
}
func (ts *FakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) bool {
return ts.passThruEnabled
}
func (ts *FakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
if ts.ExpectedAuthUser != nil {
return ts.ExpectedAuthUser, true, nil
}
if error, ok := ts.ExpectedErrors["HasOAuthEntry"]; ok {
return nil, false, error
}
return nil, false, nil
}
func (ts *FakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
ts.ExpectedAuthUser.OAuthAccessToken = ""
ts.ExpectedAuthUser.OAuthRefreshToken = ""
ts.ExpectedAuthUser.OAuthExpiry = time.Time{}
return nil
}
func (ts *FakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
if err, ok := ts.ExpectedErrors["TryTokenRefresh"]; ok {
return err
}
return nil
}

View File

@@ -104,7 +104,7 @@ func getContextHandler(t *testing.T) *ContextHandler {
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc,
renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator,
&userService, orgService)
&userService, orgService, nil)
}
type FakeGetSignUserStore struct {

View File

@@ -4,6 +4,7 @@ package contexthandler
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
@@ -23,6 +24,7 @@ import (
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/contexthandler/ctxkey"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/sqlstore"
@@ -44,39 +46,42 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtS
remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore sqlstore.Store,
tracer tracing.Tracer, authProxy *authproxy.AuthProxy, loginService login.Service,
apiKeyService apikey.Service, authenticator loginpkg.Authenticator, userService user.Service,
orgService org.Service) *ContextHandler {
orgService org.Service, oauthTokenService oauthtoken.OAuthTokenService,
) *ContextHandler {
return &ContextHandler{
Cfg: cfg,
AuthTokenService: tokenService,
JWTAuthService: jwtService,
RemoteCache: remoteCache,
RenderService: renderService,
SQLStore: sqlStore,
tracer: tracer,
authProxy: authProxy,
authenticator: authenticator,
loginService: loginService,
apiKeyService: apiKeyService,
userService: userService,
orgService: orgService,
Cfg: cfg,
AuthTokenService: tokenService,
JWTAuthService: jwtService,
RemoteCache: remoteCache,
RenderService: renderService,
SQLStore: sqlStore,
tracer: tracer,
authProxy: authProxy,
authenticator: authenticator,
loginService: loginService,
apiKeyService: apiKeyService,
userService: userService,
orgService: orgService,
oauthTokenService: oauthTokenService,
}
}
// ContextHandler is a middleware.
type ContextHandler struct {
Cfg *setting.Cfg
AuthTokenService models.UserTokenService
JWTAuthService models.JWTService
RemoteCache *remotecache.RemoteCache
RenderService rendering.Service
SQLStore sqlstore.Store
tracer tracing.Tracer
authProxy *authproxy.AuthProxy
authenticator loginpkg.Authenticator
loginService login.Service
apiKeyService apikey.Service
userService user.Service
orgService org.Service
Cfg *setting.Cfg
AuthTokenService models.UserTokenService
JWTAuthService models.JWTService
RemoteCache *remotecache.RemoteCache
RenderService rendering.Service
SQLStore sqlstore.Store
tracer tracing.Tracer
authProxy *authproxy.AuthProxy
authenticator loginpkg.Authenticator
loginService login.Service
apiKeyService apikey.Service
userService user.Service
orgService org.Service
oauthTokenService oauthtoken.OAuthTokenService
// GetTime returns the current time.
// Stubbable by tests.
GetTime func() time.Time
@@ -428,6 +433,38 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org
return false
}
getTime := h.GetTime
if getTime == nil {
getTime = time.Now
}
// Check whether the logged in User has a token (whether the User used an OAuth provider to login)
oauthToken, exists, _ := h.oauthTokenService.HasOAuthEntry(ctx, queryResult)
if exists {
// Skip where the OAuthExpiry is default/zero/unset
if !oauthToken.OAuthExpiry.IsZero() && oauthToken.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime()) {
reqContext.Logger.Info("access token expired", "userId", query.UserID, "expiry", fmt.Sprintf("%v", oauthToken.OAuthExpiry))
// If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and Invalidate the OAuth tokens
if err = h.oauthTokenService.TryTokenRefresh(ctx, oauthToken); err != nil {
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
reqContext.Logger.Error("could not fetch a new access token", "userId", oauthToken.UserId, "error", err)
}
reqContext.Resp.Before(h.deleteInvalidCookieEndOfRequestFunc(reqContext))
if err = h.oauthTokenService.InvalidateOAuthTokens(ctx, oauthToken); err != nil {
reqContext.Logger.Error("could not invalidate OAuth tokens", "userId", oauthToken.UserId, "error", err)
}
err = h.AuthTokenService.RevokeToken(ctx, token, false)
if err != nil && !errors.Is(err, models.ErrUserTokenNotFound) {
reqContext.Logger.Error("failed to revoke auth token", "error", err)
}
return false
}
}
}
reqContext.SignedInUser = queryResult
reqContext.IsSignedIn = true
reqContext.UserToken = token

View File

@@ -204,13 +204,8 @@ func (s *AuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAu
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
}
cond := &models.UserAuth{
UserId: cmd.UserId,
AuthModule: cmd.AuthModule,
}
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error {
upd, err := sess.Update(authUser, cond)
upd, err := sess.MustCols("o_auth_expiry").Where("user_id = ? AND auth_module = ?", cmd.UserId, cmd.AuthModule).Update(authUser)
s.logger.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_module", cmd.AuthModule, "rows", upd)
return err
})

View File

@@ -3,8 +3,12 @@ package oauthtoken
import (
"context"
"errors"
"fmt"
"strings"
"time"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
@@ -12,26 +16,39 @@ import (
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
var (
logger = log.New("oauthtoken")
// ExpiryDelta is used to prevent any issue that is caused by the clock skew (server times can differ slightly between different machines).
// Shouldn't be more than 30s
ExpiryDelta = 10 * time.Second
ErrNoRefreshTokenFound = errors.New("no refresh token found")
ErrNotAnOAuthProvider = errors.New("not an oauth provider")
)
type Service struct {
SocialService social.Service
AuthInfoService login.AuthInfoService
Cfg *setting.Cfg
SocialService social.Service
AuthInfoService login.AuthInfoService
singleFlightGroup *singleflight.Group
}
type OAuthTokenService interface {
GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token
IsOAuthPassThruEnabled(*datasources.DataSource) bool
HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error)
TryTokenRefresh(context.Context, *models.UserAuth) error
InvalidateOAuthTokens(context.Context, *models.UserAuth) error
}
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService) *Service {
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg) *Service {
return &Service{
SocialService: socialService,
AuthInfoService: authInfoService,
Cfg: cfg,
SocialService: socialService,
AuthInfoService: authInfoService,
singleFlightGroup: new(singleflight.Group),
}
}
@@ -46,59 +63,17 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs
if err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery); err != nil {
if errors.Is(err, user.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way.
logger.Debug("no OAuth token for user found", "userId", usr.UserID, "username", usr.Login)
logger.Debug("no oauth token for user found", "userId", usr.UserID, "username", usr.Login)
} else {
logger.Error("failed to get OAuth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
logger.Error("failed to get oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
}
return nil
}
authProvider := authInfoQuery.Result.AuthModule
connect, err := o.SocialService.GetConnector(authProvider)
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfoQuery.Result)
if err != nil {
logger.Error("failed to get OAuth connector", "provider", authProvider, "error", err)
return nil
}
client, err := o.SocialService.GetOAuthHttpClient(authProvider)
if err != nil {
logger.Error("failed to get OAuth http client", "provider", authProvider, "error", err)
return nil
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
persistedToken := &oauth2.Token{
AccessToken: authInfoQuery.Result.OAuthAccessToken,
Expiry: authInfoQuery.Result.OAuthExpiry,
RefreshToken: authInfoQuery.Result.OAuthRefreshToken,
TokenType: authInfoQuery.Result.OAuthTokenType,
}
if authInfoQuery.Result.OAuthIdToken != "" {
persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": authInfoQuery.Result.OAuthIdToken})
}
// TokenSource handles refreshing the token if it has expired
token, err := connect.TokenSource(ctx, persistedToken).Token()
if err != nil {
logger.Error("failed to retrieve OAuth access token", "provider", authInfoQuery.Result.AuthModule, "userId", usr.UserID, "username", usr.Login, "error", err)
return nil
}
// If the tokens are not the same, update the entry in the DB
if !tokensEq(persistedToken, token) {
updateAuthCommand := &models.UpdateAuthInfoCommand{
UserId: authInfoQuery.Result.UserId,
AuthModule: authInfoQuery.Result.AuthModule,
AuthId: authInfoQuery.Result.AuthId,
OAuthToken: token,
}
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
logger.Error("failed to update auth info during token refresh", "userId", usr.UserID, "username", usr.Login, "error", err)
return nil
}
logger.Debug("updated OAuth info for user", "userId", usr.UserID, "username", usr.Login)
}
return token
}
@@ -107,6 +82,128 @@ func (o *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool()
}
// HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User
func (o *Service) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error) {
if usr == nil {
// No user, therefore no token
return nil, false, nil
}
authInfoQuery := &models.GetAuthInfoQuery{UserId: usr.UserID}
err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) {
// Not necessarily an error. User may be logged in another way.
return nil, false, nil
}
logger.Error("failed to fetch oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
return nil, false, err
}
if !strings.Contains(authInfoQuery.Result.AuthModule, "oauth") {
return nil, false, nil
}
return authInfoQuery.Result, true, nil
}
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
// It uses a singleflight.Group to prevent getting the Refresh Token multiple times for a given User
func (o *Service) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
lockKey := fmt.Sprintf("oauth-refresh-token-%d", usr.UserId)
_, err, _ := o.singleFlightGroup.Do(lockKey, func() (interface{}, error) {
logger.Debug("singleflight request for getting a new access token", "key", lockKey)
authProvider := usr.AuthModule
if !strings.Contains(authProvider, "oauth") {
logger.Error("the specified user's auth provider is not oauth", "authmodule", usr.AuthModule, "userid", usr.UserId)
return nil, ErrNotAnOAuthProvider
}
if usr.OAuthRefreshToken == "" {
logger.Debug("no refresh token available", "authmodule", usr.AuthModule, "userid", usr.UserId)
return nil, ErrNoRefreshTokenFound
}
return o.tryGetOrRefreshAccessToken(ctx, usr)
})
return err
}
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
return o.AuthInfoService.UpdateAuthInfo(ctx, &models.UpdateAuthInfoCommand{
UserId: usr.UserId,
AuthModule: usr.AuthModule,
AuthId: usr.AuthId,
OAuthToken: &oauth2.Token{
AccessToken: "",
RefreshToken: "",
Expiry: time.Time{},
},
})
}
func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *models.UserAuth) (*oauth2.Token, error) {
authProvider := usr.AuthModule
connect, err := o.SocialService.GetConnector(authProvider)
if err != nil {
logger.Error("failed to get oauth connector", "provider", authProvider, "error", err)
return nil, err
}
client, err := o.SocialService.GetOAuthHttpClient(authProvider)
if err != nil {
logger.Error("failed to get oauth http client", "provider", authProvider, "error", err)
return nil, err
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
persistedToken := &oauth2.Token{
AccessToken: usr.OAuthAccessToken,
Expiry: usr.OAuthExpiry,
RefreshToken: usr.OAuthRefreshToken,
TokenType: usr.OAuthTokenType,
}
if usr.OAuthIdToken != "" {
persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": usr.OAuthIdToken})
}
// TokenSource handles refreshing the token if it has expired
token, err := connect.TokenSource(ctx, persistedToken).Token()
if err != nil {
logger.Error("failed to retrieve oauth access token", "provider", usr.AuthModule, "userId", usr.UserId, "error", err)
return nil, err
}
// If the tokens are not the same, update the entry in the DB
if !tokensEq(persistedToken, token) {
updateAuthCommand := &models.UpdateAuthInfoCommand{
UserId: usr.UserId,
AuthModule: usr.AuthModule,
AuthId: usr.AuthId,
OAuthToken: token,
}
if o.Cfg.Env == setting.Dev {
logger.Debug("oauth got token",
"user", usr.UserId,
"auth_module", usr.AuthModule,
"expiry", fmt.Sprintf("%v", token.Expiry),
"access_token", fmt.Sprintf("%v", token.AccessToken),
"refresh_token", fmt.Sprintf("%v", token.RefreshToken),
)
}
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
logger.Error("failed to update auth info during token refresh", "userId", usr.UserId, "error", err)
return nil, err
}
logger.Debug("updated oauth info for user", "userId", usr.UserId)
}
return token, nil
}
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
func tokensEq(t1, t2 *oauth2.Token) bool {
return t1.AccessToken == t2.AccessToken &&

View File

@@ -0,0 +1,374 @@
package oauthtoken
import (
"context"
"errors"
"net/http"
"reflect"
"testing"
"time"
"github.com/grafana/grafana/pkg/infra/usagestats"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
"golang.org/x/sync/singleflight"
)
func TestService_HasOAuthEntry(t *testing.T) {
testCases := []struct {
name string
user *user.SignedInUser
want *models.UserAuth
wantExist bool
wantErr bool
err error
getAuthInfoErr error
getAuthInfoUser models.UserAuth
}{
{
name: "returns false without an error in case user is nil",
user: nil,
want: nil,
wantExist: false,
wantErr: false,
},
{
name: "returns false and an error in case GetAuthInfo returns an error",
user: &user.SignedInUser{},
want: nil,
wantExist: false,
wantErr: true,
getAuthInfoErr: errors.New("error"),
},
{
name: "returns false without an error in case auth entry is not found",
user: &user.SignedInUser{},
want: nil,
wantExist: false,
wantErr: false,
getAuthInfoErr: user.ErrUserNotFound,
},
{
name: "returns false without an error in case the auth entry is not oauth",
user: &user.SignedInUser{},
want: nil,
wantExist: false,
wantErr: false,
getAuthInfoUser: models.UserAuth{AuthModule: "auth_saml"},
},
{
name: "returns true when the auth entry is found",
user: &user.SignedInUser{},
want: &models.UserAuth{AuthModule: "oauth_generic_oauth"},
wantExist: true,
wantErr: false,
getAuthInfoUser: models.UserAuth{AuthModule: "oauth_generic_oauth"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
srv, authInfoStore, _ := setupOAuthTokenService(t)
authInfoStore.ExpectedOAuth = &tc.getAuthInfoUser
authInfoStore.ExpectedError = tc.getAuthInfoErr
entry, exists, err := srv.HasOAuthEntry(context.Background(), tc.user)
if tc.wantErr {
assert.Error(t, err)
}
if tc.want != nil {
assert.True(t, reflect.DeepEqual(tc.want, entry))
}
assert.Equal(t, tc.wantExist, exists)
})
}
}
func TestService_TryTokenRefresh_ValidToken(t *testing.T) {
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
ctx := context.Background()
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now(),
TokenType: "Bearer",
}
usr := &models.UserAuth{
AuthModule: "oauth_generic_oauth",
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
authInfoStore.ExpectedOAuth = usr
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
err := srv.TryTokenRefresh(ctx, usr)
assert.Nil(t, err)
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
authInfoQuery := &models.GetAuthInfoQuery{}
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
assert.Nil(t, err)
// User's token data had not been updated
resultUsr := authInfoQuery.Result
assert.Equal(t, resultUsr.OAuthAccessToken, token.AccessToken)
assert.Equal(t, resultUsr.OAuthExpiry, token.Expiry)
assert.Equal(t, resultUsr.OAuthRefreshToken, token.RefreshToken)
assert.Equal(t, resultUsr.OAuthTokenType, token.TokenType)
}
func TestService_TryTokenRefresh_NoRefreshToken(t *testing.T) {
srv, _, socialConnector := setupOAuthTokenService(t)
ctx := context.Background()
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "",
Expiry: time.Now().Add(-time.Hour),
TokenType: "Bearer",
}
usr := &models.UserAuth{
AuthModule: "oauth_generic_oauth",
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
err := srv.TryTokenRefresh(ctx, usr)
assert.NotNil(t, err)
assert.ErrorIs(t, err, ErrNoRefreshTokenFound)
socialConnector.AssertNotCalled(t, "TokenSource")
}
func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) {
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
ctx := context.Background()
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now().Add(-time.Hour),
TokenType: "Bearer",
}
newToken := &oauth2.Token{
AccessToken: "testaccess_new",
RefreshToken: "testrefresh_new",
Expiry: time.Now().Add(time.Hour),
TokenType: "Bearer",
}
usr := &models.UserAuth{
AuthModule: "oauth_generic_oauth",
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthExpiry: token.Expiry,
OAuthTokenType: token.TokenType,
}
authInfoStore.ExpectedOAuth = usr
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.ReuseTokenSource(token, oauth2.StaticTokenSource(newToken)), nil)
err := srv.TryTokenRefresh(ctx, usr)
assert.Nil(t, err)
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
authInfoQuery := &models.GetAuthInfoQuery{}
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
assert.Nil(t, err)
// newToken should be returned after the .Token() call, therefore the User had to be updated
assert.Equal(t, authInfoQuery.Result.OAuthAccessToken, newToken.AccessToken)
assert.Equal(t, authInfoQuery.Result.OAuthExpiry, newToken.Expiry)
assert.Equal(t, authInfoQuery.Result.OAuthRefreshToken, newToken.RefreshToken)
assert.Equal(t, authInfoQuery.Result.OAuthTokenType, newToken.TokenType)
}
func TestService_TryTokenRefresh_DifferentAuthModuleForUser(t *testing.T) {
srv, _, socialConnector := setupOAuthTokenService(t)
ctx := context.Background()
token := &oauth2.Token{}
usr := &models.UserAuth{
AuthModule: "auth.saml",
}
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
err := srv.TryTokenRefresh(ctx, usr)
assert.NotNil(t, err)
assert.ErrorIs(t, err, ErrNotAnOAuthProvider)
socialConnector.AssertNotCalled(t, "TokenSource")
}
func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *MockSocialConnector) {
t.Helper()
socialConnector := &MockSocialConnector{}
socialService := &FakeSocialService{
connector: socialConnector,
}
authInfoStore := &FakeAuthInfoStore{}
authInfoService := authinfoservice.ProvideAuthInfoService(nil, authInfoStore, &usagestats.UsageStatsMock{})
return &Service{
Cfg: setting.NewCfg(),
SocialService: socialService,
AuthInfoService: authInfoService,
singleFlightGroup: &singleflight.Group{},
}, authInfoStore, socialConnector
}
type FakeSocialService struct {
httpClient *http.Client
connector *MockSocialConnector
}
func (fss *FakeSocialService) GetOAuthProviders() map[string]bool {
panic("not implemented")
}
func (fss *FakeSocialService) GetOAuthHttpClient(string) (*http.Client, error) {
return fss.httpClient, nil
}
func (fss *FakeSocialService) GetConnector(string) (social.SocialConnector, error) {
return fss.connector, nil
}
func (fss *FakeSocialService) GetOAuthInfoProvider(string) *social.OAuthInfo {
panic("not implemented")
}
func (fss *FakeSocialService) GetOAuthInfoProviders() map[string]*social.OAuthInfo {
panic("not implemented")
}
type MockSocialConnector struct {
mock.Mock
}
func (m *MockSocialConnector) Type() int {
args := m.Called()
return args.Int(0)
}
func (m *MockSocialConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
args := m.Called(client, token)
return args.Get(0).(*social.BasicUserInfo), args.Error(1)
}
func (m *MockSocialConnector) IsEmailAllowed(email string) bool {
panic("not implemented")
}
func (m *MockSocialConnector) IsSignupAllowed() bool {
panic("not implemented")
}
func (m *MockSocialConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
panic("not implemented")
}
func (m *MockSocialConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
panic("not implemented")
}
func (m *MockSocialConnector) Client(ctx context.Context, t *oauth2.Token) *http.Client {
panic("not implemented")
}
func (m *MockSocialConnector) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource {
args := m.Called(ctx, t)
return args.Get(0).(oauth2.TokenSource)
}
type FakeAuthInfoStore struct {
ExpectedError error
ExpectedUser *user.User
ExpectedOAuth *models.UserAuth
ExpectedDuplicateUserEntries int
ExpectedHasDuplicateUserEntries int
ExpectedLoginStats login.LoginStats
}
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *models.GetExternalUserInfoByLoginQuery) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *models.GetAuthInfoQuery) error {
query.Result = f.ExpectedOAuth
return f.ExpectedError
}
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *models.UserAuth) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error {
f.ExpectedOAuth.OAuthAccessToken = cmd.OAuthToken.AccessToken
f.ExpectedOAuth.OAuthExpiry = cmd.OAuthToken.Expiry
f.ExpectedOAuth.OAuthTokenType = cmd.OAuthToken.TokenType
f.ExpectedOAuth.OAuthRefreshToken = cmd.OAuthToken.RefreshToken
return f.ExpectedError
}
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAuthInfoCommand) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeAuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeAuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeAuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]interface{}, error) {
var res = make(map[string]interface{})
res["stats.users.duplicate_user_entries"] = f.ExpectedDuplicateUserEntries
res["stats.users.has_duplicate_user_entries"] = f.ExpectedHasDuplicateUserEntries
res["stats.users.duplicate_user_entries_by_login"] = 0
res["stats.users.has_duplicate_user_entries_by_login"] = 0
res["stats.users.duplicate_user_entries_by_email"] = 0
res["stats.users.has_duplicate_user_entries_by_email"] = 0
res["stats.users.mixed_cased_users"] = f.ExpectedLoginStats.MixedCasedUsers
return res, f.ExpectedError
}
func (f *FakeAuthInfoStore) RunMetricsCollection(ctx context.Context) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, error) {
return f.ExpectedLoginStats, f.ExpectedError
}

View File

@@ -164,6 +164,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource)
return ts.passThruEnabled
}
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
return nil, false, nil
}
func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
return nil
}
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
return nil
}
// copied from pkg/api/plugins_test.go
type fakePluginClient struct {
plugins.Client

View File

@@ -8,6 +8,7 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend"
"github.com/grafana/grafana/pkg/expr"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -440,6 +441,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource)
return ts.passThruEnabled
}
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
return nil, false, nil
}
func (ts *fakeOAuthTokenService) TryTokenRefresh(context.Context, *models.UserAuth) error {
return nil
}
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(context.Context, *models.UserAuth) error {
return nil
}
type fakeDataSourceCache struct {
ds *datasources.DataSource
}