grafana/pkg/services/authn/clients/oauth_test.go

353 lines
11 KiB
Go
Raw Normal View History

package clients
import (
"context"
"net/http"
"net/url"
"testing"
"golang.org/x/oauth2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting"
)
func TestOAuth_Authenticate(t *testing.T) {
type testCase struct {
desc string
req *authn.Request
oauthCfg *social.OAuthInfo
allowInsecureTakeover bool
addStateCookie bool
stateCookieValue string
addPKCECookie bool
pkceCookieValue string
isEmailAllowed bool
userInfo *social.BasicUserInfo
expectedErr error
expectedIdentity *authn.Identity
}
tests := []testCase{
{
desc: "should return error when missing state cookie",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{},
expectedErr: errOAuthMissingState,
},
{
desc: "should return error when state cookie is present but don't have a value",
req: &authn.Request{HTTPRequest: &http.Request{Header: map[string][]string{}}},
oauthCfg: &social.OAuthInfo{},
addStateCookie: true,
stateCookieValue: "",
expectedErr: errOAuthMissingState,
},
{
desc: "should return error when state from ipd does not match stored state",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-other-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
addStateCookie: true,
stateCookieValue: "some-state",
expectedErr: errOAuthInvalidState,
},
{
desc: "should return error when pkce is configured but the cookie is not present",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
addStateCookie: true,
stateCookieValue: "some-state",
expectedErr: errOAuthMissingPKCE,
},
{
desc: "should return error when email is empty",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
userInfo: &social.BasicUserInfo{},
expectedErr: errOAuthMissingRequiredEmail,
},
{
desc: "should return error when email is not allowed",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
userInfo: &social.BasicUserInfo{Email: "some@email.com"},
isEmailAllowed: false,
expectedErr: errOAuthEmailNotAllowed,
},
{
desc: "should return identity for valid request",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
isEmailAllowed: true,
userInfo: &social.BasicUserInfo{
Id: "123",
Name: "name",
Email: "some@email.com",
Role: "Admin",
Groups: []string{"grp1", "grp2"},
},
expectedIdentity: &authn.Identity{
Email: "some@email.com",
AuthModule: "oauth_azuread",
AuthID: "123",
Name: "name",
Groups: []string{"grp1", "grp2"},
OAuthToken: &oauth2.Token{},
OrgRoles: map[int64]org.RoleType{1: org.RoleAdmin},
ClientParams: authn.ClientParams{
SyncUser: true,
SyncTeams: true,
AllowSignUp: true,
FetchSyncedUser: true,
SyncOrgRoles: true,
LookUpParams: login.UserLookupParams{},
},
},
},
{
desc: "should return identity for valid request - and lookup user by email",
req: &authn.Request{HTTPRequest: &http.Request{
Header: map[string][]string{},
URL: mustParseURL("http://grafana.com/?state=some-state"),
},
},
oauthCfg: &social.OAuthInfo{UsePKCE: true},
allowInsecureTakeover: true,
addStateCookie: true,
stateCookieValue: "some-state",
addPKCECookie: true,
pkceCookieValue: "some-pkce-value",
isEmailAllowed: true,
userInfo: &social.BasicUserInfo{
Id: "123",
Name: "name",
Email: "some@email.com",
Role: "Admin",
Groups: []string{"grp1", "grp2"},
},
expectedIdentity: &authn.Identity{
Email: "some@email.com",
AuthModule: "oauth_azuread",
AuthID: "123",
Name: "name",
Groups: []string{"grp1", "grp2"},
OAuthToken: &oauth2.Token{},
OrgRoles: map[int64]org.RoleType{1: org.RoleAdmin},
ClientParams: authn.ClientParams{
SyncUser: true,
SyncTeams: true,
AllowSignUp: true,
FetchSyncedUser: true,
SyncOrgRoles: true,
LookUpParams: login.UserLookupParams{Email: strPtr("some@email.com")},
},
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
cfg := setting.NewCfg()
if tt.allowInsecureTakeover {
cfg.OAuthAllowInsecureEmailLookup = true
}
if tt.addStateCookie {
v := tt.stateCookieValue
if v != "" {
v = hashOAuthState(v, cfg.SecretKey, tt.oauthCfg.ClientSecret)
}
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthStateCookieName, Value: v})
}
if tt.addPKCECookie {
tt.req.HTTPRequest.AddCookie(&http.Cookie{Name: oauthPKCECookieName, Value: tt.pkceCookieValue})
}
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), cfg, tt.oauthCfg, fakeConnector{
ExpectedUserInfo: tt.userInfo,
ExpectedToken: &oauth2.Token{},
ExpectedIsSignupAllowed: true,
ExpectedIsEmailAllowed: tt.isEmailAllowed,
}, nil)
identity, err := c.Authenticate(context.Background(), tt.req)
assert.ErrorIs(t, err, tt.expectedErr)
if tt.expectedIdentity != nil {
assert.Equal(t, tt.expectedIdentity.Login, identity.Login)
assert.Equal(t, tt.expectedIdentity.Name, identity.Name)
assert.Equal(t, tt.expectedIdentity.Email, identity.Email)
assert.Equal(t, tt.expectedIdentity.AuthID, identity.AuthID)
assert.Equal(t, tt.expectedIdentity.AuthModule, identity.AuthModule)
assert.Equal(t, tt.expectedIdentity.Groups, identity.Groups)
assert.Equal(t, tt.expectedIdentity.ClientParams.SyncUser, identity.ClientParams.SyncUser)
assert.Equal(t, tt.expectedIdentity.ClientParams.AllowSignUp, identity.ClientParams.AllowSignUp)
assert.Equal(t, tt.expectedIdentity.ClientParams.SyncTeams, identity.ClientParams.SyncTeams)
assert.Equal(t, tt.expectedIdentity.ClientParams.EnableDisabledUsers, identity.ClientParams.EnableDisabledUsers)
assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.Email, identity.ClientParams.LookUpParams.Email)
assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.Login, identity.ClientParams.LookUpParams.Login)
assert.EqualValues(t, tt.expectedIdentity.ClientParams.LookUpParams.UserID, identity.ClientParams.LookUpParams.UserID)
} else {
assert.Nil(t, tt.expectedIdentity)
}
})
}
}
func TestOAuth_RedirectURL(t *testing.T) {
type testCase struct {
desc string
oauthCfg *social.OAuthInfo
expectedErr error
numCallOptions int
authCodeUrlCalled bool
}
tests := []testCase{
{
desc: "should generate redirect url and state",
oauthCfg: &social.OAuthInfo{},
authCodeUrlCalled: true,
},
{
desc: "should generate redirect url with hosted domain option if configured",
oauthCfg: &social.OAuthInfo{HostedDomain: "grafana.com"},
numCallOptions: 1,
authCodeUrlCalled: true,
},
{
desc: "should generate redirect url with pkce if configured",
oauthCfg: &social.OAuthInfo{UsePKCE: true},
numCallOptions: 2,
authCodeUrlCalled: true,
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
var (
authCodeUrlCalled = false
)
c := ProvideOAuth(authn.ClientWithPrefix("azuread"), setting.NewCfg(), tt.oauthCfg, mockConnector{
AuthCodeURLFunc: func(state string, opts ...oauth2.AuthCodeOption) string {
authCodeUrlCalled = true
require.Len(t, opts, tt.numCallOptions)
return ""
},
}, nil)
redirect, err := c.RedirectURL(context.Background(), nil)
assert.ErrorIs(t, err, tt.expectedErr)
assert.Equal(t, tt.authCodeUrlCalled, authCodeUrlCalled)
if tt.expectedErr != nil {
return
}
assert.NotEmpty(t, redirect.Extra[authn.KeyOAuthState])
if tt.oauthCfg.UsePKCE {
assert.NotEmpty(t, redirect.Extra[authn.KeyOAuthPKCE])
}
})
}
}
type mockConnector struct {
AuthCodeURLFunc func(state string, opts ...oauth2.AuthCodeOption) string
social.SocialConnector
}
func (m mockConnector) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string {
if m.AuthCodeURLFunc != nil {
return m.AuthCodeURLFunc(state, opts...)
}
return ""
}
var _ social.SocialConnector = new(fakeConnector)
type fakeConnector struct {
ExpectedUserInfo *social.BasicUserInfo
ExpectedUserInfoErr error
ExpectedIsEmailAllowed bool
ExpectedIsSignupAllowed bool
ExpectedToken *oauth2.Token
ExpectedTokenErr error
social.SocialConnector
}
func (f fakeConnector) UserInfo(ctx context.Context, client *http.Client, token *oauth2.Token) (*social.BasicUserInfo, error) {
return f.ExpectedUserInfo, f.ExpectedUserInfoErr
}
func (f fakeConnector) IsEmailAllowed(email string) bool {
return f.ExpectedIsEmailAllowed
}
func (f fakeConnector) IsSignupAllowed() bool {
return f.ExpectedIsSignupAllowed
}
func (f fakeConnector) Exchange(ctx context.Context, code string, authOptions ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
return f.ExpectedToken, f.ExpectedTokenErr
}
func (f fakeConnector) Client(ctx context.Context, t *oauth2.Token) *http.Client {
return nil
}
func mustParseURL(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}