Auth: Add request context to UserInfo calls (#70007)

* use context for UserInfo requests

* set timeouts for oauth http client

* Update pkg/login/social/common.go

Co-authored-by: Ieva <ieva.vasiljeva@grafana.com>

---------

Co-authored-by: Ieva <ieva.vasiljeva@grafana.com>
This commit is contained in:
Jo 2023-06-14 12:30:40 +00:00 committed by GitHub
parent 1445a7cc5c
commit 914daef0fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 172 additions and 138 deletions

View File

@ -70,61 +70,61 @@ func genPKCECode() (string, string, error) {
return string(ascii), pkce, nil return string(ascii), pkce, nil
} }
func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) { func (hs *HTTPServer) OAuthLogin(reqCtx *contextmodel.ReqContext) {
name := web.Params(ctx.Req)[":name"] name := web.Params(reqCtx.Req)[":name"]
loginInfo := loginservice.LoginInfo{AuthModule: name} loginInfo := loginservice.LoginInfo{AuthModule: name}
if errorParam := ctx.Query("error"); errorParam != "" { if errorParam := reqCtx.Query("error"); errorParam != "" {
errorDesc := ctx.Query("error_description") errorDesc := reqCtx.Query("error_description")
oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc) oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc)
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc) hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc)
return return
} }
code := ctx.Query("code") code := reqCtx.Query("code")
if hs.Cfg.AuthBrokerEnabled { if hs.Cfg.AuthBrokerEnabled {
req := &authn.Request{HTTPRequest: ctx.Req, Resp: ctx.Resp} req := &authn.Request{HTTPRequest: reqCtx.Req, Resp: reqCtx.Resp}
if code == "" { if code == "" {
redirect, err := hs.authnService.RedirectURL(ctx.Req.Context(), authn.ClientWithPrefix(name), req) redirect, err := hs.authnService.RedirectURL(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
if err != nil { if err != nil {
ctx.Redirect(hs.redirectURLWithErrorCookie(ctx, err)) reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))
return return
} }
if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" { if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" {
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
} }
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
ctx.Redirect(redirect.URL) reqCtx.Redirect(redirect.URL)
return return
} }
identity, err := hs.authnService.Login(ctx.Req.Context(), authn.ClientWithPrefix(name), req) identity, err := hs.authnService.Login(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
// NOTE: always delete these cookies, even if login failed // NOTE: always delete these cookies, even if login failed
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg) cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg) cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
if err != nil { if err != nil {
ctx.Redirect(hs.redirectURLWithErrorCookie(ctx, err)) reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))
return return
} }
metrics.MApiLoginOAuth.Inc() metrics.MApiLoginOAuth.Inc()
authn.HandleLoginRedirect(ctx.Req, ctx.Resp, hs.Cfg, identity, hs.ValidateRedirectTo) authn.HandleLoginRedirect(reqCtx.Req, reqCtx.Resp, hs.Cfg, identity, hs.ValidateRedirectTo)
return return
} }
provider := hs.SocialService.GetOAuthInfoProvider(name) provider := hs.SocialService.GetOAuthInfoProvider(name)
if provider == nil { if provider == nil {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, errors.New("OAuth not enabled")) hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, errors.New("OAuth not enabled"))
return return
} }
connect, err := hs.SocialService.GetConnector(name) connect, err := hs.SocialService.GetConnector(name)
if err != nil { if err != nil {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, fmt.Errorf("no OAuth with name %s configured", name)) hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, fmt.Errorf("no OAuth with name %s configured", name))
return return
} }
@ -133,15 +133,15 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
if provider.UsePKCE { if provider.UsePKCE {
ascii, pkce, err := genPKCECode() ascii, pkce, err := genPKCECode()
if err != nil { if err != nil {
ctx.Logger.Error("Generating PKCE failed", "error", err) reqCtx.Logger.Error("Generating PKCE failed", "error", err)
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError, HttpStatus: http.StatusInternalServerError,
PublicMessage: "An internal error occurred", PublicMessage: "An internal error occurred",
}) })
return return
} }
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
opts = append(opts, opts = append(opts,
oauth2.SetAuthURLParam("code_challenge", pkce), oauth2.SetAuthURLParam("code_challenge", pkce),
@ -151,8 +151,8 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
state, err := GenStateString() state, err := GenStateString()
if err != nil { if err != nil {
ctx.Logger.Error("Generating state string failed", "err", err) reqCtx.Logger.Error("Generating state string failed", "err", err)
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError, HttpStatus: http.StatusInternalServerError,
PublicMessage: "An internal error occurred", PublicMessage: "An internal error occurred",
}) })
@ -160,32 +160,32 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
} }
hashedState := hs.hashStatecode(state, provider.ClientSecret) hashedState := hs.hashStatecode(state, provider.ClientSecret)
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
if provider.HostedDomain != "" { if provider.HostedDomain != "" {
opts = append(opts, oauth2.SetAuthURLParam("hd", provider.HostedDomain)) opts = append(opts, oauth2.SetAuthURLParam("hd", provider.HostedDomain))
} }
ctx.Redirect(connect.AuthCodeURL(state, opts...)) reqCtx.Redirect(connect.AuthCodeURL(state, opts...))
return return
} }
cookieState := ctx.GetCookie(OauthStateCookieName) cookieState := reqCtx.GetCookie(OauthStateCookieName)
// delete cookie // delete cookie
cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg) cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
if cookieState == "" { if cookieState == "" {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError, HttpStatus: http.StatusInternalServerError,
PublicMessage: "login.OAuthLogin(missing saved state)", PublicMessage: "login.OAuthLogin(missing saved state)",
}) })
return return
} }
queryState := hs.hashStatecode(ctx.Query("state"), provider.ClientSecret) queryState := hs.hashStatecode(reqCtx.Query("state"), provider.ClientSecret)
oauthLogger.Info("state check", "queryState", queryState, "cookieState", cookieState) oauthLogger.Info("state check", "queryState", queryState, "cookieState", cookieState)
if cookieState != queryState { if cookieState != queryState {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError, HttpStatus: http.StatusInternalServerError,
PublicMessage: "login.OAuthLogin(state mismatch)", PublicMessage: "login.OAuthLogin(state mismatch)",
}) })
@ -194,19 +194,20 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
oauthClient, err := hs.SocialService.GetOAuthHttpClient(name) oauthClient, err := hs.SocialService.GetOAuthHttpClient(name)
if err != nil { if err != nil {
ctx.Logger.Error("Failed to create OAuth http client", "error", err) reqCtx.Logger.Error("Failed to create OAuth http client", "error", err)
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError, HttpStatus: http.StatusInternalServerError,
PublicMessage: "login.OAuthLogin(" + err.Error() + ")", PublicMessage: "login.OAuthLogin(" + err.Error() + ")",
}) })
return return
} }
oauthCtx := context.WithValue(context.Background(), oauth2.HTTPClient, oauthClient) ctx := reqCtx.Req.Context()
oauthCtx := context.WithValue(ctx, oauth2.HTTPClient, oauthClient)
opts := []oauth2.AuthCodeOption{} opts := []oauth2.AuthCodeOption{}
codeVerifier := ctx.GetCookie(OauthPKCECookieName) codeVerifier := reqCtx.GetCookie(OauthPKCECookieName)
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg) cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
if codeVerifier != "" { if codeVerifier != "" {
opts = append(opts, opts = append(opts,
oauth2.SetAuthURLParam("code_verifier", codeVerifier), oauth2.SetAuthURLParam("code_verifier", codeVerifier),
@ -216,7 +217,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
// get token from provider // get token from provider
token, err := connect.Exchange(oauthCtx, code, opts...) token, err := connect.Exchange(oauthCtx, code, opts...)
if err != nil { if err != nil {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError, HttpStatus: http.StatusInternalServerError,
PublicMessage: "login.OAuthLogin(NewTransportWithCode)", PublicMessage: "login.OAuthLogin(NewTransportWithCode)",
Err: err, Err: err,
@ -245,13 +246,13 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
client := connect.Client(oauthCtx, token) client := connect.Client(oauthCtx, token)
// get user info // get user info
userInfo, err := connect.UserInfo(client, token) userInfo, err := connect.UserInfo(ctx, client, token)
if err != nil { if err != nil {
var sErr *social.Error var sErr *social.Error
if errors.As(err, &sErr) { if errors.As(err, &sErr) {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, sErr) hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, sErr)
} else { } else {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError, HttpStatus: http.StatusInternalServerError,
PublicMessage: fmt.Sprintf("login.OAuthLogin(get info from %s)", name), PublicMessage: fmt.Sprintf("login.OAuthLogin(get info from %s)", name),
Err: err, Err: err,
@ -264,34 +265,34 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
// validate that we got at least an email address // validate that we got at least an email address
if userInfo.Email == "" { if userInfo.Email == "" {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrNoEmail) hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrNoEmail)
return return
} }
// validate that the email is allowed to login to grafana // validate that the email is allowed to login to grafana
if !connect.IsEmailAllowed(userInfo.Email) { if !connect.IsEmailAllowed(userInfo.Email) {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrEmailNotAllowed) hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrEmailNotAllowed)
return return
} }
loginInfo.ExternalUser = *hs.buildExternalUserInfo(token, userInfo, name) loginInfo.ExternalUser = *hs.buildExternalUserInfo(token, userInfo, name)
loginInfo.User, err = hs.SyncUser(ctx, &loginInfo.ExternalUser, connect) loginInfo.User, err = hs.SyncUser(reqCtx, &loginInfo.ExternalUser, connect)
if err != nil { if err != nil {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, err) hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, err)
return return
} }
// login // login
if err := hs.loginUserWithUser(loginInfo.User, ctx); err != nil { if err := hs.loginUserWithUser(loginInfo.User, reqCtx); err != nil {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, err) hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, err)
return return
} }
loginInfo.HTTPStatus = http.StatusOK loginInfo.HTTPStatus = http.StatusOK
hs.HooksService.RunLoginHook(&loginInfo, ctx) hs.HooksService.RunLoginHook(&loginInfo, reqCtx)
metrics.MApiLoginOAuth.Inc() metrics.MApiLoginOAuth.Inc()
ctx.Redirect(hs.GetRedirectURL(ctx)) reqCtx.Redirect(hs.GetRedirectURL(reqCtx))
} }
// buildExternalUserInfo returns a ExternalUserInfo struct from OAuth user profile // buildExternalUserInfo returns a ExternalUserInfo struct from OAuth user profile

View File

@ -2,6 +2,7 @@ package social
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -47,7 +48,7 @@ type azureAccessClaims struct {
TenantID string `json:"tid"` TenantID string `json:"tid"`
} }
func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
idToken := token.Extra("id_token") idToken := token.Extra("id_token")
if idToken == nil { if idToken == nil {
return nil, ErrIDTokenNotFound return nil, ErrIDTokenNotFound
@ -83,7 +84,7 @@ func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*Bas
} }
logger.Debug("AzureAD OAuth: extracted role", "email", email, "role", role) logger.Debug("AzureAD OAuth: extracted role", "email", email, "role", role)
groups, err := s.extractGroups(client, claims, token) groups, err := s.extractGroups(ctx, client, claims, token)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to extract groups: %w", err) return nil, fmt.Errorf("failed to extract groups: %w", err)
} }
@ -176,7 +177,7 @@ type getAzureGroupResponse struct {
// Note: If user groups exceeds 200 no groups will be found in claims and URL to target the Graph API will be // Note: If user groups exceeds 200 no groups will be found in claims and URL to target the Graph API will be
// given instead. // given instead.
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim // See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) { func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
if !s.forceUseGraphAPI { if !s.forceUseGraphAPI {
logger.Debug("checking the claim for groups") logger.Debug("checking the claim for groups")
if len(claims.Groups) > 0 { if len(claims.Groups) > 0 {
@ -199,7 +200,13 @@ func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, t
return nil, err return nil, err
} }
res, err := client.Post(endpoint, "application/json", bytes.NewBuffer(data)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
res, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -473,7 +473,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
tt.args.client = s.Client(context.Background(), token) tt.args.client = s.Client(context.Background(), token)
} }
got, err := s.UserInfo(tt.args.client, token) got, err := s.UserInfo(context.Background(), tt.args.client, token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -617,7 +617,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
tt.args.client = s.Client(context.Background(), token) tt.args.client = s.Client(context.Background(), token)
} }
got, err := s.UserInfo(tt.args.client, token) got, err := s.UserInfo(context.Background(), tt.args.client, token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
return return

View File

@ -1,6 +1,7 @@
package social package social
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -42,10 +43,15 @@ func isEmailAllowed(email string, allowedDomains []string) bool {
return valid return valid
} }
func (s *SocialBase) httpGet(client *http.Client, url string) (response httpGetResponse, err error) { func (s *SocialBase) httpGet(ctx context.Context, client *http.Client, url string) (*httpGetResponse, error) {
r, err := client.Get(url) req, errReq := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil { if errReq != nil {
return return nil, errReq
}
r, errDo := client.Do(req)
if errDo != nil {
return nil, errDo
} }
defer func() { defer func() {
@ -54,21 +60,20 @@ func (s *SocialBase) httpGet(client *http.Client, url string) (response httpGetR
} }
}() }()
body, err := io.ReadAll(r.Body) body, errRead := io.ReadAll(r.Body)
if err != nil { if errRead != nil {
return return nil, errRead
} }
response = httpGetResponse{body, r.Header} response := &httpGetResponse{body, r.Header}
if r.StatusCode >= 300 { if r.StatusCode >= 300 {
err = fmt.Errorf(string(response.Body)) return nil, fmt.Errorf("unsuccessful response status code %d: %s", r.StatusCode, string(response.Body))
return
} }
s.log.Debug("HTTP GET", "url", url, "status", r.Status, "response_body", string(response.Body)) s.log.Debug("HTTP GET", "url", url, "status", r.Status, "response_body", string(response.Body))
err = nil return response, nil
return
} }
func (s *SocialBase) searchJSONForAttr(attributePath string, data []byte) (interface{}, error) { func (s *SocialBase) searchJSONForAttr(attributePath string, data []byte) (interface{}, error) {

View File

@ -3,6 +3,7 @@ package social
import ( import (
"bytes" "bytes"
"compress/zlib" "compress/zlib"
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
@ -50,12 +51,12 @@ func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool {
return false return false
} }
func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool { func (s *SocialGenericOAuth) IsTeamMember(ctx context.Context, client *http.Client) bool {
if len(s.teamIds) == 0 { if len(s.teamIds) == 0 {
return true return true
} }
teamMemberships, err := s.FetchTeamMemberships(client) teamMemberships, err := s.FetchTeamMemberships(ctx, client)
if err != nil { if err != nil {
return false return false
} }
@ -71,12 +72,12 @@ func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
return false return false
} }
func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool { func (s *SocialGenericOAuth) IsOrganizationMember(ctx context.Context, client *http.Client) bool {
if len(s.allowedOrganizations) == 0 { if len(s.allowedOrganizations) == 0 {
return true return true
} }
organizations, ok := s.FetchOrganizations(client) organizations, ok := s.FetchOrganizations(ctx, client)
if !ok { if !ok {
return false return false
} }
@ -111,14 +112,14 @@ func (info *UserInfoJson) String() string {
info.Name, info.DisplayName, info.Login, info.Username, info.Email, info.Upn, info.Attributes) info.Name, info.DisplayName, info.Login, info.Username, info.Email, info.Upn, info.Attributes)
} }
func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
s.log.Debug("Getting user info") s.log.Debug("Getting user info")
toCheck := make([]*UserInfoJson, 0, 2) toCheck := make([]*UserInfoJson, 0, 2)
if tokenData := s.extractFromToken(token); tokenData != nil { if tokenData := s.extractFromToken(token); tokenData != nil {
toCheck = append(toCheck, tokenData) toCheck = append(toCheck, tokenData)
} }
if apiData := s.extractFromAPI(client); apiData != nil { if apiData := s.extractFromAPI(ctx, client); apiData != nil {
toCheck = append(toCheck, apiData) toCheck = append(toCheck, apiData)
} }
@ -179,7 +180,7 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
if userInfo.Email == "" { if userInfo.Email == "" {
var err error var err error
userInfo.Email, err = s.FetchPrivateEmail(client) userInfo.Email, err = s.FetchPrivateEmail(ctx, client)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -191,11 +192,11 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
userInfo.Login = userInfo.Email userInfo.Login = userInfo.Email
} }
if !s.IsTeamMember(client) { if !s.IsTeamMember(ctx, client) {
return nil, errors.New("user not a member of one of the required teams") return nil, errors.New("user not a member of one of the required teams")
} }
if !s.IsOrganizationMember(client) { if !s.IsOrganizationMember(ctx, client) {
return nil, errors.New("user not a member of one of the required organizations") return nil, errors.New("user not a member of one of the required organizations")
} }
@ -288,14 +289,14 @@ func (s *SocialGenericOAuth) extractFromToken(token *oauth2.Token) *UserInfoJson
return &data return &data
} }
func (s *SocialGenericOAuth) extractFromAPI(client *http.Client) *UserInfoJson { func (s *SocialGenericOAuth) extractFromAPI(ctx context.Context, client *http.Client) *UserInfoJson {
s.log.Debug("Getting user info from API") s.log.Debug("Getting user info from API")
if s.apiUrl == "" { if s.apiUrl == "" {
s.log.Debug("No api url configured") s.log.Debug("No api url configured")
return nil return nil
} }
rawUserInfoResponse, err := s.httpGet(client, s.apiUrl) rawUserInfoResponse, err := s.httpGet(ctx, client, s.apiUrl)
if err != nil { if err != nil {
s.log.Debug("Error getting user info from API", "url", s.apiUrl, "error", err) s.log.Debug("Error getting user info from API", "url", s.apiUrl, "error", err)
return nil return nil
@ -404,7 +405,7 @@ func (s *SocialGenericOAuth) extractGroups(data *UserInfoJson) ([]string, error)
return s.searchJSONForStringArrayAttr(s.groupsAttributePath, data.rawJSON) return s.searchJSONForStringArrayAttr(s.groupsAttributePath, data.rawJSON)
} }
func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) { func (s *SocialGenericOAuth) FetchPrivateEmail(ctx context.Context, client *http.Client) (string, error) {
type Record struct { type Record struct {
Email string `json:"email"` Email string `json:"email"`
Primary bool `json:"primary"` Primary bool `json:"primary"`
@ -413,7 +414,7 @@ func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, err
IsConfirmed bool `json:"is_confirmed"` IsConfirmed bool `json:"is_confirmed"`
} }
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/emails")) response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/emails"))
if err != nil { if err != nil {
s.log.Error("Error getting email address", "url", s.apiUrl+"/emails", "error", err) s.log.Error("Error getting email address", "url", s.apiUrl+"/emails", "error", err)
return "", fmt.Errorf("%v: %w", "Error getting email address", err) return "", fmt.Errorf("%v: %w", "Error getting email address", err)
@ -451,14 +452,14 @@ func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, err
return email, nil return email, nil
} }
func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]string, error) { func (s *SocialGenericOAuth) FetchTeamMemberships(ctx context.Context, client *http.Client) ([]string, error) {
var err error var err error
var ids []string var ids []string
if s.teamsUrl == "" { if s.teamsUrl == "" {
ids, err = s.fetchTeamMembershipsFromDeprecatedTeamsUrl(client) ids, err = s.fetchTeamMembershipsFromDeprecatedTeamsUrl(ctx, client)
} else { } else {
ids, err = s.fetchTeamMembershipsFromTeamsUrl(client) ids, err = s.fetchTeamMembershipsFromTeamsUrl(ctx, client)
} }
if err == nil { if err == nil {
@ -468,16 +469,14 @@ func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]string
return ids, err return ids, err
} }
func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(client *http.Client) ([]string, error) { func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(ctx context.Context, client *http.Client) ([]string, error) {
var response httpGetResponse
var err error
var ids []string var ids []string
type Record struct { type Record struct {
Id int `json:"id"` Id int `json:"id"`
} }
response, err = s.httpGet(client, fmt.Sprintf(s.apiUrl+"/teams")) response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/teams"))
if err != nil { if err != nil {
s.log.Error("Error getting team memberships", "url", s.apiUrl+"/teams", "error", err) s.log.Error("Error getting team memberships", "url", s.apiUrl+"/teams", "error", err)
return []string{}, err return []string{}, err
@ -499,15 +498,12 @@ func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(client *
return ids, nil return ids, nil
} }
func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(client *http.Client) ([]string, error) { func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(ctx context.Context, client *http.Client) ([]string, error) {
if s.teamIdsAttributePath == "" { if s.teamIdsAttributePath == "" {
return []string{}, nil return []string{}, nil
} }
var response httpGetResponse response, err := s.httpGet(ctx, client, fmt.Sprintf(s.teamsUrl))
var err error
response, err = s.httpGet(client, fmt.Sprintf(s.teamsUrl))
if err != nil { if err != nil {
s.log.Error("Error getting team memberships", "url", s.teamsUrl, "error", err) s.log.Error("Error getting team memberships", "url", s.teamsUrl, "error", err)
return nil, err return nil, err
@ -516,12 +512,12 @@ func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(client *http.Clien
return s.searchJSONForStringArrayAttr(s.teamIdsAttributePath, response.Body) return s.searchJSONForStringArrayAttr(s.teamIdsAttributePath, response.Body)
} }
func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, bool) { func (s *SocialGenericOAuth) FetchOrganizations(ctx context.Context, client *http.Client) ([]string, bool) {
type Record struct { type Record struct {
Login string `json:"login"` Login string `json:"login"`
} }
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/orgs")) response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/orgs"))
if err != nil { if err != nil {
s.log.Error("Error getting organizations", "url", s.apiUrl+"/orgs", "error", err) s.log.Error("Error getting organizations", "url", s.apiUrl+"/orgs", "error", err)
return nil, false return nil, false

View File

@ -1,6 +1,7 @@
package social package social
import ( import (
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -490,7 +491,7 @@ func TestUserInfoSearchesForEmailAndRole(t *testing.T) {
} }
token := staticToken.WithExtra(test.OAuth2Extra) token := staticToken.WithExtra(test.OAuth2Extra)
actualResult, err := provider.UserInfo(ts.Client(), token) actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.ExpectedEmail, actualResult.Email) require.Equal(t, test.ExpectedEmail, actualResult.Email)
require.Equal(t, test.ExpectedEmail, actualResult.Login) require.Equal(t, test.ExpectedEmail, actualResult.Login)
@ -588,7 +589,7 @@ func TestUserInfoSearchesForLogin(t *testing.T) {
} }
token := staticToken.WithExtra(test.OAuth2Extra) token := staticToken.WithExtra(test.OAuth2Extra)
actualResult, err := provider.UserInfo(ts.Client(), token) actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.ExpectedLogin, actualResult.Login) require.Equal(t, test.ExpectedLogin, actualResult.Login)
}) })
@ -686,7 +687,7 @@ func TestUserInfoSearchesForName(t *testing.T) {
} }
token := staticToken.WithExtra(test.OAuth2Extra) token := staticToken.WithExtra(test.OAuth2Extra)
actualResult, err := provider.UserInfo(ts.Client(), token) actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.ExpectedName, actualResult.Name) require.Equal(t, test.ExpectedName, actualResult.Name)
}) })
@ -755,7 +756,7 @@ func TestUserInfoSearchesForGroup(t *testing.T) {
Expiry: time.Now(), Expiry: time.Now(),
} }
userInfo, err := provider.UserInfo(ts.Client(), token) userInfo, err := provider.UserInfo(context.Background(), ts.Client(), token)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, test.expectedResult, userInfo.Groups) assert.Equal(t, test.expectedResult, userInfo.Groups)
}) })

View File

@ -1,6 +1,7 @@
package social package social
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -35,12 +36,12 @@ var (
ErrMissingOrganizationMembership = Error{"user not a member of one of the required organizations"} ErrMissingOrganizationMembership = Error{"user not a member of one of the required organizations"}
) )
func (s *SocialGithub) IsTeamMember(client *http.Client) bool { func (s *SocialGithub) IsTeamMember(ctx context.Context, client *http.Client) bool {
if len(s.teamIds) == 0 { if len(s.teamIds) == 0 {
return true return true
} }
teamMemberships, err := s.FetchTeamMemberships(client) teamMemberships, err := s.FetchTeamMemberships(ctx, client)
if err != nil { if err != nil {
return false return false
} }
@ -56,12 +57,13 @@ func (s *SocialGithub) IsTeamMember(client *http.Client) bool {
return false return false
} }
func (s *SocialGithub) IsOrganizationMember(client *http.Client, organizationsUrl string) bool { func (s *SocialGithub) IsOrganizationMember(ctx context.Context,
client *http.Client, organizationsUrl string) bool {
if len(s.allowedOrganizations) == 0 { if len(s.allowedOrganizations) == 0 {
return true return true
} }
organizations, err := s.FetchOrganizations(client, organizationsUrl) organizations, err := s.FetchOrganizations(ctx, client, organizationsUrl)
if err != nil { if err != nil {
return false return false
} }
@ -77,14 +79,14 @@ func (s *SocialGithub) IsOrganizationMember(client *http.Client, organizationsUr
return false return false
} }
func (s *SocialGithub) FetchPrivateEmail(client *http.Client) (string, error) { func (s *SocialGithub) FetchPrivateEmail(ctx context.Context, client *http.Client) (string, error) {
type Record struct { type Record struct {
Email string `json:"email"` Email string `json:"email"`
Primary bool `json:"primary"` Primary bool `json:"primary"`
Verified bool `json:"verified"` Verified bool `json:"verified"`
} }
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/emails")) response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/emails"))
if err != nil { if err != nil {
return "", fmt.Errorf("Error getting email address: %s", err) return "", fmt.Errorf("Error getting email address: %s", err)
} }
@ -106,13 +108,13 @@ func (s *SocialGithub) FetchPrivateEmail(client *http.Client) (string, error) {
return email, nil return email, nil
} }
func (s *SocialGithub) FetchTeamMemberships(client *http.Client) ([]GithubTeam, error) { func (s *SocialGithub) FetchTeamMemberships(ctx context.Context, client *http.Client) ([]GithubTeam, error) {
url := fmt.Sprintf(s.apiUrl + "/teams?per_page=100") url := fmt.Sprintf(s.apiUrl + "/teams?per_page=100")
hasMore := true hasMore := true
teams := make([]GithubTeam, 0) teams := make([]GithubTeam, 0)
for hasMore { for hasMore {
response, err := s.httpGet(client, url) response, err := s.httpGet(ctx, client, url)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error getting team memberships: %s", err) return nil, fmt.Errorf("Error getting team memberships: %s", err)
} }
@ -150,7 +152,7 @@ func (s *SocialGithub) HasMoreRecords(headers http.Header) (string, bool) {
return url, true return url, true
} }
func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl string) ([]string, error) { func (s *SocialGithub) FetchOrganizations(ctx context.Context, client *http.Client, organizationsUrl string) ([]string, error) {
url := organizationsUrl url := organizationsUrl
hasMore := true hasMore := true
logins := make([]string, 0) logins := make([]string, 0)
@ -160,7 +162,7 @@ func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl
} }
for hasMore { for hasMore {
response, err := s.httpGet(client, url) response, err := s.httpGet(ctx, client, url)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting organizations: %s", err) return nil, fmt.Errorf("error getting organizations: %s", err)
} }
@ -181,7 +183,7 @@ func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl
return logins, nil return logins, nil
} }
func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { func (s *SocialGithub) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
var data struct { var data struct {
Id int `json:"id"` Id int `json:"id"`
Login string `json:"login"` Login string `json:"login"`
@ -189,7 +191,7 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
Name string `json:"name"` Name string `json:"name"`
} }
response, err := s.httpGet(client, s.apiUrl) response, err := s.httpGet(ctx, client, s.apiUrl)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting user info: %s", err) return nil, fmt.Errorf("error getting user info: %s", err)
} }
@ -198,7 +200,7 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
return nil, fmt.Errorf("error unmarshalling user info: %s", err) return nil, fmt.Errorf("error unmarshalling user info: %s", err)
} }
teamMemberships, err := s.FetchTeamMemberships(client) teamMemberships, err := s.FetchTeamMemberships(ctx, client)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting user teams: %s", err) return nil, fmt.Errorf("error getting user teams: %s", err)
} }
@ -241,16 +243,16 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
organizationsUrl := fmt.Sprintf(s.apiUrl + "/orgs?per_page=100") organizationsUrl := fmt.Sprintf(s.apiUrl + "/orgs?per_page=100")
if !s.IsTeamMember(client) { if !s.IsTeamMember(ctx, client) {
return nil, ErrMissingTeamMembership return nil, ErrMissingTeamMembership
} }
if !s.IsOrganizationMember(client, organizationsUrl) { if !s.IsOrganizationMember(ctx, client, organizationsUrl) {
return nil, ErrMissingOrganizationMembership return nil, ErrMissingOrganizationMembership
} }
if userInfo.Email == "" { if userInfo.Email == "" {
userInfo.Email, err = s.FetchPrivateEmail(client) userInfo.Email, err = s.FetchPrivateEmail(ctx, client)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package social package social
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
@ -250,7 +251,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) {
AccessToken: "fake_token", AccessToken: "fake_token",
} }
got, err := s.UserInfo(server.Client(), token) got, err := s.UserInfo(context.Background(), server.Client(), token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
return return

View File

@ -1,6 +1,7 @@
package social package social
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -34,10 +35,10 @@ func (s *SocialGitlab) IsGroupMember(groups []string) bool {
return false return false
} }
func (s *SocialGitlab) GetGroups(client *http.Client) []string { func (s *SocialGitlab) GetGroups(ctx context.Context, client *http.Client) []string {
groups := make([]string, 0) groups := make([]string, 0)
for page, url := s.GetGroupsPage(client, s.apiUrl+"/groups"); page != nil; page, url = s.GetGroupsPage(client, url) { for page, url := s.GetGroupsPage(ctx, client, s.apiUrl+"/groups"); page != nil; page, url = s.GetGroupsPage(ctx, client, url) {
groups = append(groups, page...) groups = append(groups, page...)
} }
@ -45,7 +46,7 @@ func (s *SocialGitlab) GetGroups(client *http.Client) []string {
} }
// GetGroupsPage returns groups and link to the next page if response is paginated // GetGroupsPage returns groups and link to the next page if response is paginated
func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string, string) { func (s *SocialGitlab) GetGroupsPage(ctx context.Context, client *http.Client, url string) ([]string, string) {
type Group struct { type Group struct {
FullPath string `json:"full_path"` FullPath string `json:"full_path"`
} }
@ -59,7 +60,7 @@ func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string,
return nil, next return nil, next
} }
response, err := s.httpGet(client, url) response, err := s.httpGet(ctx, client, url)
if err != nil { if err != nil {
s.log.Error("Error getting groups from GitLab API", "err", err) s.log.Error("Error getting groups from GitLab API", "err", err)
return nil, next return nil, next
@ -86,7 +87,7 @@ func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string,
return fullPaths, next return fullPaths, next
} }
func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) { func (s *SocialGitlab) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
var data struct { var data struct {
Id int Id int
Username string Username string
@ -95,7 +96,7 @@ func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUse
State string State string
} }
response, err := s.httpGet(client, s.apiUrl+"/user") response, err := s.httpGet(ctx, client, s.apiUrl+"/user")
if err != nil { if err != nil {
return nil, fmt.Errorf("Error getting user info: %s", err) return nil, fmt.Errorf("Error getting user info: %s", err)
} }
@ -108,7 +109,7 @@ func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUse
return nil, fmt.Errorf("user %s is inactive", data.Username) return nil, fmt.Errorf("user %s is inactive", data.Username)
} }
groups := s.GetGroups(client) groups := s.GetGroups(ctx, client)
var role roletype.RoleType var role roletype.RoleType
var isGrafanaAdmin *bool = nil var isGrafanaAdmin *bool = nil

View File

@ -1,6 +1,7 @@
package social package social
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -159,7 +160,7 @@ func TestSocialGitlab_UserInfo(t *testing.T) {
} }
})) }))
provider.apiUrl = ts.URL + apiURI provider.apiUrl = ts.URL + apiURI
actualResult, err := provider.UserInfo(ts.Client(), nil) actualResult, err := provider.UserInfo(context.Background(), ts.Client(), nil)
if test.ExpectedError != nil { if test.ExpectedError != nil {
require.Equal(t, err, test.ExpectedError) require.Equal(t, err, test.ExpectedError)
return return

View File

@ -1,6 +1,7 @@
package social package social
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -16,14 +17,14 @@ type SocialGoogle struct {
apiUrl string apiUrl string
} }
func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
var data struct { var data struct {
Id string `json:"id"` Id string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Email string `json:"email"` Email string `json:"email"`
} }
response, err := s.httpGet(client, s.apiUrl) response, err := s.httpGet(ctx, client, s.apiUrl)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error getting user info: %s", err) return nil, fmt.Errorf("Error getting user info: %s", err)
} }

View File

@ -1,6 +1,7 @@
package social package social
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -43,7 +44,7 @@ func (s *SocialGrafanaCom) IsOrganizationMember(organizations []OrgRecord) bool
} }
// UserInfo is used for login credentials for the user // UserInfo is used for login credentials for the user
func (s *SocialGrafanaCom) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) { func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
var data struct { var data struct {
Id int `json:"id"` Id int `json:"id"`
Name string `json:"name"` Name string `json:"name"`
@ -53,7 +54,8 @@ func (s *SocialGrafanaCom) UserInfo(client *http.Client, _ *oauth2.Token) (*Basi
Orgs []OrgRecord `json:"orgs"` Orgs []OrgRecord `json:"orgs"`
} }
response, err := s.httpGet(client, s.url+"/api/oauth2/user") response, err := s.httpGet(ctx, client, s.url+"/api/oauth2/user")
if err != nil { if err != nil {
return nil, fmt.Errorf("Error getting user info: %s", err) return nil, fmt.Errorf("Error getting user info: %s", err)
} }

View File

@ -1,6 +1,7 @@
package social package social
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -81,7 +82,7 @@ func TestSocialGrafanaCom_UserInfo(t *testing.T) {
} }
})) }))
provider.url = ts.URL provider.url = ts.URL
actualResult, err := provider.UserInfo(ts.Client(), nil) actualResult, err := provider.UserInfo(context.Background(), ts.Client(), nil)
if test.ExpectedError != nil { if test.ExpectedError != nil {
require.Equal(t, err, test.ExpectedError) require.Equal(t, err, test.ExpectedError)
return return

View File

@ -1,6 +1,7 @@
package social package social
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -46,7 +47,7 @@ func (claims *OktaClaims) extractEmail() string {
return claims.Email return claims.Email
} }
func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
idToken := token.Extra("id_token") idToken := token.Extra("id_token")
if idToken == nil { if idToken == nil {
return nil, fmt.Errorf("no id_token found") return nil, fmt.Errorf("no id_token found")
@ -68,7 +69,7 @@ func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicU
} }
var data OktaUserInfoJson var data OktaUserInfoJson
err = s.extractAPI(&data, client) err = s.extractAPI(ctx, &data, client)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -105,8 +106,8 @@ func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicU
}, nil }, nil
} }
func (s *SocialOkta) extractAPI(data *OktaUserInfoJson, client *http.Client) error { func (s *SocialOkta) extractAPI(ctx context.Context, data *OktaUserInfoJson, client *http.Client) error {
rawUserInfoResponse, err := s.httpGet(client, s.apiUrl) rawUserInfoResponse, err := s.httpGet(ctx, client, s.apiUrl)
if err != nil { if err != nil {
s.log.Debug("Error getting user info response", "url", s.apiUrl, "error", err) s.log.Debug("Error getting user info response", "url", s.apiUrl, "error", err)
return fmt.Errorf("error getting user info response: %w", err) return fmt.Errorf("error getting user info response: %w", err)

View File

@ -1,6 +1,7 @@
package social package social
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -110,7 +111,7 @@ func TestSocialOkta_UserInfo(t *testing.T) {
Expiry: time.Now(), Expiry: time.Now(),
} }
token := staticToken.WithExtra(tt.OAuth2Extra) token := staticToken.WithExtra(tt.OAuth2Extra)
got, err := provider.UserInfo(server.Client(), token) got, err := provider.UserInfo(context.Background(), server.Client(), token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
return return

View File

@ -7,9 +7,11 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"net/http" "net/http"
"os" "os"
"strings" "strings"
"time"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/text/cases" "golang.org/x/text/cases"
@ -261,7 +263,7 @@ func (b *BasicUserInfo) String() string {
} }
type SocialConnector interface { type SocialConnector interface {
UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error)
IsEmailAllowed(email string) bool IsEmailAllowed(email string) bool
IsSignupAllowed() bool IsSignupAllowed() bool
@ -450,9 +452,19 @@ func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) {
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
InsecureSkipVerify: info.TlsSkipVerify, InsecureSkipVerify: info.TlsSkipVerify,
}, },
DialContext: (&net.Dialer{
Timeout: time.Second * 10,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 15 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
} }
oauthClient := &http.Client{ oauthClient := &http.Client{
Transport: tr, Transport: tr,
Timeout: time.Second * 15,
} }
if info.TlsClientCert != "" || info.TlsClientKey != "" { if info.TlsClientCert != "" || info.TlsClientKey != "" {

View File

@ -116,7 +116,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
} }
token.TokenType = "Bearer" token.TokenType = "Bearer"
userInfo, err := c.connector.UserInfo(c.connector.Client(clientCtx, token), token) userInfo, err := c.connector.UserInfo(ctx, c.connector.Client(clientCtx, token), token)
if err != nil { if err != nil {
var sErr *social.Error var sErr *social.Error
if errors.As(err, &sErr) { if errors.As(err, &sErr) {

View File

@ -8,13 +8,14 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/authn" "github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/org" "github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestOAuth_Authenticate(t *testing.T) { func TestOAuth_Authenticate(t *testing.T) {
@ -278,7 +279,7 @@ type fakeConnector struct {
social.SocialConnector social.SocialConnector
} }
func (f fakeConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { func (f fakeConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
return f.ExpectedUserInfo, f.ExpectedUserInfoErr return f.ExpectedUserInfo, f.ExpectedUserInfoErr
} }

View File

@ -273,7 +273,7 @@ func (m *MockSocialConnector) Type() int {
return args.Int(0) return args.Int(0)
} }
func (m *MockSocialConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) { func (m *MockSocialConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
args := m.Called(client, token) args := m.Called(client, token)
return args.Get(0).(*social.BasicUserInfo), args.Error(1) return args.Get(0).(*social.BasicUserInfo), args.Error(1)
} }