mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Respect cache control for JWKS in auth.jwt (#68872)
* respect cache control for auth.jwt * add documentation * add small note on cache control header ignores * make distinction of env
This commit is contained in:
@@ -135,6 +135,8 @@ jwk_set_url = https://your-auth-provider.example.com/.well-known/jwks.json
|
||||
cache_ttl = 60m
|
||||
```
|
||||
|
||||
> **Note**: If the JWKS endpoint includes cache control headers and the value is less than the configured `cache_ttl`, then the cache control header value is used instead. If the cache_ttl is not set, no caching is performed. `no-store` and `no-cache` cache control headers are ignored.
|
||||
|
||||
### Verify token using a JSON Web Key Set loaded from JSON file
|
||||
|
||||
Key set in the same format as in JWKS endpoint but located on disk.
|
||||
|
||||
@@ -109,6 +109,16 @@ func TestVerifyUsingJWKSetURL(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) {
|
||||
cfg.JWTAuthJWKSetURL = "http://example.com/.well-known/jwks.json"
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
oldEnv := setting.Env
|
||||
setting.Env = setting.Prod
|
||||
defer func() {
|
||||
setting.Env = oldEnv
|
||||
}()
|
||||
_, err = initAuthService(t, func(t *testing.T, cfg *setting.Cfg) {
|
||||
cfg.JWTAuthJWKSetURL = "http://example.com/.well-known/jwks.json"
|
||||
})
|
||||
|
||||
@@ -3,21 +3,26 @@ package jwt
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
var ErrFailedToParsePemFile = errors.New("failed to parse pem-encoded file")
|
||||
@@ -118,7 +123,7 @@ func (s *AuthService) initKeySet() error {
|
||||
return fmt.Errorf("unknown pem block type %q", block.Type)
|
||||
}
|
||||
|
||||
s.keySet = keySetJWKS{
|
||||
s.keySet = &keySetJWKS{
|
||||
jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{{Key: key}},
|
||||
},
|
||||
@@ -141,19 +146,35 @@ func (s *AuthService) initKeySet() error {
|
||||
return err
|
||||
}
|
||||
|
||||
s.keySet = keySetJWKS{jwks}
|
||||
s.keySet = &keySetJWKS{jwks}
|
||||
} else if urlStr := s.Cfg.JWTAuthJWKSetURL; urlStr != "" {
|
||||
urlParsed, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if urlParsed.Scheme != "https" {
|
||||
if urlParsed.Scheme != "https" && setting.Env != setting.Dev {
|
||||
return ErrJWTSetURLMustHaveHTTPSScheme
|
||||
}
|
||||
s.keySet = &keySetHTTP{
|
||||
url: urlStr,
|
||||
log: s.log,
|
||||
client: &http.Client{},
|
||||
url: urlStr,
|
||||
log: s.log,
|
||||
client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
Renegotiation: tls.RenegotiateFreelyAsClient,
|
||||
},
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Second * 30,
|
||||
KeepAlive: 15 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
},
|
||||
Timeout: time.Second * 30,
|
||||
},
|
||||
cacheKey: fmt.Sprintf("auth-jwt:jwk-%s", urlStr),
|
||||
cacheExpiration: s.Cfg.JWTAuthCacheTTL,
|
||||
cache: s.RemoteCache,
|
||||
@@ -163,7 +184,7 @@ func (s *AuthService) initKeySet() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ks keySetJWKS) Key(ctx context.Context, keyID string) ([]jose.JSONWebKey, error) {
|
||||
func (ks *keySetJWKS) Key(ctx context.Context, keyID string) ([]jose.JSONWebKey, error) {
|
||||
return ks.JSONWebKeySet.Key(keyID), nil
|
||||
}
|
||||
|
||||
@@ -200,11 +221,40 @@ func (ks *keySetHTTP) getJWKS(ctx context.Context) (keySetJWKS, error) {
|
||||
}
|
||||
|
||||
if ks.cacheExpiration > 0 {
|
||||
err = ks.cache.Set(ctx, ks.cacheKey, jsonBuf.Bytes(), ks.cacheExpiration)
|
||||
cacheExpiration := ks.getCacheExpiration(resp.Header.Get("cache-control"))
|
||||
|
||||
ks.log.Debug("Setting key set in cache", "url", ks.url,
|
||||
"cacheExpiration", cacheExpiration, "cacheControl", resp.Header.Get("cache-control"))
|
||||
err = ks.cache.Set(ctx, ks.cacheKey, jsonBuf.Bytes(), cacheExpiration)
|
||||
}
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
func (ks *keySetHTTP) getCacheExpiration(cacheControl string) time.Duration {
|
||||
cacheDuration := ks.cacheExpiration
|
||||
if cacheControl == "" {
|
||||
return cacheDuration
|
||||
}
|
||||
|
||||
parts := strings.Split(cacheControl, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "max-age=") {
|
||||
maxAge, err := strconv.Atoi(part[8:])
|
||||
if err != nil {
|
||||
return cacheDuration
|
||||
}
|
||||
|
||||
// If the cache duration is 0 or the max-age is less than the cache duration, use the max-age
|
||||
if cacheDuration == 0 || time.Duration(maxAge)*time.Second < cacheDuration {
|
||||
return time.Duration(maxAge) * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cacheDuration
|
||||
}
|
||||
|
||||
func (ks keySetHTTP) Key(ctx context.Context, kid string) ([]jose.JSONWebKey, error) {
|
||||
jwks, err := ks.getJWKS(ctx)
|
||||
if err != nil {
|
||||
|
||||
54
pkg/services/auth/jwt/key_sets_test.go
Normal file
54
pkg/services/auth/jwt/key_sets_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetCacheExpiration(t *testing.T) {
|
||||
ks := &keySetHTTP{cacheExpiration: 10 * time.Minute}
|
||||
type testCase struct {
|
||||
name string
|
||||
header string
|
||||
expiration time.Duration
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "no cache control header",
|
||||
header: "",
|
||||
expiration: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "max-age less than cache duration",
|
||||
header: "max-age=300",
|
||||
expiration: 5 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "max-age greater than cache duration",
|
||||
header: "max-age=7200",
|
||||
expiration: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "invalid max-age",
|
||||
header: "max-age=invalid",
|
||||
expiration: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "multiple cache control directives",
|
||||
header: "max-age=300, no-cache",
|
||||
expiration: 5 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
expiration := ks.getCacheExpiration(tc.header)
|
||||
assert.Equal(t, tc.expiration, expiration)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user