mirror of
https://github.com/grafana/grafana.git
synced 2024-12-02 05:29:42 -06:00
239 lines
6.8 KiB
Go
239 lines
6.8 KiB
Go
package clients
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-jose/go-jose/v3"
|
|
"github.com/go-jose/go-jose/v3/jwt"
|
|
|
|
"github.com/grafana/grafana/pkg/infra/log"
|
|
"github.com/grafana/grafana/pkg/services/authn"
|
|
"github.com/grafana/grafana/pkg/services/extsvcauth/oauthserver"
|
|
"github.com/grafana/grafana/pkg/services/login"
|
|
"github.com/grafana/grafana/pkg/services/signingkeys"
|
|
"github.com/grafana/grafana/pkg/services/user"
|
|
"github.com/grafana/grafana/pkg/setting"
|
|
)
|
|
|
|
var _ authn.Client = new(ExtendedJWT)
|
|
|
|
var (
|
|
acceptedSigningMethods = []string{"RS256", "ES256"}
|
|
timeNow = time.Now
|
|
)
|
|
|
|
const (
|
|
rfc9068ShortMediaType = "at+jwt"
|
|
rfc9068MediaType = "application/at+jwt"
|
|
)
|
|
|
|
func ProvideExtendedJWT(userService user.Service, cfg *setting.Cfg, signingKeys signingkeys.Service, oauthServer oauthserver.OAuth2Server) *ExtendedJWT {
|
|
return &ExtendedJWT{
|
|
cfg: cfg,
|
|
log: log.New(authn.ClientExtendedJWT),
|
|
userService: userService,
|
|
signingKeys: signingKeys,
|
|
oauthServer: oauthServer,
|
|
}
|
|
}
|
|
|
|
type ExtendedJWT struct {
|
|
cfg *setting.Cfg
|
|
log log.Logger
|
|
userService user.Service
|
|
signingKeys signingkeys.Service
|
|
oauthServer oauthserver.OAuth2Server
|
|
}
|
|
|
|
type ExtendedJWTClaims struct {
|
|
jwt.Claims
|
|
ClientID string `json:"client_id"`
|
|
Groups []string `json:"groups"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
Login string `json:"login"`
|
|
Scopes []string `json:"scope"`
|
|
Entitlements map[string][]string `json:"entitlements"`
|
|
}
|
|
|
|
func (s *ExtendedJWT) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
|
|
jwtToken := s.retrieveToken(r.HTTPRequest)
|
|
|
|
claims, err := s.verifyRFC9068Token(ctx, jwtToken)
|
|
if err != nil {
|
|
s.log.Error("Failed to verify JWT", "error", err)
|
|
return nil, errJWTInvalid.Errorf("Failed to verify JWT: %w", err)
|
|
}
|
|
|
|
// user:id:18
|
|
userID, err := strconv.ParseInt(strings.TrimPrefix(claims.Subject, fmt.Sprintf("%s:id:", authn.NamespaceUser)), 10, 64)
|
|
if err != nil {
|
|
s.log.Error("Failed to parse sub", "error", err)
|
|
return nil, errJWTInvalid.Errorf("Failed to parse sub: %w", err)
|
|
}
|
|
|
|
// FIXME: support multiple organizations
|
|
defaultOrgID := s.getDefaultOrgID()
|
|
if r.OrgID != defaultOrgID {
|
|
s.log.Error("Failed to verify the Organization: OrgID is not the default")
|
|
return nil, errJWTInvalid.Errorf("Failed to verify the Organization. Only the default org is supported")
|
|
}
|
|
|
|
signedInUser, err := s.userService.GetSignedInUserWithCacheCtx(ctx, &user.GetSignedInUserQuery{OrgID: defaultOrgID, UserID: userID})
|
|
if err != nil {
|
|
s.log.Error("Failed to get user", "error", err)
|
|
return nil, errJWTInvalid.Errorf("Failed to get user: %w", err)
|
|
}
|
|
|
|
if signedInUser.Permissions == nil {
|
|
signedInUser.Permissions = make(map[int64]map[string][]string)
|
|
}
|
|
|
|
if len(claims.Entitlements) == 0 {
|
|
s.log.Error("Entitlements claim is missing")
|
|
return nil, errJWTInvalid.Errorf("Entitlements claim is missing")
|
|
}
|
|
|
|
signedInUser.Permissions[s.getDefaultOrgID()] = claims.Entitlements
|
|
|
|
return authn.IdentityFromSignedInUser(authn.NamespacedID(authn.NamespaceUser, signedInUser.UserID), signedInUser, authn.ClientParams{SyncPermissions: false}, login.ExtendedJWTModule), nil
|
|
}
|
|
|
|
func (s *ExtendedJWT) Test(ctx context.Context, r *authn.Request) bool {
|
|
if !s.cfg.ExtendedJWTAuthEnabled {
|
|
return false
|
|
}
|
|
|
|
rawToken := s.retrieveToken(r.HTTPRequest)
|
|
if rawToken == "" {
|
|
return false
|
|
}
|
|
|
|
parsedToken, err := jwt.ParseSigned(rawToken)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
var claims jwt.Claims
|
|
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
|
return false
|
|
}
|
|
|
|
return claims.Issuer == s.cfg.ExtendedJWTExpectIssuer
|
|
}
|
|
|
|
func (s *ExtendedJWT) Name() string {
|
|
return authn.ClientExtendedJWT
|
|
}
|
|
|
|
func (s *ExtendedJWT) Priority() uint {
|
|
// This client should come before the normal JWT client, because it is more specific, because of the Issuer check
|
|
return 15
|
|
}
|
|
|
|
// retrieveToken retrieves the JWT token from the request.
|
|
func (s *ExtendedJWT) retrieveToken(httpRequest *http.Request) string {
|
|
jwtToken := httpRequest.Header.Get("Authorization")
|
|
|
|
// Strip the 'Bearer' prefix if it exists.
|
|
return strings.TrimPrefix(jwtToken, "Bearer ")
|
|
}
|
|
|
|
// verifyRFC9068Token verifies the token against the RFC 9068 specification.
|
|
func (s *ExtendedJWT) verifyRFC9068Token(ctx context.Context, rawToken string) (*ExtendedJWTClaims, error) {
|
|
parsedToken, err := jwt.ParseSigned(rawToken)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
|
}
|
|
|
|
if len(parsedToken.Headers) != 1 {
|
|
return nil, fmt.Errorf("only one header supported, got %d", len(parsedToken.Headers))
|
|
}
|
|
|
|
parsedHeader := parsedToken.Headers[0]
|
|
|
|
typeHeader := parsedHeader.ExtraHeaders["typ"]
|
|
if typeHeader == nil {
|
|
return nil, fmt.Errorf("missing 'typ' field from the header")
|
|
}
|
|
|
|
jwtType := strings.ToLower(typeHeader.(string))
|
|
if jwtType != rfc9068ShortMediaType && jwtType != rfc9068MediaType {
|
|
return nil, fmt.Errorf("invalid JWT type: %s", jwtType)
|
|
}
|
|
|
|
if !slices.Contains(acceptedSigningMethods, parsedHeader.Algorithm) {
|
|
return nil, fmt.Errorf("invalid algorithm: %s. Accepted algorithms: %s", parsedHeader.Algorithm, strings.Join(acceptedSigningMethods, ", "))
|
|
}
|
|
|
|
var claims ExtendedJWTClaims
|
|
_, key, err := s.signingKeys.GetOrCreatePrivateKey(ctx,
|
|
signingkeys.ServerPrivateKeyID, jose.ES256)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get public key: %w", err)
|
|
}
|
|
|
|
err = parsedToken.Claims(key.Public(), &claims)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to verify the signature: %w", err)
|
|
}
|
|
|
|
if claims.Expiry == nil {
|
|
return nil, fmt.Errorf("missing 'exp' claim")
|
|
}
|
|
|
|
if claims.ID == "" {
|
|
return nil, fmt.Errorf("missing 'jti' claim")
|
|
}
|
|
|
|
if claims.Subject == "" {
|
|
return nil, fmt.Errorf("missing 'sub' claim")
|
|
}
|
|
|
|
if claims.IssuedAt == nil {
|
|
return nil, fmt.Errorf("missing 'iat' claim")
|
|
}
|
|
|
|
err = claims.ValidateWithLeeway(jwt.Expected{
|
|
Issuer: s.cfg.ExtendedJWTExpectIssuer,
|
|
Audience: jwt.Audience{s.cfg.ExtendedJWTExpectAudience},
|
|
Time: timeNow(),
|
|
}, 0)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to validate JWT: %w", err)
|
|
}
|
|
|
|
if err := s.validateClientIdClaim(ctx, claims); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &claims, nil
|
|
}
|
|
|
|
func (s *ExtendedJWT) validateClientIdClaim(ctx context.Context, claims ExtendedJWTClaims) error {
|
|
if claims.ClientID == "" {
|
|
return fmt.Errorf("missing 'client_id' claim")
|
|
}
|
|
|
|
if _, err := s.oauthServer.GetExternalService(ctx, claims.ClientID); err != nil {
|
|
return fmt.Errorf("invalid 'client_id' claim: %s", claims.ClientID)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *ExtendedJWT) getDefaultOrgID() int64 {
|
|
orgID := int64(1)
|
|
if s.cfg.AutoAssignOrg && s.cfg.AutoAssignOrgId > 0 {
|
|
orgID = int64(s.cfg.AutoAssignOrgId)
|
|
}
|
|
return orgID
|
|
}
|