Auth: Use authn.Service for all tests (#72921)

* Dashboards: Fix tests when authn broker is enabled.
StarService was not configured for tests, the call was guarded by !c.IsSignedIn

* Change default to be anon user to match expectations from tests

* OAuth: rewrite tests to work with authn.Service

* Setup template renderer by default

* Extract cookie options from cfg instead of relying on global variables

* Fix test to work with authn service

* Middleware: rewrite auth tests

* Remvoe session cookie if we cannot refresh access token
This commit is contained in:
Karl Persson 2023-08-09 08:54:52 +02:00 committed by GitHub
parent 5eef8291e2
commit 144e4887ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 431 additions and 1722 deletions

View File

@ -72,6 +72,7 @@ func loggedInUserScenarioWithRole(t *testing.T, desc string, method string, url
sc.context.OrgID = testOrgID
sc.context.Login = testUserLogin
sc.context.OrgRole = role
sc.context.IsAnonymous = false
if sc.handlerFunc != nil {
return sc.handlerFunc(sc.context)
}
@ -212,7 +213,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil,
authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake(),
nil, featuremgmt.WithFeatures(), &authntest.FakeService{
ExpectedIdentity: &authn.Identity{OrgID: 1, ID: "user:1", SessionToken: &usertoken.UserToken{}}}, &anontest.FakeAnonymousSessionService{})
ExpectedIdentity: &authn.Identity{IsAnonymous: true, SessionToken: &usertoken.UserToken{}}}, &anontest.FakeAnonymousSessionService{})
return ctxHdlr
}
@ -310,6 +311,11 @@ func SetupAPITestServer(t *testing.T, opts ...APITestServerOption) *webtest.Serv
hs.registerRoutes()
s := webtest.NewServer(t, hs.RouteRegister)
viewsPath, err := filepath.Abs("../../public/views")
require.NoError(t, err)
s.Mux.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
return s
}

View File

@ -52,6 +52,7 @@ import (
"github.com/grafana/grafana/pkg/services/publicdashboards"
"github.com/grafana/grafana/pkg/services/publicdashboards/api"
"github.com/grafana/grafana/pkg/services/quota/quotatest"
"github.com/grafana/grafana/pkg/services/star/startest"
"github.com/grafana/grafana/pkg/services/tag/tagimpl"
"github.com/grafana/grafana/pkg/services/team/teamtest"
"github.com/grafana/grafana/pkg/services/user"
@ -160,6 +161,7 @@ func TestDashboardAPIEndpoint(t *testing.T) {
dashboardVersionService: fakeDashboardVersionService,
Kinds: corekind.NewBase(nil),
QuotaService: quotatest.New(false, nil),
starService: startest.NewStarServiceFake(),
userService: &usertest.FakeUserService{
ExpectedUser: &user.User{ID: 1, Login: "test-user"},
},
@ -933,6 +935,7 @@ func TestDashboardAPIEndpoint(t *testing.T) {
DashboardService: dashboardService,
Features: featuremgmt.WithFeatures(),
Kinds: corekind.NewBase(nil),
starService: startest.NewStarServiceFake(),
}
hs.callGetDashboard(sc)
@ -1121,6 +1124,7 @@ func getDashboardShouldReturn200WithConfig(t *testing.T, sc *scenarioContext, pr
DashboardService: dashboardService,
Features: featuremgmt.WithFeatures(),
Kinds: corekind.NewBase(nil),
starService: startest.NewStarServiceFake(),
}
hs.callGetDashboard(sc)

View File

@ -92,19 +92,20 @@ func (hs *HTTPServer) OAuthLogin(reqCtx *contextmodel.ReqContext) {
return
}
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" {
cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
}
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
reqCtx.Redirect(redirect.URL)
return
}
identity, err := hs.authnService.Login(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
// NOTE: always delete these cookies, even if login failed
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
if err != nil {
reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))

View File

@ -1,237 +1,212 @@
package api
import (
"crypto/sha256"
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/infra/usagestats"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/models/roletype"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/hooks"
"github.com/grafana/grafana/pkg/services/licensing"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/models/usertoken"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/authn/authntest"
"github.com/grafana/grafana/pkg/services/secrets/fakes"
"github.com/grafana/grafana/pkg/services/supportbundles/supportbundlestest"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
)
func setupSocialHTTPServerWithConfig(t *testing.T, cfg *setting.Cfg) *HTTPServer {
sqlStore := db.InitTestDB(t)
features := featuremgmt.WithFeatures()
return &HTTPServer{
Cfg: cfg,
License: &licensing.OSSLicensingService{Cfg: cfg},
SQLStore: sqlStore,
SocialService: social.ProvideService(cfg, features, &usagestats.UsageStatsMock{}, supportbundlestest.NewFakeBundleService(), remotecache.NewFakeCacheStorage()),
HooksService: hooks.ProvideService(),
SecretsService: fakes.NewFakeSecretsService(),
Features: features,
}
}
func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux {
func setClientWithoutRedirectFollow(t *testing.T) {
t.Helper()
if cfg == nil {
cfg = setting.NewCfg()
}
cfg.ErrTemplateName = "error-template"
hs := setupSocialHTTPServerWithConfig(t, cfg)
m := web.New()
m.Use(getContextHandler(t, cfg).Middleware)
viewPath, err := filepath.Abs("../../public/views")
require.NoError(t, err)
m.UseMiddleware(web.Renderer(viewPath, "[[", "]]"))
m.Get("/login/:name", hs.OAuthLogin)
return m
}
func TestOAuthLogin_UnknownProvider(t *testing.T) {
m := setupOAuthTest(t, nil)
req := httptest.NewRequest(http.MethodGet, "/login/notaprovider", nil)
recorder := httptest.NewRecorder()
m.ServeHTTP(recorder, req)
// expect to be redirected to /login
assert.Equal(t, http.StatusFound, recorder.Code)
assert.Equal(t, "/login", recorder.Header().Get("Location"))
}
func TestOAuthLogin_Base(t *testing.T) {
cfg := setting.NewCfg()
sec := cfg.Raw.Section("auth.generic_oauth")
_, err := sec.NewKey("enabled", "true")
require.NoError(t, err)
m := setupOAuthTest(t, cfg)
req := httptest.NewRequest(http.MethodGet, "/login/generic_oauth", nil)
recorder := httptest.NewRecorder()
m.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.NotEmpty(t, location)
u, err := url.Parse(location)
require.NoError(t, err)
assert.False(t, u.Query().Has("code_challenge"))
assert.False(t, u.Query().Has("code_challenge_method"))
resp := recorder.Result()
require.NoError(t, resp.Body.Close())
cookies := resp.Cookies()
var stateCookie *http.Cookie
for _, c := range cookies {
if c.Name == OauthStateCookieName {
stateCookie = c
}
}
require.NotNil(t, stateCookie)
req = httptest.NewRequest(
http.MethodGet,
(&url.URL{
Path: "/login/generic_oauth",
RawQuery: url.Values{
"code": []string{"helloworld"},
"state": []string{u.Query().Get("state")},
}.Encode(),
}).String(),
nil,
)
req.AddCookie(stateCookie)
recorder = httptest.NewRecorder()
m.ServeHTTP(recorder, req)
// TODO: validate that 'creating a token works'
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
assert.Contains(t, recorder.Body.String(), "login.OAuthLogin(NewTransportWithCode)")
}
func TestOAuthLogin_UsePKCE(t *testing.T) {
cfg := setting.NewCfg()
sec := cfg.Raw.Section("auth.generic_oauth")
_, err := sec.NewKey("enabled", "true")
require.NoError(t, err)
_, err = sec.NewKey("use_pkce", "true")
require.NoError(t, err)
m := setupOAuthTest(t, cfg)
req := httptest.NewRequest(http.MethodGet, "/login/generic_oauth", nil)
recorder := httptest.NewRecorder()
m.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusFound, recorder.Code)
location := recorder.Header().Get("Location")
assert.NotEmpty(t, location)
u, err := url.Parse(location)
require.NoError(t, err)
assert.True(t, u.Query().Has("code_challenge"))
assert.Equal(t, "S256", u.Query().Get("code_challenge_method"))
resp := recorder.Result()
require.NoError(t, resp.Body.Close())
var oauthCookie *http.Cookie
for _, cookie := range resp.Cookies() {
if cookie.Name == OauthPKCECookieName {
oauthCookie = cookie
}
}
require.NotNil(t, oauthCookie)
shasum := sha256.Sum256([]byte(oauthCookie.Value))
assert.Equal(
t,
u.Query().Get("code_challenge"),
base64.RawURLEncoding.EncodeToString(shasum[:]),
)
}
func TestOAuthLogin_BuildExternalUserInfo(t *testing.T) {
t.Helper()
cfgOAuthSkipRoleSync := setting.NewCfg()
authOAuthSec := cfgOAuthSkipRoleSync.Raw.Section("auth")
_, err := authOAuthSec.NewKey("oauth_skip_org_role_update_sync", "true")
require.NoError(t, err)
cfgOAuthSkipRoleSync.ErrTemplateName = "error-template"
cfgOAuthOrgRoleSync := setting.NewCfg()
authOAutoWithoutSec := cfgOAuthOrgRoleSync.Raw.Section("auth")
_, err = authOAutoWithoutSec.NewKey("oauth_skip_org_role_update_sync", "false")
require.NoError(t, err)
cfgOAuthOrgRoleSync.ErrTemplateName = "error-template"
testcases := []struct {
name string
cfg *setting.Cfg
basicUser *social.BasicUserInfo
expectedOrgRoles map[int64]org.RoleType
}{
{
name: "should return empty map of org role mapping if the role for the basic info is empty",
cfg: cfgOAuthOrgRoleSync,
basicUser: &social.BasicUserInfo{
Id: "1",
Name: "first lastname",
Email: "example@github.com",
Login: "example",
Role: "",
},
expectedOrgRoles: map[int64]org.RoleType{},
},
{
name: "should set internal role if role exists and we are skipping org role sync",
cfg: cfgOAuthSkipRoleSync,
basicUser: &social.BasicUserInfo{
Id: "1",
Name: "first lastname",
Email: "example@github.com",
Login: "example",
Role: roletype.RoleAdmin,
},
expectedOrgRoles: map[int64]org.RoleType{1: roletype.RoleAdmin},
},
{
name: "should return empty external role, if the role for the basic info is empty",
cfg: cfgOAuthSkipRoleSync,
basicUser: &social.BasicUserInfo{
Id: "1",
Name: "first lastname",
Email: "example@github.com",
Login: "example",
Role: "",
},
expectedOrgRoles: map[int64]org.RoleType{},
old := http.DefaultClient
http.DefaultClient = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
for _, tc := range testcases {
t.Logf("%s", tc.name)
cfg := tc.cfg
hs := setupSocialHTTPServerWithConfig(t, cfg)
externalUser := hs.buildExternalUserInfo(nil, tc.basicUser, "")
require.Equal(t, tc.expectedOrgRoles, externalUser.OrgRoles)
t.Cleanup(func() {
http.DefaultClient = old
})
}
func TestOAuthLogin_Redirect(t *testing.T) {
type testCase struct {
desc string
expectedErr error
expectedCode int
expectedRedirect *authn.Redirect
}
tests := []testCase{
{
desc: "should be redirected to /login when passing un-configured provider",
expectedErr: authn.ErrClientNotConfigured,
expectedCode: http.StatusFound,
},
{
desc: "should be redirected to provider",
expectedCode: http.StatusFound,
expectedRedirect: &authn.Redirect{
URL: "https://some-provider.com",
Extra: map[string]string{
authn.KeyOAuthState: "some-state",
},
},
},
{
desc: "should set pkce cookie",
expectedCode: http.StatusFound,
expectedRedirect: &authn.Redirect{
URL: "https://some-provider.com",
Extra: map[string]string{
authn.KeyOAuthState: "some-state",
authn.KeyOAuthPKCE: "pkce-",
},
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
server := SetupAPITestServer(t, func(hs *HTTPServer) {
hs.Cfg = setting.NewCfg()
hs.SecretsService = fakes.NewFakeSecretsService()
hs.authnService = &authntest.FakeService{
ExpectedErr: tt.expectedErr,
ExpectedRedirect: tt.expectedRedirect,
}
})
// we need to prevent the http.Client from following redirects
setClientWithoutRedirectFollow(t)
res, err := server.Send(server.NewGetRequest("/login/generic_oauth"))
require.NoError(t, err)
assert.Equal(t, http.StatusFound, res.StatusCode)
// on every error we should get redirected to /login
if tt.expectedErr != nil {
assert.Equal(t, "/login", res.Header.Get("Location"))
} else {
// check that we get correct redirect url
assert.Equal(t, tt.expectedRedirect.URL, res.Header.Get("Location"))
require.GreaterOrEqual(t, len(res.Cookies()), 1)
if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" {
require.Len(t, res.Cookies(), 2)
} else {
require.Len(t, res.Cookies(), 1)
}
require.GreaterOrEqual(t, len(res.Cookies()), 1)
stateCookie := res.Cookies()[0]
assert.Equal(t, OauthStateCookieName, stateCookie.Name)
assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthState], stateCookie.Value)
if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" {
require.Len(t, res.Cookies(), 2)
pkceCookie := res.Cookies()[1]
assert.Equal(t, OauthPKCECookieName, pkceCookie.Name)
assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthPKCE], pkceCookie.Value)
} else {
require.Len(t, res.Cookies(), 1)
}
require.NoError(t, res.Body.Close())
}
})
}
}
func TestOAuthLogin_AuthorizationCode(t *testing.T) {
type testCase struct {
desc string
expectedErr error
expectedIdentity *authn.Identity
}
tests := []testCase{
{
desc: "should redirect to /login on error",
expectedErr: errors.New("some error"),
},
{
desc: "should redirect to / and set session cookie on successful authentication",
expectedIdentity: &authn.Identity{
SessionToken: &usertoken.UserToken{UnhashedToken: "some-token"},
},
},
}
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
var cfg *setting.Cfg
server := SetupAPITestServer(t, func(hs *HTTPServer) {
cfg = setting.NewCfg()
hs.Cfg = cfg
hs.Cfg.LoginCookieName = "some_name"
hs.SecretsService = fakes.NewFakeSecretsService()
hs.authnService = &authntest.FakeService{
ExpectedErr: tt.expectedErr,
ExpectedIdentity: tt.expectedIdentity,
}
})
// we need to prevent the http.Client from following redirects
setClientWithoutRedirectFollow(t)
res, err := server.Send(server.NewGetRequest("/login/generic_oauth?code=code"))
require.NoError(t, err)
require.GreaterOrEqual(t, len(res.Cookies()), 3)
// make sure oauth state cookie is deleted
assert.Equal(t, OauthStateCookieName, res.Cookies()[0].Name)
assert.Equal(t, "", res.Cookies()[0].Value)
assert.Equal(t, -1, res.Cookies()[0].MaxAge)
// make sure oauth pkce cookie is deleted
assert.Equal(t, OauthPKCECookieName, res.Cookies()[1].Name)
assert.Equal(t, "", res.Cookies()[1].Value)
assert.Equal(t, -1, res.Cookies()[1].MaxAge)
if tt.expectedErr != nil {
require.Len(t, res.Cookies(), 3)
assert.Equal(t, http.StatusFound, res.StatusCode)
assert.Equal(t, "/login", res.Header.Get("Location"))
assert.Equal(t, loginErrorCookieName, res.Cookies()[2].Name)
} else {
require.Len(t, res.Cookies(), 4)
assert.Equal(t, http.StatusFound, res.StatusCode)
assert.Equal(t, "/", res.Header.Get("Location"))
// verify session expiry cookie is set
assert.Equal(t, cfg.LoginCookieName, res.Cookies()[2].Name)
assert.Equal(t, "grafana_session_expiry", res.Cookies()[3].Name)
}
require.NoError(t, res.Body.Close())
})
}
}
func TestOAuthLogin_Error(t *testing.T) {
server := SetupAPITestServer(t, func(hs *HTTPServer) {
hs.Cfg = setting.NewCfg()
hs.SecretsService = fakes.NewFakeSecretsService()
})
setClientWithoutRedirectFollow(t)
res, err := server.Send(server.NewGetRequest("/login/azuread?error=someerror"))
require.NoError(t, err)
assert.Equal(t, http.StatusFound, res.StatusCode)
assert.Equal(t, "/login", res.Header.Get("Location"))
require.Len(t, res.Cookies(), 1)
errCookie := res.Cookies()[0]
assert.Equal(t, loginErrorCookieName, errCookie.Name)
require.NoError(t, res.Body.Close())
}

View File

@ -20,14 +20,15 @@ import (
"github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/models/usertoken"
"github.com/grafana/grafana/pkg/services/auth/authtest"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/authn/authntest"
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/hooks"
"github.com/grafana/grafana/pkg/services/licensing"
loginservice "github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/navtree"
"github.com/grafana/grafana/pkg/services/secrets"
"github.com/grafana/grafana/pkg/services/secrets/fakes"
@ -317,11 +318,15 @@ func TestLoginPostRedirect(t *testing.T) {
fakeViewIndex(t)
sc := setupScenarioContext(t, "/login")
hs := &HTTPServer{
log: log.NewNopLogger(),
Cfg: setting.NewCfg(),
HooksService: &hooks.HooksService{},
License: &licensing.OSSLicensingService{},
log: log.NewNopLogger(),
Cfg: setting.NewCfg(),
HooksService: &hooks.HooksService{},
License: &licensing.OSSLicensingService{},
authnService: &authntest.FakeService{
ExpectedIdentity: &authn.Identity{ID: "user:42", SessionToken: &usertoken.UserToken{}},
},
AuthTokenService: authtest.NewFakeUserAuthTokenService(),
Features: featuremgmt.WithFeatures(),
}
@ -333,13 +338,6 @@ func TestLoginPostRedirect(t *testing.T) {
return hs.LoginPost(c)
})
user := &user.User{
ID: 42,
Email: "",
}
hs.authenticator = &fakeAuthenticator{user, "", nil}
redirectCases := []redirectCase{
{
desc: "grafana relative url without subpath",
@ -429,6 +427,9 @@ func TestLoginPostRedirect(t *testing.T) {
hs.Cfg.AppSubURL = c.appSubURL
t.Run(c.desc, func(t *testing.T) {
if c.desc == "grafana invalid relative url starting with subpath" {
fmt.Println()
}
expCookiePath := "/"
if len(hs.Cfg.AppSubURL) > 0 {
expCookiePath = hs.Cfg.AppSubURL
@ -640,112 +641,6 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte
return sc
}
type loginHookTest struct {
info *loginservice.LoginInfo
}
func (r *loginHookTest) LoginHook(loginInfo *loginservice.LoginInfo, req *contextmodel.ReqContext) {
r.info = loginInfo
}
// TOREMOVE: remove with context handler auth
func TestLoginPostRunLokingHook(t *testing.T) {
sc := setupScenarioContext(t, "/login")
hookService := &hooks.HooksService{}
hs := &HTTPServer{
log: log.New("test"),
Cfg: sc.cfg,
License: &licensing.OSSLicensingService{},
AuthTokenService: authtest.NewFakeUserAuthTokenService(),
Features: featuremgmt.WithFeatures(),
HooksService: hookService,
authnService: sc.ctxHdlr.AuthnService,
}
sc.cfg.AuthBrokerEnabled = false
sc.defaultHandler = routing.Wrap(func(c *contextmodel.ReqContext) response.Response {
c.Req.Header.Set("Content-Type", "application/json")
c.Req.Body = io.NopCloser(bytes.NewBufferString(`{"user":"admin","password":"admin"}`))
x := hs.LoginPost(c)
return x
})
testHook := loginHookTest{}
hookService.AddLoginHook(testHook.LoginHook)
testUser := &user.User{
ID: 42,
Email: "",
}
testCases := []struct {
desc string
authUser *user.User
authModule string
authErr error
info loginservice.LoginInfo
}{
{
desc: "invalid credentials",
authErr: login.ErrInvalidCredentials,
info: loginservice.LoginInfo{
AuthModule: "",
HTTPStatus: 401,
Error: login.ErrInvalidCredentials,
},
},
{
desc: "user disabled",
authErr: login.ErrUserDisabled,
info: loginservice.LoginInfo{
AuthModule: "",
HTTPStatus: 401,
Error: login.ErrUserDisabled,
},
},
{
desc: "valid Grafana user",
authUser: testUser,
authModule: "grafana",
info: loginservice.LoginInfo{
AuthModule: "grafana",
User: testUser,
HTTPStatus: 200,
},
},
{
desc: "valid LDAP user",
authUser: testUser,
authModule: loginservice.LDAPAuthModule,
info: loginservice.LoginInfo{
AuthModule: loginservice.LDAPAuthModule,
User: testUser,
HTTPStatus: 200,
},
},
}
for _, c := range testCases {
t.Run(c.desc, func(t *testing.T) {
hs.authenticator = &fakeAuthenticator{c.authUser, c.authModule, c.authErr}
sc.m.Post(sc.url, sc.defaultHandler)
sc.fakeReqNoAssertions("POST", sc.url).exec()
info := testHook.info
assert.Equal(t, c.info.AuthModule, info.AuthModule)
assert.Equal(t, "admin", info.LoginUsername)
assert.Equal(t, c.info.HTTPStatus, info.HTTPStatus)
assert.Equal(t, c.info.Error, info.Error)
if c.info.User != nil {
require.NotEmpty(t, info.User)
assert.Equal(t, c.info.User.ID, info.User.ID)
}
})
}
}
type mockSocialService struct {
oAuthInfo *social.OAuthInfo
oAuthInfos map[string]*social.OAuthInfo
@ -774,15 +669,3 @@ func (m *mockSocialService) GetOAuthHttpClient(name string) (*http.Client, error
func (m *mockSocialService) GetConnector(string) (social.SocialConnector, error) {
return m.socialConnector, m.err
}
type fakeAuthenticator struct {
ExpectedUser *user.User
ExpectedAuthModule string
ExpectedError error
}
func (fa *fakeAuthenticator) AuthenticateUser(c context.Context, query *loginservice.LoginUserQuery) error {
query.User = fa.ExpectedUser
query.AuthModule = fa.ExpectedAuthModule
return fa.ExpectedError
}

View File

@ -1,115 +1,152 @@
package middleware
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/authn/authntest"
"github.com/grafana/grafana/pkg/services/contexthandler"
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
)
func TestMiddlewareAuth(t *testing.T) {
reqSignIn := Auth(&AuthOptions{ReqSignedIn: true})
func setupAuthMiddlewareTest(t *testing.T, identity *authn.Identity, authErr error) *contexthandler.ContextHandler {
return contexthandler.ProvideService(setting.NewCfg(), nil, nil, nil, nil, nil, tracing.NewFakeTracer(), nil, nil, nil, nil, nil, nil, nil, nil, &authntest.FakeService{
ExpectedErr: authErr,
ExpectedIdentity: identity,
}, nil)
}
middlewareScenario(t, "ReqSignIn true and unauthenticated request", func(t *testing.T, sc *scenarioContext) {
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
sc.fakeReq("GET", "/secure").exec()
func TestAuth_Middleware(t *testing.T) {
type testCase struct {
desc string
identity *authn.Identity
path string
authErr error
authMiddleware web.Handler
expecedReached bool
expectedCode int
}
assert.Equal(t, 302, sc.resp.Code)
})
tests := []testCase{
{
desc: "ReqSignedIn should redirect unauthenticated request to secure endpoint",
path: "/secure",
authMiddleware: ReqSignedIn,
authErr: errors.New("no auth"),
expectedCode: http.StatusFound,
},
{
desc: "ReqSignedIn should return 401 for api endpint",
path: "/api/secure",
authMiddleware: ReqSignedIn,
authErr: errors.New("no auth"),
expectedCode: http.StatusUnauthorized,
},
{
desc: "ReqSignedIn should return 200 for anonymous user",
path: "/api/secure",
authMiddleware: ReqSignedIn,
identity: &authn.Identity{IsAnonymous: true},
expecedReached: true,
expectedCode: http.StatusOK,
},
{
desc: "ReqSignedIn should return redirect anonymous user with forceLogin query string",
path: "/secure?forceLogin=true",
authMiddleware: ReqSignedIn,
identity: &authn.Identity{IsAnonymous: true},
expecedReached: false,
expectedCode: http.StatusFound,
},
{
desc: "ReqSignedIn should return redirect anonymous user when orgId in query string is different from currently used",
path: "/secure?orgId=2",
authMiddleware: ReqSignedIn,
identity: &authn.Identity{IsAnonymous: true, OrgID: 1},
expecedReached: false,
expectedCode: http.StatusFound,
},
{
desc: "ReqSignedInNoAnonymous should return 401 for anonymous user",
path: "/api/secure",
authMiddleware: ReqSignedInNoAnonymous,
identity: &authn.Identity{IsAnonymous: true},
expecedReached: false,
expectedCode: http.StatusUnauthorized,
},
{
desc: "ReqSignedInNoAnonymous should return 200 for authenticated user",
path: "/api/secure",
authMiddleware: ReqSignedInNoAnonymous,
identity: &authn.Identity{ID: "user:1"},
expecedReached: true,
expectedCode: http.StatusOK,
},
{
desc: "snapshot public mode disabled should return 200 for authenticated user",
path: "/api/secure",
authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: false}),
identity: &authn.Identity{ID: "user:1"},
expecedReached: true,
expectedCode: http.StatusOK,
},
{
desc: "snapshot public mode disabled should return 401 for unauthenticated request",
path: "/api/secure",
authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: false}),
authErr: errors.New("no auth"),
expecedReached: false,
expectedCode: http.StatusUnauthorized,
},
{
desc: "snapshot public mode enabled should return 200 for unauthenticated request",
path: "/api/secure",
authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: true}),
authErr: errors.New("no auth"),
expecedReached: true,
expectedCode: http.StatusOK,
},
}
middlewareScenario(t, "ReqSignIn true and unauthenticated API request", func(t *testing.T, sc *scenarioContext) {
sc.m.Get("/api/secure", reqSignIn, sc.defaultHandler)
for _, tt := range tests {
t.Run(tt.desc, func(t *testing.T) {
ctxHandler := setupAuthMiddlewareTest(t, tt.identity, tt.authErr)
sc.fakeReq("GET", "/api/secure").exec()
server := web.New()
server.Use(ctxHandler.Middleware)
server.Use(tt.authMiddleware)
assert.Equal(t, 401, sc.resp.Code)
})
var reached bool
server.Get("/secure", func(c *contextmodel.ReqContext) {
reached = true
c.Resp.WriteHeader(http.StatusOK)
})
server.Get("/api/secure", func(c *contextmodel.ReqContext) {
reached = true
c.Resp.WriteHeader(http.StatusOK)
})
t.Run("Anonymous auth enabled", func(t *testing.T) {
const orgID int64 = 1
req, err := http.NewRequest(http.MethodGet, tt.path, nil)
require.NoError(t, err)
recorder := httptest.NewRecorder()
server.ServeHTTP(recorder, req)
configure := func(cfg *setting.Cfg) {
cfg.AnonymousEnabled = true
cfg.AnonymousOrgName = "test"
}
middlewareScenario(t, "ReqSignIn true and NoAnonynmous true", func(
t *testing.T, sc *scenarioContext) {
sc.orgService.ExpectedOrg = &org.Org{ID: orgID, Name: "test"}
sc.m.Get("/api/secure", ReqSignedInNoAnonymous, sc.defaultHandler)
sc.fakeReq("GET", "/api/secure").exec()
assert.Equal(t, 401, sc.resp.Code)
}, configure)
middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func(
t *testing.T, sc *scenarioContext) {
sc.orgService.ExpectedOrg = &org.Org{ID: orgID, Name: "test"}
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
sc.fakeReq("GET", "/secure?forceLogin=true").exec()
assert.Equal(t, 302, sc.resp.Code)
location, ok := sc.resp.Header()["Location"]
assert.True(t, ok)
assert.Equal(t, "/login", location[0])
}, configure)
middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func(
t *testing.T, sc *scenarioContext) {
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
sc.fakeReq("GET", fmt.Sprintf("/secure?orgId=%d", 1)).exec()
assert.Equal(t, 200, sc.resp.Code)
}, configure)
middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func(
t *testing.T, sc *scenarioContext) {
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
sc.fakeReq("GET", "/secure?orgId=2").exec()
assert.Equal(t, 302, sc.resp.Code)
location, ok := sc.resp.Header()["Location"]
assert.True(t, ok)
assert.Equal(t, "/login", location[0])
}, configure)
})
middlewareScenario(t, "Snapshot public mode disabled and unauthenticated request should return 401", func(
t *testing.T, sc *scenarioContext) {
sc.m.Get("/api/snapshot", func(c *contextmodel.ReqContext) {
c.IsSignedIn = false
}, SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
sc.fakeReq("GET", "/api/snapshot").exec()
assert.Equal(t, 401, sc.resp.Code)
})
middlewareScenario(t, "Snapshot public mode disabled and authenticated request should return 200", func(
t *testing.T, sc *scenarioContext) {
sc.m.Get("/api/snapshot", func(c *contextmodel.ReqContext) {
c.IsSignedIn = true
}, SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
sc.fakeReq("GET", "/api/snapshot").exec()
assert.Equal(t, 200, sc.resp.Code)
})
middlewareScenario(t, "Snapshot public mode enabled and unauthenticated request should return 200", func(
t *testing.T, sc *scenarioContext) {
sc.cfg.SnapshotPublicMode = true
sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
sc.fakeReq("GET", "/api/snapshot").exec()
assert.Equal(t, 200, sc.resp.Code)
})
res := recorder.Result()
assert.Equal(t, tt.expecedReached, reached)
assert.Equal(t, tt.expectedCode, res.StatusCode)
require.NoError(t, res.Body.Close())
})
}
}
func TestRemoveForceLoginparams(t *testing.T) {

View File

@ -1,108 +0,0 @@
package middleware
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/services/apikey"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/login/logintest"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
)
func TestMiddlewareBasicAuth(t *testing.T) {
const id int64 = 12
configure := func(cfg *setting.Cfg) {
cfg.BasicAuthEnabled = true
cfg.DisableBruteForceLoginProtection = true
}
middlewareScenario(t, "Valid API key", func(t *testing.T, sc *scenarioContext) {
const orgID int64 = 2
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
require.NoError(t, err)
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: orgID, Role: org.RoleEditor, Key: keyhash}
authHeader := util.GetBasicAuthHeader("api_key", "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9")
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
assert.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgID)
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
require.NotNil(t, list)
require.EqualValues(t, []string{"Authorization"}, list.Items)
}, configure)
middlewareScenario(t, "Handle auth", func(t *testing.T, sc *scenarioContext) {
const password = "MyPass"
const orgID int64 = 2
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: id}
authHeader := util.GetBasicAuthHeader("myUser", password)
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgID)
assert.Equal(t, id, sc.context.UserID)
}, configure)
middlewareScenario(t, "Auth sequence", func(t *testing.T, sc *scenarioContext) {
const password = "MyPass"
const salt = "Salt"
encoded, err := util.EncodePassword(password, salt)
require.NoError(t, err)
sc.userService.ExpectedUser = &user.User{Password: encoded, ID: id, Salt: salt}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id}
login.ProvideService(sc.mockSQLStore, &logintest.LoginServiceFake{}, nil, sc.userService, sc.cfg)
authHeader := util.GetBasicAuthHeader("myUser", password)
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
require.NotNil(t, sc.context)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, id, sc.context.UserID)
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
require.NotNil(t, list)
require.EqualValues(t, []string{"Authorization"}, list.Items)
}, configure)
middlewareScenario(t, "Should return error if user is not found", func(t *testing.T, sc *scenarioContext) {
sc.userService.ExpectedError = user.ErrUserNotFound
sc.fakeReq("GET", "/")
sc.req.SetBasicAuth("user", "password")
sc.exec()
err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
require.Error(t, err)
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"])
}, configure)
middlewareScenario(t, "Should return error if user & password do not match", func(t *testing.T, sc *scenarioContext) {
sc.userService.ExpectedError = user.ErrUserNotFound
sc.fakeReq("GET", "/")
sc.req.SetBasicAuth("killa", "gorilla")
sc.exec()
err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
require.Error(t, err)
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"])
}, configure)
}

View File

@ -1,303 +0,0 @@
package middleware
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/services/auth/jwt"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
)
func TestMiddlewareJWTAuth(t *testing.T) {
const myEmail = "vladimir@example.com"
const id int64 = 12
const orgID int64 = 2
configure := func(cfg *setting.Cfg) {
cfg.JWTAuthEnabled = true
cfg.JWTAuthHeaderName = "x-jwt-assertion"
}
configureAuthHeader := func(cfg *setting.Cfg) {
cfg.JWTAuthEnabled = true
cfg.JWTAuthHeaderName = "Authorization"
}
configureUsernameClaim := func(cfg *setting.Cfg) {
cfg.JWTAuthUsernameClaim = "foo-username"
}
configureEmailClaim := func(cfg *setting.Cfg) {
cfg.JWTAuthEmailClaim = "foo-email"
}
configureAutoSignUp := func(cfg *setting.Cfg) {
cfg.JWTAuthAutoSignUp = true
}
configureRole := func(cfg *setting.Cfg) {
cfg.JWTAuthEmailClaim = "sub"
cfg.JWTAuthRoleAttributePath = "role"
}
configureRoleStrict := func(cfg *setting.Cfg) {
cfg.JWTAuthRoleAttributeStrict = true
}
configureRoleAllowAdmin := func(cfg *setting.Cfg) {
cfg.JWTAuthAllowAssignGrafanaAdmin = true
}
// #nosec G101 -- This is dummy/test token
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ2bGFkaW1pckBleGFtcGxlLmNvbSIsImlhdCI6MTUxNjIzOTAyMiwiZm9vLXVzZXJuYW1lIjoidmxhZGltaXIiLCJuYW1lIjoiVmxhZGltaXIgRXhhbXBsZSIsImZvby1lbWFpbCI6InZsYWRpbWlyQGV4YW1wbGUuY29tIn0.MeNU1pCzRHGdQuu5ppeftxT31_2Le2kM1wd1GK2jExs"
middlewareScenario(t, "Valid token with valid login claim", func(t *testing.T, sc *scenarioContext) {
myUsername := "vladimir"
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return jwt.JWTClaims{
"sub": myUsername,
"foo-username": myUsername,
}, nil
}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Login: myUsername}
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgID)
assert.Equal(t, id, sc.context.UserID)
assert.Equal(t, myUsername, sc.context.Login)
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
require.NotNil(t, list)
require.EqualValues(t, []string{"Authorization", sc.cfg.JWTAuthHeaderName}, list.Items)
}, configure, configureUsernameClaim)
middlewareScenario(t, "Valid token with bearer in authorization header", func(t *testing.T, sc *scenarioContext) {
myUsername := "vladimir"
// We can ignore gosec G101 since this does not contain any credentials.
// nolint:gosec
myToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ2bGFkaW1pckBleGFtcGxlLmNvbSIsImlhdCI6MTUxNjIzOTAyMiwiZm9vLXVzZXJuYW1lIjoidmxhZGltaXIiLCJuYW1lIjoiVmxhZGltaXIgRXhhbXBsZSIsImZvby1lbWFpbCI6InZsYWRpbWlyQGV4YW1wbGUuY29tIn0.MeNU1pCzRHGdQuu5ppeftxT31_2Le2kM1wd1GK2jExs"
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = myToken
return jwt.JWTClaims{
"sub": myUsername,
"foo-username": myUsername,
}, nil
}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Login: myUsername}
sc.fakeReq("GET", "/").withJWTAuthHeader("Bearer " + myToken).exec()
assert.Equal(t, verifiedToken, myToken)
assert.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgID)
assert.Equal(t, id, sc.context.UserID)
assert.Equal(t, myUsername, sc.context.Login)
}, configureAuthHeader, configureUsernameClaim)
middlewareScenario(t, "Valid token with valid email claim", func(t *testing.T, sc *scenarioContext) {
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return jwt.JWTClaims{
"sub": myEmail,
"foo-email": myEmail,
}, nil
}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail}
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgID)
assert.Equal(t, id, sc.context.UserID)
assert.Equal(t, myEmail, sc.context.Email)
}, configure, configureEmailClaim)
middlewareScenario(t, "Valid token with no user and auto_sign_up disabled", func(t *testing.T, sc *scenarioContext) {
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return jwt.JWTClaims{
"sub": myEmail,
"name": "Vladimir Example",
"foo-email": myEmail,
}, nil
}
sc.userService.ExpectedError = user.ErrUserNotFound
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, contexthandler.UserNotFound, sc.respJson["message"])
}, configure, configureEmailClaim)
middlewareScenario(t, "Valid token with no user and auto_sign_up enabled", func(t *testing.T, sc *scenarioContext) {
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return jwt.JWTClaims{
"sub": myEmail,
"name": "Vladimir Example",
"foo-email": myEmail,
}, nil
}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail}
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgID)
assert.Equal(t, id, sc.context.UserID)
assert.Equal(t, myEmail, sc.context.Email)
}, configure, configureEmailClaim, configureAutoSignUp)
middlewareScenario(t, "Valid token without a login claim", func(t *testing.T, sc *scenarioContext) {
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return jwt.JWTClaims{
"sub": "baz",
"foo": "bar",
}, nil
}
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, contexthandler.InvalidJWT, sc.respJson["message"])
}, configure, configureUsernameClaim)
middlewareScenario(t, "Valid token without a email claim", func(t *testing.T, sc *scenarioContext) {
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return jwt.JWTClaims{
"sub": "baz",
"foo": "bar",
}, nil
}
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, contexthandler.InvalidJWT, sc.respJson["message"])
}, configure, configureEmailClaim)
middlewareScenario(t, "Valid token with role", func(t *testing.T, sc *scenarioContext) {
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return jwt.JWTClaims{
"sub": myEmail,
"role": "Editor",
}, nil
}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleEditor}
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
}, configure, configureAutoSignUp, configureRole)
middlewareScenario(t, "Valid token with invalid role", func(t *testing.T, sc *scenarioContext) {
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return jwt.JWTClaims{
"sub": myEmail,
"role": "test",
}, nil
}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleViewer}
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, org.RoleViewer, sc.context.OrgRole)
}, configure, configureAutoSignUp, configureRole)
middlewareScenario(t, "Valid token with invalid role in strict mode", func(t *testing.T, sc *scenarioContext) {
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return jwt.JWTClaims{
"sub": myEmail,
"role": "test",
}, nil
}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleViewer}
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 403, sc.resp.Code)
assert.Equal(t, contexthandler.InvalidRole, sc.respJson["message"])
}, configure, configureAutoSignUp, configureRole, configureRoleStrict)
middlewareScenario(t, "Valid token with grafana admin role not allowed", func(t *testing.T, sc *scenarioContext) {
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return jwt.JWTClaims{
"sub": myEmail,
"role": "GrafanaAdmin",
}, nil
}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleAdmin}
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, org.RoleAdmin, sc.context.OrgRole)
assert.False(t, sc.context.IsGrafanaAdmin)
}, configure, configureAutoSignUp, configureRole)
middlewareScenario(t, "Valid token with grafana admin role allowed", func(t *testing.T, sc *scenarioContext) {
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return jwt.JWTClaims{
"sub": myEmail,
"role": "GrafanaAdmin",
}, nil
}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleAdmin, IsGrafanaAdmin: true}
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, org.RoleAdmin, sc.context.OrgRole)
assert.True(t, sc.context.IsGrafanaAdmin)
}, configure, configureAutoSignUp, configureRole, configureRoleAllowAdmin)
middlewareScenario(t, "Invalid token", func(t *testing.T, sc *scenarioContext) {
var verifiedToken string
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
verifiedToken = token
return nil, errors.New("token is invalid")
}
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
assert.Equal(t, verifiedToken, token)
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, contexthandler.InvalidJWT, sc.respJson["message"])
}, configure, configureUsernameClaim)
}

View File

@ -1,17 +1,10 @@
package middleware
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -19,47 +12,22 @@ import (
"github.com/grafana/grafana-plugin-sdk-go/backend/gtime"
"github.com/grafana/grafana/pkg/api/dtos"
"github.com/grafana/grafana/pkg/infra/db/dbtest"
"github.com/grafana/grafana/pkg/infra/fs"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/services/anonymous/anontest"
"github.com/grafana/grafana/pkg/services/apikey"
"github.com/grafana/grafana/pkg/services/apikey/apikeytest"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/auth/authtest"
"github.com/grafana/grafana/pkg/services/auth/jwt"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/authn/authntest"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/ldap/service"
loginsvc "github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/login/logintest"
"github.com/grafana/grafana/pkg/services/navtree"
"github.com/grafana/grafana/pkg/services/org"
"github.com/grafana/grafana/pkg/services/org/orgtest"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/services/user/usertest"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/grafana/grafana/pkg/web"
)
func fakeGetTime() func() time.Time {
var timeSeed int64
return func() time.Time {
fakeNow := time.Unix(timeSeed, 0)
timeSeed++
return fakeNow
}
}
func TestMiddleWareSecurityHeaders(t *testing.T) {
middlewareScenario(t, "middleware should get correct x-xss-protection header", func(t *testing.T, sc *scenarioContext) {
sc.fakeReq("GET", "/api/").exec()
@ -134,11 +102,6 @@ func TestMiddleWareContentSecurityPolicyHeaders(t *testing.T) {
func TestMiddlewareContext(t *testing.T) {
const noStore = "no-store"
configureJWTAuthHeader := func(cfg *setting.Cfg) {
cfg.JWTAuthEnabled = true
cfg.JWTAuthHeaderName = "Authorization"
}
middlewareScenario(t, "middleware should add context to injector", func(t *testing.T, sc *scenarioContext) {
sc.fakeReq("GET", "/").exec()
assert.NotNil(t, sc.context)
@ -214,372 +177,6 @@ func TestMiddlewareContext(t *testing.T) {
cfg.AllowEmbedding = true
})
middlewareScenario(t, "Invalid api key", func(t *testing.T, sc *scenarioContext) {
sc.apiKey = "invalid_key_test"
sc.fakeReq("GET", "/").exec()
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, contexthandler.InvalidAPIKey, sc.respJson["message"])
})
middlewareScenario(t, "Valid API key", func(t *testing.T, sc *scenarioContext) {
const orgID int64 = 12
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
require.NoError(t, err)
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: orgID, Role: org.RoleEditor, Key: keyhash}
sc.fakeReq("GET", "/").withValidApiKey().exec()
require.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgID)
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
})
middlewareScenario(t, "Valid API key with JWT enabled", func(t *testing.T, sc *scenarioContext) {
const orgID int64 = 12
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
require.NoError(t, err)
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: orgID, Role: org.RoleEditor, Key: keyhash}
sc.fakeReq("GET", "/").withValidApiKey().exec()
require.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgID)
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
}, configureJWTAuthHeader)
middlewareScenario(t, "Valid Basic Auth header with JWT enabled and empty 'sub' claim", func(t *testing.T, sc *scenarioContext) {
const password = "MyPass"
const orgID int64 = 2
const userID int64 = 12
// #nosec G101 -- This is dummy/test token
const emptySubToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJzdWIiOiIiLCJpYXQiOjE1MTYyMzkwMjJ9.tnwtOHK58d47dO4DHW4b9MzeToxa1kGiko5Oo887Rqc"
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
authHeader := util.GetBasicAuthHeader("myuser", password)
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).withJWTAuthHeader(emptySubToken).exec()
require.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgID)
assert.Equal(t, userID, sc.context.UserID)
}, func(cfg *setting.Cfg) {
cfg.JWTAuthEnabled = true
cfg.JWTAuthHeaderName = "X-JWT-Token"
cfg.BasicAuthEnabled = true
})
middlewareScenario(t, "Valid Basic Auth header with JWT enabled and missing 'sub' claim", func(t *testing.T, sc *scenarioContext) {
const password = "MyPass"
const orgID int64 = 2
const userID int64 = 12
// #nosec G101 -- This is dummy/test token
const missingSubToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjJ9.8nYFUX869Y1mnDDDU4yL11aANgVRuifoxrE8BHZY1iE"
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
authHeader := util.GetBasicAuthHeader("myuser", password)
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).withJWTAuthHeader(missingSubToken).exec()
require.Equal(t, 200, sc.resp.Code)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgID)
assert.Equal(t, userID, sc.context.UserID)
}, func(cfg *setting.Cfg) {
cfg.JWTAuthEnabled = true
cfg.JWTAuthHeaderName = "X-JWT-Token"
cfg.BasicAuthEnabled = true
})
middlewareScenario(t, "Valid API key, but does not match DB hash", func(t *testing.T, sc *scenarioContext) {
const keyhash = "Something_not_matching"
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: 12, Role: org.RoleEditor, Key: keyhash}
sc.fakeReq("GET", "/").withValidApiKey().exec()
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, contexthandler.InvalidAPIKey, sc.respJson["message"])
})
middlewareScenario(t, "Valid API key, but expired", func(t *testing.T, sc *scenarioContext) {
sc.contextHandler.GetTime = fakeGetTime()
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
require.NoError(t, err)
expires := sc.contextHandler.GetTime().Add(-1 * time.Second).Unix()
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: 12, Role: org.RoleEditor, Key: keyhash, Expires: &expires}
sc.fakeReq("GET", "/").withValidApiKey().exec()
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, "Expired API key", sc.respJson["message"])
})
middlewareScenario(t, "Non-expired auth token in cookie which is not being rotated", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
return &auth.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
require.NotNil(t, sc.context)
require.NotNil(t, sc.context.UserToken)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "Non-expired auth token in cookie which is being rotated", func(t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
return &auth.UserToken{
UserId: userID,
UnhashedToken: "",
}, nil
}
sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *auth.UserToken,
clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
userToken.UnhashedToken = "rotated"
return true, userToken, nil
}
maxAge := int(sc.cfg.LoginMaxLifetime.Seconds())
sameSiteModes := []http.SameSite{
http.SameSiteNoneMode,
http.SameSiteLaxMode,
http.SameSiteStrictMode,
}
for _, sameSiteMode := range sameSiteModes {
t.Run(fmt.Sprintf("Same site mode %d", sameSiteMode), func(t *testing.T) {
origCookieSameSiteMode := setting.CookieSameSiteMode
t.Cleanup(func() {
setting.CookieSameSiteMode = origCookieSameSiteMode
})
setting.CookieSameSiteMode = sameSiteMode
expectedCookiePath := "/"
if len(sc.cfg.AppSubURL) > 0 {
expectedCookiePath = sc.cfg.AppSubURL
}
expectedCookie := &http.Cookie{
Name: sc.cfg.LoginCookieName,
Value: "rotated",
Path: expectedCookiePath,
HttpOnly: true,
MaxAge: maxAge,
Secure: setting.CookieSecure,
SameSite: sameSiteMode,
}
sc.fakeReq("GET", "/").exec()
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "rotated", sc.context.UserToken.UnhashedToken)
assert.Equal(t, expectedCookie.String(), sc.resp.Header().Get("Set-Cookie"))
})
}
t.Run("Should not set cookie with SameSite attribute when setting.CookieSameSiteDisabled is true", func(t *testing.T) {
origCookieSameSiteDisabled := setting.CookieSameSiteDisabled
origCookieSameSiteMode := setting.CookieSameSiteMode
t.Cleanup(func() {
setting.CookieSameSiteDisabled = origCookieSameSiteDisabled
setting.CookieSameSiteMode = origCookieSameSiteMode
})
setting.CookieSameSiteDisabled = true
setting.CookieSameSiteMode = http.SameSiteLaxMode
expectedCookiePath := "/"
if len(sc.cfg.AppSubURL) > 0 {
expectedCookiePath = sc.cfg.AppSubURL
}
expectedCookie := &http.Cookie{
Name: sc.cfg.LoginCookieName,
Value: "rotated",
Path: expectedCookiePath,
HttpOnly: true,
MaxAge: maxAge,
Secure: setting.CookieSecure,
}
sc.fakeReq("GET", "/").exec()
assert.Equal(t, expectedCookie.String(), sc.resp.Header().Get("Set-Cookie"))
})
})
middlewareScenario(t, "Invalid/expired auth token in cookie", func(t *testing.T, sc *scenarioContext) {
sc.withTokenSessionCookie("token")
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
return nil, auth.ErrUserTokenNotFound
}
sc.fakeReq("GET", "/").exec()
assert.False(t, sc.context.IsSignedIn)
assert.Equal(t, int64(0), sc.context.UserID)
assert.Nil(t, sc.context.UserToken)
})
middlewareScenario(t, "Non-expired auth token in cookie and non-expired OAuth access token", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(11 * time.Second)}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
return &auth.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
require.NotNil(t, sc.context)
require.NotNil(t, sc.context.UserToken)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token fails", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
signedInUser := &user.SignedInUser{OrgID: 2, UserID: userID}
sc.userService.ExpectedSignedInUser = signedInUser
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{
UserId: userID,
OAuthExpiry: fakeGetTime()().Add(-1 * time.Second),
OAuthAccessToken: "access_token",
OAuthRefreshToken: "refresh_token"}
sc.oauthTokenService.ExpectedErrors = map[string]error{"TryTokenRefresh": errors.New("error")}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
return &auth.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
token := sc.oauthTokenService.GetCurrentOAuthToken(sc.context.Req.Context(), signedInUser)
assert.Equal(t, token.AccessToken, "")
assert.Equal(t, token.RefreshToken, "")
assert.True(t, token.Expiry.IsZero())
require.NotNil(t, sc.context)
require.Nil(t, sc.context.UserToken)
assert.False(t, sc.context.IsSignedIn)
assert.Equal(t, int64(0), sc.context.UserID)
assert.Equal(t, "grafana_session=; Path=/; Max-Age=0; HttpOnly", sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token succeeds", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(-5 * time.Second), OAuthRefreshToken: "refreshtoken"}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
return &auth.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
require.NotNil(t, sc.context)
require.NotNil(t, sc.context.UserToken)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "Non-expired auth token in cookie and OAuth Access Token's Expiry is not set", func(
t *testing.T, sc *scenarioContext) {
const userID int64 = 12
sc.contextHandler.GetTime = fakeGetTime()
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{UserId: userID}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
return &auth.UserToken{
UserId: userID,
UnhashedToken: unhashedToken,
}, nil
}
sc.fakeReq("GET", "/").exec()
require.NotNil(t, sc.context)
require.NotNil(t, sc.context.UserToken)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, userID, sc.context.UserToken.UserId)
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
})
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
sc.fakeReq("GET", "/").exec()
assert.Equal(t, int64(0), sc.context.UserID)
assert.Equal(t, int64(1), sc.context.OrgID)
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
assert.False(t, sc.context.IsSignedIn)
}, func(cfg *setting.Cfg) {
cfg.AnonymousEnabled = true
cfg.AnonymousOrgName = "test"
cfg.AnonymousOrgRole = string(org.RoleEditor)
})
middlewareScenario(t, "middleware should add custom response headers", func(t *testing.T, sc *scenarioContext) {
sc.fakeReq("GET", "/api/").exec()
assert.Regexp(t, "test", sc.resp.Header().Get("X-Custom-Header"))
@ -590,278 +187,6 @@ func TestMiddlewareContext(t *testing.T) {
"X-Other-Header": "other-test",
}
})
t.Run("auth_proxy", func(t *testing.T) {
const userID int64 = 33
const orgID int64 = 4
const defaultOrgId int64 = 1
const orgRole = "Admin"
configure := func(cfg *setting.Cfg) {
cfg.AuthProxyEnabled = true
cfg.AuthProxyAutoSignUp = true
cfg.LDAPAuthEnabled = true
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
cfg.AuthProxyHeaderProperty = "username"
cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS", "Role": "X-WEBAUTH-ROLE"}
}
const hdrName = "markelog"
const group = "grafana-core-team"
middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) {
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
h, err := authproxy.HashCacheKey(hdrName + "-" + group)
require.NoError(t, err)
key := fmt.Sprintf(authproxy.CachePrefix, h)
userIdBytes := []byte(strconv.FormatInt(userID, 10))
err = sc.remoteCacheService.Set(context.Background(), key, userIdBytes, 0)
require.NoError(t, err)
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.req.Header.Set("X-WEBAUTH-GROUPS", group)
sc.exec()
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, orgID, sc.context.OrgID)
}, configure)
middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) {
var actualAuthProxyAutoSignUp *bool = nil
sc.loginService.ExpectedUserFunc = func(cmd *loginsvc.UpsertUserCommand) *user.User {
actualAuthProxyAutoSignUp = &cmd.SignupAllowed
return nil
}
sc.loginService.ExpectedError = login.ErrInvalidCredentials
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec()
assert.False(t, *actualAuthProxyAutoSignUp)
assert.Equal(t, 407, sc.resp.Code)
assert.Nil(t, sc.context)
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.LDAPAuthEnabled = false
cfg.AuthProxyAutoSignUp = false
})
middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) {
sc.loginService.ExpectedUser = &user.User{ID: userID}
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec()
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, orgID, sc.context.OrgID)
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
require.NotNil(t, list)
require.Contains(t, list.Items, sc.cfg.AuthProxyHeaderName)
require.Contains(t, list.Items, "X-WEBAUTH-GROUPS")
require.Contains(t, list.Items, "X-WEBAUTH-ROLE")
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.LDAPAuthEnabled = false
cfg.AuthProxyAutoSignUp = true
})
middlewareScenario(t, "Should assign role from header to default org", func(t *testing.T, sc *scenarioContext) {
var storedRoleInfo map[int64]org.RoleType = nil
sc.loginService.ExpectedUserFunc = func(cmd *loginsvc.UpsertUserCommand) *user.User {
storedRoleInfo = cmd.ExternalUser.OrgRoles
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: defaultOrgId, UserID: userID, OrgRole: storedRoleInfo[defaultOrgId]}
return &user.User{ID: userID}
}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.req.Header.Set("X-WEBAUTH-ROLE", orgRole)
sc.exec()
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, defaultOrgId, sc.context.OrgID)
assert.Equal(t, orgRole, string(sc.context.OrgRole))
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.LDAPAuthEnabled = false
cfg.AuthProxyAutoSignUp = true
})
middlewareScenario(t, "Should NOT assign role from header to non-default org", func(t *testing.T, sc *scenarioContext) {
var storedRoleInfo map[int64]org.RoleType = nil
sc.loginService.ExpectedUserFunc = func(cmd *loginsvc.UpsertUserCommand) *user.User {
storedRoleInfo = cmd.ExternalUser.OrgRoles
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID, OrgRole: storedRoleInfo[orgID]}
return &user.User{ID: userID}
}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.req.Header.Set("X-WEBAUTH-ROLE", "Admin")
sc.req.Header.Set("X-Grafana-Org-Id", strconv.FormatInt(orgID, 10))
sc.exec()
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, orgID, sc.context.OrgID)
// For non-default org, the user role should be empty
assert.Equal(t, "", string(sc.context.OrgRole))
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.LDAPAuthEnabled = false
cfg.AuthProxyAutoSignUp = true
})
middlewareScenario(t, "Should use organisation specified by targetOrgId parameter", func(t *testing.T, sc *scenarioContext) {
var targetOrgID int64 = 123
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: targetOrgID, UserID: userID}
sc.loginService.ExpectedUser = &user.User{ID: userID}
sc.fakeReq("GET", fmt.Sprintf("/?targetOrgId=%d", targetOrgID))
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec()
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, targetOrgID, sc.context.OrgID)
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.LDAPAuthEnabled = false
cfg.AuthProxyAutoSignUp = true
})
middlewareScenario(t, "Request body should not be read in default context handler", func(t *testing.T, sc *scenarioContext) {
sc.fakeReq("POST", "/?targetOrgId=123")
body := "key=value"
sc.req.Body = io.NopCloser(strings.NewReader(body))
sc.handlerFunc = func(c *contextmodel.ReqContext) {
t.Log("Handler called")
defer func() {
err := c.Req.Body.Close()
require.NoError(t, err)
}()
bodyAfterHandler, e := io.ReadAll(c.Req.Body)
require.NoError(t, e)
require.Equal(t, body, string(bodyAfterHandler))
}
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
sc.req.Header.Set("Content-Length", strconv.Itoa(len(body)))
sc.m.Post("/", sc.defaultHandler)
sc.exec()
})
middlewareScenario(t, "Request body should not be read in default context handler, but query should be altered - jwt", func(t *testing.T, sc *scenarioContext) {
sc.fakeReq("POST", "/?targetOrgId=123&auth_token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NSIsImlhdCI6MTUxNjIzOTAyMn0.1E9qmtctlHAeJzNLPgGFfxdA8WfbEl_vwYO91ffQGxs")
body := "key=value"
sc.req.Body = io.NopCloser(strings.NewReader(body))
sc.handlerFunc = func(c *contextmodel.ReqContext) {
t.Log("Handler called")
defer func() {
err := c.Req.Body.Close()
require.NoError(t, err)
}()
require.Equal(t, "", c.Req.URL.Query().Get("auth_token"))
bodyAfterHandler, e := io.ReadAll(c.Req.Body)
require.NoError(t, e)
require.Equal(t, body, string(bodyAfterHandler))
}
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
sc.req.Header.Set("Content-Length", strconv.Itoa(len(body)))
sc.m.Post("/", sc.defaultHandler)
sc.exec()
}, func(cfg *setting.Cfg) {
cfg.JWTAuthEnabled = true
cfg.JWTAuthURLLogin = true
cfg.JWTAuthHeaderName = "X-WEBAUTH-TOKEN"
})
middlewareScenario(t, "Should get an existing user from header", func(t *testing.T, sc *scenarioContext) {
const userID int64 = 12
const orgID int64 = 2
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
sc.loginService.ExpectedUser = &user.User{ID: userID}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec()
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, orgID, sc.context.OrgID)
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.LDAPAuthEnabled = false
})
middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) {
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
sc.loginService.ExpectedUser = &user.User{ID: userID}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.req.RemoteAddr = "[2001::23]:12345"
sc.exec()
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserID)
assert.Equal(t, orgID, sc.context.OrgID)
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.AuthProxyWhitelist = "192.168.1.0/24, 2001::0/120"
cfg.LDAPAuthEnabled = false
})
middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) {
sc.loginService.ExpectedUser = &user.User{ID: userID}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.req.RemoteAddr = "[2001::23]:12345"
sc.exec()
assert.Equal(t, 407, sc.resp.Code)
assert.Nil(t, sc.context)
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.AuthProxyWhitelist = "8.8.8.8"
cfg.LDAPAuthEnabled = false
})
middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) {
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec()
assert.Equal(t, 407, sc.resp.Code)
assert.Nil(t, sc.context)
}, configure)
middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) {
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec()
assert.Equal(t, 407, sc.resp.Code)
assert.Nil(t, sc.context)
}, configure)
})
}
func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(*setting.Cfg)) {
@ -894,22 +219,14 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
sc.m.UseMiddleware(ContentSecurityPolicy(cfg, logger))
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
sc.mockSQLStore = dbtest.NewFakeDB()
sc.loginService = &loginservice.LoginServiceMock{}
// defalut to not authenticated request
sc.authnService = &authntest.FakeService{ExpectedErr: errors.New("no auth")}
sc.userService = usertest.NewUserServiceFake()
sc.orgService = orgtest.NewOrgServiceFake()
sc.apiKeyService = &apikeytest.Service{}
sc.oauthTokenService = &authtest.FakeOAuthTokenService{}
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService, sc.oauthTokenService)
sc.sqlStore = ctxHdlr.SQLStore
sc.contextHandler = ctxHdlr
ctxHdlr := getContextHandler(t, cfg, sc.authnService)
sc.m.Use(ctxHdlr.Middleware)
sc.m.Use(OrgRedirect(sc.cfg, sc.userService))
sc.userAuthTokenService = ctxHdlr.AuthTokenService.(*authtest.FakeUserAuthTokenService)
sc.jwtAuthService = ctxHdlr.JWTAuthService.(*jwt.FakeJWTService)
sc.remoteCacheService = ctxHdlr.RemoteCache
sc.defaultHandler = func(c *contextmodel.ReqContext) {
require.NotNil(t, c)
t.Log("Default HTTP handler called")
@ -933,40 +250,14 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
})
}
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *dbtest.FakeDB,
loginService *loginservice.LoginServiceMock, apiKeyService *apikeytest.Service,
userService *usertest.FakeUserService, orgService *orgtest.FakeOrgService,
oauthTokenService *authtest.FakeOAuthTokenService,
) *contexthandler.ContextHandler {
func getContextHandler(t *testing.T, cfg *setting.Cfg, authnService authn.Service) *contexthandler.ContextHandler {
t.Helper()
if cfg == nil {
cfg = setting.NewCfg()
}
cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{
Name: "database",
}
remoteCacheSvc := remotecache.NewFakeStore(t)
userAuthTokenSvc := authtest.NewFakeUserAuthTokenService()
renderSvc := &fakeRenderService{}
authJWTSvc := jwt.NewFakeJWTService()
tracer := tracing.InitializeTracerForTest()
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService,
userService, mockSQLStore, &service.LDAPFakeService{ExpectedError: service.ErrUnableToCreateLDAPClient})
authenticator := &logintest.AuthenticatorFake{ExpectedUser: &user.User{}}
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc,
remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy,
loginService, apiKeyService, authenticator, userService, orgService,
oauthTokenService,
featuremgmt.WithFeatures(featuremgmt.FlagAccessTokenExpirationCheck),
&authntest.FakeService{}, &anontest.FakeAnonymousSessionService{})
}
type fakeRenderService struct {
rendering.Service
}
func (s *fakeRenderService) Init() error {
return nil
tracer := tracing.NewFakeTracer()
return contexthandler.ProvideService(cfg, authtest.NewFakeUserAuthTokenService(), nil,
nil, nil, nil, tracer, nil,
nil, nil, nil, nil, nil,
nil, featuremgmt.WithFeatures(featuremgmt.FlagAccessTokenExpirationCheck),
authnService, &anontest.FakeAnonymousSessionService{},
)
}

View File

@ -1,14 +1,12 @@
package middleware
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/services/authn"
)
func TestOrgRedirectMiddleware(t *testing.T) {
@ -46,15 +44,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
for _, tc := range testCases {
middlewareScenario(t, tc.desc, func(t *testing.T, sc *scenarioContext) {
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
return &auth.UserToken{
UserId: 0,
UnhashedToken: "",
}, nil
}
sc.withIdentity(&authn.Identity{})
sc.m.Get("/", sc.defaultHandler)
sc.fakeReq("GET", tc.input).exec()
@ -64,19 +54,11 @@ func TestOrgRedirectMiddleware(t *testing.T) {
}
middlewareScenario(t, "when setting an invalid org for user", func(t *testing.T, sc *scenarioContext) {
sc.withTokenSessionCookie("token")
sc.withIdentity(&authn.Identity{})
sc.userService.ExpectedSetUsingOrgError = fmt.Errorf("")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
return &auth.UserToken{
UserId: 12,
UnhashedToken: "",
}, nil
}
sc.m.Get("/", sc.defaultHandler)
sc.fakeReq("GET", "/?orgId=3").exec()
sc.fakeReq("GET", "/?orgId=1").exec()
require.Equal(t, 404, sc.resp.Code)
})

View File

@ -1,14 +1,13 @@
package middleware
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/quota/quotatest"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
)
@ -53,14 +52,7 @@ func TestMiddlewareQuota(t *testing.T) {
t.Run("with user logged in", func(t *testing.T) {
setUp := func(sc *scenarioContext) {
sc.withTokenSessionCookie("token")
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: 12}
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
return &auth.UserToken{
UserId: 12,
UnhashedToken: "",
}, nil
}
sc.withIdentity(&authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{UserId: 12}})
}
middlewareScenario(t, "global datasource quota reached", func(t *testing.T, sc *scenarioContext) {

View File

@ -8,9 +8,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/services/auth/authtest"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/authn/authntest"
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
"github.com/grafana/grafana/pkg/services/user/usertest"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
)
@ -66,13 +67,10 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
sc.m.Use(AddDefaultResponseHeaders(cfg))
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
sc.userAuthTokenService = authtest.NewFakeUserAuthTokenService()
sc.remoteCacheService = remotecache.NewFakeStore(t)
contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil, nil)
contextHandler := getContextHandler(t, setting.NewCfg(), &authntest.FakeService{ExpectedIdentity: &authn.Identity{}})
sc.m.Use(contextHandler.Middleware)
// mock out gc goroutine
sc.m.Use(OrgRedirect(cfg, sc.userService))
sc.m.Use(OrgRedirect(cfg, usertest.NewUserServiceFake()))
sc.defaultHandler = func(c *contextmodel.ReqContext) {
sc.context = c

View File

@ -8,69 +8,35 @@ import (
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/db/dbtest"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/services/apikey/apikeytest"
"github.com/grafana/grafana/pkg/services/auth/authtest"
"github.com/grafana/grafana/pkg/services/auth/jwt"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/authn"
"github.com/grafana/grafana/pkg/services/authn/authntest"
"github.com/grafana/grafana/pkg/services/contexthandler/ctxkey"
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/org/orgtest"
"github.com/grafana/grafana/pkg/services/user/usertest"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
)
type scenarioContext struct {
t *testing.T
m *web.Mux
context *contextmodel.ReqContext
resp *httptest.ResponseRecorder
apiKey string
authHeader string
jwtAuthHeader string
tokenSessionCookie string
respJson map[string]interface{}
handlerFunc handlerFunc
defaultHandler web.Handler
url string
userAuthTokenService *authtest.FakeUserAuthTokenService
jwtAuthService *jwt.FakeJWTService
remoteCacheService *remotecache.RemoteCache
cfg *setting.Cfg
sqlStore db.DB
mockSQLStore *dbtest.FakeDB
contextHandler *contexthandler.ContextHandler
loginService *loginservice.LoginServiceMock
apiKeyService *apikeytest.Service
userService *usertest.FakeUserService
oauthTokenService *authtest.FakeOAuthTokenService
orgService *orgtest.FakeOrgService
t *testing.T
m *web.Mux
context *contextmodel.ReqContext
resp *httptest.ResponseRecorder
respJson map[string]interface{}
handlerFunc handlerFunc
defaultHandler web.Handler
url string
authnService *authntest.FakeService
userService *usertest.FakeUserService
cfg *setting.Cfg
req *http.Request
}
func (sc *scenarioContext) withValidApiKey() *scenarioContext {
sc.apiKey = "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9"
return sc
}
func (sc *scenarioContext) withTokenSessionCookie(unhashedToken string) *scenarioContext {
sc.tokenSessionCookie = unhashedToken
return sc
}
func (sc *scenarioContext) withAuthorizationHeader(authHeader string) *scenarioContext {
sc.authHeader = authHeader
return sc
}
func (sc *scenarioContext) withJWTAuthHeader(jwtAuthHeader string) *scenarioContext {
sc.jwtAuthHeader = jwtAuthHeader
return sc
// set identity to use for request
func (sc *scenarioContext) withIdentity(identity *authn.Identity) {
sc.authnService.ExpectedErr = nil
sc.authnService.ExpectedIdentity = identity
}
func (sc *scenarioContext) fakeReq(method, url string) *scenarioContext {
@ -116,29 +82,6 @@ func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map
func (sc *scenarioContext) exec() {
sc.t.Helper()
if sc.apiKey != "" {
sc.t.Logf(`Adding header "Authorization: Bearer %s"`, sc.apiKey)
sc.req.Header.Set("Authorization", "Bearer "+sc.apiKey)
}
if sc.authHeader != "" {
sc.t.Logf(`Adding header "Authorization: %s"`, sc.authHeader)
sc.req.Header.Set("Authorization", sc.authHeader)
}
if sc.jwtAuthHeader != "" {
sc.t.Logf(`Adding header "%s: %s"`, sc.cfg.JWTAuthHeaderName, sc.jwtAuthHeader)
sc.req.Header.Set(sc.cfg.JWTAuthHeaderName, sc.jwtAuthHeader)
}
if sc.tokenSessionCookie != "" {
sc.t.Log(`Adding cookie`, "name", sc.cfg.LoginCookieName, "value", sc.tokenSessionCookie)
sc.req.AddCookie(&http.Cookie{
Name: sc.cfg.LoginCookieName,
Value: sc.tokenSessionCookie,
})
}
sc.m.ServeHTTP(sc.resp, sc.req)
if sc.resp.Header().Get("Content-Type") == "application/json; charset=UTF-8" {

View File

@ -337,9 +337,7 @@ type RedirectValidator func(url string) error
// HandleLoginResponse is a utility function to perform common operations after a successful login and returns response.NormalResponse
func HandleLoginResponse(r *http.Request, w http.ResponseWriter, cfg *setting.Cfg, identity *Identity, validator RedirectValidator) *response.NormalResponse {
result := map[string]interface{}{"message": "Logged in"}
if redirectURL := handleLogin(r, w, cfg, identity, validator); redirectURL != cfg.AppSubURL+"/" {
result["redirectUrl"] = redirectURL
}
result["redirectUrl"] = handleLogin(r, w, cfg, identity, validator)
return response.JSON(http.StatusOK, result)
}
@ -356,9 +354,11 @@ func HandleLoginRedirectResponse(r *http.Request, w http.ResponseWriter, cfg *se
func handleLogin(r *http.Request, w http.ResponseWriter, cfg *setting.Cfg, identity *Identity, validator RedirectValidator) string {
redirectURL := cfg.AppSubURL + "/"
if redirectTo := getRedirectURL(r); len(redirectTo) > 0 && validator(redirectTo) == nil {
cookies.DeleteCookie(w, "redirect_to", nil)
redirectURL = redirectTo
if redirectTo := getRedirectURL(r); len(redirectTo) > 0 {
if validator(redirectTo) == nil {
redirectURL = redirectTo
}
cookies.DeleteCookie(w, "redirect_to", cookieOptions(cfg))
}
WriteSessionCookie(w, cfg, identity.SessionToken)
@ -386,17 +386,32 @@ func WriteSessionCookie(w http.ResponseWriter, cfg *setting.Cfg, token *usertoke
cookies.WriteCookie(w, cfg.LoginCookieName, url.QueryEscape(token.UnhashedToken), maxAge, nil)
expiry := token.NextRotation(time.Duration(cfg.TokenRotationIntervalMinutes) * time.Minute)
cookies.WriteCookie(w, sessionExpiryCookie, url.QueryEscape(strconv.FormatInt(expiry.Unix(), 10)), maxAge, func() cookies.CookieOptions {
opts := cookies.NewCookieOptions()
opts := cookieOptions(cfg)()
opts.NotHttpOnly = true
return opts
})
}
func DeleteSessionCookie(w http.ResponseWriter, cfg *setting.Cfg) {
cookies.DeleteCookie(w, cfg.LoginCookieName, nil)
cookies.DeleteCookie(w, cfg.LoginCookieName, cookieOptions(cfg))
cookies.DeleteCookie(w, sessionExpiryCookie, func() cookies.CookieOptions {
opts := cookies.NewCookieOptions()
opts := cookieOptions(cfg)()
opts.NotHttpOnly = true
return opts
})
}
func cookieOptions(cfg *setting.Cfg) func() cookies.CookieOptions {
return func() cookies.CookieOptions {
path := "/"
if len(cfg.AppSubURL) > 0 {
path = cfg.AppSubURL
}
return cookies.CookieOptions{
Path: path,
Secure: cfg.CookieSecure,
SameSiteDisabled: cfg.CookieSameSiteDisabled,
SameSiteMode: cfg.CookieSameSiteMode,
}
}
}

View File

@ -17,14 +17,6 @@ import (
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/util/errutil"
)
var (
errExpiredAccessToken = errutil.NewBase(
errutil.StatusUnauthorized,
"oauth.expired-token",
errutil.WithPublicMessage("OAuth access token expired"))
)
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
@ -122,7 +114,7 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
s.log.FromContext(ctx).Error("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
}
return errExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
}
return nil

View File

@ -89,7 +89,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
expectInvalidateOauthTokensCalled: true,
expectRevokeTokenCalled: true,
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
expectedErr: errExpiredAccessToken,
expectedErr: authn.ErrExpiredAccessToken,
}, {
desc: "should skip sync when use_refresh_token is disabled",
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},

View File

@ -7,4 +7,5 @@ var (
ErrUnsupportedClient = errutil.NewBase(errutil.StatusBadRequest, "auth.client.unsupported")
ErrClientNotConfigured = errutil.NewBase(errutil.StatusBadRequest, "auth.client.notConfigured")
ErrUnsupportedIdentity = errutil.NewBase(errutil.StatusNotImplemented, "auth.identity.unsupported")
ErrExpiredAccessToken = errutil.NewBase(errutil.StatusUnauthorized, "oauth.expired-token", errutil.WithPublicMessage("OAuth access token expired"))
)

View File

@ -55,12 +55,11 @@ func ProvideService(cfg *setting.Cfg, tokenService auth.UserTokenService, jwtSer
authnService authn.Service, anonDeviceService anonymous.Service,
) *ContextHandler {
return &ContextHandler{
Cfg: cfg,
AuthTokenService: tokenService,
JWTAuthService: jwtService,
RemoteCache: remoteCache,
RenderService: renderService,
SQLStore: sqlStore,
Cfg: cfg,
AuthTokenService: tokenService,
JWTAuthService: jwtService,
RemoteCache: remoteCache,
RenderService: renderService, SQLStore: sqlStore,
tracer: tracer,
authProxy: authProxy,
authenticator: authenticator,
@ -173,7 +172,7 @@ func (h *ContextHandler) Middleware(next http.Handler) http.Handler {
if h.Cfg.AuthBrokerEnabled {
identity, err := h.AuthnService.Authenticate(ctx, &authn.Request{HTTPRequest: reqContext.Req, Resp: reqContext.Resp})
if err != nil {
if errors.Is(err, auth.ErrInvalidSessionToken) {
if errors.Is(err, auth.ErrInvalidSessionToken) || errors.Is(err, authn.ErrExpiredAccessToken) {
// Burn the cookie in case of invalid, expired or missing token
reqContext.Resp.Before(h.deleteInvalidCookieEndOfRequestFunc(reqContext))
}

View File

@ -968,11 +968,12 @@ var skipStaticRootValidation = false
func NewCfg() *Cfg {
return &Cfg{
Target: []string{},
Logger: log.New("settings"),
Raw: ini.Empty(),
Azure: &azsettings.AzureSettings{},
RBACEnabled: true,
Target: []string{},
Logger: log.New("settings"),
Raw: ini.Empty(),
Azure: &azsettings.AzureSettings{},
RBACEnabled: true,
AuthBrokerEnabled: true,
}
}