ContextHandler: Get token from req context when performing rotation (#60533)

ContextHandler: get token from req context when performing end of
request rotation
This commit is contained in:
Karl Persson
2022-12-20 14:41:26 +01:00
committed by GitHub
parent 707198227c
commit 1b1a14b6f6
2 changed files with 21 additions and 22 deletions

View File

@@ -522,7 +522,7 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org
// Rotate the token just before we write response headers to ensure there is no delay between
// the new token being generated and the client receiving it.
reqContext.Resp.Before(h.rotateEndOfRequestFunc(reqContext, h.AuthTokenService, token))
reqContext.Resp.Before(h.rotateEndOfRequestFunc(reqContext))
return true
}
@@ -539,8 +539,7 @@ func (h *ContextHandler) deleteInvalidCookieEndOfRequestFunc(reqContext *models.
}
}
func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *models.ReqContext, authTokenService auth.UserTokenService,
token *auth.UserToken) web.BeforeFunc {
func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *models.ReqContext) web.BeforeFunc {
return func(w web.ResponseWriter) {
// if response has already been written, skip.
if w.Written() {
@@ -553,6 +552,11 @@ func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *models.ReqContext, a
return
}
// if there is no user token attached to reqContext, skip.
if reqContext.UserToken == nil {
return
}
ctx, span := h.tracer.Start(reqContext.Req.Context(), "rotateEndOfRequestFunc")
defer span.End()
@@ -562,14 +566,14 @@ func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *models.ReqContext, a
reqContext.Logger.Debug("Failed to get client IP address", "addr", addr, "err", err)
ip = nil
}
rotated, err := authTokenService.TryRotateToken(ctx, token, ip, reqContext.Req.UserAgent())
rotated, err := h.AuthTokenService.TryRotateToken(ctx, reqContext.UserToken, ip, reqContext.Req.UserAgent())
if err != nil {
reqContext.Logger.Error("Failed to rotate token", "error", err)
return
}
if rotated {
cookies.WriteSessionCookie(reqContext, h.Cfg, token.UnhashedToken, h.Cfg.LoginMaxLifetime)
cookies.WriteSessionCookie(reqContext, h.Cfg, reqContext.UserToken.UnhashedToken, h.Cfg.LoginMaxLifetime)
}
}
}

View File

@@ -21,13 +21,8 @@ import (
func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
ctxHdlr := getContextHandler(t)
ctx, cancel := context.WithCancel(context.Background())
reqContext, _, err := initTokenRotationScenario(ctx, t, ctxHdlr)
require.NoError(t, err)
tryRotateCallCount := 0
uts := &authtest.FakeUserAuthTokenService{
ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP,
userAgent string) (bool, error) {
tryRotateCallCount++
@@ -35,9 +30,12 @@ func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
},
}
token := &auth.UserToken{AuthToken: "oldtoken"}
ctx, cancel := context.WithCancel(context.Background())
reqContext, _, err := initTokenRotationScenario(ctx, t, ctxHdlr)
require.NoError(t, err)
reqContext.UserToken = &auth.UserToken{AuthToken: "oldtoken"}
fn := ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token)
fn := ctxHdlr.rotateEndOfRequestFunc(reqContext)
cancel()
fn(reqContext.Resp)
@@ -46,11 +44,7 @@ func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
func TestTokenRotationAtEndOfRequest(t *testing.T) {
ctxHdlr := getContextHandler(t)
reqContext, rr, err := initTokenRotationScenario(context.Background(), t, ctxHdlr)
require.NoError(t, err)
uts := &authtest.FakeUserAuthTokenService{
ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP,
userAgent string) (bool, error) {
newToken, err := util.RandomHex(16)
@@ -60,10 +54,11 @@ func TestTokenRotationAtEndOfRequest(t *testing.T) {
},
}
token := &auth.UserToken{AuthToken: "oldtoken"}
ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token)(reqContext.Resp)
reqContext, rr, err := initTokenRotationScenario(context.Background(), t, ctxHdlr)
require.NoError(t, err)
reqContext.UserToken = &auth.UserToken{AuthToken: "oldtoken"}
ctxHdlr.rotateEndOfRequestFunc(reqContext)(reqContext.Resp)
foundLoginCookie := false
// nolint:bodyclose
resp := rr.Result()
@@ -74,7 +69,7 @@ func TestTokenRotationAtEndOfRequest(t *testing.T) {
for _, c := range resp.Cookies() {
if c.Name == "login_token" {
foundLoginCookie = true
require.NotEqual(t, token.AuthToken, c.Value, "Auth token is still the same")
require.NotEqual(t, reqContext.UserToken.AuthToken, c.Value, "Auth token is still the same")
}
}