From 7c1d9769ca5f658d2e1b61d9b07f201ea4231142 Mon Sep 17 00:00:00 2001 From: Misi Date: Thu, 2 Feb 2023 14:36:16 +0100 Subject: [PATCH] Auth: Rotate token patch (#62676) * Use singleflight.Group * Align tests * Cleanup --- pkg/middleware/middleware_test.go | 7 +- pkg/services/auth/auth.go | 2 +- pkg/services/auth/authimpl/auth_token.go | 142 +++++++++++------- pkg/services/auth/authimpl/auth_token_test.go | 26 ++-- pkg/services/auth/authtest/testing.go | 8 +- pkg/services/authn/clients/session.go | 3 +- pkg/services/authn/clients/session_test.go | 4 +- pkg/services/contexthandler/contexthandler.go | 10 +- .../contexthandler/contexthandler_test.go | 8 +- 9 files changed, 122 insertions(+), 88 deletions(-) diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index eaabf0085d1..bb606ebbf76 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -13,10 +13,11 @@ import ( "testing" "time" - "github.com/grafana/grafana-plugin-sdk-go/backend/gtime" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/grafana/grafana-plugin-sdk-go/backend/gtime" + "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/infra/db/dbtest" "github.com/grafana/grafana/pkg/infra/fs" @@ -344,9 +345,9 @@ func TestMiddlewareContext(t *testing.T) { } sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *auth.UserToken, - clientIP net.IP, userAgent string) (bool, error) { + clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) { userToken.UnhashedToken = "rotated" - return true, nil + return true, userToken, nil } maxAge := int(sc.cfg.LoginMaxLifetime.Seconds()) diff --git a/pkg/services/auth/auth.go b/pkg/services/auth/auth.go index 356e2426de1..1bf08cc3d31 100644 --- a/pkg/services/auth/auth.go +++ b/pkg/services/auth/auth.go @@ -62,7 +62,7 @@ type RevokeAuthTokenCmd struct { type UserTokenService interface { CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*UserToken, error) LookupToken(ctx context.Context, unhashedToken string) (*UserToken, error) - TryRotateToken(ctx context.Context, token *UserToken, clientIP net.IP, userAgent string) (bool, error) + TryRotateToken(ctx context.Context, token *UserToken, clientIP net.IP, userAgent string) (bool, *UserToken, error) RevokeToken(ctx context.Context, token *UserToken, soft bool) error RevokeAllUserTokens(ctx context.Context, userId int64) error GetUserToken(ctx context.Context, userId, userTokenId int64) (*UserToken, error) diff --git a/pkg/services/auth/authimpl/auth_token.go b/pkg/services/auth/authimpl/auth_token.go index 991143f538a..4f9061d4c79 100644 --- a/pkg/services/auth/authimpl/auth_token.go +++ b/pkg/services/auth/authimpl/auth_token.go @@ -5,10 +5,13 @@ import ( "crypto/sha256" "encoding/hex" "errors" + "fmt" "net" "strings" "time" + "golang.org/x/sync/singleflight" + "github.com/grafana/grafana/pkg/infra/db" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/remotecache" @@ -41,6 +44,7 @@ func ProvideUserAuthTokenService(sqlStore db.DB, log: log.New("auth"), remoteCache: remoteCache, features: features, + singleflight: new(singleflight.Group), } defaultLimits, err := readQuotaConfig(cfg) @@ -68,6 +72,7 @@ type UserAuthTokenService struct { log log.Logger remoteCache *remotecache.RemoteCache features *featuremgmt.FeatureManager + singleflight *singleflight.Group } func (s *UserAuthTokenService) CreateToken(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error) { @@ -202,6 +207,7 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st } } + // Current incoming token is the previous auth token in the DB and the auth_token_seen is true if model.AuthToken != hashedToken && model.PrevAuthToken == hashedToken && model.AuthTokenSeen { modelCopy := model modelCopy.AuthTokenSeen = false @@ -229,6 +235,7 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st } } + // Current incoming token is not seen and it is the latest valid auth token in the db if !model.AuthTokenSeen && model.AuthToken == hashedToken { modelCopy := model modelCopy.AuthTokenSeen = true @@ -268,83 +275,102 @@ func (s *UserAuthTokenService) lookupToken(ctx context.Context, unhashedToken st } func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.UserToken, - clientIP net.IP, userAgent string) (bool, error) { + clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) { if token == nil { - return false, nil + return false, nil, nil } model, err := userAuthTokenFromUserToken(token) if err != nil { - return false, err + return false, nil, err } now := getTime() - var needsRotation bool - rotatedAt := time.Unix(model.RotatedAt, 0) - if model.AuthTokenSeen { - needsRotation = rotatedAt.Before(now.Add(-time.Duration(s.cfg.TokenRotationIntervalMinutes) * time.Minute)) - } else { - needsRotation = rotatedAt.Before(now.Add(-urgentRotateTime)) + type rotationResult struct { + rotated bool + newToken *auth.UserToken } - if !needsRotation { - return false, nil - } - - ctxLogger := s.log.FromContext(ctx) - ctxLogger.Debug("token needs rotation", "tokenId", model.Id, "authTokenSeen", model.AuthTokenSeen, "rotatedAt", rotatedAt) - - clientIPStr := clientIP.String() - if len(clientIP) == 0 { - clientIPStr = "" - } - newToken, err := util.RandomHex(16) - if err != nil { - return false, err - } - hashedToken := hashToken(newToken) - - // very important that auth_token_seen is set after the prev_auth_token = case when ... for mysql to function correctly - sql := ` - UPDATE user_auth_token - SET - seen_at = 0, - user_agent = ?, - client_ip = ?, - prev_auth_token = case when auth_token_seen = ? then auth_token else prev_auth_token end, - auth_token = ?, - auth_token_seen = ?, - rotated_at = ? - WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)` - - var affected int64 - err = s.sqlStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error { - res, err := dbSession.Exec(sql, userAgent, clientIPStr, s.sqlStore.GetDialect().BooleanStr(true), hashedToken, - s.sqlStore.GetDialect().BooleanStr(false), now.Unix(), model.Id, s.sqlStore.GetDialect().BooleanStr(true), - now.Add(-30*time.Second).Unix()) - if err != nil { - return err + rotResult, err, _ := s.singleflight.Do(fmt.Sprint(model.Id), func() (interface{}, error) { + var needsRotation bool + rotatedAt := time.Unix(model.RotatedAt, 0) + if model.AuthTokenSeen { + needsRotation = rotatedAt.Before(now.Add(-time.Duration(s.cfg.TokenRotationIntervalMinutes) * time.Minute)) + } else { + needsRotation = rotatedAt.Before(now.Add(-urgentRotateTime)) } - affected, err = res.RowsAffected() - return err + if !needsRotation { + return &rotationResult{rotated: false}, nil + } + + ctxLogger := s.log.FromContext(ctx) + ctxLogger.Debug("token needs rotation", "tokenId", model.Id, "authTokenSeen", model.AuthTokenSeen, "rotatedAt", rotatedAt) + + clientIPStr := clientIP.String() + if len(clientIP) == 0 { + clientIPStr = "" + } + newToken, err := util.RandomHex(16) + if err != nil { + return nil, err + } + hashedToken := hashToken(newToken) + + // very important that auth_token_seen is set after the prev_auth_token = case when ... for mysql to function correctly + sql := ` + UPDATE user_auth_token + SET + seen_at = 0, + user_agent = ?, + client_ip = ?, + prev_auth_token = case when auth_token_seen = ? then auth_token else prev_auth_token end, + auth_token = ?, + auth_token_seen = ?, + rotated_at = ? + WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)` + + var affected int64 + err = s.sqlStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error { + res, err := dbSession.Exec(sql, userAgent, clientIPStr, s.sqlStore.GetDialect().BooleanStr(true), hashedToken, + s.sqlStore.GetDialect().BooleanStr(false), now.Unix(), model.Id, s.sqlStore.GetDialect().BooleanStr(true), + now.Add(-30*time.Second).Unix()) + if err != nil { + return err + } + + affected, err = res.RowsAffected() + return err + }) + + if err != nil { + return nil, err + } + + if affected > 0 { + ctxLogger.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId) + model.UnhashedToken = newToken + var result auth.UserToken + if err := model.toUserToken(&result); err != nil { + return nil, err + } + return &rotationResult{ + rotated: true, + newToken: &result, + }, nil + } + + return &rotationResult{rotated: false}, nil }) if err != nil { - return false, err + return false, nil, err } - ctxLogger.Debug("auth token rotated", "affected", affected, "auth_token_id", model.Id, "userId", model.UserId) - if affected > 0 { - model.UnhashedToken = newToken - if err := model.toUserToken(token); err != nil { - return false, err - } - return true, nil - } + result := rotResult.(*rotationResult) - return false, nil + return result.rotated, result.newToken, nil } func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *auth.UserToken, soft bool) error { diff --git a/pkg/services/auth/authimpl/auth_token_test.go b/pkg/services/auth/authimpl/auth_token_test.go index 4220a30d9cf..4c919e3dc40 100644 --- a/pkg/services/auth/authimpl/auth_token_test.go +++ b/pkg/services/auth/authimpl/auth_token_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/stretchr/testify/require" + "golang.org/x/sync/singleflight" "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/infra/db" @@ -178,7 +179,7 @@ func TestUserAuthToken(t *testing.T) { getTime = func() time.Time { return now.Add(time.Hour) } - rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, + rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, net.ParseIP("192.168.10.11"), "some user agent") require.Nil(t, err) require.True(t, rotated) @@ -262,7 +263,7 @@ func TestUserAuthToken(t *testing.T) { prevToken := userToken.AuthToken unhashedPrev := userToken.UnhashedToken - rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, + rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, net.ParseIP("192.168.10.12"), "a new user agent") require.Nil(t, err) require.False(t, rotated) @@ -280,12 +281,12 @@ func TestUserAuthToken(t *testing.T) { getTime = func() time.Time { return now.Add(time.Hour) } - rotated, err = ctx.tokenService.TryRotateToken(context.Background(), &tok, + rotated, newToken, err := ctx.tokenService.TryRotateToken(context.Background(), &tok, net.ParseIP("192.168.10.12"), "a new user agent") require.Nil(t, err) require.True(t, rotated) - unhashedToken := tok.UnhashedToken + unhashedToken := newToken.UnhashedToken model, err = ctx.getAuthTokenByID(tok.Id) require.Nil(t, err) @@ -326,7 +327,7 @@ func TestUserAuthToken(t *testing.T) { require.NotNil(t, lookedUpModel) require.False(t, lookedUpModel.AuthTokenSeen) - rotated, err = ctx.tokenService.TryRotateToken(context.Background(), userToken, + rotated, _, err = ctx.tokenService.TryRotateToken(context.Background(), userToken, net.ParseIP("192.168.10.12"), "a new user agent") require.Nil(t, err) require.True(t, rotated) @@ -351,7 +352,7 @@ func TestUserAuthToken(t *testing.T) { getTime = func() time.Time { return now.Add(10 * time.Minute) } prevToken := userToken.UnhashedToken - rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, + rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, net.ParseIP("1.1.1.1"), "firefox") require.Nil(t, err) require.True(t, rotated) @@ -407,7 +408,7 @@ func TestUserAuthToken(t *testing.T) { return now.Add(10 * time.Minute) } - rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, + rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, net.ParseIP("1.1.1.1"), "firefox") require.Nil(t, err) require.True(t, rotated) @@ -429,7 +430,7 @@ func TestUserAuthToken(t *testing.T) { return now.Add(20 * time.Minute) } - rotated, err = ctx.tokenService.TryRotateToken(context.Background(), userToken, + rotated, _, err = ctx.tokenService.TryRotateToken(context.Background(), userToken, net.ParseIP("1.1.1.1"), "firefox") require.Nil(t, err) require.True(t, rotated) @@ -456,7 +457,7 @@ func TestUserAuthToken(t *testing.T) { return now.Add(2 * time.Minute) } - rotated, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, + rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, net.ParseIP("1.1.1.1"), "firefox") require.Nil(t, err) require.True(t, rotated) @@ -550,9 +551,10 @@ func createTestContext(t *testing.T) *testContext { } tokenService := &UserAuthTokenService{ - sqlStore: sqlstore, - cfg: cfg, - log: log.New("test-logger"), + sqlStore: sqlstore, + cfg: cfg, + log: log.New("test-logger"), + singleflight: new(singleflight.Group), } return &testContext{ diff --git a/pkg/services/auth/authtest/testing.go b/pkg/services/auth/authtest/testing.go index 6a924c5ebf4..a092c6e997e 100644 --- a/pkg/services/auth/authtest/testing.go +++ b/pkg/services/auth/authtest/testing.go @@ -15,7 +15,7 @@ import ( type FakeUserAuthTokenService struct { CreateTokenProvider func(ctx context.Context, user *user.User, clientIP net.IP, userAgent string) (*auth.UserToken, error) - TryRotateTokenProvider func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) + TryRotateTokenProvider func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) LookupTokenProvider func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) RevokeTokenProvider func(ctx context.Context, token *auth.UserToken, soft bool) error RevokeAllUserTokensProvider func(ctx context.Context, userId int64) error @@ -34,8 +34,8 @@ func NewFakeUserAuthTokenService() *FakeUserAuthTokenService { UnhashedToken: "", }, nil }, - TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) { - return false, nil + TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) { + return false, nil, nil }, LookupTokenProvider: func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) { return &auth.UserToken{ @@ -79,7 +79,7 @@ func (s *FakeUserAuthTokenService) LookupToken(ctx context.Context, unhashedToke } func (s *FakeUserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.UserToken, clientIP net.IP, - userAgent string) (bool, error) { + userAgent string) (bool, *auth.UserToken, error) { return s.TryRotateTokenProvider(context.Background(), token, clientIP, userAgent) } diff --git a/pkg/services/authn/clients/session.go b/pkg/services/authn/clients/session.go index 8c26fb60332..2af8d35bc72 100644 --- a/pkg/services/authn/clients/session.go +++ b/pkg/services/authn/clients/session.go @@ -107,13 +107,14 @@ func (s *Session) RefreshTokenHook(ctx context.Context, identity *authn.Identity s.log.Debug("failed to get client IP address", "addr", addr, "err", err) ip = nil } - rotated, err := s.sessionService.TryRotateToken(ctx, identity.SessionToken, ip, userAgent) + rotated, newToken, err := s.sessionService.TryRotateToken(ctx, identity.SessionToken, ip, userAgent) if err != nil { s.log.Error("failed to rotate token", "error", err) return } if rotated { + identity.SessionToken = newToken s.log.Debug("rotated session token", "user", identity.ID) maxAge := int(s.loginMaxLifetime.Seconds()) diff --git a/pkg/services/authn/clients/session_test.go b/pkg/services/authn/clients/session_test.go index 443f6371a6c..4809c03fb54 100644 --- a/pkg/services/authn/clients/session_test.go +++ b/pkg/services/authn/clients/session_test.go @@ -143,9 +143,9 @@ func (f *fakeResponseWriter) WriteHeader(statusCode int) { func TestSession_RefreshHook(t *testing.T) { s := ProvideSession(&authtest.FakeUserAuthTokenService{ - TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, error) { + TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) { token.UnhashedToken = "new-token" - return true, nil + return true, token, nil }, }, &usertest.FakeUserService{}, "grafana-session", 20*time.Second) diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index 14f3398ec85..530047f2e25 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "golang.org/x/sync/singleflight" + "github.com/grafana/grafana/pkg/components/apikeygen" apikeygenprefix "github.com/grafana/grafana/pkg/components/apikeygenprefixed" "github.com/grafana/grafana/pkg/infra/db" @@ -70,6 +72,7 @@ func ProvideService(cfg *setting.Cfg, tokenService auth.UserTokenService, jwtSer oauthTokenService: oauthTokenService, features: features, authnService: authnService, + singleflight: new(singleflight.Group), } } @@ -91,6 +94,7 @@ type ContextHandler struct { oauthTokenService oauthtoken.OAuthTokenService features *featuremgmt.FeatureManager authnService authn.Service + singleflight *singleflight.Group // GetTime returns the current time. // Stubbable by tests. GetTime func() time.Time @@ -568,15 +572,15 @@ func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *contextmodel.ReqCont ip = nil } - // FIXME (jguer): rotation should return a new token instead of modifying the existing one. - rotated, err := h.AuthTokenService.TryRotateToken(ctx, reqContext.UserToken, ip, reqContext.Req.UserAgent()) + rotated, newToken, err := h.AuthTokenService.TryRotateToken(ctx, reqContext.UserToken, ip, reqContext.Req.UserAgent()) if err != nil { reqContext.Logger.Error("Failed to rotate token", "error", err) return } if rotated { - cookies.WriteSessionCookie(reqContext, h.Cfg, reqContext.UserToken.UnhashedToken, h.Cfg.LoginMaxLifetime) + reqContext.UserToken = newToken + cookies.WriteSessionCookie(reqContext, h.Cfg, newToken.UnhashedToken, h.Cfg.LoginMaxLifetime) } } } diff --git a/pkg/services/contexthandler/contexthandler_test.go b/pkg/services/contexthandler/contexthandler_test.go index efdb5ccb275..145fb3c64de 100644 --- a/pkg/services/contexthandler/contexthandler_test.go +++ b/pkg/services/contexthandler/contexthandler_test.go @@ -24,9 +24,9 @@ func TestDontRotateTokensOnCancelledRequests(t *testing.T) { tryRotateCallCount := 0 ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{ TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, - userAgent string) (bool, error) { + userAgent string) (bool, *auth.UserToken, error) { tryRotateCallCount++ - return false, nil + return false, nil, nil }, } @@ -46,11 +46,11 @@ func TestTokenRotationAtEndOfRequest(t *testing.T) { ctxHdlr := getContextHandler(t) ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{ TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, - userAgent string) (bool, error) { + userAgent string) (bool, *auth.UserToken, error) { newToken, err := util.RandomHex(16) require.NoError(t, err) token.AuthToken = newToken - return true, nil + return true, token, nil }, }