mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
AuthN: Add oauth clients and perform oauth authentication with authn.Service (#62072)
* AuthN: Update signature of redirect client and RedirectURL function * OAuth: use authn.Service to perform oauth authentication and login if feature toggle is enabled * AuthN: register oauth clients * AuthN: set auth module metadata * AuthN: add logs for failed login attempts * AuthN: Don't use enable disabled setting * OAuth: only run hooks when authnService feature toggle is disabled * OAuth: Add function to handle oauth errors from authn.Service
This commit is contained in:
parent
e3bfc67d7b
commit
efeb0daec6
@ -17,11 +17,14 @@ import (
|
||||
"github.com/grafana/grafana/pkg/login"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/middleware/cookies"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
loginservice "github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
)
|
||||
|
||||
@ -70,11 +73,59 @@ func genPKCECode() (string, string, error) {
|
||||
}
|
||||
|
||||
func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
||||
loginInfo := loginservice.LoginInfo{
|
||||
AuthModule: "oauth",
|
||||
}
|
||||
name := web.Params(ctx.Req)[":name"]
|
||||
loginInfo.AuthModule = name
|
||||
loginInfo := loginservice.LoginInfo{AuthModule: name}
|
||||
|
||||
if errorParam := ctx.Query("error"); errorParam != "" {
|
||||
errorDesc := ctx.Query("error_description")
|
||||
oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc)
|
||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc)
|
||||
return
|
||||
}
|
||||
|
||||
code := ctx.Query("code")
|
||||
|
||||
if hs.Features.IsEnabled(featuremgmt.FlagAuthnService) {
|
||||
req := &authn.Request{HTTPRequest: ctx.Req, Resp: ctx.Resp}
|
||||
if code == "" {
|
||||
redirect, err := hs.authnService.RedirectURL(ctx.Req.Context(), authn.ClientWithPrefix(name), req)
|
||||
if err != nil {
|
||||
hs.handleAuthnOAuthErr(ctx, "failed to generate oauth redirect url", err)
|
||||
return
|
||||
}
|
||||
|
||||
if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" {
|
||||
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
}
|
||||
|
||||
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
ctx.Redirect(redirect.URL)
|
||||
return
|
||||
}
|
||||
|
||||
identity, err := hs.authnService.Login(ctx.Req.Context(), authn.ClientWithPrefix(name), req)
|
||||
// NOTE: always delete these cookies, even if login failed
|
||||
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
||||
cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
||||
|
||||
if err != nil {
|
||||
hs.handleAuthnOAuthErr(ctx, "failed to perform login for oauth request", err)
|
||||
return
|
||||
}
|
||||
|
||||
metrics.MApiLoginOAuth.Inc()
|
||||
cookies.WriteSessionCookie(ctx, hs.Cfg, identity.SessionToken.UnhashedToken, hs.Cfg.LoginMaxLifetime)
|
||||
|
||||
redirectURL := setting.AppSubUrl + "/"
|
||||
if redirectTo := ctx.GetCookie("redirect_to"); len(redirectTo) > 0 && hs.ValidateRedirectTo(redirectTo) == nil {
|
||||
redirectURL = redirectTo
|
||||
cookies.DeleteCookie(ctx.Resp, "redirect_to", hs.CookieOptionsFromCfg)
|
||||
}
|
||||
|
||||
ctx.Redirect(redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
provider := hs.SocialService.GetOAuthInfoProvider(name)
|
||||
if provider == nil {
|
||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, errors.New("OAuth not enabled"))
|
||||
@ -87,15 +138,6 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
||||
return
|
||||
}
|
||||
|
||||
errorParam := ctx.Query("error")
|
||||
if errorParam != "" {
|
||||
errorDesc := ctx.Query("error_description")
|
||||
oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc)
|
||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc)
|
||||
return
|
||||
}
|
||||
|
||||
code := ctx.Query("code")
|
||||
if code == "" {
|
||||
var opts []oauth2.AuthCodeOption
|
||||
if provider.UsePKCE {
|
||||
@ -106,6 +148,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
||||
HttpStatus: http.StatusInternalServerError,
|
||||
PublicMessage: "An internal error occurred",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
@ -345,6 +388,19 @@ func (hs *HTTPServer) hashStatecode(code, seed string) string {
|
||||
return hex.EncodeToString(hashBytes[:])
|
||||
}
|
||||
|
||||
func (hs *HTTPServer) handleAuthnOAuthErr(c *contextmodel.ReqContext, msg string, err error) {
|
||||
gfErr := &errutil.Error{}
|
||||
if errors.As(err, gfErr) {
|
||||
if gfErr.Public().Message != "" {
|
||||
c.Handle(hs.Cfg, gfErr.Public().StatusCode, gfErr.Public().Message, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Logger.Warn(msg, "err", err)
|
||||
c.Redirect(hs.Cfg.AppSubURL + "/login")
|
||||
}
|
||||
|
||||
type LoginError struct {
|
||||
HttpStatus int
|
||||
PublicMessage string
|
||||
@ -354,18 +410,24 @@ type LoginError struct {
|
||||
func (hs *HTTPServer) handleOAuthLoginError(ctx *contextmodel.ReqContext, info loginservice.LoginInfo, err LoginError) {
|
||||
ctx.Handle(hs.Cfg, err.HttpStatus, err.PublicMessage, err.Err)
|
||||
|
||||
info.Error = err.Err
|
||||
if info.Error == nil {
|
||||
info.Error = errors.New(err.PublicMessage)
|
||||
}
|
||||
info.HTTPStatus = err.HttpStatus
|
||||
// login hooks is handled by authn.Service
|
||||
if !hs.Features.IsEnabled(featuremgmt.FlagAuthnService) {
|
||||
info.Error = err.Err
|
||||
if info.Error == nil {
|
||||
info.Error = errors.New(err.PublicMessage)
|
||||
}
|
||||
info.HTTPStatus = err.HttpStatus
|
||||
|
||||
hs.HooksService.RunLoginHook(&info, ctx)
|
||||
hs.HooksService.RunLoginHook(&info, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (hs *HTTPServer) handleOAuthLoginErrorWithRedirect(ctx *contextmodel.ReqContext, info loginservice.LoginInfo, err error, v ...interface{}) {
|
||||
hs.redirectWithError(ctx, err, v...)
|
||||
|
||||
info.Error = err
|
||||
hs.HooksService.RunLoginHook(&info, ctx)
|
||||
// login hooks is handled by authn.Service
|
||||
if !hs.Features.IsEnabled(featuremgmt.FlagAuthnService) {
|
||||
info.Error = err
|
||||
hs.HooksService.RunLoginHook(&info, ctx)
|
||||
}
|
||||
}
|
||||
|
@ -34,6 +34,7 @@ func setupSocialHTTPServerWithConfig(t *testing.T, cfg *setting.Cfg) *HTTPServer
|
||||
SocialService: social.ProvideService(cfg, featuremgmt.WithFeatures()),
|
||||
HooksService: hooks.ProvideService(),
|
||||
SecretsService: fakes.NewFakeSecretsService(),
|
||||
Features: featuremgmt.WithFeatures(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -33,7 +33,7 @@ const (
|
||||
MetaKeyAuthModule = "authModule"
|
||||
)
|
||||
|
||||
// ClientParams are hints to the auth serviAuthN: Post login hooksce about how to handle the identity management
|
||||
// ClientParams are hints to the auth service about how to handle the identity management
|
||||
// from the authenticating client.
|
||||
type ClientParams struct {
|
||||
// Update the internal representation of the entity from the identity provided
|
||||
@ -42,7 +42,7 @@ type ClientParams struct {
|
||||
SyncTeamMembers bool
|
||||
// Create entity in the DB if it doesn't exist
|
||||
AllowSignUp bool
|
||||
// EnableDisabledUsers is a hint to the auth service that it should reenable disabled users
|
||||
// EnableDisabledUsers is a hint to the auth service that it should re-enable disabled users
|
||||
EnableDisabledUsers bool
|
||||
// LookUpParams are the arguments used to look up the entity in the DB.
|
||||
LookUpParams login.UserLookupParams
|
||||
@ -63,7 +63,7 @@ type Service interface {
|
||||
// A lower number means higher priority.
|
||||
RegisterPostLoginHook(hook PostLoginHookFn, priority uint)
|
||||
// RedirectURL will generate url that we can use to initiate auth flow for supported clients.
|
||||
RedirectURL(ctx context.Context, client string, r *Request) (string, error)
|
||||
RedirectURL(ctx context.Context, client string, r *Request) (*Redirect, error)
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
@ -83,7 +83,7 @@ type ContextAwareClient interface {
|
||||
|
||||
type RedirectClient interface {
|
||||
Client
|
||||
RedirectURL(ctx context.Context, r *Request) (string, error)
|
||||
RedirectURL(ctx context.Context, r *Request) (*Redirect, error)
|
||||
}
|
||||
|
||||
type PasswordClient interface {
|
||||
@ -122,6 +122,18 @@ func (r *Request) GetMeta(k string) string {
|
||||
return r.metadata[k]
|
||||
}
|
||||
|
||||
const (
|
||||
KeyOAuthPKCE = "pkce"
|
||||
KeyOAuthState = "state"
|
||||
)
|
||||
|
||||
type Redirect struct {
|
||||
// Url used for redirect
|
||||
URL string
|
||||
// Extra contains data used for redirect, e.g. for oauth this would be state and pkce
|
||||
Extra map[string]string
|
||||
}
|
||||
|
||||
const (
|
||||
NamespaceUser = "user"
|
||||
NamespaceAPIKey = "api-key"
|
||||
@ -144,7 +156,7 @@ type Identity struct {
|
||||
ID string
|
||||
// IsAnonymous
|
||||
IsAnonymous bool
|
||||
// Login is the short hand identifier of the entity. Should be unique.
|
||||
// Login is the shorthand identifier of the entity. Should be unique.
|
||||
Login string
|
||||
// Name is the display name of the entity. It is not guaranteed to be unique.
|
||||
Name string
|
||||
@ -283,3 +295,8 @@ func IdentityFromSignedInUser(id string, usr *user.SignedInUser, params ClientPa
|
||||
ClientParams: params,
|
||||
}
|
||||
}
|
||||
|
||||
// ClientWithPrefix returns a client name prefixed with "auth.client."
|
||||
func ClientWithPrefix(name string) string {
|
||||
return fmt.Sprintf("auth.client.%s", name)
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/network"
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/apikey"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
@ -52,6 +53,7 @@ func ProvideService(
|
||||
loginAttempts loginattempt.Service, quotaService quota.Service,
|
||||
authInfoService login.AuthInfoService, renderService rendering.Service,
|
||||
features *featuremgmt.FeatureManager, oauthTokenService oauthtoken.OAuthTokenService,
|
||||
socialService social.Service,
|
||||
) *Service {
|
||||
s := &Service{
|
||||
log: log.New("authn.service"),
|
||||
@ -116,6 +118,21 @@ func ProvideService(
|
||||
s.RegisterClient(clients.ProvideJWT(jwtService, cfg))
|
||||
}
|
||||
|
||||
for name := range socialService.GetOAuthProviders() {
|
||||
oauthCfg := socialService.GetOAuthInfoProvider(name)
|
||||
if oauthCfg != nil && oauthCfg.Enabled {
|
||||
clientName := authn.ClientWithPrefix(name)
|
||||
|
||||
connector, errConnector := socialService.GetConnector(name)
|
||||
httpClient, errHTTPClient := socialService.GetOAuthHttpClient(name)
|
||||
if errConnector != nil || errHTTPClient != nil {
|
||||
s.log.Error("failed to configure oauth client", "client", clientName, "err", multierror.Append(errConnector, errHTTPClient))
|
||||
}
|
||||
|
||||
s.RegisterClient(clients.ProvideOAuth(clientName, cfg, oauthCfg, connector, httpClient))
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME (jguer): move to User package
|
||||
userSyncService := sync.ProvideUserSync(userService, userProtectionService, authInfoService, quotaService)
|
||||
orgUserSyncService := sync.ProvideOrgSync(userService, orgService, accessControlService)
|
||||
@ -233,6 +250,7 @@ func (s *Service) Login(ctx context.Context, client string, r *authn.Request) (i
|
||||
|
||||
sessionToken, err := s.sessionService.CreateToken(ctx, &user.User{ID: id}, ip, r.HTTPRequest.UserAgent())
|
||||
if err != nil {
|
||||
s.log.FromContext(ctx).Error("failed to create session", "client", client, "userId", id, "err", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -244,19 +262,19 @@ func (s *Service) RegisterPostLoginHook(hook authn.PostLoginHookFn, priority uin
|
||||
s.postLoginHooks.insert(hook, priority)
|
||||
}
|
||||
|
||||
func (s *Service) RedirectURL(ctx context.Context, client string, r *authn.Request) (string, error) {
|
||||
func (s *Service) RedirectURL(ctx context.Context, client string, r *authn.Request) (*authn.Redirect, error) {
|
||||
ctx, span := s.tracer.Start(ctx, "authn.RedirectURL")
|
||||
defer span.End()
|
||||
span.SetAttributes(attributeKeyClient, client, attribute.Key(attributeKeyClient).String(client))
|
||||
|
||||
c, ok := s.clients[client]
|
||||
if !ok {
|
||||
return "", authn.ErrClientNotConfigured.Errorf("client not configured: %s", client)
|
||||
return nil, authn.ErrClientNotConfigured.Errorf("client not configured: %s", client)
|
||||
}
|
||||
|
||||
redirectClient, ok := c.(authn.RedirectClient)
|
||||
if !ok {
|
||||
return "", authn.ErrUnsupportedClient.Errorf("client does not support generating redirect url: %s", client)
|
||||
return nil, authn.ErrUnsupportedClient.Errorf("client does not support generating redirect url: %s", client)
|
||||
}
|
||||
|
||||
return redirectClient.RedirectURL(ctx, r)
|
||||
|
@ -243,16 +243,13 @@ func TestService_RedirectURL(t *testing.T) {
|
||||
type testCase struct {
|
||||
desc string
|
||||
client string
|
||||
expectedURL string
|
||||
expectedErr error
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "should generate url for valid redirect client",
|
||||
client: "redirect",
|
||||
expectedURL: "https://localhost/redirect",
|
||||
expectedErr: nil,
|
||||
desc: "should generate url for valid redirect client",
|
||||
client: "redirect",
|
||||
},
|
||||
{
|
||||
desc: "should return error on non existing client",
|
||||
@ -269,13 +266,12 @@ func TestService_RedirectURL(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
service := setupTests(t, func(svc *Service) {
|
||||
svc.RegisterClient(authntest.FakeRedirectClient{ExpectedName: "redirect", ExpectedURL: tt.expectedURL})
|
||||
svc.RegisterClient(authntest.FakeRedirectClient{ExpectedName: "redirect"})
|
||||
svc.RegisterClient(&authntest.FakeClient{ExpectedName: "non-redirect"})
|
||||
})
|
||||
|
||||
u, err := service.RedirectURL(context.Background(), tt.client, nil)
|
||||
_, err := service.RedirectURL(context.Background(), tt.client, nil)
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
assert.Equal(t, tt.expectedURL, u)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -53,6 +53,8 @@ type FakeRedirectClient struct {
|
||||
ExpectedErr error
|
||||
ExpectedURL string
|
||||
ExpectedName string
|
||||
ExpectedOK bool
|
||||
ExpectedRedirect *authn.Redirect
|
||||
ExpectedIdentity *authn.Identity
|
||||
}
|
||||
|
||||
@ -64,6 +66,10 @@ func (f FakeRedirectClient) Authenticate(ctx context.Context, r *authn.Request)
|
||||
return f.ExpectedIdentity, f.ExpectedErr
|
||||
}
|
||||
|
||||
func (f FakeRedirectClient) RedirectURL(ctx context.Context, r *authn.Request) (string, error) {
|
||||
return f.ExpectedURL, f.ExpectedErr
|
||||
func (f FakeRedirectClient) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) {
|
||||
return f.ExpectedRedirect, f.ExpectedErr
|
||||
}
|
||||
|
||||
func (f FakeRedirectClient) Test(ctx context.Context, r *authn.Request) bool {
|
||||
return f.ExpectedOK
|
||||
}
|
||||
|
237
pkg/services/authn/clients/oauth.go
Normal file
237
pkg/services/authn/clients/oauth.go
Normal file
@ -0,0 +1,237 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
)
|
||||
|
||||
const (
|
||||
hostedDomainParamName = "hd"
|
||||
codeVerifierParamName = "code_verifier"
|
||||
codeChallengeParamName = "code_challenge"
|
||||
codeChallengeMethodParamName = "code_challenge_method"
|
||||
codeChallengeMethod = "S256"
|
||||
|
||||
oauthStateQueryName = "state"
|
||||
oauthStateCookieName = "oauth_state"
|
||||
oauthPKCECookieName = "oauth_code_verifier"
|
||||
)
|
||||
|
||||
var (
|
||||
errOAuthGenPKCE = errutil.NewBase(errutil.StatusInternal, "auth.oauth.pkce.internal", errutil.WithPublicMessage("An internal error occurred"))
|
||||
errOAuthMissingPKCE = errutil.NewBase(errutil.StatusBadRequest, "auth.oauth.pkce.missing", errutil.WithPublicMessage("Missing required pkce cookie"))
|
||||
|
||||
errOAuthGenState = errutil.NewBase(errutil.StatusInternal, "auth.oauth.state.internal", errutil.WithPublicMessage("An internal error occurred"))
|
||||
errOAuthMissingState = errutil.NewBase(errutil.StatusBadRequest, "auth.oauth.state.missing", errutil.WithPublicMessage("Missing saved oauth state"))
|
||||
errOAuthInvalidState = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.state.invalid", errutil.WithPublicMessage("Provided state does not match stored state"))
|
||||
|
||||
errOAuthTokenExchange = errutil.NewBase(errutil.StatusInternal, "auth.oauth.token.exchange", errutil.WithPublicMessage("Failed to get token from provider"))
|
||||
|
||||
errOAuthMissingRequiredEmail = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.email.missing")
|
||||
errOAuthEmailNotAllowed = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.email.not-allowed")
|
||||
)
|
||||
|
||||
var _ authn.RedirectClient = new(OAuth)
|
||||
|
||||
func ProvideOAuth(
|
||||
name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo,
|
||||
connector social.SocialConnector, httpClient *http.Client,
|
||||
) *OAuth {
|
||||
return &OAuth{
|
||||
name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")),
|
||||
log.New(name), cfg, oauthCfg, connector, httpClient,
|
||||
}
|
||||
}
|
||||
|
||||
type OAuth struct {
|
||||
name string
|
||||
moduleName string
|
||||
log log.Logger
|
||||
cfg *setting.Cfg
|
||||
oauthCfg *social.OAuthInfo
|
||||
connector social.SocialConnector
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func (c *OAuth) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
|
||||
r.SetMeta(authn.MetaKeyAuthModule, c.moduleName)
|
||||
// get hashed state stored in cookie
|
||||
stateCookie, err := r.HTTPRequest.Cookie(oauthStateCookieName)
|
||||
if err != nil {
|
||||
return nil, errOAuthMissingState.Errorf("missing state cookie")
|
||||
}
|
||||
|
||||
if stateCookie.Value == "" {
|
||||
return nil, errOAuthMissingState.Errorf("missing state value in state cookie")
|
||||
}
|
||||
|
||||
// get state returned by the idp and hash it
|
||||
stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, c.oauthCfg.ClientSecret)
|
||||
// compare the state returned by idp against the one we stored in cookie
|
||||
if stateQuery != stateCookie.Value {
|
||||
return nil, errOAuthInvalidState.Errorf("provided state did not match stored state")
|
||||
}
|
||||
|
||||
var opts []oauth2.AuthCodeOption
|
||||
// if pkce is enabled for client validate we have the cookie and set it as url param
|
||||
if c.oauthCfg.UsePKCE {
|
||||
pkceCookie, err := r.HTTPRequest.Cookie(oauthPKCECookieName)
|
||||
if err != nil {
|
||||
return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err)
|
||||
}
|
||||
opts = append(opts, oauth2.SetAuthURLParam(codeVerifierParamName, pkceCookie.Value))
|
||||
}
|
||||
|
||||
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
|
||||
// exchange auth code to a valid token
|
||||
token, err := c.connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token.TokenType = "Bearer"
|
||||
|
||||
userInfo, err := c.connector.UserInfo(c.connector.Client(clientCtx, token), token)
|
||||
if err != nil {
|
||||
return nil, errOAuthTokenExchange.Errorf("failed to exchange code to token: %w", err)
|
||||
}
|
||||
|
||||
if userInfo.Email == "" {
|
||||
return nil, errOAuthMissingRequiredEmail.Errorf("required attribute email was not provided")
|
||||
}
|
||||
|
||||
if !c.connector.IsEmailAllowed(userInfo.Email) {
|
||||
return nil, errOAuthEmailNotAllowed.Errorf("provided email is not allowed")
|
||||
}
|
||||
|
||||
return &authn.Identity{
|
||||
Login: userInfo.Login,
|
||||
Name: userInfo.Name,
|
||||
Email: userInfo.Email,
|
||||
IsGrafanaAdmin: userInfo.IsGrafanaAdmin,
|
||||
AuthModule: c.moduleName,
|
||||
AuthID: userInfo.Id,
|
||||
Groups: userInfo.Groups,
|
||||
OAuthToken: token,
|
||||
OrgRoles: getOAuthOrgRole(userInfo, c.cfg),
|
||||
ClientParams: authn.ClientParams{
|
||||
SyncUser: true,
|
||||
SyncTeamMembers: true,
|
||||
AllowSignUp: c.connector.IsSignupAllowed(),
|
||||
LookUpParams: login.UserLookupParams{Email: &userInfo.Email},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) {
|
||||
var opts []oauth2.AuthCodeOption
|
||||
|
||||
if c.oauthCfg.HostedDomain != "" {
|
||||
opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, c.oauthCfg.HostedDomain))
|
||||
}
|
||||
|
||||
var plainPKCE string
|
||||
if c.oauthCfg.UsePKCE {
|
||||
pkce, hashedPKCE, err := genPKCECode()
|
||||
if err != nil {
|
||||
return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err)
|
||||
}
|
||||
|
||||
plainPKCE = pkce
|
||||
opts = append(opts,
|
||||
oauth2.SetAuthURLParam(codeChallengeParamName, hashedPKCE),
|
||||
oauth2.SetAuthURLParam(codeChallengeMethodParamName, codeChallengeMethod),
|
||||
)
|
||||
}
|
||||
|
||||
state, hashedSate, err := genOAuthState(c.cfg.SecretKey, c.oauthCfg.ClientSecret)
|
||||
if err != nil {
|
||||
return nil, errOAuthGenState.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
return &authn.Redirect{
|
||||
URL: c.connector.AuthCodeURL(state, opts...),
|
||||
Extra: map[string]string{
|
||||
authn.KeyOAuthState: hashedSate,
|
||||
authn.KeyOAuthPKCE: plainPKCE,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest.
|
||||
func genPKCECode() (string, string, error) {
|
||||
// IETF RFC 7636 specifies that the code verifier should be 43-128
|
||||
// characters from a set of unreserved URI characters which is
|
||||
// almost the same as the set of characters in base64url.
|
||||
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
|
||||
//
|
||||
// It doesn't hurt to generate a few more bytes here, we generate
|
||||
// 96 bytes which we then encode using base64url to make sure
|
||||
// they're within the set of unreserved characters.
|
||||
//
|
||||
// 96 is chosen because 96*8/6 = 128, which means that we'll have
|
||||
// 128 characters after it has been base64 encoded.
|
||||
raw := make([]byte, 96)
|
||||
_, err := rand.Read(raw)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
ascii := make([]byte, 128)
|
||||
base64.RawURLEncoding.Encode(ascii, raw)
|
||||
|
||||
shasum := sha256.Sum256(ascii)
|
||||
pkce := base64.RawURLEncoding.EncodeToString(shasum[:])
|
||||
return string(ascii), pkce, nil
|
||||
}
|
||||
|
||||
func genOAuthState(secret, seed string) (string, string, error) {
|
||||
rnd := make([]byte, 32)
|
||||
if _, err := rand.Read(rnd); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
state := base64.URLEncoding.EncodeToString(rnd)
|
||||
return state, hashOAuthState(state, secret, seed), nil
|
||||
}
|
||||
|
||||
func hashOAuthState(state, secret, seed string) string {
|
||||
hashBytes := sha256.Sum256([]byte(state + secret + seed))
|
||||
return hex.EncodeToString(hashBytes[:])
|
||||
}
|
||||
|
||||
func getOAuthOrgRole(userInfo *social.BasicUserInfo, cfg *setting.Cfg) map[int64]org.RoleType {
|
||||
orgRoles := make(map[int64]org.RoleType, 0)
|
||||
if cfg.OAuthSkipOrgRoleUpdateSync {
|
||||
return orgRoles
|
||||
}
|
||||
|
||||
if userInfo.Role == "" || !userInfo.Role.IsValid() {
|
||||
return orgRoles
|
||||
}
|
||||
|
||||
orgID := int64(1)
|
||||
if cfg.AutoAssignOrg && cfg.AutoAssignOrgId > 0 {
|
||||
orgID = int64(cfg.AutoAssignOrgId)
|
||||
}
|
||||
|
||||
orgRoles[orgID] = userInfo.Role
|
||||
return orgRoles
|
||||
}
|
305
pkg/services/authn/clients/oauth_test.go
Normal file
305
pkg/services/authn/clients/oauth_test.go
Normal file
@ -0,0 +1,305 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOAuth_Authenticate(t *testing.T) {
|
||||
type testCase struct {
|
||||
desc string
|
||||
req *authn.Request
|
||||
oauthCfg *social.OAuthInfo
|
||||
|
||||
addStateCookie bool
|
||||
stateCookieValue string
|
||||
|
||||
addPKCECookie bool
|
||||
pkceCookieValue string
|
||||
|
||||
isEmailAllowed bool
|
||||
userInfo *social.BasicUserInfo
|
||||
|
||||
expectedErr error
|
||||
expectedIdentity *authn.Identity
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "should return error when missing state cookie",
|
||||
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
|
||||
oauthCfg: &social.OAuthInfo{},
|
||||
expectedErr: errOAuthMissingState,
|
||||
},
|
||||
{
|
||||
desc: "should return error when state cookie is present but don't have a value",
|
||||
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
|
||||
oauthCfg: &social.OAuthInfo{},
|
||||
addStateCookie: true,
|
||||
stateCookieValue: "",
|
||||
expectedErr: errOAuthMissingState,
|
||||
},
|
||||
{
|
||||
desc: "should return error when state from ipd does not match stored state",
|
||||
req: &authn.Request{HTTPRequest: &http.Request{
|
||||
Header: map[string][]string{},
|
||||
URL: mustParseURL("http://grafana.com/?state=some-other-state"),
|
||||
},
|
||||
},
|
||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
||||
addStateCookie: true,
|
||||
stateCookieValue: "some-state",
|
||||
expectedErr: errOAuthInvalidState,
|
||||
},
|
||||
{
|
||||
desc: "should return error when pkce is configured but the cookie is not present",
|
||||
req: &authn.Request{HTTPRequest: &http.Request{
|
||||
Header: map[string][]string{},
|
||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||
},
|
||||
},
|
||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
||||
addStateCookie: true,
|
||||
stateCookieValue: "some-state",
|
||||
expectedErr: errOAuthMissingPKCE,
|
||||
},
|
||||
{
|
||||
desc: "should return error when email is empty",
|
||||
req: &authn.Request{HTTPRequest: &http.Request{
|
||||
Header: map[string][]string{},
|
||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||
},
|
||||
},
|
||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
||||
addStateCookie: true,
|
||||
stateCookieValue: "some-state",
|
||||
addPKCECookie: true,
|
||||
pkceCookieValue: "some-pkce-value",
|
||||
userInfo: &social.BasicUserInfo{},
|
||||
expectedErr: errOAuthMissingRequiredEmail,
|
||||
},
|
||||
{
|
||||
desc: "should return error when email is not allowed",
|
||||
req: &authn.Request{HTTPRequest: &http.Request{
|
||||
Header: map[string][]string{},
|
||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||
},
|
||||
},
|
||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
||||
addStateCookie: true,
|
||||
stateCookieValue: "some-state",
|
||||
addPKCECookie: true,
|
||||
pkceCookieValue: "some-pkce-value",
|
||||
userInfo: &social.BasicUserInfo{Email: "some@email.com"},
|
||||
isEmailAllowed: false,
|
||||
expectedErr: errOAuthEmailNotAllowed,
|
||||
},
|
||||
{
|
||||
desc: "should return identity for valid request",
|
||||
req: &authn.Request{HTTPRequest: &http.Request{
|
||||
Header: map[string][]string{},
|
||||
URL: mustParseURL("http://grafana.com/?state=some-state"),
|
||||
},
|
||||
},
|
||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
||||
addStateCookie: true,
|
||||
stateCookieValue: "some-state",
|
||||
addPKCECookie: true,
|
||||
pkceCookieValue: "some-pkce-value",
|
||||
isEmailAllowed: true,
|
||||
userInfo: &social.BasicUserInfo{
|
||||
Id: "123",
|
||||
Name: "name",
|
||||
Email: "some@email.com",
|
||||
Role: "Admin",
|
||||
Groups: []string{"grp1", "grp2"},
|
||||
},
|
||||
expectedIdentity: &authn.Identity{
|
||||
Email: "some@email.com",
|
||||
AuthModule: "oauth_azuread",
|
||||
AuthID: "123",
|
||||
Name: "name",
|
||||
Groups: []string{"grp1", "grp2"},
|
||||
OAuthToken: &oauth2.Token{},
|
||||
OrgRoles: map[int64]org.RoleType{1: org.RoleAdmin},
|
||||
ClientParams: authn.ClientParams{
|
||||
SyncUser: true,
|
||||
SyncTeamMembers: true,
|
||||
AllowSignUp: true,
|
||||
LookUpParams: login.UserLookupParams{Email: strPtr("some@email.com")},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
cfg := setting.NewCfg()
|
||||
|
||||
if tt.addStateCookie {
|
||||
v := tt.stateCookieValue
|
||||
if v != "" {
|
||||
v = hashOAuthState(v, cfg.SecretKey, tt.oauthCfg.ClientSecret)
|
||||
}
|
||||
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthStateCookieName, Value: v})
|
||||
}
|
||||
|
||||
if tt.addPKCECookie {
|
||||
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue})
|
||||
}
|
||||
|
||||
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, tt.oauthCfg, fakeConnector{
|
||||
ExpectedUserInfo: tt.userInfo,
|
||||
ExpectedToken: &oauth2.Token{},
|
||||
ExpectedIsSignupAllowed: true,
|
||||
ExpectedIsEmailAllowed: tt.isEmailAllowed,
|
||||
}, nil)
|
||||
identity, err := c.Authenticate(context.Background(), tt.req)
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
|
||||
if tt.expectedIdentity != nil {
|
||||
assert.Equal(t, tt.expectedIdentity.Login, identity.Login)
|
||||
assert.Equal(t, tt.expectedIdentity.Name, identity.Name)
|
||||
assert.Equal(t, tt.expectedIdentity.Email, identity.Email)
|
||||
assert.Equal(t, tt.expectedIdentity.AuthID, identity.AuthID)
|
||||
assert.Equal(t, tt.expectedIdentity.AuthModule, identity.AuthModule)
|
||||
assert.Equal(t, tt.expectedIdentity.Groups, identity.Groups)
|
||||
|
||||
assert.Equal(t, tt.expectedIdentity.ClientParams.SyncUser, identity.ClientParams.SyncUser)
|
||||
assert.Equal(t, tt.expectedIdentity.ClientParams.AllowSignUp, identity.ClientParams.AllowSignUp)
|
||||
assert.Equal(t, tt.expectedIdentity.ClientParams.SyncTeamMembers, identity.ClientParams.SyncTeamMembers)
|
||||
assert.Equal(t, tt.expectedIdentity.ClientParams.EnableDisabledUsers, identity.ClientParams.EnableDisabledUsers)
|
||||
|
||||
assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.Email, identity.ClientParams.LookUpParams.Email)
|
||||
assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.Login, identity.ClientParams.LookUpParams.Login)
|
||||
assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.UserID, identity.ClientParams.LookUpParams.UserID)
|
||||
} else {
|
||||
assert.Nil(t, tt.expectedIdentity)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth_RedirectURL(t *testing.T) {
|
||||
type testCase struct {
|
||||
desc string
|
||||
oauthCfg *social.OAuthInfo
|
||||
expectedErr error
|
||||
|
||||
numCallOptions int
|
||||
authCodeUrlCalled bool
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "should generate redirect url and state",
|
||||
oauthCfg: &social.OAuthInfo{},
|
||||
authCodeUrlCalled: true,
|
||||
},
|
||||
{
|
||||
desc: "should generate redirect url with hosted domain option if configured",
|
||||
oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com"},
|
||||
numCallOptions: 1,
|
||||
authCodeUrlCalled: true,
|
||||
},
|
||||
{
|
||||
desc: "should generate redirect url with pkce if configured",
|
||||
oauthCfg: &social.OAuthInfo{UsePKCE: true},
|
||||
numCallOptions: 2,
|
||||
authCodeUrlCalled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
var (
|
||||
authCodeUrlCalled = false
|
||||
)
|
||||
|
||||
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), tt.oauthCfg, mockConnector{
|
||||
AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string {
|
||||
authCodeUrlCalled = true
|
||||
require.Len(t, opts, tt.numCallOptions)
|
||||
return ""
|
||||
},
|
||||
}, nil)
|
||||
|
||||
redirect, err := c.RedirectURL(context.Background(), nil)
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
assert.Equal(t, tt.authCodeUrlCalled, authCodeUrlCalled)
|
||||
|
||||
if tt.expectedErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.NotEmpty(t, redirect.Extra[authn.KeyOAuthState])
|
||||
if tt.oauthCfg.UsePKCE {
|
||||
assert.NotEmpty(t, redirect.Extra[authn.KeyOAuthPKCE])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockConnector struct {
|
||||
AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string
|
||||
social.SocialConnector
|
||||
}
|
||||
|
||||
func (m mockConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
|
||||
if m.AuthCodeURLFunc != nil {
|
||||
return m.AuthCodeURLFunc(state, opts...)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var _ social.SocialConnector = new(fakeConnector)
|
||||
|
||||
type fakeConnector struct {
|
||||
ExpectedUserInfo *social.BasicUserInfo
|
||||
ExpectedUserInfoErr error
|
||||
ExpectedIsEmailAllowed bool
|
||||
ExpectedIsSignupAllowed bool
|
||||
ExpectedToken *oauth2.Token
|
||||
ExpectedTokenErr error
|
||||
social.SocialConnector
|
||||
}
|
||||
|
||||
func (f fakeConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
||||
return f.ExpectedUserInfo, f.ExpectedUserInfoErr
|
||||
}
|
||||
|
||||
func (f fakeConnector) IsEmailAllowed(email string) bool {
|
||||
return f.ExpectedIsEmailAllowed
|
||||
}
|
||||
|
||||
func (f fakeConnector) IsSignupAllowed() bool {
|
||||
return f.ExpectedIsSignupAllowed
|
||||
}
|
||||
|
||||
func (f fakeConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
|
||||
return f.ExpectedToken, f.ExpectedTokenErr
|
||||
}
|
||||
|
||||
func (f fakeConnector) Client(ctx context.Context, t *oauth2.Token) *http.Client {
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustParseURL(s string) *url.URL {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
Loading…
Reference in New Issue
Block a user