Auth: Add support for custom signing keys in auth.azure_ad (#71365)

* fallthrough JWKS validation and caching for Azure

* remove unused field
This commit is contained in:
Jo
2023-07-12 11:29:02 +02:00
committed by GitHub
parent c2a0487572
commit fbfdd6ba32
4 changed files with 225 additions and 156 deletions

View File

@@ -0,0 +1,117 @@
package social
import (
"bytes"
"context"
"encoding/json"
"net/http"
"strconv"
"strings"
"time"
)
const (
azureCacheKeyPrefix = "azuread_oauth_jwks-"
defaultCacheExpiration = 5 * time.Minute
)
func (s *SocialAzureAD) getJWKSCacheKey() (string, error) {
return azureCacheKeyPrefix + s.ClientID, nil
}
func (s *SocialAzureAD) retrieveJWKSFromCache(ctx context.Context, client *http.Client, authURL string) (*keySetJWKS, time.Duration, error) {
cacheKey, err := s.getJWKSCacheKey()
if err != nil {
return nil, 0, err
}
if val, err := s.cache.Get(ctx, cacheKey); err == nil {
var jwks keySetJWKS
err := json.Unmarshal(val, &jwks)
s.log.Debug("Retrieved cached key set", "cacheKey", cacheKey)
return &jwks, 0, err
}
s.log.Debug("Keyset not found in cache", "err", err)
return &keySetJWKS{}, 0, nil
}
func (s *SocialAzureAD) cacheJWKS(ctx context.Context, jwks *keySetJWKS, cacheExpiration time.Duration) error {
cacheKey, err := s.getJWKSCacheKey()
if err != nil {
return err
}
var jsonBuf bytes.Buffer
if err := json.NewEncoder(&jsonBuf).Encode(jwks); err != nil {
return err
}
if err := s.cache.Set(ctx, cacheKey, jsonBuf.Bytes(), cacheExpiration); err != nil {
s.log.Warn("Failed to cache key set", "err", err)
}
return nil
}
func (s *SocialAzureAD) retrieveGeneralJWKS(ctx context.Context, client *http.Client, authURL string) (*keySetJWKS, time.Duration, error) {
keysetURL := strings.Replace(authURL, "/oauth2/v2.0/authorize", "/discovery/v2.0/keys", 1)
resp, err := s.httpGet(context.Background(), client, keysetURL)
if err != nil {
return nil, 0, err
}
bytesReader := bytes.NewReader(resp.Body)
var jwks keySetJWKS
if err := json.NewDecoder(bytesReader).Decode(&jwks); err != nil {
return nil, 0, err
}
cacheExpiration := getCacheExpiration(resp.Headers.Get("cache-control"))
s.log.Debug("Retrieved general key set", "url", keysetURL, "cacheExpiration", cacheExpiration)
return &jwks, cacheExpiration, nil
}
func (s *SocialAzureAD) retrieveSpecificJWKS(ctx context.Context, client *http.Client, authURL string) (*keySetJWKS, time.Duration, error) {
keysetURL := strings.Replace(authURL, "/oauth2/v2.0/authorize", "/discovery/v2.0/keys", 1) + "?appid=" + s.ClientID
resp, err := s.httpGet(ctx, client, keysetURL)
if err != nil {
return nil, 0, err
}
bytesReader := bytes.NewReader(resp.Body)
var jwks keySetJWKS
if err := json.NewDecoder(bytesReader).Decode(&jwks); err != nil {
return nil, 0, err
}
cacheExpiration := getCacheExpiration(resp.Headers.Get("cache-control"))
s.log.Debug("Retrieved specific key set", "url", keysetURL, "cacheExpiration", cacheExpiration)
return &jwks, cacheExpiration, nil
}
func getCacheExpiration(header string) time.Duration {
if header == "" {
return defaultCacheExpiration
}
// Cache-Control: public, max-age=14400
cacheControl := strings.Split(header, ",")
for _, v := range cacheControl {
if strings.Contains(v, "max-age") {
parts := strings.Split(v, "=")
if len(parts) == 2 {
seconds, err := strconv.Atoi(parts[1])
if err != nil {
return defaultCacheExpiration
}
return time.Duration(seconds) * time.Second
}
}
}
return defaultCacheExpiration
}

View File

@@ -7,8 +7,6 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"time"
@@ -21,12 +19,6 @@ import (
"github.com/grafana/grafana/pkg/services/org"
)
const (
azureCacheKeyPrefix = "azuread_oauth_jwks-"
defaultCacheExpiration = 24 * time.Hour
tenantRegex = `^https:\/\/login\.microsoftonline\.com\/(?P<tenant>[a-zA-Z0-9\-]+)\/oauth2\/v2\.0\/authorize$`
)
type SocialAzureAD struct {
*SocialBase
cache remotecache.CacheStorage
@@ -34,7 +26,6 @@ type SocialAzureAD struct {
allowedGroups []string
forceUseGraphAPI bool
skipOrgRoleSync bool
compiledTenantRegex *regexp.Regexp
}
type azureClaims struct {
@@ -78,45 +69,9 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token
return nil, fmt.Errorf("error parsing id token: %w", err)
}
var claims azureClaims
keyset, err := s.retrieveJWKS(client)
claims, err := s.validateClaims(ctx, client, parsedToken)
if err != nil {
return nil, fmt.Errorf("error retrieving jwks: %w", err)
}
var errClaims error
keyID := parsedToken.Headers[0].KeyID
keys := keyset.Key(keyID)
if len(keys) == 0 {
s.log.Warn("AzureAD OAuth: signing key not found",
"kid", keyID,
"keys", fmt.Sprintf("%v", keyset.Keys))
return nil, &Error{"AzureAD OAuth: signing key not found"}
}
for _, key := range keys {
s.log.Debug("AzureAD OAuth: trying to parse token with key", "kid", key.KeyID)
if errClaims = parsedToken.Claims(key, &claims); errClaims == nil {
break
}
}
if errClaims != nil {
return nil, fmt.Errorf("error getting claims from id token: %w", errClaims)
}
if claims.OAuthVersion == "1.0" {
return nil, &Error{"AzureAD OAuth: version 1.0 is not supported. Please ensure the auth_url and token_url are set to the v2.0 endpoints."}
}
s.log.Debug("Validating audience", "audience", claims.Audience, "client_id", s.ClientID)
if claims.Audience != s.ClientID {
return nil, &Error{"AzureAD OAuth: audience mismatch"}
}
s.log.Debug("Validating tenant", "tenant", claims.TenantID, "allowed_tenants", s.allowedOrganizations)
if !s.isAllowedTenant(claims.TenantID) {
return nil, &Error{"AzureAD OAuth: tenant mismatch"}
return nil, err
}
email := claims.extractEmail()
@@ -128,7 +83,7 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token
var role roletype.RoleType
var grafanaAdmin bool
if !s.skipOrgRoleSync {
role, grafanaAdmin = s.extractRoleAndAdmin(&claims)
role, grafanaAdmin = s.extractRoleAndAdmin(claims)
}
if s.roleAttributeStrict && !role.IsValid() {
return nil, &InvalidBasicRoleError{idP: "Azure", assignedRole: string(role)}
@@ -160,6 +115,65 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token
}, nil
}
func (s *SocialAzureAD) validateClaims(ctx context.Context, client *http.Client, parsedToken *jwt.JSONWebToken) (*azureClaims, error) {
claims, err := s.validateIDTokenSignature(ctx, client, parsedToken)
if err != nil {
return nil, fmt.Errorf("error getting claims from id token: %w", err)
}
if claims.OAuthVersion == "1.0" {
return nil, &Error{"AzureAD OAuth: version 1.0 is not supported. Please ensure the auth_url and token_url are set to the v2.0 endpoints."}
}
s.log.Debug("Validating audience", "audience", claims.Audience, "client_id", s.ClientID)
if claims.Audience != s.ClientID {
return nil, &Error{"AzureAD OAuth: audience mismatch"}
}
s.log.Debug("Validating tenant", "tenant", claims.TenantID, "allowed_tenants", s.allowedOrganizations)
if !s.isAllowedTenant(claims.TenantID) {
return nil, &Error{"AzureAD OAuth: tenant mismatch"}
}
return claims, nil
}
func (s *SocialAzureAD) validateIDTokenSignature(ctx context.Context, client *http.Client, parsedToken *jwt.JSONWebToken) (*azureClaims, error) {
var claims azureClaims
jwksFuncs := []func(ctx context.Context, client *http.Client, authURL string) (*keySetJWKS, time.Duration, error){
s.retrieveJWKSFromCache, s.retrieveSpecificJWKS, s.retrieveGeneralJWKS,
}
keyID := parsedToken.Headers[0].KeyID
for _, jwksFunc := range jwksFuncs {
keyset, expiry, err := jwksFunc(ctx, client, s.Endpoint.AuthURL)
if err != nil {
return nil, fmt.Errorf("error retrieving jwks: %w", err)
}
var errClaims error
keys := keyset.Key(keyID)
for _, key := range keys {
s.log.Debug("AzureAD OAuth: trying to parse token with key", "kid", key.KeyID)
if errClaims = parsedToken.Claims(key, &claims); errClaims == nil {
if expiry != 0 {
s.log.Debug("AzureAD OAuth: caching key set", "kid", key.KeyID, "expiry", expiry)
if err := s.cacheJWKS(ctx, keyset, expiry); err != nil {
s.log.Warn("Failed to set key set in cache", "err", err)
}
}
return &claims, nil
} else {
s.log.Warn("AzureAD OAuth: failed to parse token with key", "kid", key.KeyID, "err", errClaims)
}
}
}
s.log.Warn("AzureAD OAuth: signing key not found", "kid", keyID)
return nil, &Error{"AzureAD OAuth: signing key not found"}
}
func (s *SocialAzureAD) IsGroupMember(groups []string) bool {
if len(s.allowedGroups) == 0 {
return true
@@ -229,7 +243,7 @@ type getAzureGroupResponse struct {
// Note: If user groups exceeds 200 no groups will be found in claims and URL to target the Graph API will be
// given instead.
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client, claims *azureClaims, token *oauth2.Token) ([]string, error) {
if !s.forceUseGraphAPI {
s.log.Debug("checking the claim for groups")
if len(claims.Groups) > 0 {
@@ -289,7 +303,7 @@ func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client,
// groupsGraphAPIURL retrieves the Microsoft Graph API URL to fetch user groups from the _claim_sources if present
// otherwise it generates an handcrafted URL.
func (s *SocialAzureAD) groupsGraphAPIURL(claims azureClaims, token *oauth2.Token) (string, error) {
func (s *SocialAzureAD) groupsGraphAPIURL(claims *azureClaims, token *oauth2.Token) (string, error) {
var endpoint string
// First check if an endpoint was specified in the claims
if claims.ClaimNames.Groups != "" {
@@ -332,68 +346,6 @@ func (s *SocialAzureAD) SupportBundleContent(bf *bytes.Buffer) error {
return s.SocialBase.SupportBundleContent(bf)
}
func (s *SocialAzureAD) extractTenantID(authURL string) (string, error) {
if s.compiledTenantRegex == nil {
compiledTenantRegex, err := regexp.Compile(`https://login.microsoftonline.(com|us)/([^/]+)/oauth2`)
if err != nil {
return "", err
}
s.compiledTenantRegex = compiledTenantRegex
}
matches := s.compiledTenantRegex.FindStringSubmatch(authURL)
if len(matches) < 3 {
return "", fmt.Errorf("unable to extract tenant ID from URL")
}
return matches[2], nil
}
func (s *SocialAzureAD) retrieveJWKS(client *http.Client) (*jose.JSONWebKeySet, error) {
var jwks keySetJWKS
// https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize retrieve organizations
// https://login.microsoftonline.com/xxx/oauth2/v2.0/authorize retrieve specific tenant xxx
tenant_id, err := s.extractTenantID(s.Endpoint.AuthURL)
if err != nil {
return nil, err
}
// example: azuread_oauth_jwks-33121321nUd
cacheKey := azureCacheKeyPrefix + tenant_id
// TODO: propagate context
if val, err := s.cache.Get(context.Background(), cacheKey); err == nil {
err := json.Unmarshal(val, &jwks)
return &jwks.JSONWebKeySet, err
} else {
s.log.Debug("Keyset not found in cache", "err", err)
}
// TODO: allow setting well-known endpoint and retrieve from there
keysetURL := strings.Replace(s.Endpoint.AuthURL, "/oauth2/v2.0/authorize", "/discovery/v2.0/keys", 1)
resp, err := s.httpGet(context.Background(), client, keysetURL)
if err != nil {
return nil, err
}
bytesReader := bytes.NewReader(resp.Body)
var jsonBuf bytes.Buffer
if err := json.NewDecoder(io.TeeReader(bytesReader, &jsonBuf)).Decode(&jwks); err != nil {
return nil, err
}
cacheExpiration := getCacheExpiration(resp.Headers.Get("cache-control"))
s.log.Debug("Setting key set in cache", "url", keysetURL, "cache-key", cacheKey,
"cacheExpiration", cacheExpiration)
if err := s.cache.Set(context.Background(), cacheKey, jsonBuf.Bytes(), cacheExpiration); err != nil {
s.log.Warn("Failed to set key set in cache", "url", keysetURL, "cache-key", cacheKey, "err", err)
}
return &jwks.JSONWebKeySet, nil
}
func (s *SocialAzureAD) isAllowedTenant(tenantID string) bool {
if len(s.allowedOrganizations) == 0 {
s.log.Warn("No allowed organizations specified, all tenants are allowed. Configure allowed_organizations to restrict access")
@@ -407,26 +359,3 @@ func (s *SocialAzureAD) isAllowedTenant(tenantID string) bool {
}
return false
}
func getCacheExpiration(header string) time.Duration {
if header == "" {
return defaultCacheExpiration
}
// Cache-Control: public, max-age=14400
cacheControl := strings.Split(header, ",")
for _, v := range cacheControl {
if strings.Contains(v, "max-age") {
parts := strings.Split(v, "=")
if len(parts) == 2 {
seconds, err := strconv.Atoi(parts[1])
if err != nil {
return defaultCacheExpiration
}
return time.Duration(seconds) * time.Second
}
}
}
return defaultCacheExpiration
}

View File

@@ -61,7 +61,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
ID: "1234",
},
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
want: &BasicUserInfo{
Id: "1234",
@@ -74,6 +74,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
},
{
name: "No email",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "",
PreferredUsername: "",
@@ -100,7 +103,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
ID: "1234",
},
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
usGovURL: true,
},
want: &BasicUserInfo{
@@ -122,7 +125,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
ID: "1234",
},
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
want: &BasicUserInfo{
Id: "1234",
@@ -135,6 +138,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
},
{
name: "Admin role",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
@@ -153,6 +159,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
},
{
name: "Lowercase Admin role",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
@@ -172,7 +181,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Only other roles",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
@@ -200,7 +209,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
ID: "1234",
},
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()),
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()),
},
want: &BasicUserInfo{
Id: "1234",
@@ -220,6 +229,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
Name: "My Name",
ID: "1234",
},
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()),
},
want: &BasicUserInfo{
Id: "1234",
Name: "My Name",
@@ -231,6 +243,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
},
{
name: "Admin and Editor roles in claim",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
@@ -249,7 +264,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
},
{
name: "Grafana Admin but setting is disabled",
fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures())},
fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures())},
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
@@ -271,7 +286,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
name: "Editor roles in claim and GrafanaAdminAssignment enabled",
fields: fields{
SocialBase: newSocialBase("azuread",
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
&oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
@@ -293,7 +308,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Grafana Admin and Editor roles in claim",
fields: fields{SocialBase: newSocialBase("azuread",
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())},
&oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())},
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
@@ -314,6 +329,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Error if user is not a member of allowed_groups",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures()),
allowedGroups: []string{"dead-beef"},
},
claims: &azureClaims{
@@ -330,6 +346,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Error if user is not a member of allowed_organizations",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures()),
allowedOrganizations: []string{"uuid-1234"},
},
claims: &azureClaims{
@@ -372,7 +389,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
fields: fields{
allowedGroups: []string{"foo", "bar"},
SocialBase: newSocialBase("azuread",
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false, *featuremgmt.WithFeatures()),
&oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
@@ -394,7 +411,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch groups when ClaimsNames and ClaimsSources is set",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
ID: "1",
@@ -419,7 +436,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch groups when forceUseGraphAPI is set",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
forceUseGraphAPI: true,
},
claims: &azureClaims{
@@ -446,7 +463,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch empty role when strict attribute role is true and no match",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
@@ -462,7 +479,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
{
name: "Fetch empty role when strict attribute role is true and no role claims returned",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
},
claims: &azureClaims{
Email: "me@example.com",
@@ -506,7 +523,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
jwksDump, err := json.Marshal(jwks)
require.NoError(t, err)
err = cache.Set(context.Background(), azureCacheKeyPrefix+"1234", jwksDump, 0)
err = cache.Set(context.Background(), azureCacheKeyPrefix+"client-id-example", jwksDump, 0)
require.NoError(t, err)
for _, tt := range tests {
@@ -520,7 +537,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
}
if tt.fields.SocialBase == nil {
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
s.SocialBase = newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
}
if tt.fields.usGovURL {
@@ -530,14 +547,15 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
}
cl := jwt.Claims{
Audience: jwt.Audience{"client-id-example"},
Subject: "subject",
Issuer: "issuer",
NotBefore: jwt.NewNumericDate(time.Date(2016, 1, 1, 0, 0, 0, 0, time.UTC)),
Audience: jwt.Audience{"leela", "fry"},
}
var raw string
if tt.claims != nil {
tt.claims.Audience = "client-id-example"
if tt.claims.ClaimNames.Groups != "" {
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
tokenParts := strings.Split(request.Header.Get("Authorization"), " ")
@@ -609,7 +627,9 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
{
name: "Grafana Admin and Editor roles in claim, skipOrgRoleSync disabled should get roles, skipOrgRoleSyncBase disabled",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
SocialBase: newSocialBase("azuread",
&oauth2.Config{ClientID: "client-id-example"},
&OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
skipOrgRoleSync: false,
},
claims: &azureClaims{
@@ -632,7 +652,9 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
{
name: "Grafana Admin and Editor roles in claim, skipOrgRoleSync disabled should not get roles",
fields: fields{
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
SocialBase: newSocialBase("azuread",
&oauth2.Config{ClientID: "client-id-example"},
&OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
skipOrgRoleSync: false,
},
claims: &azureClaims{
@@ -681,7 +703,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
jwksDump, err := json.Marshal(jwks)
require.NoError(t, err)
err = cache.Set(context.Background(), azureCacheKeyPrefix+"1234", jwksDump, 0)
err = cache.Set(context.Background(), azureCacheKeyPrefix+"client-id-example", jwksDump, 0)
require.NoError(t, err)
for _, tt := range tests {
@@ -695,7 +717,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
}
if tt.fields.SocialBase == nil {
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
s.SocialBase = newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
}
s.SocialBase.Endpoint.AuthURL = authURL
@@ -709,6 +731,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
var raw string
if tt.claims != nil {
tt.claims.Audience = "client-id-example"
if tt.claims.ClaimNames.Groups != "" {
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
tokenParts := strings.Split(request.Header.Get("Authorization"), " ")

View File

@@ -193,9 +193,9 @@ func ProvideService(cfg *setting.Cfg,
ss.socialMap["azuread"] = &SocialAzureAD{
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
cache: cache,
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
forceUseGraphAPI: sec.Key("force_use_graph_api").MustBool(false),
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
skipOrgRoleSync: cfg.AzureADSkipOrgRoleSync,
}
}