mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Add request context to UserInfo calls (#70007)
* use context for UserInfo requests * set timeouts for oauth http client * Update pkg/login/social/common.go Co-authored-by: Ieva <ieva.vasiljeva@grafana.com> --------- Co-authored-by: Ieva <ieva.vasiljeva@grafana.com>
This commit is contained in:
parent
1445a7cc5c
commit
914daef0fd
@ -70,61 +70,61 @@ func genPKCECode() (string, string, error) {
|
||||
return string(ascii), pkce, nil
|
||||
}
|
||||
|
||||
func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
||||
name := web.Params(ctx.Req)[":name"]
|
||||
func (hs *HTTPServer) OAuthLogin(reqCtx *contextmodel.ReqContext) {
|
||||
name := web.Params(reqCtx.Req)[":name"]
|
||||
loginInfo := loginservice.LoginInfo{AuthModule: name}
|
||||
|
||||
if errorParam := ctx.Query("error"); errorParam != "" {
|
||||
errorDesc := ctx.Query("error_description")
|
||||
if errorParam := reqCtx.Query("error"); errorParam != "" {
|
||||
errorDesc := reqCtx.Query("error_description")
|
||||
oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc)
|
||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc)
|
||||
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc)
|
||||
return
|
||||
}
|
||||
|
||||
code := ctx.Query("code")
|
||||
code := reqCtx.Query("code")
|
||||
|
||||
if hs.Cfg.AuthBrokerEnabled {
|
||||
req := &authn.Request{HTTPRequest: ctx.Req, Resp: ctx.Resp}
|
||||
req := &authn.Request{HTTPRequest: reqCtx.Req, Resp: reqCtx.Resp}
|
||||
if code == "" {
|
||||
redirect, err := hs.authnService.RedirectURL(ctx.Req.Context(), authn.ClientWithPrefix(name), req)
|
||||
redirect, err := hs.authnService.RedirectURL(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
|
||||
if err != nil {
|
||||
ctx.Redirect(hs.redirectURLWithErrorCookie(ctx, err))
|
||||
reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))
|
||||
return
|
||||
}
|
||||
|
||||
if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" {
|
||||
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
}
|
||||
|
||||
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
ctx.Redirect(redirect.URL)
|
||||
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
reqCtx.Redirect(redirect.URL)
|
||||
return
|
||||
}
|
||||
|
||||
identity, err := hs.authnService.Login(ctx.Req.Context(), authn.ClientWithPrefix(name), req)
|
||||
identity, err := hs.authnService.Login(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
|
||||
// NOTE: always delete these cookies, even if login failed
|
||||
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
||||
cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
||||
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
||||
cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
||||
|
||||
if err != nil {
|
||||
ctx.Redirect(hs.redirectURLWithErrorCookie(ctx, err))
|
||||
reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))
|
||||
return
|
||||
}
|
||||
|
||||
metrics.MApiLoginOAuth.Inc()
|
||||
authn.HandleLoginRedirect(ctx.Req, ctx.Resp, hs.Cfg, identity, hs.ValidateRedirectTo)
|
||||
authn.HandleLoginRedirect(reqCtx.Req, reqCtx.Resp, hs.Cfg, identity, hs.ValidateRedirectTo)
|
||||
return
|
||||
}
|
||||
|
||||
provider := hs.SocialService.GetOAuthInfoProvider(name)
|
||||
if provider == nil {
|
||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, errors.New("OAuth not enabled"))
|
||||
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, errors.New("OAuth not enabled"))
|
||||
return
|
||||
}
|
||||
|
||||
connect, err := hs.SocialService.GetConnector(name)
|
||||
if err != nil {
|
||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, fmt.Errorf("no OAuth with name %s configured", name))
|
||||
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, fmt.Errorf("no OAuth with name %s configured", name))
|
||||
return
|
||||
}
|
||||
|
||||
@ -133,15 +133,15 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
||||
if provider.UsePKCE {
|
||||
ascii, pkce, err := genPKCECode()
|
||||
if err != nil {
|
||||
ctx.Logger.Error("Generating PKCE failed", "error", err)
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
reqCtx.Logger.Error("Generating PKCE failed", "error", err)
|
||||
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||
HttpStatus: http.StatusInternalServerError,
|
||||
PublicMessage: "An internal error occurred",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
|
||||
opts = append(opts,
|
||||
oauth2.SetAuthURLParam("code_challenge", pkce),
|
||||
@ -151,8 +151,8 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
||||
|
||||
state, err := GenStateString()
|
||||
if err != nil {
|
||||
ctx.Logger.Error("Generating state string failed", "err", err)
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
reqCtx.Logger.Error("Generating state string failed", "err", err)
|
||||
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||
HttpStatus: http.StatusInternalServerError,
|
||||
PublicMessage: "An internal error occurred",
|
||||
})
|
||||
@ -160,32 +160,32 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
||||
}
|
||||
|
||||
hashedState := hs.hashStatecode(state, provider.ClientSecret)
|
||||
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
if provider.HostedDomain != "" {
|
||||
opts = append(opts, oauth2.SetAuthURLParam("hd", provider.HostedDomain))
|
||||
}
|
||||
|
||||
ctx.Redirect(connect.AuthCodeURL(state, opts...))
|
||||
reqCtx.Redirect(connect.AuthCodeURL(state, opts...))
|
||||
return
|
||||
}
|
||||
|
||||
cookieState := ctx.GetCookie(OauthStateCookieName)
|
||||
cookieState := reqCtx.GetCookie(OauthStateCookieName)
|
||||
|
||||
// delete cookie
|
||||
cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
||||
cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
||||
|
||||
if cookieState == "" {
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||
HttpStatus: http.StatusInternalServerError,
|
||||
PublicMessage: "login.OAuthLogin(missing saved state)",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
queryState := hs.hashStatecode(ctx.Query("state"), provider.ClientSecret)
|
||||
queryState := hs.hashStatecode(reqCtx.Query("state"), provider.ClientSecret)
|
||||
oauthLogger.Info("state check", "queryState", queryState, "cookieState", cookieState)
|
||||
if cookieState != queryState {
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||
HttpStatus: http.StatusInternalServerError,
|
||||
PublicMessage: "login.OAuthLogin(state mismatch)",
|
||||
})
|
||||
@ -194,19 +194,20 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
||||
|
||||
oauthClient, err := hs.SocialService.GetOAuthHttpClient(name)
|
||||
if err != nil {
|
||||
ctx.Logger.Error("Failed to create OAuth http client", "error", err)
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
reqCtx.Logger.Error("Failed to create OAuth http client", "error", err)
|
||||
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||
HttpStatus: http.StatusInternalServerError,
|
||||
PublicMessage: "login.OAuthLogin(" + err.Error() + ")",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
oauthCtx := context.WithValue(context.Background(), oauth2.HTTPClient, oauthClient)
|
||||
ctx := reqCtx.Req.Context()
|
||||
oauthCtx := context.WithValue(ctx, oauth2.HTTPClient, oauthClient)
|
||||
opts := []oauth2.AuthCodeOption{}
|
||||
|
||||
codeVerifier := ctx.GetCookie(OauthPKCECookieName)
|
||||
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
||||
codeVerifier := reqCtx.GetCookie(OauthPKCECookieName)
|
||||
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
||||
if codeVerifier != "" {
|
||||
opts = append(opts,
|
||||
oauth2.SetAuthURLParam("code_verifier", codeVerifier),
|
||||
@ -216,7 +217,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
||||
// get token from provider
|
||||
token, err := connect.Exchange(oauthCtx, code, opts...)
|
||||
if err != nil {
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||
HttpStatus: http.StatusInternalServerError,
|
||||
PublicMessage: "login.OAuthLogin(NewTransportWithCode)",
|
||||
Err: err,
|
||||
@ -245,13 +246,13 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
||||
client := connect.Client(oauthCtx, token)
|
||||
|
||||
// get user info
|
||||
userInfo, err := connect.UserInfo(client, token)
|
||||
userInfo, err := connect.UserInfo(ctx, client, token)
|
||||
if err != nil {
|
||||
var sErr *social.Error
|
||||
if errors.As(err, &sErr) {
|
||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, sErr)
|
||||
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, sErr)
|
||||
} else {
|
||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
||||
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||
HttpStatus: http.StatusInternalServerError,
|
||||
PublicMessage: fmt.Sprintf("login.OAuthLogin(get info from %s)", name),
|
||||
Err: err,
|
||||
@ -264,34 +265,34 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
||||
|
||||
// validate that we got at least an email address
|
||||
if userInfo.Email == "" {
|
||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrNoEmail)
|
||||
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrNoEmail)
|
||||
return
|
||||
}
|
||||
|
||||
// validate that the email is allowed to login to grafana
|
||||
if !connect.IsEmailAllowed(userInfo.Email) {
|
||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrEmailNotAllowed)
|
||||
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrEmailNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
loginInfo.ExternalUser = *hs.buildExternalUserInfo(token, userInfo, name)
|
||||
loginInfo.User, err = hs.SyncUser(ctx, &loginInfo.ExternalUser, connect)
|
||||
loginInfo.User, err = hs.SyncUser(reqCtx, &loginInfo.ExternalUser, connect)
|
||||
if err != nil {
|
||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, err)
|
||||
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, err)
|
||||
return
|
||||
}
|
||||
|
||||
// login
|
||||
if err := hs.loginUserWithUser(loginInfo.User, ctx); err != nil {
|
||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, err)
|
||||
if err := hs.loginUserWithUser(loginInfo.User, reqCtx); err != nil {
|
||||
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, err)
|
||||
return
|
||||
}
|
||||
|
||||
loginInfo.HTTPStatus = http.StatusOK
|
||||
hs.HooksService.RunLoginHook(&loginInfo, ctx)
|
||||
hs.HooksService.RunLoginHook(&loginInfo, reqCtx)
|
||||
metrics.MApiLoginOAuth.Inc()
|
||||
|
||||
ctx.Redirect(hs.GetRedirectURL(ctx))
|
||||
reqCtx.Redirect(hs.GetRedirectURL(reqCtx))
|
||||
}
|
||||
|
||||
// buildExternalUserInfo returns a ExternalUserInfo struct from OAuth user profile
|
||||
|
@ -2,6 +2,7 @@ package social
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -47,7 +48,7 @@ type azureAccessClaims struct {
|
||||
TenantID string `json:"tid"`
|
||||
}
|
||||
|
||||
func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
idToken := token.Extra("id_token")
|
||||
if idToken == nil {
|
||||
return nil, ErrIDTokenNotFound
|
||||
@ -83,7 +84,7 @@ func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*Bas
|
||||
}
|
||||
logger.Debug("AzureAD OAuth: extracted role", "email", email, "role", role)
|
||||
|
||||
groups, err := s.extractGroups(client, claims, token)
|
||||
groups, err := s.extractGroups(ctx, client, claims, token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extract groups: %w", err)
|
||||
}
|
||||
@ -176,7 +177,7 @@ type getAzureGroupResponse struct {
|
||||
// Note: If user groups exceeds 200 no groups will be found in claims and URL to target the Graph API will be
|
||||
// given instead.
|
||||
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
|
||||
func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
|
||||
func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
|
||||
if !s.forceUseGraphAPI {
|
||||
logger.Debug("checking the claim for groups")
|
||||
if len(claims.Groups) > 0 {
|
||||
@ -199,7 +200,13 @@ func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, t
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := client.Post(endpoint, "application/json", bytes.NewBuffer(data))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -473,7 +473,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
||||
tt.args.client = s.Client(context.Background(), token)
|
||||
}
|
||||
|
||||
got, err := s.UserInfo(tt.args.client, token)
|
||||
got, err := s.UserInfo(context.Background(), tt.args.client, token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@ -617,7 +617,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
|
||||
tt.args.client = s.Client(context.Background(), token)
|
||||
}
|
||||
|
||||
got, err := s.UserInfo(tt.args.client, token)
|
||||
got, err := s.UserInfo(context.Background(), tt.args.client, token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -1,6 +1,7 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -42,10 +43,15 @@ func isEmailAllowed(email string, allowedDomains []string) bool {
|
||||
return valid
|
||||
}
|
||||
|
||||
func (s *SocialBase) httpGet(client *http.Client, url string) (response httpGetResponse, err error) {
|
||||
r, err := client.Get(url)
|
||||
if err != nil {
|
||||
return
|
||||
func (s *SocialBase) httpGet(ctx context.Context, client *http.Client, url string) (*httpGetResponse, error) {
|
||||
req, errReq := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
}
|
||||
|
||||
r, errDo := client.Do(req)
|
||||
if errDo != nil {
|
||||
return nil, errDo
|
||||
}
|
||||
|
||||
defer func() {
|
||||
@ -54,21 +60,20 @@ func (s *SocialBase) httpGet(client *http.Client, url string) (response httpGetR
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return
|
||||
body, errRead := io.ReadAll(r.Body)
|
||||
if errRead != nil {
|
||||
return nil, errRead
|
||||
}
|
||||
|
||||
response = httpGetResponse{body, r.Header}
|
||||
response := &httpGetResponse{body, r.Header}
|
||||
|
||||
if r.StatusCode >= 300 {
|
||||
err = fmt.Errorf(string(response.Body))
|
||||
return
|
||||
return nil, fmt.Errorf("unsuccessful response status code %d: %s", r.StatusCode, string(response.Body))
|
||||
}
|
||||
|
||||
s.log.Debug("HTTP GET", "url", url, "status", r.Status, "response_body", string(response.Body))
|
||||
|
||||
err = nil
|
||||
return
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *SocialBase) searchJSONForAttr(attributePath string, data []byte) (interface{}, error) {
|
||||
|
@ -3,6 +3,7 @@ package social
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@ -50,12 +51,12 @@ func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
|
||||
func (s *SocialGenericOAuth) IsTeamMember(ctx context.Context, client *http.Client) bool {
|
||||
if len(s.teamIds) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
teamMemberships, err := s.FetchTeamMemberships(client)
|
||||
teamMemberships, err := s.FetchTeamMemberships(ctx, client)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@ -71,12 +72,12 @@ func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool {
|
||||
func (s *SocialGenericOAuth) IsOrganizationMember(ctx context.Context, client *http.Client) bool {
|
||||
if len(s.allowedOrganizations) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
organizations, ok := s.FetchOrganizations(client)
|
||||
organizations, ok := s.FetchOrganizations(ctx, client)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@ -111,14 +112,14 @@ func (info *UserInfoJson) String() string {
|
||||
info.Name, info.DisplayName, info.Login, info.Username, info.Email, info.Upn, info.Attributes)
|
||||
}
|
||||
|
||||
func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
s.log.Debug("Getting user info")
|
||||
toCheck := make([]*UserInfoJson, 0, 2)
|
||||
|
||||
if tokenData := s.extractFromToken(token); tokenData != nil {
|
||||
toCheck = append(toCheck, tokenData)
|
||||
}
|
||||
if apiData := s.extractFromAPI(client); apiData != nil {
|
||||
if apiData := s.extractFromAPI(ctx, client); apiData != nil {
|
||||
toCheck = append(toCheck, apiData)
|
||||
}
|
||||
|
||||
@ -179,7 +180,7 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
|
||||
|
||||
if userInfo.Email == "" {
|
||||
var err error
|
||||
userInfo.Email, err = s.FetchPrivateEmail(client)
|
||||
userInfo.Email, err = s.FetchPrivateEmail(ctx, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -191,11 +192,11 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
|
||||
userInfo.Login = userInfo.Email
|
||||
}
|
||||
|
||||
if !s.IsTeamMember(client) {
|
||||
if !s.IsTeamMember(ctx, client) {
|
||||
return nil, errors.New("user not a member of one of the required teams")
|
||||
}
|
||||
|
||||
if !s.IsOrganizationMember(client) {
|
||||
if !s.IsOrganizationMember(ctx, client) {
|
||||
return nil, errors.New("user not a member of one of the required organizations")
|
||||
}
|
||||
|
||||
@ -288,14 +289,14 @@ func (s *SocialGenericOAuth) extractFromToken(token *oauth2.Token) *UserInfoJson
|
||||
return &data
|
||||
}
|
||||
|
||||
func (s *SocialGenericOAuth) extractFromAPI(client *http.Client) *UserInfoJson {
|
||||
func (s *SocialGenericOAuth) extractFromAPI(ctx context.Context, client *http.Client) *UserInfoJson {
|
||||
s.log.Debug("Getting user info from API")
|
||||
if s.apiUrl == "" {
|
||||
s.log.Debug("No api url configured")
|
||||
return nil
|
||||
}
|
||||
|
||||
rawUserInfoResponse, err := s.httpGet(client, s.apiUrl)
|
||||
rawUserInfoResponse, err := s.httpGet(ctx, client, s.apiUrl)
|
||||
if err != nil {
|
||||
s.log.Debug("Error getting user info from API", "url", s.apiUrl, "error", err)
|
||||
return nil
|
||||
@ -404,7 +405,7 @@ func (s *SocialGenericOAuth) extractGroups(data *UserInfoJson) ([]string, error)
|
||||
return s.searchJSONForStringArrayAttr(s.groupsAttributePath, data.rawJSON)
|
||||
}
|
||||
|
||||
func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) {
|
||||
func (s *SocialGenericOAuth) FetchPrivateEmail(ctx context.Context, client *http.Client) (string, error) {
|
||||
type Record struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
@ -413,7 +414,7 @@ func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, err
|
||||
IsConfirmed bool `json:"is_confirmed"`
|
||||
}
|
||||
|
||||
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
|
||||
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/emails"))
|
||||
if err != nil {
|
||||
s.log.Error("Error getting email address", "url", s.apiUrl+"/emails", "error", err)
|
||||
return "", fmt.Errorf("%v: %w", "Error getting email address", err)
|
||||
@ -451,14 +452,14 @@ func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, err
|
||||
return email, nil
|
||||
}
|
||||
|
||||
func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]string, error) {
|
||||
func (s *SocialGenericOAuth) FetchTeamMemberships(ctx context.Context, client *http.Client) ([]string, error) {
|
||||
var err error
|
||||
var ids []string
|
||||
|
||||
if s.teamsUrl == "" {
|
||||
ids, err = s.fetchTeamMembershipsFromDeprecatedTeamsUrl(client)
|
||||
ids, err = s.fetchTeamMembershipsFromDeprecatedTeamsUrl(ctx, client)
|
||||
} else {
|
||||
ids, err = s.fetchTeamMembershipsFromTeamsUrl(client)
|
||||
ids, err = s.fetchTeamMembershipsFromTeamsUrl(ctx, client)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
@ -468,16 +469,14 @@ func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]string
|
||||
return ids, err
|
||||
}
|
||||
|
||||
func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(client *http.Client) ([]string, error) {
|
||||
var response httpGetResponse
|
||||
var err error
|
||||
func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(ctx context.Context, client *http.Client) ([]string, error) {
|
||||
var ids []string
|
||||
|
||||
type Record struct {
|
||||
Id int `json:"id"`
|
||||
}
|
||||
|
||||
response, err = s.httpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
|
||||
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/teams"))
|
||||
if err != nil {
|
||||
s.log.Error("Error getting team memberships", "url", s.apiUrl+"/teams", "error", err)
|
||||
return []string{}, err
|
||||
@ -499,15 +498,12 @@ func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(client *
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(client *http.Client) ([]string, error) {
|
||||
func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(ctx context.Context, client *http.Client) ([]string, error) {
|
||||
if s.teamIdsAttributePath == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
var response httpGetResponse
|
||||
var err error
|
||||
|
||||
response, err = s.httpGet(client, fmt.Sprintf(s.teamsUrl))
|
||||
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.teamsUrl))
|
||||
if err != nil {
|
||||
s.log.Error("Error getting team memberships", "url", s.teamsUrl, "error", err)
|
||||
return nil, err
|
||||
@ -516,12 +512,12 @@ func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(client *http.Clien
|
||||
return s.searchJSONForStringArrayAttr(s.teamIdsAttributePath, response.Body)
|
||||
}
|
||||
|
||||
func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, bool) {
|
||||
func (s *SocialGenericOAuth) FetchOrganizations(ctx context.Context, client *http.Client) ([]string, bool) {
|
||||
type Record struct {
|
||||
Login string `json:"login"`
|
||||
}
|
||||
|
||||
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
|
||||
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/orgs"))
|
||||
if err != nil {
|
||||
s.log.Error("Error getting organizations", "url", s.apiUrl+"/orgs", "error", err)
|
||||
return nil, false
|
||||
|
@ -1,6 +1,7 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -490,7 +491,7 @@ func TestUserInfoSearchesForEmailAndRole(t *testing.T) {
|
||||
}
|
||||
|
||||
token := staticToken.WithExtra(test.OAuth2Extra)
|
||||
actualResult, err := provider.UserInfo(ts.Client(), token)
|
||||
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.ExpectedEmail, actualResult.Email)
|
||||
require.Equal(t, test.ExpectedEmail, actualResult.Login)
|
||||
@ -588,7 +589,7 @@ func TestUserInfoSearchesForLogin(t *testing.T) {
|
||||
}
|
||||
|
||||
token := staticToken.WithExtra(test.OAuth2Extra)
|
||||
actualResult, err := provider.UserInfo(ts.Client(), token)
|
||||
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.ExpectedLogin, actualResult.Login)
|
||||
})
|
||||
@ -686,7 +687,7 @@ func TestUserInfoSearchesForName(t *testing.T) {
|
||||
}
|
||||
|
||||
token := staticToken.WithExtra(test.OAuth2Extra)
|
||||
actualResult, err := provider.UserInfo(ts.Client(), token)
|
||||
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.ExpectedName, actualResult.Name)
|
||||
})
|
||||
@ -755,7 +756,7 @@ func TestUserInfoSearchesForGroup(t *testing.T) {
|
||||
Expiry: time.Now(),
|
||||
}
|
||||
|
||||
userInfo, err := provider.UserInfo(ts.Client(), token)
|
||||
userInfo, err := provider.UserInfo(context.Background(), ts.Client(), token)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.expectedResult, userInfo.Groups)
|
||||
})
|
||||
|
@ -1,6 +1,7 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -35,12 +36,12 @@ var (
|
||||
ErrMissingOrganizationMembership = Error{"user not a member of one of the required organizations"}
|
||||
)
|
||||
|
||||
func (s *SocialGithub) IsTeamMember(client *http.Client) bool {
|
||||
func (s *SocialGithub) IsTeamMember(ctx context.Context, client *http.Client) bool {
|
||||
if len(s.teamIds) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
teamMemberships, err := s.FetchTeamMemberships(client)
|
||||
teamMemberships, err := s.FetchTeamMemberships(ctx, client)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@ -56,12 +57,13 @@ func (s *SocialGithub) IsTeamMember(client *http.Client) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SocialGithub) IsOrganizationMember(client *http.Client, organizationsUrl string) bool {
|
||||
func (s *SocialGithub) IsOrganizationMember(ctx context.Context,
|
||||
client *http.Client, organizationsUrl string) bool {
|
||||
if len(s.allowedOrganizations) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
organizations, err := s.FetchOrganizations(client, organizationsUrl)
|
||||
organizations, err := s.FetchOrganizations(ctx, client, organizationsUrl)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@ -77,14 +79,14 @@ func (s *SocialGithub) IsOrganizationMember(client *http.Client, organizationsUr
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SocialGithub) FetchPrivateEmail(client *http.Client) (string, error) {
|
||||
func (s *SocialGithub) FetchPrivateEmail(ctx context.Context, client *http.Client) (string, error) {
|
||||
type Record struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
Verified bool `json:"verified"`
|
||||
}
|
||||
|
||||
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
|
||||
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/emails"))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Error getting email address: %s", err)
|
||||
}
|
||||
@ -106,13 +108,13 @@ func (s *SocialGithub) FetchPrivateEmail(client *http.Client) (string, error) {
|
||||
return email, nil
|
||||
}
|
||||
|
||||
func (s *SocialGithub) FetchTeamMemberships(client *http.Client) ([]GithubTeam, error) {
|
||||
func (s *SocialGithub) FetchTeamMemberships(ctx context.Context, client *http.Client) ([]GithubTeam, error) {
|
||||
url := fmt.Sprintf(s.apiUrl + "/teams?per_page=100")
|
||||
hasMore := true
|
||||
teams := make([]GithubTeam, 0)
|
||||
|
||||
for hasMore {
|
||||
response, err := s.httpGet(client, url)
|
||||
response, err := s.httpGet(ctx, client, url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error getting team memberships: %s", err)
|
||||
}
|
||||
@ -150,7 +152,7 @@ func (s *SocialGithub) HasMoreRecords(headers http.Header) (string, bool) {
|
||||
return url, true
|
||||
}
|
||||
|
||||
func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl string) ([]string, error) {
|
||||
func (s *SocialGithub) FetchOrganizations(ctx context.Context, client *http.Client, organizationsUrl string) ([]string, error) {
|
||||
url := organizationsUrl
|
||||
hasMore := true
|
||||
logins := make([]string, 0)
|
||||
@ -160,7 +162,7 @@ func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl
|
||||
}
|
||||
|
||||
for hasMore {
|
||||
response, err := s.httpGet(client, url)
|
||||
response, err := s.httpGet(ctx, client, url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting organizations: %s", err)
|
||||
}
|
||||
@ -181,7 +183,7 @@ func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl
|
||||
return logins, nil
|
||||
}
|
||||
|
||||
func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
func (s *SocialGithub) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
var data struct {
|
||||
Id int `json:"id"`
|
||||
Login string `json:"login"`
|
||||
@ -189,7 +191,7 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
response, err := s.httpGet(client, s.apiUrl)
|
||||
response, err := s.httpGet(ctx, client, s.apiUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting user info: %s", err)
|
||||
}
|
||||
@ -198,7 +200,7 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
|
||||
return nil, fmt.Errorf("error unmarshalling user info: %s", err)
|
||||
}
|
||||
|
||||
teamMemberships, err := s.FetchTeamMemberships(client)
|
||||
teamMemberships, err := s.FetchTeamMemberships(ctx, client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting user teams: %s", err)
|
||||
}
|
||||
@ -241,16 +243,16 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
|
||||
|
||||
organizationsUrl := fmt.Sprintf(s.apiUrl + "/orgs?per_page=100")
|
||||
|
||||
if !s.IsTeamMember(client) {
|
||||
if !s.IsTeamMember(ctx, client) {
|
||||
return nil, ErrMissingTeamMembership
|
||||
}
|
||||
|
||||
if !s.IsOrganizationMember(client, organizationsUrl) {
|
||||
if !s.IsOrganizationMember(ctx, client, organizationsUrl) {
|
||||
return nil, ErrMissingOrganizationMembership
|
||||
}
|
||||
|
||||
if userInfo.Email == "" {
|
||||
userInfo.Email, err = s.FetchPrivateEmail(client)
|
||||
userInfo.Email, err = s.FetchPrivateEmail(ctx, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
@ -250,7 +251,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) {
|
||||
AccessToken: "fake_token",
|
||||
}
|
||||
|
||||
got, err := s.UserInfo(server.Client(), token)
|
||||
got, err := s.UserInfo(context.Background(), server.Client(), token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -1,6 +1,7 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -34,10 +35,10 @@ func (s *SocialGitlab) IsGroupMember(groups []string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SocialGitlab) GetGroups(client *http.Client) []string {
|
||||
func (s *SocialGitlab) GetGroups(ctx context.Context, client *http.Client) []string {
|
||||
groups := make([]string, 0)
|
||||
|
||||
for page, url := s.GetGroupsPage(client, s.apiUrl+"/groups"); page != nil; page, url = s.GetGroupsPage(client, url) {
|
||||
for page, url := s.GetGroupsPage(ctx, client, s.apiUrl+"/groups"); page != nil; page, url = s.GetGroupsPage(ctx, client, url) {
|
||||
groups = append(groups, page...)
|
||||
}
|
||||
|
||||
@ -45,7 +46,7 @@ func (s *SocialGitlab) GetGroups(client *http.Client) []string {
|
||||
}
|
||||
|
||||
// GetGroupsPage returns groups and link to the next page if response is paginated
|
||||
func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string, string) {
|
||||
func (s *SocialGitlab) GetGroupsPage(ctx context.Context, client *http.Client, url string) ([]string, string) {
|
||||
type Group struct {
|
||||
FullPath string `json:"full_path"`
|
||||
}
|
||||
@ -59,7 +60,7 @@ func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string,
|
||||
return nil, next
|
||||
}
|
||||
|
||||
response, err := s.httpGet(client, url)
|
||||
response, err := s.httpGet(ctx, client, url)
|
||||
if err != nil {
|
||||
s.log.Error("Error getting groups from GitLab API", "err", err)
|
||||
return nil, next
|
||||
@ -86,7 +87,7 @@ func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string,
|
||||
return fullPaths, next
|
||||
}
|
||||
|
||||
func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
|
||||
func (s *SocialGitlab) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
|
||||
var data struct {
|
||||
Id int
|
||||
Username string
|
||||
@ -95,7 +96,7 @@ func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUse
|
||||
State string
|
||||
}
|
||||
|
||||
response, err := s.httpGet(client, s.apiUrl+"/user")
|
||||
response, err := s.httpGet(ctx, client, s.apiUrl+"/user")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error getting user info: %s", err)
|
||||
}
|
||||
@ -108,7 +109,7 @@ func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUse
|
||||
return nil, fmt.Errorf("user %s is inactive", data.Username)
|
||||
}
|
||||
|
||||
groups := s.GetGroups(client)
|
||||
groups := s.GetGroups(ctx, client)
|
||||
|
||||
var role roletype.RoleType
|
||||
var isGrafanaAdmin *bool = nil
|
||||
|
@ -1,6 +1,7 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@ -159,7 +160,7 @@ func TestSocialGitlab_UserInfo(t *testing.T) {
|
||||
}
|
||||
}))
|
||||
provider.apiUrl = ts.URL + apiURI
|
||||
actualResult, err := provider.UserInfo(ts.Client(), nil)
|
||||
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), nil)
|
||||
if test.ExpectedError != nil {
|
||||
require.Equal(t, err, test.ExpectedError)
|
||||
return
|
||||
|
@ -1,6 +1,7 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -16,14 +17,14 @@ type SocialGoogle struct {
|
||||
apiUrl string
|
||||
}
|
||||
|
||||
func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
var data struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
response, err := s.httpGet(client, s.apiUrl)
|
||||
response, err := s.httpGet(ctx, client, s.apiUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error getting user info: %s", err)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -43,7 +44,7 @@ func (s *SocialGrafanaCom) IsOrganizationMember(organizations []OrgRecord) bool
|
||||
}
|
||||
|
||||
// UserInfo is used for login credentials for the user
|
||||
func (s *SocialGrafanaCom) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
|
||||
func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
|
||||
var data struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@ -53,7 +54,8 @@ func (s *SocialGrafanaCom) UserInfo(client *http.Client, _ *oauth2.Token) (*Basi
|
||||
Orgs []OrgRecord `json:"orgs"`
|
||||
}
|
||||
|
||||
response, err := s.httpGet(client, s.url+"/api/oauth2/user")
|
||||
response, err := s.httpGet(ctx, client, s.url+"/api/oauth2/user")
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error getting user info: %s", err)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@ -81,7 +82,7 @@ func TestSocialGrafanaCom_UserInfo(t *testing.T) {
|
||||
}
|
||||
}))
|
||||
provider.url = ts.URL
|
||||
actualResult, err := provider.UserInfo(ts.Client(), nil)
|
||||
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), nil)
|
||||
if test.ExpectedError != nil {
|
||||
require.Equal(t, err, test.ExpectedError)
|
||||
return
|
||||
|
@ -1,6 +1,7 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -46,7 +47,7 @@ func (claims *OktaClaims) extractEmail() string {
|
||||
return claims.Email
|
||||
}
|
||||
|
||||
func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||
idToken := token.Extra("id_token")
|
||||
if idToken == nil {
|
||||
return nil, fmt.Errorf("no id_token found")
|
||||
@ -68,7 +69,7 @@ func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicU
|
||||
}
|
||||
|
||||
var data OktaUserInfoJson
|
||||
err = s.extractAPI(&data, client)
|
||||
err = s.extractAPI(ctx, &data, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -105,8 +106,8 @@ func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicU
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SocialOkta) extractAPI(data *OktaUserInfoJson, client *http.Client) error {
|
||||
rawUserInfoResponse, err := s.httpGet(client, s.apiUrl)
|
||||
func (s *SocialOkta) extractAPI(ctx context.Context, data *OktaUserInfoJson, client *http.Client) error {
|
||||
rawUserInfoResponse, err := s.httpGet(ctx, client, s.apiUrl)
|
||||
if err != nil {
|
||||
s.log.Debug("Error getting user info response", "url", s.apiUrl, "error", err)
|
||||
return fmt.Errorf("error getting user info response: %w", err)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package social
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -110,7 +111,7 @@ func TestSocialOkta_UserInfo(t *testing.T) {
|
||||
Expiry: time.Now(),
|
||||
}
|
||||
token := staticToken.WithExtra(tt.OAuth2Extra)
|
||||
got, err := provider.UserInfo(server.Client(), token)
|
||||
got, err := provider.UserInfo(context.Background(), server.Client(), token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -7,9 +7,11 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/cases"
|
||||
@ -261,7 +263,7 @@ func (b *BasicUserInfo) String() string {
|
||||
}
|
||||
|
||||
type SocialConnector interface {
|
||||
UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error)
|
||||
UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error)
|
||||
IsEmailAllowed(email string) bool
|
||||
IsSignupAllowed() bool
|
||||
|
||||
@ -450,9 +452,19 @@ func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) {
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: info.TlsSkipVerify,
|
||||
},
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: time.Second * 10,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 15 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
|
||||
oauthClient := &http.Client{
|
||||
Transport: tr,
|
||||
Timeout: time.Second * 15,
|
||||
}
|
||||
|
||||
if info.TlsClientCert != "" || info.TlsClientKey != "" {
|
||||
|
@ -116,7 +116,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
|
||||
}
|
||||
token.TokenType = "Bearer"
|
||||
|
||||
userInfo, err := c.connector.UserInfo(c.connector.Client(clientCtx, token), token)
|
||||
userInfo, err := c.connector.UserInfo(ctx, c.connector.Client(clientCtx, token), token)
|
||||
if err != nil {
|
||||
var sErr *social.Error
|
||||
if errors.As(err, &sErr) {
|
||||
|
@ -8,13 +8,14 @@ import (
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOAuth_Authenticate(t *testing.T) {
|
||||
@ -278,7 +279,7 @@ type fakeConnector struct {
|
||||
social.SocialConnector
|
||||
}
|
||||
|
||||
func (f fakeConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
||||
func (f fakeConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
||||
return f.ExpectedUserInfo, f.ExpectedUserInfoErr
|
||||
}
|
||||
|
||||
|
@ -273,7 +273,7 @@ func (m *MockSocialConnector) Type() int {
|
||||
return args.Int(0)
|
||||
}
|
||||
|
||||
func (m *MockSocialConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
||||
func (m *MockSocialConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
||||
args := m.Called(client, token)
|
||||
return args.Get(0).(*social.BasicUserInfo), args.Error(1)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user