mirror of
https://github.com/grafana/grafana.git
synced 2024-11-30 04:34:23 -06:00
6543259a7d
* Add SyncPermissionsFromDB post auth hook * Delete FromDB prefix * Align tests * Fixes * Change SyncPermissionsHook prio
240 lines
8.3 KiB
Go
240 lines
8.3 KiB
Go
package clients
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"golang.org/x/oauth2"
|
|
|
|
"github.com/grafana/grafana/pkg/infra/log"
|
|
"github.com/grafana/grafana/pkg/login/social"
|
|
"github.com/grafana/grafana/pkg/services/authn"
|
|
"github.com/grafana/grafana/pkg/services/login"
|
|
"github.com/grafana/grafana/pkg/services/org"
|
|
"github.com/grafana/grafana/pkg/setting"
|
|
"github.com/grafana/grafana/pkg/util/errutil"
|
|
)
|
|
|
|
const (
|
|
hostedDomainParamName = "hd"
|
|
codeVerifierParamName = "code_verifier"
|
|
codeChallengeParamName = "code_challenge"
|
|
codeChallengeMethodParamName = "code_challenge_method"
|
|
codeChallengeMethod = "S256"
|
|
|
|
oauthStateQueryName = "state"
|
|
oauthStateCookieName = "oauth_state"
|
|
oauthPKCECookieName = "oauth_code_verifier"
|
|
)
|
|
|
|
var (
|
|
errOAuthGenPKCE = errutil.NewBase(errutil.StatusInternal, "auth.oauth.pkce.internal", errutil.WithPublicMessage("An internal error occurred"))
|
|
errOAuthMissingPKCE = errutil.NewBase(errutil.StatusBadRequest, "auth.oauth.pkce.missing", errutil.WithPublicMessage("Missing required pkce cookie"))
|
|
|
|
errOAuthGenState = errutil.NewBase(errutil.StatusInternal, "auth.oauth.state.internal", errutil.WithPublicMessage("An internal error occurred"))
|
|
errOAuthMissingState = errutil.NewBase(errutil.StatusBadRequest, "auth.oauth.state.missing", errutil.WithPublicMessage("Missing saved oauth state"))
|
|
errOAuthInvalidState = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.state.invalid", errutil.WithPublicMessage("Provided state does not match stored state"))
|
|
|
|
errOAuthTokenExchange = errutil.NewBase(errutil.StatusInternal, "auth.oauth.token.exchange", errutil.WithPublicMessage("Failed to get token from provider"))
|
|
errOAuthUserInfo = errutil.NewBase(errutil.StatusInternal, "auth.oauth.userinfo.error")
|
|
|
|
errOAuthMissingRequiredEmail = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.email.missing", errutil.WithPublicMessage("Provider didn't return an email address"))
|
|
errOAuthEmailNotAllowed = errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.email.not-allowed", errutil.WithPublicMessage("Required email domain not fulfilled"))
|
|
)
|
|
|
|
func fromSocialErr(err *social.Error) error {
|
|
return errutil.NewBase(errutil.StatusUnauthorized, "auth.oauth.userinfo.failed", errutil.WithPublicMessage(err.Error())).Errorf("%w", err)
|
|
}
|
|
|
|
var _ authn.RedirectClient = new(OAuth)
|
|
|
|
func ProvideOAuth(
|
|
name string, cfg *setting.Cfg, oauthCfg *social.OAuthInfo,
|
|
connector social.SocialConnector, httpClient *http.Client,
|
|
) *OAuth {
|
|
return &OAuth{
|
|
name, fmt.Sprintf("oauth_%s", strings.TrimPrefix(name, "auth.client.")),
|
|
log.New(name), cfg, oauthCfg, connector, httpClient,
|
|
}
|
|
}
|
|
|
|
type OAuth struct {
|
|
name string
|
|
moduleName string
|
|
log log.Logger
|
|
cfg *setting.Cfg
|
|
oauthCfg *social.OAuthInfo
|
|
connector social.SocialConnector
|
|
httpClient *http.Client
|
|
}
|
|
|
|
func (c *OAuth) Name() string {
|
|
return c.name
|
|
}
|
|
|
|
func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
|
|
r.SetMeta(authn.MetaKeyAuthModule, c.moduleName)
|
|
// get hashed state stored in cookie
|
|
stateCookie, err := r.HTTPRequest.Cookie(oauthStateCookieName)
|
|
if err != nil {
|
|
return nil, errOAuthMissingState.Errorf("missing state cookie")
|
|
}
|
|
|
|
if stateCookie.Value == "" {
|
|
return nil, errOAuthMissingState.Errorf("missing state value in state cookie")
|
|
}
|
|
|
|
// get state returned by the idp and hash it
|
|
stateQuery := hashOAuthState(r.HTTPRequest.URL.Query().Get(oauthStateQueryName), c.cfg.SecretKey, c.oauthCfg.ClientSecret)
|
|
// compare the state returned by idp against the one we stored in cookie
|
|
if stateQuery != stateCookie.Value {
|
|
return nil, errOAuthInvalidState.Errorf("provided state did not match stored state")
|
|
}
|
|
|
|
var opts []oauth2.AuthCodeOption
|
|
// if pkce is enabled for client validate we have the cookie and set it as url param
|
|
if c.oauthCfg.UsePKCE {
|
|
pkceCookie, err := r.HTTPRequest.Cookie(oauthPKCECookieName)
|
|
if err != nil {
|
|
return nil, errOAuthMissingPKCE.Errorf("no pkce cookie found: %w", err)
|
|
}
|
|
opts = append(opts, oauth2.SetAuthURLParam(codeVerifierParamName, pkceCookie.Value))
|
|
}
|
|
|
|
clientCtx := context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
|
|
// exchange auth code to a valid token
|
|
token, err := c.connector.Exchange(clientCtx, r.HTTPRequest.URL.Query().Get("code"), opts...)
|
|
if err != nil {
|
|
return nil, errOAuthTokenExchange.Errorf("failed to exchange code to token: %w", err)
|
|
}
|
|
token.TokenType = "Bearer"
|
|
|
|
userInfo, err := c.connector.UserInfo(c.connector.Client(clientCtx, token), token)
|
|
if err != nil {
|
|
var sErr *social.Error
|
|
if errors.As(err, &sErr) {
|
|
return nil, fromSocialErr(sErr)
|
|
}
|
|
return nil, errOAuthUserInfo.Errorf("failed to get user info: %w", err)
|
|
}
|
|
|
|
if userInfo.Email == "" {
|
|
return nil, errOAuthMissingRequiredEmail.Errorf("required attribute email was not provided")
|
|
}
|
|
|
|
if !c.connector.IsEmailAllowed(userInfo.Email) {
|
|
return nil, errOAuthEmailNotAllowed.Errorf("provided email is not allowed")
|
|
}
|
|
|
|
orgRoles, isGrafanaAdmin, _ := getRoles(c.cfg, func() (org.RoleType, *bool, error) {
|
|
if c.cfg.OAuthSkipOrgRoleUpdateSync {
|
|
return "", nil, nil
|
|
}
|
|
return userInfo.Role, userInfo.IsGrafanaAdmin, nil
|
|
})
|
|
|
|
return &authn.Identity{
|
|
Login: userInfo.Login,
|
|
Name: userInfo.Name,
|
|
Email: userInfo.Email,
|
|
IsGrafanaAdmin: isGrafanaAdmin,
|
|
AuthModule: c.moduleName,
|
|
AuthID: userInfo.Id,
|
|
Groups: userInfo.Groups,
|
|
OAuthToken: token,
|
|
OrgRoles: orgRoles,
|
|
ClientParams: authn.ClientParams{
|
|
SyncUser: true,
|
|
SyncTeams: true,
|
|
FetchSyncedUser: true,
|
|
SyncPermissions: true,
|
|
AllowSignUp: c.connector.IsSignupAllowed(),
|
|
// skip org role flag is checked and handled in the connector. For now we can skip the hook if no roles are passed
|
|
SyncOrgRoles: len(orgRoles) > 0,
|
|
LookUpParams: login.UserLookupParams{Email: &userInfo.Email},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (c *OAuth) RedirectURL(ctx context.Context, r *authn.Request) (*authn.Redirect, error) {
|
|
var opts []oauth2.AuthCodeOption
|
|
|
|
if c.oauthCfg.HostedDomain != "" {
|
|
opts = append(opts, oauth2.SetAuthURLParam(hostedDomainParamName, c.oauthCfg.HostedDomain))
|
|
}
|
|
|
|
var plainPKCE string
|
|
if c.oauthCfg.UsePKCE {
|
|
pkce, hashedPKCE, err := genPKCECode()
|
|
if err != nil {
|
|
return nil, errOAuthGenPKCE.Errorf("failed to generate pkce: %w", err)
|
|
}
|
|
|
|
plainPKCE = pkce
|
|
opts = append(opts,
|
|
oauth2.SetAuthURLParam(codeChallengeParamName, hashedPKCE),
|
|
oauth2.SetAuthURLParam(codeChallengeMethodParamName, codeChallengeMethod),
|
|
)
|
|
}
|
|
|
|
state, hashedSate, err := genOAuthState(c.cfg.SecretKey, c.oauthCfg.ClientSecret)
|
|
if err != nil {
|
|
return nil, errOAuthGenState.Errorf("failed to generate state: %w", err)
|
|
}
|
|
|
|
return &authn.Redirect{
|
|
URL: c.connector.AuthCodeURL(state, opts...),
|
|
Extra: map[string]string{
|
|
authn.KeyOAuthState: hashedSate,
|
|
authn.KeyOAuthPKCE: plainPKCE,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// genPKCECode returns a random URL-friendly string and it's base64 URL encoded SHA256 digest.
|
|
func genPKCECode() (string, string, error) {
|
|
// IETF RFC 7636 specifies that the code verifier should be 43-128
|
|
// characters from a set of unreserved URI characters which is
|
|
// almost the same as the set of characters in base64url.
|
|
// https://datatracker.ietf.org/doc/html/rfc7636#section-4.1
|
|
//
|
|
// It doesn't hurt to generate a few more bytes here, we generate
|
|
// 96 bytes which we then encode using base64url to make sure
|
|
// they're within the set of unreserved characters.
|
|
//
|
|
// 96 is chosen because 96*8/6 = 128, which means that we'll have
|
|
// 128 characters after it has been base64 encoded.
|
|
raw := make([]byte, 96)
|
|
_, err := rand.Read(raw)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
ascii := make([]byte, 128)
|
|
base64.RawURLEncoding.Encode(ascii, raw)
|
|
|
|
shasum := sha256.Sum256(ascii)
|
|
pkce := base64.RawURLEncoding.EncodeToString(shasum[:])
|
|
return string(ascii), pkce, nil
|
|
}
|
|
|
|
func genOAuthState(secret, seed string) (string, string, error) {
|
|
rnd := make([]byte, 32)
|
|
if _, err := rand.Read(rnd); err != nil {
|
|
return "", "", err
|
|
}
|
|
state := base64.URLEncoding.EncodeToString(rnd)
|
|
return state, hashOAuthState(state, secret, seed), nil
|
|
}
|
|
|
|
func hashOAuthState(state, secret, seed string) string {
|
|
hashBytes := sha256.Sum256([]byte(state + secret + seed))
|
|
return hex.EncodeToString(hashBytes[:])
|
|
}
|