mirror of
https://github.com/grafana/grafana.git
synced 2024-11-26 19:00:54 -06:00
Auth: Signing Key persistence (#75487)
* signing key wip use db keyset storage add signing_key table add testing for key storage add ES256 key tests Remove caching and implement UpdateOrCreate Stabilize interfaces * Encrypt private keys * Fixup signer * Fixup ext_jwt * Add GetOrCreatePrivate with automatic key rotation * use GetOrCreate for ext_jwt * use GetOrCreate in id * catch invalid block type * fix broken test * remove key generator * reduce public interface of signing service
This commit is contained in:
parent
0eac9aff7f
commit
44fa0697ce
@ -10,14 +10,20 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/signingkeys"
|
||||
)
|
||||
|
||||
const idSignerKeyPrefix = "id"
|
||||
|
||||
var _ auth.IDSigner = (*LocalSigner)(nil)
|
||||
|
||||
func ProvideLocalSigner(keyService signingkeys.Service) (*LocalSigner, error) {
|
||||
key := keyService.GetServerPrivateKey() // FIXME: replace with signing specific key
|
||||
id, key, err := keyService.GetOrCreatePrivateKey(context.Background(), idSignerKeyPrefix, jose.ES256)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// FIXME: Handle key rotation
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, &jose.SignerOptions{
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
"kid": "default", // FIXME: replace with specific key id
|
||||
"kid": id,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
@ -172,7 +173,13 @@ func (s *ExtendedJWT) verifyRFC9068Token(ctx context.Context, rawToken string) (
|
||||
}
|
||||
|
||||
var claims ExtendedJWTClaims
|
||||
err = parsedToken.Claims(s.signingKeys.GetServerPublicKey(), &claims)
|
||||
_, key, err := s.signingKeys.GetOrCreatePrivateKey(ctx,
|
||||
signingkeys.ServerPrivateKeyID, jose.ES256)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get public key: %w", err)
|
||||
}
|
||||
|
||||
err = parsedToken.Claims(key.Public(), &claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to verify the signature: %w", err)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
@ -12,17 +13,19 @@ import (
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models/roletype"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthserver"
|
||||
"github.com/grafana/grafana/pkg/services/oauthserver/oastest"
|
||||
"github.com/grafana/grafana/pkg/services/signingkeys"
|
||||
"github.com/grafana/grafana/pkg/services/signingkeys/signingkeystest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/services/user/usertest"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -513,8 +516,9 @@ func setupTestCtx(t *testing.T, cfg *setting.Cfg) *testEnv {
|
||||
}
|
||||
}
|
||||
|
||||
signingKeysSvc := &signingkeystest.FakeSigningKeysService{}
|
||||
signingKeysSvc.ExpectedServerPublicKey = &pk.PublicKey
|
||||
signingKeysSvc := &signingkeystest.FakeSigningKeysService{ExpectedKeys: map[string]crypto.Signer{
|
||||
signingkeys.ServerPrivateKeyID: pk},
|
||||
}
|
||||
|
||||
userSvc := &usertest.FakeUserService{}
|
||||
oauthSvc := &oastest.FakeService{}
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/compose"
|
||||
"github.com/ory/fosite/storage"
|
||||
@ -75,18 +76,6 @@ func ProvideService(router routing.RouteRegister, db db.DB, cfg *setting.Cfg,
|
||||
ScopeStrategy: fosite.WildcardScopeStrategy,
|
||||
}
|
||||
|
||||
privateKey := keySvc.GetServerPrivateKey()
|
||||
|
||||
var publicKey any
|
||||
switch k := privateKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
publicKey = &k.PublicKey
|
||||
case *ecdsa.PrivateKey:
|
||||
publicKey = &k.PublicKey
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown private key type %T", k)
|
||||
}
|
||||
|
||||
s := &OAuth2ServiceImpl{
|
||||
cache: localcache.New(cacheExpirationTime, cacheCleanupInterval),
|
||||
cfg: cfg,
|
||||
@ -98,20 +87,20 @@ func ProvideService(router routing.RouteRegister, db db.DB, cfg *setting.Cfg,
|
||||
userService: userSvc,
|
||||
saService: svcAccSvc,
|
||||
teamService: teamSvc,
|
||||
publicKey: publicKey,
|
||||
}
|
||||
|
||||
api := api.NewAPI(router, s)
|
||||
api.RegisterAPIEndpoints()
|
||||
|
||||
s.oauthProvider = newProvider(config, s, privateKey)
|
||||
s.oauthProvider = newProvider(config, s, keySvc)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func newProvider(config *fosite.Config, storage any, key any) fosite.OAuth2Provider {
|
||||
keyGetter := func(context.Context) (any, error) {
|
||||
return key, nil
|
||||
func newProvider(config *fosite.Config, storage any, signingKeyService signingkeys.Service) fosite.OAuth2Provider {
|
||||
keyGetter := func(ctx context.Context) (any, error) {
|
||||
_, key, err := signingKeyService.GetOrCreatePrivateKey(ctx, signingkeys.ServerPrivateKeyID, jose.ES256)
|
||||
return key, err
|
||||
}
|
||||
return compose.Compose(
|
||||
config,
|
||||
|
@ -2,6 +2,7 @@ package oasimpl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
@ -26,6 +27,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/oauthserver/oastest"
|
||||
sa "github.com/grafana/grafana/pkg/services/serviceaccounts"
|
||||
satests "github.com/grafana/grafana/pkg/services/serviceaccounts/tests"
|
||||
"github.com/grafana/grafana/pkg/services/signingkeys/signingkeystest"
|
||||
"github.com/grafana/grafana/pkg/services/team/teamtest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/services/user/usertest"
|
||||
@ -89,7 +91,13 @@ func setupTestEnv(t *testing.T) *TestEnv {
|
||||
teamService: env.TeamService,
|
||||
publicKey: &pk.PublicKey,
|
||||
}
|
||||
env.S.oauthProvider = newProvider(config, env.S, pk)
|
||||
|
||||
env.S.oauthProvider = newProvider(config, env.S, &signingkeystest.FakeSigningKeysService{
|
||||
ExpectedKeys: map[string]crypto.Signer{
|
||||
"default": pk,
|
||||
},
|
||||
ExpectedError: nil,
|
||||
})
|
||||
|
||||
return env
|
||||
}
|
||||
|
@ -609,7 +609,7 @@ func TestOAuth2ServiceImpl_HandleTokenRequest(t *testing.T) {
|
||||
|
||||
env.S.HandleTokenRequest(resp, req)
|
||||
|
||||
require.Equal(t, tt.wantCode, resp.Code)
|
||||
require.Equal(t, tt.wantCode, resp.Code, resp.Body.String())
|
||||
if tt.wantCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ func ProvideSecretsMigrator(
|
||||
b64Secret{simpleSecret: simpleSecret{tableName: "secrets", columnName: "value"}, hasUpdatedColumn: true, encoding: base64.RawStdEncoding},
|
||||
jsonSecret{tableName: "data_source"},
|
||||
jsonSecret{tableName: "plugin_setting"},
|
||||
b64Secret{simpleSecret: simpleSecret{tableName: "signing_key", columnName: "private_key"}, encoding: base64.StdEncoding},
|
||||
alertingSecret{},
|
||||
}
|
||||
|
||||
|
@ -8,27 +8,21 @@
|
||||
package signingkeys
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
)
|
||||
|
||||
const (
|
||||
ServerPrivateKeyID = "default"
|
||||
)
|
||||
|
||||
// Service provides functionality for managing signing keys used to sign and verify JWT tokens.
|
||||
//
|
||||
// The service is under active development and is not yet ready for production use.
|
||||
type Service interface {
|
||||
// GetJWKS returns the JSON Web Key Set (JWKS) with all the keys that can be used to verify tokens (public keys)
|
||||
GetJWKS() jose.JSONWebKeySet
|
||||
// GetJWK returns the JSON Web Key (JWK) with the specified key ID which can be used to verify tokens (public key)
|
||||
GetJWK(keyID string) (jose.JSONWebKey, error)
|
||||
// GetPublicKey returns the public key with the specified key ID
|
||||
GetPublicKey(keyID string) (crypto.PublicKey, error)
|
||||
// GetPrivateKey returns the private key with the specified key ID
|
||||
GetPrivateKey(keyID string) (crypto.PrivateKey, error)
|
||||
// GetServerPrivateKey returns the private key used to sign tokens
|
||||
GetServerPrivateKey() crypto.PrivateKey
|
||||
// GetServerPublicKey returns the public key used to verify tokens
|
||||
GetServerPublicKey() crypto.PublicKey
|
||||
// AddPrivateKey adds a private key to the service
|
||||
AddPrivateKey(keyID string, privateKey crypto.PrivateKey) error
|
||||
GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error)
|
||||
GetOrCreatePrivateKey(ctx context.Context, keyPrefix string, alg jose.SignatureAlgorithm) (string, crypto.Signer, error)
|
||||
}
|
||||
|
@ -1,37 +1,34 @@
|
||||
package signingkeysimpl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/services/secrets"
|
||||
"github.com/grafana/grafana/pkg/services/signingkeys"
|
||||
)
|
||||
|
||||
const (
|
||||
serverPrivateKeyID = "default"
|
||||
"github.com/grafana/grafana/pkg/services/signingkeys/signingkeystore"
|
||||
)
|
||||
|
||||
var _ signingkeys.Service = new(Service)
|
||||
|
||||
func ProvideEmbeddedSigningKeysService() (*Service, error) {
|
||||
func ProvideEmbeddedSigningKeysService(dbStore db.DB, secretsService secrets.Service,
|
||||
remoteCache remotecache.CacheStorage,
|
||||
) (*Service, error) {
|
||||
s := &Service{
|
||||
log: log.New("auth.key_service"),
|
||||
keys: map[string]crypto.Signer{},
|
||||
}
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
s.log.Error("Error generating private key", "err", err)
|
||||
return nil, signingkeys.ErrKeyGenerationFailed.Errorf("Error generating private key: %v", err)
|
||||
}
|
||||
|
||||
if err := s.AddPrivateKey(serverPrivateKeyID, privateKey); err != nil {
|
||||
return nil, err
|
||||
log: log.New("auth.key_service"),
|
||||
store: signingkeystore.NewSigningKeyStore(dbStore, secretsService),
|
||||
remoteCache: remoteCache,
|
||||
}
|
||||
|
||||
return s, nil
|
||||
@ -42,81 +39,49 @@ func ProvideEmbeddedSigningKeysService() (*Service, error) {
|
||||
//
|
||||
// The service is under active development and is not yet ready for production use.
|
||||
type Service struct {
|
||||
log log.Logger
|
||||
keys map[string]crypto.Signer
|
||||
log log.Logger
|
||||
store signingkeystore.SigningStore
|
||||
remoteCache remotecache.CacheStorage
|
||||
}
|
||||
|
||||
// GetJWKS returns the JSON Web Key Set (JWKS) with all the keys that can be used to verify tokens (public keys)
|
||||
func (s *Service) GetJWKS() jose.JSONWebKeySet {
|
||||
result := jose.JSONWebKeySet{}
|
||||
func (s *Service) GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error) {
|
||||
jwks, err := s.store.GetJWKS(ctx)
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
for keyID := range s.keys {
|
||||
// Skip error check because keyID must be a valid key ID
|
||||
jwk, _ := s.GetJWK(keyID)
|
||||
result.Keys = append(result.Keys, jwk)
|
||||
// GetOrCreatePrivateKey returns the private key with the specified key ID. If the key does not exist, it will be
|
||||
// created with the specified algorithm.
|
||||
// The key will be automatically rotated at the beginning of each month. The previous key will be kept for 30 days.
|
||||
func (s *Service) GetOrCreatePrivateKey(ctx context.Context,
|
||||
keyPrefix string, alg jose.SignatureAlgorithm) (string, crypto.Signer, error) {
|
||||
if alg != jose.ES256 {
|
||||
s.log.Error("Only ES256 is supported", "alg", alg)
|
||||
return "", nil, signingkeys.ErrKeyGenerationFailed.Errorf("Only ES256 is supported: %v", alg)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
keyID := keyMonthScopedID(keyPrefix, alg)
|
||||
signer, err := s.store.GetPrivateKey(ctx, keyID)
|
||||
if err == nil {
|
||||
return keyID, signer, nil
|
||||
}
|
||||
s.log.Debug("Private key not found, generating new key", "keyID", keyID, "err", err)
|
||||
|
||||
// GetJWK returns the JSON Web Key (JWK) with the specified key ID which can be used to verify tokens (public key)
|
||||
func (s *Service) GetJWK(keyID string) (jose.JSONWebKey, error) {
|
||||
privateKey, ok := s.keys[keyID]
|
||||
if !ok {
|
||||
s.log.Error("The specified key was not found", "keyID", keyID)
|
||||
return jose.JSONWebKey{}, signingkeys.ErrSigningKeyNotFound.Errorf("The specified key was not found: %s", keyID)
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
s.log.Error("Error generating private key", "err", err)
|
||||
return "", nil, signingkeys.ErrKeyGenerationFailed.Errorf("Error generating private key: %v", err)
|
||||
}
|
||||
|
||||
result := jose.JSONWebKey{
|
||||
Key: privateKey.Public(),
|
||||
Use: "sig",
|
||||
expiry := time.Now().Add(30 * 24 * time.Hour)
|
||||
if signer, err = s.store.AddPrivateKey(ctx, keyID, alg, privateKey, &expiry, false); err != nil && !errors.Is(err, signingkeys.ErrSigningKeyAlreadyExists) {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return keyID, signer, nil
|
||||
}
|
||||
|
||||
// GetPublicKey returns the public key with the specified key ID
|
||||
func (s *Service) GetPublicKey(keyID string) (crypto.PublicKey, error) {
|
||||
privateKey, ok := s.keys[keyID]
|
||||
if !ok {
|
||||
s.log.Error("The specified key was not found", "keyID", keyID)
|
||||
return nil, signingkeys.ErrSigningKeyNotFound.Errorf("The specified key was not found: %s", keyID)
|
||||
}
|
||||
|
||||
return privateKey.Public(), nil
|
||||
}
|
||||
|
||||
// GetPrivateKey returns the private key with the specified key ID
|
||||
func (s *Service) GetPrivateKey(keyID string) (crypto.PrivateKey, error) {
|
||||
privateKey, ok := s.keys[keyID]
|
||||
if !ok {
|
||||
s.log.Error("The specified key was not found", "keyID", keyID)
|
||||
return nil, signingkeys.ErrSigningKeyNotFound.Errorf("The specified key was not found: %s", keyID)
|
||||
}
|
||||
|
||||
return privateKey, nil
|
||||
}
|
||||
|
||||
// AddPrivateKey adds a private key to the service
|
||||
func (s *Service) AddPrivateKey(keyID string, privateKey crypto.PrivateKey) error {
|
||||
if _, ok := s.keys[keyID]; ok {
|
||||
s.log.Error("The specified key ID is already in use", "keyID", keyID)
|
||||
return signingkeys.ErrSigningKeyAlreadyExists.Errorf("The specified key ID is already in use: %s", keyID)
|
||||
}
|
||||
s.keys[keyID] = privateKey.(crypto.Signer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetServerPrivateKey returns the private key used to sign tokens
|
||||
func (s *Service) GetServerPrivateKey() crypto.PrivateKey {
|
||||
// The server private key is always available
|
||||
pk, _ := s.GetPrivateKey(serverPrivateKeyID)
|
||||
return pk
|
||||
}
|
||||
|
||||
// GetServerPrivateKey returns the private key used to sign tokens
|
||||
func (s *Service) GetServerPublicKey() crypto.PublicKey {
|
||||
// The server public key is always available
|
||||
publicKey, _ := s.GetPublicKey(serverPrivateKeyID)
|
||||
return publicKey
|
||||
func keyMonthScopedID(keyPrefix string, alg jose.SignatureAlgorithm) string {
|
||||
keyID := keyPrefix + "-" + time.Now().UTC().Format("2006-01") + "-" + strings.ToLower(string(alg))
|
||||
return keyID
|
||||
}
|
||||
|
@ -1,17 +1,22 @@
|
||||
package signingkeysimpl
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/services/signingkeys"
|
||||
"github.com/grafana/grafana/pkg/services/signingkeys/signingkeystore"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -28,84 +33,23 @@ func getPrivateKey(t *testing.T) *ecdsa.PrivateKey {
|
||||
return privateKey.(*ecdsa.PrivateKey)
|
||||
}
|
||||
|
||||
func setupTestService(t *testing.T) *Service {
|
||||
svc := &Service{
|
||||
log: log.NewNopLogger(),
|
||||
keys: map[string]crypto.Signer{serverPrivateKeyID: getPrivateKey(t)},
|
||||
}
|
||||
return svc
|
||||
}
|
||||
|
||||
func TestEmbeddedKeyService_GetJWK(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyID string
|
||||
want jose.JSONWebKey
|
||||
wantErr bool
|
||||
}{
|
||||
{name: "creates a JSON Web Key successfully",
|
||||
keyID: "default",
|
||||
want: jose.JSONWebKey{
|
||||
Key: getPrivateKey(t).Public(),
|
||||
Use: "sig",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{name: "returns error when the specified key was not found",
|
||||
keyID: "not-existing-key-id",
|
||||
want: jose.JSONWebKey{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
svc := setupTestService(t)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := svc.GetJWK(tt.keyID)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, got, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedKeyService_GetJWK_OnlyPublicKeyShared(t *testing.T) {
|
||||
svc := setupTestService(t)
|
||||
jwk, err := svc.GetJWK("default")
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
jwkJson, err := jwk.MarshalJSON()
|
||||
require.NoError(t, err)
|
||||
|
||||
kvs := make(map[string]any)
|
||||
err = json.Unmarshal(jwkJson, &kvs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// check that the private key is not shared
|
||||
require.NotContains(t, kvs, "d")
|
||||
require.NotContains(t, kvs, "p")
|
||||
require.NotContains(t, kvs, "q")
|
||||
}
|
||||
|
||||
func TestEmbeddedKeyService_GetJWKS(t *testing.T) {
|
||||
svc := &Service{
|
||||
log: log.NewNopLogger(),
|
||||
keys: map[string]crypto.Signer{
|
||||
serverPrivateKeyID: getPrivateKey(t),
|
||||
"other": getPrivateKey(t),
|
||||
},
|
||||
}
|
||||
jwk := svc.GetJWKS()
|
||||
|
||||
require.Equal(t, 2, len(jwk.Keys))
|
||||
}
|
||||
|
||||
func TestEmbeddedKeyService_GetJWKS_OnlyPublicKeyShared(t *testing.T) {
|
||||
svc := setupTestService(t)
|
||||
jwks := svc.GetJWKS()
|
||||
mockStore := signingkeystore.NewFakeStore()
|
||||
|
||||
_, err := mockStore.AddPrivateKey(context.Background(), signingkeys.ServerPrivateKeyID, jose.ES256, getPrivateKey(t), nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = mockStore.AddPrivateKey(context.Background(), "other", jose.ES256, getPrivateKey(t), nil, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := &Service{
|
||||
log: log.NewNopLogger(),
|
||||
store: mockStore,
|
||||
}
|
||||
jwks, err := svc.GetJWKS(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, 2, len(jwks.Keys))
|
||||
|
||||
jwksJson, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
@ -115,6 +59,7 @@ func TestEmbeddedKeyService_GetJWKS_OnlyPublicKeyShared(t *testing.T) {
|
||||
}
|
||||
|
||||
var kvs keys
|
||||
|
||||
err = json.Unmarshal(jwksJson, &kvs)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -126,120 +71,34 @@ func TestEmbeddedKeyService_GetJWKS_OnlyPublicKeyShared(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedKeyService_GetPublicKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyID string
|
||||
want crypto.PublicKey
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "returns the public key successfully",
|
||||
keyID: "default",
|
||||
want: getPrivateKey(t).Public(),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "returns error when the specified key was not found",
|
||||
keyID: "not-existent-key-id",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
svc := setupTestService(t)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := svc.GetPublicKey(tt.keyID)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, got, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestEmbeddedKeyService_GetOrCreatePrivateKey(t *testing.T) {
|
||||
mockStore := signingkeystore.NewFakeStore()
|
||||
|
||||
func TestEmbeddedKeyService_GetPrivateKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyID string
|
||||
want crypto.PrivateKey
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "returns the private key successfully",
|
||||
keyID: "default",
|
||||
want: getPrivateKey(t),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "returns error when the specified key was not found",
|
||||
keyID: "not-existent-key-id",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
svc := &Service{
|
||||
log: log.NewNopLogger(),
|
||||
store: mockStore,
|
||||
}
|
||||
svc := setupTestService(t)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := svc.GetPrivateKey(tt.keyID)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, got, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedKeyService_AddPrivateKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyID string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "adds the private key successfully",
|
||||
keyID: "new-key-id",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "returns error when the specified key is already in the store",
|
||||
keyID: serverPrivateKeyID,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
svc := setupTestService(t)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := svc.AddPrivateKey(tt.keyID, &dummyPrivateKey{})
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
wantedKeyID := keyMonthScopedID("test", jose.ES256)
|
||||
assert.Equal(t, wantedKeyID, fmt.Sprintf("test-%s-es256", time.Now().UTC().Format("2006-01")))
|
||||
|
||||
func TestProvideEmbeddedSigningKeysService(t *testing.T) {
|
||||
s, err := ProvideEmbeddedSigningKeysService()
|
||||
// only ES256 is supported
|
||||
_, _, err := svc.GetOrCreatePrivateKey(context.Background(), "test", jose.RS256)
|
||||
require.Error(t, err)
|
||||
|
||||
// first call should generate a key
|
||||
_, key, err := svc.GetOrCreatePrivateKey(context.Background(), "test", jose.ES256)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, s)
|
||||
require.NotNil(t, key)
|
||||
|
||||
// Verify that ProvideEmbeddedSigningKeysService generates an ECDSA private key by default
|
||||
require.IsType(t, &ecdsa.PrivateKey{}, s.GetServerPrivateKey())
|
||||
}
|
||||
assert.Contains(t, mockStore.PrivateKeys, wantedKeyID)
|
||||
|
||||
type dummyPrivateKey struct {
|
||||
}
|
||||
// second call should return the same key
|
||||
id, key2, err := svc.GetOrCreatePrivateKey(context.Background(), "test", jose.ES256)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, key2)
|
||||
require.Equal(t, key, key2)
|
||||
require.Equal(t, wantedKeyID, id)
|
||||
|
||||
func (d dummyPrivateKey) Public() crypto.PublicKey {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (d dummyPrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
||||
return nil, nil
|
||||
assert.Len(t, mockStore.PrivateKeys, 1)
|
||||
}
|
||||
|
@ -1,54 +1,48 @@
|
||||
package signingkeystest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
)
|
||||
|
||||
type FakeSigningKeysService struct {
|
||||
ExpectedJSONWebKeySet jose.JSONWebKeySet
|
||||
ExpectedJSONWebKey jose.JSONWebKey
|
||||
ExpectedKeys map[string]crypto.Signer
|
||||
ExpectedServerPrivateKey crypto.PrivateKey
|
||||
ExpectedServerPublicKey crypto.PublicKey
|
||||
ExpectedError error
|
||||
ExpectedJSONWebKeySet jose.JSONWebKeySet
|
||||
ExpectedJSONWebKey jose.JSONWebKey
|
||||
ExpectedKeys map[string]crypto.Signer
|
||||
ExpectedError error
|
||||
}
|
||||
|
||||
func (s *FakeSigningKeysService) GetJWKS() jose.JSONWebKeySet {
|
||||
return s.ExpectedJSONWebKeySet
|
||||
}
|
||||
|
||||
// GetJWK returns the JSON Web Key (JWK) with the specified key ID which can be used to verify tokens (public key)
|
||||
func (s *FakeSigningKeysService) GetJWK(keyID string) (jose.JSONWebKey, error) {
|
||||
return s.ExpectedJSONWebKey, s.ExpectedError
|
||||
func (s *FakeSigningKeysService) GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error) {
|
||||
return s.ExpectedJSONWebKeySet, nil
|
||||
}
|
||||
|
||||
// GetPublicKey returns the public key with the specified key ID
|
||||
func (s *FakeSigningKeysService) GetPublicKey(keyID string) (crypto.PublicKey, error) {
|
||||
func (s *FakeSigningKeysService) GetPublicKey(ctx context.Context, keyID string) (crypto.PublicKey, error) {
|
||||
return s.ExpectedKeys[keyID].Public(), s.ExpectedError
|
||||
}
|
||||
|
||||
// GetPrivateKey returns the private key with the specified key ID
|
||||
func (s *FakeSigningKeysService) GetPrivateKey(keyID string) (crypto.PrivateKey, error) {
|
||||
func (s *FakeSigningKeysService) GetPrivateKey(ctx context.Context, keyID string) (crypto.PrivateKey, error) {
|
||||
return s.ExpectedKeys[keyID], s.ExpectedError
|
||||
}
|
||||
|
||||
// GetServerPrivateKey returns the private key used to sign tokens
|
||||
func (s *FakeSigningKeysService) GetServerPrivateKey() crypto.PrivateKey {
|
||||
return s.ExpectedServerPrivateKey
|
||||
}
|
||||
|
||||
// GetServerPublicKey returns the public key used to verify tokens
|
||||
func (s *FakeSigningKeysService) GetServerPublicKey() crypto.PublicKey {
|
||||
return s.ExpectedServerPublicKey
|
||||
}
|
||||
|
||||
// AddPrivateKey adds a private key to the service
|
||||
func (s *FakeSigningKeysService) AddPrivateKey(keyID string, privateKey crypto.PrivateKey) error {
|
||||
func (s *FakeSigningKeysService) AddPrivateKey(ctx context.Context, keyID string,
|
||||
privateKey crypto.Signer, alg jose.SignatureAlgorithm, expiresAt *time.Time, force bool) error {
|
||||
if s.ExpectedError != nil {
|
||||
return s.ExpectedError
|
||||
}
|
||||
s.ExpectedKeys[keyID] = privateKey.(crypto.Signer)
|
||||
s.ExpectedKeys[keyID] = privateKey
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FakeSigningKeysService) GetOrCreatePrivateKey(ctx context.Context,
|
||||
keyPrefix string, alg jose.SignatureAlgorithm) (string, crypto.Signer, error) {
|
||||
if s.ExpectedError != nil {
|
||||
return "", nil, s.ExpectedError
|
||||
}
|
||||
return keyPrefix, s.ExpectedKeys[keyPrefix], nil
|
||||
}
|
||||
|
62
pkg/services/signingkeys/signingkeystore/fake.go
Normal file
62
pkg/services/signingkeys/signingkeystore/fake.go
Normal file
@ -0,0 +1,62 @@
|
||||
package signingkeystore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
)
|
||||
|
||||
type FakeStore struct {
|
||||
PrivateKeys map[string]crypto.Signer
|
||||
jwks jose.JSONWebKeySet
|
||||
}
|
||||
|
||||
func NewFakeStore() *FakeStore {
|
||||
return &FakeStore{
|
||||
PrivateKeys: make(map[string]crypto.Signer),
|
||||
jwks: jose.JSONWebKeySet{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *FakeStore) GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error) {
|
||||
return s.jwks, nil
|
||||
}
|
||||
|
||||
func (s *FakeStore) AddPrivateKey(ctx context.Context, keyID string, alg jose.SignatureAlgorithm,
|
||||
privateKey crypto.Signer, expiresAt *time.Time, force bool) (crypto.Signer, error) {
|
||||
if !force {
|
||||
if key, ok := s.PrivateKeys[keyID]; ok {
|
||||
if !hasExpired(key) {
|
||||
return nil, fmt.Errorf("key already exists and has not expired")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.PrivateKeys[keyID] = privateKey
|
||||
|
||||
jwk := jose.JSONWebKey{
|
||||
Key: privateKey.Public(),
|
||||
Algorithm: string(alg),
|
||||
KeyID: keyID,
|
||||
Use: "sig",
|
||||
}
|
||||
|
||||
s.jwks.Keys = append(s.jwks.Keys, jwk)
|
||||
|
||||
return privateKey, nil
|
||||
}
|
||||
|
||||
func (s *FakeStore) GetPrivateKey(ctx context.Context, keyID string) (crypto.Signer, error) {
|
||||
if key, ok := s.PrivateKeys[keyID]; ok {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("key not found")
|
||||
}
|
||||
|
||||
func hasExpired(key crypto.Signer) bool {
|
||||
return false
|
||||
}
|
215
pkg/services/signingkeys/signingkeystore/store.go
Normal file
215
pkg/services/signingkeys/signingkeystore/store.go
Normal file
@ -0,0 +1,215 @@
|
||||
package signingkeystore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/services/secrets"
|
||||
"github.com/grafana/grafana/pkg/services/signingkeys"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/session"
|
||||
)
|
||||
|
||||
type SigningStore interface {
|
||||
// GetJWKS returns the JSON Web Key Set for the service
|
||||
GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error)
|
||||
// AddPrivateKey adds a private key to the service. If the key already exists, it will be updated if force is true.
|
||||
// If force is false, the key will only be updated if it has expired. If the key does not exist, it will be added.
|
||||
// If expiresAt is nil, the key will not expire. Retrieve the result key with GetPrivateKey.
|
||||
AddPrivateKey(ctx context.Context, keyID string, alg jose.SignatureAlgorithm,
|
||||
privateKey crypto.Signer, expiresAt *time.Time, force bool) (crypto.Signer, error)
|
||||
// GetPrivateKey returns the private key with the specified key ID
|
||||
GetPrivateKey(ctx context.Context, keyID string) (crypto.Signer, error)
|
||||
}
|
||||
|
||||
var _ SigningStore = (*Store)(nil)
|
||||
|
||||
type Store struct {
|
||||
dbStore db.DB
|
||||
secretsService secrets.Service
|
||||
}
|
||||
|
||||
type SigningKey struct {
|
||||
ID int64 `json:"-" db:"id"`
|
||||
KeyID string `json:"key_id" db:"key_id"`
|
||||
PrivateKey []byte `json:"private_key" db:"private_key"`
|
||||
AddedAt time.Time `json:"added_at" db:"added_at"`
|
||||
ExpiresAt *time.Time `json:"expires_at" db:"expires_at"`
|
||||
Alg jose.SignatureAlgorithm `json:"alg" db:"alg"`
|
||||
}
|
||||
|
||||
func NewSigningKeyStore(dbStore db.DB, secretsService secrets.Service) *Store {
|
||||
return &Store{
|
||||
dbStore: dbStore,
|
||||
secretsService: secretsService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetJWKS returns the JSON Web Key Set (JWKS) for the service. Expired keys will not be returned.
|
||||
func (s *Store) GetJWKS(ctx context.Context) (jose.JSONWebKeySet, error) {
|
||||
keySet := jose.JSONWebKeySet{}
|
||||
|
||||
keys := []*SigningKey{}
|
||||
if err := s.dbStore.GetSqlxSession().Select(ctx,
|
||||
&keys, "SELECT * FROM signing_key WHERE expires_at IS NULL OR expires_at > ?", time.Now()); err != nil {
|
||||
return keySet, err
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
assertedKey, err := s.decodePrivateKey(ctx, key)
|
||||
if err != nil {
|
||||
return keySet, err
|
||||
}
|
||||
|
||||
keySet.Keys = append(keySet.Keys, jose.JSONWebKey{
|
||||
Key: assertedKey.Public(),
|
||||
Algorithm: string(key.Alg),
|
||||
KeyID: key.KeyID,
|
||||
Use: "sig",
|
||||
})
|
||||
}
|
||||
|
||||
return keySet, nil
|
||||
}
|
||||
|
||||
// AddPrivateKey adds a private key to the service.
|
||||
func (s *Store) AddPrivateKey(ctx context.Context,
|
||||
keyID string, alg jose.SignatureAlgorithm, privateKey crypto.Signer, expiresAt *time.Time, force bool) (crypto.Signer, error) {
|
||||
privateKeyPEM, err := s.encodePrivateKey(ctx, privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := &SigningKey{
|
||||
KeyID: keyID,
|
||||
PrivateKey: privateKeyPEM,
|
||||
AddedAt: time.Now(),
|
||||
Alg: alg,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
dbSession := s.dbStore.GetSqlxSession()
|
||||
var signer crypto.Signer
|
||||
err = dbSession.WithTransaction(ctx, func(tx *session.SessionTx) error {
|
||||
existingKey := SigningKey{}
|
||||
err := tx.Get(ctx, &existingKey, "SELECT * FROM signing_key WHERE key_id = ?", keyID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(existingKey.PrivateKey) == 0 {
|
||||
_, err = tx.Exec(ctx,
|
||||
"INSERT INTO signing_key (key_id, private_key, added_at, alg, expires_at) VALUES (?, ?, ?, ?, ?)",
|
||||
key.KeyID, key.PrivateKey, key.AddedAt, key.Alg, key.ExpiresAt)
|
||||
signer = privateKey
|
||||
return err
|
||||
}
|
||||
|
||||
if force || (existingKey.ExpiresAt != nil && existingKey.ExpiresAt.Before(time.Now())) {
|
||||
_, err = tx.Exec(ctx,
|
||||
"UPDATE signing_key SET private_key = ?, added_at = ?, alg = ?, expires_at = ? WHERE key_id = ?",
|
||||
key.PrivateKey, key.AddedAt, key.Alg, key.ExpiresAt, key.KeyID)
|
||||
signer = privateKey
|
||||
return err
|
||||
}
|
||||
|
||||
signer, err = s.decodePrivateKey(ctx, &existingKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return signingkeys.ErrSigningKeyAlreadyExists.Errorf("The specified key already exists: %s", keyID)
|
||||
})
|
||||
return signer, err
|
||||
}
|
||||
|
||||
// GetPrivateKey returns the private key with the specified key ID. Expired keys will not be returned.
|
||||
func (s *Store) GetPrivateKey(ctx context.Context, keyID string) (crypto.Signer, error) {
|
||||
key := &SigningKey{}
|
||||
err := s.dbStore.GetSqlxSession().Get(ctx, key,
|
||||
"SELECT * FROM signing_key WHERE key_id = ?", keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Bail out if key has expired
|
||||
if key.ExpiresAt != nil && key.ExpiresAt.Before(time.Now()) {
|
||||
return nil, signingkeys.ErrSigningKeyNotFound.Errorf("The specified key was not found: %s", keyID)
|
||||
}
|
||||
|
||||
signKey, err := s.decodePrivateKey(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signKey, nil
|
||||
}
|
||||
|
||||
func (s *Store) encodePrivateKey(ctx context.Context, privateKey crypto.Signer) ([]byte, error) {
|
||||
// Encode private key to binary format
|
||||
pKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Encode private key to PEM format
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: pKeyBytes,
|
||||
})
|
||||
|
||||
encrypted, err := s.secretsService.Encrypt(ctx, privateKeyPEM, secrets.WithoutScope())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encoded := make([]byte, base64.StdEncoding.EncodedLen(len(encrypted)))
|
||||
base64.StdEncoding.Encode(encoded, encrypted)
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
func (s *Store) decodePrivateKey(ctx context.Context, signingKey *SigningKey) (crypto.Signer, error) {
|
||||
// Bail out if empty string since it'll cause a segfault in Decrypt
|
||||
if len(signingKey.PrivateKey) == 0 {
|
||||
return nil, errors.New("private key is empty")
|
||||
}
|
||||
|
||||
payload := make([]byte, base64.StdEncoding.DecodedLen(len(signingKey.PrivateKey)))
|
||||
_, err := base64.StdEncoding.Decode(payload, signingKey.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decrypted, err := s.secretsService.Decrypt(ctx, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(decrypted)
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode private key PEM")
|
||||
}
|
||||
|
||||
if block.Type != "PRIVATE KEY" {
|
||||
return nil, errors.New("invalid block type")
|
||||
}
|
||||
|
||||
parsedKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
assertedKey, ok := parsedKey.(crypto.Signer)
|
||||
if !ok {
|
||||
return nil, errors.New("failed to assert private key as crypto.Signer")
|
||||
}
|
||||
return assertedKey, nil
|
||||
}
|
199
pkg/services/signingkeys/signingkeystore/store_test.go
Normal file
199
pkg/services/signingkeys/signingkeystore/store_test.go
Normal file
@ -0,0 +1,199 @@
|
||||
package signingkeystore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||
"github.com/grafana/grafana/pkg/services/signingkeys"
|
||||
)
|
||||
|
||||
func TestIntegrationSigningKeyStore(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
keyFunc func() (crypto.Signer, error)
|
||||
keyID string
|
||||
alg jose.SignatureAlgorithm
|
||||
expected jose.JSONWebKey
|
||||
}{
|
||||
{
|
||||
name: "RSA key",
|
||||
keyFunc: func() (crypto.Signer, error) {
|
||||
return rsa.GenerateKey(rand.Reader, 2048)
|
||||
},
|
||||
keyID: "test-rsa-key",
|
||||
alg: jose.RS256,
|
||||
expected: jose.JSONWebKey{
|
||||
Key: &rsa.PublicKey{},
|
||||
Algorithm: "RS256",
|
||||
KeyID: "test-rsa-key",
|
||||
Use: "sig",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Elliptic Curve key",
|
||||
keyFunc: func() (crypto.Signer, error) {
|
||||
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
},
|
||||
keyID: "test-ec-key",
|
||||
alg: jose.ES256,
|
||||
expected: jose.JSONWebKey{
|
||||
Key: &ecdsa.PublicKey{},
|
||||
Algorithm: "ES256",
|
||||
KeyID: "test-ec-key",
|
||||
Use: "sig",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
dbStore := db.InitTestDB(t)
|
||||
secretSvc := fakes.NewFakeSecretsService()
|
||||
store := NewSigningKeyStore(dbStore, secretSvc)
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
key, err := tc.keyFunc()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = store.AddPrivateKey(ctx, tc.keyID, tc.alg, key, nil, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
retrievedKey, err := store.GetPrivateKey(ctx, tc.keyID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, key.Public(), retrievedKey.Public())
|
||||
|
||||
jwks, err := store.GetJWKS(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
require.Len(t, jwks.Keys, 1)
|
||||
assert.Equal(t, key.Public(), jwks.Keys[0].Key)
|
||||
assert.Equal(t, tc.expected.Algorithm, jwks.Keys[0].Algorithm)
|
||||
assert.Equal(t, tc.expected.KeyID, jwks.Keys[0].KeyID)
|
||||
assert.Equal(t, tc.expected.Use, jwks.Keys[0].Use)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegrationAddPrivateKey(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
dbStore := db.InitTestDB(t)
|
||||
secretSvc := fakes.NewFakeSecretsService()
|
||||
store := NewSigningKeyStore(dbStore, secretSvc)
|
||||
|
||||
key1 := generateRSAKey(t)
|
||||
key2 := generateECKey(t)
|
||||
key3 := generateECKey(t)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
keyID string
|
||||
alg jose.SignatureAlgorithm
|
||||
privateKey crypto.Signer
|
||||
expiresAt *time.Time
|
||||
force bool
|
||||
expectedErr error
|
||||
expectedKey crypto.Signer
|
||||
expectedGot crypto.Signer
|
||||
}{
|
||||
{
|
||||
name: "Add new private key",
|
||||
keyID: "test-key-1",
|
||||
alg: jose.RS256,
|
||||
privateKey: key1,
|
||||
force: false,
|
||||
expectedKey: key1,
|
||||
expectedGot: key1,
|
||||
},
|
||||
{
|
||||
name: "Add new private key with expiration",
|
||||
keyID: "test-key-2",
|
||||
alg: jose.ES256,
|
||||
privateKey: key2,
|
||||
expiresAt: &[]time.Time{time.Now().Add(24 * time.Hour)}[0],
|
||||
force: false,
|
||||
expectedKey: key2,
|
||||
expectedGot: key2,
|
||||
},
|
||||
{
|
||||
name: "Fail to replace unexpired key",
|
||||
keyID: "test-key-1",
|
||||
alg: jose.RS256,
|
||||
privateKey: key3,
|
||||
expiresAt: &[]time.Time{time.Now().Add(-24 * time.Hour)}[0],
|
||||
force: false,
|
||||
expectedErr: signingkeys.ErrSigningKeyAlreadyExists,
|
||||
expectedKey: key1,
|
||||
expectedGot: key1,
|
||||
},
|
||||
{
|
||||
name: "Replace key1 private key with force, already expired",
|
||||
keyID: "test-key-1",
|
||||
alg: jose.ES256,
|
||||
privateKey: key3,
|
||||
expiresAt: &[]time.Time{time.Now().Add(-24 * time.Hour)}[0],
|
||||
force: true,
|
||||
expectedKey: nil,
|
||||
expectedGot: key3,
|
||||
},
|
||||
{
|
||||
name: "Replace key1 private key with no force, is expired",
|
||||
keyID: "test-key-1",
|
||||
alg: jose.ES256,
|
||||
privateKey: key1,
|
||||
expiresAt: &[]time.Time{time.Now().Add(24 * time.Hour)}[0],
|
||||
force: false,
|
||||
expectedKey: nil,
|
||||
expectedGot: key1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := store.AddPrivateKey(ctx, tc.keyID, tc.alg, tc.privateKey, tc.expiresAt, tc.force)
|
||||
if tc.expectedErr != nil {
|
||||
assert.ErrorIs(t, err, tc.expectedErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
if tc.expectedGot != nil {
|
||||
assert.Equal(t, tc.expectedGot.Public(), got.Public())
|
||||
} else {
|
||||
assert.Nil(t, got)
|
||||
}
|
||||
|
||||
if tc.expectedKey != nil {
|
||||
retrievedKey, err := store.GetPrivateKey(ctx, tc.keyID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedKey.Public(), retrievedKey.Public())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateRSAKey(t *testing.T) *rsa.PrivateKey {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
return key
|
||||
}
|
||||
|
||||
func generateECKey(t *testing.T) *ecdsa.PrivateKey {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
return key
|
||||
}
|
@ -5,6 +5,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrations/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrations/anonservice"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrations/oauthserver"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrations/signingkeys"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrations/ualert"
|
||||
. "github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
)
|
||||
@ -99,6 +100,7 @@ func (*OSSMigrations) AddMigration(mg *Migrator) {
|
||||
}
|
||||
|
||||
anonservice.AddMigration(mg)
|
||||
signingkeys.AddMigration(mg)
|
||||
}
|
||||
|
||||
func addStarMigrations(mg *Migrator) {
|
||||
|
23
pkg/services/sqlstore/migrations/signingkeys/migrations.go
Normal file
23
pkg/services/sqlstore/migrations/signingkeys/migrations.go
Normal file
@ -0,0 +1,23 @@
|
||||
package signingkeys
|
||||
|
||||
import "github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
|
||||
func AddMigration(mg *migrator.Migrator) {
|
||||
var signingKeysV1 = migrator.Table{
|
||||
Name: "signing_key",
|
||||
Columns: []*migrator.Column{
|
||||
{Name: "id", Type: migrator.DB_BigInt, IsPrimaryKey: true, IsAutoIncrement: true},
|
||||
{Name: "key_id", Type: migrator.DB_NVarchar, Length: 255, Nullable: false},
|
||||
{Name: "private_key", Type: migrator.DB_Text, Nullable: false},
|
||||
{Name: "added_at", Type: migrator.DB_DateTime, Nullable: false},
|
||||
{Name: "expires_at", Type: migrator.DB_DateTime, Nullable: true},
|
||||
{Name: "alg", Type: migrator.DB_NVarchar, Length: 255, Nullable: false},
|
||||
},
|
||||
Indices: []*migrator.Index{
|
||||
{Cols: []string{"key_id"}, Type: migrator.UniqueIndex},
|
||||
},
|
||||
}
|
||||
|
||||
mg.AddMigration("create signing_key table", migrator.NewAddTableMigration(signingKeysV1))
|
||||
mg.AddMigration("add unique index signing_key.key_id", migrator.NewAddIndexMigration(signingKeysV1, signingKeysV1.Indices[0]))
|
||||
}
|
Loading…
Reference in New Issue
Block a user