mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
AzureAD OAuth: Add support for fetching user groups (#43470)
* Add functionallity to extractGroups to handle groups-overage claims Co-authored-by: Emil Tullstedt <emil.tullstedt@grafana.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -20,19 +22,33 @@ type SocialAzureAD struct {
|
||||
}
|
||||
|
||||
type azureClaims struct {
|
||||
Email string `json:"email"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Roles []string `json:"roles"`
|
||||
Groups []string `json:"groups"`
|
||||
Name string `json:"name"`
|
||||
ID string `json:"oid"`
|
||||
Email string `json:"email"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Roles []string `json:"roles"`
|
||||
Groups []string `json:"groups"`
|
||||
Name string `json:"name"`
|
||||
ID string `json:"oid"`
|
||||
ClaimNames claimNames `json:"_claim_names,omitempty"`
|
||||
ClaimSources map[string]claimSource `json:"_claim_sources,omitempty"`
|
||||
}
|
||||
|
||||
type claimNames struct {
|
||||
Groups string `json:"groups"`
|
||||
}
|
||||
|
||||
type claimSource struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type azureAccessClaims struct {
|
||||
TenantID string `json:"tid"`
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) Type() int {
|
||||
return int(models.AZUREAD)
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) UserInfo(_ *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
idToken := token.Extra("id_token")
|
||||
if idToken == nil {
|
||||
return nil, fmt.Errorf("no id_token found")
|
||||
@@ -56,7 +72,11 @@ func (s *SocialAzureAD) UserInfo(_ *http.Client, token *oauth2.Token) (*BasicUse
|
||||
role := extractRole(claims, s.autoAssignOrgRole)
|
||||
logger.Debug("AzureAD OAuth: extracted role", "email", email, "role", role)
|
||||
|
||||
groups := extractGroups(claims)
|
||||
groups, err := extractGroups(client, claims, token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extract groups: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("AzureAD OAuth: extracted groups", "email", email, "groups", groups)
|
||||
if !s.IsGroupMember(groups) {
|
||||
return nil, errMissingGroupMembership
|
||||
@@ -127,8 +147,69 @@ func hasRole(roles []string, role models.RoleType) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func extractGroups(claims azureClaims) []string {
|
||||
groups := make([]string, 0)
|
||||
groups = append(groups, claims.Groups...)
|
||||
return groups
|
||||
type getAzureGroupRequest struct {
|
||||
SecurityEnabledOnly bool `json:"securityEnabledOnly"`
|
||||
}
|
||||
|
||||
type getAzureGroupResponse struct {
|
||||
Value []string `json:"value"`
|
||||
}
|
||||
|
||||
func extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
|
||||
if len(claims.Groups) > 0 {
|
||||
return claims.Groups, nil
|
||||
}
|
||||
|
||||
if claims.ClaimNames.Groups == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// If user groups exceeds 200 no groups will be found in claims.
|
||||
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
|
||||
endpoint := claims.ClaimSources[claims.ClaimNames.Groups].Endpoint
|
||||
if strings.Contains(endpoint, "graph.windows.net") {
|
||||
// If the endpoints provided in _claim_source is pointed to the deprecated "graph.windows.net" api
|
||||
// replace with handcrafted url to graph.microsoft.com
|
||||
// See https://docs.microsoft.com/en-us/graph/migrate-azure-ad-graph-overview
|
||||
parsedToken, err := jwt.ParseSigned(token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, errutil.Wrapf(err, "error parsing id token")
|
||||
}
|
||||
|
||||
var accessClaims azureAccessClaims
|
||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&accessClaims); err != nil {
|
||||
return nil, errutil.Wrapf(err, "error getting claims from access token")
|
||||
}
|
||||
endpoint = fmt.Sprintf("https://graph.microsoft.com/v1.0/%s/users/%s/getMemberObjects", accessClaims.TenantID, claims.ID)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(&getAzureGroupRequest{SecurityEnabledOnly: false})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := client.Post(endpoint, "application/json", bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := res.Body.Close(); err != nil {
|
||||
logger.Warn("AzureAD OAuth: failed to close response body", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
if res.StatusCode == http.StatusForbidden {
|
||||
logger.Error("AzureAD OAuth: failed to fetch user groups. Token need User.Read and GroupMember.Read.All permission")
|
||||
}
|
||||
return nil, errors.New("error fetching groups")
|
||||
}
|
||||
|
||||
var body getAzureGroupResponse
|
||||
if err := json.NewDecoder(res.Body).Decode(&body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return body.Value, nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
@@ -249,6 +254,31 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
Groups: []string{"foo"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Fetch groups when ClaimsNames and ClaimsSources is set",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
ID: "1",
|
||||
Name: "test",
|
||||
PreferredUsername: "test",
|
||||
Email: "test@test.com",
|
||||
Roles: []string{"Viewer"},
|
||||
ClaimNames: claimNames{Groups: "src1"},
|
||||
ClaimSources: nil, // set by the test
|
||||
},
|
||||
settingAutoAssignOrgRole: "",
|
||||
want: &BasicUserInfo{
|
||||
Id: "1",
|
||||
Name: "test",
|
||||
Email: "test@test.com",
|
||||
Login: "test@test.com",
|
||||
Role: "Viewer",
|
||||
Groups: []string{"from_server"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -273,6 +303,25 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
|
||||
var raw string
|
||||
if tt.claims != nil {
|
||||
if tt.claims.ClaimNames.Groups != "" {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
tokenParts := strings.Split(request.Header.Get("Authorization"), " ")
|
||||
require.Len(t, tokenParts, 2)
|
||||
require.Equal(t, "fake_token", tokenParts[1])
|
||||
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
type response struct {
|
||||
Value []string
|
||||
}
|
||||
res := response{Value: []string{"from_server"}}
|
||||
require.NoError(t, json.NewEncoder(writer).Encode(&res))
|
||||
}))
|
||||
// need to set the fake servers url as endpoint to capture request
|
||||
tt.claims.ClaimSources = map[string]claimSource{
|
||||
tt.claims.ClaimNames.Groups: {Endpoint: server.URL},
|
||||
}
|
||||
}
|
||||
raw, err = jwt.Signed(sig).Claims(cl).Claims(tt.claims).CompactSerialize()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -284,11 +333,17 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
token := &oauth2.Token{}
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "fake_token",
|
||||
}
|
||||
if tt.claims != nil {
|
||||
token = token.WithExtra(map[string]interface{}{"id_token": raw})
|
||||
}
|
||||
|
||||
if tt.fields.SocialBase != nil {
|
||||
tt.args.client = s.Client(context.Background(), token)
|
||||
}
|
||||
|
||||
got, err := s.UserInfo(tt.args.client, token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
Reference in New Issue
Block a user