package social import ( "encoding/base64" "encoding/json" "errors" "fmt" "net/http" "net/mail" "regexp" "github.com/grafana/grafana/pkg/models" "golang.org/x/oauth2" ) type SocialGenericOAuth struct { *SocialBase allowedDomains []string allowedOrganizations []string apiUrl string allowSignup bool teamIds []int } func (s *SocialGenericOAuth) Type() int { return int(models.GENERIC) } func (s *SocialGenericOAuth) IsEmailAllowed(email string) bool { return isEmailAllowed(email, s.allowedDomains) } func (s *SocialGenericOAuth) IsSignupAllowed() bool { return s.allowSignup } func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool { if len(s.teamIds) == 0 { return true } teamMemberships, err := s.FetchTeamMemberships(client) if err != nil { return false } for _, teamId := range s.teamIds { for _, membershipId := range teamMemberships { if teamId == membershipId { return true } } } return false } func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool { if len(s.allowedOrganizations) == 0 { return true } organizations, err := s.FetchOrganizations(client) if err != nil { return false } for _, allowedOrganization := range s.allowedOrganizations { for _, organization := range organizations { if organization == allowedOrganization { return true } } } return false } func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) { type Record struct { Email string `json:"email"` Primary bool `json:"primary"` IsPrimary bool `json:"is_primary"` Verified bool `json:"verified"` IsConfirmed bool `json:"is_confirmed"` } response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/emails")) if err != nil { return "", fmt.Errorf("Error getting email address: %s", err) } var records []Record err = json.Unmarshal(response.Body, &records) if err != nil { var data struct { Values []Record `json:"values"` } err = json.Unmarshal(response.Body, &data) if err != nil { return "", fmt.Errorf("Error getting email address: %s", err) } records = data.Values } var email = "" for _, record := range records { if record.Primary || record.IsPrimary { email = record.Email break } } return email, nil } func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]int, error) { type Record struct { Id int `json:"id"` } response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/teams")) if err != nil { return nil, fmt.Errorf("Error getting team memberships: %s", err) } var records []Record err = json.Unmarshal(response.Body, &records) if err != nil { return nil, fmt.Errorf("Error getting team memberships: %s", err) } var ids = make([]int, len(records)) for i, record := range records { ids[i] = record.Id } return ids, nil } func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, error) { type Record struct { Login string `json:"login"` } response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/orgs")) if err != nil { return nil, fmt.Errorf("Error getting organizations: %s", err) } var records []Record err = json.Unmarshal(response.Body, &records) if err != nil { return nil, fmt.Errorf("Error getting organizations: %s", err) } var logins = make([]string, len(records)) for i, record := range records { logins[i] = record.Login } return logins, nil } type UserInfoJson struct { Name string `json:"name"` DisplayName string `json:"display_name"` Login string `json:"login"` Username string `json:"username"` Email string `json:"email"` Upn string `json:"upn"` Attributes map[string][]string `json:"attributes"` } func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { var data UserInfoJson var err error if !s.extractToken(&data, token) { response, err := HttpGet(client, s.apiUrl) if err != nil { return nil, fmt.Errorf("Error getting user info: %s", err) } err = json.Unmarshal(response.Body, &data) if err != nil { return nil, fmt.Errorf("Error decoding user info JSON: %s", err) } } name := s.extractName(&data) email := s.extractEmail(&data) if email == "" { email, err = s.FetchPrivateEmail(client) if err != nil { return nil, err } } login := s.extractLogin(&data, email) userInfo := &BasicUserInfo{ Name: name, Login: login, Email: email, } if !s.IsTeamMember(client) { return nil, errors.New("User not a member of one of the required teams") } if !s.IsOrganizationMember(client) { return nil, errors.New("User not a member of one of the required organizations") } return userInfo, nil } func (s *SocialGenericOAuth) extractToken(data *UserInfoJson, token *oauth2.Token) bool { idToken := token.Extra("id_token") if idToken == nil { s.log.Debug("No id_token found", "token", token) return false } jwtRegexp := regexp.MustCompile("^([-_a-zA-Z0-9]+)[.]([-_a-zA-Z0-9]+)[.]([-_a-zA-Z0-9]+)$") matched := jwtRegexp.FindStringSubmatch(idToken.(string)) if matched == nil { s.log.Debug("id_token is not in JWT format", "id_token", idToken.(string)) return false } payload, err := base64.RawURLEncoding.DecodeString(matched[2]) if err != nil { s.log.Error("Error base64 decoding id_token", "raw_payload", matched[2], "err", err) return false } err = json.Unmarshal(payload, data) if err != nil { s.log.Error("Error decoding id_token JSON", "payload", string(payload), "err", err) return false } email := s.extractEmail(data) if email == "" { s.log.Debug("No email found in id_token", "json", string(payload), "data", data) return false } s.log.Debug("Received id_token", "json", string(payload), "data", data) return true } func (s *SocialGenericOAuth) extractEmail(data *UserInfoJson) string { if data.Email != "" { return data.Email } if data.Attributes["email:primary"] != nil { return data.Attributes["email:primary"][0] } if data.Upn != "" { emailAddr, emailErr := mail.ParseAddress(data.Upn) if emailErr == nil { return emailAddr.Address } } return "" } func (s *SocialGenericOAuth) extractLogin(data *UserInfoJson, email string) string { if data.Login != "" { return data.Login } if data.Username != "" { return data.Username } return email } func (s *SocialGenericOAuth) extractName(data *UserInfoJson) string { if data.Name != "" { return data.Name } if data.DisplayName != "" { return data.DisplayName } return "" }