diff --git a/conf/defaults.ini b/conf/defaults.ini index b3cc15b3586..4e33f1793ea 100644 --- a/conf/defaults.ini +++ b/conf/defaults.ini @@ -794,6 +794,7 @@ jwk_set_file = cache_ttl = 60m expect_claims = {} key_file = +key_id = role_attribute_path = role_attribute_strict = false auto_sign_up = false diff --git a/conf/sample.ini b/conf/sample.ini index 265cb8a9b9a..cf3cb95d4ed 100644 --- a/conf/sample.ini +++ b/conf/sample.ini @@ -746,6 +746,8 @@ ;cache_ttl = 60m ;expect_claims = {"aud": ["foo", "bar"]} ;key_file = /path/to/key/file +# Use in conjunction with key_file in case the JWT token's header specifies a key ID in "kid" field +;key_id = some-key-id ;role_attribute_path = ;role_attribute_strict = false ;auto_sign_up = false diff --git a/docs/sources/setup-grafana/configure-security/configure-authentication/jwt/index.md b/docs/sources/setup-grafana/configure-security/configure-authentication/jwt/index.md index 08ebde16948..39de9c4f2a7 100644 --- a/docs/sources/setup-grafana/configure-security/configure-authentication/jwt/index.md +++ b/docs/sources/setup-grafana/configure-security/configure-authentication/jwt/index.md @@ -148,6 +148,12 @@ PEM-encoded key file in PKIX, PKCS #1, PKCS #8 or SEC 1 format. key_file = /path/to/key.pem ``` +If the JWT token's header specifies a `kid` (Key ID), then the Key ID must be set using the `key_id` configuration option. + +```ini +key_id = my-key-id +``` + ## Validate claims By default, only `"exp"`, `"nbf"` and `"iat"` claims are validated. diff --git a/pkg/services/auth/jwt/auth_test.go b/pkg/services/auth/jwt/auth_test.go index 95779e39c8c..39f3de0bd4f 100644 --- a/pkg/services/auth/jwt/auth_test.go +++ b/pkg/services/auth/jwt/auth_test.go @@ -46,7 +46,7 @@ func TestVerifyUsingPKIXPublicKeyFile(t *testing.T) { scenario(t, "verifies a token", func(t *testing.T, sc scenarioContext) { token := sign(t, key, jwt.Claims{ Subject: subject, - }) + }, nil) verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) require.NoError(t, err) assert.Equal(t, verifiedClaims["sub"], subject) @@ -55,10 +55,23 @@ func TestVerifyUsingPKIXPublicKeyFile(t *testing.T) { scenario(t, "rejects a token signed by unknown key", func(t *testing.T, sc scenarioContext) { token := sign(t, unknownKey, jwt.Claims{ Subject: subject, - }) + }, nil) _, err := sc.authJWTSvc.Verify(sc.ctx, token) require.Error(t, err) }, configurePKIXPublicKeyFile) + + publicKeyID := "some-key-id" + scenario(t, "verifies a token with a specified kid", func(t *testing.T, sc scenarioContext) { + token := sign(t, key, jwt.Claims{ + Subject: subject, + }, (&jose.SignerOptions{}).WithHeader("kid", publicKeyID)) + verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) + require.NoError(t, err) + assert.Equal(t, verifiedClaims["sub"], subject) + }, configurePKIXPublicKeyFile, func(t *testing.T, cfg *setting.Cfg) { + t.Helper() + cfg.JWTAuthKeyID = publicKeyID + }) } func TestVerifyUsingJWKSetFile(t *testing.T) { @@ -80,21 +93,21 @@ func TestVerifyUsingJWKSetFile(t *testing.T) { } scenario(t, "verifies a token signed with a key from the set", func(t *testing.T, sc scenarioContext) { - token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}) + token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}, nil) verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) require.NoError(t, err) assert.Equal(t, verifiedClaims["sub"], subject) }, configure) scenario(t, "verifies a token signed with another key from the set", func(t *testing.T, sc scenarioContext) { - token := sign(t, &jwKeys[1], jwt.Claims{Subject: subject}) + token := sign(t, &jwKeys[1], jwt.Claims{Subject: subject}, nil) verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) require.NoError(t, err) assert.Equal(t, verifiedClaims["sub"], subject) }, configure) scenario(t, "rejects a token signed with a key not from the set", func(t *testing.T, sc scenarioContext) { - token := sign(t, jwKeys[2], jwt.Claims{Subject: subject}) + token := sign(t, jwKeys[2], jwt.Claims{Subject: subject}, nil) _, err := sc.authJWTSvc.Verify(sc.ctx, token) require.Error(t, err) }, configure) @@ -126,21 +139,21 @@ func TestVerifyUsingJWKSetURL(t *testing.T) { }) jwkHTTPScenario(t, "verifies a token signed with a key from the set", func(t *testing.T, sc scenarioContext) { - token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}) + token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}, nil) verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) require.NoError(t, err) assert.Equal(t, verifiedClaims["sub"], subject) }) jwkHTTPScenario(t, "verifies a token signed with another key from the set", func(t *testing.T, sc scenarioContext) { - token := sign(t, &jwKeys[1], jwt.Claims{Subject: subject}) + token := sign(t, &jwKeys[1], jwt.Claims{Subject: subject}, nil) verifiedClaims, err := sc.authJWTSvc.Verify(sc.ctx, token) require.NoError(t, err) assert.Equal(t, verifiedClaims["sub"], subject) }) jwkHTTPScenario(t, "rejects a token signed with a key not from the set", func(t *testing.T, sc scenarioContext) { - token := sign(t, jwKeys[2], jwt.Claims{Subject: subject}) + token := sign(t, jwKeys[2], jwt.Claims{Subject: subject}, nil) _, err := sc.authJWTSvc.Verify(sc.ctx, token) require.Error(t, err) }) @@ -149,7 +162,7 @@ func TestVerifyUsingJWKSetURL(t *testing.T) { func TestCachingJWKHTTPResponse(t *testing.T) { jwkCachingScenario(t, "caches the jwk response", func(t *testing.T, sc cachingScenarioContext) { for i := 0; i < 5; i++ { - token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}) + token := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}, nil) _, err := sc.authJWTSvc.Verify(sc.ctx, token) require.NoError(t, err, "verify call %d", i+1) } @@ -160,8 +173,8 @@ func TestCachingJWKHTTPResponse(t *testing.T) { jwkCachingScenario(t, "respects TTL setting (while cached)", func(t *testing.T, sc cachingScenarioContext) { var err error - token0 := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}) - token1 := sign(t, &jwKeys[1], jwt.Claims{Subject: subject}) + token0 := sign(t, &jwKeys[0], jwt.Claims{Subject: subject}, nil) + token1 := sign(t, &jwKeys[1], jwt.Claims{Subject: subject}, nil) _, err = sc.authJWTSvc.Verify(sc.ctx, token0) require.NoError(t, err) @@ -176,7 +189,7 @@ func TestCachingJWKHTTPResponse(t *testing.T) { jwkCachingScenario(t, "does not cache the response when TTL is zero", func(t *testing.T, sc cachingScenarioContext) { for i := 0; i < 2; i++ { - _, err := sc.authJWTSvc.Verify(sc.ctx, sign(t, &jwKeys[i], jwt.Claims{Subject: subject})) + _, err := sc.authJWTSvc.Verify(sc.ctx, sign(t, &jwKeys[i], jwt.Claims{Subject: subject}, nil)) require.NoError(t, err, "verify call %d", i+1) } @@ -198,8 +211,8 @@ func TestClaimValidation(t *testing.T) { key := rsaKeys[0] scenario(t, "validates iss field for equality", func(t *testing.T, sc scenarioContext) { - tokenValid := sign(t, key, jwt.Claims{Issuer: "http://foo"}) - tokenInvalid := sign(t, key, jwt.Claims{Issuer: "http://bar"}) + tokenValid := sign(t, key, jwt.Claims{Issuer: "http://foo"}, nil) + tokenInvalid := sign(t, key, jwt.Claims{Issuer: "http://bar"}, nil) _, err := sc.authJWTSvc.Verify(sc.ctx, tokenValid) require.NoError(t, err) @@ -213,8 +226,8 @@ func TestClaimValidation(t *testing.T) { scenario(t, "validates sub field for equality", func(t *testing.T, sc scenarioContext) { var err error - tokenValid := sign(t, key, jwt.Claims{Subject: "foo"}) - tokenInvalid := sign(t, key, jwt.Claims{Subject: "bar"}) + tokenValid := sign(t, key, jwt.Claims{Subject: "foo"}, nil) + tokenInvalid := sign(t, key, jwt.Claims{Subject: "bar"}, nil) _, err = sc.authJWTSvc.Verify(sc.ctx, tokenValid) require.NoError(t, err) @@ -228,19 +241,19 @@ func TestClaimValidation(t *testing.T) { scenario(t, "validates aud field for inclusion", func(t *testing.T, sc scenarioContext) { var err error - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"bar", "foo"}})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"bar", "foo"}}, nil)) require.NoError(t, err) - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"foo", "bar", "baz"}})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"foo", "bar", "baz"}}, nil)) require.NoError(t, err) - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"foo"}})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"foo"}}, nil)) require.Error(t, err) - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"bar", "baz"}})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"bar", "baz"}}, nil)) require.Error(t, err) - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"baz"}})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Audience: []string{"baz"}}, nil)) require.Error(t, err) }, configurePKIXPublicKeyFile, func(t *testing.T, cfg *setting.Cfg) { cfg.JWTAuthExpectClaims = `{"aud": ["foo", "bar"]}` @@ -249,19 +262,19 @@ func TestClaimValidation(t *testing.T) { scenario(t, "validates non-registered (custom) claims for equality", func(t *testing.T, sc scenarioContext) { var err error - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo", "my-number": 123})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo", "my-number": 123}, nil)) require.NoError(t, err) - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "bar", "my-number": 123})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "bar", "my-number": 123}, nil)) require.Error(t, err) - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo", "my-number": 100})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo", "my-number": 100}, nil)) require.Error(t, err) - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo"})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-str": "foo"}, nil)) require.Error(t, err) - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-number": 123})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, map[string]interface{}{"my-number": 123}, nil)) require.Error(t, err) }, configurePKIXPublicKeyFile, func(t *testing.T, cfg *setting.Cfg) { cfg.JWTAuthExpectClaims = `{"my-str": "foo", "my-number": 123}` @@ -270,30 +283,30 @@ func TestClaimValidation(t *testing.T) { scenario(t, "validates exp claim of the token", func(t *testing.T, sc scenarioContext) { var err error - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour))})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour))}, nil)) require.NoError(t, err) - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Expiry: jwt.NewNumericDate(time.Now().Add(-time.Hour))})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{Expiry: jwt.NewNumericDate(time.Now().Add(-time.Hour))}, nil)) require.Error(t, err) }, configurePKIXPublicKeyFile) scenario(t, "validates nbf claim of the token", func(t *testing.T, sc scenarioContext) { var err error - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Hour))})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Hour))}, nil)) require.NoError(t, err) - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Hour))})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{NotBefore: jwt.NewNumericDate(time.Now().Add(time.Hour))}, nil)) require.Error(t, err) }, configurePKIXPublicKeyFile) scenario(t, "validates iat claim of the token", func(t *testing.T, sc scenarioContext) { var err error - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Hour))})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Hour))}, nil)) require.NoError(t, err) - _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour))})) + _, err = sc.authJWTSvc.Verify(sc.ctx, sign(t, key, jwt.Claims{IssuedAt: jwt.NewNumericDate(time.Now().Add(time.Hour))}, nil)) require.Error(t, err) }, configurePKIXPublicKeyFile) } @@ -360,7 +373,7 @@ func TestBase64Paddings(t *testing.T) { scenario(t, "verifies a token with base64 padding (non compliant rfc7515#section-2 but accepted)", func(t *testing.T, sc scenarioContext) { token := sign(t, key, jwt.Claims{ Subject: subject, - }) + }, nil) var tokenParts []string for i, part := range strings.Split(token, ".") { // Create parts with different padding numbers to test multiple cases. diff --git a/pkg/services/auth/jwt/key_sets.go b/pkg/services/auth/jwt/key_sets.go index 0181072c1d5..19562790d85 100644 --- a/pkg/services/auth/jwt/key_sets.go +++ b/pkg/services/auth/jwt/key_sets.go @@ -125,7 +125,7 @@ func (s *AuthService) initKeySet() error { s.keySet = &keySetJWKS{ jose.JSONWebKeySet{ - Keys: []jose.JSONWebKey{{Key: key}}, + Keys: []jose.JSONWebKey{{Key: key, KeyID: s.Cfg.JWTAuthKeyID}}, }, } } else if keyFilePath := s.Cfg.JWTAuthJWKSetFile; keyFilePath != "" { diff --git a/pkg/services/auth/jwt/signing_test.go b/pkg/services/auth/jwt/signing_test.go index a1088a58dd2..5837c020acb 100644 --- a/pkg/services/auth/jwt/signing_test.go +++ b/pkg/services/auth/jwt/signing_test.go @@ -10,10 +10,13 @@ import ( type noneSigner struct{} -func sign(t *testing.T, key interface{}, claims interface{}) string { +func sign(t *testing.T, key interface{}, claims interface{}, opts *jose.SignerOptions) string { t.Helper() - sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.PS512, Key: key}, (&jose.SignerOptions{}).WithType("JWT")) + if opts == nil { + opts = &jose.SignerOptions{} + } + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.PS512, Key: key}, (opts).WithType("JWT")) require.NoError(t, err) token, err := jwt.Signed(sig).Claims(claims).CompactSerialize() require.NoError(t, err) diff --git a/pkg/setting/setting.go b/pkg/setting/setting.go index bc447dcb35c..e9ffbf79d82 100644 --- a/pkg/setting/setting.go +++ b/pkg/setting/setting.go @@ -318,6 +318,7 @@ type Cfg struct { JWTAuthJWKSetURL string JWTAuthCacheTTL time.Duration JWTAuthKeyFile string + JWTAuthKeyID string JWTAuthJWKSetFile string JWTAuthAutoSignUp bool JWTAuthRoleAttributePath string @@ -1597,6 +1598,7 @@ func readAuthSettings(iniFile *ini.File, cfg *Cfg) (err error) { cfg.JWTAuthJWKSetURL = valueAsString(authJWT, "jwk_set_url", "") cfg.JWTAuthCacheTTL = authJWT.Key("cache_ttl").MustDuration(time.Minute * 60) cfg.JWTAuthKeyFile = valueAsString(authJWT, "key_file", "") + cfg.JWTAuthKeyID = authJWT.Key("key_id").MustString("") cfg.JWTAuthJWKSetFile = valueAsString(authJWT, "jwk_set_file", "") cfg.JWTAuthAutoSignUp = authJWT.Key("auto_sign_up").MustBool(false) cfg.JWTAuthRoleAttributePath = valueAsString(authJWT, "role_attribute_path", "")