Auth: Add request context to UserInfo calls (#70007)

* use context for UserInfo requests

* set timeouts for oauth http client

* Update pkg/login/social/common.go

Co-authored-by: Ieva <ieva.vasiljeva@grafana.com>

---------

Co-authored-by: Ieva <ieva.vasiljeva@grafana.com>
This commit is contained in:
Jo 2023-06-14 12:30:40 +00:00 committed by GitHub
parent 1445a7cc5c
commit 914daef0fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 172 additions and 138 deletions

View File

@ -70,61 +70,61 @@ func genPKCECode() (string, string, error) {
return string(ascii), pkce, nil
}
func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
name := web.Params(ctx.Req)[":name"]
func (hs *HTTPServer) OAuthLogin(reqCtx *contextmodel.ReqContext) {
name := web.Params(reqCtx.Req)[":name"]
loginInfo := loginservice.LoginInfo{AuthModule: name}
if errorParam := ctx.Query("error"); errorParam != "" {
errorDesc := ctx.Query("error_description")
if errorParam := reqCtx.Query("error"); errorParam != "" {
errorDesc := reqCtx.Query("error_description")
oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc)
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc)
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc)
return
}
code := ctx.Query("code")
code := reqCtx.Query("code")
if hs.Cfg.AuthBrokerEnabled {
req := &authn.Request{HTTPRequest: ctx.Req, Resp: ctx.Resp}
req := &authn.Request{HTTPRequest: reqCtx.Req, Resp: reqCtx.Resp}
if code == "" {
redirect, err := hs.authnService.RedirectURL(ctx.Req.Context(), authn.ClientWithPrefix(name), req)
redirect, err := hs.authnService.RedirectURL(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
if err != nil {
ctx.Redirect(hs.redirectURLWithErrorCookie(ctx, err))
reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))
return
}
if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" {
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
}
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
ctx.Redirect(redirect.URL)
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
reqCtx.Redirect(redirect.URL)
return
}
identity, err := hs.authnService.Login(ctx.Req.Context(), authn.ClientWithPrefix(name), req)
identity, err := hs.authnService.Login(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
// NOTE: always delete these cookies, even if login failed
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
if err != nil {
ctx.Redirect(hs.redirectURLWithErrorCookie(ctx, err))
reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))
return
}
metrics.MApiLoginOAuth.Inc()
authn.HandleLoginRedirect(ctx.Req, ctx.Resp, hs.Cfg, identity, hs.ValidateRedirectTo)
authn.HandleLoginRedirect(reqCtx.Req, reqCtx.Resp, hs.Cfg, identity, hs.ValidateRedirectTo)
return
}
provider := hs.SocialService.GetOAuthInfoProvider(name)
if provider == nil {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, errors.New("OAuth not enabled"))
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, errors.New("OAuth not enabled"))
return
}
connect, err := hs.SocialService.GetConnector(name)
if err != nil {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, fmt.Errorf("no OAuth with name %s configured", name))
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, fmt.Errorf("no OAuth with name %s configured", name))
return
}
@ -133,15 +133,15 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
if provider.UsePKCE {
ascii, pkce, err := genPKCECode()
if err != nil {
ctx.Logger.Error("Generating PKCE failed", "error", err)
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
reqCtx.Logger.Error("Generating PKCE failed", "error", err)
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError,
PublicMessage: "An internal error occurred",
})
return
}
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
opts = append(opts,
oauth2.SetAuthURLParam("code_challenge", pkce),
@ -151,8 +151,8 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
state, err := GenStateString()
if err != nil {
ctx.Logger.Error("Generating state string failed", "err", err)
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
reqCtx.Logger.Error("Generating state string failed", "err", err)
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError,
PublicMessage: "An internal error occurred",
})
@ -160,32 +160,32 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
}
hashedState := hs.hashStatecode(state, provider.ClientSecret)
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
if provider.HostedDomain != "" {
opts = append(opts, oauth2.SetAuthURLParam("hd", provider.HostedDomain))
}
ctx.Redirect(connect.AuthCodeURL(state, opts...))
reqCtx.Redirect(connect.AuthCodeURL(state, opts...))
return
}
cookieState := ctx.GetCookie(OauthStateCookieName)
cookieState := reqCtx.GetCookie(OauthStateCookieName)
// delete cookie
cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
if cookieState == "" {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError,
PublicMessage: "login.OAuthLogin(missing saved state)",
})
return
}
queryState := hs.hashStatecode(ctx.Query("state"), provider.ClientSecret)
queryState := hs.hashStatecode(reqCtx.Query("state"), provider.ClientSecret)
oauthLogger.Info("state check", "queryState", queryState, "cookieState", cookieState)
if cookieState != queryState {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError,
PublicMessage: "login.OAuthLogin(state mismatch)",
})
@ -194,19 +194,20 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
oauthClient, err := hs.SocialService.GetOAuthHttpClient(name)
if err != nil {
ctx.Logger.Error("Failed to create OAuth http client", "error", err)
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
reqCtx.Logger.Error("Failed to create OAuth http client", "error", err)
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError,
PublicMessage: "login.OAuthLogin(" + err.Error() + ")",
})
return
}
oauthCtx := context.WithValue(context.Background(), oauth2.HTTPClient, oauthClient)
ctx := reqCtx.Req.Context()
oauthCtx := context.WithValue(ctx, oauth2.HTTPClient, oauthClient)
opts := []oauth2.AuthCodeOption{}
codeVerifier := ctx.GetCookie(OauthPKCECookieName)
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
codeVerifier := reqCtx.GetCookie(OauthPKCECookieName)
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
if codeVerifier != "" {
opts = append(opts,
oauth2.SetAuthURLParam("code_verifier", codeVerifier),
@ -216,7 +217,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
// get token from provider
token, err := connect.Exchange(oauthCtx, code, opts...)
if err != nil {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError,
PublicMessage: "login.OAuthLogin(NewTransportWithCode)",
Err: err,
@ -245,13 +246,13 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
client := connect.Client(oauthCtx, token)
// get user info
userInfo, err := connect.UserInfo(client, token)
userInfo, err := connect.UserInfo(ctx, client, token)
if err != nil {
var sErr *social.Error
if errors.As(err, &sErr) {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, sErr)
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, sErr)
} else {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError,
PublicMessage: fmt.Sprintf("login.OAuthLogin(get info from %s)", name),
Err: err,
@ -264,34 +265,34 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
// validate that we got at least an email address
if userInfo.Email == "" {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrNoEmail)
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrNoEmail)
return
}
// validate that the email is allowed to login to grafana
if !connect.IsEmailAllowed(userInfo.Email) {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrEmailNotAllowed)
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrEmailNotAllowed)
return
}
loginInfo.ExternalUser = *hs.buildExternalUserInfo(token, userInfo, name)
loginInfo.User, err = hs.SyncUser(ctx, &loginInfo.ExternalUser, connect)
loginInfo.User, err = hs.SyncUser(reqCtx, &loginInfo.ExternalUser, connect)
if err != nil {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, err)
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, err)
return
}
// login
if err := hs.loginUserWithUser(loginInfo.User, ctx); err != nil {
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, err)
if err := hs.loginUserWithUser(loginInfo.User, reqCtx); err != nil {
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, err)
return
}
loginInfo.HTTPStatus = http.StatusOK
hs.HooksService.RunLoginHook(&loginInfo, ctx)
hs.HooksService.RunLoginHook(&loginInfo, reqCtx)
metrics.MApiLoginOAuth.Inc()
ctx.Redirect(hs.GetRedirectURL(ctx))
reqCtx.Redirect(hs.GetRedirectURL(reqCtx))
}
// buildExternalUserInfo returns a ExternalUserInfo struct from OAuth user profile

View File

@ -2,6 +2,7 @@ package social
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@ -47,7 +48,7 @@ type azureAccessClaims struct {
TenantID string `json:"tid"`
}
func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
idToken := token.Extra("id_token")
if idToken == nil {
return nil, ErrIDTokenNotFound
@ -83,7 +84,7 @@ func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*Bas
}
logger.Debug("AzureAD OAuth: extracted role", "email", email, "role", role)
groups, err := s.extractGroups(client, claims, token)
groups, err := s.extractGroups(ctx, client, claims, token)
if err != nil {
return nil, fmt.Errorf("failed to extract groups: %w", err)
}
@ -176,7 +177,7 @@ type getAzureGroupResponse struct {
// Note: If user groups exceeds 200 no groups will be found in claims and URL to target the Graph API will be
// given instead.
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
if !s.forceUseGraphAPI {
logger.Debug("checking the claim for groups")
if len(claims.Groups) > 0 {
@ -199,7 +200,13 @@ func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, t
return nil, err
}
res, err := client.Post(endpoint, "application/json", bytes.NewBuffer(data))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
res, err := client.Do(req)
if err != nil {
return nil, err
}

View File

@ -473,7 +473,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
tt.args.client = s.Client(context.Background(), token)
}
got, err := s.UserInfo(tt.args.client, token)
got, err := s.UserInfo(context.Background(), tt.args.client, token)
if (err != nil) != tt.wantErr {
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
return
@ -617,7 +617,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
tt.args.client = s.Client(context.Background(), token)
}
got, err := s.UserInfo(tt.args.client, token)
got, err := s.UserInfo(context.Background(), tt.args.client, token)
if (err != nil) != tt.wantErr {
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@ -1,6 +1,7 @@
package social
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -42,10 +43,15 @@ func isEmailAllowed(email string, allowedDomains []string) bool {
return valid
}
func (s *SocialBase) httpGet(client *http.Client, url string) (response httpGetResponse, err error) {
r, err := client.Get(url)
if err != nil {
return
func (s *SocialBase) httpGet(ctx context.Context, client *http.Client, url string) (*httpGetResponse, error) {
req, errReq := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if errReq != nil {
return nil, errReq
}
r, errDo := client.Do(req)
if errDo != nil {
return nil, errDo
}
defer func() {
@ -54,21 +60,20 @@ func (s *SocialBase) httpGet(client *http.Client, url string) (response httpGetR
}
}()
body, err := io.ReadAll(r.Body)
if err != nil {
return
body, errRead := io.ReadAll(r.Body)
if errRead != nil {
return nil, errRead
}
response = httpGetResponse{body, r.Header}
response := &httpGetResponse{body, r.Header}
if r.StatusCode >= 300 {
err = fmt.Errorf(string(response.Body))
return
return nil, fmt.Errorf("unsuccessful response status code %d: %s", r.StatusCode, string(response.Body))
}
s.log.Debug("HTTP GET", "url", url, "status", r.Status, "response_body", string(response.Body))
err = nil
return
return response, nil
}
func (s *SocialBase) searchJSONForAttr(attributePath string, data []byte) (interface{}, error) {

View File

@ -3,6 +3,7 @@ package social
import (
"bytes"
"compress/zlib"
"context"
"encoding/base64"
"encoding/json"
"errors"
@ -50,12 +51,12 @@ func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool {
return false
}
func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
func (s *SocialGenericOAuth) IsTeamMember(ctx context.Context, client *http.Client) bool {
if len(s.teamIds) == 0 {
return true
}
teamMemberships, err := s.FetchTeamMemberships(client)
teamMemberships, err := s.FetchTeamMemberships(ctx, client)
if err != nil {
return false
}
@ -71,12 +72,12 @@ func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
return false
}
func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool {
func (s *SocialGenericOAuth) IsOrganizationMember(ctx context.Context, client *http.Client) bool {
if len(s.allowedOrganizations) == 0 {
return true
}
organizations, ok := s.FetchOrganizations(client)
organizations, ok := s.FetchOrganizations(ctx, client)
if !ok {
return false
}
@ -111,14 +112,14 @@ func (info *UserInfoJson) String() string {
info.Name, info.DisplayName, info.Login, info.Username, info.Email, info.Upn, info.Attributes)
}
func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
s.log.Debug("Getting user info")
toCheck := make([]*UserInfoJson, 0, 2)
if tokenData := s.extractFromToken(token); tokenData != nil {
toCheck = append(toCheck, tokenData)
}
if apiData := s.extractFromAPI(client); apiData != nil {
if apiData := s.extractFromAPI(ctx, client); apiData != nil {
toCheck = append(toCheck, apiData)
}
@ -179,7 +180,7 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
if userInfo.Email == "" {
var err error
userInfo.Email, err = s.FetchPrivateEmail(client)
userInfo.Email, err = s.FetchPrivateEmail(ctx, client)
if err != nil {
return nil, err
}
@ -191,11 +192,11 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
userInfo.Login = userInfo.Email
}
if !s.IsTeamMember(client) {
if !s.IsTeamMember(ctx, client) {
return nil, errors.New("user not a member of one of the required teams")
}
if !s.IsOrganizationMember(client) {
if !s.IsOrganizationMember(ctx, client) {
return nil, errors.New("user not a member of one of the required organizations")
}
@ -288,14 +289,14 @@ func (s *SocialGenericOAuth) extractFromToken(token *oauth2.Token) *UserInfoJson
return &data
}
func (s *SocialGenericOAuth) extractFromAPI(client *http.Client) *UserInfoJson {
func (s *SocialGenericOAuth) extractFromAPI(ctx context.Context, client *http.Client) *UserInfoJson {
s.log.Debug("Getting user info from API")
if s.apiUrl == "" {
s.log.Debug("No api url configured")
return nil
}
rawUserInfoResponse, err := s.httpGet(client, s.apiUrl)
rawUserInfoResponse, err := s.httpGet(ctx, client, s.apiUrl)
if err != nil {
s.log.Debug("Error getting user info from API", "url", s.apiUrl, "error", err)
return nil
@ -404,7 +405,7 @@ func (s *SocialGenericOAuth) extractGroups(data *UserInfoJson) ([]string, error)
return s.searchJSONForStringArrayAttr(s.groupsAttributePath, data.rawJSON)
}
func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) {
func (s *SocialGenericOAuth) FetchPrivateEmail(ctx context.Context, client *http.Client) (string, error) {
type Record struct {
Email string `json:"email"`
Primary bool `json:"primary"`
@ -413,7 +414,7 @@ func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, err
IsConfirmed bool `json:"is_confirmed"`
}
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/emails"))
if err != nil {
s.log.Error("Error getting email address", "url", s.apiUrl+"/emails", "error", err)
return "", fmt.Errorf("%v: %w", "Error getting email address", err)
@ -451,14 +452,14 @@ func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, err
return email, nil
}
func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]string, error) {
func (s *SocialGenericOAuth) FetchTeamMemberships(ctx context.Context, client *http.Client) ([]string, error) {
var err error
var ids []string
if s.teamsUrl == "" {
ids, err = s.fetchTeamMembershipsFromDeprecatedTeamsUrl(client)
ids, err = s.fetchTeamMembershipsFromDeprecatedTeamsUrl(ctx, client)
} else {
ids, err = s.fetchTeamMembershipsFromTeamsUrl(client)
ids, err = s.fetchTeamMembershipsFromTeamsUrl(ctx, client)
}
if err == nil {
@ -468,16 +469,14 @@ func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]string
return ids, err
}
func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(client *http.Client) ([]string, error) {
var response httpGetResponse
var err error
func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(ctx context.Context, client *http.Client) ([]string, error) {
var ids []string
type Record struct {
Id int `json:"id"`
}
response, err = s.httpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/teams"))
if err != nil {
s.log.Error("Error getting team memberships", "url", s.apiUrl+"/teams", "error", err)
return []string{}, err
@ -499,15 +498,12 @@ func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(client *
return ids, nil
}
func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(client *http.Client) ([]string, error) {
func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(ctx context.Context, client *http.Client) ([]string, error) {
if s.teamIdsAttributePath == "" {
return []string{}, nil
}
var response httpGetResponse
var err error
response, err = s.httpGet(client, fmt.Sprintf(s.teamsUrl))
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.teamsUrl))
if err != nil {
s.log.Error("Error getting team memberships", "url", s.teamsUrl, "error", err)
return nil, err
@ -516,12 +512,12 @@ func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(client *http.Clien
return s.searchJSONForStringArrayAttr(s.teamIdsAttributePath, response.Body)
}
func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, bool) {
func (s *SocialGenericOAuth) FetchOrganizations(ctx context.Context, client *http.Client) ([]string, bool) {
type Record struct {
Login string `json:"login"`
}
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/orgs"))
if err != nil {
s.log.Error("Error getting organizations", "url", s.apiUrl+"/orgs", "error", err)
return nil, false

View File

@ -1,6 +1,7 @@
package social
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
@ -490,7 +491,7 @@ func TestUserInfoSearchesForEmailAndRole(t *testing.T) {
}
token := staticToken.WithExtra(test.OAuth2Extra)
actualResult, err := provider.UserInfo(ts.Client(), token)
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
require.NoError(t, err)
require.Equal(t, test.ExpectedEmail, actualResult.Email)
require.Equal(t, test.ExpectedEmail, actualResult.Login)
@ -588,7 +589,7 @@ func TestUserInfoSearchesForLogin(t *testing.T) {
}
token := staticToken.WithExtra(test.OAuth2Extra)
actualResult, err := provider.UserInfo(ts.Client(), token)
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
require.NoError(t, err)
require.Equal(t, test.ExpectedLogin, actualResult.Login)
})
@ -686,7 +687,7 @@ func TestUserInfoSearchesForName(t *testing.T) {
}
token := staticToken.WithExtra(test.OAuth2Extra)
actualResult, err := provider.UserInfo(ts.Client(), token)
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
require.NoError(t, err)
require.Equal(t, test.ExpectedName, actualResult.Name)
})
@ -755,7 +756,7 @@ func TestUserInfoSearchesForGroup(t *testing.T) {
Expiry: time.Now(),
}
userInfo, err := provider.UserInfo(ts.Client(), token)
userInfo, err := provider.UserInfo(context.Background(), ts.Client(), token)
assert.NoError(t, err)
assert.Equal(t, test.expectedResult, userInfo.Groups)
})

View File

@ -1,6 +1,7 @@
package social
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -35,12 +36,12 @@ var (
ErrMissingOrganizationMembership = Error{"user not a member of one of the required organizations"}
)
func (s *SocialGithub) IsTeamMember(client *http.Client) bool {
func (s *SocialGithub) IsTeamMember(ctx context.Context, client *http.Client) bool {
if len(s.teamIds) == 0 {
return true
}
teamMemberships, err := s.FetchTeamMemberships(client)
teamMemberships, err := s.FetchTeamMemberships(ctx, client)
if err != nil {
return false
}
@ -56,12 +57,13 @@ func (s *SocialGithub) IsTeamMember(client *http.Client) bool {
return false
}
func (s *SocialGithub) IsOrganizationMember(client *http.Client, organizationsUrl string) bool {
func (s *SocialGithub) IsOrganizationMember(ctx context.Context,
client *http.Client, organizationsUrl string) bool {
if len(s.allowedOrganizations) == 0 {
return true
}
organizations, err := s.FetchOrganizations(client, organizationsUrl)
organizations, err := s.FetchOrganizations(ctx, client, organizationsUrl)
if err != nil {
return false
}
@ -77,14 +79,14 @@ func (s *SocialGithub) IsOrganizationMember(client *http.Client, organizationsUr
return false
}
func (s *SocialGithub) FetchPrivateEmail(client *http.Client) (string, error) {
func (s *SocialGithub) FetchPrivateEmail(ctx context.Context, client *http.Client) (string, error) {
type Record struct {
Email string `json:"email"`
Primary bool `json:"primary"`
Verified bool `json:"verified"`
}
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/emails"))
if err != nil {
return "", fmt.Errorf("Error getting email address: %s", err)
}
@ -106,13 +108,13 @@ func (s *SocialGithub) FetchPrivateEmail(client *http.Client) (string, error) {
return email, nil
}
func (s *SocialGithub) FetchTeamMemberships(client *http.Client) ([]GithubTeam, error) {
func (s *SocialGithub) FetchTeamMemberships(ctx context.Context, client *http.Client) ([]GithubTeam, error) {
url := fmt.Sprintf(s.apiUrl + "/teams?per_page=100")
hasMore := true
teams := make([]GithubTeam, 0)
for hasMore {
response, err := s.httpGet(client, url)
response, err := s.httpGet(ctx, client, url)
if err != nil {
return nil, fmt.Errorf("Error getting team memberships: %s", err)
}
@ -150,7 +152,7 @@ func (s *SocialGithub) HasMoreRecords(headers http.Header) (string, bool) {
return url, true
}
func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl string) ([]string, error) {
func (s *SocialGithub) FetchOrganizations(ctx context.Context, client *http.Client, organizationsUrl string) ([]string, error) {
url := organizationsUrl
hasMore := true
logins := make([]string, 0)
@ -160,7 +162,7 @@ func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl
}
for hasMore {
response, err := s.httpGet(client, url)
response, err := s.httpGet(ctx, client, url)
if err != nil {
return nil, fmt.Errorf("error getting organizations: %s", err)
}
@ -181,7 +183,7 @@ func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl
return logins, nil
}
func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
func (s *SocialGithub) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
var data struct {
Id int `json:"id"`
Login string `json:"login"`
@ -189,7 +191,7 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
Name string `json:"name"`
}
response, err := s.httpGet(client, s.apiUrl)
response, err := s.httpGet(ctx, client, s.apiUrl)
if err != nil {
return nil, fmt.Errorf("error getting user info: %s", err)
}
@ -198,7 +200,7 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
return nil, fmt.Errorf("error unmarshalling user info: %s", err)
}
teamMemberships, err := s.FetchTeamMemberships(client)
teamMemberships, err := s.FetchTeamMemberships(ctx, client)
if err != nil {
return nil, fmt.Errorf("error getting user teams: %s", err)
}
@ -241,16 +243,16 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
organizationsUrl := fmt.Sprintf(s.apiUrl + "/orgs?per_page=100")
if !s.IsTeamMember(client) {
if !s.IsTeamMember(ctx, client) {
return nil, ErrMissingTeamMembership
}
if !s.IsOrganizationMember(client, organizationsUrl) {
if !s.IsOrganizationMember(ctx, client, organizationsUrl) {
return nil, ErrMissingOrganizationMembership
}
if userInfo.Email == "" {
userInfo.Email, err = s.FetchPrivateEmail(client)
userInfo.Email, err = s.FetchPrivateEmail(ctx, client)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package social
import (
"context"
"net/http"
"net/http/httptest"
"reflect"
@ -250,7 +251,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) {
AccessToken: "fake_token",
}
got, err := s.UserInfo(server.Client(), token)
got, err := s.UserInfo(context.Background(), server.Client(), token)
if (err != nil) != tt.wantErr {
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@ -1,6 +1,7 @@
package social
import (
"context"
"encoding/json"
"fmt"
"net/http"
@ -34,10 +35,10 @@ func (s *SocialGitlab) IsGroupMember(groups []string) bool {
return false
}
func (s *SocialGitlab) GetGroups(client *http.Client) []string {
func (s *SocialGitlab) GetGroups(ctx context.Context, client *http.Client) []string {
groups := make([]string, 0)
for page, url := s.GetGroupsPage(client, s.apiUrl+"/groups"); page != nil; page, url = s.GetGroupsPage(client, url) {
for page, url := s.GetGroupsPage(ctx, client, s.apiUrl+"/groups"); page != nil; page, url = s.GetGroupsPage(ctx, client, url) {
groups = append(groups, page...)
}
@ -45,7 +46,7 @@ func (s *SocialGitlab) GetGroups(client *http.Client) []string {
}
// GetGroupsPage returns groups and link to the next page if response is paginated
func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string, string) {
func (s *SocialGitlab) GetGroupsPage(ctx context.Context, client *http.Client, url string) ([]string, string) {
type Group struct {
FullPath string `json:"full_path"`
}
@ -59,7 +60,7 @@ func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string,
return nil, next
}
response, err := s.httpGet(client, url)
response, err := s.httpGet(ctx, client, url)
if err != nil {
s.log.Error("Error getting groups from GitLab API", "err", err)
return nil, next
@ -86,7 +87,7 @@ func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string,
return fullPaths, next
}
func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
func (s *SocialGitlab) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
var data struct {
Id int
Username string
@ -95,7 +96,7 @@ func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUse
State string
}
response, err := s.httpGet(client, s.apiUrl+"/user")
response, err := s.httpGet(ctx, client, s.apiUrl+"/user")
if err != nil {
return nil, fmt.Errorf("Error getting user info: %s", err)
}
@ -108,7 +109,7 @@ func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUse
return nil, fmt.Errorf("user %s is inactive", data.Username)
}
groups := s.GetGroups(client)
groups := s.GetGroups(ctx, client)
var role roletype.RoleType
var isGrafanaAdmin *bool = nil

View File

@ -1,6 +1,7 @@
package social
import (
"context"
"net/http"
"net/http/httptest"
"strings"
@ -159,7 +160,7 @@ func TestSocialGitlab_UserInfo(t *testing.T) {
}
}))
provider.apiUrl = ts.URL + apiURI
actualResult, err := provider.UserInfo(ts.Client(), nil)
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), nil)
if test.ExpectedError != nil {
require.Equal(t, err, test.ExpectedError)
return

View File

@ -1,6 +1,7 @@
package social
import (
"context"
"encoding/json"
"fmt"
"net/http"
@ -16,14 +17,14 @@ type SocialGoogle struct {
apiUrl string
}
func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
var data struct {
Id string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
}
response, err := s.httpGet(client, s.apiUrl)
response, err := s.httpGet(ctx, client, s.apiUrl)
if err != nil {
return nil, fmt.Errorf("Error getting user info: %s", err)
}

View File

@ -1,6 +1,7 @@
package social
import (
"context"
"encoding/json"
"fmt"
"net/http"
@ -43,7 +44,7 @@ func (s *SocialGrafanaCom) IsOrganizationMember(organizations []OrgRecord) bool
}
// UserInfo is used for login credentials for the user
func (s *SocialGrafanaCom) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
var data struct {
Id int `json:"id"`
Name string `json:"name"`
@ -53,7 +54,8 @@ func (s *SocialGrafanaCom) UserInfo(client *http.Client, _ *oauth2.Token) (*Basi
Orgs []OrgRecord `json:"orgs"`
}
response, err := s.httpGet(client, s.url+"/api/oauth2/user")
response, err := s.httpGet(ctx, client, s.url+"/api/oauth2/user")
if err != nil {
return nil, fmt.Errorf("Error getting user info: %s", err)
}

View File

@ -1,6 +1,7 @@
package social
import (
"context"
"net/http"
"net/http/httptest"
"testing"
@ -81,7 +82,7 @@ func TestSocialGrafanaCom_UserInfo(t *testing.T) {
}
}))
provider.url = ts.URL
actualResult, err := provider.UserInfo(ts.Client(), nil)
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), nil)
if test.ExpectedError != nil {
require.Equal(t, err, test.ExpectedError)
return

View File

@ -1,6 +1,7 @@
package social
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -46,7 +47,7 @@ func (claims *OktaClaims) extractEmail() string {
return claims.Email
}
func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
idToken := token.Extra("id_token")
if idToken == nil {
return nil, fmt.Errorf("no id_token found")
@ -68,7 +69,7 @@ func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicU
}
var data OktaUserInfoJson
err = s.extractAPI(&data, client)
err = s.extractAPI(ctx, &data, client)
if err != nil {
return nil, err
}
@ -105,8 +106,8 @@ func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicU
}, nil
}
func (s *SocialOkta) extractAPI(data *OktaUserInfoJson, client *http.Client) error {
rawUserInfoResponse, err := s.httpGet(client, s.apiUrl)
func (s *SocialOkta) extractAPI(ctx context.Context, data *OktaUserInfoJson, client *http.Client) error {
rawUserInfoResponse, err := s.httpGet(ctx, client, s.apiUrl)
if err != nil {
s.log.Debug("Error getting user info response", "url", s.apiUrl, "error", err)
return fmt.Errorf("error getting user info response: %w", err)

View File

@ -1,6 +1,7 @@
package social
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@ -110,7 +111,7 @@ func TestSocialOkta_UserInfo(t *testing.T) {
Expiry: time.Now(),
}
token := staticToken.WithExtra(tt.OAuth2Extra)
got, err := provider.UserInfo(server.Client(), token)
got, err := provider.UserInfo(context.Background(), server.Client(), token)
if (err != nil) != tt.wantErr {
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
return

View File

@ -7,9 +7,11 @@ import (
"crypto/x509"
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"strings"
"time"
"golang.org/x/oauth2"
"golang.org/x/text/cases"
@ -261,7 +263,7 @@ func (b *BasicUserInfo) String() string {
}
type SocialConnector interface {
UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error)
UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error)
IsEmailAllowed(email string) bool
IsSignupAllowed() bool
@ -450,9 +452,19 @@ func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) {
TLSClientConfig: &tls.Config{
InsecureSkipVerify: info.TlsSkipVerify,
},
DialContext: (&net.Dialer{
Timeout: time.Second * 10,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 15 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
}
oauthClient := &http.Client{
Transport: tr,
Timeout: time.Second * 15,
}
if info.TlsClientCert != "" || info.TlsClientKey != "" {

View File

@ -116,7 +116,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
}
token.TokenType = "Bearer"
userInfo, err := c.connector.UserInfo(c.connector.Client(clientCtx, token), token)
userInfo, err := c.connector.UserInfo(ctx, c.connector.Client(clientCtx, token), token)
if err != nil {
var sErr *social.Error
if errors.As(err, &sErr) {

View File

@ -8,13 +8,14 @@ import (
"golang.org/x/oauth2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOAuth_Authenticate(t *testing.T) {
@ -278,7 +279,7 @@ type fakeConnector struct {
social.SocialConnector
}
func (f fakeConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
func (f fakeConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
return f.ExpectedUserInfo, f.ExpectedUserInfoErr
}

View File

@ -273,7 +273,7 @@ func (m *MockSocialConnector) Type() int {
return args.Int(0)
}
func (m *MockSocialConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
func (m *MockSocialConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
args := m.Called(client, token)
return args.Get(0).(*social.BasicUserInfo), args.Error(1)
}