grafana/pkg/services/auth/jwt/validation.go
2021-03-31 08:40:44 -07:00

138 lines
3.6 KiB
Go

package jwt
import (
"encoding/json"
"fmt"
"reflect"
"time"
"github.com/grafana/grafana/pkg/models"
"gopkg.in/square/go-jose.v2/jwt"
)
func (s *AuthService) initClaimExpectations() error {
if err := json.Unmarshal([]byte(s.Cfg.JWTAuthExpectClaims), &s.expect); err != nil {
return err
}
for key, value := range s.expect {
switch key {
case "iss":
if stringValue, ok := value.(string); ok {
s.expectRegistered.Issuer = stringValue
} else {
return fmt.Errorf("%q expectation has invalid type %T, string expected", key, value)
}
delete(s.expect, key)
case "sub":
if stringValue, ok := value.(string); ok {
s.expectRegistered.Subject = stringValue
} else {
return fmt.Errorf("%q expectation has invalid type %T, string expected", key, value)
}
delete(s.expect, key)
case "aud":
switch value := value.(type) {
case []interface{}:
for _, val := range value {
if val, ok := val.(string); ok {
s.expectRegistered.Audience = append(s.expectRegistered.Audience, val)
} else {
return fmt.Errorf("%q expectation contains value with invalid type %T, string expected", key, val)
}
}
case string:
s.expectRegistered.Audience = []string{value}
default:
return fmt.Errorf("%q expectation has invalid type %T, array or string expected", key, value)
}
delete(s.expect, key)
}
}
return nil
}
func (s *AuthService) validateClaims(claims models.JWTClaims) error {
var registeredClaims jwt.Claims
for key, value := range claims {
switch key {
case "iss":
if stringValue, ok := value.(string); ok {
registeredClaims.Issuer = stringValue
} else {
return fmt.Errorf("%q claim has invalid type %T, string expected", key, value)
}
case "sub":
if stringValue, ok := value.(string); ok {
registeredClaims.Subject = stringValue
} else {
return fmt.Errorf("%q claim has invalid type %T, string expected", key, value)
}
case "aud":
switch value := value.(type) {
case []interface{}:
for _, val := range value {
if val, ok := val.(string); ok {
registeredClaims.Audience = append(registeredClaims.Audience, val)
} else {
return fmt.Errorf("%q claim contains value with invalid type %T, string expected", key, val)
}
}
case string:
registeredClaims.Audience = []string{value}
default:
return fmt.Errorf("%q claim has invalid type %T, array or string expected", key, value)
}
case "exp":
if value == nil {
continue
}
if floatValue, ok := value.(float64); ok {
out := jwt.NumericDate(floatValue)
registeredClaims.Expiry = &out
} else {
return fmt.Errorf("%q claim has invalid type %T, number expected", key, value)
}
case "nbf":
if value == nil {
continue
}
if floatValue, ok := value.(float64); ok {
out := jwt.NumericDate(floatValue)
registeredClaims.NotBefore = &out
} else {
return fmt.Errorf("%q claim has invalid type %T, number expected", key, value)
}
case "iat":
if value == nil {
continue
}
if floatValue, ok := value.(float64); ok {
out := jwt.NumericDate(floatValue)
registeredClaims.IssuedAt = &out
} else {
return fmt.Errorf("%q claim has invalid type %T, number expected", key, value)
}
}
}
expectRegistered := s.expectRegistered
expectRegistered.Time = time.Now()
if err := registeredClaims.Validate(expectRegistered); err != nil {
return err
}
for key, expected := range s.expect {
value, ok := claims[key]
if !ok {
return fmt.Errorf("%q claim is missing", key)
}
if !reflect.DeepEqual(expected, value) {
return fmt.Errorf("%q claim mismatch", key)
}
}
return nil
}