mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Add auth.azure_ad security improvements (#912)
* security improvements id_token * add audience validation * add allowOrganizations * add allowOrganizations tests and documentation * add log warn on no configuration * anonymize tenant id * Apply suggestions from code review Co-authored-by: Misi <mgyongyosi@users.noreply.github.com> * Update docs/sources/setup-grafana/configure-security/configure-authentication/azuread/index.md Co-authored-by: Ieva <ieva.vasiljeva@grafana.com> * Update pkg/login/social/azuread_oauth_test.go Co-authored-by: Ieva <ieva.vasiljeva@grafana.com> * Update pkg/login/social/azuread_oauth_test.go Co-authored-by: Ieva <ieva.vasiljeva@grafana.com> * optimize key validation and add mising fields * fix missing key_id * lint * Update docs/sources/setup-grafana/configure-security/configure-authentication/azuread/index.md Co-authored-by: Misi <mgyongyosi@users.noreply.github.com> * lint docs --------- Co-authored-by: Misi <mgyongyosi@users.noreply.github.com> Co-authored-by: Ieva <ieva.vasiljeva@grafana.com>
This commit is contained in:
parent
87b127e073
commit
4821175d40
conf
docs/sources/setup-grafana/configure-security/configure-authentication/azuread
pkg
api
infra/remotecache
login/social
@ -680,6 +680,7 @@ auth_url = https://login.microsoftonline.com/<tenant-id>/oauth2/v2.0/authorize
|
||||
token_url = https://login.microsoftonline.com/<tenant-id>/oauth2/v2.0/token
|
||||
allowed_domains =
|
||||
allowed_groups =
|
||||
allowed_organizations =
|
||||
role_attribute_strict = false
|
||||
allow_assign_grafana_admin = false
|
||||
force_use_graph_api = false
|
||||
|
@ -652,6 +652,7 @@
|
||||
;token_url = https://login.microsoftonline.com/<tenant-id>/oauth2/v2.0/token
|
||||
;allowed_domains =
|
||||
;allowed_groups =
|
||||
;allowed_organizations =
|
||||
;role_attribute_strict = false
|
||||
;allow_assign_grafana_admin = false
|
||||
;use_pkce = true
|
||||
|
@ -149,6 +149,7 @@ auth_url = https://login.microsoftonline.com/TENANT_ID/oauth2/v2.0/authorize
|
||||
token_url = https://login.microsoftonline.com/TENANT_ID/oauth2/v2.0/token
|
||||
allowed_domains =
|
||||
allowed_groups =
|
||||
allowed_organizations = TENANT_ID
|
||||
role_attribute_strict = false
|
||||
allow_assign_grafana_admin = false
|
||||
skip_org_role_sync = false
|
||||
@ -176,6 +177,19 @@ Grafana uses a refresh token to obtain a new access token without requiring the
|
||||
|
||||
To enable a refresh token for AzureAD, extend the `scopes` in `[auth.azuread]` with `offline_access`.
|
||||
|
||||
### Configure allowed tenants
|
||||
|
||||
To limit access to authenticated users who are members of one or more tenants, set `allowed_organizations`
|
||||
to a comma- or space-separated list of tenant IDs. You can find tenant IDs on the Azure portal under **Azure Active Directory -> Overview**.
|
||||
|
||||
Make sure to include the tenant IDs of all the federated Users' root directory if your Azure AD contains external identities.
|
||||
|
||||
For example, if you want to only give access to members of the tenant `example` with an ID of `8bab1c86-8fba-33e5-2089-1d1c80ec267d`, then set the following:
|
||||
|
||||
```
|
||||
allowed_organizations = 8bab1c86-8fba-33e5-2089-1d1c80ec267d
|
||||
```
|
||||
|
||||
### Configure allowed groups
|
||||
|
||||
To limit access to authenticated users who are members of one or more groups, set `allowed_groups`
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/plugins"
|
||||
@ -72,7 +73,7 @@ func setupTestEnvironment(t *testing.T, cfg *setting.Cfg, features *featuremgmt.
|
||||
PluginsCDNURLTemplate: cfg.PluginsCDNURLTemplate,
|
||||
PluginSettings: cfg.PluginSettings,
|
||||
}),
|
||||
SocialService: social.ProvideService(cfg, features, &usagestats.UsageStatsMock{}, supportbundlestest.NewFakeBundleService()),
|
||||
SocialService: social.ProvideService(cfg, features, &usagestats.UsageStatsMock{}, supportbundlestest.NewFakeBundleService(), remotecache.NewFakeCacheStorage()),
|
||||
}
|
||||
|
||||
m := web.New()
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/models/roletype"
|
||||
@ -34,7 +35,7 @@ func setupSocialHTTPServerWithConfig(t *testing.T, cfg *setting.Cfg) *HTTPServer
|
||||
Cfg: cfg,
|
||||
License: &licensing.OSSLicensingService{Cfg: cfg},
|
||||
SQLStore: sqlStore,
|
||||
SocialService: social.ProvideService(cfg, features, &usagestats.UsageStatsMock{}, supportbundlestest.NewFakeBundleService()),
|
||||
SocialService: social.ProvideService(cfg, features, &usagestats.UsageStatsMock{}, supportbundlestest.NewFakeBundleService(), remotecache.NewFakeCacheStorage()),
|
||||
HooksService: hooks.ProvideService(),
|
||||
SecretsService: fakes.NewFakeSecretsService(),
|
||||
Features: features,
|
||||
|
@ -171,39 +171,6 @@ func TestEncryptedCache(t *testing.T) {
|
||||
require.Equal(t, "bar", string(v))
|
||||
}
|
||||
|
||||
type fakeCacheStorage struct {
|
||||
storage map[string][]byte
|
||||
}
|
||||
|
||||
func (fcs fakeCacheStorage) Set(_ context.Context, key string, value []byte, exp time.Duration) error {
|
||||
fcs.storage[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fcs fakeCacheStorage) Get(_ context.Context, key string) ([]byte, error) {
|
||||
value, exist := fcs.storage[key]
|
||||
if !exist {
|
||||
return nil, ErrCacheItemNotFound
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (fcs fakeCacheStorage) Delete(_ context.Context, key string) error {
|
||||
delete(fcs.storage, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fcs fakeCacheStorage) Count(_ context.Context, prefix string) (int64, error) {
|
||||
return int64(len(fcs.storage)), nil
|
||||
}
|
||||
|
||||
func NewFakeCacheStorage() CacheStorage {
|
||||
return fakeCacheStorage{
|
||||
storage: map[string][]byte{},
|
||||
}
|
||||
}
|
||||
|
||||
type fakeSecretsService struct{}
|
||||
|
||||
func (f fakeSecretsService) Encrypt(_ context.Context, payload []byte, _ secrets.EncryptionOptions) ([]byte, error) {
|
||||
|
39
pkg/infra/remotecache/test_utils.go
Normal file
39
pkg/infra/remotecache/test_utils.go
Normal file
@ -0,0 +1,39 @@
|
||||
package remotecache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FakeCacheStorage struct {
|
||||
Storage map[string][]byte
|
||||
}
|
||||
|
||||
func (fcs FakeCacheStorage) Set(_ context.Context, key string, value []byte, exp time.Duration) error {
|
||||
fcs.Storage[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fcs FakeCacheStorage) Get(_ context.Context, key string) ([]byte, error) {
|
||||
value, exist := fcs.Storage[key]
|
||||
if !exist {
|
||||
return nil, ErrCacheItemNotFound
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (fcs FakeCacheStorage) Delete(_ context.Context, key string) error {
|
||||
delete(fcs.Storage, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fcs FakeCacheStorage) Count(_ context.Context, prefix string) (int64, error) {
|
||||
return int64(len(fcs.Storage)), nil
|
||||
}
|
||||
|
||||
func NewFakeCacheStorage() CacheStorage {
|
||||
return FakeCacheStorage{
|
||||
Storage: map[string][]byte{},
|
||||
}
|
||||
}
|
@ -7,23 +7,38 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/models/roletype"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
)
|
||||
|
||||
const (
|
||||
azureCacheKeyPrefix = "azuread_oauth_jwks-"
|
||||
defaultCacheExpiration = 24 * time.Hour
|
||||
tenantRegex = `^https:\/\/login\.microsoftonline\.com\/(?P<tenant>[a-zA-Z0-9\-]+)\/oauth2\/v2\.0\/authorize$`
|
||||
)
|
||||
|
||||
type SocialAzureAD struct {
|
||||
*SocialBase
|
||||
allowedGroups []string
|
||||
forceUseGraphAPI bool
|
||||
skipOrgRoleSync bool
|
||||
cache remotecache.CacheStorage
|
||||
allowedOrganizations []string
|
||||
allowedGroups []string
|
||||
forceUseGraphAPI bool
|
||||
skipOrgRoleSync bool
|
||||
compiledTenantRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
type azureClaims struct {
|
||||
Audience string `json:"aud"`
|
||||
Email string `json:"email"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Roles []string `json:"roles"`
|
||||
@ -48,6 +63,10 @@ type azureAccessClaims struct {
|
||||
TenantID string `json:"tid"`
|
||||
}
|
||||
|
||||
type keySetJWKS struct {
|
||||
jose.JSONWebKeySet
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
idToken := token.Extra("id_token")
|
||||
if idToken == nil {
|
||||
@ -60,14 +79,46 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token
|
||||
}
|
||||
|
||||
var claims azureClaims
|
||||
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||
return nil, fmt.Errorf("error getting claims from id token: %w", err)
|
||||
|
||||
keyset, err := s.retrieveJWKS(client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving jwks: %w", err)
|
||||
}
|
||||
|
||||
var errClaims error
|
||||
keyID := parsedToken.Headers[0].KeyID
|
||||
keys := keyset.Key(keyID)
|
||||
if len(keys) == 0 {
|
||||
s.log.Warn("AzureAD OAuth: signing key not found",
|
||||
"kid", keyID,
|
||||
"keys", fmt.Sprintf("%v", keyset.Keys))
|
||||
return nil, &Error{"AzureAD OAuth: signing key not found"}
|
||||
}
|
||||
for _, key := range keys {
|
||||
s.log.Debug("AzureAD OAuth: trying to parse token with key", "kid", key.KeyID)
|
||||
if errClaims = parsedToken.Claims(key, &claims); errClaims == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if errClaims != nil {
|
||||
return nil, fmt.Errorf("error getting claims from id token: %w", errClaims)
|
||||
}
|
||||
|
||||
if claims.OAuthVersion == "1.0" {
|
||||
return nil, &Error{"AzureAD OAuth: version 1.0 is not supported. Please ensure the auth_url and token_url are set to the v2.0 endpoints."}
|
||||
}
|
||||
|
||||
s.log.Debug("Validating audience", "audience", claims.Audience, "client_id", s.ClientID)
|
||||
if claims.Audience != s.ClientID {
|
||||
return nil, &Error{"AzureAD OAuth: audience mismatch"}
|
||||
}
|
||||
|
||||
s.log.Debug("Validating tenant", "tenant", claims.TenantID, "allowed_tenants", s.allowedOrganizations)
|
||||
if !s.isAllowedTenant(claims.TenantID) {
|
||||
return nil, &Error{"AzureAD OAuth: tenant mismatch"}
|
||||
}
|
||||
|
||||
email := claims.extractEmail()
|
||||
if email == "" {
|
||||
return nil, ErrEmailNotFound
|
||||
@ -82,13 +133,13 @@ func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token
|
||||
if s.roleAttributeStrict && !role.IsValid() {
|
||||
return nil, &InvalidBasicRoleError{idP: "Azure", assignedRole: string(role)}
|
||||
}
|
||||
logger.Debug("AzureAD OAuth: extracted role", "email", email, "role", role)
|
||||
s.log.Debug("AzureAD OAuth: extracted role", "email", email, "role", role)
|
||||
|
||||
groups, err := s.extractGroups(ctx, client, claims, token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extract groups: %w", err)
|
||||
}
|
||||
logger.Debug("AzureAD OAuth: extracted groups", "email", email, "groups", fmt.Sprintf("%v", groups))
|
||||
s.log.Debug("AzureAD OAuth: extracted groups", "email", email, "groups", fmt.Sprintf("%v", groups))
|
||||
if !s.IsGroupMember(groups) {
|
||||
return nil, errMissingGroupMembership
|
||||
}
|
||||
@ -179,7 +230,7 @@ type getAzureGroupResponse struct {
|
||||
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
|
||||
func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
|
||||
if !s.forceUseGraphAPI {
|
||||
logger.Debug("checking the claim for groups")
|
||||
s.log.Debug("checking the claim for groups")
|
||||
if len(claims.Groups) > 0 {
|
||||
return claims.Groups, nil
|
||||
}
|
||||
@ -190,7 +241,7 @@ func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client,
|
||||
}
|
||||
|
||||
// Fallback to the Graph API
|
||||
endpoint, errBuildGraphURI := groupsGraphAPIURL(claims, token)
|
||||
endpoint, errBuildGraphURI := s.groupsGraphAPIURL(claims, token)
|
||||
if errBuildGraphURI != nil {
|
||||
return nil, errBuildGraphURI
|
||||
}
|
||||
@ -213,16 +264,16 @@ func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client,
|
||||
|
||||
defer func() {
|
||||
if err := res.Body.Close(); err != nil {
|
||||
logger.Warn("AzureAD OAuth: failed to close response body", "err", err)
|
||||
s.log.Warn("AzureAD OAuth: failed to close response body", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
if res.StatusCode == http.StatusForbidden {
|
||||
logger.Warn("AzureAD OAuh: Token need GroupMember.Read.All permission to fetch all groups")
|
||||
s.log.Warn("AzureAD OAuh: Token need GroupMember.Read.All permission to fetch all groups")
|
||||
} else {
|
||||
body, _ := io.ReadAll(res.Body)
|
||||
logger.Warn("AzureAD OAuh: could not fetch user groups", "code", res.StatusCode, "body", string(body))
|
||||
s.log.Warn("AzureAD OAuh: could not fetch user groups", "code", res.StatusCode, "body", string(body))
|
||||
}
|
||||
return []string{}, nil
|
||||
}
|
||||
@ -237,12 +288,12 @@ func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client,
|
||||
|
||||
// groupsGraphAPIURL retrieves the Microsoft Graph API URL to fetch user groups from the _claim_sources if present
|
||||
// otherwise it generates an handcrafted URL.
|
||||
func groupsGraphAPIURL(claims azureClaims, token *oauth2.Token) (string, error) {
|
||||
func (s *SocialAzureAD) groupsGraphAPIURL(claims azureClaims, token *oauth2.Token) (string, error) {
|
||||
var endpoint string
|
||||
// First check if an endpoint was specified in the claims
|
||||
if claims.ClaimNames.Groups != "" {
|
||||
endpoint = claims.ClaimSources[claims.ClaimNames.Groups].Endpoint
|
||||
logger.Debug(fmt.Sprintf("endpoint to fetch groups specified in the claims: %s", endpoint))
|
||||
s.log.Debug(fmt.Sprintf("endpoint to fetch groups specified in the claims: %s", endpoint))
|
||||
}
|
||||
|
||||
// If no endpoint was specified or if the endpoints provided in _claim_source is pointing to the deprecated
|
||||
@ -265,7 +316,7 @@ func groupsGraphAPIURL(claims azureClaims, token *oauth2.Token) (string, error)
|
||||
}
|
||||
|
||||
endpoint = fmt.Sprintf("https://graph.microsoft.com/v1.0/%s/users/%s/getMemberObjects", tenantID, claims.ID)
|
||||
logger.Debug(fmt.Sprintf("handcrafted endpoint to fetch groups: %s", endpoint))
|
||||
s.log.Debug(fmt.Sprintf("handcrafted endpoint to fetch groups: %s", endpoint))
|
||||
}
|
||||
return endpoint, nil
|
||||
}
|
||||
@ -279,3 +330,102 @@ func (s *SocialAzureAD) SupportBundleContent(bf *bytes.Buffer) error {
|
||||
|
||||
return s.SocialBase.SupportBundleContent(bf)
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) extractTenantID(authURL string) (string, error) {
|
||||
if s.compiledTenantRegex == nil {
|
||||
compiledTenantRegex, err := regexp.Compile(`https://login.microsoftonline.com/([^/]+)/oauth2`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
s.compiledTenantRegex = compiledTenantRegex
|
||||
}
|
||||
|
||||
matches := s.compiledTenantRegex.FindStringSubmatch(authURL)
|
||||
if len(matches) < 2 {
|
||||
return "", fmt.Errorf("unable to extract tenant ID from URL")
|
||||
}
|
||||
return matches[1], nil
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) retrieveJWKS(client *http.Client) (*jose.JSONWebKeySet, error) {
|
||||
var jwks keySetJWKS
|
||||
// https://login.microsoftonline.com/organizations/oauth2/v2.0/authorize retrieve organizations
|
||||
// https://login.microsoftonline.com/xxx/oauth2/v2.0/authorize retrieve specific tenant xxx
|
||||
|
||||
tenant_id, err := s.extractTenantID(s.Endpoint.AuthURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// example: azuread_oauth_jwks-33121321nUd
|
||||
cacheKey := azureCacheKeyPrefix + tenant_id
|
||||
|
||||
// TODO: propagate context
|
||||
if val, err := s.cache.Get(context.Background(), cacheKey); err == nil {
|
||||
err := json.Unmarshal(val, &jwks)
|
||||
return &jwks.JSONWebKeySet, err
|
||||
} else {
|
||||
s.log.Debug("Keyset not found in cache", "err", err)
|
||||
}
|
||||
|
||||
// TODO: allow setting well-known endpoint and retrieve from there
|
||||
keysetURL := strings.Replace(s.Endpoint.AuthURL, "/oauth2/v2.0/authorize", "/discovery/v2.0/keys", 1)
|
||||
|
||||
resp, err := s.httpGet(client, keysetURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bytesReader := bytes.NewReader(resp.Body)
|
||||
var jsonBuf bytes.Buffer
|
||||
if err := json.NewDecoder(io.TeeReader(bytesReader, &jsonBuf)).Decode(&jwks); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cacheExpiration := getCacheExpiration(resp.Headers.Get("cache-control"))
|
||||
s.log.Debug("Setting key set in cache", "url", keysetURL, "cache-key", cacheKey,
|
||||
"cacheExpiration", cacheExpiration)
|
||||
|
||||
if err := s.cache.Set(context.Background(), cacheKey, jsonBuf.Bytes(), cacheExpiration); err != nil {
|
||||
s.log.Warn("Failed to set key set in cache", "url", keysetURL, "cache-key", cacheKey, "err", err)
|
||||
}
|
||||
|
||||
return &jwks.JSONWebKeySet, nil
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) isAllowedTenant(tenantID string) bool {
|
||||
if len(s.allowedOrganizations) == 0 {
|
||||
s.log.Warn("No allowed organizations specified, all tenants are allowed. Configure allowed_organizations to restrict access")
|
||||
return true
|
||||
}
|
||||
|
||||
for _, t := range s.allowedOrganizations {
|
||||
if t == tenantID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getCacheExpiration(header string) time.Duration {
|
||||
if header == "" {
|
||||
return defaultCacheExpiration
|
||||
}
|
||||
|
||||
// Cache-Control: public, max-age=14400
|
||||
cacheControl := strings.Split(header, ",")
|
||||
for _, v := range cacheControl {
|
||||
if strings.Contains(v, "max-age") {
|
||||
parts := strings.Split(v, "=")
|
||||
if len(parts) == 2 {
|
||||
seconds, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return defaultCacheExpiration
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return defaultCacheExpiration
|
||||
}
|
||||
|
@ -2,6 +2,8 @@ package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -14,6 +16,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
)
|
||||
|
||||
@ -29,9 +32,10 @@ func falseBoolPtr() *bool {
|
||||
|
||||
func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
type fields struct {
|
||||
SocialBase *SocialBase
|
||||
allowedGroups []string
|
||||
forceUseGraphAPI bool
|
||||
SocialBase *SocialBase
|
||||
allowedGroups []string
|
||||
allowedOrganizations []string
|
||||
forceUseGraphAPI bool
|
||||
}
|
||||
type args struct {
|
||||
client *http.Client
|
||||
@ -244,7 +248,8 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
name: "Editor roles in claim and GrafanaAdminAssignment enabled",
|
||||
fields: fields{
|
||||
SocialBase: newSocialBase("azuread",
|
||||
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures())},
|
||||
&oauth2.Config{}, &OAuthInfo{AllowAssignGrafanaAdmin: true}, "", false, *featuremgmt.WithFeatures()),
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
PreferredUsername: "",
|
||||
@ -300,7 +305,47 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Error if user is a member of allowed_groups",
|
||||
name: "Error if user is not a member of allowed_organizations",
|
||||
fields: fields{
|
||||
allowedOrganizations: []string{"uuid-1234"},
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
TenantID: "uuid-5678",
|
||||
PreferredUsername: "",
|
||||
Roles: []string{},
|
||||
Groups: []string{"foo", "bar"},
|
||||
Name: "My Name",
|
||||
ID: "1234",
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
}, {
|
||||
name: "No error if user is a member of allowed_organizations",
|
||||
fields: fields{
|
||||
allowedOrganizations: []string{"uuid-1234", "uuid-5678"},
|
||||
},
|
||||
claims: &azureClaims{
|
||||
Email: "me@example.com",
|
||||
TenantID: "uuid-5678",
|
||||
PreferredUsername: "",
|
||||
Roles: []string{},
|
||||
Groups: []string{"foo", "bar"},
|
||||
Name: "My Name",
|
||||
ID: "1234",
|
||||
},
|
||||
want: &BasicUserInfo{
|
||||
Id: "1234",
|
||||
Name: "My Name",
|
||||
Email: "me@example.com",
|
||||
Login: "me@example.com",
|
||||
Role: "",
|
||||
Groups: []string{"foo", "bar"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "No Error if user is a member of allowed_groups",
|
||||
fields: fields{
|
||||
allowedGroups: []string{"foo", "bar"},
|
||||
SocialBase: newSocialBase("azuread",
|
||||
@ -409,23 +454,51 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instantiate a signer using RSASSA-PSS (SHA256) with the given private key.
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.PS256, Key: privateKey}, (&jose.SignerOptions{
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{"kid": "1"},
|
||||
}).WithType("JWT"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// generate JWKS
|
||||
jwks := &jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{
|
||||
Key: privateKey.Public(),
|
||||
KeyID: "1",
|
||||
Algorithm: "PS256",
|
||||
Use: "sig",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
authURL := "https://login.microsoftonline.com/1234/oauth2/v2.0/authorize"
|
||||
cache := remotecache.NewFakeCacheStorage()
|
||||
// put JWKS in cache
|
||||
jwksDump, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Set(context.Background(), azureCacheKeyPrefix+"1234", jwksDump, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &SocialAzureAD{
|
||||
SocialBase: tt.fields.SocialBase,
|
||||
allowedGroups: tt.fields.allowedGroups,
|
||||
forceUseGraphAPI: tt.fields.forceUseGraphAPI,
|
||||
SocialBase: tt.fields.SocialBase,
|
||||
allowedGroups: tt.fields.allowedGroups,
|
||||
allowedOrganizations: tt.fields.allowedOrganizations,
|
||||
forceUseGraphAPI: tt.fields.forceUseGraphAPI,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
if tt.fields.SocialBase == nil {
|
||||
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
|
||||
}
|
||||
|
||||
key := []byte("secret")
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s.SocialBase.Endpoint.AuthURL = authURL
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "subject",
|
||||
@ -552,6 +625,36 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Instantiate a signer using RSASSA-PSS (SHA256) with the given private key.
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.PS256, Key: privateKey}, (&jose.SignerOptions{
|
||||
ExtraHeaders: map[jose.HeaderKey]interface{}{"kid": "1"},
|
||||
}).WithType("JWT"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// generate JWKS
|
||||
jwks := &jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{
|
||||
Key: privateKey.Public(),
|
||||
KeyID: "1",
|
||||
Algorithm: string(jose.PS256),
|
||||
Use: "sig",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
authURL := "https://login.microsoftonline.com/1234/oauth2/v2.0/authorize"
|
||||
cache := remotecache.NewFakeCacheStorage()
|
||||
// put JWKS in cache
|
||||
jwksDump, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = cache.Set(context.Background(), azureCacheKeyPrefix+"1234", jwksDump, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &SocialAzureAD{
|
||||
@ -559,17 +662,14 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
|
||||
allowedGroups: tt.fields.allowedGroups,
|
||||
forceUseGraphAPI: tt.fields.forceUseGraphAPI,
|
||||
skipOrgRoleSync: tt.fields.skipOrgRoleSync,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
if tt.fields.SocialBase == nil {
|
||||
s.SocialBase = newSocialBase("azuread", &oauth2.Config{}, &OAuthInfo{}, "", false, *featuremgmt.WithFeatures())
|
||||
}
|
||||
|
||||
key := []byte("secret")
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s.SocialBase.Endpoint.AuthURL = authURL
|
||||
|
||||
cl := jwt.Claims{
|
||||
Subject: "subject",
|
||||
|
@ -21,7 +21,9 @@ import (
|
||||
"golang.org/x/text/cases"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/grafana/grafana/pkg/cmd/grafana-cli/logger"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
@ -30,10 +32,6 @@ import (
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
|
||||
var (
|
||||
logger = log.New("social")
|
||||
)
|
||||
|
||||
type SocialService struct {
|
||||
cfg *setting.Cfg
|
||||
|
||||
@ -74,6 +72,7 @@ func ProvideService(cfg *setting.Cfg,
|
||||
features *featuremgmt.FeatureManager,
|
||||
usageStats usagestats.Service,
|
||||
bundleRegistry supportbundles.Service,
|
||||
cache remotecache.CacheStorage,
|
||||
) *SocialService {
|
||||
ss := &SocialService{
|
||||
cfg: cfg,
|
||||
@ -188,10 +187,12 @@ func ProvideService(cfg *setting.Cfg,
|
||||
// AzureAD.
|
||||
if name == "azuread" {
|
||||
ss.socialMap["azuread"] = &SocialAzureAD{
|
||||
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
|
||||
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
|
||||
forceUseGraphAPI: sec.Key("force_use_graph_api").MustBool(false),
|
||||
skipOrgRoleSync: cfg.AzureADSkipOrgRoleSync,
|
||||
SocialBase: newSocialBase(name, &config, info, cfg.AutoAssignOrgRole, cfg.OAuthSkipOrgRoleUpdateSync, *features),
|
||||
cache: cache,
|
||||
allowedGroups: util.SplitString(sec.Key("allowed_groups").String()),
|
||||
forceUseGraphAPI: sec.Key("force_use_graph_api").MustBool(false),
|
||||
allowedOrganizations: util.SplitString(sec.Key("allowed_organizations").String()),
|
||||
skipOrgRoleSync: cfg.AzureADSkipOrgRoleSync,
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user