Auth: Check id token expiry date (#69829)

* fixed: added id token expiry check to oauth token sync

* use go-jose and id token in cache

* Update pkg/services/authn/authnimpl/sync/oauth_token_sync.go

* refactored getOAuthTokenCacheTTL and added unit tests

* Small changes to oauth_token_sync

* Remove unnecessary contexthandler changes

---------

Co-authored-by: linoman <2051016+linoman@users.noreply.github.com>
Co-authored-by: Mihaly Gyongyosi <mgyongyosi@users.noreply.github.com>
This commit is contained in:
Aksel Skaar Leirvaag 2023-08-01 09:23:47 +02:00 committed by GitHub
parent 7bf3998510
commit 6d98d06f6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 162 additions and 13 deletions

View File

@ -3,14 +3,18 @@ package sync
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/util/errutil"
@ -64,10 +68,15 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
return nil
}
idTokenExpiry, err := getIDTokenExpiry(token)
if err != nil {
s.log.FromContext(ctx).Error("Failed to extract expiry of ID token", "id", identity.ID, "error", err)
}
// token has no expire time configured, so we don't have to refresh it
if token.OAuthExpiry.IsZero() {
// cache the token check, so we don't perform it on every request
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry))
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(token.OAuthExpiry, idTokenExpiry))
return nil
}
@ -84,11 +93,19 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
return nil
}
expires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
accessTokenExpires := token.OAuthExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
hasIdTokenExpired := false
idTokenExpires := time.Time{}
if !idTokenExpiry.IsZero() {
idTokenExpires = idTokenExpiry.Round(0).Add(-oauthtoken.ExpiryDelta)
hasIdTokenExpired = idTokenExpires.Before(time.Now())
}
// token has not expired, so we don't have to refresh it
if !expires.Before(time.Now()) {
if !accessTokenExpires.Before(time.Now()) && !hasIdTokenExpired {
// cache the token check, so we don't perform it on every request
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(expires))
s.cache.Set(identity.ID, struct{}{}, getOAuthTokenCacheTTL(accessTokenExpires, idTokenExpires))
return nil
}
@ -113,15 +130,47 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
const maxOAuthTokenCacheTTL = 10 * time.Minute
func getOAuthTokenCacheTTL(t time.Time) time.Duration {
if t.IsZero() {
func getOAuthTokenCacheTTL(accessTokenExpiry, idTokenExpiry time.Time) time.Duration {
if accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
return maxOAuthTokenCacheTTL
}
ttl := time.Until(t)
if ttl > maxOAuthTokenCacheTTL {
return maxOAuthTokenCacheTTL
min := func(a, b time.Duration) time.Duration {
if a <= b {
return a
}
return b
}
return ttl
if accessTokenExpiry.IsZero() && !idTokenExpiry.IsZero() {
return min(time.Until(idTokenExpiry), maxOAuthTokenCacheTTL)
}
if !accessTokenExpiry.IsZero() && idTokenExpiry.IsZero() {
return min(time.Until(accessTokenExpiry), maxOAuthTokenCacheTTL)
}
return min(min(time.Until(accessTokenExpiry), time.Until(idTokenExpiry)), maxOAuthTokenCacheTTL)
}
// getIDTokenExpiry extracts the expiry time from the ID token
func getIDTokenExpiry(token *login.UserAuth) (time.Time, error) {
if token.OAuthIdToken == "" {
return time.Time{}, nil
}
parsedToken, err := jwt.ParseSigned(token.OAuthIdToken)
if err != nil {
return time.Time{}, fmt.Errorf("error parsing id token: %w", err)
}
type Claims struct {
Exp int64 `json:"exp"`
}
var claims Claims
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
return time.Time{}, fmt.Errorf("error getting claims from id token: %w", err)
}
return time.Unix(claims.Exp, 0), nil
}

View File

@ -2,12 +2,13 @@ package sync
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
@ -18,9 +19,11 @@ import (
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken/oauthtokentest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
type testCase struct {
desc string
identity *authn.Identity
@ -95,6 +98,13 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
oauthInfo: &social.OAuthInfo{UseRefreshToken: false},
},
{
desc: "should refresh access token when ID token has expired",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}},
expectHasEntryCalled: true,
expectTryRefreshTokenCalled: true,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(10 * time.Minute), OAuthIdToken: fakeIDToken(t, time.Now().Add(-10*time.Minute))},
},
}
for _, tt := range tests {
@ -155,3 +165,93 @@ func TestOauthTokenSync_SyncOAuthTokenHook(t *testing.T) {
})
}
}
// fakeIDToken is used to create a fake invalid token to verify expiry logic
func fakeIDToken(t *testing.T, expiryDate time.Time) string {
type Header struct {
Kid string `json:"kid"`
Alg string `json:"alg"`
}
type Payload struct {
Iss string `json:"iss"`
Sub string `json:"sub"`
Exp int64 `json:"exp"`
}
header, err := json.Marshal(Header{Kid: "123", Alg: "none"})
require.NoError(t, err)
u := expiryDate.UTC().Unix()
payload, err := json.Marshal(Payload{Iss: "fake", Sub: "a-sub", Exp: u})
require.NoError(t, err)
fakeSignature := []byte("6ICJm")
return fmt.Sprintf("%s.%s.%s", base64.RawURLEncoding.EncodeToString(header), base64.RawURLEncoding.EncodeToString(payload), base64.RawURLEncoding.EncodeToString(fakeSignature))
}
func TestOAuthTokenSync_getOAuthTokenCacheTTL(t *testing.T) {
defaultTime := time.Now()
tests := []struct {
name string
accessTokenExpiry time.Time
idTokenExpiry time.Time
want time.Duration
}{
{
name: "should return maxOAuthTokenCacheTTL when no expiry is given",
accessTokenExpiry: time.Time{},
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return maxOAuthTokenCacheTTL when access token is not given and id token expiry is greater than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
{
name: "should return idTokenExpiry when access token is not given and id token expiry is less than max cache ttl",
accessTokenExpiry: time.Time{},
idTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: maxOAuthTokenCacheTTL,
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and id token is not given",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: time.Time{},
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return accessTokenExpiry when access token expiry is less than max cache ttl and less than id token expiry",
accessTokenExpiry: defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-5*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return idTokenExpiry when id token expiry is less than max cache ttl and less than access token expiry",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL),
want: time.Until(defaultTime.Add(-3*time.Minute + maxOAuthTokenCacheTTL)),
},
{
name: "should return maxOAuthTokenCacheTTL when access token expiry is greater than max cache ttl and id token expiry is greater than max cache ttl",
accessTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
idTokenExpiry: defaultTime.Add(5*time.Minute + maxOAuthTokenCacheTTL),
want: maxOAuthTokenCacheTTL,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getOAuthTokenCacheTTL(tt.accessTokenExpiry, tt.idTokenExpiry)
assert.Equal(t, tt.want.Round(time.Second), got.Round(time.Second))
})
}
}