grafana/pkg/social/generic_oauth.go

304 lines
6.6 KiB
Go

package social
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/mail"
"regexp"
"github.com/grafana/grafana/pkg/models"
"golang.org/x/oauth2"
)
type SocialGenericOAuth struct {
*SocialBase
allowedDomains []string
allowedOrganizations []string
apiUrl string
allowSignup bool
teamIds []int
}
func (s *SocialGenericOAuth) Type() int {
return int(models.GENERIC)
}
func (s *SocialGenericOAuth) IsEmailAllowed(email string) bool {
return isEmailAllowed(email, s.allowedDomains)
}
func (s *SocialGenericOAuth) IsSignupAllowed() bool {
return s.allowSignup
}
func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
if len(s.teamIds) == 0 {
return true
}
teamMemberships, err := s.FetchTeamMemberships(client)
if err != nil {
return false
}
for _, teamId := range s.teamIds {
for _, membershipId := range teamMemberships {
if teamId == membershipId {
return true
}
}
}
return false
}
func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool {
if len(s.allowedOrganizations) == 0 {
return true
}
organizations, err := s.FetchOrganizations(client)
if err != nil {
return false
}
for _, allowedOrganization := range s.allowedOrganizations {
for _, organization := range organizations {
if organization == allowedOrganization {
return true
}
}
}
return false
}
func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) {
type Record struct {
Email string `json:"email"`
Primary bool `json:"primary"`
IsPrimary bool `json:"is_primary"`
Verified bool `json:"verified"`
IsConfirmed bool `json:"is_confirmed"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
if err != nil {
return "", fmt.Errorf("Error getting email address: %s", err)
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
var data struct {
Values []Record `json:"values"`
}
err = json.Unmarshal(response.Body, &data)
if err != nil {
return "", fmt.Errorf("Error getting email address: %s", err)
}
records = data.Values
}
var email = ""
for _, record := range records {
if record.Primary || record.IsPrimary {
email = record.Email
break
}
}
return email, nil
}
func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]int, error) {
type Record struct {
Id int `json:"id"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
if err != nil {
return nil, fmt.Errorf("Error getting team memberships: %s", err)
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
return nil, fmt.Errorf("Error getting team memberships: %s", err)
}
var ids = make([]int, len(records))
for i, record := range records {
ids[i] = record.Id
}
return ids, nil
}
func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, error) {
type Record struct {
Login string `json:"login"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
if err != nil {
return nil, fmt.Errorf("Error getting organizations: %s", err)
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
return nil, fmt.Errorf("Error getting organizations: %s", err)
}
var logins = make([]string, len(records))
for i, record := range records {
logins[i] = record.Login
}
return logins, nil
}
type UserInfoJson struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
Login string `json:"login"`
Username string `json:"username"`
Email string `json:"email"`
Upn string `json:"upn"`
Attributes map[string][]string `json:"attributes"`
}
func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
var data UserInfoJson
var err error
if s.extractToken(&data, token) != true {
response, err := HttpGet(client, s.apiUrl)
if err != nil {
return nil, fmt.Errorf("Error getting user info: %s", err)
}
err = json.Unmarshal(response.Body, &data)
if err != nil {
return nil, fmt.Errorf("Error decoding user info JSON: %s", err)
}
}
name := s.extractName(&data)
email := s.extractEmail(&data)
if email == "" {
email, err = s.FetchPrivateEmail(client)
if err != nil {
return nil, err
}
}
login := s.extractLogin(&data, email)
userInfo := &BasicUserInfo{
Name: name,
Login: login,
Email: email,
}
if !s.IsTeamMember(client) {
return nil, errors.New("User not a member of one of the required teams")
}
if !s.IsOrganizationMember(client) {
return nil, errors.New("User not a member of one of the required organizations")
}
return userInfo, nil
}
func (s *SocialGenericOAuth) extractToken(data *UserInfoJson, token *oauth2.Token) bool {
idToken := token.Extra("id_token")
if idToken == nil {
s.log.Debug("No id_token found", "token", token)
return false
}
jwtRegexp := regexp.MustCompile("^([-_a-zA-Z0-9]+)[.]([-_a-zA-Z0-9]+)[.]([-_a-zA-Z0-9]+)$")
matched := jwtRegexp.FindStringSubmatch(idToken.(string))
if matched == nil {
s.log.Debug("id_token is not in JWT format", "id_token", idToken.(string))
return false
}
payload, err := base64.RawURLEncoding.DecodeString(matched[2])
if err != nil {
s.log.Error("Error base64 decoding id_token", "raw_payload", matched[2], "err", err)
return false
}
err = json.Unmarshal(payload, data)
if err != nil {
s.log.Error("Error decoding id_token JSON", "payload", string(payload), "err", err)
return false
}
email := s.extractEmail(data)
if email == "" {
s.log.Debug("No email found in id_token", "json", string(payload), "data", data)
return false
}
s.log.Debug("Received id_token", "json", string(payload), "data", data)
return true
}
func (s *SocialGenericOAuth) extractEmail(data *UserInfoJson) string {
if data.Email != "" {
return data.Email
}
if data.Attributes["email:primary"] != nil {
return data.Attributes["email:primary"][0]
}
if data.Upn != "" {
emailAddr, emailErr := mail.ParseAddress(data.Upn)
if emailErr == nil {
return emailAddr.Address
}
}
return ""
}
func (s *SocialGenericOAuth) extractLogin(data *UserInfoJson, email string) string {
if data.Login != "" {
return data.Login
}
if data.Username != "" {
return data.Username
}
return email
}
func (s *SocialGenericOAuth) extractName(data *UserInfoJson) string {
if data.Name != "" {
return data.Name
}
if data.DisplayName != "" {
return data.DisplayName
}
return ""
}