mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
ContextHandler: Get token from req context when performing rotation (#60533)
ContextHandler: get token from req context when performing end of request rotation
This commit is contained in:
@@ -522,7 +522,7 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org
|
||||
|
||||
// Rotate the token just before we write response headers to ensure there is no delay between
|
||||
// the new token being generated and the client receiving it.
|
||||
reqContext.Resp.Before(h.rotateEndOfRequestFunc(reqContext, h.AuthTokenService, token))
|
||||
reqContext.Resp.Before(h.rotateEndOfRequestFunc(reqContext))
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -539,8 +539,7 @@ func (h *ContextHandler) deleteInvalidCookieEndOfRequestFunc(reqContext *models.
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *models.ReqContext, authTokenService auth.UserTokenService,
|
||||
token *auth.UserToken) web.BeforeFunc {
|
||||
func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *models.ReqContext) web.BeforeFunc {
|
||||
return func(w web.ResponseWriter) {
|
||||
// if response has already been written, skip.
|
||||
if w.Written() {
|
||||
@@ -553,6 +552,11 @@ func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *models.ReqContext, a
|
||||
return
|
||||
}
|
||||
|
||||
// if there is no user token attached to reqContext, skip.
|
||||
if reqContext.UserToken == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, span := h.tracer.Start(reqContext.Req.Context(), "rotateEndOfRequestFunc")
|
||||
defer span.End()
|
||||
|
||||
@@ -562,14 +566,14 @@ func (h *ContextHandler) rotateEndOfRequestFunc(reqContext *models.ReqContext, a
|
||||
reqContext.Logger.Debug("Failed to get client IP address", "addr", addr, "err", err)
|
||||
ip = nil
|
||||
}
|
||||
rotated, err := authTokenService.TryRotateToken(ctx, token, ip, reqContext.Req.UserAgent())
|
||||
rotated, err := h.AuthTokenService.TryRotateToken(ctx, reqContext.UserToken, ip, reqContext.Req.UserAgent())
|
||||
if err != nil {
|
||||
reqContext.Logger.Error("Failed to rotate token", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if rotated {
|
||||
cookies.WriteSessionCookie(reqContext, h.Cfg, token.UnhashedToken, h.Cfg.LoginMaxLifetime)
|
||||
cookies.WriteSessionCookie(reqContext, h.Cfg, reqContext.UserToken.UnhashedToken, h.Cfg.LoginMaxLifetime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,13 +21,8 @@ import (
|
||||
|
||||
func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
|
||||
ctxHdlr := getContextHandler(t)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
reqContext, _, err := initTokenRotationScenario(ctx, t, ctxHdlr)
|
||||
require.NoError(t, err)
|
||||
|
||||
tryRotateCallCount := 0
|
||||
uts := &authtest.FakeUserAuthTokenService{
|
||||
ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{
|
||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP,
|
||||
userAgent string) (bool, error) {
|
||||
tryRotateCallCount++
|
||||
@@ -35,9 +30,12 @@ func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
token := &auth.UserToken{AuthToken: "oldtoken"}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
reqContext, _, err := initTokenRotationScenario(ctx, t, ctxHdlr)
|
||||
require.NoError(t, err)
|
||||
reqContext.UserToken = &auth.UserToken{AuthToken: "oldtoken"}
|
||||
|
||||
fn := ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token)
|
||||
fn := ctxHdlr.rotateEndOfRequestFunc(reqContext)
|
||||
cancel()
|
||||
fn(reqContext.Resp)
|
||||
|
||||
@@ -46,11 +44,7 @@ func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
|
||||
|
||||
func TestTokenRotationAtEndOfRequest(t *testing.T) {
|
||||
ctxHdlr := getContextHandler(t)
|
||||
|
||||
reqContext, rr, err := initTokenRotationScenario(context.Background(), t, ctxHdlr)
|
||||
require.NoError(t, err)
|
||||
|
||||
uts := &authtest.FakeUserAuthTokenService{
|
||||
ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{
|
||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP,
|
||||
userAgent string) (bool, error) {
|
||||
newToken, err := util.RandomHex(16)
|
||||
@@ -60,10 +54,11 @@ func TestTokenRotationAtEndOfRequest(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
token := &auth.UserToken{AuthToken: "oldtoken"}
|
||||
|
||||
ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token)(reqContext.Resp)
|
||||
reqContext, rr, err := initTokenRotationScenario(context.Background(), t, ctxHdlr)
|
||||
require.NoError(t, err)
|
||||
reqContext.UserToken = &auth.UserToken{AuthToken: "oldtoken"}
|
||||
|
||||
ctxHdlr.rotateEndOfRequestFunc(reqContext)(reqContext.Resp)
|
||||
foundLoginCookie := false
|
||||
// nolint:bodyclose
|
||||
resp := rr.Result()
|
||||
@@ -74,7 +69,7 @@ func TestTokenRotationAtEndOfRequest(t *testing.T) {
|
||||
for _, c := range resp.Cookies() {
|
||||
if c.Name == "login_token" {
|
||||
foundLoginCookie = true
|
||||
require.NotEqual(t, token.AuthToken, c.Value, "Auth token is still the same")
|
||||
require.NotEqual(t, reqContext.UserToken.AuthToken, c.Value, "Auth token is still the same")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user