mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Refresh OAuth access_token automatically using the refresh_token (#56076)
* Verify OAuth token expiration for oauth users in the ctx handler middleware * Use refresh token to get a new access token * Refactor oauth_token.go * Add tests for the middleware changes * Align other tests * Add tests, wip * Add more tests * Add InvalidateOAuthTokens method * Fix ExpiryDate update to default * Invalidate OAuth tokens during logout * Improve logout * Add more comments * Cleanup * Fix import order * Add error to HasOAuthEntry return values * add dev debug logs * Fix tests Co-authored-by: jguer <joao.guerreiro@grafana.com>
This commit is contained in:
@@ -213,7 +213,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
|
|||||||
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, &usertest.FakeUserService{}, sqlStore)
|
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, &usertest.FakeUserService{}, sqlStore)
|
||||||
loginService := &logintest.LoginServiceFake{}
|
loginService := &logintest.LoginServiceFake{}
|
||||||
authenticator := &logintest.AuthenticatorFake{}
|
authenticator := &logintest.AuthenticatorFake{}
|
||||||
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake())
|
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake(), nil)
|
||||||
|
|
||||||
return ctxHdlr
|
return ctxHdlr
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/bus"
|
"github.com/grafana/grafana/pkg/bus"
|
||||||
"github.com/grafana/grafana/pkg/middleware/csrf"
|
"github.com/grafana/grafana/pkg/middleware/csrf"
|
||||||
"github.com/grafana/grafana/pkg/services/folder"
|
"github.com/grafana/grafana/pkg/services/folder"
|
||||||
|
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||||
"github.com/grafana/grafana/pkg/services/querylibrary"
|
"github.com/grafana/grafana/pkg/services/querylibrary"
|
||||||
"github.com/grafana/grafana/pkg/services/searchV2"
|
"github.com/grafana/grafana/pkg/services/searchV2"
|
||||||
"github.com/grafana/grafana/pkg/services/store/object"
|
"github.com/grafana/grafana/pkg/services/store/object"
|
||||||
@@ -206,6 +207,7 @@ type HTTPServer struct {
|
|||||||
annotationsRepo annotations.Repository
|
annotationsRepo annotations.Repository
|
||||||
tagService tag.Service
|
tagService tag.Service
|
||||||
userAuthService userauth.Service
|
userAuthService userauth.Service
|
||||||
|
oauthTokenService oauthtoken.OAuthTokenService
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerOptions struct {
|
type ServerOptions struct {
|
||||||
@@ -248,6 +250,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi
|
|||||||
accesscontrolService accesscontrol.Service, dashboardThumbsService thumbs.DashboardThumbService, navTreeService navtree.Service,
|
accesscontrolService accesscontrol.Service, dashboardThumbsService thumbs.DashboardThumbService, navTreeService navtree.Service,
|
||||||
annotationRepo annotations.Repository, tagService tag.Service, searchv2HTTPService searchV2.SearchHTTPService,
|
annotationRepo annotations.Repository, tagService tag.Service, searchv2HTTPService searchV2.SearchHTTPService,
|
||||||
userAuthService userauth.Service, queryLibraryHTTPService querylibrary.HTTPService, queryLibraryService querylibrary.Service,
|
userAuthService userauth.Service, queryLibraryHTTPService querylibrary.HTTPService, queryLibraryService querylibrary.Service,
|
||||||
|
oauthTokenService oauthtoken.OAuthTokenService,
|
||||||
) (*HTTPServer, error) {
|
) (*HTTPServer, error) {
|
||||||
web.Env = cfg.Env
|
web.Env = cfg.Env
|
||||||
m := web.New()
|
m := web.New()
|
||||||
@@ -352,6 +355,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi
|
|||||||
userAuthService: userAuthService,
|
userAuthService: userAuthService,
|
||||||
QueryLibraryHTTPService: queryLibraryHTTPService,
|
QueryLibraryHTTPService: queryLibraryHTTPService,
|
||||||
QueryLibraryService: queryLibraryService,
|
QueryLibraryService: queryLibraryService,
|
||||||
|
oauthTokenService: oauthTokenService,
|
||||||
}
|
}
|
||||||
if hs.Listener != nil {
|
if hs.Listener != nil {
|
||||||
hs.log.Debug("Using provided listener")
|
hs.log.Debug("Using provided listener")
|
||||||
|
|||||||
@@ -304,6 +304,13 @@ func (hs *HTTPServer) Logout(c *models.ReqContext) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Invalidate the OAuth tokens in case the User logged in with OAuth or the last external AuthEntry is an OAuth one
|
||||||
|
if entry, exists, _ := hs.oauthTokenService.HasOAuthEntry(c.Req.Context(), c.SignedInUser); exists {
|
||||||
|
if err := hs.oauthTokenService.InvalidateOAuthTokens(c.Req.Context(), entry); err != nil {
|
||||||
|
hs.log.Warn("failed to invalidate oauth tokens for user", "userId", c.UserID, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
err := hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken, false)
|
err := hs.AuthTokenService.RevokeToken(c.Req.Context(), c.UserToken, false)
|
||||||
if err != nil && !errors.Is(err, models.ErrUserTokenNotFound) {
|
if err != nil && !errors.Is(err, models.ErrUserTokenNotFound) {
|
||||||
hs.log.Error("failed to revoke auth token", "error", err)
|
hs.log.Error("failed to revoke auth token", "error", err)
|
||||||
|
|||||||
@@ -194,7 +194,15 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
|
|||||||
// token.TokenType was defaulting to "bearer", which is out of spec, so we explicitly set to "Bearer"
|
// token.TokenType was defaulting to "bearer", which is out of spec, so we explicitly set to "Bearer"
|
||||||
token.TokenType = "Bearer"
|
token.TokenType = "Bearer"
|
||||||
|
|
||||||
oauthLogger.Debug("OAuthLogin: got token", "expiry", fmt.Sprintf("%v", token.Expiry))
|
if hs.Cfg.Env != setting.Dev {
|
||||||
|
oauthLogger.Debug("OAuthLogin: got token", "expiry", fmt.Sprintf("%v", token.Expiry))
|
||||||
|
} else {
|
||||||
|
oauthLogger.Debug("OAuthLogin: got token",
|
||||||
|
"expiry", fmt.Sprintf("%v", token.Expiry),
|
||||||
|
"access_token", fmt.Sprintf("%v", token.AccessToken),
|
||||||
|
"refresh_token", fmt.Sprintf("%v", token.RefreshToken),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// set up oauth2 client
|
// set up oauth2 client
|
||||||
client := connect.Client(oauthCtx, token)
|
client := connect.Client(oauthCtx, token)
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/models"
|
||||||
"github.com/grafana/grafana/pkg/plugins"
|
"github.com/grafana/grafana/pkg/plugins"
|
||||||
"github.com/grafana/grafana/pkg/plugins/backendplugin"
|
"github.com/grafana/grafana/pkg/plugins/backendplugin"
|
||||||
pluginClient "github.com/grafana/grafana/pkg/plugins/manager/client"
|
pluginClient "github.com/grafana/grafana/pkg/plugins/manager/client"
|
||||||
@@ -56,6 +57,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource)
|
|||||||
return ts.passThruEnabled
|
return ts.passThruEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// `/ds/query` endpoint test
|
// `/ds/query` endpoint test
|
||||||
func TestAPIEndpoint_Metrics_QueryMetricsV2(t *testing.T) {
|
func TestAPIEndpoint_Metrics_QueryMetricsV2(t *testing.T) {
|
||||||
qds := query.ProvideService(
|
qds := query.ProvideService(
|
||||||
|
|||||||
@@ -1065,3 +1065,15 @@ func (m *mockOAuthTokenService) GetCurrentOAuthToken(ctx context.Context, user *
|
|||||||
func (m *mockOAuthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
|
func (m *mockOAuthTokenService) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
|
||||||
return m.oAuthEnabled
|
return m.oAuthEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockOAuthTokenService) TryTokenRefresh(context.Context, *models.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockOAuthTokenService) InvalidateOAuthTokens(context.Context, *models.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -339,6 +340,123 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
assert.Nil(t, sc.context.UserToken)
|
assert.Nil(t, sc.context.UserToken)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
middlewareScenario(t, "Non-expired auth token in cookie and non-expired OAuth access token", func(
|
||||||
|
t *testing.T, sc *scenarioContext) {
|
||||||
|
const userID int64 = 12
|
||||||
|
sc.contextHandler.GetTime = fakeGetTime()
|
||||||
|
|
||||||
|
sc.withTokenSessionCookie("token")
|
||||||
|
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||||
|
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(11 * time.Second)}
|
||||||
|
|
||||||
|
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||||
|
return &models.UserToken{
|
||||||
|
UserId: userID,
|
||||||
|
UnhashedToken: unhashedToken,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.fakeReq("GET", "/").exec()
|
||||||
|
|
||||||
|
require.NotNil(t, sc.context)
|
||||||
|
require.NotNil(t, sc.context.UserToken)
|
||||||
|
assert.True(t, sc.context.IsSignedIn)
|
||||||
|
assert.Equal(t, userID, sc.context.UserID)
|
||||||
|
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||||
|
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||||
|
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||||
|
})
|
||||||
|
|
||||||
|
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token fails", func(
|
||||||
|
t *testing.T, sc *scenarioContext) {
|
||||||
|
const userID int64 = 12
|
||||||
|
sc.contextHandler.GetTime = fakeGetTime()
|
||||||
|
|
||||||
|
sc.withTokenSessionCookie("token")
|
||||||
|
signedInUser := &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||||
|
sc.userService.ExpectedSignedInUser = signedInUser
|
||||||
|
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{
|
||||||
|
UserId: userID,
|
||||||
|
OAuthExpiry: fakeGetTime()().Add(-1 * time.Second),
|
||||||
|
OAuthAccessToken: "access_token",
|
||||||
|
OAuthRefreshToken: "refresh_token"}
|
||||||
|
sc.oauthTokenService.ExpectedErrors = map[string]error{"TryTokenRefresh": errors.New("error")}
|
||||||
|
|
||||||
|
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||||
|
return &models.UserToken{
|
||||||
|
UserId: userID,
|
||||||
|
UnhashedToken: unhashedToken,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.fakeReq("GET", "/").exec()
|
||||||
|
|
||||||
|
token := sc.oauthTokenService.GetCurrentOAuthToken(sc.context.Req.Context(), signedInUser)
|
||||||
|
assert.Equal(t, token.AccessToken, "")
|
||||||
|
assert.Equal(t, token.RefreshToken, "")
|
||||||
|
assert.True(t, token.Expiry.IsZero())
|
||||||
|
|
||||||
|
require.NotNil(t, sc.context)
|
||||||
|
require.Nil(t, sc.context.UserToken)
|
||||||
|
assert.False(t, sc.context.IsSignedIn)
|
||||||
|
assert.Equal(t, int64(0), sc.context.UserID)
|
||||||
|
assert.Equal(t, "grafana_session=; Path=/; Max-Age=0; HttpOnly", sc.resp.Header().Get("Set-Cookie"))
|
||||||
|
})
|
||||||
|
|
||||||
|
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token succeeds", func(
|
||||||
|
t *testing.T, sc *scenarioContext) {
|
||||||
|
const userID int64 = 12
|
||||||
|
sc.contextHandler.GetTime = fakeGetTime()
|
||||||
|
|
||||||
|
sc.withTokenSessionCookie("token")
|
||||||
|
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||||
|
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(-5 * time.Second), OAuthRefreshToken: "refreshtoken"}
|
||||||
|
|
||||||
|
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||||
|
return &models.UserToken{
|
||||||
|
UserId: userID,
|
||||||
|
UnhashedToken: unhashedToken,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.fakeReq("GET", "/").exec()
|
||||||
|
|
||||||
|
require.NotNil(t, sc.context)
|
||||||
|
require.NotNil(t, sc.context.UserToken)
|
||||||
|
assert.True(t, sc.context.IsSignedIn)
|
||||||
|
assert.Equal(t, userID, sc.context.UserID)
|
||||||
|
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||||
|
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||||
|
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||||
|
})
|
||||||
|
|
||||||
|
middlewareScenario(t, "Non-expired auth token in cookie and OAuth Access Token's Expiry is not set", func(
|
||||||
|
t *testing.T, sc *scenarioContext) {
|
||||||
|
const userID int64 = 12
|
||||||
|
sc.contextHandler.GetTime = fakeGetTime()
|
||||||
|
|
||||||
|
sc.withTokenSessionCookie("token")
|
||||||
|
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||||
|
sc.oauthTokenService.ExpectedAuthUser = &models.UserAuth{UserId: userID}
|
||||||
|
|
||||||
|
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
|
||||||
|
return &models.UserToken{
|
||||||
|
UserId: userID,
|
||||||
|
UnhashedToken: unhashedToken,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.fakeReq("GET", "/").exec()
|
||||||
|
|
||||||
|
require.NotNil(t, sc.context)
|
||||||
|
require.NotNil(t, sc.context.UserToken)
|
||||||
|
assert.True(t, sc.context.IsSignedIn)
|
||||||
|
assert.Equal(t, userID, sc.context.UserID)
|
||||||
|
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||||
|
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||||
|
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||||
|
})
|
||||||
|
|
||||||
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
|
||||||
sc.mockSQLStore.ExpectedOrg = &models.Org{Id: 1, Name: sc.cfg.AnonymousOrgName}
|
sc.mockSQLStore.ExpectedOrg = &models.Org{Id: 1, Name: sc.cfg.AnonymousOrgName}
|
||||||
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
|
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
|
||||||
@@ -655,7 +773,8 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
|||||||
sc.userService = usertest.NewUserServiceFake()
|
sc.userService = usertest.NewUserServiceFake()
|
||||||
sc.orgService = orgtest.NewOrgServiceFake()
|
sc.orgService = orgtest.NewOrgServiceFake()
|
||||||
sc.apiKeyService = &apikeytest.Service{}
|
sc.apiKeyService = &apikeytest.Service{}
|
||||||
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService)
|
sc.oauthTokenService = &auth.FakeOAuthTokenService{}
|
||||||
|
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService, sc.oauthTokenService)
|
||||||
sc.sqlStore = ctxHdlr.SQLStore
|
sc.sqlStore = ctxHdlr.SQLStore
|
||||||
sc.contextHandler = ctxHdlr
|
sc.contextHandler = ctxHdlr
|
||||||
sc.m.Use(ctxHdlr.Middleware)
|
sc.m.Use(ctxHdlr.Middleware)
|
||||||
@@ -691,6 +810,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
|||||||
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock,
|
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock,
|
||||||
loginService *loginservice.LoginServiceMock, apiKeyService *apikeytest.Service,
|
loginService *loginservice.LoginServiceMock, apiKeyService *apikeytest.Service,
|
||||||
userService *usertest.FakeUserService, orgService *orgtest.FakeOrgService,
|
userService *usertest.FakeUserService, orgService *orgtest.FakeOrgService,
|
||||||
|
oauthTokenService *auth.FakeOAuthTokenService,
|
||||||
) *contexthandler.ContextHandler {
|
) *contexthandler.ContextHandler {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -708,7 +828,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.S
|
|||||||
tracer := tracing.InitializeTracerForTest()
|
tracer := tracing.InitializeTracerForTest()
|
||||||
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, userService, mockSQLStore)
|
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, userService, mockSQLStore)
|
||||||
authenticator := &logintest.AuthenticatorFake{ExpectedUser: &user.User{}}
|
authenticator := &logintest.AuthenticatorFake{ExpectedUser: &user.User{}}
|
||||||
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService)
|
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, apiKeyService, authenticator, userService, orgService, oauthTokenService)
|
||||||
}
|
}
|
||||||
|
|
||||||
type fakeRenderService struct {
|
type fakeRenderService struct {
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
|
|||||||
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
|
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
|
||||||
sc.remoteCacheService = remotecache.NewFakeStore(t)
|
sc.remoteCacheService = remotecache.NewFakeStore(t)
|
||||||
|
|
||||||
contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil)
|
contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil, nil)
|
||||||
sc.m.Use(contextHandler.Middleware)
|
sc.m.Use(contextHandler.Middleware)
|
||||||
// mock out gc goroutine
|
// mock out gc goroutine
|
||||||
sc.m.Use(OrgRedirect(cfg, sc.userService))
|
sc.m.Use(OrgRedirect(cfg, sc.userService))
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ type scenarioContext struct {
|
|||||||
loginService *loginservice.LoginServiceMock
|
loginService *loginservice.LoginServiceMock
|
||||||
apiKeyService *apikeytest.Service
|
apiKeyService *apikeytest.Service
|
||||||
userService *usertest.FakeUserService
|
userService *usertest.FakeUserService
|
||||||
|
oauthTokenService *auth.FakeOAuthTokenService
|
||||||
orgService *orgtest.FakeOrgService
|
orgService *orgtest.FakeOrgService
|
||||||
|
|
||||||
req *http.Request
|
req *http.Request
|
||||||
|
|||||||
@@ -3,9 +3,12 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/models"
|
"github.com/grafana/grafana/pkg/models"
|
||||||
|
"github.com/grafana/grafana/pkg/services/datasources"
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type FakeUserAuthTokenService struct {
|
type FakeUserAuthTokenService struct {
|
||||||
@@ -105,3 +108,46 @@ func (s *FakeUserAuthTokenService) GetUserRevokedTokens(ctx context.Context, use
|
|||||||
func (s *FakeUserAuthTokenService) BatchRevokeAllUserTokens(ctx context.Context, userIds []int64) error {
|
func (s *FakeUserAuthTokenService) BatchRevokeAllUserTokens(ctx context.Context, userIds []int64) error {
|
||||||
return s.BatchRevokedTokenProvider(ctx, userIds)
|
return s.BatchRevokedTokenProvider(ctx, userIds)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FakeOAuthTokenService struct {
|
||||||
|
passThruEnabled bool
|
||||||
|
ExpectedAuthUser *models.UserAuth
|
||||||
|
ExpectedErrors map[string]error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *FakeOAuthTokenService) GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token {
|
||||||
|
return &oauth2.Token{
|
||||||
|
AccessToken: ts.ExpectedAuthUser.OAuthAccessToken,
|
||||||
|
RefreshToken: ts.ExpectedAuthUser.OAuthRefreshToken,
|
||||||
|
Expiry: ts.ExpectedAuthUser.OAuthExpiry,
|
||||||
|
TokenType: ts.ExpectedAuthUser.OAuthTokenType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *FakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource) bool {
|
||||||
|
return ts.passThruEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *FakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||||
|
if ts.ExpectedAuthUser != nil {
|
||||||
|
return ts.ExpectedAuthUser, true, nil
|
||||||
|
}
|
||||||
|
if error, ok := ts.ExpectedErrors["HasOAuthEntry"]; ok {
|
||||||
|
return nil, false, error
|
||||||
|
}
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *FakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
ts.ExpectedAuthUser.OAuthAccessToken = ""
|
||||||
|
ts.ExpectedAuthUser.OAuthRefreshToken = ""
|
||||||
|
ts.ExpectedAuthUser.OAuthExpiry = time.Time{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *FakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
if err, ok := ts.ExpectedErrors["TryTokenRefresh"]; ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ func getContextHandler(t *testing.T) *ContextHandler {
|
|||||||
|
|
||||||
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc,
|
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc,
|
||||||
renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator,
|
renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator,
|
||||||
&userService, orgService)
|
&userService, orgService, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
type FakeGetSignUserStore struct {
|
type FakeGetSignUserStore struct {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package contexthandler
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -23,6 +24,7 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler/ctxkey"
|
"github.com/grafana/grafana/pkg/services/contexthandler/ctxkey"
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
|
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||||
"github.com/grafana/grafana/pkg/services/org"
|
"github.com/grafana/grafana/pkg/services/org"
|
||||||
"github.com/grafana/grafana/pkg/services/rendering"
|
"github.com/grafana/grafana/pkg/services/rendering"
|
||||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||||
@@ -44,39 +46,42 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtS
|
|||||||
remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore sqlstore.Store,
|
remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore sqlstore.Store,
|
||||||
tracer tracing.Tracer, authProxy *authproxy.AuthProxy, loginService login.Service,
|
tracer tracing.Tracer, authProxy *authproxy.AuthProxy, loginService login.Service,
|
||||||
apiKeyService apikey.Service, authenticator loginpkg.Authenticator, userService user.Service,
|
apiKeyService apikey.Service, authenticator loginpkg.Authenticator, userService user.Service,
|
||||||
orgService org.Service) *ContextHandler {
|
orgService org.Service, oauthTokenService oauthtoken.OAuthTokenService,
|
||||||
|
) *ContextHandler {
|
||||||
return &ContextHandler{
|
return &ContextHandler{
|
||||||
Cfg: cfg,
|
Cfg: cfg,
|
||||||
AuthTokenService: tokenService,
|
AuthTokenService: tokenService,
|
||||||
JWTAuthService: jwtService,
|
JWTAuthService: jwtService,
|
||||||
RemoteCache: remoteCache,
|
RemoteCache: remoteCache,
|
||||||
RenderService: renderService,
|
RenderService: renderService,
|
||||||
SQLStore: sqlStore,
|
SQLStore: sqlStore,
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
authProxy: authProxy,
|
authProxy: authProxy,
|
||||||
authenticator: authenticator,
|
authenticator: authenticator,
|
||||||
loginService: loginService,
|
loginService: loginService,
|
||||||
apiKeyService: apiKeyService,
|
apiKeyService: apiKeyService,
|
||||||
userService: userService,
|
userService: userService,
|
||||||
orgService: orgService,
|
orgService: orgService,
|
||||||
|
oauthTokenService: oauthTokenService,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ContextHandler is a middleware.
|
// ContextHandler is a middleware.
|
||||||
type ContextHandler struct {
|
type ContextHandler struct {
|
||||||
Cfg *setting.Cfg
|
Cfg *setting.Cfg
|
||||||
AuthTokenService models.UserTokenService
|
AuthTokenService models.UserTokenService
|
||||||
JWTAuthService models.JWTService
|
JWTAuthService models.JWTService
|
||||||
RemoteCache *remotecache.RemoteCache
|
RemoteCache *remotecache.RemoteCache
|
||||||
RenderService rendering.Service
|
RenderService rendering.Service
|
||||||
SQLStore sqlstore.Store
|
SQLStore sqlstore.Store
|
||||||
tracer tracing.Tracer
|
tracer tracing.Tracer
|
||||||
authProxy *authproxy.AuthProxy
|
authProxy *authproxy.AuthProxy
|
||||||
authenticator loginpkg.Authenticator
|
authenticator loginpkg.Authenticator
|
||||||
loginService login.Service
|
loginService login.Service
|
||||||
apiKeyService apikey.Service
|
apiKeyService apikey.Service
|
||||||
userService user.Service
|
userService user.Service
|
||||||
orgService org.Service
|
orgService org.Service
|
||||||
|
oauthTokenService oauthtoken.OAuthTokenService
|
||||||
// GetTime returns the current time.
|
// GetTime returns the current time.
|
||||||
// Stubbable by tests.
|
// Stubbable by tests.
|
||||||
GetTime func() time.Time
|
GetTime func() time.Time
|
||||||
@@ -428,6 +433,38 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getTime := h.GetTime
|
||||||
|
if getTime == nil {
|
||||||
|
getTime = time.Now
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check whether the logged in User has a token (whether the User used an OAuth provider to login)
|
||||||
|
oauthToken, exists, _ := h.oauthTokenService.HasOAuthEntry(ctx, queryResult)
|
||||||
|
if exists {
|
||||||
|
// Skip where the OAuthExpiry is default/zero/unset
|
||||||
|
if !oauthToken.OAuthExpiry.IsZero() && oauthToken.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta).Before(getTime()) {
|
||||||
|
reqContext.Logger.Info("access token expired", "userId", query.UserID, "expiry", fmt.Sprintf("%v", oauthToken.OAuthExpiry))
|
||||||
|
|
||||||
|
// If the User doesn't have a refresh_token or refreshing the token was unsuccessful then log out the User and Invalidate the OAuth tokens
|
||||||
|
if err = h.oauthTokenService.TryTokenRefresh(ctx, oauthToken); err != nil {
|
||||||
|
if !errors.Is(err, oauthtoken.ErrNoRefreshTokenFound) {
|
||||||
|
reqContext.Logger.Error("could not fetch a new access token", "userId", oauthToken.UserId, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reqContext.Resp.Before(h.deleteInvalidCookieEndOfRequestFunc(reqContext))
|
||||||
|
if err = h.oauthTokenService.InvalidateOAuthTokens(ctx, oauthToken); err != nil {
|
||||||
|
reqContext.Logger.Error("could not invalidate OAuth tokens", "userId", oauthToken.UserId, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.AuthTokenService.RevokeToken(ctx, token, false)
|
||||||
|
if err != nil && !errors.Is(err, models.ErrUserTokenNotFound) {
|
||||||
|
reqContext.Logger.Error("failed to revoke auth token", "error", err)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
reqContext.SignedInUser = queryResult
|
reqContext.SignedInUser = queryResult
|
||||||
reqContext.IsSignedIn = true
|
reqContext.IsSignedIn = true
|
||||||
reqContext.UserToken = token
|
reqContext.UserToken = token
|
||||||
|
|||||||
@@ -204,13 +204,8 @@ func (s *AuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAu
|
|||||||
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
|
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
|
||||||
}
|
}
|
||||||
|
|
||||||
cond := &models.UserAuth{
|
|
||||||
UserId: cmd.UserId,
|
|
||||||
AuthModule: cmd.AuthModule,
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error {
|
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *sqlstore.DBSession) error {
|
||||||
upd, err := sess.Update(authUser, cond)
|
upd, err := sess.MustCols("o_auth_expiry").Where("user_id = ? AND auth_module = ?", cmd.UserId, cmd.AuthModule).Update(authUser)
|
||||||
s.logger.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_module", cmd.AuthModule, "rows", upd)
|
s.logger.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_module", cmd.AuthModule, "rows", upd)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -3,8 +3,12 @@ package oauthtoken
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
@@ -12,26 +16,39 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/datasources"
|
"github.com/grafana/grafana/pkg/services/datasources"
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
logger = log.New("oauthtoken")
|
logger = log.New("oauthtoken")
|
||||||
|
// ExpiryDelta is used to prevent any issue that is caused by the clock skew (server times can differ slightly between different machines).
|
||||||
|
// Shouldn't be more than 30s
|
||||||
|
ExpiryDelta = 10 * time.Second
|
||||||
|
ErrNoRefreshTokenFound = errors.New("no refresh token found")
|
||||||
|
ErrNotAnOAuthProvider = errors.New("not an oauth provider")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
SocialService social.Service
|
Cfg *setting.Cfg
|
||||||
AuthInfoService login.AuthInfoService
|
SocialService social.Service
|
||||||
|
AuthInfoService login.AuthInfoService
|
||||||
|
singleFlightGroup *singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthTokenService interface {
|
type OAuthTokenService interface {
|
||||||
GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token
|
GetCurrentOAuthToken(context.Context, *user.SignedInUser) *oauth2.Token
|
||||||
IsOAuthPassThruEnabled(*datasources.DataSource) bool
|
IsOAuthPassThruEnabled(*datasources.DataSource) bool
|
||||||
|
HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error)
|
||||||
|
TryTokenRefresh(context.Context, *models.UserAuth) error
|
||||||
|
InvalidateOAuthTokens(context.Context, *models.UserAuth) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService) *Service {
|
func ProvideService(socialService social.Service, authInfoService login.AuthInfoService, cfg *setting.Cfg) *Service {
|
||||||
return &Service{
|
return &Service{
|
||||||
SocialService: socialService,
|
Cfg: cfg,
|
||||||
AuthInfoService: authInfoService,
|
SocialService: socialService,
|
||||||
|
AuthInfoService: authInfoService,
|
||||||
|
singleFlightGroup: new(singleflight.Group),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,59 +63,17 @@ func (o *Service) GetCurrentOAuthToken(ctx context.Context, usr *user.SignedInUs
|
|||||||
if err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery); err != nil {
|
if err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery); err != nil {
|
||||||
if errors.Is(err, user.ErrUserNotFound) {
|
if errors.Is(err, user.ErrUserNotFound) {
|
||||||
// Not necessarily an error. User may be logged in another way.
|
// Not necessarily an error. User may be logged in another way.
|
||||||
logger.Debug("no OAuth token for user found", "userId", usr.UserID, "username", usr.Login)
|
logger.Debug("no oauth token for user found", "userId", usr.UserID, "username", usr.Login)
|
||||||
} else {
|
} else {
|
||||||
logger.Error("failed to get OAuth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
|
logger.Error("failed to get oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
authProvider := authInfoQuery.Result.AuthModule
|
token, err := o.tryGetOrRefreshAccessToken(ctx, authInfoQuery.Result)
|
||||||
connect, err := o.SocialService.GetConnector(authProvider)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("failed to get OAuth connector", "provider", authProvider, "error", err)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := o.SocialService.GetOAuthHttpClient(authProvider)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to get OAuth http client", "provider", authProvider, "error", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
|
||||||
|
|
||||||
persistedToken := &oauth2.Token{
|
|
||||||
AccessToken: authInfoQuery.Result.OAuthAccessToken,
|
|
||||||
Expiry: authInfoQuery.Result.OAuthExpiry,
|
|
||||||
RefreshToken: authInfoQuery.Result.OAuthRefreshToken,
|
|
||||||
TokenType: authInfoQuery.Result.OAuthTokenType,
|
|
||||||
}
|
|
||||||
|
|
||||||
if authInfoQuery.Result.OAuthIdToken != "" {
|
|
||||||
persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": authInfoQuery.Result.OAuthIdToken})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenSource handles refreshing the token if it has expired
|
|
||||||
token, err := connect.TokenSource(ctx, persistedToken).Token()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("failed to retrieve OAuth access token", "provider", authInfoQuery.Result.AuthModule, "userId", usr.UserID, "username", usr.Login, "error", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the tokens are not the same, update the entry in the DB
|
|
||||||
if !tokensEq(persistedToken, token) {
|
|
||||||
updateAuthCommand := &models.UpdateAuthInfoCommand{
|
|
||||||
UserId: authInfoQuery.Result.UserId,
|
|
||||||
AuthModule: authInfoQuery.Result.AuthModule,
|
|
||||||
AuthId: authInfoQuery.Result.AuthId,
|
|
||||||
OAuthToken: token,
|
|
||||||
}
|
|
||||||
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
|
|
||||||
logger.Error("failed to update auth info during token refresh", "userId", usr.UserID, "username", usr.Login, "error", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
logger.Debug("updated OAuth info for user", "userId", usr.UserID, "username", usr.Login)
|
|
||||||
}
|
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,6 +82,128 @@ func (o *Service) IsOAuthPassThruEnabled(ds *datasources.DataSource) bool {
|
|||||||
return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool()
|
return ds.JsonData != nil && ds.JsonData.Get("oauthPassThru").MustBool()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HasOAuthEntry returns true and the UserAuth object when OAuth info exists for the specified User
|
||||||
|
func (o *Service) HasOAuthEntry(ctx context.Context, usr *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||||
|
if usr == nil {
|
||||||
|
// No user, therefore no token
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
authInfoQuery := &models.GetAuthInfoQuery{UserId: usr.UserID}
|
||||||
|
err := o.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, user.ErrUserNotFound) {
|
||||||
|
// Not necessarily an error. User may be logged in another way.
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
logger.Error("failed to fetch oauth token for user", "userId", usr.UserID, "username", usr.Login, "error", err)
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if !strings.Contains(authInfoQuery.Result.AuthModule, "oauth") {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
return authInfoQuery.Result, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryTokenRefresh returns an error in case the OAuth token refresh was unsuccessful
|
||||||
|
// It uses a singleflight.Group to prevent getting the Refresh Token multiple times for a given User
|
||||||
|
func (o *Service) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
lockKey := fmt.Sprintf("oauth-refresh-token-%d", usr.UserId)
|
||||||
|
_, err, _ := o.singleFlightGroup.Do(lockKey, func() (interface{}, error) {
|
||||||
|
logger.Debug("singleflight request for getting a new access token", "key", lockKey)
|
||||||
|
authProvider := usr.AuthModule
|
||||||
|
|
||||||
|
if !strings.Contains(authProvider, "oauth") {
|
||||||
|
logger.Error("the specified user's auth provider is not oauth", "authmodule", usr.AuthModule, "userid", usr.UserId)
|
||||||
|
return nil, ErrNotAnOAuthProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
if usr.OAuthRefreshToken == "" {
|
||||||
|
logger.Debug("no refresh token available", "authmodule", usr.AuthModule, "userid", usr.UserId)
|
||||||
|
return nil, ErrNoRefreshTokenFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return o.tryGetOrRefreshAccessToken(ctx, usr)
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateOAuthTokens invalidates the OAuth tokens (access_token, refresh_token) and sets the Expiry to default/zero
|
||||||
|
func (o *Service) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
return o.AuthInfoService.UpdateAuthInfo(ctx, &models.UpdateAuthInfoCommand{
|
||||||
|
UserId: usr.UserId,
|
||||||
|
AuthModule: usr.AuthModule,
|
||||||
|
AuthId: usr.AuthId,
|
||||||
|
OAuthToken: &oauth2.Token{
|
||||||
|
AccessToken: "",
|
||||||
|
RefreshToken: "",
|
||||||
|
Expiry: time.Time{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Service) tryGetOrRefreshAccessToken(ctx context.Context, usr *models.UserAuth) (*oauth2.Token, error) {
|
||||||
|
authProvider := usr.AuthModule
|
||||||
|
connect, err := o.SocialService.GetConnector(authProvider)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to get oauth connector", "provider", authProvider, "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := o.SocialService.GetOAuthHttpClient(authProvider)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to get oauth http client", "provider", authProvider, "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||||
|
|
||||||
|
persistedToken := &oauth2.Token{
|
||||||
|
AccessToken: usr.OAuthAccessToken,
|
||||||
|
Expiry: usr.OAuthExpiry,
|
||||||
|
RefreshToken: usr.OAuthRefreshToken,
|
||||||
|
TokenType: usr.OAuthTokenType,
|
||||||
|
}
|
||||||
|
|
||||||
|
if usr.OAuthIdToken != "" {
|
||||||
|
persistedToken = persistedToken.WithExtra(map[string]interface{}{"id_token": usr.OAuthIdToken})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenSource handles refreshing the token if it has expired
|
||||||
|
token, err := connect.TokenSource(ctx, persistedToken).Token()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to retrieve oauth access token", "provider", usr.AuthModule, "userId", usr.UserId, "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the tokens are not the same, update the entry in the DB
|
||||||
|
if !tokensEq(persistedToken, token) {
|
||||||
|
updateAuthCommand := &models.UpdateAuthInfoCommand{
|
||||||
|
UserId: usr.UserId,
|
||||||
|
AuthModule: usr.AuthModule,
|
||||||
|
AuthId: usr.AuthId,
|
||||||
|
OAuthToken: token,
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.Cfg.Env == setting.Dev {
|
||||||
|
logger.Debug("oauth got token",
|
||||||
|
"user", usr.UserId,
|
||||||
|
"auth_module", usr.AuthModule,
|
||||||
|
"expiry", fmt.Sprintf("%v", token.Expiry),
|
||||||
|
"access_token", fmt.Sprintf("%v", token.AccessToken),
|
||||||
|
"refresh_token", fmt.Sprintf("%v", token.RefreshToken),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := o.AuthInfoService.UpdateAuthInfo(ctx, updateAuthCommand); err != nil {
|
||||||
|
logger.Error("failed to update auth info during token refresh", "userId", usr.UserId, "error", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
logger.Debug("updated oauth info for user", "userId", usr.UserId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
|
// tokensEq checks for OAuth2 token equivalence given the fields of the struct Grafana is interested in
|
||||||
func tokensEq(t1, t2 *oauth2.Token) bool {
|
func tokensEq(t1, t2 *oauth2.Token) bool {
|
||||||
return t1.AccessToken == t2.AccessToken &&
|
return t1.AccessToken == t2.AccessToken &&
|
||||||
|
|||||||
374
pkg/services/oauthtoken/oauth_token_test.go
Normal file
374
pkg/services/oauthtoken/oauth_token_test.go
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
package oauthtoken
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||||
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
|
"github.com/grafana/grafana/pkg/models"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
|
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
|
||||||
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestService_HasOAuthEntry(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
user *user.SignedInUser
|
||||||
|
want *models.UserAuth
|
||||||
|
wantExist bool
|
||||||
|
wantErr bool
|
||||||
|
err error
|
||||||
|
getAuthInfoErr error
|
||||||
|
getAuthInfoUser models.UserAuth
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns false without an error in case user is nil",
|
||||||
|
user: nil,
|
||||||
|
want: nil,
|
||||||
|
wantExist: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns false and an error in case GetAuthInfo returns an error",
|
||||||
|
user: &user.SignedInUser{},
|
||||||
|
want: nil,
|
||||||
|
wantExist: false,
|
||||||
|
wantErr: true,
|
||||||
|
getAuthInfoErr: errors.New("error"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns false without an error in case auth entry is not found",
|
||||||
|
user: &user.SignedInUser{},
|
||||||
|
want: nil,
|
||||||
|
wantExist: false,
|
||||||
|
wantErr: false,
|
||||||
|
getAuthInfoErr: user.ErrUserNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns false without an error in case the auth entry is not oauth",
|
||||||
|
user: &user.SignedInUser{},
|
||||||
|
want: nil,
|
||||||
|
wantExist: false,
|
||||||
|
wantErr: false,
|
||||||
|
getAuthInfoUser: models.UserAuth{AuthModule: "auth_saml"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "returns true when the auth entry is found",
|
||||||
|
user: &user.SignedInUser{},
|
||||||
|
want: &models.UserAuth{AuthModule: "oauth_generic_oauth"},
|
||||||
|
wantExist: true,
|
||||||
|
wantErr: false,
|
||||||
|
getAuthInfoUser: models.UserAuth{AuthModule: "oauth_generic_oauth"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
srv, authInfoStore, _ := setupOAuthTokenService(t)
|
||||||
|
authInfoStore.ExpectedOAuth = &tc.getAuthInfoUser
|
||||||
|
authInfoStore.ExpectedError = tc.getAuthInfoErr
|
||||||
|
|
||||||
|
entry, exists, err := srv.HasOAuthEntry(context.Background(), tc.user)
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.want != nil {
|
||||||
|
assert.True(t, reflect.DeepEqual(tc.want, entry))
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.wantExist, exists)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService_TryTokenRefresh_ValidToken(t *testing.T) {
|
||||||
|
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: "testaccess",
|
||||||
|
RefreshToken: "testrefresh",
|
||||||
|
Expiry: time.Now(),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
usr := &models.UserAuth{
|
||||||
|
AuthModule: "oauth_generic_oauth",
|
||||||
|
OAuthAccessToken: token.AccessToken,
|
||||||
|
OAuthRefreshToken: token.RefreshToken,
|
||||||
|
OAuthExpiry: token.Expiry,
|
||||||
|
OAuthTokenType: token.TokenType,
|
||||||
|
}
|
||||||
|
|
||||||
|
authInfoStore.ExpectedOAuth = usr
|
||||||
|
|
||||||
|
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
||||||
|
|
||||||
|
err := srv.TryTokenRefresh(ctx, usr)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
|
||||||
|
|
||||||
|
authInfoQuery := &models.GetAuthInfoQuery{}
|
||||||
|
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
// User's token data had not been updated
|
||||||
|
resultUsr := authInfoQuery.Result
|
||||||
|
assert.Equal(t, resultUsr.OAuthAccessToken, token.AccessToken)
|
||||||
|
assert.Equal(t, resultUsr.OAuthExpiry, token.Expiry)
|
||||||
|
assert.Equal(t, resultUsr.OAuthRefreshToken, token.RefreshToken)
|
||||||
|
assert.Equal(t, resultUsr.OAuthTokenType, token.TokenType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService_TryTokenRefresh_NoRefreshToken(t *testing.T) {
|
||||||
|
srv, _, socialConnector := setupOAuthTokenService(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: "testaccess",
|
||||||
|
RefreshToken: "",
|
||||||
|
Expiry: time.Now().Add(-time.Hour),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
usr := &models.UserAuth{
|
||||||
|
AuthModule: "oauth_generic_oauth",
|
||||||
|
OAuthAccessToken: token.AccessToken,
|
||||||
|
OAuthRefreshToken: token.RefreshToken,
|
||||||
|
OAuthExpiry: token.Expiry,
|
||||||
|
OAuthTokenType: token.TokenType,
|
||||||
|
}
|
||||||
|
|
||||||
|
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
||||||
|
|
||||||
|
err := srv.TryTokenRefresh(ctx, usr)
|
||||||
|
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNoRefreshTokenFound)
|
||||||
|
|
||||||
|
socialConnector.AssertNotCalled(t, "TokenSource")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService_TryTokenRefresh_ExpiredToken(t *testing.T) {
|
||||||
|
srv, authInfoStore, socialConnector := setupOAuthTokenService(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: "testaccess",
|
||||||
|
RefreshToken: "testrefresh",
|
||||||
|
Expiry: time.Now().Add(-time.Hour),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
|
||||||
|
newToken := &oauth2.Token{
|
||||||
|
AccessToken: "testaccess_new",
|
||||||
|
RefreshToken: "testrefresh_new",
|
||||||
|
Expiry: time.Now().Add(time.Hour),
|
||||||
|
TokenType: "Bearer",
|
||||||
|
}
|
||||||
|
|
||||||
|
usr := &models.UserAuth{
|
||||||
|
AuthModule: "oauth_generic_oauth",
|
||||||
|
OAuthAccessToken: token.AccessToken,
|
||||||
|
OAuthRefreshToken: token.RefreshToken,
|
||||||
|
OAuthExpiry: token.Expiry,
|
||||||
|
OAuthTokenType: token.TokenType,
|
||||||
|
}
|
||||||
|
|
||||||
|
authInfoStore.ExpectedOAuth = usr
|
||||||
|
|
||||||
|
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.ReuseTokenSource(token, oauth2.StaticTokenSource(newToken)), nil)
|
||||||
|
|
||||||
|
err := srv.TryTokenRefresh(ctx, usr)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
socialConnector.AssertNumberOfCalls(t, "TokenSource", 1)
|
||||||
|
|
||||||
|
authInfoQuery := &models.GetAuthInfoQuery{}
|
||||||
|
err = srv.AuthInfoService.GetAuthInfo(ctx, authInfoQuery)
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
// newToken should be returned after the .Token() call, therefore the User had to be updated
|
||||||
|
assert.Equal(t, authInfoQuery.Result.OAuthAccessToken, newToken.AccessToken)
|
||||||
|
assert.Equal(t, authInfoQuery.Result.OAuthExpiry, newToken.Expiry)
|
||||||
|
assert.Equal(t, authInfoQuery.Result.OAuthRefreshToken, newToken.RefreshToken)
|
||||||
|
assert.Equal(t, authInfoQuery.Result.OAuthTokenType, newToken.TokenType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestService_TryTokenRefresh_DifferentAuthModuleForUser(t *testing.T) {
|
||||||
|
srv, _, socialConnector := setupOAuthTokenService(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
token := &oauth2.Token{}
|
||||||
|
usr := &models.UserAuth{
|
||||||
|
AuthModule: "auth.saml",
|
||||||
|
}
|
||||||
|
|
||||||
|
socialConnector.On("TokenSource", mock.Anything, mock.Anything).Return(oauth2.StaticTokenSource(token))
|
||||||
|
|
||||||
|
err := srv.TryTokenRefresh(ctx, usr)
|
||||||
|
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrNotAnOAuthProvider)
|
||||||
|
|
||||||
|
socialConnector.AssertNotCalled(t, "TokenSource")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupOAuthTokenService(t *testing.T) (*Service, *FakeAuthInfoStore, *MockSocialConnector) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
socialConnector := &MockSocialConnector{}
|
||||||
|
socialService := &FakeSocialService{
|
||||||
|
connector: socialConnector,
|
||||||
|
}
|
||||||
|
|
||||||
|
authInfoStore := &FakeAuthInfoStore{}
|
||||||
|
authInfoService := authinfoservice.ProvideAuthInfoService(nil, authInfoStore, &usagestats.UsageStatsMock{})
|
||||||
|
return &Service{
|
||||||
|
Cfg: setting.NewCfg(),
|
||||||
|
SocialService: socialService,
|
||||||
|
AuthInfoService: authInfoService,
|
||||||
|
singleFlightGroup: &singleflight.Group{},
|
||||||
|
}, authInfoStore, socialConnector
|
||||||
|
}
|
||||||
|
|
||||||
|
type FakeSocialService struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
connector *MockSocialConnector
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fss *FakeSocialService) GetOAuthProviders() map[string]bool {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fss *FakeSocialService) GetOAuthHttpClient(string) (*http.Client, error) {
|
||||||
|
return fss.httpClient, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fss *FakeSocialService) GetConnector(string) (social.SocialConnector, error) {
|
||||||
|
return fss.connector, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fss *FakeSocialService) GetOAuthInfoProvider(string) *social.OAuthInfo {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fss *FakeSocialService) GetOAuthInfoProviders() map[string]*social.OAuthInfo {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockSocialConnector struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocialConnector) Type() int {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Int(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocialConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
||||||
|
args := m.Called(client, token)
|
||||||
|
return args.Get(0).(*social.BasicUserInfo), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocialConnector) IsEmailAllowed(email string) bool {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocialConnector) IsSignupAllowed() bool {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocialConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocialConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocialConnector) Client(ctx context.Context, t *oauth2.Token) *http.Client {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSocialConnector) TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource {
|
||||||
|
args := m.Called(ctx, t)
|
||||||
|
return args.Get(0).(oauth2.TokenSource)
|
||||||
|
}
|
||||||
|
|
||||||
|
type FakeAuthInfoStore struct {
|
||||||
|
ExpectedError error
|
||||||
|
ExpectedUser *user.User
|
||||||
|
ExpectedOAuth *models.UserAuth
|
||||||
|
ExpectedDuplicateUserEntries int
|
||||||
|
ExpectedHasDuplicateUserEntries int
|
||||||
|
ExpectedLoginStats login.LoginStats
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *models.GetExternalUserInfoByLoginQuery) error {
|
||||||
|
return f.ExpectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *models.GetAuthInfoQuery) error {
|
||||||
|
query.Result = f.ExpectedOAuth
|
||||||
|
return f.ExpectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error {
|
||||||
|
return f.ExpectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *models.UserAuth) error {
|
||||||
|
return f.ExpectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error {
|
||||||
|
f.ExpectedOAuth.OAuthAccessToken = cmd.OAuthToken.AccessToken
|
||||||
|
f.ExpectedOAuth.OAuthExpiry = cmd.OAuthToken.Expiry
|
||||||
|
f.ExpectedOAuth.OAuthTokenType = cmd.OAuthToken.TokenType
|
||||||
|
f.ExpectedOAuth.OAuthRefreshToken = cmd.OAuthToken.RefreshToken
|
||||||
|
return f.ExpectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAuthInfoCommand) error {
|
||||||
|
return f.ExpectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) {
|
||||||
|
return f.ExpectedUser, f.ExpectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) {
|
||||||
|
return f.ExpectedUser, f.ExpectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
|
||||||
|
return f.ExpectedUser, f.ExpectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]interface{}, error) {
|
||||||
|
var res = make(map[string]interface{})
|
||||||
|
res["stats.users.duplicate_user_entries"] = f.ExpectedDuplicateUserEntries
|
||||||
|
res["stats.users.has_duplicate_user_entries"] = f.ExpectedHasDuplicateUserEntries
|
||||||
|
res["stats.users.duplicate_user_entries_by_login"] = 0
|
||||||
|
res["stats.users.has_duplicate_user_entries_by_login"] = 0
|
||||||
|
res["stats.users.duplicate_user_entries_by_email"] = 0
|
||||||
|
res["stats.users.has_duplicate_user_entries_by_email"] = 0
|
||||||
|
res["stats.users.mixed_cased_users"] = f.ExpectedLoginStats.MixedCasedUsers
|
||||||
|
return res, f.ExpectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) RunMetricsCollection(ctx context.Context) error {
|
||||||
|
return f.ExpectedError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FakeAuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, error) {
|
||||||
|
return f.ExpectedLoginStats, f.ExpectedError
|
||||||
|
}
|
||||||
@@ -164,6 +164,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource)
|
|||||||
return ts.passThruEnabled
|
return ts.passThruEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *fakeOAuthTokenService) TryTokenRefresh(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(ctx context.Context, usr *models.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// copied from pkg/api/plugins_test.go
|
// copied from pkg/api/plugins_test.go
|
||||||
type fakePluginClient struct {
|
type fakePluginClient struct {
|
||||||
plugins.Client
|
plugins.Client
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
"github.com/grafana/grafana-plugin-sdk-go/backend"
|
||||||
"github.com/grafana/grafana/pkg/expr"
|
"github.com/grafana/grafana/pkg/expr"
|
||||||
|
"github.com/grafana/grafana/pkg/models"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -440,6 +441,18 @@ func (ts *fakeOAuthTokenService) IsOAuthPassThruEnabled(*datasources.DataSource)
|
|||||||
return ts.passThruEnabled
|
return ts.passThruEnabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts *fakeOAuthTokenService) HasOAuthEntry(context.Context, *user.SignedInUser) (*models.UserAuth, bool, error) {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *fakeOAuthTokenService) TryTokenRefresh(context.Context, *models.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *fakeOAuthTokenService) InvalidateOAuthTokens(context.Context, *models.UserAuth) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type fakeDataSourceCache struct {
|
type fakeDataSourceCache struct {
|
||||||
ds *datasources.DataSource
|
ds *datasources.DataSource
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user