mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Check id token expiry date (#69829)
* fixed: added id token expiry check to oauth token sync * use go-jose and id token in cache * Update pkg/services/authn/authnimpl/sync/oauth_token_sync.go * refactored getOAuthTokenCacheTTL and added unit tests * Small changes to oauth_token_sync * Remove unnecessary contexthandler changes --------- Co-authored-by: linoman <2051016+linoman@users.noreply.github.com> Co-authored-by: Mihaly Gyongyosi <mgyongyosi@users.noreply.github.com>
This commit is contained in:
parent
7bf3998510
commit
6d98d06f6e
@ -3,14 +3,18 @@ package sync
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
@ -64,10 +68,15 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||
return nil
|
||||
}
|
||||
|
||||
idTokenExpiry, err := getIDTokenExpiry(token)
|
||||
if err != nil {
|
||||
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
|
||||
}
|
||||
|
||||
// token has no expire time configured, so we don't have to refresh it
|
||||
if token.OAuthExpiry.IsZero() {
|
||||
// cache the token check, so we don't perform it on every request
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry))
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -84,11 +93,19 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||
return nil
|
||||
}
|
||||
|
||||
expires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
|
||||
hasIdTokenExpired := false
|
||||
idTokenExpires := time.Time{}
|
||||
|
||||
if !idTokenExpiry.IsZero() {
|
||||
idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
|
||||
hasIdTokenExpired = idTokenExpires.Before(time.Now())
|
||||
}
|
||||
// token has not expired, so we don't have to refresh it
|
||||
if !expires.Before(time.Now()) {
|
||||
if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired {
|
||||
// cache the token check, so we don't perform it on every request
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(expires))
|
||||
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -113,15 +130,47 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||
|
||||
const maxOAuthTokenCacheTTL = 10 * time.Minute
|
||||
|
||||
func getOAuthTokenCacheTTL(t time.Time) time.Duration {
|
||||
if t.IsZero() {
|
||||
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
|
||||
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||
return maxOAuthTokenCacheTTL
|
||||
}
|
||||
|
||||
ttl := time.Until(t)
|
||||
if ttl > maxOAuthTokenCacheTTL {
|
||||
return maxOAuthTokenCacheTTL
|
||||
min := func(a, b time.Duration) time.Duration {
|
||||
if a <= b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
return ttl
|
||||
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
|
||||
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
|
||||
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
|
||||
}
|
||||
|
||||
// getIDTokenExpiry extracts the expiry time from the ID token
|
||||
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
|
||||
if token.OAuthIdToken == "" {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
var claims Claims
|
||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
|
||||
}
|
||||
|
||||
return time.Unix(claims.Exp, 0), nil
|
||||
}
|
||||
|
@ -2,12 +2,13 @@ package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/localcache"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
@ -18,9 +19,11 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
type testCase struct {
|
||||
desc string
|
||||
identity *authn.Identity
|
||||
@ -95,6 +98,13 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
|
||||
},
|
||||
{
|
||||
desc: "should refresh access token when ID token has expired",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
|
||||
expectHasEntryCalled: true,
|
||||
expectTryRefreshTokenCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -155,3 +165,93 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// fakeIDToken is used to create a fake invalid token to verify expiry logic
|
||||
func fakeIDToken(t *testing.T, expiryDate time.Time) string {
|
||||
type Header struct {
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
}
|
||||
type Payload struct {
|
||||
Iss string `json:"iss"`
|
||||
Sub string `json:"sub"`
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
|
||||
header, err := json.Marshal(Header{Kid: "123", Alg: "none"})
|
||||
require.NoError(t, err)
|
||||
u := expiryDate.UTC().Unix()
|
||||
payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u})
|
||||
require.NoError(t, err)
|
||||
|
||||
fakeSignature := []byte("6ICJm")
|
||||
return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature))
|
||||
}
|
||||
|
||||
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
|
||||
defaultTime := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
accessTokenExpiry time.Time
|
||||
idTokenExpiry time.Time
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: time.Time{},
|
||||
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
|
||||
accessTokenExpiry: time.Time{},
|
||||
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: time.Time{},
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
{
|
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
|
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: time.Time{},
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
|
||||
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
|
||||
},
|
||||
{
|
||||
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
|
||||
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
|
||||
want: maxOAuthTokenCacheTTL,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
|
||||
|
||||
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user