diff --git a/pkg/services/authn/authn.go b/pkg/services/authn/authn.go index 500507620ee..2f7da220426 100644 --- a/pkg/services/authn/authn.go +++ b/pkg/services/authn/authn.go @@ -22,15 +22,16 @@ import ( ) const ( - ClientAPIKey = "auth.client.api-key" // #nosec G101 - ClientAnonymous = "auth.client.anonymous" - ClientBasic = "auth.client.basic" - ClientJWT = "auth.client.jwt" - ClientRender = "auth.client.render" - ClientSession = "auth.client.session" - ClientForm = "auth.client.form" - ClientProxy = "auth.client.proxy" - ClientSAML = "auth.client.saml" + ClientAPIKey = "auth.client.api-key" // #nosec G101 + ClientAnonymous = "auth.client.anonymous" + ClientBasic = "auth.client.basic" + ClientJWT = "auth.client.jwt" + ClientExtendedJWT = "auth.client.extended-jwt" + ClientRender = "auth.client.render" + ClientSession = "auth.client.session" + ClientForm = "auth.client.form" + ClientProxy = "auth.client.proxy" + ClientSAML = "auth.client.saml" ) const ( diff --git a/pkg/services/authn/authnimpl/service.go b/pkg/services/authn/authnimpl/service.go index 43d8ea210e7..6cd1825e1de 100644 --- a/pkg/services/authn/authnimpl/service.go +++ b/pkg/services/authn/authnimpl/service.go @@ -31,6 +31,7 @@ import ( "github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/quota" "github.com/grafana/grafana/pkg/services/rendering" + "github.com/grafana/grafana/pkg/services/signingkeys" "github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util/errutil" @@ -63,6 +64,7 @@ func ProvideService( features *featuremgmt.FeatureManager, oauthTokenService oauthtoken.OAuthTokenService, socialService social.Service, cache *remotecache.RemoteCache, ldapService service.LDAP, registerer prometheus.Registerer, + signingKeysService signingkeys.Service, ) authn.Service { s := &Service{ log: log.New("authn.service"), @@ -128,6 +130,10 @@ func ProvideService( s.RegisterClient(clients.ProvideJWT(jwtService, cfg)) } + if s.cfg.ExtendedJWTAuthEnabled && features.IsEnabled(featuremgmt.FlagExternalServiceAuth) { + s.RegisterClient(clients.ProvideExtendedJWT(userService, cfg, signingKeysService)) + } + for name := range socialService.GetOAuthProviders() { oauthCfg := socialService.GetOAuthInfoProvider(name) if oauthCfg != nil && oauthCfg.Enabled { diff --git a/pkg/services/authn/clients/ext_jwt.go b/pkg/services/authn/clients/ext_jwt.go new file mode 100644 index 00000000000..df44b60570e --- /dev/null +++ b/pkg/services/authn/clients/ext_jwt.go @@ -0,0 +1,228 @@ +package clients + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/go-jose/go-jose/v3/jwt" + "golang.org/x/exp/slices" + + "github.com/grafana/grafana/pkg/infra/log" + "github.com/grafana/grafana/pkg/services/authn" + "github.com/grafana/grafana/pkg/services/signingkeys" + "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/setting" +) + +var _ authn.Client = new(ExtendedJWT) + +var ( + acceptedSigningMethods = []string{"RS256", "ES256"} + timeNow = time.Now +) + +const ( + rfc9068ShortMediaType = "at+jwt" + rfc9068MediaType = "application/at+jwt" +) + +func ProvideExtendedJWT(userService user.Service, cfg *setting.Cfg, signingKeys signingkeys.Service) *ExtendedJWT { + return &ExtendedJWT{ + cfg: cfg, + log: log.New(authn.ClientExtendedJWT), + userService: userService, + signingKeys: signingKeys, + } +} + +type ExtendedJWT struct { + cfg *setting.Cfg + log log.Logger + userService user.Service + signingKeys signingkeys.Service +} + +type ExtendedJWTClaims struct { + jwt.Claims + ClientID string `json:"client_id"` + Groups []string `json:"groups"` + Email string `json:"email"` + Name string `json:"name"` + Login string `json:"login"` + Scopes []string `json:"scope"` + Entitlements map[string][]string `json:"entitlements"` +} + +func (s *ExtendedJWT) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) { + jwtToken := s.retrieveToken(r.HTTPRequest) + + claims, err := s.verifyRFC9068Token(ctx, jwtToken) + if err != nil { + s.log.Error("Failed to verify JWT", "error", err) + return nil, errJWTInvalid.Errorf("Failed to verify JWT: %w", err) + } + + // user:id:18 + userID, err := strconv.ParseInt(strings.TrimPrefix(claims.Subject, fmt.Sprintf("%s:id:", authn.NamespaceUser)), 10, 64) + if err != nil { + s.log.Error("Failed to parse sub", "error", err) + return nil, errJWTInvalid.Errorf("Failed to parse sub: %w", err) + } + + // FIXME: support multiple organizations + defaultOrgID := s.getDefaultOrgID() + if r.OrgID != defaultOrgID { + s.log.Error("Failed to verify the Organization: OrgID is not the default") + return nil, errJWTInvalid.Errorf("Failed to verify the Organization. Only the default org is supported") + } + + signedInUser, err := s.userService.GetSignedInUserWithCacheCtx(ctx, &user.GetSignedInUserQuery{OrgID: defaultOrgID, UserID: userID}) + if err != nil { + s.log.Error("Failed to get user", "error", err) + return nil, errJWTInvalid.Errorf("Failed to get user: %w", err) + } + + if signedInUser.Permissions == nil { + signedInUser.Permissions = make(map[int64]map[string][]string) + } + + if len(claims.Entitlements) == 0 { + s.log.Error("Entitlements claim is missing") + return nil, errJWTInvalid.Errorf("Entitlements claim is missing") + } + + signedInUser.Permissions[s.getDefaultOrgID()] = claims.Entitlements + + return authn.IdentityFromSignedInUser(authn.NamespacedID(authn.NamespaceUser, signedInUser.UserID), signedInUser, authn.ClientParams{SyncPermissions: false}), nil +} + +func (s *ExtendedJWT) Test(ctx context.Context, r *authn.Request) bool { + if !s.cfg.ExtendedJWTAuthEnabled { + return false + } + + rawToken := s.retrieveToken(r.HTTPRequest) + if rawToken == "" { + return false + } + + parsedToken, err := jwt.ParseSigned(rawToken) + if err != nil { + return false + } + + var claims jwt.Claims + if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil { + return false + } + + return claims.Issuer == s.cfg.ExtendedJWTExpectIssuer +} + +func (s *ExtendedJWT) Name() string { + return authn.ClientExtendedJWT +} + +func (s *ExtendedJWT) Priority() uint { + // This client should come before the normal JWT client, because it is more specific, because of the Issuer check + return 15 +} + +// retrieveToken retrieves the JWT token from the request. +func (s *ExtendedJWT) retrieveToken(httpRequest *http.Request) string { + jwtToken := httpRequest.Header.Get("Authorization") + + // Strip the 'Bearer' prefix if it exists. + return strings.TrimPrefix(jwtToken, "Bearer ") +} + +// verifyRFC9068Token verifies the token against the RFC 9068 specification. +func (s *ExtendedJWT) verifyRFC9068Token(ctx context.Context, rawToken string) (*ExtendedJWTClaims, error) { + parsedToken, err := jwt.ParseSigned(rawToken) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT: %w", err) + } + + if len(parsedToken.Headers) != 1 { + return nil, fmt.Errorf("only one header supported, got %d", len(parsedToken.Headers)) + } + + parsedHeader := parsedToken.Headers[0] + + typeHeader := parsedHeader.ExtraHeaders["typ"] + if typeHeader == nil { + return nil, fmt.Errorf("missing 'typ' field from the header") + } + + jwtType := strings.ToLower(typeHeader.(string)) + if jwtType != rfc9068ShortMediaType && jwtType != rfc9068MediaType { + return nil, fmt.Errorf("invalid JWT type: %s", jwtType) + } + + if !slices.Contains(acceptedSigningMethods, parsedHeader.Algorithm) { + return nil, fmt.Errorf("invalid algorithm: %s. Accepted algorithms: %s", parsedHeader.Algorithm, strings.Join(acceptedSigningMethods, ", ")) + } + + var claims ExtendedJWTClaims + err = parsedToken.Claims(s.signingKeys.GetServerPublicKey(), &claims) + if err != nil { + return nil, fmt.Errorf("failed to verify the signature: %w", err) + } + + if claims.Expiry == nil { + return nil, fmt.Errorf("missing 'exp' claim") + } + + if claims.ID == "" { + return nil, fmt.Errorf("missing 'jti' claim") + } + + if claims.Subject == "" { + return nil, fmt.Errorf("missing 'sub' claim") + } + + if claims.IssuedAt == nil { + return nil, fmt.Errorf("missing 'iat' claim") + } + + err = claims.ValidateWithLeeway(jwt.Expected{ + Issuer: s.cfg.ExtendedJWTExpectIssuer, + Audience: jwt.Audience{s.cfg.ExtendedJWTExpectAudience}, + Time: timeNow(), + }, 0) + + if err != nil { + return nil, fmt.Errorf("failed to validate JWT: %w", err) + } + + if err := s.validateClientIdClaim(ctx, claims); err != nil { + return nil, err + } + + return &claims, nil +} + +func (s *ExtendedJWT) validateClientIdClaim(ctx context.Context, claims ExtendedJWTClaims) error { + if claims.ClientID == "" { + return fmt.Errorf("missing 'client_id' claim") + } + + // TODO: Implement the validation for client_id when the OAuth server is ready. + // if _, err := s.oauthService.GetExternalService(ctx, clientId); err != nil { + // return fmt.Errorf("invalid 'client_id' claim: %s", clientIdClaim) + // } + + return nil +} + +func (s *ExtendedJWT) getDefaultOrgID() int64 { + orgID := int64(1) + if s.cfg.AutoAssignOrg && s.cfg.AutoAssignOrgId > 0 { + orgID = int64(s.cfg.AutoAssignOrgId) + } + return orgID +} diff --git a/pkg/services/authn/clients/ext_jwt_test.go b/pkg/services/authn/clients/ext_jwt_test.go new file mode 100644 index 00000000000..480d32d53f3 --- /dev/null +++ b/pkg/services/authn/clients/ext_jwt_test.go @@ -0,0 +1,535 @@ +package clients + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "fmt" + "net/http" + "testing" + "time" + + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" + + "github.com/grafana/grafana/pkg/models/roletype" + "github.com/grafana/grafana/pkg/services/authn" + "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/signingkeys/signingkeystest" + "github.com/grafana/grafana/pkg/services/user" + "github.com/grafana/grafana/pkg/services/user/usertest" + "github.com/grafana/grafana/pkg/setting" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + validPayload = ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + Entitlements: map[string][]string{ + "dashboards:create": { + "folders:uid:general", + }, + "folders:read": { + "folders:uid:general", + }, + "datasources:explore": nil, + "datasources.insights:read": {}, + }, + } + pk, _ = rsa.GenerateKey(rand.Reader, 4096) +) + +func TestExtendedJWTTest(t *testing.T) { + type testCase struct { + name string + cfg *setting.Cfg + authHeaderFunc func() string + want bool + } + + testCases := []testCase{ + { + name: "should return false when extended jwt is disabled", + cfg: &setting.Cfg{ + ExtendedJWTAuthEnabled: false, + }, + authHeaderFunc: func() string { return "eyJ" }, + want: false, + }, + { + name: "should return true when Authorization header contains Bearer prefix", + cfg: nil, + authHeaderFunc: func() string { return "Bearer " + generateToken(validPayload, pk, jose.RS256) }, + want: true, + }, + { + name: "should return true when Authorization header only contains the token", + cfg: nil, + authHeaderFunc: func() string { return generateToken(validPayload, pk, jose.RS256) }, + want: true, + }, + { + name: "should return false when Authorization header is empty", + cfg: nil, + authHeaderFunc: func() string { return "" }, + want: false, + }, + { + name: "should return false when jwt.ParseSigned fails", + cfg: nil, + authHeaderFunc: func() string { return "invalid token" }, + want: false, + }, + { + name: "should return false when the issuer does not match the configured issuer", + cfg: &setting.Cfg{ + ExtendedJWTExpectIssuer: "http://localhost:3000", + }, + authHeaderFunc: func() string { + payload := validPayload + payload.Issuer = "http://unknown-issuer" + return generateToken(payload, pk, jose.RS256) + }, + want: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + extJwtClient := setupTestCtx(t, nil, tc.cfg) + + validHTTPReq := &http.Request{ + Header: map[string][]string{ + "Authorization": {tc.authHeaderFunc()}, + }, + } + + actual := extJwtClient.Test(context.Background(), &authn.Request{ + HTTPRequest: validHTTPReq, + Resp: nil, + }) + + assert.Equal(t, tc.want, actual) + }) + } +} + +func TestExtendedJWTAuthenticate(t *testing.T) { + type testCase struct { + name string + payload ExtendedJWTClaims + orgID int64 + want *authn.Identity + userSvcSetup func(userSvc *usertest.FakeUserService) + wantErr bool + } + testCases := []testCase{ + { + name: "successful authentication", + payload: validPayload, + orgID: 1, + userSvcSetup: func(userSvc *usertest.FakeUserService) { + userSvc.ExpectedSignedInUser = &user.SignedInUser{ + UserID: 2, + OrgID: 1, + OrgRole: roletype.RoleAdmin, + Name: "John Doe", + Email: "johndoe@grafana.com", + Login: "johndoe", + } + }, + want: &authn.Identity{ + OrgID: 1, + OrgCount: 0, + OrgName: "", + OrgRoles: map[int64]roletype.RoleType{1: roletype.RoleAdmin}, + ID: "user:2", + Login: "johndoe", + Name: "John Doe", + Email: "johndoe@grafana.com", + IsGrafanaAdmin: boolPtr(false), + AuthModule: "", + AuthID: "", + IsDisabled: false, + HelpFlags1: 0, + Permissions: map[int64]map[string][]string{ + 1: { + "dashboards:create": { + "folders:uid:general", + }, + "folders:read": { + "folders:uid:general", + }, + "datasources:explore": nil, + "datasources.insights:read": []string{}, + }, + }, + ClientParams: authn.ClientParams{ + SyncUser: false, + AllowSignUp: false, + FetchSyncedUser: false, + EnableDisabledUsers: false, + SyncOrgRoles: false, + SyncTeams: false, + SyncPermissions: false, + LookUpParams: login.UserLookupParams{ + UserID: nil, + Email: nil, + Login: nil, + }, + }, + }, + wantErr: false, + }, + { + name: "should return error when the user cannot be parsed from the Subject claim", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + orgID: 1, + want: nil, + wantErr: true, + }, + { + name: "should return error when the OrgId is not the ID of the default org", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + orgID: 0, + want: nil, + wantErr: true, + }, + { + name: "should return error when the user cannot be found", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + orgID: 1, + want: nil, + userSvcSetup: func(userSvc *usertest.FakeUserService) { + userSvc.ExpectedError = user.ErrUserNotFound + }, + wantErr: true, + }, + { + name: "should return error when entitlements claim is missing", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + orgID: 1, + want: nil, + wantErr: true, + }, + // { + // name: "should return error when the entitlements are not in the correct format", + // payload: ExtendedJWTClaims{ + // Claims: jwt.Claims{ + // Issuer: "http://localhost:3000", + // Subject: "user:id:2", + // Audience: jwt.Audience{"http://localhost:3000"}, + // ID: "1234567890", + // Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + // IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + // }, + // ClientID: "grafana", + // Scopes: []string{"profile", "groups"}, + // Entitlements: []string{"dashboards:create", "folders:read"}, + // }, + // orgID: 1, + // want: nil, + // wantErr: true, + // }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + userSvc := &usertest.FakeUserService{} + extJwtClient := setupTestCtx(t, userSvc, nil) + if tc.userSvcSetup != nil { + tc.userSvcSetup(userSvc) + } + + validHTTPReq := &http.Request{ + Header: map[string][]string{ + "Authorization": {generateToken(tc.payload, pk, jose.RS256)}, + }, + } + + mockTimeNow(time.Date(2023, 5, 2, 0, 1, 0, 0, time.UTC)) + + id, err := extJwtClient.Authenticate(context.Background(), &authn.Request{ + OrgID: tc.orgID, + HTTPRequest: validHTTPReq, + Resp: nil, + }) + if tc.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.EqualValues(t, tc.want, id, fmt.Sprintf("%+v", id)) + } + }) + } +} + +// https://datatracker.ietf.org/doc/html/rfc9068#name-data-structure +func TestVerifyRFC9068TokenFailureScenarios(t *testing.T) { + type testCase struct { + name string + payload ExtendedJWTClaims + alg jose.SignatureAlgorithm + } + + testCases := []testCase{ + { + name: "missing iss", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + }, + { + name: "missing expiry", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + }, + { + name: "expired token", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + }, + { + name: "missing aud", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + }, + { + name: "wrong aud", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://some-other-host:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + }, + { + name: "missing sub", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + }, + { + name: "missing client_id", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + Scopes: []string{"profile", "groups"}, + }, + }, + { + name: "missing iat", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + }, + { + name: "iat later than current time", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 2, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + }, + { + name: "missing jti", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + }, + { + name: "unsupported alg", + payload: ExtendedJWTClaims{ + Claims: jwt.Claims{ + Issuer: "http://localhost:3000", + Subject: "user:id:2", + Audience: jwt.Audience{"http://localhost:3000"}, + ID: "1234567890", + Expiry: jwt.NewNumericDate(time.Date(2023, 5, 3, 0, 0, 0, 0, time.UTC)), + IssuedAt: jwt.NewNumericDate(time.Date(2023, 5, 2, 0, 0, 0, 0, time.UTC)), + }, + ClientID: "grafana", + Scopes: []string{"profile", "groups"}, + }, + alg: jose.RS384, + }, + } + + extJwtClient := setupTestCtx(t, nil, nil) + mockTimeNow(time.Date(2023, 5, 2, 0, 1, 0, 0, time.UTC)) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.alg == "" { + tc.alg = jose.RS256 + } + tokenToTest := generateToken(tc.payload, pk, tc.alg) + _, err := extJwtClient.verifyRFC9068Token(context.Background(), tokenToTest) + require.Error(t, err) + }) + } +} + +func setupTestCtx(t *testing.T, userSvc user.Service, cfg *setting.Cfg) *ExtendedJWT { + if cfg == nil { + cfg = &setting.Cfg{ + ExtendedJWTAuthEnabled: true, + ExtendedJWTExpectIssuer: "http://localhost:3000", + ExtendedJWTExpectAudience: "http://localhost:3000", + } + } + + signingKeysSvc := &signingkeystest.FakeSigningKeysService{} + signingKeysSvc.ExpectedServerPublicKey = &pk.PublicKey + + extJwtClient := ProvideExtendedJWT(userSvc, cfg, signingKeysSvc) + return extJwtClient +} + +func generateToken(payload ExtendedJWTClaims, signingKey interface{}, alg jose.SignatureAlgorithm) string { + signer, _ := jose.NewSigner(jose.SigningKey{Algorithm: alg, Key: signingKey}, &jose.SignerOptions{ + ExtraHeaders: map[jose.HeaderKey]interface{}{ + jose.HeaderType: "at+jwt", + }}) + + result, _ := jwt.Signed(signer).Claims(payload).CompactSerialize() + return result +} + +func mockTimeNow(timeSeed time.Time) { + timeNow = func() time.Time { + return timeSeed + } +} diff --git a/pkg/services/signingkeys/signingkeys.go b/pkg/services/signingkeys/signingkeys.go index c3e37e44339..35ef7bd3b3a 100644 --- a/pkg/services/signingkeys/signingkeys.go +++ b/pkg/services/signingkeys/signingkeys.go @@ -26,7 +26,9 @@ type Service interface { // GetPrivateKey returns the private key with the specified key ID GetPrivateKey(keyID string) (crypto.PrivateKey, error) // GetServerPrivateKey returns the private key used to sign tokens - GetServerPrivateKey() (crypto.PrivateKey, error) + GetServerPrivateKey() crypto.PrivateKey + // GetServerPublicKey returns the public key used to verify tokens + GetServerPublicKey() crypto.PublicKey // AddPrivateKey adds a private key to the service AddPrivateKey(keyID string, privateKey crypto.PrivateKey) error } diff --git a/pkg/services/signingkeys/signingkeysimpl/service.go b/pkg/services/signingkeys/signingkeysimpl/service.go index 30d0ec99612..52fb4b867e0 100644 --- a/pkg/services/signingkeys/signingkeysimpl/service.go +++ b/pkg/services/signingkeys/signingkeysimpl/service.go @@ -108,6 +108,15 @@ func (s *Service) AddPrivateKey(keyID string, privateKey crypto.PrivateKey) erro } // GetServerPrivateKey returns the private key used to sign tokens -func (s *Service) GetServerPrivateKey() (crypto.PrivateKey, error) { - return s.GetPrivateKey(serverPrivateKeyID) +func (s *Service) GetServerPrivateKey() crypto.PrivateKey { + // The server private key is always available + pk, _ := s.GetPrivateKey(serverPrivateKeyID) + return pk +} + +// GetServerPrivateKey returns the private key used to sign tokens +func (s *Service) GetServerPublicKey() crypto.PublicKey { + // The server public key is always available + publicKey, _ := s.GetPublicKey(serverPrivateKeyID) + return publicKey } diff --git a/pkg/services/signingkeys/signingkeystest/fake.go b/pkg/services/signingkeys/signingkeystest/fake.go index 4e9163c2970..f3f1c4648e3 100644 --- a/pkg/services/signingkeys/signingkeystest/fake.go +++ b/pkg/services/signingkeys/signingkeystest/fake.go @@ -11,6 +11,7 @@ type FakeSigningKeysService struct { ExpectedJSONWebKey jose.JSONWebKey ExpectedKeys map[string]crypto.Signer ExpectedServerPrivateKey crypto.PrivateKey + ExpectedServerPublicKey crypto.PublicKey ExpectedError error } @@ -34,8 +35,13 @@ func (s *FakeSigningKeysService) GetPrivateKey(keyID string) (crypto.PrivateKey, } // GetServerPrivateKey returns the private key used to sign tokens -func (s *FakeSigningKeysService) GetServerPrivateKey() (crypto.PrivateKey, error) { - return s.ExpectedServerPrivateKey, s.ExpectedError +func (s *FakeSigningKeysService) GetServerPrivateKey() crypto.PrivateKey { + return s.ExpectedServerPrivateKey +} + +// GetServerPublicKey returns the public key used to verify tokens +func (s *FakeSigningKeysService) GetServerPublicKey() crypto.PublicKey { + return s.ExpectedServerPublicKey } // AddPrivateKey adds a private key to the service diff --git a/pkg/setting/setting.go b/pkg/setting/setting.go index 7a3ac78fb8a..86af80fd232 100644 --- a/pkg/setting/setting.go +++ b/pkg/setting/setting.go @@ -315,6 +315,11 @@ type Cfg struct { JWTAuthAllowAssignGrafanaAdmin bool JWTAuthSkipOrgRoleSync bool + // Extended JWT Auth + ExtendedJWTAuthEnabled bool + ExtendedJWTExpectIssuer string + ExtendedJWTExpectAudience string + // Dataproxy SendUserHeader bool DataProxyLogging bool @@ -1542,6 +1547,13 @@ func readAuthSettings(iniFile *ini.File, cfg *Cfg) (err error) { cfg.JWTAuthAllowAssignGrafanaAdmin = authJWT.Key("allow_assign_grafana_admin").MustBool(false) cfg.JWTAuthSkipOrgRoleSync = authJWT.Key("skip_org_role_sync").MustBool(false) + // Extended JWT auth + authExtendedJWT := iniFile.Section("auth.extended_jwt") + cfg.ExtendedJWTAuthEnabled = authExtendedJWT.Key("enabled").MustBool(false) + cfg.ExtendedJWTExpectAudience = authExtendedJWT.Key("expect_audience").MustString("") + cfg.ExtendedJWTExpectIssuer = authExtendedJWT.Key("expect_issuer").MustString("") + + // Auth Proxy authProxy := iniFile.Section("auth.proxy") cfg.AuthProxyEnabled = authProxy.Key("enabled").MustBool(false)