grafana/pkg/services/auth/jwt/key_sets.go
Jo 6f62d970e3
JWT Authentication: Add support for specifying groups in auth.jwt for teamsync (#82175)
* merge JSON search logic

* document public methods

* improve test coverage

* use separate JWT setting struct

* correct use of cfg.JWTAuth

* add group tests

* fix DynMap typing

* add settings to default ini

* add groups option to devenv path

* fix test

* lint

* revert jwt-proxy change

* remove redundant check

* fix parallel test
2024-02-09 16:35:58 +01:00

269 lines
6.6 KiB
Go

package jwt
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
jose "github.com/go-jose/go-jose/v3"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/setting"
)
var ErrFailedToParsePemFile = errors.New("failed to parse pem-encoded file")
var ErrKeySetIsNotConfigured = errors.New("key set for jwt verification is not configured")
var ErrKeySetConfigurationAmbiguous = errors.New("key set configuration is ambiguous: you should set either key_file, jwk_set_file or jwk_set_url")
var ErrJWTSetURLMustHaveHTTPSScheme = errors.New("jwt_set_url must have https scheme")
type keySet interface {
Key(ctx context.Context, kid string) ([]jose.JSONWebKey, error)
}
type keySetJWKS struct {
jose.JSONWebKeySet
}
type keySetHTTP struct {
url string
log log.Logger
client *http.Client
cache *remotecache.RemoteCache
cacheKey string
cacheExpiration time.Duration
}
func (s *AuthService) checkKeySetConfiguration() error {
var count int
if s.Cfg.JWTAuth.KeyFile != "" {
count++
}
if s.Cfg.JWTAuth.JWKSetFile != "" {
count++
}
if s.Cfg.JWTAuth.JWKSetURL != "" {
count++
}
if count == 0 {
return ErrKeySetIsNotConfigured
}
if count > 1 {
return ErrKeySetConfigurationAmbiguous
}
return nil
}
func (s *AuthService) initKeySet() error {
if err := s.checkKeySetConfiguration(); err != nil {
return err
}
if keyFilePath := s.Cfg.JWTAuth.KeyFile; keyFilePath != "" {
// nolint:gosec
// We can ignore the gosec G304 warning on this one because `fileName` comes from grafana configuration file
file, err := os.Open(keyFilePath)
if err != nil {
return err
}
defer func() {
if err := file.Close(); err != nil {
s.log.Warn("Failed to close file", "path", keyFilePath, "err", err)
}
}()
data, err := io.ReadAll(file)
if err != nil {
return err
}
block, _ := pem.Decode(data)
if block == nil {
return ErrFailedToParsePemFile
}
var key any
switch block.Type {
case "PUBLIC KEY":
if key, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
return err
}
case "PRIVATE KEY":
if key, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return err
}
case "RSA PUBLIC KEY":
if key, err = x509.ParsePKCS1PublicKey(block.Bytes); err != nil {
return err
}
case "RSA PRIVATE KEY":
if key, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
return err
}
case "EC PRIVATE KEY":
if key, err = x509.ParseECPrivateKey(block.Bytes); err != nil {
return err
}
default:
return fmt.Errorf("unknown pem block type %q", block.Type)
}
s.keySet = &keySetJWKS{
jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{{Key: key, KeyID: s.Cfg.JWTAuth.KeyID}},
},
}
} else if keyFilePath := s.Cfg.JWTAuth.JWKSetFile; keyFilePath != "" {
// nolint:gosec
// We can ignore the gosec G304 warning on this one because `fileName` comes from grafana configuration file
file, err := os.Open(keyFilePath)
if err != nil {
return err
}
defer func() {
if err := file.Close(); err != nil {
s.log.Warn("Failed to close file", "path", keyFilePath, "err", err)
}
}()
var jwks jose.JSONWebKeySet
if err := json.NewDecoder(file).Decode(&jwks); err != nil {
return err
}
s.keySet = &keySetJWKS{jwks}
} else if urlStr := s.Cfg.JWTAuth.JWKSetURL; urlStr != "" {
urlParsed, err := url.Parse(urlStr)
if err != nil {
return err
}
if urlParsed.Scheme != "https" && s.Cfg.Env != setting.Dev {
return ErrJWTSetURLMustHaveHTTPSScheme
}
s.keySet = &keySetHTTP{
url: urlStr,
log: s.log,
client: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Renegotiation: tls.RenegotiateFreelyAsClient,
},
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Second * 30,
KeepAlive: 15 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: 100,
IdleConnTimeout: 30 * time.Second,
},
Timeout: time.Second * 30,
},
cacheKey: fmt.Sprintf("auth-jwt:jwk-%s", urlStr),
cacheExpiration: s.Cfg.JWTAuth.CacheTTL,
cache: s.RemoteCache,
}
}
return nil
}
func (ks *keySetJWKS) Key(ctx context.Context, keyID string) ([]jose.JSONWebKey, error) {
return ks.JSONWebKeySet.Key(keyID), nil
}
func (ks *keySetHTTP) getJWKS(ctx context.Context) (keySetJWKS, error) {
var jwks keySetJWKS
if ks.cacheExpiration > 0 {
if val, err := ks.cache.Get(ctx, ks.cacheKey); err == nil {
err := json.Unmarshal(val, &jwks)
if err != nil {
ks.log.Warn("Failed to unmarshal key set from cache", "err", err)
} else {
return jwks, err
}
}
}
ks.log.Debug("Getting key set from endpoint", "url", ks.url)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ks.url, nil)
if err != nil {
return jwks, err
}
resp, err := ks.client.Do(req)
if err != nil {
return jwks, err
}
defer func() {
if err := resp.Body.Close(); err != nil {
ks.log.Warn("Failed to close response body", "err", err)
}
}()
var jsonBuf bytes.Buffer
if err := json.NewDecoder(io.TeeReader(resp.Body, &jsonBuf)).Decode(&jwks); err != nil {
return jwks, err
}
if ks.cacheExpiration > 0 {
cacheExpiration := ks.getCacheExpiration(resp.Header.Get("cache-control"))
ks.log.Debug("Setting key set in cache", "url", ks.url,
"cacheExpiration", cacheExpiration, "cacheControl", resp.Header.Get("cache-control"))
err = ks.cache.Set(ctx, ks.cacheKey, jsonBuf.Bytes(), cacheExpiration)
}
return jwks, err
}
func (ks *keySetHTTP) getCacheExpiration(cacheControl string) time.Duration {
cacheDuration := ks.cacheExpiration
if cacheControl == "" {
return cacheDuration
}
parts := strings.Split(cacheControl, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "max-age=") {
maxAge, err := strconv.Atoi(part[8:])
if err != nil {
return cacheDuration
}
// If the cache duration is 0 or the max-age is less than the cache duration, use the max-age
if cacheDuration == 0 || time.Duration(maxAge)*time.Second < cacheDuration {
return time.Duration(maxAge) * time.Second
}
}
}
return cacheDuration
}
func (ks keySetHTTP) Key(ctx context.Context, kid string) ([]jose.JSONWebKey, error) {
jwks, err := ks.getJWKS(ctx)
if err != nil {
return nil, err
}
return jwks.Key(ctx, kid)
}