mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Add request context to UserInfo calls (#70007)
* use context for UserInfo requests * set timeouts for oauth http client * Update pkg/login/social/common.go Co-authored-by: Ieva <ieva.vasiljeva@grafana.com> --------- Co-authored-by: Ieva <ieva.vasiljeva@grafana.com>
This commit is contained in:
parent
1445a7cc5c
commit
914daef0fd
@ -70,61 +70,61 @@ func genPKCECode() (string, string, error) {
|
|||||||
return string(ascii), pkce, nil
|
return string(ascii), pkce, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
func (hs *HTTPServer) OAuthLogin(reqCtx *contextmodel.ReqContext) {
|
||||||
name := web.Params(ctx.Req)[":name"]
|
name := web.Params(reqCtx.Req)[":name"]
|
||||||
loginInfo := loginservice.LoginInfo{AuthModule: name}
|
loginInfo := loginservice.LoginInfo{AuthModule: name}
|
||||||
|
|
||||||
if errorParam := ctx.Query("error"); errorParam != "" {
|
if errorParam := reqCtx.Query("error"); errorParam != "" {
|
||||||
errorDesc := ctx.Query("error_description")
|
errorDesc := reqCtx.Query("error_description")
|
||||||
oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc)
|
oauthLogger.Error("failed to login ", "error", errorParam, "errorDesc", errorDesc)
|
||||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc)
|
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrProviderDeniedRequest, "error", errorParam, "errorDesc", errorDesc)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
code := ctx.Query("code")
|
code := reqCtx.Query("code")
|
||||||
|
|
||||||
if hs.Cfg.AuthBrokerEnabled {
|
if hs.Cfg.AuthBrokerEnabled {
|
||||||
req := &authn.Request{HTTPRequest: ctx.Req, Resp: ctx.Resp}
|
req := &authn.Request{HTTPRequest: reqCtx.Req, Resp: reqCtx.Resp}
|
||||||
if code == "" {
|
if code == "" {
|
||||||
redirect, err := hs.authnService.RedirectURL(ctx.Req.Context(), authn.ClientWithPrefix(name), req)
|
redirect, err := hs.authnService.RedirectURL(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Redirect(hs.redirectURLWithErrorCookie(ctx, err))
|
reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" {
|
if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" {
|
||||||
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||||
ctx.Redirect(redirect.URL)
|
reqCtx.Redirect(redirect.URL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, err := hs.authnService.Login(ctx.Req.Context(), authn.ClientWithPrefix(name), req)
|
identity, err := hs.authnService.Login(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
|
||||||
// NOTE: always delete these cookies, even if login failed
|
// NOTE: always delete these cookies, even if login failed
|
||||||
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
||||||
cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Redirect(hs.redirectURLWithErrorCookie(ctx, err))
|
reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics.MApiLoginOAuth.Inc()
|
metrics.MApiLoginOAuth.Inc()
|
||||||
authn.HandleLoginRedirect(ctx.Req, ctx.Resp, hs.Cfg, identity, hs.ValidateRedirectTo)
|
authn.HandleLoginRedirect(reqCtx.Req, reqCtx.Resp, hs.Cfg, identity, hs.ValidateRedirectTo)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
provider := hs.SocialService.GetOAuthInfoProvider(name)
|
provider := hs.SocialService.GetOAuthInfoProvider(name)
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, errors.New("OAuth not enabled"))
|
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, errors.New("OAuth not enabled"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
connect, err := hs.SocialService.GetConnector(name)
|
connect, err := hs.SocialService.GetConnector(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, fmt.Errorf("no OAuth with name %s configured", name))
|
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, fmt.Errorf("no OAuth with name %s configured", name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,15 +133,15 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
|||||||
if provider.UsePKCE {
|
if provider.UsePKCE {
|
||||||
ascii, pkce, err := genPKCECode()
|
ascii, pkce, err := genPKCECode()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Logger.Error("Generating PKCE failed", "error", err)
|
reqCtx.Logger.Error("Generating PKCE failed", "error", err)
|
||||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||||
HttpStatus: http.StatusInternalServerError,
|
HttpStatus: http.StatusInternalServerError,
|
||||||
PublicMessage: "An internal error occurred",
|
PublicMessage: "An internal error occurred",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cookies.WriteCookie(ctx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, ascii, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||||
|
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
oauth2.SetAuthURLParam("code_challenge", pkce),
|
oauth2.SetAuthURLParam("code_challenge", pkce),
|
||||||
@ -151,8 +151,8 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
|||||||
|
|
||||||
state, err := GenStateString()
|
state, err := GenStateString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Logger.Error("Generating state string failed", "err", err)
|
reqCtx.Logger.Error("Generating state string failed", "err", err)
|
||||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||||
HttpStatus: http.StatusInternalServerError,
|
HttpStatus: http.StatusInternalServerError,
|
||||||
PublicMessage: "An internal error occurred",
|
PublicMessage: "An internal error occurred",
|
||||||
})
|
})
|
||||||
@ -160,32 +160,32 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hashedState := hs.hashStatecode(state, provider.ClientSecret)
|
hashedState := hs.hashStatecode(state, provider.ClientSecret)
|
||||||
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||||
if provider.HostedDomain != "" {
|
if provider.HostedDomain != "" {
|
||||||
opts = append(opts, oauth2.SetAuthURLParam("hd", provider.HostedDomain))
|
opts = append(opts, oauth2.SetAuthURLParam("hd", provider.HostedDomain))
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Redirect(connect.AuthCodeURL(state, opts...))
|
reqCtx.Redirect(connect.AuthCodeURL(state, opts...))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieState := ctx.GetCookie(OauthStateCookieName)
|
cookieState := reqCtx.GetCookie(OauthStateCookieName)
|
||||||
|
|
||||||
// delete cookie
|
// delete cookie
|
||||||
cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
||||||
|
|
||||||
if cookieState == "" {
|
if cookieState == "" {
|
||||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||||
HttpStatus: http.StatusInternalServerError,
|
HttpStatus: http.StatusInternalServerError,
|
||||||
PublicMessage: "login.OAuthLogin(missing saved state)",
|
PublicMessage: "login.OAuthLogin(missing saved state)",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queryState := hs.hashStatecode(ctx.Query("state"), provider.ClientSecret)
|
queryState := hs.hashStatecode(reqCtx.Query("state"), provider.ClientSecret)
|
||||||
oauthLogger.Info("state check", "queryState", queryState, "cookieState", cookieState)
|
oauthLogger.Info("state check", "queryState", queryState, "cookieState", cookieState)
|
||||||
if cookieState != queryState {
|
if cookieState != queryState {
|
||||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||||
HttpStatus: http.StatusInternalServerError,
|
HttpStatus: http.StatusInternalServerError,
|
||||||
PublicMessage: "login.OAuthLogin(state mismatch)",
|
PublicMessage: "login.OAuthLogin(state mismatch)",
|
||||||
})
|
})
|
||||||
@ -194,19 +194,20 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
|||||||
|
|
||||||
oauthClient, err := hs.SocialService.GetOAuthHttpClient(name)
|
oauthClient, err := hs.SocialService.GetOAuthHttpClient(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.Logger.Error("Failed to create OAuth http client", "error", err)
|
reqCtx.Logger.Error("Failed to create OAuth http client", "error", err)
|
||||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||||
HttpStatus: http.StatusInternalServerError,
|
HttpStatus: http.StatusInternalServerError,
|
||||||
PublicMessage: "login.OAuthLogin(" + err.Error() + ")",
|
PublicMessage: "login.OAuthLogin(" + err.Error() + ")",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthCtx := context.WithValue(context.Background(), oauth2.HTTPClient, oauthClient)
|
ctx := reqCtx.Req.Context()
|
||||||
|
oauthCtx := context.WithValue(ctx, oauth2.HTTPClient, oauthClient)
|
||||||
opts := []oauth2.AuthCodeOption{}
|
opts := []oauth2.AuthCodeOption{}
|
||||||
|
|
||||||
codeVerifier := ctx.GetCookie(OauthPKCECookieName)
|
codeVerifier := reqCtx.GetCookie(OauthPKCECookieName)
|
||||||
cookies.DeleteCookie(ctx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
||||||
if codeVerifier != "" {
|
if codeVerifier != "" {
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
oauth2.SetAuthURLParam("code_verifier", codeVerifier),
|
oauth2.SetAuthURLParam("code_verifier", codeVerifier),
|
||||||
@ -216,7 +217,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
|||||||
// get token from provider
|
// get token from provider
|
||||||
token, err := connect.Exchange(oauthCtx, code, opts...)
|
token, err := connect.Exchange(oauthCtx, code, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||||
HttpStatus: http.StatusInternalServerError,
|
HttpStatus: http.StatusInternalServerError,
|
||||||
PublicMessage: "login.OAuthLogin(NewTransportWithCode)",
|
PublicMessage: "login.OAuthLogin(NewTransportWithCode)",
|
||||||
Err: err,
|
Err: err,
|
||||||
@ -245,13 +246,13 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
|||||||
client := connect.Client(oauthCtx, token)
|
client := connect.Client(oauthCtx, token)
|
||||||
|
|
||||||
// get user info
|
// get user info
|
||||||
userInfo, err := connect.UserInfo(client, token)
|
userInfo, err := connect.UserInfo(ctx, client, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var sErr *social.Error
|
var sErr *social.Error
|
||||||
if errors.As(err, &sErr) {
|
if errors.As(err, &sErr) {
|
||||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, sErr)
|
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, sErr)
|
||||||
} else {
|
} else {
|
||||||
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
|
hs.handleOAuthLoginError(reqCtx, loginInfo, LoginError{
|
||||||
HttpStatus: http.StatusInternalServerError,
|
HttpStatus: http.StatusInternalServerError,
|
||||||
PublicMessage: fmt.Sprintf("login.OAuthLogin(get info from %s)", name),
|
PublicMessage: fmt.Sprintf("login.OAuthLogin(get info from %s)", name),
|
||||||
Err: err,
|
Err: err,
|
||||||
@ -264,34 +265,34 @@ func (hs *HTTPServer) OAuthLogin(ctx *contextmodel.ReqContext) {
|
|||||||
|
|
||||||
// validate that we got at least an email address
|
// validate that we got at least an email address
|
||||||
if userInfo.Email == "" {
|
if userInfo.Email == "" {
|
||||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrNoEmail)
|
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrNoEmail)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate that the email is allowed to login to grafana
|
// validate that the email is allowed to login to grafana
|
||||||
if !connect.IsEmailAllowed(userInfo.Email) {
|
if !connect.IsEmailAllowed(userInfo.Email) {
|
||||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, login.ErrEmailNotAllowed)
|
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, login.ErrEmailNotAllowed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
loginInfo.ExternalUser = *hs.buildExternalUserInfo(token, userInfo, name)
|
loginInfo.ExternalUser = *hs.buildExternalUserInfo(token, userInfo, name)
|
||||||
loginInfo.User, err = hs.SyncUser(ctx, &loginInfo.ExternalUser, connect)
|
loginInfo.User, err = hs.SyncUser(reqCtx, &loginInfo.ExternalUser, connect)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, err)
|
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// login
|
// login
|
||||||
if err := hs.loginUserWithUser(loginInfo.User, ctx); err != nil {
|
if err := hs.loginUserWithUser(loginInfo.User, reqCtx); err != nil {
|
||||||
hs.handleOAuthLoginErrorWithRedirect(ctx, loginInfo, err)
|
hs.handleOAuthLoginErrorWithRedirect(reqCtx, loginInfo, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
loginInfo.HTTPStatus = http.StatusOK
|
loginInfo.HTTPStatus = http.StatusOK
|
||||||
hs.HooksService.RunLoginHook(&loginInfo, ctx)
|
hs.HooksService.RunLoginHook(&loginInfo, reqCtx)
|
||||||
metrics.MApiLoginOAuth.Inc()
|
metrics.MApiLoginOAuth.Inc()
|
||||||
|
|
||||||
ctx.Redirect(hs.GetRedirectURL(ctx))
|
reqCtx.Redirect(hs.GetRedirectURL(reqCtx))
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildExternalUserInfo returns a ExternalUserInfo struct from OAuth user profile
|
// buildExternalUserInfo returns a ExternalUserInfo struct from OAuth user profile
|
||||||
|
@ -2,6 +2,7 @@ package social
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -47,7 +48,7 @@ type azureAccessClaims struct {
|
|||||||
TenantID string `json:"tid"`
|
TenantID string `json:"tid"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
func (s *SocialAzureAD) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||||
idToken := token.Extra("id_token")
|
idToken := token.Extra("id_token")
|
||||||
if idToken == nil {
|
if idToken == nil {
|
||||||
return nil, ErrIDTokenNotFound
|
return nil, ErrIDTokenNotFound
|
||||||
@ -83,7 +84,7 @@ func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*Bas
|
|||||||
}
|
}
|
||||||
logger.Debug("AzureAD OAuth: extracted role", "email", email, "role", role)
|
logger.Debug("AzureAD OAuth: extracted role", "email", email, "role", role)
|
||||||
|
|
||||||
groups, err := s.extractGroups(client, claims, token)
|
groups, err := s.extractGroups(ctx, client, claims, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to extract groups: %w", err)
|
return nil, fmt.Errorf("failed to extract groups: %w", err)
|
||||||
}
|
}
|
||||||
@ -176,7 +177,7 @@ type getAzureGroupResponse struct {
|
|||||||
// Note: If user groups exceeds 200 no groups will be found in claims and URL to target the Graph API will be
|
// Note: If user groups exceeds 200 no groups will be found in claims and URL to target the Graph API will be
|
||||||
// given instead.
|
// given instead.
|
||||||
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
|
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
|
||||||
func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
|
func (s *SocialAzureAD) extractGroups(ctx context.Context, client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) {
|
||||||
if !s.forceUseGraphAPI {
|
if !s.forceUseGraphAPI {
|
||||||
logger.Debug("checking the claim for groups")
|
logger.Debug("checking the claim for groups")
|
||||||
if len(claims.Groups) > 0 {
|
if len(claims.Groups) > 0 {
|
||||||
@ -199,7 +200,13 @@ func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, t
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := client.Post(endpoint, "application/json", bytes.NewBuffer(data))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewBuffer(data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -473,7 +473,7 @@ func TestSocialAzureAD_UserInfo(t *testing.T) {
|
|||||||
tt.args.client = s.Client(context.Background(), token)
|
tt.args.client = s.Client(context.Background(), token)
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := s.UserInfo(tt.args.client, token)
|
got, err := s.UserInfo(context.Background(), tt.args.client, token)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
@ -617,7 +617,7 @@ func TestSocialAzureAD_SkipOrgRole(t *testing.T) {
|
|||||||
tt.args.client = s.Client(context.Background(), token)
|
tt.args.client = s.Client(context.Background(), token)
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := s.UserInfo(tt.args.client, token)
|
got, err := s.UserInfo(context.Background(), tt.args.client, token)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package social
|
package social
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -42,10 +43,15 @@ func isEmailAllowed(email string, allowedDomains []string) bool {
|
|||||||
return valid
|
return valid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialBase) httpGet(client *http.Client, url string) (response httpGetResponse, err error) {
|
func (s *SocialBase) httpGet(ctx context.Context, client *http.Client, url string) (*httpGetResponse, error) {
|
||||||
r, err := client.Get(url)
|
req, errReq := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
if err != nil {
|
if errReq != nil {
|
||||||
return
|
return nil, errReq
|
||||||
|
}
|
||||||
|
|
||||||
|
r, errDo := client.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return nil, errDo
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -54,21 +60,20 @@ func (s *SocialBase) httpGet(client *http.Client, url string) (response httpGetR
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
body, err := io.ReadAll(r.Body)
|
body, errRead := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if errRead != nil {
|
||||||
return
|
return nil, errRead
|
||||||
}
|
}
|
||||||
|
|
||||||
response = httpGetResponse{body, r.Header}
|
response := &httpGetResponse{body, r.Header}
|
||||||
|
|
||||||
if r.StatusCode >= 300 {
|
if r.StatusCode >= 300 {
|
||||||
err = fmt.Errorf(string(response.Body))
|
return nil, fmt.Errorf("unsuccessful response status code %d: %s", r.StatusCode, string(response.Body))
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.log.Debug("HTTP GET", "url", url, "status", r.Status, "response_body", string(response.Body))
|
s.log.Debug("HTTP GET", "url", url, "status", r.Status, "response_body", string(response.Body))
|
||||||
|
|
||||||
err = nil
|
return response, nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialBase) searchJSONForAttr(attributePath string, data []byte) (interface{}, error) {
|
func (s *SocialBase) searchJSONForAttr(attributePath string, data []byte) (interface{}, error) {
|
||||||
|
@ -3,6 +3,7 @@ package social
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/zlib"
|
"compress/zlib"
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@ -50,12 +51,12 @@ func (s *SocialGenericOAuth) IsGroupMember(groups []string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
|
func (s *SocialGenericOAuth) IsTeamMember(ctx context.Context, client *http.Client) bool {
|
||||||
if len(s.teamIds) == 0 {
|
if len(s.teamIds) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
teamMemberships, err := s.FetchTeamMemberships(client)
|
teamMemberships, err := s.FetchTeamMemberships(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -71,12 +72,12 @@ func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool {
|
func (s *SocialGenericOAuth) IsOrganizationMember(ctx context.Context, client *http.Client) bool {
|
||||||
if len(s.allowedOrganizations) == 0 {
|
if len(s.allowedOrganizations) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
organizations, ok := s.FetchOrganizations(client)
|
organizations, ok := s.FetchOrganizations(ctx, client)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -111,14 +112,14 @@ func (info *UserInfoJson) String() string {
|
|||||||
info.Name, info.DisplayName, info.Login, info.Username, info.Email, info.Upn, info.Attributes)
|
info.Name, info.DisplayName, info.Login, info.Username, info.Email, info.Upn, info.Attributes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
func (s *SocialGenericOAuth) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||||
s.log.Debug("Getting user info")
|
s.log.Debug("Getting user info")
|
||||||
toCheck := make([]*UserInfoJson, 0, 2)
|
toCheck := make([]*UserInfoJson, 0, 2)
|
||||||
|
|
||||||
if tokenData := s.extractFromToken(token); tokenData != nil {
|
if tokenData := s.extractFromToken(token); tokenData != nil {
|
||||||
toCheck = append(toCheck, tokenData)
|
toCheck = append(toCheck, tokenData)
|
||||||
}
|
}
|
||||||
if apiData := s.extractFromAPI(client); apiData != nil {
|
if apiData := s.extractFromAPI(ctx, client); apiData != nil {
|
||||||
toCheck = append(toCheck, apiData)
|
toCheck = append(toCheck, apiData)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,7 +180,7 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
|
|||||||
|
|
||||||
if userInfo.Email == "" {
|
if userInfo.Email == "" {
|
||||||
var err error
|
var err error
|
||||||
userInfo.Email, err = s.FetchPrivateEmail(client)
|
userInfo.Email, err = s.FetchPrivateEmail(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -191,11 +192,11 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
|
|||||||
userInfo.Login = userInfo.Email
|
userInfo.Login = userInfo.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.IsTeamMember(client) {
|
if !s.IsTeamMember(ctx, client) {
|
||||||
return nil, errors.New("user not a member of one of the required teams")
|
return nil, errors.New("user not a member of one of the required teams")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.IsOrganizationMember(client) {
|
if !s.IsOrganizationMember(ctx, client) {
|
||||||
return nil, errors.New("user not a member of one of the required organizations")
|
return nil, errors.New("user not a member of one of the required organizations")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,14 +289,14 @@ func (s *SocialGenericOAuth) extractFromToken(token *oauth2.Token) *UserInfoJson
|
|||||||
return &data
|
return &data
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGenericOAuth) extractFromAPI(client *http.Client) *UserInfoJson {
|
func (s *SocialGenericOAuth) extractFromAPI(ctx context.Context, client *http.Client) *UserInfoJson {
|
||||||
s.log.Debug("Getting user info from API")
|
s.log.Debug("Getting user info from API")
|
||||||
if s.apiUrl == "" {
|
if s.apiUrl == "" {
|
||||||
s.log.Debug("No api url configured")
|
s.log.Debug("No api url configured")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rawUserInfoResponse, err := s.httpGet(client, s.apiUrl)
|
rawUserInfoResponse, err := s.httpGet(ctx, client, s.apiUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Debug("Error getting user info from API", "url", s.apiUrl, "error", err)
|
s.log.Debug("Error getting user info from API", "url", s.apiUrl, "error", err)
|
||||||
return nil
|
return nil
|
||||||
@ -404,7 +405,7 @@ func (s *SocialGenericOAuth) extractGroups(data *UserInfoJson) ([]string, error)
|
|||||||
return s.searchJSONForStringArrayAttr(s.groupsAttributePath, data.rawJSON)
|
return s.searchJSONForStringArrayAttr(s.groupsAttributePath, data.rawJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) {
|
func (s *SocialGenericOAuth) FetchPrivateEmail(ctx context.Context, client *http.Client) (string, error) {
|
||||||
type Record struct {
|
type Record struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Primary bool `json:"primary"`
|
Primary bool `json:"primary"`
|
||||||
@ -413,7 +414,7 @@ func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, err
|
|||||||
IsConfirmed bool `json:"is_confirmed"`
|
IsConfirmed bool `json:"is_confirmed"`
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
|
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/emails"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("Error getting email address", "url", s.apiUrl+"/emails", "error", err)
|
s.log.Error("Error getting email address", "url", s.apiUrl+"/emails", "error", err)
|
||||||
return "", fmt.Errorf("%v: %w", "Error getting email address", err)
|
return "", fmt.Errorf("%v: %w", "Error getting email address", err)
|
||||||
@ -451,14 +452,14 @@ func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, err
|
|||||||
return email, nil
|
return email, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]string, error) {
|
func (s *SocialGenericOAuth) FetchTeamMemberships(ctx context.Context, client *http.Client) ([]string, error) {
|
||||||
var err error
|
var err error
|
||||||
var ids []string
|
var ids []string
|
||||||
|
|
||||||
if s.teamsUrl == "" {
|
if s.teamsUrl == "" {
|
||||||
ids, err = s.fetchTeamMembershipsFromDeprecatedTeamsUrl(client)
|
ids, err = s.fetchTeamMembershipsFromDeprecatedTeamsUrl(ctx, client)
|
||||||
} else {
|
} else {
|
||||||
ids, err = s.fetchTeamMembershipsFromTeamsUrl(client)
|
ids, err = s.fetchTeamMembershipsFromTeamsUrl(ctx, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -468,16 +469,14 @@ func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]string
|
|||||||
return ids, err
|
return ids, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(client *http.Client) ([]string, error) {
|
func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(ctx context.Context, client *http.Client) ([]string, error) {
|
||||||
var response httpGetResponse
|
|
||||||
var err error
|
|
||||||
var ids []string
|
var ids []string
|
||||||
|
|
||||||
type Record struct {
|
type Record struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err = s.httpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
|
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/teams"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("Error getting team memberships", "url", s.apiUrl+"/teams", "error", err)
|
s.log.Error("Error getting team memberships", "url", s.apiUrl+"/teams", "error", err)
|
||||||
return []string{}, err
|
return []string{}, err
|
||||||
@ -499,15 +498,12 @@ func (s *SocialGenericOAuth) fetchTeamMembershipsFromDeprecatedTeamsUrl(client *
|
|||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(client *http.Client) ([]string, error) {
|
func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(ctx context.Context, client *http.Client) ([]string, error) {
|
||||||
if s.teamIdsAttributePath == "" {
|
if s.teamIdsAttributePath == "" {
|
||||||
return []string{}, nil
|
return []string{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var response httpGetResponse
|
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.teamsUrl))
|
||||||
var err error
|
|
||||||
|
|
||||||
response, err = s.httpGet(client, fmt.Sprintf(s.teamsUrl))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("Error getting team memberships", "url", s.teamsUrl, "error", err)
|
s.log.Error("Error getting team memberships", "url", s.teamsUrl, "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -516,12 +512,12 @@ func (s *SocialGenericOAuth) fetchTeamMembershipsFromTeamsUrl(client *http.Clien
|
|||||||
return s.searchJSONForStringArrayAttr(s.teamIdsAttributePath, response.Body)
|
return s.searchJSONForStringArrayAttr(s.teamIdsAttributePath, response.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, bool) {
|
func (s *SocialGenericOAuth) FetchOrganizations(ctx context.Context, client *http.Client) ([]string, bool) {
|
||||||
type Record struct {
|
type Record struct {
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
|
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/orgs"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("Error getting organizations", "url", s.apiUrl+"/orgs", "error", err)
|
s.log.Error("Error getting organizations", "url", s.apiUrl+"/orgs", "error", err)
|
||||||
return nil, false
|
return nil, false
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package social
|
package social
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -490,7 +491,7 @@ func TestUserInfoSearchesForEmailAndRole(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
token := staticToken.WithExtra(test.OAuth2Extra)
|
token := staticToken.WithExtra(test.OAuth2Extra)
|
||||||
actualResult, err := provider.UserInfo(ts.Client(), token)
|
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, test.ExpectedEmail, actualResult.Email)
|
require.Equal(t, test.ExpectedEmail, actualResult.Email)
|
||||||
require.Equal(t, test.ExpectedEmail, actualResult.Login)
|
require.Equal(t, test.ExpectedEmail, actualResult.Login)
|
||||||
@ -588,7 +589,7 @@ func TestUserInfoSearchesForLogin(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
token := staticToken.WithExtra(test.OAuth2Extra)
|
token := staticToken.WithExtra(test.OAuth2Extra)
|
||||||
actualResult, err := provider.UserInfo(ts.Client(), token)
|
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, test.ExpectedLogin, actualResult.Login)
|
require.Equal(t, test.ExpectedLogin, actualResult.Login)
|
||||||
})
|
})
|
||||||
@ -686,7 +687,7 @@ func TestUserInfoSearchesForName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
token := staticToken.WithExtra(test.OAuth2Extra)
|
token := staticToken.WithExtra(test.OAuth2Extra)
|
||||||
actualResult, err := provider.UserInfo(ts.Client(), token)
|
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), token)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, test.ExpectedName, actualResult.Name)
|
require.Equal(t, test.ExpectedName, actualResult.Name)
|
||||||
})
|
})
|
||||||
@ -755,7 +756,7 @@ func TestUserInfoSearchesForGroup(t *testing.T) {
|
|||||||
Expiry: time.Now(),
|
Expiry: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
userInfo, err := provider.UserInfo(ts.Client(), token)
|
userInfo, err := provider.UserInfo(context.Background(), ts.Client(), token)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, test.expectedResult, userInfo.Groups)
|
assert.Equal(t, test.expectedResult, userInfo.Groups)
|
||||||
})
|
})
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package social
|
package social
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -35,12 +36,12 @@ var (
|
|||||||
ErrMissingOrganizationMembership = Error{"user not a member of one of the required organizations"}
|
ErrMissingOrganizationMembership = Error{"user not a member of one of the required organizations"}
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *SocialGithub) IsTeamMember(client *http.Client) bool {
|
func (s *SocialGithub) IsTeamMember(ctx context.Context, client *http.Client) bool {
|
||||||
if len(s.teamIds) == 0 {
|
if len(s.teamIds) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
teamMemberships, err := s.FetchTeamMemberships(client)
|
teamMemberships, err := s.FetchTeamMemberships(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -56,12 +57,13 @@ func (s *SocialGithub) IsTeamMember(client *http.Client) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGithub) IsOrganizationMember(client *http.Client, organizationsUrl string) bool {
|
func (s *SocialGithub) IsOrganizationMember(ctx context.Context,
|
||||||
|
client *http.Client, organizationsUrl string) bool {
|
||||||
if len(s.allowedOrganizations) == 0 {
|
if len(s.allowedOrganizations) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
organizations, err := s.FetchOrganizations(client, organizationsUrl)
|
organizations, err := s.FetchOrganizations(ctx, client, organizationsUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -77,14 +79,14 @@ func (s *SocialGithub) IsOrganizationMember(client *http.Client, organizationsUr
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGithub) FetchPrivateEmail(client *http.Client) (string, error) {
|
func (s *SocialGithub) FetchPrivateEmail(ctx context.Context, client *http.Client) (string, error) {
|
||||||
type Record struct {
|
type Record struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Primary bool `json:"primary"`
|
Primary bool `json:"primary"`
|
||||||
Verified bool `json:"verified"`
|
Verified bool `json:"verified"`
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := s.httpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
|
response, err := s.httpGet(ctx, client, fmt.Sprintf(s.apiUrl+"/emails"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("Error getting email address: %s", err)
|
return "", fmt.Errorf("Error getting email address: %s", err)
|
||||||
}
|
}
|
||||||
@ -106,13 +108,13 @@ func (s *SocialGithub) FetchPrivateEmail(client *http.Client) (string, error) {
|
|||||||
return email, nil
|
return email, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGithub) FetchTeamMemberships(client *http.Client) ([]GithubTeam, error) {
|
func (s *SocialGithub) FetchTeamMemberships(ctx context.Context, client *http.Client) ([]GithubTeam, error) {
|
||||||
url := fmt.Sprintf(s.apiUrl + "/teams?per_page=100")
|
url := fmt.Sprintf(s.apiUrl + "/teams?per_page=100")
|
||||||
hasMore := true
|
hasMore := true
|
||||||
teams := make([]GithubTeam, 0)
|
teams := make([]GithubTeam, 0)
|
||||||
|
|
||||||
for hasMore {
|
for hasMore {
|
||||||
response, err := s.httpGet(client, url)
|
response, err := s.httpGet(ctx, client, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error getting team memberships: %s", err)
|
return nil, fmt.Errorf("Error getting team memberships: %s", err)
|
||||||
}
|
}
|
||||||
@ -150,7 +152,7 @@ func (s *SocialGithub) HasMoreRecords(headers http.Header) (string, bool) {
|
|||||||
return url, true
|
return url, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl string) ([]string, error) {
|
func (s *SocialGithub) FetchOrganizations(ctx context.Context, client *http.Client, organizationsUrl string) ([]string, error) {
|
||||||
url := organizationsUrl
|
url := organizationsUrl
|
||||||
hasMore := true
|
hasMore := true
|
||||||
logins := make([]string, 0)
|
logins := make([]string, 0)
|
||||||
@ -160,7 +162,7 @@ func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl
|
|||||||
}
|
}
|
||||||
|
|
||||||
for hasMore {
|
for hasMore {
|
||||||
response, err := s.httpGet(client, url)
|
response, err := s.httpGet(ctx, client, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting organizations: %s", err)
|
return nil, fmt.Errorf("error getting organizations: %s", err)
|
||||||
}
|
}
|
||||||
@ -181,7 +183,7 @@ func (s *SocialGithub) FetchOrganizations(client *http.Client, organizationsUrl
|
|||||||
return logins, nil
|
return logins, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
func (s *SocialGithub) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
@ -189,7 +191,7 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
|
|||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := s.httpGet(client, s.apiUrl)
|
response, err := s.httpGet(ctx, client, s.apiUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting user info: %s", err)
|
return nil, fmt.Errorf("error getting user info: %s", err)
|
||||||
}
|
}
|
||||||
@ -198,7 +200,7 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
|
|||||||
return nil, fmt.Errorf("error unmarshalling user info: %s", err)
|
return nil, fmt.Errorf("error unmarshalling user info: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
teamMemberships, err := s.FetchTeamMemberships(client)
|
teamMemberships, err := s.FetchTeamMemberships(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting user teams: %s", err)
|
return nil, fmt.Errorf("error getting user teams: %s", err)
|
||||||
}
|
}
|
||||||
@ -241,16 +243,16 @@ func (s *SocialGithub) UserInfo(client *http.Client, token *oauth2.Token) (*Basi
|
|||||||
|
|
||||||
organizationsUrl := fmt.Sprintf(s.apiUrl + "/orgs?per_page=100")
|
organizationsUrl := fmt.Sprintf(s.apiUrl + "/orgs?per_page=100")
|
||||||
|
|
||||||
if !s.IsTeamMember(client) {
|
if !s.IsTeamMember(ctx, client) {
|
||||||
return nil, ErrMissingTeamMembership
|
return nil, ErrMissingTeamMembership
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.IsOrganizationMember(client, organizationsUrl) {
|
if !s.IsOrganizationMember(ctx, client, organizationsUrl) {
|
||||||
return nil, ErrMissingOrganizationMembership
|
return nil, ErrMissingOrganizationMembership
|
||||||
}
|
}
|
||||||
|
|
||||||
if userInfo.Email == "" {
|
if userInfo.Email == "" {
|
||||||
userInfo.Email, err = s.FetchPrivateEmail(client)
|
userInfo.Email, err = s.FetchPrivateEmail(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package social
|
package social
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -250,7 +251,7 @@ func TestSocialGitHub_UserInfo(t *testing.T) {
|
|||||||
AccessToken: "fake_token",
|
AccessToken: "fake_token",
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := s.UserInfo(server.Client(), token)
|
got, err := s.UserInfo(context.Background(), server.Client(), token)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package social
|
package social
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -34,10 +35,10 @@ func (s *SocialGitlab) IsGroupMember(groups []string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGitlab) GetGroups(client *http.Client) []string {
|
func (s *SocialGitlab) GetGroups(ctx context.Context, client *http.Client) []string {
|
||||||
groups := make([]string, 0)
|
groups := make([]string, 0)
|
||||||
|
|
||||||
for page, url := s.GetGroupsPage(client, s.apiUrl+"/groups"); page != nil; page, url = s.GetGroupsPage(client, url) {
|
for page, url := s.GetGroupsPage(ctx, client, s.apiUrl+"/groups"); page != nil; page, url = s.GetGroupsPage(ctx, client, url) {
|
||||||
groups = append(groups, page...)
|
groups = append(groups, page...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -45,7 +46,7 @@ func (s *SocialGitlab) GetGroups(client *http.Client) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupsPage returns groups and link to the next page if response is paginated
|
// GetGroupsPage returns groups and link to the next page if response is paginated
|
||||||
func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string, string) {
|
func (s *SocialGitlab) GetGroupsPage(ctx context.Context, client *http.Client, url string) ([]string, string) {
|
||||||
type Group struct {
|
type Group struct {
|
||||||
FullPath string `json:"full_path"`
|
FullPath string `json:"full_path"`
|
||||||
}
|
}
|
||||||
@ -59,7 +60,7 @@ func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string,
|
|||||||
return nil, next
|
return nil, next
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := s.httpGet(client, url)
|
response, err := s.httpGet(ctx, client, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Error("Error getting groups from GitLab API", "err", err)
|
s.log.Error("Error getting groups from GitLab API", "err", err)
|
||||||
return nil, next
|
return nil, next
|
||||||
@ -86,7 +87,7 @@ func (s *SocialGitlab) GetGroupsPage(client *http.Client, url string) ([]string,
|
|||||||
return fullPaths, next
|
return fullPaths, next
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
|
func (s *SocialGitlab) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id int
|
Id int
|
||||||
Username string
|
Username string
|
||||||
@ -95,7 +96,7 @@ func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUse
|
|||||||
State string
|
State string
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := s.httpGet(client, s.apiUrl+"/user")
|
response, err := s.httpGet(ctx, client, s.apiUrl+"/user")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error getting user info: %s", err)
|
return nil, fmt.Errorf("Error getting user info: %s", err)
|
||||||
}
|
}
|
||||||
@ -108,7 +109,7 @@ func (s *SocialGitlab) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUse
|
|||||||
return nil, fmt.Errorf("user %s is inactive", data.Username)
|
return nil, fmt.Errorf("user %s is inactive", data.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
groups := s.GetGroups(client)
|
groups := s.GetGroups(ctx, client)
|
||||||
|
|
||||||
var role roletype.RoleType
|
var role roletype.RoleType
|
||||||
var isGrafanaAdmin *bool = nil
|
var isGrafanaAdmin *bool = nil
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package social
|
package social
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@ -159,7 +160,7 @@ func TestSocialGitlab_UserInfo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
provider.apiUrl = ts.URL + apiURI
|
provider.apiUrl = ts.URL + apiURI
|
||||||
actualResult, err := provider.UserInfo(ts.Client(), nil)
|
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), nil)
|
||||||
if test.ExpectedError != nil {
|
if test.ExpectedError != nil {
|
||||||
require.Equal(t, err, test.ExpectedError)
|
require.Equal(t, err, test.ExpectedError)
|
||||||
return
|
return
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package social
|
package social
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -16,14 +17,14 @@ type SocialGoogle struct {
|
|||||||
apiUrl string
|
apiUrl string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialGoogle) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id string `json:"id"`
|
Id string `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := s.httpGet(client, s.apiUrl)
|
response, err := s.httpGet(ctx, client, s.apiUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error getting user info: %s", err)
|
return nil, fmt.Errorf("Error getting user info: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package social
|
package social
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -43,7 +44,7 @@ func (s *SocialGrafanaCom) IsOrganizationMember(organizations []OrgRecord) bool
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UserInfo is used for login credentials for the user
|
// UserInfo is used for login credentials for the user
|
||||||
func (s *SocialGrafanaCom) UserInfo(client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
|
func (s *SocialGrafanaCom) UserInfo(ctx context.Context, client *http.Client, _ *oauth2.Token) (*BasicUserInfo, error) {
|
||||||
var data struct {
|
var data struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@ -53,7 +54,8 @@ func (s *SocialGrafanaCom) UserInfo(client *http.Client, _ *oauth2.Token) (*Basi
|
|||||||
Orgs []OrgRecord `json:"orgs"`
|
Orgs []OrgRecord `json:"orgs"`
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := s.httpGet(client, s.url+"/api/oauth2/user")
|
response, err := s.httpGet(ctx, client, s.url+"/api/oauth2/user")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error getting user info: %s", err)
|
return nil, fmt.Errorf("Error getting user info: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package social
|
package social
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@ -81,7 +82,7 @@ func TestSocialGrafanaCom_UserInfo(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
provider.url = ts.URL
|
provider.url = ts.URL
|
||||||
actualResult, err := provider.UserInfo(ts.Client(), nil)
|
actualResult, err := provider.UserInfo(context.Background(), ts.Client(), nil)
|
||||||
if test.ExpectedError != nil {
|
if test.ExpectedError != nil {
|
||||||
require.Equal(t, err, test.ExpectedError)
|
require.Equal(t, err, test.ExpectedError)
|
||||||
return
|
return
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package social
|
package social
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -46,7 +47,7 @@ func (claims *OktaClaims) extractEmail() string {
|
|||||||
return claims.Email
|
return claims.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
func (s *SocialOkta) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
|
||||||
idToken := token.Extra("id_token")
|
idToken := token.Extra("id_token")
|
||||||
if idToken == nil {
|
if idToken == nil {
|
||||||
return nil, fmt.Errorf("no id_token found")
|
return nil, fmt.Errorf("no id_token found")
|
||||||
@ -68,7 +69,7 @@ func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicU
|
|||||||
}
|
}
|
||||||
|
|
||||||
var data OktaUserInfoJson
|
var data OktaUserInfoJson
|
||||||
err = s.extractAPI(&data, client)
|
err = s.extractAPI(ctx, &data, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -105,8 +106,8 @@ func (s *SocialOkta) UserInfo(client *http.Client, token *oauth2.Token) (*BasicU
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SocialOkta) extractAPI(data *OktaUserInfoJson, client *http.Client) error {
|
func (s *SocialOkta) extractAPI(ctx context.Context, data *OktaUserInfoJson, client *http.Client) error {
|
||||||
rawUserInfoResponse, err := s.httpGet(client, s.apiUrl)
|
rawUserInfoResponse, err := s.httpGet(ctx, client, s.apiUrl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.log.Debug("Error getting user info response", "url", s.apiUrl, "error", err)
|
s.log.Debug("Error getting user info response", "url", s.apiUrl, "error", err)
|
||||||
return fmt.Errorf("error getting user info response: %w", err)
|
return fmt.Errorf("error getting user info response: %w", err)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package social
|
package social
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -110,7 +111,7 @@ func TestSocialOkta_UserInfo(t *testing.T) {
|
|||||||
Expiry: time.Now(),
|
Expiry: time.Now(),
|
||||||
}
|
}
|
||||||
token := staticToken.WithExtra(tt.OAuth2Extra)
|
token := staticToken.WithExtra(tt.OAuth2Extra)
|
||||||
got, err := provider.UserInfo(server.Client(), token)
|
got, err := provider.UserInfo(context.Background(), server.Client(), token)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -7,9 +7,11 @@ import (
|
|||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/text/cases"
|
"golang.org/x/text/cases"
|
||||||
@ -261,7 +263,7 @@ func (b *BasicUserInfo) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SocialConnector interface {
|
type SocialConnector interface {
|
||||||
UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error)
|
UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*BasicUserInfo, error)
|
||||||
IsEmailAllowed(email string) bool
|
IsEmailAllowed(email string) bool
|
||||||
IsSignupAllowed() bool
|
IsSignupAllowed() bool
|
||||||
|
|
||||||
@ -450,9 +452,19 @@ func (ss *SocialService) GetOAuthHttpClient(name string) (*http.Client, error) {
|
|||||||
TLSClientConfig: &tls.Config{
|
TLSClientConfig: &tls.Config{
|
||||||
InsecureSkipVerify: info.TlsSkipVerify,
|
InsecureSkipVerify: info.TlsSkipVerify,
|
||||||
},
|
},
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: time.Second * 10,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
TLSHandshakeTimeout: 15 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthClient := &http.Client{
|
oauthClient := &http.Client{
|
||||||
Transport: tr,
|
Transport: tr,
|
||||||
|
Timeout: time.Second * 15,
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.TlsClientCert != "" || info.TlsClientKey != "" {
|
if info.TlsClientCert != "" || info.TlsClientKey != "" {
|
||||||
|
@ -116,7 +116,7 @@ func (c *OAuth) Authenticate(ctx context.Context, r *authn.Request) (*authn.Iden
|
|||||||
}
|
}
|
||||||
token.TokenType = "Bearer"
|
token.TokenType = "Bearer"
|
||||||
|
|
||||||
userInfo, err := c.connector.UserInfo(c.connector.Client(clientCtx, token), token)
|
userInfo, err := c.connector.UserInfo(ctx, c.connector.Client(clientCtx, token), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var sErr *social.Error
|
var sErr *social.Error
|
||||||
if errors.As(err, &sErr) {
|
if errors.As(err, &sErr) {
|
||||||
|
@ -8,13 +8,14 @@ import (
|
|||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
"github.com/grafana/grafana/pkg/services/authn"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/login"
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
"github.com/grafana/grafana/pkg/services/org"
|
"github.com/grafana/grafana/pkg/services/org"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOAuth_Authenticate(t *testing.T) {
|
func TestOAuth_Authenticate(t *testing.T) {
|
||||||
@ -278,7 +279,7 @@ type fakeConnector struct {
|
|||||||
social.SocialConnector
|
social.SocialConnector
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f fakeConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
func (f fakeConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
||||||
return f.ExpectedUserInfo, f.ExpectedUserInfoErr
|
return f.ExpectedUserInfo, f.ExpectedUserInfoErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -273,7 +273,7 @@ func (m *MockSocialConnector) Type() int {
|
|||||||
return args.Int(0)
|
return args.Int(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockSocialConnector) UserInfo(client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
func (m *MockSocialConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
|
||||||
args := m.Called(client, token)
|
args := m.Called(client, token)
|
||||||
return args.Get(0).(*social.BasicUserInfo), args.Error(1)
|
return args.Get(0).(*social.BasicUserInfo), args.Error(1)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user