mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Add support for custom signing keys in auth.azure_ad (#71365)
* fallthrough JWKS validation and caching for Azure * remove unused field
This commit is contained in:
117
pkg/login/social/azuread_jwks.go
Normal file
117
pkg/login/social/azuread_jwks.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
azureCacheKeyPrefix = "azuread_oauth_jwks-"
|
||||
defaultCacheExpiration = 5 * time.Minute
|
||||
)
|
||||
|
||||
func (s *SocialAzureAD) getJWKSCacheKey() (string, error) {
|
||||
return azureCacheKeyPrefix + s.ClientID, nil
|
||||
}
|
||||
func (s *SocialAzureAD) retrieveJWKSFromCache(ctx context.Context, client *http.Client, authURL string) (*keySetJWKS, time.Duration, error) {
|
||||
cacheKey, err := s.getJWKSCacheKey()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if val, err := s.cache.Get(ctx, cacheKey); err == nil {
|
||||
var jwks keySetJWKS
|
||||
err := json.Unmarshal(val, &jwks)
|
||||
s.log.Debug("Retrieved cached key set", "cacheKey", cacheKey)
|
||||
return &jwks, 0, err
|
||||
}
|
||||
s.log.Debug("Keyset not found in cache", "err", err)
|
||||
|
||||
return &keySetJWKS{}, 0, nil
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) cacheJWKS(ctx context.Context, jwks *keySetJWKS, cacheExpiration time.Duration) error {
|
||||
cacheKey, err := s.getJWKSCacheKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var jsonBuf bytes.Buffer
|
||||
if err := json.NewEncoder(&jsonBuf).Encode(jwks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.cache.Set(ctx, cacheKey, jsonBuf.Bytes(), cacheExpiration); err != nil {
|
||||
s.log.Warn("Failed to cache key set", "err", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) retrieveGeneralJWKS(ctx context.Context, client *http.Client, authURL string) (*keySetJWKS, time.Duration, error) {
|
||||
keysetURL := strings.Replace(authURL, "/oauth2/v2.0/authorize", "/discovery/v2.0/keys", 1)
|
||||
|
||||
resp, err := s.httpGet(context.Background(), client, keysetURL)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
bytesReader := bytes.NewReader(resp.Body)
|
||||
var jwks keySetJWKS
|
||||
if err := json.NewDecoder(bytesReader).Decode(&jwks); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
cacheExpiration := getCacheExpiration(resp.Headers.Get("cache-control"))
|
||||
s.log.Debug("Retrieved general key set", "url", keysetURL, "cacheExpiration", cacheExpiration)
|
||||
|
||||
return &jwks, cacheExpiration, nil
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) retrieveSpecificJWKS(ctx context.Context, client *http.Client, authURL string) (*keySetJWKS, time.Duration, error) {
|
||||
keysetURL := strings.Replace(authURL, "/oauth2/v2.0/authorize", "/discovery/v2.0/keys", 1) + "?appid=" + s.ClientID
|
||||
|
||||
resp, err := s.httpGet(ctx, client, keysetURL)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
bytesReader := bytes.NewReader(resp.Body)
|
||||
var jwks keySetJWKS
|
||||
if err := json.NewDecoder(bytesReader).Decode(&jwks); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
cacheExpiration := getCacheExpiration(resp.Headers.Get("cache-control"))
|
||||
s.log.Debug("Retrieved specific key set", "url", keysetURL, "cacheExpiration", cacheExpiration)
|
||||
|
||||
return &jwks, cacheExpiration, nil
|
||||
}
|
||||
|
||||
func getCacheExpiration(header string) time.Duration {
|
||||
if header == "" {
|
||||
return defaultCacheExpiration
|
||||
}
|
||||
|
||||
// Cache-Control: public, max-age=14400
|
||||
cacheControl := strings.Split(header, ",")
|
||||
for _, v := range cacheControl {
|
||||
if strings.Contains(v, "max-age") {
|
||||
parts := strings.Split(v, "=")
|
||||
if len(parts) == 2 {
|
||||
seconds, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return defaultCacheExpiration
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return defaultCacheExpiration
|
||||
}
|
||||
@@ -7,8 +7,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -21,12 +19,6 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
)
|
||||
|
||||
const (
|
||||
azureCacheKeyPrefix = "azuread_oauth_jwks-"
|
||||
defaultCacheExpiration = 24 * time.Hour
|
||||
tenantRegex = `^https:\/\/login\.microsoftonline\.com\/(?P<tenant>[a-zA-Z0-9\-]+)\/oauth2\/v2\.0\/authorize$`
|
||||
)
|
||||
|
||||
type SocialAzureAD struct {
|
||||
*SocialBase
|
||||
cache remotecache.CacheStorage
|
||||
@@ -34,7 +26,6 @@ type SocialAzureAD struct {
|
||||
allowedGroups []string
|
||||
forceUseGraphAPI bool
|
||||
skipOrgRoleSync bool
|
||||
compiledTenantRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
type azureClaims struct {
|
||||
@@ -78,45 +69,9 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token
|
||||
return nil, fmt.Errorf("error parsing id token: %w", err)
|
||||
}
|
||||
|
||||
var claims azureClaims
|
||||
|
||||
keyset, err := s.retrieveJWKS(client)
|
||||
claims, err := s.validateClaims(ctx, client, parsedToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving jwks: %w", err)
|
||||
}
|
||||
|
||||
var errClaims error
|
||||
keyID := parsedToken.Headers[0].KeyID
|
||||
keys := keyset.Key(keyID)
|
||||
if len(keys) == 0 {
|
||||
s.log.Warn("AzureAD OAuth: signing key not found",
|
||||
"kid", keyID,
|
||||
"keys", fmt.Sprintf("%v", keyset.Keys))
|
||||
return nil, &Error{"AzureAD OAuth: signing key not found"}
|
||||
}
|
||||
for _, key := range keys {
|
||||
s.log.Debug("AzureAD OAuth: trying to parse token with key", "kid", key.KeyID)
|
||||
if errClaims = parsedToken.Claims(key, &claims); errClaims == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if errClaims != nil {
|
||||
return nil, fmt.Errorf("error getting claims from id token: %w", errClaims)
|
||||
}
|
||||
|
||||
if claims.OAuthVersion == "1.0" {
|
||||
return nil, &Error{"AzureAD OAuth: version 1.0 is not supported. Please ensure the auth_url and token_url are set to the v2.0 endpoints."}
|
||||
}
|
||||
|
||||
s.log.Debug("Validating audience", "audience", claims.Audience, "client_id", s.ClientID)
|
||||
if claims.Audience != s.ClientID {
|
||||
return nil, &Error{"AzureAD OAuth: audience mismatch"}
|
||||
}
|
||||
|
||||
s.log.Debug("Validating tenant", "tenant", claims.TenantID, "allowed_tenants", s.allowedOrganizations)
|
||||
if !s.isAllowedTenant(claims.TenantID) {
|
||||
return nil, &Error{"AzureAD OAuth: tenant mismatch"}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
email := claims.extractEmail()
|
||||
@@ -128,7 +83,7 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token
|
||||
var role roletype.RoleType
|
||||
var grafanaAdmin bool
|
||||
if !s.skipOrgRoleSync {
|
||||
role, grafanaAdmin = s.extractRoleAndAdmin(&claims)
|
||||
role, grafanaAdmin = s.extractRoleAndAdmin(claims)
|
||||
}
|
||||
if s.roleAttributeStrict && !role.IsValid() {
|
||||
return nil, &InvalidBasicRoleError{idP: "Azure", assignedRole: string(role)}
|
||||
@@ -160,6 +115,65 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) validateClaims(ctx context.Context, client *http.Client, parsedToken *jwt.JSONWebToken) (*azureClaims, error) {
|
||||
claims, err := s.validateIDTokenSignature(ctx, client, parsedToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting claims from id token: %w", err)
|
||||
}
|
||||
|
||||
if claims.OAuthVersion == "1.0" {
|
||||
return nil, &Error{"AzureAD OAuth: version 1.0 is not supported. Please ensure the auth_url and token_url are set to the v2.0 endpoints."}
|
||||
}
|
||||
|
||||
s.log.Debug("Validating audience", "audience", claims.Audience, "client_id", s.ClientID)
|
||||
if claims.Audience != s.ClientID {
|
||||
return nil, &Error{"AzureAD OAuth: audience mismatch"}
|
||||
}
|
||||
|
||||
s.log.Debug("Validating tenant", "tenant", claims.TenantID, "allowed_tenants", s.allowedOrganizations)
|
||||
if !s.isAllowedTenant(claims.TenantID) {
|
||||
return nil, &Error{"AzureAD OAuth: tenant mismatch"}
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) validateIDTokenSignature(ctx context.Context, client *http.Client, parsedToken *jwt.JSONWebToken) (*azureClaims, error) {
|
||||
var claims azureClaims
|
||||
|
||||
jwksFuncs := []func(ctx context.Context, client *http.Client, authURL string) (*keySetJWKS, time.Duration, error){
|
||||
s.retrieveJWKSFromCache, s.retrieveSpecificJWKS, s.retrieveGeneralJWKS,
|
||||
}
|
||||
|
||||
keyID := parsedToken.Headers[0].KeyID
|
||||
|
||||
for _, jwksFunc := range jwksFuncs {
|
||||
keyset, expiry, err := jwksFunc(ctx, client, s.Endpoint.AuthURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving jwks: %w", err)
|
||||
}
|
||||
var errClaims error
|
||||
keys := keyset.Key(keyID)
|
||||
for _, key := range keys {
|
||||
s.log.Debug("AzureAD OAuth: trying to parse token with key", "kid", key.KeyID)
|
||||
if errClaims = parsedToken.Claims(key, &claims); errClaims == nil {
|
||||
if expiry != 0 {
|
||||
s.log.Debug("AzureAD OAuth: caching key set", "kid", key.KeyID, "expiry", expiry)
|
||||
if err := s.cacheJWKS(ctx, keyset, expiry); err != nil {
|
||||
s.log.Warn("Failed to set key set in cache", "err", err)
|
||||
}
|
||||
}
|
||||
return &claims, nil
|
||||
} else {
|
||||
s.log.Warn("AzureAD OAuth: failed to parse token with key", "kid", key.KeyID, "err", errClaims)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Warn("AzureAD OAuth: signing key not found", "kid", keyID)
|
||||
|
||||
return nil, &Error{"AzureAD OAuth: signing key not found"}
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) IsGroupMember(groups []string) bool {
|
||||
if len(s.allowedGroups) == 0 {
|
||||
return true
|
||||
@@ -229,7 +243,7 @@ type getAzureGroupResponse struct {
|
||||
// Note: If user groups exceeds 200 no groups will be found in claims and URL to target the Graph API will be
|
||||
// given instead.
|
||||
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
|
||||
func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
|
||||
func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client, claims *azureClaims, token *oauth2.Token) ([]string, error) {
|
||||
if !s.forceUseGraphAPI {
|
||||
s.log.Debug("checking the claim for groups")
|
||||
if len(claims.Groups) > 0 {
|
||||
@@ -289,7 +303,7 @@ func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client,
|
||||
|
||||
// groupsGraphAPIURL retrieves the Microsoft Graph API URL to fetch user groups from the _claim_sources if present
|
||||
// otherwise it generates an handcrafted URL.
|
||||
func (s *SocialAzureAD) groupsGraphAPIURL(claims azureClaims, token *oauth2.Token) (string, error) {
|
||||
func (s *SocialAzureAD) groupsGraphAPIURL(claims *azureClaims, token *oauth2.Token) (string, error) {
|
||||
var endpoint string
|
||||
// First check if an endpoint was specified in the claims
|
||||
if claims.ClaimNames.Groups != "" {
|
||||
@@ -332,68 +346,6 @@ func (s *SocialAzureAD) SupportBundleContent(bf *bytes.Buffer) error {
|
||||
return s.SocialBase.SupportBundleContent(bf)
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) extractTenantID(authURL string) (string, error) {
|
||||
if s.compiledTenantRegex == nil {
|
||||
compiledTenantRegex, err := regexp.Compile(`https://login.microsoftonline.(com|us)/([^/]+)/oauth2`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
s.compiledTenantRegex = compiledTenantRegex
|
||||
}
|
||||
|
||||
matches := s.compiledTenantRegex.FindStringSubmatch(authURL)
|
||||
if len(matches) < 3 {
|
||||
return "", fmt.Errorf("unable to extract tenant ID from URL")
|
||||
}
|
||||
return matches[2], nil
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) retrieveJWKS(client *http.Client) (*jose.JSONWebKeySet, error) {
|
||||
var jwks keySetJWKS
|
||||
// https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize retrieve organizations
|
||||
// https://login.microsoftonline.com/xxx/oauth2/v2.0/authorize retrieve specific tenant xxx
|
||||
|
||||
tenant_id, err := s.extractTenantID(s.Endpoint.AuthURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// example: azuread_oauth_jwks-33121321nUd
|
||||
cacheKey := azureCacheKeyPrefix + tenant_id
|
||||
|
||||
// TODO: propagate context
|
||||
if val, err := s.cache.Get(context.Background(), cacheKey); err == nil {
|
||||
err := json.Unmarshal(val, &jwks)
|
||||
return &jwks.JSONWebKeySet, err
|
||||
} else {
|
||||
s.log.Debug("Keyset not found in cache", "err", err)
|
||||
}
|
||||
|
||||
// TODO: allow setting well-known endpoint and retrieve from there
|
||||
keysetURL := strings.Replace(s.Endpoint.AuthURL, "/oauth2/v2.0/authorize", "/discovery/v2.0/keys", 1)
|
||||
|
||||
resp, err := s.httpGet(context.Background(), client, keysetURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bytesReader := bytes.NewReader(resp.Body)
|
||||
var jsonBuf bytes.Buffer
|
||||
if err := json.NewDecoder(io.TeeReader(bytesReader, &jsonBuf)).Decode(&jwks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cacheExpiration := getCacheExpiration(resp.Headers.Get("cache-control"))
|
||||
s.log.Debug("Setting key set in cache", "url", keysetURL, "cache-key", cacheKey,
|
||||
"cacheExpiration", cacheExpiration)
|
||||
|
||||
if err := s.cache.Set(context.Background(), cacheKey, jsonBuf.Bytes(), cacheExpiration); err != nil {
|
||||
s.log.Warn("Failed to set key set in cache", "url", keysetURL, "cache-key", cacheKey, "err", err)
|
||||
}
|
||||
|
||||
return &jwks.JSONWebKeySet, nil
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) isAllowedTenant(tenantID string) bool {
|
||||
if len(s.allowedOrganizations) == 0 {
|
||||
s.log.Warn("No allowed organizations specified, all tenants are allowed. Configure allowed_organizations to restrict access")
|
||||
@@ -407,26 +359,3 @@ func (s *SocialAzureAD) isAllowedTenant(tenantID string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getCacheExpiration(header string) time.Duration {
|
||||
if header == "" {
|
||||
return defaultCacheExpiration
|
||||
}
|
||||
|
||||
// Cache-Control: public, max-age=14400
|
||||
cacheControl := strings.Split(header, ",")
|
||||
for _, v := range cacheControl {
|
||||
if strings.Contains(v, "max-age") {
|
||||
parts := strings.Split(v, "=")
|
||||
if len(parts) == 2 {
|
||||
seconds, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return defaultCacheExpiration
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return defaultCacheExpiration
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
ID: "1234",
|
||||
},
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
want: &BasicUserInfo{
|
||||
Id: "1234",
|
||||
@@ -74,6 +74,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "No email",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "",
|
||||
PreferredUsername: "",
|
||||
@@ -100,7 +103,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
ID: "1234",
|
||||
},
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
usGovURL: true,
|
||||
},
|
||||
want: &BasicUserInfo{
|
||||
@@ -122,7 +125,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
ID: "1234",
|
||||
},
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
want: &BasicUserInfo{
|
||||
Id: "1234",
|
||||
@@ -135,6 +138,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Admin role",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
PreferredUsername: "",
|
||||
@@ -153,6 +159,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Lowercase Admin role",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
PreferredUsername: "",
|
||||
@@ -172,7 +181,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
{
|
||||
name: "Only other roles",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
@@ -200,7 +209,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
ID: "1234",
|
||||
},
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()),
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
want: &BasicUserInfo{
|
||||
Id: "1234",
|
||||
@@ -220,6 +229,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
Name: "My Name",
|
||||
ID: "1234",
|
||||
},
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
want: &BasicUserInfo{
|
||||
Id: "1234",
|
||||
Name: "My Name",
|
||||
@@ -231,6 +243,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Admin and Editor roles in claim",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "Editor", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
PreferredUsername: "",
|
||||
@@ -249,7 +264,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Grafana Admin but setting is disabled",
|
||||
fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures())},
|
||||
fields: fields{SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures())},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
PreferredUsername: "",
|
||||
@@ -271,7 +286,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
name: "Editor roles in claim and GrafanaAdminAssignment enabled",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread",
|
||||
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
|
||||
&oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
@@ -293,7 +308,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
{
|
||||
name: "Grafana Admin and Editor roles in claim",
|
||||
fields: fields{SocialBase: newSocialBase("azuread",
|
||||
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())},
|
||||
&oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
PreferredUsername: "",
|
||||
@@ -314,6 +329,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
{
|
||||
name: "Error if user is not a member of allowed_groups",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures()),
|
||||
allowedGroups: []string{"dead-beef"},
|
||||
},
|
||||
claims: &azureClaims{
|
||||
@@ -330,6 +346,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
{
|
||||
name: "Error if user is not a member of allowed_organizations",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Editor", false, *featuremgmt.WithFeatures()),
|
||||
allowedOrganizations: []string{"uuid-1234"},
|
||||
},
|
||||
claims: &azureClaims{
|
||||
@@ -372,7 +389,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
fields: fields{
|
||||
allowedGroups: []string{"foo", "bar"},
|
||||
SocialBase: newSocialBase("azuread",
|
||||
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
&oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{AllowAssignGrafanaAdmin: false}, "Viewer", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
@@ -394,7 +411,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
{
|
||||
name: "Fetch groups when ClaimsNames and ClaimsSources is set",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
ID: "1",
|
||||
@@ -419,7 +436,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
{
|
||||
name: "Fetch groups when forceUseGraphAPI is set",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures()),
|
||||
forceUseGraphAPI: true,
|
||||
},
|
||||
claims: &azureClaims{
|
||||
@@ -446,7 +463,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
{
|
||||
name: "Fetch empty role when strict attribute role is true and no match",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
@@ -462,7 +479,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
{
|
||||
name: "Fetch empty role when strict attribute role is true and no role claims returned",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{RoleAttributeStrict: true}, "", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
@@ -506,7 +523,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
jwksDump, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Set(context.Background(), azureCacheKeyPrefix+"1234", jwksDump, 0)
|
||||
err = cache.Set(context.Background(), azureCacheKeyPrefix+"client-id-example", jwksDump, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -520,7 +537,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
}
|
||||
|
||||
if tt.fields.SocialBase == nil {
|
||||
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
|
||||
s.SocialBase = newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
|
||||
}
|
||||
|
||||
if tt.fields.usGovURL {
|
||||
@@ -530,14 +547,15 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
}
|
||||
|
||||
cl := jwt.Claims{
|
||||
Audience: jwt.Audience{"client-id-example"},
|
||||
Subject: "subject",
|
||||
Issuer: "issuer",
|
||||
NotBefore: jwt.NewNumericDate(time.Date(2016, 1, 1, 0, 0, 0, 0, time.UTC)),
|
||||
Audience: jwt.Audience{"leela", "fry"},
|
||||
}
|
||||
|
||||
var raw string
|
||||
if tt.claims != nil {
|
||||
tt.claims.Audience = "client-id-example"
|
||||
if tt.claims.ClaimNames.Groups != "" {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
tokenParts := strings.Split(request.Header.Get("Authorization"), " ")
|
||||
@@ -609,7 +627,9 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
|
||||
{
|
||||
name: "Grafana Admin and Editor roles in claim, skipOrgRoleSync disabled should get roles, skipOrgRoleSyncBase disabled",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
|
||||
SocialBase: newSocialBase("azuread",
|
||||
&oauth2.Config{ClientID: "client-id-example"},
|
||||
&OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
|
||||
skipOrgRoleSync: false,
|
||||
},
|
||||
claims: &azureClaims{
|
||||
@@ -632,7 +652,9 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
|
||||
{
|
||||
name: "Grafana Admin and Editor roles in claim, skipOrgRoleSync disabled should not get roles",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
|
||||
SocialBase: newSocialBase("azuread",
|
||||
&oauth2.Config{ClientID: "client-id-example"},
|
||||
&OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
|
||||
skipOrgRoleSync: false,
|
||||
},
|
||||
claims: &azureClaims{
|
||||
@@ -681,7 +703,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
|
||||
jwksDump, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Set(context.Background(), azureCacheKeyPrefix+"1234", jwksDump, 0)
|
||||
err = cache.Set(context.Background(), azureCacheKeyPrefix+"client-id-example", jwksDump, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -695,7 +717,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
|
||||
}
|
||||
|
||||
if tt.fields.SocialBase == nil {
|
||||
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
|
||||
s.SocialBase = newSocialBase("azuread", &oauth2.Config{ClientID: "client-id-example"}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
|
||||
}
|
||||
|
||||
s.SocialBase.Endpoint.AuthURL = authURL
|
||||
@@ -709,6 +731,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
|
||||
|
||||
var raw string
|
||||
if tt.claims != nil {
|
||||
tt.claims.Audience = "client-id-example"
|
||||
if tt.claims.ClaimNames.Groups != "" {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
tokenParts := strings.Split(request.Header.Get("Authorization"), " ")
|
||||
|
||||
@@ -193,9 +193,9 @@ func ProvideService(cfg *setting.Cfg,
|
||||
ss.socialMap["azuread"] = &SocialAzureAD{
|
||||
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
|
||||
cache: cache,
|
||||
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
|
||||
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
|
||||
forceUseGraphAPI: sec.Key("force_use_graph_api").MustBool(false),
|
||||
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
|
||||
skipOrgRoleSync: cfg.AzureADSkipOrgRoleSync,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user