SigningKeys: Add jwks endpoint (#76040)

* add jwks

add remote caching

add expose jwks test

tweaks

* fix swagger

* nt
This commit is contained in:
Jo 2023-10-05 15:17:31 +02:00 committed by GitHub
parent 8a33a6f958
commit f2bf066ad2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 145 additions and 7 deletions

View File

@ -32,7 +32,7 @@ func (fcs FakeCacheStorage) Count(_ context.Context, prefix string) (int64, erro
return int64(len(fcs.Storage)), nil
}
func NewFakeCacheStorage() CacheStorage {
func NewFakeCacheStorage() FakeCacheStorage {
return FakeCacheStorage{
Storage: map[string][]byte{},
}

View File

@ -6,15 +6,20 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/json"
"errors"
"net/http"
"strings"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/grafana/grafana/pkg/api/response"
"github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
"github.com/grafana/grafana/pkg/services/secrets"
"github.com/grafana/grafana/pkg/services/signingkeys"
"github.com/grafana/grafana/pkg/services/signingkeys/signingkeystore"
@ -23,7 +28,7 @@ import (
var _ signingkeys.Service = new(Service)
func ProvideEmbeddedSigningKeysService(dbStore db.DB, secretsService secrets.Service,
remoteCache remotecache.CacheStorage,
remoteCache remotecache.CacheStorage, routerRegister routing.RouteRegister,
) (*Service, error) {
s := &Service{
log: log.New("auth.key_service"),
@ -31,6 +36,8 @@ func ProvideEmbeddedSigningKeysService(dbStore db.DB, secretsService secrets.Ser
remoteCache: remoteCache,
}
s.registerAPIEndpoints(routerRegister)
return s, nil
}
@ -44,9 +51,34 @@ type Service struct {
remoteCache remotecache.CacheStorage
}
const (
jwksCacheKey = "signingkeys-jwks"
defaultExpiry = 12 * time.Hour
)
// GetJWKS returns the JSON Web Key Set (JWKS) with all the keys that can be used to verify tokens (public keys)
func (s *Service) GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error) {
// check cache for jwks
keySet := jose.JSONWebKeySet{}
if jwks, err := s.remoteCache.Get(ctx, jwksCacheKey); err == nil {
if err := json.Unmarshal(jwks, &keySet); err == nil {
return keySet, nil
}
}
jwks, err := s.store.GetJWKS(ctx)
if err != nil {
return jose.JSONWebKeySet{}, err
}
// cache jwks
jwksBytes, err := json.Marshal(jwks)
if err == nil {
if err := s.remoteCache.Set(ctx, jwksCacheKey, jwksBytes, defaultExpiry); err != nil {
s.log.Warn("Failed to cache JWKS", "err", err)
}
}
return jwks, err
}
@ -78,6 +110,12 @@ func (s *Service) GetOrCreatePrivateKey(ctx context.Context,
return "", nil, err
}
// invalidate cache
if err := s.remoteCache.Delete(ctx, jwksCacheKey); err != nil {
// not a critical error, key might not be in cache
s.log.Debug("Failed to invalidate JWKS cache", "err", err)
}
return keyID, signer, nil
}
@ -85,3 +123,40 @@ func keyMonthScopedID(keyPrefix string, alg jose.SignatureAlgorithm) string {
keyID := keyPrefix + "-" + time.Now().UTC().Format("2006-01") + "-" + strings.ToLower(string(alg))
return keyID
}
func (s *Service) registerAPIEndpoints(router routing.RouteRegister) {
router.Group("/api/signing-keys", func(grouper routing.RouteRegister) {
grouper.Get("/keys", s.exposeJWKS)
})
}
// swagger:response jwksResponse
type RetrieveJWKSResponse struct {
// in: body
Body struct {
Keys []jose.JSONWebKey `json:"keys"`
}
}
// swagger:route GET /signing-keys/keys signing_keys retrieveJWKS
//
// # Get JSON Web Key Set (JWKS) with all the keys that can be used to verify tokens (public keys)
//
// Required permissions
// None
//
// Responses:
// 200: jwksResponse
// 500: internalServerError
func (s *Service) exposeJWKS(ctx *contextmodel.ReqContext) response.Response {
jwks, err := s.GetJWKS(ctx.Req.Context())
if err != nil {
s.log.Error("Failed to get JWKS", "err", err)
return response.Error(http.StatusInternalServerError, "Failed to get JWKS", err)
}
// set cache headers to 1 hour
ctx.Resp.Header().Set("Cache-Control", "public, max-age=3600")
return response.JSON(http.StatusOK, jwks)
}

View File

@ -7,6 +7,8 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http"
"testing"
"time"
@ -14,9 +16,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/services/signingkeys"
"github.com/grafana/grafana/pkg/services/signingkeys/signingkeystore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/web/webtest"
)
const (
@ -35,6 +41,7 @@ func getPrivateKey(t *testing.T) *ecdsa.PrivateKey {
func TestEmbeddedKeyService_GetJWKS_OnlyPublicKeyShared(t *testing.T) {
mockStore := signingkeystore.NewFakeStore()
cacheStorage := remotecache.NewFakeCacheStorage()
_, err := mockStore.AddPrivateKey(context.Background(), signingkeys.ServerPrivateKeyID, jose.ES256, getPrivateKey(t), nil, false)
require.NoError(t, err)
@ -43,8 +50,9 @@ func TestEmbeddedKeyService_GetJWKS_OnlyPublicKeyShared(t *testing.T) {
require.NoError(t, err)
svc := &Service{
log: log.NewNopLogger(),
store: mockStore,
log: log.NewNopLogger(),
store: mockStore,
remoteCache: cacheStorage,
}
jwks, err := svc.GetJWKS(context.Background())
require.NoError(t, err)
@ -74,25 +82,39 @@ func TestEmbeddedKeyService_GetJWKS_OnlyPublicKeyShared(t *testing.T) {
func TestEmbeddedKeyService_GetOrCreatePrivateKey(t *testing.T) {
mockStore := signingkeystore.NewFakeStore()
cacheStorage := remotecache.NewFakeCacheStorage()
svc := &Service{
log: log.NewNopLogger(),
store: mockStore,
log: log.NewNopLogger(),
store: mockStore,
remoteCache: cacheStorage,
}
wantedKeyID := keyMonthScopedID("test", jose.ES256)
assert.Equal(t, wantedKeyID, fmt.Sprintf("test-%s-es256", time.Now().UTC().Format("2006-01")))
err := cacheStorage.Set(context.Background(), jwksCacheKey, []byte("invalid"), 0)
require.NoError(t, err)
// only ES256 is supported
_, _, err := svc.GetOrCreatePrivateKey(context.Background(), "test", jose.RS256)
_, _, err = svc.GetOrCreatePrivateKey(context.Background(), "test", jose.RS256)
require.Error(t, err)
_, err = cacheStorage.Get(context.Background(), jwksCacheKey)
require.NoError(t, err)
require.Len(t, cacheStorage.Storage, 1)
// first call should generate a key
_, key, err := svc.GetOrCreatePrivateKey(context.Background(), "test", jose.ES256)
require.NoError(t, err)
require.NotNil(t, key)
// new key is generated, so jwks cache should be voided
require.Len(t, cacheStorage.Storage, 0)
assert.Contains(t, mockStore.PrivateKeys, wantedKeyID)
err = cacheStorage.Set(context.Background(), jwksCacheKey, []byte("invalid"), 0)
require.NoError(t, err)
// second call should return the same key
id, key2, err := svc.GetOrCreatePrivateKey(context.Background(), "test", jose.ES256)
require.NoError(t, err)
@ -101,4 +123,45 @@ func TestEmbeddedKeyService_GetOrCreatePrivateKey(t *testing.T) {
require.Equal(t, wantedKeyID, id)
assert.Len(t, mockStore.PrivateKeys, 1)
// no new key is generated, so jwks cache should not be voided
require.Len(t, cacheStorage.Storage, 1)
}
func TestExposeJWKS(t *testing.T) {
// create a new service instance
mockStore := signingkeystore.NewFakeStore()
cacheStorage := remotecache.NewFakeCacheStorage()
svc := &Service{
log: log.NewNopLogger(),
store: mockStore,
remoteCache: cacheStorage,
}
routerRegister := routing.NewRouteRegister()
svc.registerAPIEndpoints(routerRegister)
server := webtest.NewServer(t, routerRegister)
_, err := mockStore.AddPrivateKey(context.Background(), "test-key", jose.ES256, getPrivateKey(t), nil, false)
require.NoError(t, err)
// create a new request context
req := server.NewRequest(http.MethodGet, "/api/signing-keys/keys", nil)
webtest.RequestWithSignedInUser(req, &user.SignedInUser{OrgID: 1,
Permissions: map[int64]map[string][]string{}})
res, err := server.Send(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "application/json", res.Header.Get("Content-Type"))
// check the response body
expected := `{"keys":[{"use":"sig","kty":"EC","kid":"test-key","crv":"P-256","alg":"ES256","x":"YYpLNYcnJp7FmSkPBHEOvwmCspeJvUYiOC3vo2h7jsY","y":"2PDsIq8bryoBUmLBYW1tlpy6fhMcHVNnaOApWStRYGw"}]}`
body, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.JSONEq(t, expected, string(body), string(body))
require.NoError(t, res.Body.Close())
assert.Contains(t, cacheStorage.Storage, jwksCacheKey)
}