From 5e5c751ecdb5ab28e321efc2a43d431c1ba7430b Mon Sep 17 00:00:00 2001 From: Jo Date: Tue, 23 May 2023 12:29:10 +0200 Subject: [PATCH] Auth: Respect cache control for JWKS in auth.jwt (#68872) * respect cache control for auth.jwt * add documentation * add small note on cache control header ignores * make distinction of env --- .../configure-authentication/jwt/index.md | 2 + pkg/services/auth/jwt/auth_test.go | 10 +++ pkg/services/auth/jwt/key_sets.go | 66 ++++++++++++++++--- pkg/services/auth/jwt/key_sets_test.go | 54 +++++++++++++++ 4 files changed, 124 insertions(+), 8 deletions(-) create mode 100644 pkg/services/auth/jwt/key_sets_test.go diff --git a/docs/sources/setup-grafana/configure-security/configure-authentication/jwt/index.md b/docs/sources/setup-grafana/configure-security/configure-authentication/jwt/index.md index 89b07cd1a5f..62bef87d9de 100644 --- a/docs/sources/setup-grafana/configure-security/configure-authentication/jwt/index.md +++ b/docs/sources/setup-grafana/configure-security/configure-authentication/jwt/index.md @@ -135,6 +135,8 @@ jwk_set_url = https://your-auth-provider.example.com/.well-known/jwks.json cache_ttl = 60m ``` +> **Note**: If the JWKS endpoint includes cache control headers and the value is less than the configured `cache_ttl`, then the cache control header value is used instead. If the cache_ttl is not set, no caching is performed. `no-store` and `no-cache` cache control headers are ignored. + ### Verify token using a JSON Web Key Set loaded from JSON file Key set in the same format as in JWKS endpoint but located on disk. diff --git a/pkg/services/auth/jwt/auth_test.go b/pkg/services/auth/jwt/auth_test.go index a77166534fc..95779e39c8c 100644 --- a/pkg/services/auth/jwt/auth_test.go +++ b/pkg/services/auth/jwt/auth_test.go @@ -109,6 +109,16 @@ func TestVerifyUsingJWKSetURL(t *testing.T) { }) require.NoError(t, err) + _, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) { + cfg.JWTAuthJWKSetURL = "http://example.com/.well-known/jwks.json" + }) + require.NoError(t, err) + + oldEnv := setting.Env + setting.Env = setting.Prod + defer func() { + setting.Env = oldEnv + }() _, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) { cfg.JWTAuthJWKSetURL = "http://example.com/.well-known/jwks.json" }) diff --git a/pkg/services/auth/jwt/key_sets.go b/pkg/services/auth/jwt/key_sets.go index 05b746f4ce0..0181072c1d5 100644 --- a/pkg/services/auth/jwt/key_sets.go +++ b/pkg/services/auth/jwt/key_sets.go @@ -3,21 +3,26 @@ package jwt import ( "bytes" "context" + "crypto/tls" "crypto/x509" "encoding/json" "encoding/pem" "errors" "fmt" "io" + "net" "net/http" "net/url" "os" + "strconv" + "strings" "time" jose "github.com/go-jose/go-jose/v3" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/remotecache" + "github.com/grafana/grafana/pkg/setting" ) var ErrFailedToParsePemFile = errors.New("failed to parse pem-encoded file") @@ -118,7 +123,7 @@ func (s *AuthService) initKeySet() error { return fmt.Errorf("unknown pem block type %q", block.Type) } - s.keySet = keySetJWKS{ + s.keySet = &keySetJWKS{ jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{{Key: key}}, }, @@ -141,19 +146,35 @@ func (s *AuthService) initKeySet() error { return err } - s.keySet = keySetJWKS{jwks} + s.keySet = &keySetJWKS{jwks} } else if urlStr := s.Cfg.JWTAuthJWKSetURL; urlStr != "" { urlParsed, err := url.Parse(urlStr) if err != nil { return err } - if urlParsed.Scheme != "https" { + if urlParsed.Scheme != "https" && setting.Env != setting.Dev { return ErrJWTSetURLMustHaveHTTPSScheme } s.keySet = &keySetHTTP{ - url: urlStr, - log: s.log, - client: &http.Client{}, + url: urlStr, + log: s.log, + client: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Renegotiation: tls.RenegotiateFreelyAsClient, + }, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: time.Second * 30, + KeepAlive: 15 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + MaxIdleConns: 100, + IdleConnTimeout: 30 * time.Second, + }, + Timeout: time.Second * 30, + }, cacheKey: fmt.Sprintf("auth-jwt:jwk-%s", urlStr), cacheExpiration: s.Cfg.JWTAuthCacheTTL, cache: s.RemoteCache, @@ -163,7 +184,7 @@ func (s *AuthService) initKeySet() error { return nil } -func (ks keySetJWKS) Key(ctx context.Context, keyID string) ([]jose.JSONWebKey, error) { +func (ks *keySetJWKS) Key(ctx context.Context, keyID string) ([]jose.JSONWebKey, error) { return ks.JSONWebKeySet.Key(keyID), nil } @@ -200,11 +221,40 @@ func (ks *keySetHTTP) getJWKS(ctx context.Context) (keySetJWKS, error) { } if ks.cacheExpiration > 0 { - err = ks.cache.Set(ctx, ks.cacheKey, jsonBuf.Bytes(), ks.cacheExpiration) + cacheExpiration := ks.getCacheExpiration(resp.Header.Get("cache-control")) + + ks.log.Debug("Setting key set in cache", "url", ks.url, + "cacheExpiration", cacheExpiration, "cacheControl", resp.Header.Get("cache-control")) + err = ks.cache.Set(ctx, ks.cacheKey, jsonBuf.Bytes(), cacheExpiration) } return jwks, err } +func (ks *keySetHTTP) getCacheExpiration(cacheControl string) time.Duration { + cacheDuration := ks.cacheExpiration + if cacheControl == "" { + return cacheDuration + } + + parts := strings.Split(cacheControl, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "max-age=") { + maxAge, err := strconv.Atoi(part[8:]) + if err != nil { + return cacheDuration + } + + // If the cache duration is 0 or the max-age is less than the cache duration, use the max-age + if cacheDuration == 0 || time.Duration(maxAge)*time.Second < cacheDuration { + return time.Duration(maxAge) * time.Second + } + } + } + + return cacheDuration +} + func (ks keySetHTTP) Key(ctx context.Context, kid string) ([]jose.JSONWebKey, error) { jwks, err := ks.getJWKS(ctx) if err != nil { diff --git a/pkg/services/auth/jwt/key_sets_test.go b/pkg/services/auth/jwt/key_sets_test.go new file mode 100644 index 00000000000..9512cb0e07a --- /dev/null +++ b/pkg/services/auth/jwt/key_sets_test.go @@ -0,0 +1,54 @@ +package jwt + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGetCacheExpiration(t *testing.T) { + ks := &keySetHTTP{cacheExpiration: 10 * time.Minute} + type testCase struct { + name string + header string + expiration time.Duration + } + + testCases := []testCase{ + { + name: "no cache control header", + header: "", + expiration: 10 * time.Minute, + }, + { + name: "max-age less than cache duration", + header: "max-age=300", + expiration: 5 * time.Minute, + }, + { + name: "max-age greater than cache duration", + header: "max-age=7200", + expiration: 10 * time.Minute, + }, + { + name: "invalid max-age", + header: "max-age=invalid", + expiration: 10 * time.Minute, + }, + { + name: "multiple cache control directives", + header: "max-age=300, no-cache", + expiration: 5 * time.Minute, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + expiration := ks.getCacheExpiration(tc.header) + assert.Equal(t, tc.expiration, expiration) + }) + } +}