AuthN: Add oauth clients and perform oauth authentication with authn.Service (#62072)

* AuthN: Update signature of redirect client and RedirectURL function

* OAuth: use authn.Service to perform oauth authentication and login if feature toggle is enabled

* AuthN: register oauth clients

* AuthN: set auth module metadata

* AuthN: add logs for failed login attempts

* AuthN: Don't use enable disabled setting

* OAuth: only run hooks when authnService feature toggle is disabled

* OAuth: Add function to handle oauth errors from authn.Service
This commit is contained in:
Karl Persson 2023-01-30 12:45:04 +01:00 committed by GitHub
parent e3bfc67d7b
commit efeb0daec6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 681 additions and 39 deletions

View File

@ -17,11 +17,14 @@ import (
"github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/middleware/cookies"
"github.com/grafana/grafana/pkg/services/authn"
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
"github.com/grafana/grafana/pkg/services/featuremgmt"
loginservice "github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util/errutil"
"github.com/grafana/grafana/pkg/web"
)
@ -70,11 +73,59 @@ func genPKCECode() (string, string, error) {
}
func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
loginInfo := loginservice.LoginInfo{
AuthModule: "oauth",
}
name := web.Params(ctx.Req)[":name"]
loginInfo.AuthModule = name
loginInfo := loginservice.LoginInfo{AuthModule: name}
if errorParam := ctx.Query("error"); errorParam != "" {
errorDesc := ctx.Query("error_description")
oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc)
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc)
return
}
code := ctx.Query("code")
if hs.Features.IsEnabled(featuremgmt.FlagAuthnService) {
req := &authn.Request{HTTPRequest: ctx.Req, Resp: ctx.Resp}
if code == "" {
redirect, err := hs.authnService.RedirectURL(ctx.Req.Context(), authn.ClientWithPrefix(name), req)
if err != nil {
hs.handleAuthnOAuthErr(ctx, "failed to generate oauth redirect url", err)
return
}
if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" {
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
}
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
ctx.Redirect(redirect.URL)
return
}
identity, err := hs.authnService.Login(ctx.Req.Context(), authn.ClientWithPrefix(name), req)
// NOTE: always delete these cookies, even if login failed
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
if err != nil {
hs.handleAuthnOAuthErr(ctx, "failed to perform login for oauth request", err)
return
}
metrics.MApiLoginOAuth.Inc()
cookies.WriteSessionCookie(ctx, hs.Cfg, identity.SessionToken.UnhashedToken, hs.Cfg.LoginMaxLifetime)
redirectURL := setting.AppSubUrl + "/"
if redirectTo := ctx.GetCookie("redirect_to"); len(redirectTo) > 0 && hs.ValidateRedirectTo(redirectTo) == nil {
redirectURL = redirectTo
cookies.DeleteCookie(ctx.Resp, "redirect_to", hs.CookieOptionsFromCfg)
}
ctx.Redirect(redirectURL)
return
}
provider := hs.SocialService.GetOAuthInfoProvider(name)
if provider == nil {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, errors.New("OAuth not enabled"))
@ -87,15 +138,6 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
return
}
errorParam := ctx.Query("error")
if errorParam != "" {
errorDesc := ctx.Query("error_description")
oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc)
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc)
return
}
code := ctx.Query("code")
if code == "" {
var opts []oauth2.AuthCodeOption
if provider.UsePKCE {
@ -106,6 +148,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
HttpStatus: http.StatusInternalServerError,
PublicMessage: "An internal error occurred",
})
return
}
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
@ -345,6 +388,19 @@ func (hs *HTTPServer) hashStatecode(code, seed string) string {
return hex.EncodeToString(hashBytes[:])
}
func (hs *HTTPServer) handleAuthnOAuthErr(c *contextmodel.ReqContext, msg string, err error) {
gfErr := &errutil.Error{}
if errors.As(err, gfErr) {
if gfErr.Public().Message != "" {
c.Handle(hs.Cfg, gfErr.Public().StatusCode, gfErr.Public().Message, err)
return
}
}
c.Logger.Warn(msg, "err", err)
c.Redirect(hs.Cfg.AppSubURL + "/login")
}
type LoginError struct {
HttpStatus int
PublicMessage string
@ -354,18 +410,24 @@ type LoginError struct {
func (hs *HTTPServer) handleOAuthLoginError(ctx *contextmodel.ReqContext, info loginservice.LoginInfo, err LoginError) {
ctx.Handle(hs.Cfg, err.HttpStatus, err.PublicMessage, err.Err)
info.Error = err.Err
if info.Error == nil {
info.Error = errors.New(err.PublicMessage)
}
info.HTTPStatus = err.HttpStatus
// login hooks is handled by authn.Service
if !hs.Features.IsEnabled(featuremgmt.FlagAuthnService) {
info.Error = err.Err
if info.Error == nil {
info.Error = errors.New(err.PublicMessage)
}
info.HTTPStatus = err.HttpStatus
hs.HooksService.RunLoginHook(&info, ctx)
hs.HooksService.RunLoginHook(&info, ctx)
}
}
func (hs *HTTPServer) handleOAuthLoginErrorWithRedirect(ctx *contextmodel.ReqContext, info loginservice.LoginInfo, err error, v ...interface{}) {
hs.redirectWithError(ctx, err, v...)
info.Error = err
hs.HooksService.RunLoginHook(&info, ctx)
// login hooks is handled by authn.Service
if !hs.Features.IsEnabled(featuremgmt.FlagAuthnService) {
info.Error = err
hs.HooksService.RunLoginHook(&info, ctx)
}
}

View File

@ -34,6 +34,7 @@ func setupSocialHTTPServerWithConfig(t *testing.T, cfg *setting.Cfg) *HTTPServer
SocialService: social.ProvideService(cfg, featuremgmt.WithFeatures()),
HooksService: hooks.ProvideService(),
SecretsService: fakes.NewFakeSecretsService(),
Features: featuremgmt.WithFeatures(),
}
}

View File

@ -33,7 +33,7 @@ const (
MetaKeyAuthModule = "authModule"
)
// ClientParams are hints to the auth serviAuthN: Post login hooksce about how to handle the identity management
// ClientParams are hints to the auth service about how to handle the identity management
// from the authenticating client.
type ClientParams struct {
// Update the internal representation of the entity from the identity provided
@ -42,7 +42,7 @@ type ClientParams struct {
SyncTeamMembers bool
// Create entity in the DB if it doesn't exist
AllowSignUp bool
// EnableDisabledUsers is a hint to the auth service that it should reenable disabled users
// EnableDisabledUsers is a hint to the auth service that it should re-enable disabled users
EnableDisabledUsers bool
// LookUpParams are the arguments used to look up the entity in the DB.
LookUpParams login.UserLookupParams
@ -63,7 +63,7 @@ type Service interface {
// A lower number means higher priority.
RegisterPostLoginHook(hook PostLoginHookFn, priority uint)
// RedirectURL will generate url that we can use to initiate auth flow for supported clients.
RedirectURL(ctx context.Context, client string, r *Request) (string, error)
RedirectURL(ctx context.Context, client string, r *Request) (*Redirect, error)
}
type Client interface {
@ -83,7 +83,7 @@ type ContextAwareClient interface {
type RedirectClient interface {
Client
RedirectURL(ctx context.Context, r *Request) (string, error)
RedirectURL(ctx context.Context, r *Request) (*Redirect, error)
}
type PasswordClient interface {
@ -122,6 +122,18 @@ func (r *Request) GetMeta(k string) string {
return r.metadata[k]
}
const (
KeyOAuthPKCE = "pkce"
KeyOAuthState = "state"
)
type Redirect struct {
// Url used for redirect
URL string
// Extra contains data used for redirect, e.g. for oauth this would be state and pkce
Extra map[string]string
}
const (
NamespaceUser = "user"
NamespaceAPIKey = "api-key"
@ -144,7 +156,7 @@ type Identity struct {
ID string
// IsAnonymous
IsAnonymous bool
// Login is the short hand identifier of the entity. Should be unique.
// Login is the shorthand identifier of the entity. Should be unique.
Login string
// Name is the display name of the entity. It is not guaranteed to be unique.
Name string
@ -283,3 +295,8 @@ func IdentityFromSignedInUser(id string, usr *user.SignedInUser, params ClientPa
ClientParams: params,
}
}
// ClientWithPrefix returns a client name prefixed with "auth.client."
func ClientWithPrefix(name string) string {
return fmt.Sprintf("auth.client.%s", name)
}

View File

@ -11,6 +11,7 @@ import (
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/network"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/apikey"
"github.com/grafana/grafana/pkg/services/auth"
@ -52,6 +53,7 @@ func ProvideService(
loginAttempts loginattempt.Service, quotaService quota.Service,
authInfoService login.AuthInfoService, renderService rendering.Service,
features *featuremgmt.FeatureManager, oauthTokenService oauthtoken.OAuthTokenService,
socialService social.Service,
) *Service {
s := &Service{
log: log.New("authn.service"),
@ -116,6 +118,21 @@ func ProvideService(
s.RegisterClient(clients.ProvideJWT(jwtService, cfg))
}
for name := range socialService.GetOAuthProviders() {
oauthCfg := socialService.GetOAuthInfoProvider(name)
if oauthCfg != nil && oauthCfg.Enabled {
clientName := authn.ClientWithPrefix(name)
connector, errConnector := socialService.GetConnector(name)
httpClient, errHTTPClient := socialService.GetOAuthHttpClient(name)
if errConnector != nil || errHTTPClient != nil {
s.log.Error("failed to configure oauth client", "client", clientName, "err", multierror.Append(errConnector, errHTTPClient))
}
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient))
}
}
// FIXME (jguer): move to User package
userSyncService := sync.ProvideUserSync(userService, userProtectionService, authInfoService, quotaService)
orgUserSyncService := sync.ProvideOrgSync(userService, orgService, accessControlService)
@ -233,6 +250,7 @@ func (s *Service) Login(ctx context.Context, client string, r *authn.Request) (i
sessionToken, err := s.sessionService.CreateToken(ctx, &user.User{ID: id}, ip, r.HTTPRequest.UserAgent())
if err != nil {
s.log.FromContext(ctx).Error("failed to create session", "client", client, "userId", id, "err", err)
return nil, err
}
@ -244,19 +262,19 @@ func (s *Service) RegisterPostLoginHook(hook authn.PostLoginHookFn, priority uin
s.postLoginHooks.insert(hook, priority)
}
func (s *Service) RedirectURL(ctx context.Context, client string, r *authn.Request) (string, error) {
func (s *Service) RedirectURL(ctx context.Context, client string, r *authn.Request) (*authn.Redirect, error) {
ctx, span := s.tracer.Start(ctx, "authn.RedirectURL")
defer span.End()
span.SetAttributes(attributeKeyClient, client, attribute.Key(attributeKeyClient).String(client))
c, ok := s.clients[client]
if !ok {
return "", authn.ErrClientNotConfigured.Errorf("client not configured: %s", client)
return nil, authn.ErrClientNotConfigured.Errorf("client not configured: %s", client)
}
redirectClient, ok := c.(authn.RedirectClient)
if !ok {
return "", authn.ErrUnsupportedClient.Errorf("client does not support generating redirect url: %s", client)
return nil, authn.ErrUnsupportedClient.Errorf("client does not support generating redirect url: %s", client)
}
return redirectClient.RedirectURL(ctx, r)

View File

@ -243,16 +243,13 @@ func TestService_RedirectURL(t *testing.T) {
type testCase struct {
desc string
client string
expectedURL string
expectedErr error
}
tests := []testCase{
{
desc: "should generate url for valid redirect client",
client: "redirect",
expectedURL: "https://localhost/redirect",
expectedErr: nil,
desc: "should generate url for valid redirect client",
client: "redirect",
},
{
desc: "should return error on non existing client",
@ -269,13 +266,12 @@ func TestService_RedirectURL(t *testing.T) {
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
service := setupTests(t, func(svc *Service) {
svc.RegisterClient(authntest.FakeRedirectClient{ExpectedName: "redirect", ExpectedURL: tt.expectedURL})
svc.RegisterClient(authntest.FakeRedirectClient{ExpectedName: "redirect"})
svc.RegisterClient(&authntest.FakeClient{ExpectedName: "non-redirect"})
})
u, err := service.RedirectURL(context.Background(), tt.client, nil)
_, err := service.RedirectURL(context.Background(), tt.client, nil)
assert.ErrorIs(t, err, tt.expectedErr)
assert.Equal(t, tt.expectedURL, u)
})
}
}

View File

@ -53,6 +53,8 @@ type FakeRedirectClient struct {
ExpectedErr error
ExpectedURL string
ExpectedName string
ExpectedOK bool
ExpectedRedirect *authn.Redirect
ExpectedIdentity *authn.Identity
}
@ -64,6 +66,10 @@ func (f FakeRedirectClient) Authenticate(ctx context.Context, r *authn.Request)
return f.ExpectedIdentity, f.ExpectedErr
}
func (f FakeRedirectClient) RedirectURL(ctx context.Context, r *authn.Request) (string, error) {
return f.ExpectedURL, f.ExpectedErr
func (f FakeRedirectClient) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) {
return f.ExpectedRedirect, f.ExpectedErr
}
func (f FakeRedirectClient) Test(ctx context.Context, r *authn.Request) bool {
return f.ExpectedOK
}

View File

@ -0,0 +1,237 @@
package clients
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"strings"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util/errutil"
)
const (
hostedDomainParamName = "hd"
codeVerifierParamName = "code_verifier"
codeChallengeParamName = "code_challenge"
codeChallengeMethodParamName = "code_challenge_method"
codeChallengeMethod = "S256"
oauthStateQueryName = "state"
oauthStateCookieName = "oauth_state"
oauthPKCECookieName = "oauth_code_verifier"
)
var (
errOAuthGenPKCE = errutil.NewBase(errutil.StatusInternal, "auth.oauth.pkce.internal", errutil.WithPublicMessage("An internal error occurred"))
errOAuthMissingPKCE = errutil.NewBase(errutil.StatusBadRequest, "auth.oauth.pkce.missing", errutil.WithPublicMessage("Missing required pkce cookie"))
errOAuthGenState = errutil.NewBase(errutil.StatusInternal, "auth.oauth.state.internal", errutil.WithPublicMessage("An internal error occurred"))
errOAuthMissingState = errutil.NewBase(errutil.StatusBadRequest, "auth.oauth.state.missing", errutil.WithPublicMessage("Missing saved oauth state"))
errOAuthInvalidState = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.state.invalid", errutil.WithPublicMessage("Provided state does not match stored state"))
errOAuthTokenExchange = errutil.NewBase(errutil.StatusInternal, "auth.oauth.token.exchange", errutil.WithPublicMessage("Failed to get token from provider"))
errOAuthMissingRequiredEmail = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.email.missing")
errOAuthEmailNotAllowed = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.email.not-allowed")
)
var _ authn.RedirectClient = new(OAuth)
func ProvideOAuth(
name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo,
connector social.SocialConnector, httpClient *http.Client,
) *OAuth {
return &OAuth{
name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")),
log.New(name), cfg, oauthCfg, connector, httpClient,
}
}
type OAuth struct {
name string
moduleName string
log log.Logger
cfg *setting.Cfg
oauthCfg *social.OAuthInfo
connector social.SocialConnector
httpClient *http.Client
}
func (c *OAuth) Name() string {
return c.name
}
func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
r.SetMeta(authn.MetaKeyAuthModule, c.moduleName)
// get hashed state stored in cookie
stateCookie, err := r.HTTPRequest.Cookie(oauthStateCookieName)
if err != nil {
return nil, errOAuthMissingState.Errorf("missing state cookie")
}
if stateCookie.Value == "" {
return nil, errOAuthMissingState.Errorf("missing state value in state cookie")
}
// get state returned by the idp and hash it
stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, c.oauthCfg.ClientSecret)
// compare the state returned by idp against the one we stored in cookie
if stateQuery != stateCookie.Value {
return nil, errOAuthInvalidState.Errorf("provided state did not match stored state")
}
var opts []oauth2.AuthCodeOption
// if pkce is enabled for client validate we have the cookie and set it as url param
if c.oauthCfg.UsePKCE {
pkceCookie, err := r.HTTPRequest.Cookie(oauthPKCECookieName)
if err != nil {
return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err)
}
opts = append(opts, oauth2.SetAuthURLParam(codeVerifierParamName, pkceCookie.Value))
}
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
// exchange auth code to a valid token
token, err := c.connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...)
if err != nil {
return nil, err
}
token.TokenType = "Bearer"
userInfo, err := c.connector.UserInfo(c.connector.Client(clientCtx, token), token)
if err != nil {
return nil, errOAuthTokenExchange.Errorf("failed to exchange code to token: %w", err)
}
if userInfo.Email == "" {
return nil, errOAuthMissingRequiredEmail.Errorf("required attribute email was not provided")
}
if !c.connector.IsEmailAllowed(userInfo.Email) {
return nil, errOAuthEmailNotAllowed.Errorf("provided email is not allowed")
}
return &authn.Identity{
Login: userInfo.Login,
Name: userInfo.Name,
Email: userInfo.Email,
IsGrafanaAdmin: userInfo.IsGrafanaAdmin,
AuthModule: c.moduleName,
AuthID: userInfo.Id,
Groups: userInfo.Groups,
OAuthToken: token,
OrgRoles: getOAuthOrgRole(userInfo, c.cfg),
ClientParams: authn.ClientParams{
SyncUser: true,
SyncTeamMembers: true,
AllowSignUp: c.connector.IsSignupAllowed(),
LookUpParams: login.UserLookupParams{Email: &userInfo.Email},
},
}, nil
}
func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) {
var opts []oauth2.AuthCodeOption
if c.oauthCfg.HostedDomain != "" {
opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, c.oauthCfg.HostedDomain))
}
var plainPKCE string
if c.oauthCfg.UsePKCE {
pkce, hashedPKCE, err := genPKCECode()
if err != nil {
return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err)
}
plainPKCE = pkce
opts = append(opts,
oauth2.SetAuthURLParam(codeChallengeParamName, hashedPKCE),
oauth2.SetAuthURLParam(codeChallengeMethodParamName, codeChallengeMethod),
)
}
state, hashedSate, err := genOAuthState(c.cfg.SecretKey, c.oauthCfg.ClientSecret)
if err != nil {
return nil, errOAuthGenState.Errorf("failed to generate state: %w", err)
}
return &authn.Redirect{
URL: c.connector.AuthCodeURL(state, opts...),
Extra: map[string]string{
authn.KeyOAuthState: hashedSate,
authn.KeyOAuthPKCE: plainPKCE,
},
}, nil
}
// genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest.
func genPKCECode() (string, string, error) {
// IETF RFC 7636 specifies that the code verifier should be 43-128
// characters from a set of unreserved URI characters which is
// almost the same as the set of characters in base64url.
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
//
// It doesn't hurt to generate a few more bytes here, we generate
// 96 bytes which we then encode using base64url to make sure
// they're within the set of unreserved characters.
//
// 96 is chosen because 96*8/6 = 128, which means that we'll have
// 128 characters after it has been base64 encoded.
raw := make([]byte, 96)
_, err := rand.Read(raw)
if err != nil {
return "", "", err
}
ascii := make([]byte, 128)
base64.RawURLEncoding.Encode(ascii, raw)
shasum := sha256.Sum256(ascii)
pkce := base64.RawURLEncoding.EncodeToString(shasum[:])
return string(ascii), pkce, nil
}
func genOAuthState(secret, seed string) (string, string, error) {
rnd := make([]byte, 32)
if _, err := rand.Read(rnd); err != nil {
return "", "", err
}
state := base64.URLEncoding.EncodeToString(rnd)
return state, hashOAuthState(state, secret, seed), nil
}
func hashOAuthState(state, secret, seed string) string {
hashBytes := sha256.Sum256([]byte(state + secret + seed))
return hex.EncodeToString(hashBytes[:])
}
func getOAuthOrgRole(userInfo *social.BasicUserInfo, cfg *setting.Cfg) map[int64]org.RoleType {
orgRoles := make(map[int64]org.RoleType, 0)
if cfg.OAuthSkipOrgRoleUpdateSync {
return orgRoles
}
if userInfo.Role == "" || !userInfo.Role.IsValid() {
return orgRoles
}
orgID := int64(1)
if cfg.AutoAssignOrg && cfg.AutoAssignOrgId > 0 {
orgID = int64(cfg.AutoAssignOrgId)
}
orgRoles[orgID] = userInfo.Role
return orgRoles
}

View File

@ -0,0 +1,305 @@
package clients
import (
"context"
"net/http"
"net/url"
"testing"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOAuth_Authenticate(t *testing.T) {
type testCase struct {
desc string
req *authn.Request
oauthCfg *social.OAuthInfo
addStateCookie bool
stateCookieValue string
addPKCECookie bool
pkceCookieValue string
isEmailAllowed bool
userInfo *social.BasicUserInfo
expectedErr error
expectedIdentity *authn.Identity
}
tests := []testCase{
{
desc: "should return error when missing state cookie",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{},
expectedErr: errOAuthMissingState,
},
{
desc: "should return error when state cookie is present but don't have a value",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{},
addStateCookie: true,
stateCookieValue: "",
expectedErr: errOAuthMissingState,
},
{
desc: "should return error when state from ipd does not match stored state",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-other-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
addStateCookie: true,
stateCookieValue: "some-state",
expectedErr: errOAuthInvalidState,
},
{
desc: "should return error when pkce is configured but the cookie is not present",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
addStateCookie: true,
stateCookieValue: "some-state",
expectedErr: errOAuthMissingPKCE,
},
{
desc: "should return error when email is empty",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
userInfo: &social.BasicUserInfo{},
expectedErr: errOAuthMissingRequiredEmail,
},
{
desc: "should return error when email is not allowed",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
userInfo: &social.BasicUserInfo{Email: "some@email.com"},
isEmailAllowed: false,
expectedErr: errOAuthEmailNotAllowed,
},
{
desc: "should return identity for valid request",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
isEmailAllowed: true,
userInfo: &social.BasicUserInfo{
Id: "123",
Name: "name",
Email: "some@email.com",
Role: "Admin",
Groups: []string{"grp1", "grp2"},
},
expectedIdentity: &authn.Identity{
Email: "some@email.com",
AuthModule: "oauth_azuread",
AuthID: "123",
Name: "name",
Groups: []string{"grp1", "grp2"},
OAuthToken: &oauth2.Token{},
OrgRoles: map[int64]org.RoleType{1: org.RoleAdmin},
ClientParams: authn.ClientParams{
SyncUser: true,
SyncTeamMembers: true,
AllowSignUp: true,
LookUpParams: login.UserLookupParams{Email: strPtr("some@email.com")},
},
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
cfg := setting.NewCfg()
if tt.addStateCookie {
v := tt.stateCookieValue
if v != "" {
v = hashOAuthState(v, cfg.SecretKey, tt.oauthCfg.ClientSecret)
}
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthStateCookieName, Value: v})
}
if tt.addPKCECookie {
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue})
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, tt.oauthCfg, fakeConnector{
ExpectedUserInfo: tt.userInfo,
ExpectedToken: &oauth2.Token{},
ExpectedIsSignupAllowed: true,
ExpectedIsEmailAllowed: tt.isEmailAllowed,
}, nil)
identity, err := c.Authenticate(context.Background(), tt.req)
assert.ErrorIs(t, err, tt.expectedErr)
if tt.expectedIdentity != nil {
assert.Equal(t, tt.expectedIdentity.Login, identity.Login)
assert.Equal(t, tt.expectedIdentity.Name, identity.Name)
assert.Equal(t, tt.expectedIdentity.Email, identity.Email)
assert.Equal(t, tt.expectedIdentity.AuthID, identity.AuthID)
assert.Equal(t, tt.expectedIdentity.AuthModule, identity.AuthModule)
assert.Equal(t, tt.expectedIdentity.Groups, identity.Groups)
assert.Equal(t, tt.expectedIdentity.ClientParams.SyncUser, identity.ClientParams.SyncUser)
assert.Equal(t, tt.expectedIdentity.ClientParams.AllowSignUp, identity.ClientParams.AllowSignUp)
assert.Equal(t, tt.expectedIdentity.ClientParams.SyncTeamMembers, identity.ClientParams.SyncTeamMembers)
assert.Equal(t, tt.expectedIdentity.ClientParams.EnableDisabledUsers, identity.ClientParams.EnableDisabledUsers)
assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.Email, identity.ClientParams.LookUpParams.Email)
assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.Login, identity.ClientParams.LookUpParams.Login)
assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.UserID, identity.ClientParams.LookUpParams.UserID)
} else {
assert.Nil(t, tt.expectedIdentity)
}
})
}
}
func TestOAuth_RedirectURL(t *testing.T) {
type testCase struct {
desc string
oauthCfg *social.OAuthInfo
expectedErr error
numCallOptions int
authCodeUrlCalled bool
}
tests := []testCase{
{
desc: "should generate redirect url and state",
oauthCfg: &social.OAuthInfo{},
authCodeUrlCalled: true,
},
{
desc: "should generate redirect url with hosted domain option if configured",
oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com"},
numCallOptions: 1,
authCodeUrlCalled: true,
},
{
desc: "should generate redirect url with pkce if configured",
oauthCfg: &social.OAuthInfo{UsePKCE: true},
numCallOptions: 2,
authCodeUrlCalled: true,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
var (
authCodeUrlCalled = false
)
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), tt.oauthCfg, mockConnector{
AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string {
authCodeUrlCalled = true
require.Len(t, opts, tt.numCallOptions)
return ""
},
}, nil)
redirect, err := c.RedirectURL(context.Background(), nil)
assert.ErrorIs(t, err, tt.expectedErr)
assert.Equal(t, tt.authCodeUrlCalled, authCodeUrlCalled)
if tt.expectedErr != nil {
return
}
assert.NotEmpty(t, redirect.Extra[authn.KeyOAuthState])
if tt.oauthCfg.UsePKCE {
assert.NotEmpty(t, redirect.Extra[authn.KeyOAuthPKCE])
}
})
}
}
type mockConnector struct {
AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string
social.SocialConnector
}
func (m mockConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
if m.AuthCodeURLFunc != nil {
return m.AuthCodeURLFunc(state, opts...)
}
return ""
}
var _ social.SocialConnector = new(fakeConnector)
type fakeConnector struct {
ExpectedUserInfo *social.BasicUserInfo
ExpectedUserInfoErr error
ExpectedIsEmailAllowed bool
ExpectedIsSignupAllowed bool
ExpectedToken *oauth2.Token
ExpectedTokenErr error
social.SocialConnector
}
func (f fakeConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
return f.ExpectedUserInfo, f.ExpectedUserInfoErr
}
func (f fakeConnector) IsEmailAllowed(email string) bool {
return f.ExpectedIsEmailAllowed
}
func (f fakeConnector) IsSignupAllowed() bool {
return f.ExpectedIsSignupAllowed
}
func (f fakeConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return f.ExpectedToken, f.ExpectedTokenErr
}
func (f fakeConnector) Client(ctx context.Context, t *oauth2.Token) *http.Client {
return nil
}
func mustParseURL(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}