mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Use authn.Service for all tests (#72921)
* Dashboards: Fix tests when authn broker is enabled. StarService was not configured for tests, the call was guarded by !c.IsSignedIn * Change default to be anon user to match expectations from tests * OAuth: rewrite tests to work with authn.Service * Setup template renderer by default * Extract cookie options from cfg instead of relying on global variables * Fix test to work with authn service * Middleware: rewrite auth tests * Remvoe session cookie if we cannot refresh access token
This commit is contained in:
parent
5eef8291e2
commit
144e4887ee
@ -72,6 +72,7 @@ func loggedInUserScenarioWithRole(t *testing.T, desc string, method string, url
|
||||
sc.context.OrgID = testOrgID
|
||||
sc.context.Login = testUserLogin
|
||||
sc.context.OrgRole = role
|
||||
sc.context.IsAnonymous = false
|
||||
if sc.handlerFunc != nil {
|
||||
return sc.handlerFunc(sc.context)
|
||||
}
|
||||
@ -212,7 +213,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
|
||||
remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil,
|
||||
authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake(),
|
||||
nil, featuremgmt.WithFeatures(), &authntest.FakeService{
|
||||
ExpectedIdentity: &authn.Identity{OrgID: 1, ID: "user:1", SessionToken: &usertoken.UserToken{}}}, &anontest.FakeAnonymousSessionService{})
|
||||
ExpectedIdentity: &authn.Identity{IsAnonymous: true, SessionToken: &usertoken.UserToken{}}}, &anontest.FakeAnonymousSessionService{})
|
||||
|
||||
return ctxHdlr
|
||||
}
|
||||
@ -310,6 +311,11 @@ func SetupAPITestServer(t *testing.T, opts ...APITestServerOption) *webtest.Serv
|
||||
hs.registerRoutes()
|
||||
|
||||
s := webtest.NewServer(t, hs.RouteRegister)
|
||||
|
||||
viewsPath, err := filepath.Abs("../../public/views")
|
||||
require.NoError(t, err)
|
||||
s.Mux.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -52,6 +52,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/publicdashboards"
|
||||
"github.com/grafana/grafana/pkg/services/publicdashboards/api"
|
||||
"github.com/grafana/grafana/pkg/services/quota/quotatest"
|
||||
"github.com/grafana/grafana/pkg/services/star/startest"
|
||||
"github.com/grafana/grafana/pkg/services/tag/tagimpl"
|
||||
"github.com/grafana/grafana/pkg/services/team/teamtest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
@ -160,6 +161,7 @@ func TestDashboardAPIEndpoint(t *testing.T) {
|
||||
dashboardVersionService: fakeDashboardVersionService,
|
||||
Kinds: corekind.NewBase(nil),
|
||||
QuotaService: quotatest.New(false, nil),
|
||||
starService: startest.NewStarServiceFake(),
|
||||
userService: &usertest.FakeUserService{
|
||||
ExpectedUser: &user.User{ID: 1, Login: "test-user"},
|
||||
},
|
||||
@ -933,6 +935,7 @@ func TestDashboardAPIEndpoint(t *testing.T) {
|
||||
DashboardService: dashboardService,
|
||||
Features: featuremgmt.WithFeatures(),
|
||||
Kinds: corekind.NewBase(nil),
|
||||
starService: startest.NewStarServiceFake(),
|
||||
}
|
||||
hs.callGetDashboard(sc)
|
||||
|
||||
@ -1121,6 +1124,7 @@ func getDashboardShouldReturn200WithConfig(t *testing.T, sc *scenarioContext, pr
|
||||
DashboardService: dashboardService,
|
||||
Features: featuremgmt.WithFeatures(),
|
||||
Kinds: corekind.NewBase(nil),
|
||||
starService: startest.NewStarServiceFake(),
|
||||
}
|
||||
|
||||
hs.callGetDashboard(sc)
|
||||
|
@ -92,19 +92,20 @@ func (hs *HTTPServer) OAuthLogin(reqCtx *contextmodel.ReqContext) {
|
||||
return
|
||||
}
|
||||
|
||||
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
|
||||
if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" {
|
||||
cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
}
|
||||
|
||||
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||
reqCtx.Redirect(redirect.URL)
|
||||
return
|
||||
}
|
||||
|
||||
identity, err := hs.authnService.Login(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
|
||||
// NOTE: always delete these cookies, even if login failed
|
||||
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
||||
cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
||||
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
||||
|
||||
if err != nil {
|
||||
reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))
|
||||
|
@ -1,237 +1,212 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/infra/usagestats"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/models/roletype"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/hooks"
|
||||
"github.com/grafana/grafana/pkg/services/licensing"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||
"github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||
"github.com/grafana/grafana/pkg/services/supportbundles/supportbundlestest"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
)
|
||||
|
||||
func setupSocialHTTPServerWithConfig(t *testing.T, cfg *setting.Cfg) *HTTPServer {
|
||||
sqlStore := db.InitTestDB(t)
|
||||
features := featuremgmt.WithFeatures()
|
||||
|
||||
return &HTTPServer{
|
||||
Cfg: cfg,
|
||||
License: &licensing.OSSLicensingService{Cfg: cfg},
|
||||
SQLStore: sqlStore,
|
||||
SocialService: social.ProvideService(cfg, features, &usagestats.UsageStatsMock{}, supportbundlestest.NewFakeBundleService(), remotecache.NewFakeCacheStorage()),
|
||||
HooksService: hooks.ProvideService(),
|
||||
SecretsService: fakes.NewFakeSecretsService(),
|
||||
Features: features,
|
||||
}
|
||||
}
|
||||
|
||||
func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux {
|
||||
func setClientWithoutRedirectFollow(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
if cfg == nil {
|
||||
cfg = setting.NewCfg()
|
||||
}
|
||||
cfg.ErrTemplateName = "error-template"
|
||||
hs := setupSocialHTTPServerWithConfig(t, cfg)
|
||||
|
||||
m := web.New()
|
||||
m.Use(getContextHandler(t, cfg).Middleware)
|
||||
viewPath, err := filepath.Abs("../../public/views")
|
||||
require.NoError(t, err)
|
||||
|
||||
m.UseMiddleware(web.Renderer(viewPath, "[[", "]]"))
|
||||
|
||||
m.Get("/login/:name", hs.OAuthLogin)
|
||||
return m
|
||||
}
|
||||
|
||||
func TestOAuthLogin_UnknownProvider(t *testing.T) {
|
||||
m := setupOAuthTest(t, nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/login/notaprovider", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
m.ServeHTTP(recorder, req)
|
||||
// expect to be redirected to /login
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
assert.Equal(t, "/login", recorder.Header().Get("Location"))
|
||||
}
|
||||
|
||||
func TestOAuthLogin_Base(t *testing.T) {
|
||||
cfg := setting.NewCfg()
|
||||
sec := cfg.Raw.Section("auth.generic_oauth")
|
||||
_, err := sec.NewKey("enabled", "true")
|
||||
require.NoError(t, err)
|
||||
|
||||
m := setupOAuthTest(t, cfg)
|
||||
req := httptest.NewRequest(http.MethodGet, "/login/generic_oauth", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
m.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.NotEmpty(t, location)
|
||||
|
||||
u, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, u.Query().Has("code_challenge"))
|
||||
assert.False(t, u.Query().Has("code_challenge_method"))
|
||||
|
||||
resp := recorder.Result()
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
cookies := resp.Cookies()
|
||||
var stateCookie *http.Cookie
|
||||
for _, c := range cookies {
|
||||
if c.Name == OauthStateCookieName {
|
||||
stateCookie = c
|
||||
}
|
||||
}
|
||||
require.NotNil(t, stateCookie)
|
||||
|
||||
req = httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
(&url.URL{
|
||||
Path: "/login/generic_oauth",
|
||||
RawQuery: url.Values{
|
||||
"code": []string{"helloworld"},
|
||||
"state": []string{u.Query().Get("state")},
|
||||
}.Encode(),
|
||||
}).String(),
|
||||
nil,
|
||||
)
|
||||
req.AddCookie(stateCookie)
|
||||
recorder = httptest.NewRecorder()
|
||||
|
||||
m.ServeHTTP(recorder, req)
|
||||
// TODO: validate that 'creating a token works'
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "login.OAuthLogin(NewTransportWithCode)")
|
||||
}
|
||||
|
||||
func TestOAuthLogin_UsePKCE(t *testing.T) {
|
||||
cfg := setting.NewCfg()
|
||||
sec := cfg.Raw.Section("auth.generic_oauth")
|
||||
_, err := sec.NewKey("enabled", "true")
|
||||
require.NoError(t, err)
|
||||
_, err = sec.NewKey("use_pkce", "true")
|
||||
require.NoError(t, err)
|
||||
|
||||
m := setupOAuthTest(t, cfg)
|
||||
req := httptest.NewRequest(http.MethodGet, "/login/generic_oauth", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
m.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
||||
|
||||
location := recorder.Header().Get("Location")
|
||||
assert.NotEmpty(t, location)
|
||||
|
||||
u, err := url.Parse(location)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, u.Query().Has("code_challenge"))
|
||||
assert.Equal(t, "S256", u.Query().Get("code_challenge_method"))
|
||||
|
||||
resp := recorder.Result()
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
var oauthCookie *http.Cookie
|
||||
for _, cookie := range resp.Cookies() {
|
||||
if cookie.Name == OauthPKCECookieName {
|
||||
oauthCookie = cookie
|
||||
}
|
||||
}
|
||||
require.NotNil(t, oauthCookie)
|
||||
|
||||
shasum := sha256.Sum256([]byte(oauthCookie.Value))
|
||||
assert.Equal(
|
||||
t,
|
||||
u.Query().Get("code_challenge"),
|
||||
base64.RawURLEncoding.EncodeToString(shasum[:]),
|
||||
)
|
||||
}
|
||||
|
||||
func TestOAuthLogin_BuildExternalUserInfo(t *testing.T) {
|
||||
t.Helper()
|
||||
cfgOAuthSkipRoleSync := setting.NewCfg()
|
||||
authOAuthSec := cfgOAuthSkipRoleSync.Raw.Section("auth")
|
||||
_, err := authOAuthSec.NewKey("oauth_skip_org_role_update_sync", "true")
|
||||
require.NoError(t, err)
|
||||
cfgOAuthSkipRoleSync.ErrTemplateName = "error-template"
|
||||
|
||||
cfgOAuthOrgRoleSync := setting.NewCfg()
|
||||
authOAutoWithoutSec := cfgOAuthOrgRoleSync.Raw.Section("auth")
|
||||
_, err = authOAutoWithoutSec.NewKey("oauth_skip_org_role_update_sync", "false")
|
||||
require.NoError(t, err)
|
||||
cfgOAuthOrgRoleSync.ErrTemplateName = "error-template"
|
||||
|
||||
testcases := []struct {
|
||||
name string
|
||||
cfg *setting.Cfg
|
||||
basicUser *social.BasicUserInfo
|
||||
expectedOrgRoles map[int64]org.RoleType
|
||||
}{
|
||||
{
|
||||
name: "should return empty map of org role mapping if the role for the basic info is empty",
|
||||
cfg: cfgOAuthOrgRoleSync,
|
||||
basicUser: &social.BasicUserInfo{
|
||||
Id: "1",
|
||||
Name: "first lastname",
|
||||
Email: "example@github.com",
|
||||
Login: "example",
|
||||
Role: "",
|
||||
},
|
||||
expectedOrgRoles: map[int64]org.RoleType{},
|
||||
},
|
||||
{
|
||||
name: "should set internal role if role exists and we are skipping org role sync",
|
||||
cfg: cfgOAuthSkipRoleSync,
|
||||
basicUser: &social.BasicUserInfo{
|
||||
Id: "1",
|
||||
Name: "first lastname",
|
||||
Email: "example@github.com",
|
||||
Login: "example",
|
||||
Role: roletype.RoleAdmin,
|
||||
},
|
||||
expectedOrgRoles: map[int64]org.RoleType{1: roletype.RoleAdmin},
|
||||
},
|
||||
{
|
||||
name: "should return empty external role, if the role for the basic info is empty",
|
||||
cfg: cfgOAuthSkipRoleSync,
|
||||
basicUser: &social.BasicUserInfo{
|
||||
Id: "1",
|
||||
Name: "first lastname",
|
||||
Email: "example@github.com",
|
||||
Login: "example",
|
||||
Role: "",
|
||||
},
|
||||
expectedOrgRoles: map[int64]org.RoleType{},
|
||||
old := http.DefaultClient
|
||||
http.DefaultClient = &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
for _, tc := range testcases {
|
||||
t.Logf("%s", tc.name)
|
||||
cfg := tc.cfg
|
||||
hs := setupSocialHTTPServerWithConfig(t, cfg)
|
||||
externalUser := hs.buildExternalUserInfo(nil, tc.basicUser, "")
|
||||
require.Equal(t, tc.expectedOrgRoles, externalUser.OrgRoles)
|
||||
|
||||
t.Cleanup(func() {
|
||||
http.DefaultClient = old
|
||||
})
|
||||
}
|
||||
|
||||
func TestOAuthLogin_Redirect(t *testing.T) {
|
||||
type testCase struct {
|
||||
desc string
|
||||
expectedErr error
|
||||
expectedCode int
|
||||
expectedRedirect *authn.Redirect
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "should be redirected to /login when passing un-configured provider",
|
||||
expectedErr: authn.ErrClientNotConfigured,
|
||||
expectedCode: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "should be redirected to provider",
|
||||
expectedCode: http.StatusFound,
|
||||
expectedRedirect: &authn.Redirect{
|
||||
URL: "https://some-provider.com",
|
||||
Extra: map[string]string{
|
||||
authn.KeyOAuthState: "some-state",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "should set pkce cookie",
|
||||
expectedCode: http.StatusFound,
|
||||
expectedRedirect: &authn.Redirect{
|
||||
URL: "https://some-provider.com",
|
||||
Extra: map[string]string{
|
||||
authn.KeyOAuthState: "some-state",
|
||||
authn.KeyOAuthPKCE: "pkce-",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
server := SetupAPITestServer(t, func(hs *HTTPServer) {
|
||||
hs.Cfg = setting.NewCfg()
|
||||
hs.SecretsService = fakes.NewFakeSecretsService()
|
||||
hs.authnService = &authntest.FakeService{
|
||||
ExpectedErr: tt.expectedErr,
|
||||
ExpectedRedirect: tt.expectedRedirect,
|
||||
}
|
||||
})
|
||||
|
||||
// we need to prevent the http.Client from following redirects
|
||||
setClientWithoutRedirectFollow(t)
|
||||
|
||||
res, err := server.Send(server.NewGetRequest("/login/generic_oauth"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusFound, res.StatusCode)
|
||||
|
||||
// on every error we should get redirected to /login
|
||||
if tt.expectedErr != nil {
|
||||
assert.Equal(t, "/login", res.Header.Get("Location"))
|
||||
} else {
|
||||
// check that we get correct redirect url
|
||||
assert.Equal(t, tt.expectedRedirect.URL, res.Header.Get("Location"))
|
||||
|
||||
require.GreaterOrEqual(t, len(res.Cookies()), 1)
|
||||
if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" {
|
||||
require.Len(t, res.Cookies(), 2)
|
||||
} else {
|
||||
require.Len(t, res.Cookies(), 1)
|
||||
}
|
||||
|
||||
require.GreaterOrEqual(t, len(res.Cookies()), 1)
|
||||
stateCookie := res.Cookies()[0]
|
||||
assert.Equal(t, OauthStateCookieName, stateCookie.Name)
|
||||
assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthState], stateCookie.Value)
|
||||
|
||||
if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" {
|
||||
require.Len(t, res.Cookies(), 2)
|
||||
pkceCookie := res.Cookies()[1]
|
||||
assert.Equal(t, OauthPKCECookieName, pkceCookie.Name)
|
||||
assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthPKCE], pkceCookie.Value)
|
||||
} else {
|
||||
require.Len(t, res.Cookies(), 1)
|
||||
}
|
||||
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthLogin_AuthorizationCode(t *testing.T) {
|
||||
type testCase struct {
|
||||
desc string
|
||||
expectedErr error
|
||||
expectedIdentity *authn.Identity
|
||||
}
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "should redirect to /login on error",
|
||||
expectedErr: errors.New("some error"),
|
||||
},
|
||||
{
|
||||
desc: "should redirect to / and set session cookie on successful authentication",
|
||||
expectedIdentity: &authn.Identity{
|
||||
SessionToken: &usertoken.UserToken{UnhashedToken: "some-token"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
var cfg *setting.Cfg
|
||||
server := SetupAPITestServer(t, func(hs *HTTPServer) {
|
||||
cfg = setting.NewCfg()
|
||||
hs.Cfg = cfg
|
||||
hs.Cfg.LoginCookieName = "some_name"
|
||||
hs.SecretsService = fakes.NewFakeSecretsService()
|
||||
hs.authnService = &authntest.FakeService{
|
||||
ExpectedErr: tt.expectedErr,
|
||||
ExpectedIdentity: tt.expectedIdentity,
|
||||
}
|
||||
})
|
||||
|
||||
// we need to prevent the http.Client from following redirects
|
||||
setClientWithoutRedirectFollow(t)
|
||||
|
||||
res, err := server.Send(server.NewGetRequest("/login/generic_oauth?code=code"))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.GreaterOrEqual(t, len(res.Cookies()), 3)
|
||||
|
||||
// make sure oauth state cookie is deleted
|
||||
assert.Equal(t, OauthStateCookieName, res.Cookies()[0].Name)
|
||||
assert.Equal(t, "", res.Cookies()[0].Value)
|
||||
assert.Equal(t, -1, res.Cookies()[0].MaxAge)
|
||||
|
||||
// make sure oauth pkce cookie is deleted
|
||||
assert.Equal(t, OauthPKCECookieName, res.Cookies()[1].Name)
|
||||
assert.Equal(t, "", res.Cookies()[1].Value)
|
||||
assert.Equal(t, -1, res.Cookies()[1].MaxAge)
|
||||
|
||||
if tt.expectedErr != nil {
|
||||
require.Len(t, res.Cookies(), 3)
|
||||
assert.Equal(t, http.StatusFound, res.StatusCode)
|
||||
assert.Equal(t, "/login", res.Header.Get("Location"))
|
||||
assert.Equal(t, loginErrorCookieName, res.Cookies()[2].Name)
|
||||
} else {
|
||||
require.Len(t, res.Cookies(), 4)
|
||||
assert.Equal(t, http.StatusFound, res.StatusCode)
|
||||
assert.Equal(t, "/", res.Header.Get("Location"))
|
||||
|
||||
// verify session expiry cookie is set
|
||||
assert.Equal(t, cfg.LoginCookieName, res.Cookies()[2].Name)
|
||||
assert.Equal(t, "grafana_session_expiry", res.Cookies()[3].Name)
|
||||
}
|
||||
|
||||
require.NoError(t, res.Body.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthLogin_Error(t *testing.T) {
|
||||
server := SetupAPITestServer(t, func(hs *HTTPServer) {
|
||||
hs.Cfg = setting.NewCfg()
|
||||
hs.SecretsService = fakes.NewFakeSecretsService()
|
||||
})
|
||||
|
||||
setClientWithoutRedirectFollow(t)
|
||||
|
||||
res, err := server.Send(server.NewGetRequest("/login/azuread?error=someerror"))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusFound, res.StatusCode)
|
||||
assert.Equal(t, "/login", res.Header.Get("Location"))
|
||||
|
||||
require.Len(t, res.Cookies(), 1)
|
||||
errCookie := res.Cookies()[0]
|
||||
assert.Equal(t, loginErrorCookieName, errCookie.Name)
|
||||
require.NoError(t, res.Body.Close())
|
||||
}
|
||||
|
@ -20,14 +20,15 @@ import (
|
||||
"github.com/grafana/grafana/pkg/api/routing"
|
||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/login"
|
||||
"github.com/grafana/grafana/pkg/login/social"
|
||||
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/hooks"
|
||||
"github.com/grafana/grafana/pkg/services/licensing"
|
||||
loginservice "github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/navtree"
|
||||
"github.com/grafana/grafana/pkg/services/secrets"
|
||||
"github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||
@ -317,11 +318,15 @@ func TestLoginPostRedirect(t *testing.T) {
|
||||
|
||||
fakeViewIndex(t)
|
||||
sc := setupScenarioContext(t, "/login")
|
||||
|
||||
hs := &HTTPServer{
|
||||
log: log.NewNopLogger(),
|
||||
Cfg: setting.NewCfg(),
|
||||
HooksService: &hooks.HooksService{},
|
||||
License: &licensing.OSSLicensingService{},
|
||||
log: log.NewNopLogger(),
|
||||
Cfg: setting.NewCfg(),
|
||||
HooksService: &hooks.HooksService{},
|
||||
License: &licensing.OSSLicensingService{},
|
||||
authnService: &authntest.FakeService{
|
||||
ExpectedIdentity: &authn.Identity{ID: "user:42", SessionToken: &usertoken.UserToken{}},
|
||||
},
|
||||
AuthTokenService: authtest.NewFakeUserAuthTokenService(),
|
||||
Features: featuremgmt.WithFeatures(),
|
||||
}
|
||||
@ -333,13 +338,6 @@ func TestLoginPostRedirect(t *testing.T) {
|
||||
return hs.LoginPost(c)
|
||||
})
|
||||
|
||||
user := &user.User{
|
||||
ID: 42,
|
||||
Email: "",
|
||||
}
|
||||
|
||||
hs.authenticator = &fakeAuthenticator{user, "", nil}
|
||||
|
||||
redirectCases := []redirectCase{
|
||||
{
|
||||
desc: "grafana relative url without subpath",
|
||||
@ -429,6 +427,9 @@ func TestLoginPostRedirect(t *testing.T) {
|
||||
hs.Cfg.AppSubURL = c.appSubURL
|
||||
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
if c.desc == "grafana invalid relative url starting with subpath" {
|
||||
fmt.Println()
|
||||
}
|
||||
expCookiePath := "/"
|
||||
if len(hs.Cfg.AppSubURL) > 0 {
|
||||
expCookiePath = hs.Cfg.AppSubURL
|
||||
@ -640,112 +641,6 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte
|
||||
return sc
|
||||
}
|
||||
|
||||
type loginHookTest struct {
|
||||
info *loginservice.LoginInfo
|
||||
}
|
||||
|
||||
func (r *loginHookTest) LoginHook(loginInfo *loginservice.LoginInfo, req *contextmodel.ReqContext) {
|
||||
r.info = loginInfo
|
||||
}
|
||||
|
||||
// TOREMOVE: remove with context handler auth
|
||||
func TestLoginPostRunLokingHook(t *testing.T) {
|
||||
sc := setupScenarioContext(t, "/login")
|
||||
hookService := &hooks.HooksService{}
|
||||
hs := &HTTPServer{
|
||||
log: log.New("test"),
|
||||
Cfg: sc.cfg,
|
||||
License: &licensing.OSSLicensingService{},
|
||||
AuthTokenService: authtest.NewFakeUserAuthTokenService(),
|
||||
Features: featuremgmt.WithFeatures(),
|
||||
HooksService: hookService,
|
||||
authnService: sc.ctxHdlr.AuthnService,
|
||||
}
|
||||
|
||||
sc.cfg.AuthBrokerEnabled = false
|
||||
|
||||
sc.defaultHandler = routing.Wrap(func(c *contextmodel.ReqContext) response.Response {
|
||||
c.Req.Header.Set("Content-Type", "application/json")
|
||||
c.Req.Body = io.NopCloser(bytes.NewBufferString(`{"user":"admin","password":"admin"}`))
|
||||
x := hs.LoginPost(c)
|
||||
return x
|
||||
})
|
||||
|
||||
testHook := loginHookTest{}
|
||||
hookService.AddLoginHook(testHook.LoginHook)
|
||||
|
||||
testUser := &user.User{
|
||||
ID: 42,
|
||||
Email: "",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
authUser *user.User
|
||||
authModule string
|
||||
authErr error
|
||||
info loginservice.LoginInfo
|
||||
}{
|
||||
{
|
||||
desc: "invalid credentials",
|
||||
authErr: login.ErrInvalidCredentials,
|
||||
info: loginservice.LoginInfo{
|
||||
AuthModule: "",
|
||||
HTTPStatus: 401,
|
||||
Error: login.ErrInvalidCredentials,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "user disabled",
|
||||
authErr: login.ErrUserDisabled,
|
||||
info: loginservice.LoginInfo{
|
||||
AuthModule: "",
|
||||
HTTPStatus: 401,
|
||||
Error: login.ErrUserDisabled,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "valid Grafana user",
|
||||
authUser: testUser,
|
||||
authModule: "grafana",
|
||||
info: loginservice.LoginInfo{
|
||||
AuthModule: "grafana",
|
||||
User: testUser,
|
||||
HTTPStatus: 200,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "valid LDAP user",
|
||||
authUser: testUser,
|
||||
authModule: loginservice.LDAPAuthModule,
|
||||
info: loginservice.LoginInfo{
|
||||
AuthModule: loginservice.LDAPAuthModule,
|
||||
User: testUser,
|
||||
HTTPStatus: 200,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range testCases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
hs.authenticator = &fakeAuthenticator{c.authUser, c.authModule, c.authErr}
|
||||
sc.m.Post(sc.url, sc.defaultHandler)
|
||||
sc.fakeReqNoAssertions("POST", sc.url).exec()
|
||||
|
||||
info := testHook.info
|
||||
assert.Equal(t, c.info.AuthModule, info.AuthModule)
|
||||
assert.Equal(t, "admin", info.LoginUsername)
|
||||
assert.Equal(t, c.info.HTTPStatus, info.HTTPStatus)
|
||||
assert.Equal(t, c.info.Error, info.Error)
|
||||
|
||||
if c.info.User != nil {
|
||||
require.NotEmpty(t, info.User)
|
||||
assert.Equal(t, c.info.User.ID, info.User.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockSocialService struct {
|
||||
oAuthInfo *social.OAuthInfo
|
||||
oAuthInfos map[string]*social.OAuthInfo
|
||||
@ -774,15 +669,3 @@ func (m *mockSocialService) GetOAuthHttpClient(name string) (*http.Client, error
|
||||
func (m *mockSocialService) GetConnector(string) (social.SocialConnector, error) {
|
||||
return m.socialConnector, m.err
|
||||
}
|
||||
|
||||
type fakeAuthenticator struct {
|
||||
ExpectedUser *user.User
|
||||
ExpectedAuthModule string
|
||||
ExpectedError error
|
||||
}
|
||||
|
||||
func (fa *fakeAuthenticator) AuthenticateUser(c context.Context, query *loginservice.LoginUserQuery) error {
|
||||
query.User = fa.ExpectedUser
|
||||
query.AuthModule = fa.ExpectedAuthModule
|
||||
return fa.ExpectedError
|
||||
}
|
||||
|
@ -1,115 +1,152 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
)
|
||||
|
||||
func TestMiddlewareAuth(t *testing.T) {
|
||||
reqSignIn := Auth(&AuthOptions{ReqSignedIn: true})
|
||||
func setupAuthMiddlewareTest(t *testing.T, identity *authn.Identity, authErr error) *contexthandler.ContextHandler {
|
||||
return contexthandler.ProvideService(setting.NewCfg(), nil, nil, nil, nil, nil, tracing.NewFakeTracer(), nil, nil, nil, nil, nil, nil, nil, nil, &authntest.FakeService{
|
||||
ExpectedErr: authErr,
|
||||
ExpectedIdentity: identity,
|
||||
}, nil)
|
||||
}
|
||||
|
||||
middlewareScenario(t, "ReqSignIn true and unauthenticated request", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
|
||||
sc.fakeReq("GET", "/secure").exec()
|
||||
func TestAuth_Middleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
desc string
|
||||
identity *authn.Identity
|
||||
path string
|
||||
authErr error
|
||||
authMiddleware web.Handler
|
||||
expecedReached bool
|
||||
expectedCode int
|
||||
}
|
||||
|
||||
assert.Equal(t, 302, sc.resp.Code)
|
||||
})
|
||||
tests := []testCase{
|
||||
{
|
||||
desc: "ReqSignedIn should redirect unauthenticated request to secure endpoint",
|
||||
path: "/secure",
|
||||
authMiddleware: ReqSignedIn,
|
||||
authErr: errors.New("no auth"),
|
||||
expectedCode: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "ReqSignedIn should return 401 for api endpint",
|
||||
path: "/api/secure",
|
||||
authMiddleware: ReqSignedIn,
|
||||
authErr: errors.New("no auth"),
|
||||
expectedCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
desc: "ReqSignedIn should return 200 for anonymous user",
|
||||
path: "/api/secure",
|
||||
authMiddleware: ReqSignedIn,
|
||||
identity: &authn.Identity{IsAnonymous: true},
|
||||
expecedReached: true,
|
||||
expectedCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "ReqSignedIn should return redirect anonymous user with forceLogin query string",
|
||||
path: "/secure?forceLogin=true",
|
||||
authMiddleware: ReqSignedIn,
|
||||
identity: &authn.Identity{IsAnonymous: true},
|
||||
expecedReached: false,
|
||||
expectedCode: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "ReqSignedIn should return redirect anonymous user when orgId in query string is different from currently used",
|
||||
path: "/secure?orgId=2",
|
||||
authMiddleware: ReqSignedIn,
|
||||
identity: &authn.Identity{IsAnonymous: true, OrgID: 1},
|
||||
expecedReached: false,
|
||||
expectedCode: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "ReqSignedInNoAnonymous should return 401 for anonymous user",
|
||||
path: "/api/secure",
|
||||
authMiddleware: ReqSignedInNoAnonymous,
|
||||
identity: &authn.Identity{IsAnonymous: true},
|
||||
expecedReached: false,
|
||||
expectedCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
desc: "ReqSignedInNoAnonymous should return 200 for authenticated user",
|
||||
path: "/api/secure",
|
||||
authMiddleware: ReqSignedInNoAnonymous,
|
||||
identity: &authn.Identity{ID: "user:1"},
|
||||
expecedReached: true,
|
||||
expectedCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "snapshot public mode disabled should return 200 for authenticated user",
|
||||
path: "/api/secure",
|
||||
authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: false}),
|
||||
identity: &authn.Identity{ID: "user:1"},
|
||||
expecedReached: true,
|
||||
expectedCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "snapshot public mode disabled should return 401 for unauthenticated request",
|
||||
path: "/api/secure",
|
||||
authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: false}),
|
||||
authErr: errors.New("no auth"),
|
||||
expecedReached: false,
|
||||
expectedCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
desc: "snapshot public mode enabled should return 200 for unauthenticated request",
|
||||
path: "/api/secure",
|
||||
authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: true}),
|
||||
authErr: errors.New("no auth"),
|
||||
expecedReached: true,
|
||||
expectedCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
middlewareScenario(t, "ReqSignIn true and unauthenticated API request", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.m.Get("/api/secure", reqSignIn, sc.defaultHandler)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
ctxHandler := setupAuthMiddlewareTest(t, tt.identity, tt.authErr)
|
||||
|
||||
sc.fakeReq("GET", "/api/secure").exec()
|
||||
server := web.New()
|
||||
server.Use(ctxHandler.Middleware)
|
||||
server.Use(tt.authMiddleware)
|
||||
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
})
|
||||
var reached bool
|
||||
server.Get("/secure", func(c *contextmodel.ReqContext) {
|
||||
reached = true
|
||||
c.Resp.WriteHeader(http.StatusOK)
|
||||
})
|
||||
server.Get("/api/secure", func(c *contextmodel.ReqContext) {
|
||||
reached = true
|
||||
c.Resp.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("Anonymous auth enabled", func(t *testing.T) {
|
||||
const orgID int64 = 1
|
||||
req, err := http.NewRequest(http.MethodGet, tt.path, nil)
|
||||
require.NoError(t, err)
|
||||
recorder := httptest.NewRecorder()
|
||||
server.ServeHTTP(recorder, req)
|
||||
|
||||
configure := func(cfg *setting.Cfg) {
|
||||
cfg.AnonymousEnabled = true
|
||||
cfg.AnonymousOrgName = "test"
|
||||
}
|
||||
|
||||
middlewareScenario(t, "ReqSignIn true and NoAnonynmous true", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
sc.orgService.ExpectedOrg = &org.Org{ID: orgID, Name: "test"}
|
||||
sc.m.Get("/api/secure", ReqSignedInNoAnonymous, sc.defaultHandler)
|
||||
sc.fakeReq("GET", "/api/secure").exec()
|
||||
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
}, configure)
|
||||
|
||||
middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
sc.orgService.ExpectedOrg = &org.Org{ID: orgID, Name: "test"}
|
||||
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
|
||||
|
||||
sc.fakeReq("GET", "/secure?forceLogin=true").exec()
|
||||
|
||||
assert.Equal(t, 302, sc.resp.Code)
|
||||
location, ok := sc.resp.Header()["Location"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "/login", location[0])
|
||||
}, configure)
|
||||
|
||||
middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
|
||||
|
||||
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
|
||||
|
||||
sc.fakeReq("GET", fmt.Sprintf("/secure?orgId=%d", 1)).exec()
|
||||
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
}, configure)
|
||||
|
||||
middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
|
||||
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
|
||||
|
||||
sc.fakeReq("GET", "/secure?orgId=2").exec()
|
||||
|
||||
assert.Equal(t, 302, sc.resp.Code)
|
||||
location, ok := sc.resp.Header()["Location"]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "/login", location[0])
|
||||
}, configure)
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Snapshot public mode disabled and unauthenticated request should return 401", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
sc.m.Get("/api/snapshot", func(c *contextmodel.ReqContext) {
|
||||
c.IsSignedIn = false
|
||||
}, SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
|
||||
sc.fakeReq("GET", "/api/snapshot").exec()
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Snapshot public mode disabled and authenticated request should return 200", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
sc.m.Get("/api/snapshot", func(c *contextmodel.ReqContext) {
|
||||
c.IsSignedIn = true
|
||||
}, SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
|
||||
sc.fakeReq("GET", "/api/snapshot").exec()
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Snapshot public mode enabled and unauthenticated request should return 200", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
sc.cfg.SnapshotPublicMode = true
|
||||
sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
|
||||
sc.fakeReq("GET", "/api/snapshot").exec()
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
})
|
||||
res := recorder.Result()
|
||||
assert.Equal(t, tt.expecedReached, reached)
|
||||
assert.Equal(t, tt.expectedCode, res.StatusCode)
|
||||
require.NoError(t, res.Body.Close())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveForceLoginparams(t *testing.T) {
|
||||
|
@ -1,108 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/login"
|
||||
"github.com/grafana/grafana/pkg/services/apikey"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/services/login/logintest"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
|
||||
func TestMiddlewareBasicAuth(t *testing.T) {
|
||||
const id int64 = 12
|
||||
|
||||
configure := func(cfg *setting.Cfg) {
|
||||
cfg.BasicAuthEnabled = true
|
||||
cfg.DisableBruteForceLoginProtection = true
|
||||
}
|
||||
|
||||
middlewareScenario(t, "Valid API key", func(t *testing.T, sc *scenarioContext) {
|
||||
const orgID int64 = 2
|
||||
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
|
||||
require.NoError(t, err)
|
||||
|
||||
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: orgID, Role: org.RoleEditor, Key: keyhash}
|
||||
|
||||
authHeader := util.GetBasicAuthHeader("api_key", "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9")
|
||||
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
|
||||
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
|
||||
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
|
||||
require.NotNil(t, list)
|
||||
require.EqualValues(t, []string{"Authorization"}, list.Items)
|
||||
}, configure)
|
||||
|
||||
middlewareScenario(t, "Handle auth", func(t *testing.T, sc *scenarioContext) {
|
||||
const password = "MyPass"
|
||||
const orgID int64 = 2
|
||||
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: id}
|
||||
|
||||
authHeader := util.GetBasicAuthHeader("myUser", password)
|
||||
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
assert.Equal(t, id, sc.context.UserID)
|
||||
}, configure)
|
||||
|
||||
middlewareScenario(t, "Auth sequence", func(t *testing.T, sc *scenarioContext) {
|
||||
const password = "MyPass"
|
||||
const salt = "Salt"
|
||||
|
||||
encoded, err := util.EncodePassword(password, salt)
|
||||
require.NoError(t, err)
|
||||
|
||||
sc.userService.ExpectedUser = &user.User{Password: encoded, ID: id, Salt: salt}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id}
|
||||
login.ProvideService(sc.mockSQLStore, &logintest.LoginServiceFake{}, nil, sc.userService, sc.cfg)
|
||||
|
||||
authHeader := util.GetBasicAuthHeader("myUser", password)
|
||||
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
|
||||
require.NotNil(t, sc.context)
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, id, sc.context.UserID)
|
||||
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
|
||||
require.NotNil(t, list)
|
||||
require.EqualValues(t, []string{"Authorization"}, list.Items)
|
||||
}, configure)
|
||||
|
||||
middlewareScenario(t, "Should return error if user is not found", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.userService.ExpectedError = user.ErrUserNotFound
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.SetBasicAuth("user", "password")
|
||||
sc.exec()
|
||||
|
||||
err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"])
|
||||
}, configure)
|
||||
|
||||
middlewareScenario(t, "Should return error if user & password do not match", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.userService.ExpectedError = user.ErrUserNotFound
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.SetBasicAuth("killa", "gorilla")
|
||||
sc.exec()
|
||||
|
||||
err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"])
|
||||
}, configure)
|
||||
}
|
@ -1,303 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/auth/jwt"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
)
|
||||
|
||||
func TestMiddlewareJWTAuth(t *testing.T) {
|
||||
const myEmail = "vladimir@example.com"
|
||||
const id int64 = 12
|
||||
const orgID int64 = 2
|
||||
|
||||
configure := func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthEnabled = true
|
||||
cfg.JWTAuthHeaderName = "x-jwt-assertion"
|
||||
}
|
||||
|
||||
configureAuthHeader := func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthEnabled = true
|
||||
cfg.JWTAuthHeaderName = "Authorization"
|
||||
}
|
||||
|
||||
configureUsernameClaim := func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthUsernameClaim = "foo-username"
|
||||
}
|
||||
|
||||
configureEmailClaim := func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthEmailClaim = "foo-email"
|
||||
}
|
||||
|
||||
configureAutoSignUp := func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthAutoSignUp = true
|
||||
}
|
||||
|
||||
configureRole := func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthEmailClaim = "sub"
|
||||
cfg.JWTAuthRoleAttributePath = "role"
|
||||
}
|
||||
|
||||
configureRoleStrict := func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthRoleAttributeStrict = true
|
||||
}
|
||||
|
||||
configureRoleAllowAdmin := func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthAllowAssignGrafanaAdmin = true
|
||||
}
|
||||
|
||||
// #nosec G101 -- This is dummy/test token
|
||||
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ2bGFkaW1pckBleGFtcGxlLmNvbSIsImlhdCI6MTUxNjIzOTAyMiwiZm9vLXVzZXJuYW1lIjoidmxhZGltaXIiLCJuYW1lIjoiVmxhZGltaXIgRXhhbXBsZSIsImZvby1lbWFpbCI6InZsYWRpbWlyQGV4YW1wbGUuY29tIn0.MeNU1pCzRHGdQuu5ppeftxT31_2Le2kM1wd1GK2jExs"
|
||||
|
||||
middlewareScenario(t, "Valid token with valid login claim", func(t *testing.T, sc *scenarioContext) {
|
||||
myUsername := "vladimir"
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return jwt.JWTClaims{
|
||||
"sub": myUsername,
|
||||
"foo-username": myUsername,
|
||||
}, nil
|
||||
}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Login: myUsername}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
assert.Equal(t, id, sc.context.UserID)
|
||||
assert.Equal(t, myUsername, sc.context.Login)
|
||||
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
|
||||
require.NotNil(t, list)
|
||||
require.EqualValues(t, []string{"Authorization", sc.cfg.JWTAuthHeaderName}, list.Items)
|
||||
}, configure, configureUsernameClaim)
|
||||
|
||||
middlewareScenario(t, "Valid token with bearer in authorization header", func(t *testing.T, sc *scenarioContext) {
|
||||
myUsername := "vladimir"
|
||||
// We can ignore gosec G101 since this does not contain any credentials.
|
||||
// nolint:gosec
|
||||
myToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ2bGFkaW1pckBleGFtcGxlLmNvbSIsImlhdCI6MTUxNjIzOTAyMiwiZm9vLXVzZXJuYW1lIjoidmxhZGltaXIiLCJuYW1lIjoiVmxhZGltaXIgRXhhbXBsZSIsImZvby1lbWFpbCI6InZsYWRpbWlyQGV4YW1wbGUuY29tIn0.MeNU1pCzRHGdQuu5ppeftxT31_2Le2kM1wd1GK2jExs"
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = myToken
|
||||
return jwt.JWTClaims{
|
||||
"sub": myUsername,
|
||||
"foo-username": myUsername,
|
||||
}, nil
|
||||
}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Login: myUsername}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader("Bearer " + myToken).exec()
|
||||
assert.Equal(t, verifiedToken, myToken)
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
assert.Equal(t, id, sc.context.UserID)
|
||||
assert.Equal(t, myUsername, sc.context.Login)
|
||||
}, configureAuthHeader, configureUsernameClaim)
|
||||
|
||||
middlewareScenario(t, "Valid token with valid email claim", func(t *testing.T, sc *scenarioContext) {
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return jwt.JWTClaims{
|
||||
"sub": myEmail,
|
||||
"foo-email": myEmail,
|
||||
}, nil
|
||||
}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
assert.Equal(t, id, sc.context.UserID)
|
||||
assert.Equal(t, myEmail, sc.context.Email)
|
||||
}, configure, configureEmailClaim)
|
||||
|
||||
middlewareScenario(t, "Valid token with no user and auto_sign_up disabled", func(t *testing.T, sc *scenarioContext) {
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return jwt.JWTClaims{
|
||||
"sub": myEmail,
|
||||
"name": "Vladimir Example",
|
||||
"foo-email": myEmail,
|
||||
}, nil
|
||||
}
|
||||
sc.userService.ExpectedError = user.ErrUserNotFound
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
assert.Equal(t, contexthandler.UserNotFound, sc.respJson["message"])
|
||||
}, configure, configureEmailClaim)
|
||||
|
||||
middlewareScenario(t, "Valid token with no user and auto_sign_up enabled", func(t *testing.T, sc *scenarioContext) {
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return jwt.JWTClaims{
|
||||
"sub": myEmail,
|
||||
"name": "Vladimir Example",
|
||||
"foo-email": myEmail,
|
||||
}, nil
|
||||
}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
assert.Equal(t, id, sc.context.UserID)
|
||||
assert.Equal(t, myEmail, sc.context.Email)
|
||||
}, configure, configureEmailClaim, configureAutoSignUp)
|
||||
|
||||
middlewareScenario(t, "Valid token without a login claim", func(t *testing.T, sc *scenarioContext) {
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return jwt.JWTClaims{
|
||||
"sub": "baz",
|
||||
"foo": "bar",
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
assert.Equal(t, contexthandler.InvalidJWT, sc.respJson["message"])
|
||||
}, configure, configureUsernameClaim)
|
||||
|
||||
middlewareScenario(t, "Valid token without a email claim", func(t *testing.T, sc *scenarioContext) {
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return jwt.JWTClaims{
|
||||
"sub": "baz",
|
||||
"foo": "bar",
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
assert.Equal(t, contexthandler.InvalidJWT, sc.respJson["message"])
|
||||
}, configure, configureEmailClaim)
|
||||
|
||||
middlewareScenario(t, "Valid token with role", func(t *testing.T, sc *scenarioContext) {
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return jwt.JWTClaims{
|
||||
"sub": myEmail,
|
||||
"role": "Editor",
|
||||
}, nil
|
||||
}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleEditor}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
|
||||
}, configure, configureAutoSignUp, configureRole)
|
||||
|
||||
middlewareScenario(t, "Valid token with invalid role", func(t *testing.T, sc *scenarioContext) {
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return jwt.JWTClaims{
|
||||
"sub": myEmail,
|
||||
"role": "test",
|
||||
}, nil
|
||||
}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleViewer}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, org.RoleViewer, sc.context.OrgRole)
|
||||
}, configure, configureAutoSignUp, configureRole)
|
||||
|
||||
middlewareScenario(t, "Valid token with invalid role in strict mode", func(t *testing.T, sc *scenarioContext) {
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return jwt.JWTClaims{
|
||||
"sub": myEmail,
|
||||
"role": "test",
|
||||
}, nil
|
||||
}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleViewer}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 403, sc.resp.Code)
|
||||
assert.Equal(t, contexthandler.InvalidRole, sc.respJson["message"])
|
||||
}, configure, configureAutoSignUp, configureRole, configureRoleStrict)
|
||||
|
||||
middlewareScenario(t, "Valid token with grafana admin role not allowed", func(t *testing.T, sc *scenarioContext) {
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return jwt.JWTClaims{
|
||||
"sub": myEmail,
|
||||
"role": "GrafanaAdmin",
|
||||
}, nil
|
||||
}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleAdmin}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, org.RoleAdmin, sc.context.OrgRole)
|
||||
assert.False(t, sc.context.IsGrafanaAdmin)
|
||||
}, configure, configureAutoSignUp, configureRole)
|
||||
|
||||
middlewareScenario(t, "Valid token with grafana admin role allowed", func(t *testing.T, sc *scenarioContext) {
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return jwt.JWTClaims{
|
||||
"sub": myEmail,
|
||||
"role": "GrafanaAdmin",
|
||||
}, nil
|
||||
}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleAdmin, IsGrafanaAdmin: true}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 200, sc.resp.Code)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, org.RoleAdmin, sc.context.OrgRole)
|
||||
assert.True(t, sc.context.IsGrafanaAdmin)
|
||||
}, configure, configureAutoSignUp, configureRole, configureRoleAllowAdmin)
|
||||
|
||||
middlewareScenario(t, "Invalid token", func(t *testing.T, sc *scenarioContext) {
|
||||
var verifiedToken string
|
||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
||||
verifiedToken = token
|
||||
return nil, errors.New("token is invalid")
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
||||
assert.Equal(t, verifiedToken, token)
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
assert.Equal(t, contexthandler.InvalidJWT, sc.respJson["message"])
|
||||
}, configure, configureUsernameClaim)
|
||||
}
|
@ -1,17 +1,10 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -19,47 +12,22 @@ import (
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/gtime"
|
||||
|
||||
"github.com/grafana/grafana/pkg/api/dtos"
|
||||
"github.com/grafana/grafana/pkg/infra/db/dbtest"
|
||||
"github.com/grafana/grafana/pkg/infra/fs"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/login"
|
||||
"github.com/grafana/grafana/pkg/services/anonymous/anontest"
|
||||
"github.com/grafana/grafana/pkg/services/apikey"
|
||||
"github.com/grafana/grafana/pkg/services/apikey/apikeytest"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||
"github.com/grafana/grafana/pkg/services/auth/jwt"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/ldap/service"
|
||||
loginsvc "github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||
"github.com/grafana/grafana/pkg/services/login/logintest"
|
||||
"github.com/grafana/grafana/pkg/services/navtree"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/org/orgtest"
|
||||
"github.com/grafana/grafana/pkg/services/rendering"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/services/user/usertest"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
)
|
||||
|
||||
func fakeGetTime() func() time.Time {
|
||||
var timeSeed int64
|
||||
return func() time.Time {
|
||||
fakeNow := time.Unix(timeSeed, 0)
|
||||
timeSeed++
|
||||
return fakeNow
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddleWareSecurityHeaders(t *testing.T) {
|
||||
middlewareScenario(t, "middleware should get correct x-xss-protection header", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.fakeReq("GET", "/api/").exec()
|
||||
@ -134,11 +102,6 @@ func TestMiddleWareContentSecurityPolicyHeaders(t *testing.T) {
|
||||
func TestMiddlewareContext(t *testing.T) {
|
||||
const noStore = "no-store"
|
||||
|
||||
configureJWTAuthHeader := func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthEnabled = true
|
||||
cfg.JWTAuthHeaderName = "Authorization"
|
||||
}
|
||||
|
||||
middlewareScenario(t, "middleware should add context to injector", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
assert.NotNil(t, sc.context)
|
||||
@ -214,372 +177,6 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
cfg.AllowEmbedding = true
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Invalid api key", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.apiKey = "invalid_key_test"
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
assert.Equal(t, contexthandler.InvalidAPIKey, sc.respJson["message"])
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Valid API key", func(t *testing.T, sc *scenarioContext) {
|
||||
const orgID int64 = 12
|
||||
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
|
||||
require.NoError(t, err)
|
||||
|
||||
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: orgID, Role: org.RoleEditor, Key: keyhash}
|
||||
|
||||
sc.fakeReq("GET", "/").withValidApiKey().exec()
|
||||
|
||||
require.Equal(t, 200, sc.resp.Code)
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Valid API key with JWT enabled", func(t *testing.T, sc *scenarioContext) {
|
||||
const orgID int64 = 12
|
||||
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
|
||||
require.NoError(t, err)
|
||||
|
||||
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: orgID, Role: org.RoleEditor, Key: keyhash}
|
||||
|
||||
sc.fakeReq("GET", "/").withValidApiKey().exec()
|
||||
|
||||
require.Equal(t, 200, sc.resp.Code)
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
|
||||
}, configureJWTAuthHeader)
|
||||
|
||||
middlewareScenario(t, "Valid Basic Auth header with JWT enabled and empty 'sub' claim", func(t *testing.T, sc *scenarioContext) {
|
||||
const password = "MyPass"
|
||||
const orgID int64 = 2
|
||||
const userID int64 = 12
|
||||
// #nosec G101 -- This is dummy/test token
|
||||
const emptySubToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJzdWIiOiIiLCJpYXQiOjE1MTYyMzkwMjJ9.tnwtOHK58d47dO4DHW4b9MzeToxa1kGiko5Oo887Rqc"
|
||||
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
||||
authHeader := util.GetBasicAuthHeader("myuser", password)
|
||||
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).withJWTAuthHeader(emptySubToken).exec()
|
||||
|
||||
require.Equal(t, 200, sc.resp.Code)
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
}, func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthEnabled = true
|
||||
cfg.JWTAuthHeaderName = "X-JWT-Token"
|
||||
cfg.BasicAuthEnabled = true
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Valid Basic Auth header with JWT enabled and missing 'sub' claim", func(t *testing.T, sc *scenarioContext) {
|
||||
const password = "MyPass"
|
||||
const orgID int64 = 2
|
||||
const userID int64 = 12
|
||||
// #nosec G101 -- This is dummy/test token
|
||||
const missingSubToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjJ9.8nYFUX869Y1mnDDDU4yL11aANgVRuifoxrE8BHZY1iE"
|
||||
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
||||
authHeader := util.GetBasicAuthHeader("myuser", password)
|
||||
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).withJWTAuthHeader(missingSubToken).exec()
|
||||
|
||||
require.Equal(t, 200, sc.resp.Code)
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
}, func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthEnabled = true
|
||||
cfg.JWTAuthHeaderName = "X-JWT-Token"
|
||||
cfg.BasicAuthEnabled = true
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Valid API key, but does not match DB hash", func(t *testing.T, sc *scenarioContext) {
|
||||
const keyhash = "Something_not_matching"
|
||||
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: 12, Role: org.RoleEditor, Key: keyhash}
|
||||
|
||||
sc.fakeReq("GET", "/").withValidApiKey().exec()
|
||||
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
assert.Equal(t, contexthandler.InvalidAPIKey, sc.respJson["message"])
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Valid API key, but expired", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
|
||||
require.NoError(t, err)
|
||||
|
||||
expires := sc.contextHandler.GetTime().Add(-1 * time.Second).Unix()
|
||||
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: 12, Role: org.RoleEditor, Key: keyhash, Expires: &expires}
|
||||
|
||||
sc.fakeReq("GET", "/").withValidApiKey().exec()
|
||||
|
||||
assert.Equal(t, 401, sc.resp.Code)
|
||||
assert.Equal(t, "Expired API key", sc.respJson["message"])
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie which is not being rotated", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
||||
return &auth.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.NotNil(t, sc.context.UserToken)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie which is being rotated", func(t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
||||
return &auth.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *auth.UserToken,
|
||||
clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
|
||||
userToken.UnhashedToken = "rotated"
|
||||
return true, userToken, nil
|
||||
}
|
||||
|
||||
maxAge := int(sc.cfg.LoginMaxLifetime.Seconds())
|
||||
|
||||
sameSiteModes := []http.SameSite{
|
||||
http.SameSiteNoneMode,
|
||||
http.SameSiteLaxMode,
|
||||
http.SameSiteStrictMode,
|
||||
}
|
||||
for _, sameSiteMode := range sameSiteModes {
|
||||
t.Run(fmt.Sprintf("Same site mode %d", sameSiteMode), func(t *testing.T) {
|
||||
origCookieSameSiteMode := setting.CookieSameSiteMode
|
||||
t.Cleanup(func() {
|
||||
setting.CookieSameSiteMode = origCookieSameSiteMode
|
||||
})
|
||||
setting.CookieSameSiteMode = sameSiteMode
|
||||
|
||||
expectedCookiePath := "/"
|
||||
if len(sc.cfg.AppSubURL) > 0 {
|
||||
expectedCookiePath = sc.cfg.AppSubURL
|
||||
}
|
||||
expectedCookie := &http.Cookie{
|
||||
Name: sc.cfg.LoginCookieName,
|
||||
Value: "rotated",
|
||||
Path: expectedCookiePath,
|
||||
HttpOnly: true,
|
||||
MaxAge: maxAge,
|
||||
Secure: setting.CookieSecure,
|
||||
SameSite: sameSiteMode,
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||
assert.Equal(t, "rotated", sc.context.UserToken.UnhashedToken)
|
||||
assert.Equal(t, expectedCookie.String(), sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Should not set cookie with SameSite attribute when setting.CookieSameSiteDisabled is true", func(t *testing.T) {
|
||||
origCookieSameSiteDisabled := setting.CookieSameSiteDisabled
|
||||
origCookieSameSiteMode := setting.CookieSameSiteMode
|
||||
t.Cleanup(func() {
|
||||
setting.CookieSameSiteDisabled = origCookieSameSiteDisabled
|
||||
setting.CookieSameSiteMode = origCookieSameSiteMode
|
||||
})
|
||||
setting.CookieSameSiteDisabled = true
|
||||
setting.CookieSameSiteMode = http.SameSiteLaxMode
|
||||
|
||||
expectedCookiePath := "/"
|
||||
if len(sc.cfg.AppSubURL) > 0 {
|
||||
expectedCookiePath = sc.cfg.AppSubURL
|
||||
}
|
||||
expectedCookie := &http.Cookie{
|
||||
Name: sc.cfg.LoginCookieName,
|
||||
Value: "rotated",
|
||||
Path: expectedCookiePath,
|
||||
HttpOnly: true,
|
||||
MaxAge: maxAge,
|
||||
Secure: setting.CookieSecure,
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
assert.Equal(t, expectedCookie.String(), sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Invalid/expired auth token in cookie", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.withTokenSessionCookie("token")
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
||||
return nil, auth.ErrUserTokenNotFound
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
assert.False(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, int64(0), sc.context.UserID)
|
||||
assert.Nil(t, sc.context.UserToken)
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and non-expired OAuth access token", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(11 * time.Second)}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
||||
return &auth.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.NotNil(t, sc.context.UserToken)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token fails", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
signedInUser := &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.userService.ExpectedSignedInUser = signedInUser
|
||||
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{
|
||||
UserId: userID,
|
||||
OAuthExpiry: fakeGetTime()().Add(-1 * time.Second),
|
||||
OAuthAccessToken: "access_token",
|
||||
OAuthRefreshToken: "refresh_token"}
|
||||
sc.oauthTokenService.ExpectedErrors = map[string]error{"TryTokenRefresh": errors.New("error")}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
||||
return &auth.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
token := sc.oauthTokenService.GetCurrentOAuthToken(sc.context.Req.Context(), signedInUser)
|
||||
assert.Equal(t, token.AccessToken, "")
|
||||
assert.Equal(t, token.RefreshToken, "")
|
||||
assert.True(t, token.Expiry.IsZero())
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.Nil(t, sc.context.UserToken)
|
||||
assert.False(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, int64(0), sc.context.UserID)
|
||||
assert.Equal(t, "grafana_session=; Path=/; Max-Age=0; HttpOnly", sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token succeeds", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(-5 * time.Second), OAuthRefreshToken: "refreshtoken"}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
||||
return &auth.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.NotNil(t, sc.context.UserToken)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Non-expired auth token in cookie and OAuth Access Token's Expiry is not set", func(
|
||||
t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
sc.contextHandler.GetTime = fakeGetTime()
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
||||
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{UserId: userID}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
||||
return &auth.UserToken{
|
||||
UserId: userID,
|
||||
UnhashedToken: unhashedToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
require.NotNil(t, sc.context)
|
||||
require.NotNil(t, sc.context.UserToken)
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
||||
})
|
||||
|
||||
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
|
||||
sc.fakeReq("GET", "/").exec()
|
||||
|
||||
assert.Equal(t, int64(0), sc.context.UserID)
|
||||
assert.Equal(t, int64(1), sc.context.OrgID)
|
||||
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
|
||||
assert.False(t, sc.context.IsSignedIn)
|
||||
}, func(cfg *setting.Cfg) {
|
||||
cfg.AnonymousEnabled = true
|
||||
cfg.AnonymousOrgName = "test"
|
||||
cfg.AnonymousOrgRole = string(org.RoleEditor)
|
||||
})
|
||||
|
||||
middlewareScenario(t, "middleware should add custom response headers", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.fakeReq("GET", "/api/").exec()
|
||||
assert.Regexp(t, "test", sc.resp.Header().Get("X-Custom-Header"))
|
||||
@ -590,278 +187,6 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
"X-Other-Header": "other-test",
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auth_proxy", func(t *testing.T) {
|
||||
const userID int64 = 33
|
||||
const orgID int64 = 4
|
||||
const defaultOrgId int64 = 1
|
||||
const orgRole = "Admin"
|
||||
|
||||
configure := func(cfg *setting.Cfg) {
|
||||
cfg.AuthProxyEnabled = true
|
||||
cfg.AuthProxyAutoSignUp = true
|
||||
cfg.LDAPAuthEnabled = true
|
||||
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
|
||||
cfg.AuthProxyHeaderProperty = "username"
|
||||
cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS", "Role": "X-WEBAUTH-ROLE"}
|
||||
}
|
||||
|
||||
const hdrName = "markelog"
|
||||
const group = "grafana-core-team"
|
||||
|
||||
middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
||||
h, err := authproxy.HashCacheKey(hdrName + "-" + group)
|
||||
require.NoError(t, err)
|
||||
key := fmt.Sprintf(authproxy.CachePrefix, h)
|
||||
userIdBytes := []byte(strconv.FormatInt(userID, 10))
|
||||
err = sc.remoteCacheService.Set(context.Background(), key, userIdBytes, 0)
|
||||
require.NoError(t, err)
|
||||
sc.fakeReq("GET", "/")
|
||||
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.req.Header.Set("X-WEBAUTH-GROUPS", group)
|
||||
sc.exec()
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
}, configure)
|
||||
|
||||
middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) {
|
||||
var actualAuthProxyAutoSignUp *bool = nil
|
||||
sc.loginService.ExpectedUserFunc = func(cmd *loginsvc.UpsertUserCommand) *user.User {
|
||||
actualAuthProxyAutoSignUp = &cmd.SignupAllowed
|
||||
return nil
|
||||
}
|
||||
sc.loginService.ExpectedError = login.ErrInvalidCredentials
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.exec()
|
||||
|
||||
assert.False(t, *actualAuthProxyAutoSignUp)
|
||||
assert.Equal(t, 407, sc.resp.Code)
|
||||
assert.Nil(t, sc.context)
|
||||
}, func(cfg *setting.Cfg) {
|
||||
configure(cfg)
|
||||
cfg.LDAPAuthEnabled = false
|
||||
cfg.AuthProxyAutoSignUp = false
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.loginService.ExpectedUser = &user.User{ID: userID}
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.exec()
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
|
||||
require.NotNil(t, list)
|
||||
require.Contains(t, list.Items, sc.cfg.AuthProxyHeaderName)
|
||||
require.Contains(t, list.Items, "X-WEBAUTH-GROUPS")
|
||||
require.Contains(t, list.Items, "X-WEBAUTH-ROLE")
|
||||
}, func(cfg *setting.Cfg) {
|
||||
configure(cfg)
|
||||
cfg.LDAPAuthEnabled = false
|
||||
cfg.AuthProxyAutoSignUp = true
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should assign role from header to default org", func(t *testing.T, sc *scenarioContext) {
|
||||
var storedRoleInfo map[int64]org.RoleType = nil
|
||||
sc.loginService.ExpectedUserFunc = func(cmd *loginsvc.UpsertUserCommand) *user.User {
|
||||
storedRoleInfo = cmd.ExternalUser.OrgRoles
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: defaultOrgId, UserID: userID, OrgRole: storedRoleInfo[defaultOrgId]}
|
||||
return &user.User{ID: userID}
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.req.Header.Set("X-WEBAUTH-ROLE", orgRole)
|
||||
sc.exec()
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, defaultOrgId, sc.context.OrgID)
|
||||
assert.Equal(t, orgRole, string(sc.context.OrgRole))
|
||||
}, func(cfg *setting.Cfg) {
|
||||
configure(cfg)
|
||||
cfg.LDAPAuthEnabled = false
|
||||
cfg.AuthProxyAutoSignUp = true
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should NOT assign role from header to non-default org", func(t *testing.T, sc *scenarioContext) {
|
||||
var storedRoleInfo map[int64]org.RoleType = nil
|
||||
sc.loginService.ExpectedUserFunc = func(cmd *loginsvc.UpsertUserCommand) *user.User {
|
||||
storedRoleInfo = cmd.ExternalUser.OrgRoles
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID, OrgRole: storedRoleInfo[orgID]}
|
||||
return &user.User{ID: userID}
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.req.Header.Set("X-WEBAUTH-ROLE", "Admin")
|
||||
sc.req.Header.Set("X-Grafana-Org-Id", strconv.FormatInt(orgID, 10))
|
||||
sc.exec()
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
|
||||
// For non-default org, the user role should be empty
|
||||
assert.Equal(t, "", string(sc.context.OrgRole))
|
||||
}, func(cfg *setting.Cfg) {
|
||||
configure(cfg)
|
||||
cfg.LDAPAuthEnabled = false
|
||||
cfg.AuthProxyAutoSignUp = true
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should use organisation specified by targetOrgId parameter", func(t *testing.T, sc *scenarioContext) {
|
||||
var targetOrgID int64 = 123
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: targetOrgID, UserID: userID}
|
||||
sc.loginService.ExpectedUser = &user.User{ID: userID}
|
||||
|
||||
sc.fakeReq("GET", fmt.Sprintf("/?targetOrgId=%d", targetOrgID))
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.exec()
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, targetOrgID, sc.context.OrgID)
|
||||
}, func(cfg *setting.Cfg) {
|
||||
configure(cfg)
|
||||
cfg.LDAPAuthEnabled = false
|
||||
cfg.AuthProxyAutoSignUp = true
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Request body should not be read in default context handler", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.fakeReq("POST", "/?targetOrgId=123")
|
||||
body := "key=value"
|
||||
sc.req.Body = io.NopCloser(strings.NewReader(body))
|
||||
|
||||
sc.handlerFunc = func(c *contextmodel.ReqContext) {
|
||||
t.Log("Handler called")
|
||||
defer func() {
|
||||
err := c.Req.Body.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
bodyAfterHandler, e := io.ReadAll(c.Req.Body)
|
||||
require.NoError(t, e)
|
||||
require.Equal(t, body, string(bodyAfterHandler))
|
||||
}
|
||||
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
sc.req.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
||||
sc.m.Post("/", sc.defaultHandler)
|
||||
sc.exec()
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Request body should not be read in default context handler, but query should be altered - jwt", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.fakeReq("POST", "/?targetOrgId=123&auth_token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NSIsImlhdCI6MTUxNjIzOTAyMn0.1E9qmtctlHAeJzNLPgGFfxdA8WfbEl_vwYO91ffQGxs")
|
||||
body := "key=value"
|
||||
sc.req.Body = io.NopCloser(strings.NewReader(body))
|
||||
|
||||
sc.handlerFunc = func(c *contextmodel.ReqContext) {
|
||||
t.Log("Handler called")
|
||||
defer func() {
|
||||
err := c.Req.Body.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
require.Equal(t, "", c.Req.URL.Query().Get("auth_token"))
|
||||
|
||||
bodyAfterHandler, e := io.ReadAll(c.Req.Body)
|
||||
require.NoError(t, e)
|
||||
require.Equal(t, body, string(bodyAfterHandler))
|
||||
}
|
||||
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
sc.req.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
||||
sc.m.Post("/", sc.defaultHandler)
|
||||
sc.exec()
|
||||
}, func(cfg *setting.Cfg) {
|
||||
cfg.JWTAuthEnabled = true
|
||||
cfg.JWTAuthURLLogin = true
|
||||
cfg.JWTAuthHeaderName = "X-WEBAUTH-TOKEN"
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should get an existing user from header", func(t *testing.T, sc *scenarioContext) {
|
||||
const userID int64 = 12
|
||||
const orgID int64 = 2
|
||||
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
||||
sc.loginService.ExpectedUser = &user.User{ID: userID}
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.exec()
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
}, func(cfg *setting.Cfg) {
|
||||
configure(cfg)
|
||||
cfg.LDAPAuthEnabled = false
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
||||
sc.loginService.ExpectedUser = &user.User{ID: userID}
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.req.RemoteAddr = "[2001::23]:12345"
|
||||
sc.exec()
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserID)
|
||||
assert.Equal(t, orgID, sc.context.OrgID)
|
||||
}, func(cfg *setting.Cfg) {
|
||||
configure(cfg)
|
||||
cfg.AuthProxyWhitelist = "192.168.1.0/24, 2001::0/120"
|
||||
cfg.LDAPAuthEnabled = false
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.loginService.ExpectedUser = &user.User{ID: userID}
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.req.RemoteAddr = "[2001::23]:12345"
|
||||
sc.exec()
|
||||
|
||||
assert.Equal(t, 407, sc.resp.Code)
|
||||
assert.Nil(t, sc.context)
|
||||
}, func(cfg *setting.Cfg) {
|
||||
configure(cfg)
|
||||
cfg.AuthProxyWhitelist = "8.8.8.8"
|
||||
cfg.LDAPAuthEnabled = false
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.exec()
|
||||
|
||||
assert.Equal(t, 407, sc.resp.Code)
|
||||
assert.Nil(t, sc.context)
|
||||
}, configure)
|
||||
|
||||
middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.exec()
|
||||
|
||||
assert.Equal(t, 407, sc.resp.Code)
|
||||
assert.Nil(t, sc.context)
|
||||
}, configure)
|
||||
})
|
||||
}
|
||||
|
||||
func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(*setting.Cfg)) {
|
||||
@ -894,22 +219,14 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
||||
sc.m.UseMiddleware(ContentSecurityPolicy(cfg, logger))
|
||||
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
||||
|
||||
sc.mockSQLStore = dbtest.NewFakeDB()
|
||||
sc.loginService = &loginservice.LoginServiceMock{}
|
||||
// defalut to not authenticated request
|
||||
sc.authnService = &authntest.FakeService{ExpectedErr: errors.New("no auth")}
|
||||
sc.userService = usertest.NewUserServiceFake()
|
||||
sc.orgService = orgtest.NewOrgServiceFake()
|
||||
sc.apiKeyService = &apikeytest.Service{}
|
||||
sc.oauthTokenService = &authtest.FakeOAuthTokenService{}
|
||||
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService, sc.oauthTokenService)
|
||||
sc.sqlStore = ctxHdlr.SQLStore
|
||||
sc.contextHandler = ctxHdlr
|
||||
|
||||
ctxHdlr := getContextHandler(t, cfg, sc.authnService)
|
||||
sc.m.Use(ctxHdlr.Middleware)
|
||||
sc.m.Use(OrgRedirect(sc.cfg, sc.userService))
|
||||
|
||||
sc.userAuthTokenService = ctxHdlr.AuthTokenService.(*authtest.FakeUserAuthTokenService)
|
||||
sc.jwtAuthService = ctxHdlr.JWTAuthService.(*jwt.FakeJWTService)
|
||||
sc.remoteCacheService = ctxHdlr.RemoteCache
|
||||
|
||||
sc.defaultHandler = func(c *contextmodel.ReqContext) {
|
||||
require.NotNil(t, c)
|
||||
t.Log("Default HTTP handler called")
|
||||
@ -933,40 +250,14 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
||||
})
|
||||
}
|
||||
|
||||
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *dbtest.FakeDB,
|
||||
loginService *loginservice.LoginServiceMock, apiKeyService *apikeytest.Service,
|
||||
userService *usertest.FakeUserService, orgService *orgtest.FakeOrgService,
|
||||
oauthTokenService *authtest.FakeOAuthTokenService,
|
||||
) *contexthandler.ContextHandler {
|
||||
func getContextHandler(t *testing.T, cfg *setting.Cfg, authnService authn.Service) *contexthandler.ContextHandler {
|
||||
t.Helper()
|
||||
|
||||
if cfg == nil {
|
||||
cfg = setting.NewCfg()
|
||||
}
|
||||
cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{
|
||||
Name: "database",
|
||||
}
|
||||
|
||||
remoteCacheSvc := remotecache.NewFakeStore(t)
|
||||
userAuthTokenSvc := authtest.NewFakeUserAuthTokenService()
|
||||
renderSvc := &fakeRenderService{}
|
||||
authJWTSvc := jwt.NewFakeJWTService()
|
||||
tracer := tracing.InitializeTracerForTest()
|
||||
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService,
|
||||
userService, mockSQLStore, &service.LDAPFakeService{ExpectedError: service.ErrUnableToCreateLDAPClient})
|
||||
authenticator := &logintest.AuthenticatorFake{ExpectedUser: &user.User{}}
|
||||
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc,
|
||||
remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy,
|
||||
loginService, apiKeyService, authenticator, userService, orgService,
|
||||
oauthTokenService,
|
||||
featuremgmt.WithFeatures(featuremgmt.FlagAccessTokenExpirationCheck),
|
||||
&authntest.FakeService{}, &anontest.FakeAnonymousSessionService{})
|
||||
}
|
||||
|
||||
type fakeRenderService struct {
|
||||
rendering.Service
|
||||
}
|
||||
|
||||
func (s *fakeRenderService) Init() error {
|
||||
return nil
|
||||
tracer := tracing.NewFakeTracer()
|
||||
return contexthandler.ProvideService(cfg, authtest.NewFakeUserAuthTokenService(), nil,
|
||||
nil, nil, nil, tracer, nil,
|
||||
nil, nil, nil, nil, nil,
|
||||
nil, featuremgmt.WithFeatures(featuremgmt.FlagAccessTokenExpirationCheck),
|
||||
authnService, &anontest.FakeAnonymousSessionService{},
|
||||
)
|
||||
}
|
||||
|
@ -1,14 +1,12 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
)
|
||||
|
||||
func TestOrgRedirectMiddleware(t *testing.T) {
|
||||
@ -46,15 +44,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
middlewareScenario(t, tc.desc, func(t *testing.T, sc *scenarioContext) {
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12}
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
||||
return &auth.UserToken{
|
||||
UserId: 0,
|
||||
UnhashedToken: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.withIdentity(&authn.Identity{})
|
||||
sc.m.Get("/", sc.defaultHandler)
|
||||
sc.fakeReq("GET", tc.input).exec()
|
||||
|
||||
@ -64,19 +54,11 @@ func TestOrgRedirectMiddleware(t *testing.T) {
|
||||
}
|
||||
|
||||
middlewareScenario(t, "when setting an invalid org for user", func(t *testing.T, sc *scenarioContext) {
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.withIdentity(&authn.Identity{})
|
||||
sc.userService.ExpectedSetUsingOrgError = fmt.Errorf("")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12}
|
||||
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
||||
return &auth.UserToken{
|
||||
UserId: 12,
|
||||
UnhashedToken: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
sc.m.Get("/", sc.defaultHandler)
|
||||
sc.fakeReq("GET", "/?orgId=3").exec()
|
||||
sc.fakeReq("GET", "/?orgId=1").exec()
|
||||
|
||||
require.Equal(t, 404, sc.resp.Code)
|
||||
})
|
||||
|
@ -1,14 +1,13 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/quota/quotatest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
)
|
||||
@ -53,14 +52,7 @@ func TestMiddlewareQuota(t *testing.T) {
|
||||
|
||||
t.Run("with user logged in", func(t *testing.T) {
|
||||
setUp := func(sc *scenarioContext) {
|
||||
sc.withTokenSessionCookie("token")
|
||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: 12}
|
||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
||||
return &auth.UserToken{
|
||||
UserId: 12,
|
||||
UnhashedToken: "",
|
||||
}, nil
|
||||
}
|
||||
sc.withIdentity(&authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{UserId: 12}})
|
||||
}
|
||||
|
||||
middlewareScenario(t, "global datasource quota reached", func(t *testing.T, sc *scenarioContext) {
|
||||
|
@ -8,9 +8,10 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||
"github.com/grafana/grafana/pkg/services/user/usertest"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
)
|
||||
@ -66,13 +67,10 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
|
||||
sc.m.Use(AddDefaultResponseHeaders(cfg))
|
||||
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
||||
|
||||
sc.userAuthTokenService = authtest.NewFakeUserAuthTokenService()
|
||||
sc.remoteCacheService = remotecache.NewFakeStore(t)
|
||||
|
||||
contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil, nil)
|
||||
contextHandler := getContextHandler(t, setting.NewCfg(), &authntest.FakeService{ExpectedIdentity: &authn.Identity{}})
|
||||
sc.m.Use(contextHandler.Middleware)
|
||||
// mock out gc goroutine
|
||||
sc.m.Use(OrgRedirect(cfg, sc.userService))
|
||||
sc.m.Use(OrgRedirect(cfg, usertest.NewUserServiceFake()))
|
||||
|
||||
sc.defaultHandler = func(c *contextmodel.ReqContext) {
|
||||
sc.context = c
|
||||
|
@ -8,69 +8,35 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db"
|
||||
"github.com/grafana/grafana/pkg/infra/db/dbtest"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/services/apikey/apikeytest"
|
||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||
"github.com/grafana/grafana/pkg/services/auth/jwt"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler/ctxkey"
|
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||
"github.com/grafana/grafana/pkg/services/org/orgtest"
|
||||
"github.com/grafana/grafana/pkg/services/user/usertest"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
)
|
||||
|
||||
type scenarioContext struct {
|
||||
t *testing.T
|
||||
m *web.Mux
|
||||
context *contextmodel.ReqContext
|
||||
resp *httptest.ResponseRecorder
|
||||
apiKey string
|
||||
authHeader string
|
||||
jwtAuthHeader string
|
||||
tokenSessionCookie string
|
||||
respJson map[string]interface{}
|
||||
handlerFunc handlerFunc
|
||||
defaultHandler web.Handler
|
||||
url string
|
||||
userAuthTokenService *authtest.FakeUserAuthTokenService
|
||||
jwtAuthService *jwt.FakeJWTService
|
||||
remoteCacheService *remotecache.RemoteCache
|
||||
cfg *setting.Cfg
|
||||
sqlStore db.DB
|
||||
mockSQLStore *dbtest.FakeDB
|
||||
contextHandler *contexthandler.ContextHandler
|
||||
loginService *loginservice.LoginServiceMock
|
||||
apiKeyService *apikeytest.Service
|
||||
userService *usertest.FakeUserService
|
||||
oauthTokenService *authtest.FakeOAuthTokenService
|
||||
orgService *orgtest.FakeOrgService
|
||||
t *testing.T
|
||||
m *web.Mux
|
||||
context *contextmodel.ReqContext
|
||||
resp *httptest.ResponseRecorder
|
||||
respJson map[string]interface{}
|
||||
handlerFunc handlerFunc
|
||||
defaultHandler web.Handler
|
||||
url string
|
||||
authnService *authntest.FakeService
|
||||
userService *usertest.FakeUserService
|
||||
cfg *setting.Cfg
|
||||
|
||||
req *http.Request
|
||||
}
|
||||
|
||||
func (sc *scenarioContext) withValidApiKey() *scenarioContext {
|
||||
sc.apiKey = "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9"
|
||||
return sc
|
||||
}
|
||||
|
||||
func (sc *scenarioContext) withTokenSessionCookie(unhashedToken string) *scenarioContext {
|
||||
sc.tokenSessionCookie = unhashedToken
|
||||
return sc
|
||||
}
|
||||
|
||||
func (sc *scenarioContext) withAuthorizationHeader(authHeader string) *scenarioContext {
|
||||
sc.authHeader = authHeader
|
||||
return sc
|
||||
}
|
||||
|
||||
func (sc *scenarioContext) withJWTAuthHeader(jwtAuthHeader string) *scenarioContext {
|
||||
sc.jwtAuthHeader = jwtAuthHeader
|
||||
return sc
|
||||
// set identity to use for request
|
||||
func (sc *scenarioContext) withIdentity(identity *authn.Identity) {
|
||||
sc.authnService.ExpectedErr = nil
|
||||
sc.authnService.ExpectedIdentity = identity
|
||||
}
|
||||
|
||||
func (sc *scenarioContext) fakeReq(method, url string) *scenarioContext {
|
||||
@ -116,29 +82,6 @@ func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map
|
||||
func (sc *scenarioContext) exec() {
|
||||
sc.t.Helper()
|
||||
|
||||
if sc.apiKey != "" {
|
||||
sc.t.Logf(`Adding header "Authorization: Bearer %s"`, sc.apiKey)
|
||||
sc.req.Header.Set("Authorization", "Bearer "+sc.apiKey)
|
||||
}
|
||||
|
||||
if sc.authHeader != "" {
|
||||
sc.t.Logf(`Adding header "Authorization: %s"`, sc.authHeader)
|
||||
sc.req.Header.Set("Authorization", sc.authHeader)
|
||||
}
|
||||
|
||||
if sc.jwtAuthHeader != "" {
|
||||
sc.t.Logf(`Adding header "%s: %s"`, sc.cfg.JWTAuthHeaderName, sc.jwtAuthHeader)
|
||||
sc.req.Header.Set(sc.cfg.JWTAuthHeaderName, sc.jwtAuthHeader)
|
||||
}
|
||||
|
||||
if sc.tokenSessionCookie != "" {
|
||||
sc.t.Log(`Adding cookie`, "name", sc.cfg.LoginCookieName, "value", sc.tokenSessionCookie)
|
||||
sc.req.AddCookie(&http.Cookie{
|
||||
Name: sc.cfg.LoginCookieName,
|
||||
Value: sc.tokenSessionCookie,
|
||||
})
|
||||
}
|
||||
|
||||
sc.m.ServeHTTP(sc.resp, sc.req)
|
||||
|
||||
if sc.resp.Header().Get("Content-Type") == "application/json; charset=UTF-8" {
|
||||
|
@ -337,9 +337,7 @@ type RedirectValidator func(url string) error
|
||||
// HandleLoginResponse is a utility function to perform common operations after a successful login and returns response.NormalResponse
|
||||
func HandleLoginResponse(r *http.Request, w http.ResponseWriter, cfg *setting.Cfg, identity *Identity, validator RedirectValidator) *response.NormalResponse {
|
||||
result := map[string]interface{}{"message": "Logged in"}
|
||||
if redirectURL := handleLogin(r, w, cfg, identity, validator); redirectURL != cfg.AppSubURL+"/" {
|
||||
result["redirectUrl"] = redirectURL
|
||||
}
|
||||
result["redirectUrl"] = handleLogin(r, w, cfg, identity, validator)
|
||||
return response.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
@ -356,9 +354,11 @@ func HandleLoginRedirectResponse(r *http.Request, w http.ResponseWriter, cfg *se
|
||||
|
||||
func handleLogin(r *http.Request, w http.ResponseWriter, cfg *setting.Cfg, identity *Identity, validator RedirectValidator) string {
|
||||
redirectURL := cfg.AppSubURL + "/"
|
||||
if redirectTo := getRedirectURL(r); len(redirectTo) > 0 && validator(redirectTo) == nil {
|
||||
cookies.DeleteCookie(w, "redirect_to", nil)
|
||||
redirectURL = redirectTo
|
||||
if redirectTo := getRedirectURL(r); len(redirectTo) > 0 {
|
||||
if validator(redirectTo) == nil {
|
||||
redirectURL = redirectTo
|
||||
}
|
||||
cookies.DeleteCookie(w, "redirect_to", cookieOptions(cfg))
|
||||
}
|
||||
|
||||
WriteSessionCookie(w, cfg, identity.SessionToken)
|
||||
@ -386,17 +386,32 @@ func WriteSessionCookie(w http.ResponseWriter, cfg *setting.Cfg, token *usertoke
|
||||
cookies.WriteCookie(w, cfg.LoginCookieName, url.QueryEscape(token.UnhashedToken), maxAge, nil)
|
||||
expiry := token.NextRotation(time.Duration(cfg.TokenRotationIntervalMinutes) * time.Minute)
|
||||
cookies.WriteCookie(w, sessionExpiryCookie, url.QueryEscape(strconv.FormatInt(expiry.Unix(), 10)), maxAge, func() cookies.CookieOptions {
|
||||
opts := cookies.NewCookieOptions()
|
||||
opts := cookieOptions(cfg)()
|
||||
opts.NotHttpOnly = true
|
||||
return opts
|
||||
})
|
||||
}
|
||||
|
||||
func DeleteSessionCookie(w http.ResponseWriter, cfg *setting.Cfg) {
|
||||
cookies.DeleteCookie(w, cfg.LoginCookieName, nil)
|
||||
cookies.DeleteCookie(w, cfg.LoginCookieName, cookieOptions(cfg))
|
||||
cookies.DeleteCookie(w, sessionExpiryCookie, func() cookies.CookieOptions {
|
||||
opts := cookies.NewCookieOptions()
|
||||
opts := cookieOptions(cfg)()
|
||||
opts.NotHttpOnly = true
|
||||
return opts
|
||||
})
|
||||
}
|
||||
|
||||
func cookieOptions(cfg *setting.Cfg) func() cookies.CookieOptions {
|
||||
return func() cookies.CookieOptions {
|
||||
path := "/"
|
||||
if len(cfg.AppSubURL) > 0 {
|
||||
path = cfg.AppSubURL
|
||||
}
|
||||
return cookies.CookieOptions{
|
||||
Path: path,
|
||||
Secure: cfg.CookieSecure,
|
||||
SameSiteDisabled: cfg.CookieSameSiteDisabled,
|
||||
SameSiteMode: cfg.CookieSameSiteMode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -17,14 +17,6 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/util/errutil"
|
||||
)
|
||||
|
||||
var (
|
||||
errExpiredAccessToken = errutil.NewBase(
|
||||
errutil.StatusUnauthorized,
|
||||
"oauth.expired-token",
|
||||
errutil.WithPublicMessage("OAuth access token expired"))
|
||||
)
|
||||
|
||||
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
|
||||
@ -122,7 +114,7 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
||||
s.log.FromContext(ctx).Error("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
|
||||
}
|
||||
|
||||
return errExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
|
||||
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -89,7 +89,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
||||
expectInvalidateOauthTokensCalled: true,
|
||||
expectRevokeTokenCalled: true,
|
||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||
expectedErr: errExpiredAccessToken,
|
||||
expectedErr: authn.ErrExpiredAccessToken,
|
||||
}, {
|
||||
desc: "should skip sync when use_refresh_token is disabled",
|
||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},
|
||||
|
@ -7,4 +7,5 @@ var (
|
||||
ErrUnsupportedClient = errutil.NewBase(errutil.StatusBadRequest, "auth.client.unsupported")
|
||||
ErrClientNotConfigured = errutil.NewBase(errutil.StatusBadRequest, "auth.client.notConfigured")
|
||||
ErrUnsupportedIdentity = errutil.NewBase(errutil.StatusNotImplemented, "auth.identity.unsupported")
|
||||
ErrExpiredAccessToken = errutil.NewBase(errutil.StatusUnauthorized, "oauth.expired-token", errutil.WithPublicMessage("OAuth access token expired"))
|
||||
)
|
||||
|
@ -55,12 +55,11 @@ func ProvideService(cfg *setting.Cfg, tokenService auth.UserTokenService, jwtSer
|
||||
authnService authn.Service, anonDeviceService anonymous.Service,
|
||||
) *ContextHandler {
|
||||
return &ContextHandler{
|
||||
Cfg: cfg,
|
||||
AuthTokenService: tokenService,
|
||||
JWTAuthService: jwtService,
|
||||
RemoteCache: remoteCache,
|
||||
RenderService: renderService,
|
||||
SQLStore: sqlStore,
|
||||
Cfg: cfg,
|
||||
AuthTokenService: tokenService,
|
||||
JWTAuthService: jwtService,
|
||||
RemoteCache: remoteCache,
|
||||
RenderService: renderService, SQLStore: sqlStore,
|
||||
tracer: tracer,
|
||||
authProxy: authProxy,
|
||||
authenticator: authenticator,
|
||||
@ -173,7 +172,7 @@ func (h *ContextHandler) Middleware(next http.Handler) http.Handler {
|
||||
if h.Cfg.AuthBrokerEnabled {
|
||||
identity, err := h.AuthnService.Authenticate(ctx, &authn.Request{HTTPRequest: reqContext.Req, Resp: reqContext.Resp})
|
||||
if err != nil {
|
||||
if errors.Is(err, auth.ErrInvalidSessionToken) {
|
||||
if errors.Is(err, auth.ErrInvalidSessionToken) || errors.Is(err, authn.ErrExpiredAccessToken) {
|
||||
// Burn the cookie in case of invalid, expired or missing token
|
||||
reqContext.Resp.Before(h.deleteInvalidCookieEndOfRequestFunc(reqContext))
|
||||
}
|
||||
|
@ -968,11 +968,12 @@ var skipStaticRootValidation = false
|
||||
|
||||
func NewCfg() *Cfg {
|
||||
return &Cfg{
|
||||
Target: []string{},
|
||||
Logger: log.New("settings"),
|
||||
Raw: ini.Empty(),
|
||||
Azure: &azsettings.AzureSettings{},
|
||||
RBACEnabled: true,
|
||||
Target: []string{},
|
||||
Logger: log.New("settings"),
|
||||
Raw: ini.Empty(),
|
||||
Azure: &azsettings.AzureSettings{},
|
||||
RBACEnabled: true,
|
||||
AuthBrokerEnabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user