diff --git a/pkg/services/auth/auth.go b/pkg/services/auth/auth.go index 4577f19994e..12481ab0b77 100644 --- a/pkg/services/auth/auth.go +++ b/pkg/services/auth/auth.go @@ -71,7 +71,6 @@ type UserTokenService interface { LookupToken(ctx context.Context, unhashedToken string) (*UserToken, error) // RotateToken will always rotate a valid token RotateToken(ctx context.Context, cmd RotateCommand) (*UserToken, error) - TryRotateToken(ctx context.Context, token *UserToken, clientIP net.IP, userAgent string) (bool, *UserToken, error) RevokeToken(ctx context.Context, token *UserToken, soft bool) error RevokeAllUserTokens(ctx context.Context, userID int64) error GetUserToken(ctx context.Context, userID, userTokenID int64) (*UserToken, error) diff --git a/pkg/services/auth/authimpl/auth_token.go b/pkg/services/auth/authimpl/auth_token.go index 2b28327e6d5..1c51e07982d 100644 --- a/pkg/services/auth/authimpl/auth_token.go +++ b/pkg/services/auth/authimpl/auth_token.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/hex" "errors" - "fmt" "net" "strings" "time" @@ -296,105 +295,6 @@ func (s *UserAuthTokenService) rotateToken(ctx context.Context, token *auth.User return token, nil } -func (s *UserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.UserToken, - clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) { - if token == nil { - return false, nil, nil - } - - model, err := userAuthTokenFromUserToken(token) - if err != nil { - return false, nil, err - } - - now := getTime() - - type rotationResult struct { - rotated bool - newToken *auth.UserToken - } - - rotResult, err, _ := s.singleflight.Do(fmt.Sprint(model.Id), func() (any, error) { - var needsRotation bool - rotatedAt := time.Unix(model.RotatedAt, 0) - if model.AuthTokenSeen { - needsRotation = rotatedAt.Before(now.Add(-time.Duration(s.cfg.TokenRotationIntervalMinutes) * time.Minute)) - } else { - needsRotation = rotatedAt.Before(now.Add(-usertoken.UrgentRotateTime)) - } - - if !needsRotation { - return &rotationResult{rotated: false}, nil - } - - ctxLogger := s.log.FromContext(ctx) - ctxLogger.Debug("Token needs rotation", "tokenID", model.Id, "authTokenSeen", model.AuthTokenSeen, "rotatedAt", rotatedAt) - - clientIPStr := clientIP.String() - if len(clientIP) == 0 { - clientIPStr = "" - } - newToken, err := util.RandomHex(16) - if err != nil { - return nil, err - } - hashedToken := hashToken(s.cfg.SecretKey, newToken) - - // very important that auth_token_seen is set after the prev_auth_token = case when ... for mysql to function correctly - sql := ` - UPDATE user_auth_token - SET - seen_at = 0, - user_agent = ?, - client_ip = ?, - prev_auth_token = case when auth_token_seen = ? then auth_token else prev_auth_token end, - auth_token = ?, - auth_token_seen = ?, - rotated_at = ? - WHERE id = ? AND (auth_token_seen = ? OR rotated_at < ?)` - - var affected int64 - err = s.sqlStore.WithTransactionalDbSession(ctx, func(dbSession *db.Session) error { - res, err := dbSession.Exec(sql, userAgent, clientIPStr, s.sqlStore.GetDialect().BooleanStr(true), hashedToken, - s.sqlStore.GetDialect().BooleanStr(false), now.Unix(), model.Id, s.sqlStore.GetDialect().BooleanStr(true), - now.Add(-30*time.Second).Unix()) - if err != nil { - return err - } - - affected, err = res.RowsAffected() - return err - }) - - if err != nil { - return nil, err - } - - if affected > 0 { - ctxLogger.Debug("Auth token rotated", "affected", affected, "tokenID", model.Id, "userID", model.UserId) - model.UnhashedToken = newToken - var result auth.UserToken - if err := model.toUserToken(&result); err != nil { - return nil, err - } - return &rotationResult{ - rotated: true, - newToken: &result, - }, nil - } - - return &rotationResult{rotated: false}, nil - }) - - if err != nil { - return false, nil, err - } - - result := rotResult.(*rotationResult) - - return result.rotated, result.newToken, nil -} - func (s *UserAuthTokenService) RevokeToken(ctx context.Context, token *auth.UserToken, soft bool) error { if token == nil { return auth.ErrUserTokenNotFound diff --git a/pkg/services/auth/authimpl/auth_token_test.go b/pkg/services/auth/authimpl/auth_token_test.go index b0a8beb94f1..97154f1f0c2 100644 --- a/pkg/services/auth/authimpl/auth_token_test.go +++ b/pkg/services/auth/authimpl/auth_token_test.go @@ -184,10 +184,12 @@ func TestIntegrationUserAuthToken(t *testing.T) { getTime = func() time.Time { return now.Add(time.Hour) } - rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, - net.ParseIP("192.168.10.11"), "some user agent") + _, err = ctx.tokenService.RotateToken(context.Background(), auth.RotateCommand{ + UnHashedToken: userToken.UnhashedToken, + IP: net.ParseIP("192.168.10.11"), + UserAgent: "some user agent", + }) require.Nil(t, err) - require.True(t, rotated) userToken, err = ctx.tokenService.LookupToken(context.Background(), userToken.UnhashedToken) require.Nil(t, err) @@ -260,41 +262,28 @@ func TestIntegrationUserAuthToken(t *testing.T) { t.Run("can properly rotate tokens", func(t *testing.T) { getTime = func() time.Time { return now } ctx := createTestContext(t) - userToken, err := ctx.tokenService.CreateToken(context.Background(), usr, - net.ParseIP("192.168.10.11"), "some user agent") + userToken, err := ctx.tokenService.CreateToken(context.Background(), usr, net.ParseIP("192.168.10.11"), "some user agent") require.Nil(t, err) prevToken := userToken.AuthToken unhashedPrev := userToken.UnhashedToken - rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, - net.ParseIP("192.168.10.12"), "a new user agent") - require.Nil(t, err) - require.False(t, rotated) - - updated, err := ctx.markAuthTokenAsSeen(userToken.Id) - require.Nil(t, err) - require.True(t, updated) - model, err := ctx.getAuthTokenByID(userToken.Id) require.Nil(t, err) - var tok auth.UserToken - err = model.toUserToken(&tok) - require.Nil(t, err) - + model.UnhashedToken = userToken.UnhashedToken getTime = func() time.Time { return now.Add(time.Hour) } - rotated, newToken, err := ctx.tokenService.TryRotateToken(context.Background(), &tok, - net.ParseIP("192.168.10.12"), "a new user agent") + newToken, err := ctx.tokenService.RotateToken(context.Background(), auth.RotateCommand{ + UnHashedToken: model.UnhashedToken, + IP: net.ParseIP("192.168.10.12"), + UserAgent: "a new user agent", + }) require.Nil(t, err) - require.True(t, rotated) - unhashedToken := newToken.UnhashedToken - - model, err = ctx.getAuthTokenByID(tok.Id) + model, err = ctx.getAuthTokenByID(model.Id) require.Nil(t, err) - model.UnhashedToken = unhashedToken + model.UnhashedToken = newToken.UnhashedToken require.Equal(t, getTime().Unix(), model.RotatedAt) require.Equal(t, "192.168.10.12", model.ClientIp) @@ -331,10 +320,12 @@ func TestIntegrationUserAuthToken(t *testing.T) { require.NotNil(t, lookedUpModel) require.False(t, lookedUpModel.AuthTokenSeen) - rotated, _, err = ctx.tokenService.TryRotateToken(context.Background(), userToken, - net.ParseIP("192.168.10.12"), "a new user agent") + _, err = ctx.tokenService.RotateToken(context.Background(), auth.RotateCommand{ + UnHashedToken: userToken.UnhashedToken, + IP: net.ParseIP("192.168.10.12"), + UserAgent: "a new user agent", + }) require.Nil(t, err) - require.True(t, rotated) model, err = ctx.getAuthTokenByID(userToken.Id) require.Nil(t, err) @@ -356,10 +347,12 @@ func TestIntegrationUserAuthToken(t *testing.T) { getTime = func() time.Time { return now.Add(10 * time.Minute) } prevToken := userToken.UnhashedToken - rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, - net.ParseIP("1.1.1.1"), "firefox") + _, err = ctx.tokenService.RotateToken(context.Background(), auth.RotateCommand{ + UnHashedToken: userToken.UnhashedToken, + IP: net.ParseIP("1.1.1.1"), + UserAgent: "firefox", + }) require.Nil(t, err) - require.True(t, rotated) getTime = func() time.Time { return now.Add(20 * time.Minute) @@ -394,87 +387,6 @@ func TestIntegrationUserAuthToken(t *testing.T) { require.True(t, lookedUpModel.AuthTokenSeen) }) - t.Run("TryRotateToken", func(t *testing.T) { - t.Run("Should rotate current token and previous token when auth token seen", func(t *testing.T) { - getTime = func() time.Time { return now } - userToken, err := ctx.tokenService.CreateToken(context.Background(), usr, - net.ParseIP("192.168.10.11"), "some user agent") - require.Nil(t, err) - require.NotNil(t, userToken) - - prevToken := userToken.AuthToken - - updated, err := ctx.markAuthTokenAsSeen(userToken.Id) - require.Nil(t, err) - require.True(t, updated) - - getTime = func() time.Time { - return now.Add(10 * time.Minute) - } - - rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, - net.ParseIP("1.1.1.1"), "firefox") - require.Nil(t, err) - require.True(t, rotated) - - storedToken, err := ctx.getAuthTokenByID(userToken.Id) - require.Nil(t, err) - require.NotNil(t, storedToken) - require.False(t, storedToken.AuthTokenSeen) - require.Equal(t, prevToken, storedToken.PrevAuthToken) - require.NotEqual(t, prevToken, storedToken.AuthToken) - - prevToken = storedToken.AuthToken - - updated, err = ctx.markAuthTokenAsSeen(userToken.Id) - require.Nil(t, err) - require.True(t, updated) - - getTime = func() time.Time { - return now.Add(20 * time.Minute) - } - - rotated, _, err = ctx.tokenService.TryRotateToken(context.Background(), userToken, - net.ParseIP("1.1.1.1"), "firefox") - require.Nil(t, err) - require.True(t, rotated) - - storedToken, err = ctx.getAuthTokenByID(userToken.Id) - require.Nil(t, err) - require.NotNil(t, storedToken) - require.False(t, storedToken.AuthTokenSeen) - require.Equal(t, prevToken, storedToken.PrevAuthToken) - require.NotEqual(t, prevToken, storedToken.AuthToken) - }) - - t.Run("Should rotate current token, but keep previous token when auth token not seen", func(t *testing.T) { - getTime = func() time.Time { return now } - userToken, err := ctx.tokenService.CreateToken(context.Background(), usr, - net.ParseIP("192.168.10.11"), "some user agent") - require.Nil(t, err) - require.NotNil(t, userToken) - - prevToken := userToken.AuthToken - userToken.RotatedAt = now.Add(-2 * time.Minute).Unix() - - getTime = func() time.Time { - return now.Add(2 * time.Minute) - } - - rotated, _, err := ctx.tokenService.TryRotateToken(context.Background(), userToken, - net.ParseIP("1.1.1.1"), "firefox") - require.Nil(t, err) - require.True(t, rotated) - - storedToken, err := ctx.getAuthTokenByID(userToken.Id) - require.Nil(t, err) - require.NotNil(t, storedToken) - require.False(t, storedToken.AuthTokenSeen) - require.Equal(t, prevToken, storedToken.PrevAuthToken) - require.NotEqual(t, prevToken, storedToken.AuthToken) - }) - }) - t.Run("RotateToken", func(t *testing.T) { var prev string token, err := ctx.tokenService.CreateToken(context.Background(), usr, nil, "") @@ -673,24 +585,6 @@ func (c *testContext) getAuthTokenByID(id int64) (*userAuthToken, error) { return res, err } -func (c *testContext) markAuthTokenAsSeen(id int64) (bool, error) { - hasRowsAffected := false - err := c.sqlstore.WithDbSession(context.Background(), func(sess *db.Session) error { - res, err := sess.Exec("UPDATE user_auth_token SET auth_token_seen = ? WHERE id = ?", c.sqlstore.GetDialect().BooleanStr(true), id) - if err != nil { - return err - } - - rowsAffected, err := res.RowsAffected() - if err != nil { - return err - } - hasRowsAffected = rowsAffected == 1 - return nil - }) - return hasRowsAffected, err -} - func (c *testContext) updateRotatedAt(id, rotatedAt int64) (bool, error) { hasRowsAffected := false err := c.sqlstore.WithDbSession(context.Background(), func(sess *db.Session) error { diff --git a/pkg/services/auth/authtest/testing.go b/pkg/services/auth/authtest/testing.go index 75682224aa6..9f834382f27 100644 --- a/pkg/services/auth/authtest/testing.go +++ b/pkg/services/auth/authtest/testing.go @@ -84,11 +84,6 @@ func (s *FakeUserAuthTokenService) LookupToken(ctx context.Context, unhashedToke return s.LookupTokenProvider(context.Background(), unhashedToken) } -func (s *FakeUserAuthTokenService) TryRotateToken(ctx context.Context, token *auth.UserToken, clientIP net.IP, - userAgent string) (bool, *auth.UserToken, error) { - return s.TryRotateTokenProvider(context.Background(), token, clientIP, userAgent) -} - func (s *FakeUserAuthTokenService) RevokeToken(ctx context.Context, token *auth.UserToken, soft bool) error { return s.RevokeTokenProvider(context.Background(), token, soft) }