IDForwading: cache based on expires in (#81136)

* IDFowarding: Cache based on expires in

* IDFowarding: Change default expires in

---------

Co-authored-by: Victor Cinaglia <victor@grafana.com>
This commit is contained in:
Karl Persson 2024-01-24 13:56:44 +01:00 committed by GitHub
parent 1c02220916
commit 28bb6979f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 10 deletions

View File

@ -24,8 +24,8 @@ import (
const ( const (
cachePrefix = "id-token" cachePrefix = "id-token"
tokenTTL = 1 * time.Hour tokenTTL = 10 * time.Minute
cacheTTL = 58 * time.Minute cacheLeeway = 30 * time.Second
) )
var _ auth.IDService = (*Service)(nil) var _ auth.IDService = (*Service)(nil)
@ -101,7 +101,22 @@ func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (stri
return "", err return "", err
} }
if err := s.cache.Set(ctx, cacheKey, []byte(token), cacheTTL); err != nil { parsed, err := jwt.ParseSigned(token)
if err != nil {
s.metrics.failedTokenSigningCounter.Inc()
return "", err
}
extracted := auth.IDClaims{}
// We don't need to verify the signature here, we are only intrested in checking
// when the token expires.
if err := parsed.UnsafeClaimsWithoutVerification(&extracted); err != nil {
s.metrics.failedTokenSigningCounter.Inc()
return "", err
}
expires := time.Until(extracted.Expiry.Time())
if err := s.cache.Set(ctx, cacheKey, []byte(token), expires-cacheLeeway); err != nil {
s.logger.FromContext(ctx).Error("Failed to add id token to cache", "error", err) s.logger.FromContext(ctx).Error("Failed to add id token to cache", "error", err)
} }

View File

@ -2,9 +2,10 @@ package idimpl
import ( import (
"context" "context"
"encoding/json"
"testing" "testing"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -53,11 +54,14 @@ func Test_ProvideService(t *testing.T) {
func TestService_SignIdentity(t *testing.T) { func TestService_SignIdentity(t *testing.T) {
signer := &idtest.MockSigner{ signer := &idtest.MockSigner{
SignIDTokenFn: func(_ context.Context, claims *auth.IDClaims) (string, error) { SignIDTokenFn: func(_ context.Context, claims *auth.IDClaims) (string, error) {
data, err := json.Marshal(claims) key := []byte("key")
if err != nil { s, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, nil)
return "", err require.NoError(t, err)
}
return string(data), nil token, err := jwt.Signed(s).Claims(claims).CompactSerialize()
require.NoError(t, err)
return token, nil
}, },
} }
@ -81,8 +85,11 @@ func TestService_SignIdentity(t *testing.T) {
token, err := s.SignIdentity(context.Background(), &authn.Identity{ID: "user:1"}) token, err := s.SignIdentity(context.Background(), &authn.Identity{ID: "user:1"})
require.NoError(t, err) require.NoError(t, err)
parsed, err := jwt.ParseSigned(token)
require.NoError(t, err)
claims := &auth.IDClaims{} claims := &auth.IDClaims{}
require.NoError(t, json.Unmarshal([]byte(token), claims)) require.NoError(t, parsed.UnsafeClaimsWithoutVerification(&claims))
assert.Equal(t, login.AzureADAuthModule, claims.AuthenticatedBy) assert.Equal(t, login.AzureADAuthModule, claims.AuthenticatedBy)
}) })
} }