grafana/pkg/services/authn/clients/ext_jwt.go
2023-11-01 09:17:38 -07:00

239 lines
6.8 KiB
Go

package clients
import (
"context"
"fmt"
"net/http"
"slices"
"strconv"
"strings"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/extsvcauth/oauthserver"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/signingkeys"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
var _ authn.Client = new(ExtendedJWT)
var (
acceptedSigningMethods = []string{"RS256", "ES256"}
timeNow = time.Now
)
const (
rfc9068ShortMediaType = "at+jwt"
rfc9068MediaType = "application/at+jwt"
)
func ProvideExtendedJWT(userService user.Service, cfg *setting.Cfg, signingKeys signingkeys.Service, oauthServer oauthserver.OAuth2Server) *ExtendedJWT {
return &ExtendedJWT{
cfg: cfg,
log: log.New(authn.ClientExtendedJWT),
userService: userService,
signingKeys: signingKeys,
oauthServer: oauthServer,
}
}
type ExtendedJWT struct {
cfg *setting.Cfg
log log.Logger
userService user.Service
signingKeys signingkeys.Service
oauthServer oauthserver.OAuth2Server
}
type ExtendedJWTClaims struct {
jwt.Claims
ClientID string `json:"client_id"`
Groups []string `json:"groups"`
Email string `json:"email"`
Name string `json:"name"`
Login string `json:"login"`
Scopes []string `json:"scope"`
Entitlements map[string][]string `json:"entitlements"`
}
func (s *ExtendedJWT) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
jwtToken := s.retrieveToken(r.HTTPRequest)
claims, err := s.verifyRFC9068Token(ctx, jwtToken)
if err != nil {
s.log.Error("Failed to verify JWT", "error", err)
return nil, errJWTInvalid.Errorf("Failed to verify JWT: %w", err)
}
// user:id:18
userID, err := strconv.ParseInt(strings.TrimPrefix(claims.Subject, fmt.Sprintf("%s:id:", authn.NamespaceUser)), 10, 64)
if err != nil {
s.log.Error("Failed to parse sub", "error", err)
return nil, errJWTInvalid.Errorf("Failed to parse sub: %w", err)
}
// FIXME: support multiple organizations
defaultOrgID := s.getDefaultOrgID()
if r.OrgID != defaultOrgID {
s.log.Error("Failed to verify the Organization: OrgID is not the default")
return nil, errJWTInvalid.Errorf("Failed to verify the Organization. Only the default org is supported")
}
signedInUser, err := s.userService.GetSignedInUserWithCacheCtx(ctx, &user.GetSignedInUserQuery{OrgID: defaultOrgID, UserID: userID})
if err != nil {
s.log.Error("Failed to get user", "error", err)
return nil, errJWTInvalid.Errorf("Failed to get user: %w", err)
}
if signedInUser.Permissions == nil {
signedInUser.Permissions = make(map[int64]map[string][]string)
}
if len(claims.Entitlements) == 0 {
s.log.Error("Entitlements claim is missing")
return nil, errJWTInvalid.Errorf("Entitlements claim is missing")
}
signedInUser.Permissions[s.getDefaultOrgID()] = claims.Entitlements
return authn.IdentityFromSignedInUser(authn.NamespacedID(authn.NamespaceUser, signedInUser.UserID), signedInUser, authn.ClientParams{SyncPermissions: false}, login.ExtendedJWTModule), nil
}
func (s *ExtendedJWT) Test(ctx context.Context, r *authn.Request) bool {
if !s.cfg.ExtendedJWTAuthEnabled {
return false
}
rawToken := s.retrieveToken(r.HTTPRequest)
if rawToken == "" {
return false
}
parsedToken, err := jwt.ParseSigned(rawToken)
if err != nil {
return false
}
var claims jwt.Claims
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
return false
}
return claims.Issuer == s.cfg.ExtendedJWTExpectIssuer
}
func (s *ExtendedJWT) Name() string {
return authn.ClientExtendedJWT
}
func (s *ExtendedJWT) Priority() uint {
// This client should come before the normal JWT client, because it is more specific, because of the Issuer check
return 15
}
// retrieveToken retrieves the JWT token from the request.
func (s *ExtendedJWT) retrieveToken(httpRequest *http.Request) string {
jwtToken := httpRequest.Header.Get("Authorization")
// Strip the 'Bearer' prefix if it exists.
return strings.TrimPrefix(jwtToken, "Bearer ")
}
// verifyRFC9068Token verifies the token against the RFC 9068 specification.
func (s *ExtendedJWT) verifyRFC9068Token(ctx context.Context, rawToken string) (*ExtendedJWTClaims, error) {
parsedToken, err := jwt.ParseSigned(rawToken)
if err != nil {
return nil, fmt.Errorf("failed to parse JWT: %w", err)
}
if len(parsedToken.Headers) != 1 {
return nil, fmt.Errorf("only one header supported, got %d", len(parsedToken.Headers))
}
parsedHeader := parsedToken.Headers[0]
typeHeader := parsedHeader.ExtraHeaders["typ"]
if typeHeader == nil {
return nil, fmt.Errorf("missing 'typ' field from the header")
}
jwtType := strings.ToLower(typeHeader.(string))
if jwtType != rfc9068ShortMediaType && jwtType != rfc9068MediaType {
return nil, fmt.Errorf("invalid JWT type: %s", jwtType)
}
if !slices.Contains(acceptedSigningMethods, parsedHeader.Algorithm) {
return nil, fmt.Errorf("invalid algorithm: %s. Accepted algorithms: %s", parsedHeader.Algorithm, strings.Join(acceptedSigningMethods, ", "))
}
var claims ExtendedJWTClaims
_, key, err := s.signingKeys.GetOrCreatePrivateKey(ctx,
signingkeys.ServerPrivateKeyID, jose.ES256)
if err != nil {
return nil, fmt.Errorf("failed to get public key: %w", err)
}
err = parsedToken.Claims(key.Public(), &claims)
if err != nil {
return nil, fmt.Errorf("failed to verify the signature: %w", err)
}
if claims.Expiry == nil {
return nil, fmt.Errorf("missing 'exp' claim")
}
if claims.ID == "" {
return nil, fmt.Errorf("missing 'jti' claim")
}
if claims.Subject == "" {
return nil, fmt.Errorf("missing 'sub' claim")
}
if claims.IssuedAt == nil {
return nil, fmt.Errorf("missing 'iat' claim")
}
err = claims.ValidateWithLeeway(jwt.Expected{
Issuer: s.cfg.ExtendedJWTExpectIssuer,
Audience: jwt.Audience{s.cfg.ExtendedJWTExpectAudience},
Time: timeNow(),
}, 0)
if err != nil {
return nil, fmt.Errorf("failed to validate JWT: %w", err)
}
if err := s.validateClientIdClaim(ctx, claims); err != nil {
return nil, err
}
return &claims, nil
}
func (s *ExtendedJWT) validateClientIdClaim(ctx context.Context, claims ExtendedJWTClaims) error {
if claims.ClientID == "" {
return fmt.Errorf("missing 'client_id' claim")
}
if _, err := s.oauthServer.GetExternalService(ctx, claims.ClientID); err != nil {
return fmt.Errorf("invalid 'client_id' claim: %s", claims.ClientID)
}
return nil
}
func (s *ExtendedJWT) getDefaultOrgID() int64 {
orgID := int64(1)
if s.cfg.AutoAssignOrg && s.cfg.AutoAssignOrgId > 0 {
orgID = int64(s.cfg.AutoAssignOrgId)
}
return orgID
}