Auth: Respect cache control for JWKS in auth.jwt (#68872)

* respect cache control for auth.jwt

* add documentation

* add small note on cache control header ignores

* make distinction of env
This commit is contained in:
Jo
2023-05-23 12:29:10 +02:00
committed by GitHub
parent 86ea0c2bc9
commit 5e5c751ecd
4 changed files with 124 additions and 8 deletions

View File

@@ -135,6 +135,8 @@ jwk_set_url = https://your-auth-provider.example.com/.well-known/jwks.json
cache_ttl = 60m
```
> **Note**: If the JWKS endpoint includes cache control headers and the value is less than the configured `cache_ttl`, then the cache control header value is used instead. If the cache_ttl is not set, no caching is performed. `no-store` and `no-cache` cache control headers are ignored.
### Verify token using a JSON Web Key Set loaded from JSON file
Key set in the same format as in JWKS endpoint but located on disk.

View File

@@ -109,6 +109,16 @@ func TestVerifyUsingJWKSetURL(t *testing.T) {
})
require.NoError(t, err)
_, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) {
cfg.JWTAuthJWKSetURL = "http://example.com/.well-known/jwks.json"
})
require.NoError(t, err)
oldEnv := setting.Env
setting.Env = setting.Prod
defer func() {
setting.Env = oldEnv
}()
_, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) {
cfg.JWTAuthJWKSetURL = "http://example.com/.well-known/jwks.json"
})

View File

@@ -3,21 +3,26 @@ package jwt
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
jose "github.com/go-jose/go-jose/v3"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/setting"
)
var ErrFailedToParsePemFile = errors.New("failed to parse pem-encoded file")
@@ -118,7 +123,7 @@ func (s *AuthService) initKeySet() error {
return fmt.Errorf("unknown pem block type %q", block.Type)
}
s.keySet = keySetJWKS{
s.keySet = &keySetJWKS{
jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{{Key: key}},
},
@@ -141,19 +146,35 @@ func (s *AuthService) initKeySet() error {
return err
}
s.keySet = keySetJWKS{jwks}
s.keySet = &keySetJWKS{jwks}
} else if urlStr := s.Cfg.JWTAuthJWKSetURL; urlStr != "" {
urlParsed, err := url.Parse(urlStr)
if err != nil {
return err
}
if urlParsed.Scheme != "https" {
if urlParsed.Scheme != "https" && setting.Env != setting.Dev {
return ErrJWTSetURLMustHaveHTTPSScheme
}
s.keySet = &keySetHTTP{
url: urlStr,
log: s.log,
client: &http.Client{},
url: urlStr,
log: s.log,
client: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Renegotiation: tls.RenegotiateFreelyAsClient,
},
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Second * 30,
KeepAlive: 15 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: 100,
IdleConnTimeout: 30 * time.Second,
},
Timeout: time.Second * 30,
},
cacheKey: fmt.Sprintf("auth-jwt:jwk-%s", urlStr),
cacheExpiration: s.Cfg.JWTAuthCacheTTL,
cache: s.RemoteCache,
@@ -163,7 +184,7 @@ func (s *AuthService) initKeySet() error {
return nil
}
func (ks keySetJWKS) Key(ctx context.Context, keyID string) ([]jose.JSONWebKey, error) {
func (ks *keySetJWKS) Key(ctx context.Context, keyID string) ([]jose.JSONWebKey, error) {
return ks.JSONWebKeySet.Key(keyID), nil
}
@@ -200,11 +221,40 @@ func (ks *keySetHTTP) getJWKS(ctx context.Context) (keySetJWKS, error) {
}
if ks.cacheExpiration > 0 {
err = ks.cache.Set(ctx, ks.cacheKey, jsonBuf.Bytes(), ks.cacheExpiration)
cacheExpiration := ks.getCacheExpiration(resp.Header.Get("cache-control"))
ks.log.Debug("Setting key set in cache", "url", ks.url,
"cacheExpiration", cacheExpiration, "cacheControl", resp.Header.Get("cache-control"))
err = ks.cache.Set(ctx, ks.cacheKey, jsonBuf.Bytes(), cacheExpiration)
}
return jwks, err
}
func (ks *keySetHTTP) getCacheExpiration(cacheControl string) time.Duration {
cacheDuration := ks.cacheExpiration
if cacheControl == "" {
return cacheDuration
}
parts := strings.Split(cacheControl, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "max-age=") {
maxAge, err := strconv.Atoi(part[8:])
if err != nil {
return cacheDuration
}
// If the cache duration is 0 or the max-age is less than the cache duration, use the max-age
if cacheDuration == 0 || time.Duration(maxAge)*time.Second < cacheDuration {
return time.Duration(maxAge) * time.Second
}
}
}
return cacheDuration
}
func (ks keySetHTTP) Key(ctx context.Context, kid string) ([]jose.JSONWebKey, error) {
jwks, err := ks.getJWKS(ctx)
if err != nil {

View File

@@ -0,0 +1,54 @@
package jwt
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestGetCacheExpiration(t *testing.T) {
ks := &keySetHTTP{cacheExpiration: 10 * time.Minute}
type testCase struct {
name string
header string
expiration time.Duration
}
testCases := []testCase{
{
name: "no cache control header",
header: "",
expiration: 10 * time.Minute,
},
{
name: "max-age less than cache duration",
header: "max-age=300",
expiration: 5 * time.Minute,
},
{
name: "max-age greater than cache duration",
header: "max-age=7200",
expiration: 10 * time.Minute,
},
{
name: "invalid max-age",
header: "max-age=invalid",
expiration: 10 * time.Minute,
},
{
name: "multiple cache control directives",
header: "max-age=300, no-cache",
expiration: 5 * time.Minute,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
expiration := ks.getCacheExpiration(tc.header)
assert.Equal(t, tc.expiration, expiration)
})
}
}