mirror of
https://github.com/grafana/grafana.git
synced 2025-01-16 03:32:37 -06:00
AzureAD: Add option to force fetch the groups from the Graph API (#56916)
* Add a new option to systematically fetch AzureAD groups from the Graph API
This commit is contained in:
parent
e4f2006cce
commit
0f4d126109
@ -544,6 +544,7 @@ allowed_domains =
|
||||
allowed_groups =
|
||||
role_attribute_strict = false
|
||||
allow_assign_grafana_admin = false
|
||||
force_use_graph_api = false
|
||||
|
||||
#################################### Okta OAuth #######################
|
||||
[auth.okta]
|
||||
|
@ -225,3 +225,11 @@ Grafana attempts to retrieve the user's group membership by calling the included
|
||||
|
||||
> Note: The token must include the `GroupMember.Read.All` permission for group overage claim calls to succeed.
|
||||
> Admin consent may be required for this permission.
|
||||
|
||||
### Force fetching groups from Microsoft graph API
|
||||
|
||||
To force fetching groups from Microsoft Graph API instead of the `id_token`. You can use the `force_use_graph_api` config option.
|
||||
|
||||
```
|
||||
force_use_graph_api = true
|
||||
```
|
||||
|
@ -17,7 +17,8 @@ import (
|
||||
|
||||
type SocialAzureAD struct {
|
||||
*SocialBase
|
||||
allowedGroups []string
|
||||
allowedGroups []string
|
||||
forceUseGraphAPI bool
|
||||
}
|
||||
|
||||
type azureClaims struct {
|
||||
@ -76,7 +77,7 @@ func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*Bas
|
||||
|
||||
logger.Debug("AzureAD OAuth: extracted role", "email", email, "role", role)
|
||||
|
||||
groups, err := extractGroups(client, claims, token)
|
||||
groups, err := s.extractGroups(client, claims, token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extract groups: %w", err)
|
||||
}
|
||||
@ -166,39 +167,26 @@ type getAzureGroupResponse struct {
|
||||
Value []string `json:"value"`
|
||||
}
|
||||
|
||||
func extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
|
||||
if len(claims.Groups) > 0 {
|
||||
return claims.Groups, nil
|
||||
}
|
||||
|
||||
if claims.ClaimNames.Groups == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// If user groups exceeds 200 no groups will be found in claims.
|
||||
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
|
||||
endpoint := claims.ClaimSources[claims.ClaimNames.Groups].Endpoint
|
||||
|
||||
// If the endpoints provided in _claim_source is pointing to the deprecated "graph.windows.net" api
|
||||
// replace with handcrafted url to graph.microsoft.com
|
||||
// See https://docs.microsoft.com/en-us/graph/migrate-azure-ad-graph-overview
|
||||
if strings.Contains(endpoint, "graph.windows.net") {
|
||||
tenantID := claims.TenantID
|
||||
// If tenantID wasn't found in the id_token, parse access token
|
||||
if tenantID == "" {
|
||||
parsedToken, err := jwt.ParseSigned(token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
|
||||
var accessClaims azureAccessClaims
|
||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&accessClaims); err != nil {
|
||||
return nil, fmt.Errorf("error getting claims from access token: %w", err)
|
||||
}
|
||||
tenantID = accessClaims.TenantID
|
||||
// extractGroups retrieves groups from the claims.
|
||||
// Note: If user groups exceeds 200 no groups will be found in claims and URL to target the Graph API will be
|
||||
// given instead.
|
||||
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
|
||||
func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
|
||||
if !s.forceUseGraphAPI {
|
||||
logger.Debug("checking the claim for groups")
|
||||
if len(claims.Groups) > 0 {
|
||||
return claims.Groups, nil
|
||||
}
|
||||
|
||||
endpoint = fmt.Sprintf("https://graph.microsoft.com/v1.0/%s/users/%s/getMemberObjects", tenantID, claims.ID)
|
||||
if claims.ClaimNames.Groups == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to the Graph API
|
||||
endpoint, errBuildGraphURI := groupsGraphAPIURL(claims, token)
|
||||
if errBuildGraphURI != nil {
|
||||
return nil, errBuildGraphURI
|
||||
}
|
||||
|
||||
data, err := json.Marshal(&getAzureGroupRequest{SecurityEnabledOnly: false})
|
||||
@ -234,3 +222,38 @@ func extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token)
|
||||
|
||||
return body.Value, nil
|
||||
}
|
||||
|
||||
// groupsGraphAPIURL retrieves the Microsoft Graph API URL to fetch user groups from the _claim_sources if present
|
||||
// otherwise it generates an handcrafted URL.
|
||||
func groupsGraphAPIURL(claims azureClaims, token *oauth2.Token) (string, error) {
|
||||
var endpoint string
|
||||
// First check if an endpoint was specified in the claims
|
||||
if claims.ClaimNames.Groups != "" {
|
||||
endpoint = claims.ClaimSources[claims.ClaimNames.Groups].Endpoint
|
||||
logger.Debug(fmt.Sprintf("endpoint to fetch groups specified in the claims: %s", endpoint))
|
||||
}
|
||||
|
||||
// If no endpoint was specified or if the endpoints provided in _claim_source is pointing to the deprecated
|
||||
// "graph.windows.net" api, use an handcrafted url to graph.microsoft.com
|
||||
// See https://docs.microsoft.com/en-us/graph/migrate-azure-ad-graph-overview
|
||||
if endpoint == "" || strings.Contains(endpoint, "graph.windows.net") {
|
||||
tenantID := claims.TenantID
|
||||
// If tenantID wasn't found in the id_token, parse access token
|
||||
if tenantID == "" {
|
||||
parsedToken, err := jwt.ParseSigned(token.AccessToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
|
||||
var accessClaims azureAccessClaims
|
||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&accessClaims); err != nil {
|
||||
return "", fmt.Errorf("error getting claims from access token: %w", err)
|
||||
}
|
||||
tenantID = accessClaims.TenantID
|
||||
}
|
||||
|
||||
endpoint = fmt.Sprintf("https://graph.microsoft.com/v1.0/%s/users/%s/getMemberObjects", tenantID, claims.ID)
|
||||
logger.Debug(fmt.Sprintf("handcrafted endpoint to fetch groups: %s", endpoint))
|
||||
}
|
||||
return endpoint, nil
|
||||
}
|
||||
|
@ -27,8 +27,9 @@ func falseBoolPtr() *bool {
|
||||
|
||||
func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
type fields struct {
|
||||
SocialBase *SocialBase
|
||||
allowedGroups []string
|
||||
SocialBase *SocialBase
|
||||
allowedGroups []string
|
||||
forceUseGraphAPI bool
|
||||
}
|
||||
type args struct {
|
||||
client *http.Client
|
||||
@ -345,6 +346,33 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Fetch groups when forceUseGraphAPI is set",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, ""),
|
||||
forceUseGraphAPI: true,
|
||||
},
|
||||
claims: &azureClaims{
|
||||
ID: "1",
|
||||
Name: "test",
|
||||
PreferredUsername: "test",
|
||||
Email: "test@test.com",
|
||||
Roles: []string{"Viewer"},
|
||||
ClaimNames: claimNames{Groups: "src1"},
|
||||
ClaimSources: nil, // set by the test
|
||||
Groups: []string{"foo", "bar"}, // must be ignored
|
||||
},
|
||||
settingAutoAssignOrgRole: "",
|
||||
want: &BasicUserInfo{
|
||||
Id: "1",
|
||||
Name: "test",
|
||||
Email: "test@test.com",
|
||||
Login: "test@test.com",
|
||||
Role: "Viewer",
|
||||
Groups: []string{"from_server"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Fetch empty role when strict attribute role is true and no match",
|
||||
fields: fields{
|
||||
@ -382,8 +410,9 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &SocialAzureAD{
|
||||
SocialBase: tt.fields.SocialBase,
|
||||
allowedGroups: tt.fields.allowedGroups,
|
||||
SocialBase: tt.fields.SocialBase,
|
||||
allowedGroups: tt.fields.allowedGroups,
|
||||
forceUseGraphAPI: tt.fields.forceUseGraphAPI,
|
||||
}
|
||||
|
||||
if tt.fields.SocialBase == nil {
|
||||
|
@ -167,8 +167,9 @@ func ProvideService(cfg *setting.Cfg) *SocialService {
|
||||
// AzureAD.
|
||||
if name == "azuread" {
|
||||
ss.socialMap["azuread"] = &SocialAzureAD{
|
||||
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole),
|
||||
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
|
||||
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole),
|
||||
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
|
||||
forceUseGraphAPI: sec.Key("force_use_graph_api").MustBool(false),
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user