mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Auth: Use authn.Service for all tests (#72921)
* Dashboards: Fix tests when authn broker is enabled. StarService was not configured for tests, the call was guarded by !c.IsSignedIn * Change default to be anon user to match expectations from tests * OAuth: rewrite tests to work with authn.Service * Setup template renderer by default * Extract cookie options from cfg instead of relying on global variables * Fix test to work with authn service * Middleware: rewrite auth tests * Remvoe session cookie if we cannot refresh access token
This commit is contained in:
parent
5eef8291e2
commit
144e4887ee
@ -72,6 +72,7 @@ func loggedInUserScenarioWithRole(t *testing.T, desc string, method string, url
|
|||||||
sc.context.OrgID = testOrgID
|
sc.context.OrgID = testOrgID
|
||||||
sc.context.Login = testUserLogin
|
sc.context.Login = testUserLogin
|
||||||
sc.context.OrgRole = role
|
sc.context.OrgRole = role
|
||||||
|
sc.context.IsAnonymous = false
|
||||||
if sc.handlerFunc != nil {
|
if sc.handlerFunc != nil {
|
||||||
return sc.handlerFunc(sc.context)
|
return sc.handlerFunc(sc.context)
|
||||||
}
|
}
|
||||||
@ -212,7 +213,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
|
|||||||
remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil,
|
remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, nil,
|
||||||
authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake(),
|
authenticator, usertest.NewUserServiceFake(), orgtest.NewOrgServiceFake(),
|
||||||
nil, featuremgmt.WithFeatures(), &authntest.FakeService{
|
nil, featuremgmt.WithFeatures(), &authntest.FakeService{
|
||||||
ExpectedIdentity: &authn.Identity{OrgID: 1, ID: "user:1", SessionToken: &usertoken.UserToken{}}}, &anontest.FakeAnonymousSessionService{})
|
ExpectedIdentity: &authn.Identity{IsAnonymous: true, SessionToken: &usertoken.UserToken{}}}, &anontest.FakeAnonymousSessionService{})
|
||||||
|
|
||||||
return ctxHdlr
|
return ctxHdlr
|
||||||
}
|
}
|
||||||
@ -310,6 +311,11 @@ func SetupAPITestServer(t *testing.T, opts ...APITestServerOption) *webtest.Serv
|
|||||||
hs.registerRoutes()
|
hs.registerRoutes()
|
||||||
|
|
||||||
s := webtest.NewServer(t, hs.RouteRegister)
|
s := webtest.NewServer(t, hs.RouteRegister)
|
||||||
|
|
||||||
|
viewsPath, err := filepath.Abs("../../public/views")
|
||||||
|
require.NoError(t, err)
|
||||||
|
s.Mux.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,6 +52,7 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/publicdashboards"
|
"github.com/grafana/grafana/pkg/services/publicdashboards"
|
||||||
"github.com/grafana/grafana/pkg/services/publicdashboards/api"
|
"github.com/grafana/grafana/pkg/services/publicdashboards/api"
|
||||||
"github.com/grafana/grafana/pkg/services/quota/quotatest"
|
"github.com/grafana/grafana/pkg/services/quota/quotatest"
|
||||||
|
"github.com/grafana/grafana/pkg/services/star/startest"
|
||||||
"github.com/grafana/grafana/pkg/services/tag/tagimpl"
|
"github.com/grafana/grafana/pkg/services/tag/tagimpl"
|
||||||
"github.com/grafana/grafana/pkg/services/team/teamtest"
|
"github.com/grafana/grafana/pkg/services/team/teamtest"
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
@ -160,6 +161,7 @@ func TestDashboardAPIEndpoint(t *testing.T) {
|
|||||||
dashboardVersionService: fakeDashboardVersionService,
|
dashboardVersionService: fakeDashboardVersionService,
|
||||||
Kinds: corekind.NewBase(nil),
|
Kinds: corekind.NewBase(nil),
|
||||||
QuotaService: quotatest.New(false, nil),
|
QuotaService: quotatest.New(false, nil),
|
||||||
|
starService: startest.NewStarServiceFake(),
|
||||||
userService: &usertest.FakeUserService{
|
userService: &usertest.FakeUserService{
|
||||||
ExpectedUser: &user.User{ID: 1, Login: "test-user"},
|
ExpectedUser: &user.User{ID: 1, Login: "test-user"},
|
||||||
},
|
},
|
||||||
@ -933,6 +935,7 @@ func TestDashboardAPIEndpoint(t *testing.T) {
|
|||||||
DashboardService: dashboardService,
|
DashboardService: dashboardService,
|
||||||
Features: featuremgmt.WithFeatures(),
|
Features: featuremgmt.WithFeatures(),
|
||||||
Kinds: corekind.NewBase(nil),
|
Kinds: corekind.NewBase(nil),
|
||||||
|
starService: startest.NewStarServiceFake(),
|
||||||
}
|
}
|
||||||
hs.callGetDashboard(sc)
|
hs.callGetDashboard(sc)
|
||||||
|
|
||||||
@ -1121,6 +1124,7 @@ func getDashboardShouldReturn200WithConfig(t *testing.T, sc *scenarioContext, pr
|
|||||||
DashboardService: dashboardService,
|
DashboardService: dashboardService,
|
||||||
Features: featuremgmt.WithFeatures(),
|
Features: featuremgmt.WithFeatures(),
|
||||||
Kinds: corekind.NewBase(nil),
|
Kinds: corekind.NewBase(nil),
|
||||||
|
starService: startest.NewStarServiceFake(),
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.callGetDashboard(sc)
|
hs.callGetDashboard(sc)
|
||||||
|
@ -92,19 +92,20 @@ func (hs *HTTPServer) OAuthLogin(reqCtx *contextmodel.ReqContext) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||||
|
|
||||||
if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" {
|
if pkce := redirect.Extra[authn.KeyOAuthPKCE]; pkce != "" {
|
||||||
cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
cookies.WriteCookie(reqCtx.Resp, OauthPKCECookieName, pkce, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
cookies.WriteCookie(reqCtx.Resp, OauthStateCookieName, redirect.Extra[authn.KeyOAuthState], hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
|
|
||||||
reqCtx.Redirect(redirect.URL)
|
reqCtx.Redirect(redirect.URL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, err := hs.authnService.Login(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
|
identity, err := hs.authnService.Login(reqCtx.Req.Context(), authn.ClientWithPrefix(name), req)
|
||||||
// NOTE: always delete these cookies, even if login failed
|
// NOTE: always delete these cookies, even if login failed
|
||||||
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
|
||||||
cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
cookies.DeleteCookie(reqCtx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
|
||||||
|
cookies.DeleteCookie(reqCtx.Resp, OauthPKCECookieName, hs.CookieOptionsFromCfg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))
|
reqCtx.Redirect(hs.redirectURLWithErrorCookie(reqCtx, err))
|
||||||
|
@ -1,237 +1,212 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"errors"
|
||||||
"encoding/base64"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/db"
|
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/infra/usagestats"
|
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
|
||||||
"github.com/grafana/grafana/pkg/models/roletype"
|
|
||||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
|
||||||
"github.com/grafana/grafana/pkg/services/hooks"
|
|
||||||
"github.com/grafana/grafana/pkg/services/licensing"
|
|
||||||
"github.com/grafana/grafana/pkg/services/org"
|
|
||||||
"github.com/grafana/grafana/pkg/services/secrets/fakes"
|
"github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||||
"github.com/grafana/grafana/pkg/services/supportbundles/supportbundlestest"
|
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/web"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupSocialHTTPServerWithConfig(t *testing.T, cfg *setting.Cfg) *HTTPServer {
|
func setClientWithoutRedirectFollow(t *testing.T) {
|
||||||
sqlStore := db.InitTestDB(t)
|
|
||||||
features := featuremgmt.WithFeatures()
|
|
||||||
|
|
||||||
return &HTTPServer{
|
|
||||||
Cfg: cfg,
|
|
||||||
License: &licensing.OSSLicensingService{Cfg: cfg},
|
|
||||||
SQLStore: sqlStore,
|
|
||||||
SocialService: social.ProvideService(cfg, features, &usagestats.UsageStatsMock{}, supportbundlestest.NewFakeBundleService(), remotecache.NewFakeCacheStorage()),
|
|
||||||
HooksService: hooks.ProvideService(),
|
|
||||||
SecretsService: fakes.NewFakeSecretsService(),
|
|
||||||
Features: features,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupOAuthTest(t *testing.T, cfg *setting.Cfg) *web.Mux {
|
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
old := http.DefaultClient
|
||||||
if cfg == nil {
|
http.DefaultClient = &http.Client{
|
||||||
cfg = setting.NewCfg()
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
}
|
return http.ErrUseLastResponse
|
||||||
cfg.ErrTemplateName = "error-template"
|
|
||||||
hs := setupSocialHTTPServerWithConfig(t, cfg)
|
|
||||||
|
|
||||||
m := web.New()
|
|
||||||
m.Use(getContextHandler(t, cfg).Middleware)
|
|
||||||
viewPath, err := filepath.Abs("../../public/views")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
m.UseMiddleware(web.Renderer(viewPath, "[[", "]]"))
|
|
||||||
|
|
||||||
m.Get("/login/:name", hs.OAuthLogin)
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOAuthLogin_UnknownProvider(t *testing.T) {
|
|
||||||
m := setupOAuthTest(t, nil)
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/login/notaprovider", nil)
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
|
|
||||||
m.ServeHTTP(recorder, req)
|
|
||||||
// expect to be redirected to /login
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
|
||||||
assert.Equal(t, "/login", recorder.Header().Get("Location"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOAuthLogin_Base(t *testing.T) {
|
|
||||||
cfg := setting.NewCfg()
|
|
||||||
sec := cfg.Raw.Section("auth.generic_oauth")
|
|
||||||
_, err := sec.NewKey("enabled", "true")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
m := setupOAuthTest(t, cfg)
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/login/generic_oauth", nil)
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
|
|
||||||
m.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
|
||||||
|
|
||||||
location := recorder.Header().Get("Location")
|
|
||||||
assert.NotEmpty(t, location)
|
|
||||||
|
|
||||||
u, err := url.Parse(location)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, u.Query().Has("code_challenge"))
|
|
||||||
assert.False(t, u.Query().Has("code_challenge_method"))
|
|
||||||
|
|
||||||
resp := recorder.Result()
|
|
||||||
require.NoError(t, resp.Body.Close())
|
|
||||||
|
|
||||||
cookies := resp.Cookies()
|
|
||||||
var stateCookie *http.Cookie
|
|
||||||
for _, c := range cookies {
|
|
||||||
if c.Name == OauthStateCookieName {
|
|
||||||
stateCookie = c
|
|
||||||
}
|
|
||||||
}
|
|
||||||
require.NotNil(t, stateCookie)
|
|
||||||
|
|
||||||
req = httptest.NewRequest(
|
|
||||||
http.MethodGet,
|
|
||||||
(&url.URL{
|
|
||||||
Path: "/login/generic_oauth",
|
|
||||||
RawQuery: url.Values{
|
|
||||||
"code": []string{"helloworld"},
|
|
||||||
"state": []string{u.Query().Get("state")},
|
|
||||||
}.Encode(),
|
|
||||||
}).String(),
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
req.AddCookie(stateCookie)
|
|
||||||
recorder = httptest.NewRecorder()
|
|
||||||
|
|
||||||
m.ServeHTTP(recorder, req)
|
|
||||||
// TODO: validate that 'creating a token works'
|
|
||||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
|
|
||||||
assert.Contains(t, recorder.Body.String(), "login.OAuthLogin(NewTransportWithCode)")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOAuthLogin_UsePKCE(t *testing.T) {
|
|
||||||
cfg := setting.NewCfg()
|
|
||||||
sec := cfg.Raw.Section("auth.generic_oauth")
|
|
||||||
_, err := sec.NewKey("enabled", "true")
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = sec.NewKey("use_pkce", "true")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
m := setupOAuthTest(t, cfg)
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/login/generic_oauth", nil)
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
|
|
||||||
m.ServeHTTP(recorder, req)
|
|
||||||
|
|
||||||
assert.Equal(t, http.StatusFound, recorder.Code)
|
|
||||||
|
|
||||||
location := recorder.Header().Get("Location")
|
|
||||||
assert.NotEmpty(t, location)
|
|
||||||
|
|
||||||
u, err := url.Parse(location)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, u.Query().Has("code_challenge"))
|
|
||||||
assert.Equal(t, "S256", u.Query().Get("code_challenge_method"))
|
|
||||||
|
|
||||||
resp := recorder.Result()
|
|
||||||
require.NoError(t, resp.Body.Close())
|
|
||||||
|
|
||||||
var oauthCookie *http.Cookie
|
|
||||||
for _, cookie := range resp.Cookies() {
|
|
||||||
if cookie.Name == OauthPKCECookieName {
|
|
||||||
oauthCookie = cookie
|
|
||||||
}
|
|
||||||
}
|
|
||||||
require.NotNil(t, oauthCookie)
|
|
||||||
|
|
||||||
shasum := sha256.Sum256([]byte(oauthCookie.Value))
|
|
||||||
assert.Equal(
|
|
||||||
t,
|
|
||||||
u.Query().Get("code_challenge"),
|
|
||||||
base64.RawURLEncoding.EncodeToString(shasum[:]),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOAuthLogin_BuildExternalUserInfo(t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
cfgOAuthSkipRoleSync := setting.NewCfg()
|
|
||||||
authOAuthSec := cfgOAuthSkipRoleSync.Raw.Section("auth")
|
|
||||||
_, err := authOAuthSec.NewKey("oauth_skip_org_role_update_sync", "true")
|
|
||||||
require.NoError(t, err)
|
|
||||||
cfgOAuthSkipRoleSync.ErrTemplateName = "error-template"
|
|
||||||
|
|
||||||
cfgOAuthOrgRoleSync := setting.NewCfg()
|
|
||||||
authOAutoWithoutSec := cfgOAuthOrgRoleSync.Raw.Section("auth")
|
|
||||||
_, err = authOAutoWithoutSec.NewKey("oauth_skip_org_role_update_sync", "false")
|
|
||||||
require.NoError(t, err)
|
|
||||||
cfgOAuthOrgRoleSync.ErrTemplateName = "error-template"
|
|
||||||
|
|
||||||
testcases := []struct {
|
|
||||||
name string
|
|
||||||
cfg *setting.Cfg
|
|
||||||
basicUser *social.BasicUserInfo
|
|
||||||
expectedOrgRoles map[int64]org.RoleType
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "should return empty map of org role mapping if the role for the basic info is empty",
|
|
||||||
cfg: cfgOAuthOrgRoleSync,
|
|
||||||
basicUser: &social.BasicUserInfo{
|
|
||||||
Id: "1",
|
|
||||||
Name: "first lastname",
|
|
||||||
Email: "example@github.com",
|
|
||||||
Login: "example",
|
|
||||||
Role: "",
|
|
||||||
},
|
|
||||||
expectedOrgRoles: map[int64]org.RoleType{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should set internal role if role exists and we are skipping org role sync",
|
|
||||||
cfg: cfgOAuthSkipRoleSync,
|
|
||||||
basicUser: &social.BasicUserInfo{
|
|
||||||
Id: "1",
|
|
||||||
Name: "first lastname",
|
|
||||||
Email: "example@github.com",
|
|
||||||
Login: "example",
|
|
||||||
Role: roletype.RoleAdmin,
|
|
||||||
},
|
|
||||||
expectedOrgRoles: map[int64]org.RoleType{1: roletype.RoleAdmin},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should return empty external role, if the role for the basic info is empty",
|
|
||||||
cfg: cfgOAuthSkipRoleSync,
|
|
||||||
basicUser: &social.BasicUserInfo{
|
|
||||||
Id: "1",
|
|
||||||
Name: "first lastname",
|
|
||||||
Email: "example@github.com",
|
|
||||||
Login: "example",
|
|
||||||
Role: "",
|
|
||||||
},
|
|
||||||
expectedOrgRoles: map[int64]org.RoleType{},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range testcases {
|
|
||||||
t.Logf("%s", tc.name)
|
t.Cleanup(func() {
|
||||||
cfg := tc.cfg
|
http.DefaultClient = old
|
||||||
hs := setupSocialHTTPServerWithConfig(t, cfg)
|
})
|
||||||
externalUser := hs.buildExternalUserInfo(nil, tc.basicUser, "")
|
}
|
||||||
require.Equal(t, tc.expectedOrgRoles, externalUser.OrgRoles)
|
|
||||||
|
func TestOAuthLogin_Redirect(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
desc string
|
||||||
|
expectedErr error
|
||||||
|
expectedCode int
|
||||||
|
expectedRedirect *authn.Redirect
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
desc: "should be redirected to /login when passing un-configured provider",
|
||||||
|
expectedErr: authn.ErrClientNotConfigured,
|
||||||
|
expectedCode: http.StatusFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should be redirected to provider",
|
||||||
|
expectedCode: http.StatusFound,
|
||||||
|
expectedRedirect: &authn.Redirect{
|
||||||
|
URL: "https://some-provider.com",
|
||||||
|
Extra: map[string]string{
|
||||||
|
authn.KeyOAuthState: "some-state",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should set pkce cookie",
|
||||||
|
expectedCode: http.StatusFound,
|
||||||
|
expectedRedirect: &authn.Redirect{
|
||||||
|
URL: "https://some-provider.com",
|
||||||
|
Extra: map[string]string{
|
||||||
|
authn.KeyOAuthState: "some-state",
|
||||||
|
authn.KeyOAuthPKCE: "pkce-",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
server := SetupAPITestServer(t, func(hs *HTTPServer) {
|
||||||
|
hs.Cfg = setting.NewCfg()
|
||||||
|
hs.SecretsService = fakes.NewFakeSecretsService()
|
||||||
|
hs.authnService = &authntest.FakeService{
|
||||||
|
ExpectedErr: tt.expectedErr,
|
||||||
|
ExpectedRedirect: tt.expectedRedirect,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// we need to prevent the http.Client from following redirects
|
||||||
|
setClientWithoutRedirectFollow(t)
|
||||||
|
|
||||||
|
res, err := server.Send(server.NewGetRequest("/login/generic_oauth"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusFound, res.StatusCode)
|
||||||
|
|
||||||
|
// on every error we should get redirected to /login
|
||||||
|
if tt.expectedErr != nil {
|
||||||
|
assert.Equal(t, "/login", res.Header.Get("Location"))
|
||||||
|
} else {
|
||||||
|
// check that we get correct redirect url
|
||||||
|
assert.Equal(t, tt.expectedRedirect.URL, res.Header.Get("Location"))
|
||||||
|
|
||||||
|
require.GreaterOrEqual(t, len(res.Cookies()), 1)
|
||||||
|
if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" {
|
||||||
|
require.Len(t, res.Cookies(), 2)
|
||||||
|
} else {
|
||||||
|
require.Len(t, res.Cookies(), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.GreaterOrEqual(t, len(res.Cookies()), 1)
|
||||||
|
stateCookie := res.Cookies()[0]
|
||||||
|
assert.Equal(t, OauthStateCookieName, stateCookie.Name)
|
||||||
|
assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthState], stateCookie.Value)
|
||||||
|
|
||||||
|
if tt.expectedRedirect.Extra[authn.KeyOAuthPKCE] != "" {
|
||||||
|
require.Len(t, res.Cookies(), 2)
|
||||||
|
pkceCookie := res.Cookies()[1]
|
||||||
|
assert.Equal(t, OauthPKCECookieName, pkceCookie.Name)
|
||||||
|
assert.Equal(t, tt.expectedRedirect.Extra[authn.KeyOAuthPKCE], pkceCookie.Value)
|
||||||
|
} else {
|
||||||
|
require.Len(t, res.Cookies(), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, res.Body.Close())
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOAuthLogin_AuthorizationCode(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
desc string
|
||||||
|
expectedErr error
|
||||||
|
expectedIdentity *authn.Identity
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []testCase{
|
||||||
|
{
|
||||||
|
desc: "should redirect to /login on error",
|
||||||
|
expectedErr: errors.New("some error"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "should redirect to / and set session cookie on successful authentication",
|
||||||
|
expectedIdentity: &authn.Identity{
|
||||||
|
SessionToken: &usertoken.UserToken{UnhashedToken: "some-token"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
var cfg *setting.Cfg
|
||||||
|
server := SetupAPITestServer(t, func(hs *HTTPServer) {
|
||||||
|
cfg = setting.NewCfg()
|
||||||
|
hs.Cfg = cfg
|
||||||
|
hs.Cfg.LoginCookieName = "some_name"
|
||||||
|
hs.SecretsService = fakes.NewFakeSecretsService()
|
||||||
|
hs.authnService = &authntest.FakeService{
|
||||||
|
ExpectedErr: tt.expectedErr,
|
||||||
|
ExpectedIdentity: tt.expectedIdentity,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// we need to prevent the http.Client from following redirects
|
||||||
|
setClientWithoutRedirectFollow(t)
|
||||||
|
|
||||||
|
res, err := server.Send(server.NewGetRequest("/login/generic_oauth?code=code"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.GreaterOrEqual(t, len(res.Cookies()), 3)
|
||||||
|
|
||||||
|
// make sure oauth state cookie is deleted
|
||||||
|
assert.Equal(t, OauthStateCookieName, res.Cookies()[0].Name)
|
||||||
|
assert.Equal(t, "", res.Cookies()[0].Value)
|
||||||
|
assert.Equal(t, -1, res.Cookies()[0].MaxAge)
|
||||||
|
|
||||||
|
// make sure oauth pkce cookie is deleted
|
||||||
|
assert.Equal(t, OauthPKCECookieName, res.Cookies()[1].Name)
|
||||||
|
assert.Equal(t, "", res.Cookies()[1].Value)
|
||||||
|
assert.Equal(t, -1, res.Cookies()[1].MaxAge)
|
||||||
|
|
||||||
|
if tt.expectedErr != nil {
|
||||||
|
require.Len(t, res.Cookies(), 3)
|
||||||
|
assert.Equal(t, http.StatusFound, res.StatusCode)
|
||||||
|
assert.Equal(t, "/login", res.Header.Get("Location"))
|
||||||
|
assert.Equal(t, loginErrorCookieName, res.Cookies()[2].Name)
|
||||||
|
} else {
|
||||||
|
require.Len(t, res.Cookies(), 4)
|
||||||
|
assert.Equal(t, http.StatusFound, res.StatusCode)
|
||||||
|
assert.Equal(t, "/", res.Header.Get("Location"))
|
||||||
|
|
||||||
|
// verify session expiry cookie is set
|
||||||
|
assert.Equal(t, cfg.LoginCookieName, res.Cookies()[2].Name)
|
||||||
|
assert.Equal(t, "grafana_session_expiry", res.Cookies()[3].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, res.Body.Close())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuthLogin_Error(t *testing.T) {
|
||||||
|
server := SetupAPITestServer(t, func(hs *HTTPServer) {
|
||||||
|
hs.Cfg = setting.NewCfg()
|
||||||
|
hs.SecretsService = fakes.NewFakeSecretsService()
|
||||||
|
})
|
||||||
|
|
||||||
|
setClientWithoutRedirectFollow(t)
|
||||||
|
|
||||||
|
res, err := server.Send(server.NewGetRequest("/login/azuread?error=someerror"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusFound, res.StatusCode)
|
||||||
|
assert.Equal(t, "/login", res.Header.Get("Location"))
|
||||||
|
|
||||||
|
require.Len(t, res.Cookies(), 1)
|
||||||
|
errCookie := res.Cookies()[0]
|
||||||
|
assert.Equal(t, loginErrorCookieName, errCookie.Name)
|
||||||
|
require.NoError(t, res.Body.Close())
|
||||||
|
}
|
||||||
|
@ -20,14 +20,15 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/api/routing"
|
"github.com/grafana/grafana/pkg/api/routing"
|
||||||
"github.com/grafana/grafana/pkg/components/simplejson"
|
"github.com/grafana/grafana/pkg/components/simplejson"
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/login"
|
|
||||||
"github.com/grafana/grafana/pkg/login/social"
|
"github.com/grafana/grafana/pkg/login/social"
|
||||||
|
"github.com/grafana/grafana/pkg/models/usertoken"
|
||||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||||
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
|
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||||
"github.com/grafana/grafana/pkg/services/hooks"
|
"github.com/grafana/grafana/pkg/services/hooks"
|
||||||
"github.com/grafana/grafana/pkg/services/licensing"
|
"github.com/grafana/grafana/pkg/services/licensing"
|
||||||
loginservice "github.com/grafana/grafana/pkg/services/login"
|
|
||||||
"github.com/grafana/grafana/pkg/services/navtree"
|
"github.com/grafana/grafana/pkg/services/navtree"
|
||||||
"github.com/grafana/grafana/pkg/services/secrets"
|
"github.com/grafana/grafana/pkg/services/secrets"
|
||||||
"github.com/grafana/grafana/pkg/services/secrets/fakes"
|
"github.com/grafana/grafana/pkg/services/secrets/fakes"
|
||||||
@ -317,11 +318,15 @@ func TestLoginPostRedirect(t *testing.T) {
|
|||||||
|
|
||||||
fakeViewIndex(t)
|
fakeViewIndex(t)
|
||||||
sc := setupScenarioContext(t, "/login")
|
sc := setupScenarioContext(t, "/login")
|
||||||
|
|
||||||
hs := &HTTPServer{
|
hs := &HTTPServer{
|
||||||
log: log.NewNopLogger(),
|
log: log.NewNopLogger(),
|
||||||
Cfg: setting.NewCfg(),
|
Cfg: setting.NewCfg(),
|
||||||
HooksService: &hooks.HooksService{},
|
HooksService: &hooks.HooksService{},
|
||||||
License: &licensing.OSSLicensingService{},
|
License: &licensing.OSSLicensingService{},
|
||||||
|
authnService: &authntest.FakeService{
|
||||||
|
ExpectedIdentity: &authn.Identity{ID: "user:42", SessionToken: &usertoken.UserToken{}},
|
||||||
|
},
|
||||||
AuthTokenService: authtest.NewFakeUserAuthTokenService(),
|
AuthTokenService: authtest.NewFakeUserAuthTokenService(),
|
||||||
Features: featuremgmt.WithFeatures(),
|
Features: featuremgmt.WithFeatures(),
|
||||||
}
|
}
|
||||||
@ -333,13 +338,6 @@ func TestLoginPostRedirect(t *testing.T) {
|
|||||||
return hs.LoginPost(c)
|
return hs.LoginPost(c)
|
||||||
})
|
})
|
||||||
|
|
||||||
user := &user.User{
|
|
||||||
ID: 42,
|
|
||||||
Email: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.authenticator = &fakeAuthenticator{user, "", nil}
|
|
||||||
|
|
||||||
redirectCases := []redirectCase{
|
redirectCases := []redirectCase{
|
||||||
{
|
{
|
||||||
desc: "grafana relative url without subpath",
|
desc: "grafana relative url without subpath",
|
||||||
@ -429,6 +427,9 @@ func TestLoginPostRedirect(t *testing.T) {
|
|||||||
hs.Cfg.AppSubURL = c.appSubURL
|
hs.Cfg.AppSubURL = c.appSubURL
|
||||||
|
|
||||||
t.Run(c.desc, func(t *testing.T) {
|
t.Run(c.desc, func(t *testing.T) {
|
||||||
|
if c.desc == "grafana invalid relative url starting with subpath" {
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
expCookiePath := "/"
|
expCookiePath := "/"
|
||||||
if len(hs.Cfg.AppSubURL) > 0 {
|
if len(hs.Cfg.AppSubURL) > 0 {
|
||||||
expCookiePath = hs.Cfg.AppSubURL
|
expCookiePath = hs.Cfg.AppSubURL
|
||||||
@ -640,112 +641,6 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte
|
|||||||
return sc
|
return sc
|
||||||
}
|
}
|
||||||
|
|
||||||
type loginHookTest struct {
|
|
||||||
info *loginservice.LoginInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *loginHookTest) LoginHook(loginInfo *loginservice.LoginInfo, req *contextmodel.ReqContext) {
|
|
||||||
r.info = loginInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
// TOREMOVE: remove with context handler auth
|
|
||||||
func TestLoginPostRunLokingHook(t *testing.T) {
|
|
||||||
sc := setupScenarioContext(t, "/login")
|
|
||||||
hookService := &hooks.HooksService{}
|
|
||||||
hs := &HTTPServer{
|
|
||||||
log: log.New("test"),
|
|
||||||
Cfg: sc.cfg,
|
|
||||||
License: &licensing.OSSLicensingService{},
|
|
||||||
AuthTokenService: authtest.NewFakeUserAuthTokenService(),
|
|
||||||
Features: featuremgmt.WithFeatures(),
|
|
||||||
HooksService: hookService,
|
|
||||||
authnService: sc.ctxHdlr.AuthnService,
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.cfg.AuthBrokerEnabled = false
|
|
||||||
|
|
||||||
sc.defaultHandler = routing.Wrap(func(c *contextmodel.ReqContext) response.Response {
|
|
||||||
c.Req.Header.Set("Content-Type", "application/json")
|
|
||||||
c.Req.Body = io.NopCloser(bytes.NewBufferString(`{"user":"admin","password":"admin"}`))
|
|
||||||
x := hs.LoginPost(c)
|
|
||||||
return x
|
|
||||||
})
|
|
||||||
|
|
||||||
testHook := loginHookTest{}
|
|
||||||
hookService.AddLoginHook(testHook.LoginHook)
|
|
||||||
|
|
||||||
testUser := &user.User{
|
|
||||||
ID: 42,
|
|
||||||
Email: "",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
desc string
|
|
||||||
authUser *user.User
|
|
||||||
authModule string
|
|
||||||
authErr error
|
|
||||||
info loginservice.LoginInfo
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "invalid credentials",
|
|
||||||
authErr: login.ErrInvalidCredentials,
|
|
||||||
info: loginservice.LoginInfo{
|
|
||||||
AuthModule: "",
|
|
||||||
HTTPStatus: 401,
|
|
||||||
Error: login.ErrInvalidCredentials,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "user disabled",
|
|
||||||
authErr: login.ErrUserDisabled,
|
|
||||||
info: loginservice.LoginInfo{
|
|
||||||
AuthModule: "",
|
|
||||||
HTTPStatus: 401,
|
|
||||||
Error: login.ErrUserDisabled,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "valid Grafana user",
|
|
||||||
authUser: testUser,
|
|
||||||
authModule: "grafana",
|
|
||||||
info: loginservice.LoginInfo{
|
|
||||||
AuthModule: "grafana",
|
|
||||||
User: testUser,
|
|
||||||
HTTPStatus: 200,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "valid LDAP user",
|
|
||||||
authUser: testUser,
|
|
||||||
authModule: loginservice.LDAPAuthModule,
|
|
||||||
info: loginservice.LoginInfo{
|
|
||||||
AuthModule: loginservice.LDAPAuthModule,
|
|
||||||
User: testUser,
|
|
||||||
HTTPStatus: 200,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range testCases {
|
|
||||||
t.Run(c.desc, func(t *testing.T) {
|
|
||||||
hs.authenticator = &fakeAuthenticator{c.authUser, c.authModule, c.authErr}
|
|
||||||
sc.m.Post(sc.url, sc.defaultHandler)
|
|
||||||
sc.fakeReqNoAssertions("POST", sc.url).exec()
|
|
||||||
|
|
||||||
info := testHook.info
|
|
||||||
assert.Equal(t, c.info.AuthModule, info.AuthModule)
|
|
||||||
assert.Equal(t, "admin", info.LoginUsername)
|
|
||||||
assert.Equal(t, c.info.HTTPStatus, info.HTTPStatus)
|
|
||||||
assert.Equal(t, c.info.Error, info.Error)
|
|
||||||
|
|
||||||
if c.info.User != nil {
|
|
||||||
require.NotEmpty(t, info.User)
|
|
||||||
assert.Equal(t, c.info.User.ID, info.User.ID)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockSocialService struct {
|
type mockSocialService struct {
|
||||||
oAuthInfo *social.OAuthInfo
|
oAuthInfo *social.OAuthInfo
|
||||||
oAuthInfos map[string]*social.OAuthInfo
|
oAuthInfos map[string]*social.OAuthInfo
|
||||||
@ -774,15 +669,3 @@ func (m *mockSocialService) GetOAuthHttpClient(name string) (*http.Client, error
|
|||||||
func (m *mockSocialService) GetConnector(string) (social.SocialConnector, error) {
|
func (m *mockSocialService) GetConnector(string) (social.SocialConnector, error) {
|
||||||
return m.socialConnector, m.err
|
return m.socialConnector, m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
type fakeAuthenticator struct {
|
|
||||||
ExpectedUser *user.User
|
|
||||||
ExpectedAuthModule string
|
|
||||||
ExpectedError error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fa *fakeAuthenticator) AuthenticateUser(c context.Context, query *loginservice.LoginUserQuery) error {
|
|
||||||
query.User = fa.ExpectedUser
|
|
||||||
query.AuthModule = fa.ExpectedAuthModule
|
|
||||||
return fa.ExpectedError
|
|
||||||
}
|
|
||||||
|
@ -1,115 +1,152 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||||
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
|
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||||
|
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||||
"github.com/grafana/grafana/pkg/services/org"
|
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
|
"github.com/grafana/grafana/pkg/web"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMiddlewareAuth(t *testing.T) {
|
func setupAuthMiddlewareTest(t *testing.T, identity *authn.Identity, authErr error) *contexthandler.ContextHandler {
|
||||||
reqSignIn := Auth(&AuthOptions{ReqSignedIn: true})
|
return contexthandler.ProvideService(setting.NewCfg(), nil, nil, nil, nil, nil, tracing.NewFakeTracer(), nil, nil, nil, nil, nil, nil, nil, nil, &authntest.FakeService{
|
||||||
|
ExpectedErr: authErr,
|
||||||
|
ExpectedIdentity: identity,
|
||||||
|
}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
middlewareScenario(t, "ReqSignIn true and unauthenticated request", func(t *testing.T, sc *scenarioContext) {
|
func TestAuth_Middleware(t *testing.T) {
|
||||||
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
|
type testCase struct {
|
||||||
sc.fakeReq("GET", "/secure").exec()
|
desc string
|
||||||
|
identity *authn.Identity
|
||||||
|
path string
|
||||||
|
authErr error
|
||||||
|
authMiddleware web.Handler
|
||||||
|
expecedReached bool
|
||||||
|
expectedCode int
|
||||||
|
}
|
||||||
|
|
||||||
assert.Equal(t, 302, sc.resp.Code)
|
tests := []testCase{
|
||||||
})
|
{
|
||||||
|
desc: "ReqSignedIn should redirect unauthenticated request to secure endpoint",
|
||||||
|
path: "/secure",
|
||||||
|
authMiddleware: ReqSignedIn,
|
||||||
|
authErr: errors.New("no auth"),
|
||||||
|
expectedCode: http.StatusFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "ReqSignedIn should return 401 for api endpint",
|
||||||
|
path: "/api/secure",
|
||||||
|
authMiddleware: ReqSignedIn,
|
||||||
|
authErr: errors.New("no auth"),
|
||||||
|
expectedCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "ReqSignedIn should return 200 for anonymous user",
|
||||||
|
path: "/api/secure",
|
||||||
|
authMiddleware: ReqSignedIn,
|
||||||
|
identity: &authn.Identity{IsAnonymous: true},
|
||||||
|
expecedReached: true,
|
||||||
|
expectedCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "ReqSignedIn should return redirect anonymous user with forceLogin query string",
|
||||||
|
path: "/secure?forceLogin=true",
|
||||||
|
authMiddleware: ReqSignedIn,
|
||||||
|
identity: &authn.Identity{IsAnonymous: true},
|
||||||
|
expecedReached: false,
|
||||||
|
expectedCode: http.StatusFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "ReqSignedIn should return redirect anonymous user when orgId in query string is different from currently used",
|
||||||
|
path: "/secure?orgId=2",
|
||||||
|
authMiddleware: ReqSignedIn,
|
||||||
|
identity: &authn.Identity{IsAnonymous: true, OrgID: 1},
|
||||||
|
expecedReached: false,
|
||||||
|
expectedCode: http.StatusFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "ReqSignedInNoAnonymous should return 401 for anonymous user",
|
||||||
|
path: "/api/secure",
|
||||||
|
authMiddleware: ReqSignedInNoAnonymous,
|
||||||
|
identity: &authn.Identity{IsAnonymous: true},
|
||||||
|
expecedReached: false,
|
||||||
|
expectedCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "ReqSignedInNoAnonymous should return 200 for authenticated user",
|
||||||
|
path: "/api/secure",
|
||||||
|
authMiddleware: ReqSignedInNoAnonymous,
|
||||||
|
identity: &authn.Identity{ID: "user:1"},
|
||||||
|
expecedReached: true,
|
||||||
|
expectedCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "snapshot public mode disabled should return 200 for authenticated user",
|
||||||
|
path: "/api/secure",
|
||||||
|
authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: false}),
|
||||||
|
identity: &authn.Identity{ID: "user:1"},
|
||||||
|
expecedReached: true,
|
||||||
|
expectedCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "snapshot public mode disabled should return 401 for unauthenticated request",
|
||||||
|
path: "/api/secure",
|
||||||
|
authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: false}),
|
||||||
|
authErr: errors.New("no auth"),
|
||||||
|
expecedReached: false,
|
||||||
|
expectedCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "snapshot public mode enabled should return 200 for unauthenticated request",
|
||||||
|
path: "/api/secure",
|
||||||
|
authMiddleware: SnapshotPublicModeOrSignedIn(&setting.Cfg{SnapshotPublicMode: true}),
|
||||||
|
authErr: errors.New("no auth"),
|
||||||
|
expecedReached: true,
|
||||||
|
expectedCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
middlewareScenario(t, "ReqSignIn true and unauthenticated API request", func(t *testing.T, sc *scenarioContext) {
|
for _, tt := range tests {
|
||||||
sc.m.Get("/api/secure", reqSignIn, sc.defaultHandler)
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
ctxHandler := setupAuthMiddlewareTest(t, tt.identity, tt.authErr)
|
||||||
|
|
||||||
sc.fakeReq("GET", "/api/secure").exec()
|
server := web.New()
|
||||||
|
server.Use(ctxHandler.Middleware)
|
||||||
|
server.Use(tt.authMiddleware)
|
||||||
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
var reached bool
|
||||||
})
|
server.Get("/secure", func(c *contextmodel.ReqContext) {
|
||||||
|
reached = true
|
||||||
|
c.Resp.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
server.Get("/api/secure", func(c *contextmodel.ReqContext) {
|
||||||
|
reached = true
|
||||||
|
c.Resp.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("Anonymous auth enabled", func(t *testing.T) {
|
req, err := http.NewRequest(http.MethodGet, tt.path, nil)
|
||||||
const orgID int64 = 1
|
require.NoError(t, err)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
server.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
configure := func(cfg *setting.Cfg) {
|
res := recorder.Result()
|
||||||
cfg.AnonymousEnabled = true
|
assert.Equal(t, tt.expecedReached, reached)
|
||||||
cfg.AnonymousOrgName = "test"
|
assert.Equal(t, tt.expectedCode, res.StatusCode)
|
||||||
}
|
require.NoError(t, res.Body.Close())
|
||||||
|
})
|
||||||
middlewareScenario(t, "ReqSignIn true and NoAnonynmous true", func(
|
}
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.orgService.ExpectedOrg = &org.Org{ID: orgID, Name: "test"}
|
|
||||||
sc.m.Get("/api/secure", ReqSignedInNoAnonymous, sc.defaultHandler)
|
|
||||||
sc.fakeReq("GET", "/api/secure").exec()
|
|
||||||
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
|
||||||
}, configure)
|
|
||||||
|
|
||||||
middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func(
|
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.orgService.ExpectedOrg = &org.Org{ID: orgID, Name: "test"}
|
|
||||||
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/secure?forceLogin=true").exec()
|
|
||||||
|
|
||||||
assert.Equal(t, 302, sc.resp.Code)
|
|
||||||
location, ok := sc.resp.Header()["Location"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, "/login", location[0])
|
|
||||||
}, configure)
|
|
||||||
|
|
||||||
middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func(
|
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
|
|
||||||
|
|
||||||
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
|
|
||||||
|
|
||||||
sc.fakeReq("GET", fmt.Sprintf("/secure?orgId=%d", 1)).exec()
|
|
||||||
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
}, configure)
|
|
||||||
|
|
||||||
middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func(
|
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
|
|
||||||
sc.m.Get("/secure", reqSignIn, sc.defaultHandler)
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/secure?orgId=2").exec()
|
|
||||||
|
|
||||||
assert.Equal(t, 302, sc.resp.Code)
|
|
||||||
location, ok := sc.resp.Header()["Location"]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, "/login", location[0])
|
|
||||||
}, configure)
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Snapshot public mode disabled and unauthenticated request should return 401", func(
|
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.m.Get("/api/snapshot", func(c *contextmodel.ReqContext) {
|
|
||||||
c.IsSignedIn = false
|
|
||||||
}, SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
|
|
||||||
sc.fakeReq("GET", "/api/snapshot").exec()
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Snapshot public mode disabled and authenticated request should return 200", func(
|
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.m.Get("/api/snapshot", func(c *contextmodel.ReqContext) {
|
|
||||||
c.IsSignedIn = true
|
|
||||||
}, SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
|
|
||||||
sc.fakeReq("GET", "/api/snapshot").exec()
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Snapshot public mode enabled and unauthenticated request should return 200", func(
|
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.cfg.SnapshotPublicMode = true
|
|
||||||
sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
|
|
||||||
sc.fakeReq("GET", "/api/snapshot").exec()
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveForceLoginparams(t *testing.T) {
|
func TestRemoveForceLoginparams(t *testing.T) {
|
||||||
|
@ -1,108 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/login"
|
|
||||||
"github.com/grafana/grafana/pkg/services/apikey"
|
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
|
||||||
"github.com/grafana/grafana/pkg/services/login/logintest"
|
|
||||||
"github.com/grafana/grafana/pkg/services/org"
|
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
|
||||||
"github.com/grafana/grafana/pkg/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMiddlewareBasicAuth(t *testing.T) {
|
|
||||||
const id int64 = 12
|
|
||||||
|
|
||||||
configure := func(cfg *setting.Cfg) {
|
|
||||||
cfg.BasicAuthEnabled = true
|
|
||||||
cfg.DisableBruteForceLoginProtection = true
|
|
||||||
}
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid API key", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
const orgID int64 = 2
|
|
||||||
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: orgID, Role: org.RoleEditor, Key: keyhash}
|
|
||||||
|
|
||||||
authHeader := util.GetBasicAuthHeader("api_key", "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9")
|
|
||||||
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
|
|
||||||
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
|
|
||||||
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
|
|
||||||
require.NotNil(t, list)
|
|
||||||
require.EqualValues(t, []string{"Authorization"}, list.Items)
|
|
||||||
}, configure)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Handle auth", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
const password = "MyPass"
|
|
||||||
const orgID int64 = 2
|
|
||||||
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: id}
|
|
||||||
|
|
||||||
authHeader := util.GetBasicAuthHeader("myUser", password)
|
|
||||||
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
assert.Equal(t, id, sc.context.UserID)
|
|
||||||
}, configure)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Auth sequence", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
const password = "MyPass"
|
|
||||||
const salt = "Salt"
|
|
||||||
|
|
||||||
encoded, err := util.EncodePassword(password, salt)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
sc.userService.ExpectedUser = &user.User{Password: encoded, ID: id, Salt: salt}
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id}
|
|
||||||
login.ProvideService(sc.mockSQLStore, &logintest.LoginServiceFake{}, nil, sc.userService, sc.cfg)
|
|
||||||
|
|
||||||
authHeader := util.GetBasicAuthHeader("myUser", password)
|
|
||||||
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
|
|
||||||
require.NotNil(t, sc.context)
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, id, sc.context.UserID)
|
|
||||||
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
|
|
||||||
require.NotNil(t, list)
|
|
||||||
require.EqualValues(t, []string{"Authorization"}, list.Items)
|
|
||||||
}, configure)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should return error if user is not found", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.userService.ExpectedError = user.ErrUserNotFound
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
sc.req.SetBasicAuth("user", "password")
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
|
||||||
assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"])
|
|
||||||
}, configure)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should return error if user & password do not match", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.userService.ExpectedError = user.ErrUserNotFound
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
sc.req.SetBasicAuth("killa", "gorilla")
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
err := json.NewDecoder(sc.resp.Body).Decode(&sc.respJson)
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
|
||||||
assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"])
|
|
||||||
}, configure)
|
|
||||||
}
|
|
@ -1,303 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/services/auth/jwt"
|
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
|
||||||
"github.com/grafana/grafana/pkg/services/org"
|
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMiddlewareJWTAuth(t *testing.T) {
|
|
||||||
const myEmail = "vladimir@example.com"
|
|
||||||
const id int64 = 12
|
|
||||||
const orgID int64 = 2
|
|
||||||
|
|
||||||
configure := func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthEnabled = true
|
|
||||||
cfg.JWTAuthHeaderName = "x-jwt-assertion"
|
|
||||||
}
|
|
||||||
|
|
||||||
configureAuthHeader := func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthEnabled = true
|
|
||||||
cfg.JWTAuthHeaderName = "Authorization"
|
|
||||||
}
|
|
||||||
|
|
||||||
configureUsernameClaim := func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthUsernameClaim = "foo-username"
|
|
||||||
}
|
|
||||||
|
|
||||||
configureEmailClaim := func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthEmailClaim = "foo-email"
|
|
||||||
}
|
|
||||||
|
|
||||||
configureAutoSignUp := func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthAutoSignUp = true
|
|
||||||
}
|
|
||||||
|
|
||||||
configureRole := func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthEmailClaim = "sub"
|
|
||||||
cfg.JWTAuthRoleAttributePath = "role"
|
|
||||||
}
|
|
||||||
|
|
||||||
configureRoleStrict := func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthRoleAttributeStrict = true
|
|
||||||
}
|
|
||||||
|
|
||||||
configureRoleAllowAdmin := func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthAllowAssignGrafanaAdmin = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// #nosec G101 -- This is dummy/test token
|
|
||||||
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ2bGFkaW1pckBleGFtcGxlLmNvbSIsImlhdCI6MTUxNjIzOTAyMiwiZm9vLXVzZXJuYW1lIjoidmxhZGltaXIiLCJuYW1lIjoiVmxhZGltaXIgRXhhbXBsZSIsImZvby1lbWFpbCI6InZsYWRpbWlyQGV4YW1wbGUuY29tIn0.MeNU1pCzRHGdQuu5ppeftxT31_2Le2kM1wd1GK2jExs"
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token with valid login claim", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
myUsername := "vladimir"
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": myUsername,
|
|
||||||
"foo-username": myUsername,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Login: myUsername}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
assert.Equal(t, id, sc.context.UserID)
|
|
||||||
assert.Equal(t, myUsername, sc.context.Login)
|
|
||||||
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
|
|
||||||
require.NotNil(t, list)
|
|
||||||
require.EqualValues(t, []string{"Authorization", sc.cfg.JWTAuthHeaderName}, list.Items)
|
|
||||||
}, configure, configureUsernameClaim)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token with bearer in authorization header", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
myUsername := "vladimir"
|
|
||||||
// We can ignore gosec G101 since this does not contain any credentials.
|
|
||||||
// nolint:gosec
|
|
||||||
myToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ2bGFkaW1pckBleGFtcGxlLmNvbSIsImlhdCI6MTUxNjIzOTAyMiwiZm9vLXVzZXJuYW1lIjoidmxhZGltaXIiLCJuYW1lIjoiVmxhZGltaXIgRXhhbXBsZSIsImZvby1lbWFpbCI6InZsYWRpbWlyQGV4YW1wbGUuY29tIn0.MeNU1pCzRHGdQuu5ppeftxT31_2Le2kM1wd1GK2jExs"
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = myToken
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": myUsername,
|
|
||||||
"foo-username": myUsername,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Login: myUsername}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader("Bearer " + myToken).exec()
|
|
||||||
assert.Equal(t, verifiedToken, myToken)
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
assert.Equal(t, id, sc.context.UserID)
|
|
||||||
assert.Equal(t, myUsername, sc.context.Login)
|
|
||||||
}, configureAuthHeader, configureUsernameClaim)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token with valid email claim", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": myEmail,
|
|
||||||
"foo-email": myEmail,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
assert.Equal(t, id, sc.context.UserID)
|
|
||||||
assert.Equal(t, myEmail, sc.context.Email)
|
|
||||||
}, configure, configureEmailClaim)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token with no user and auto_sign_up disabled", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": myEmail,
|
|
||||||
"name": "Vladimir Example",
|
|
||||||
"foo-email": myEmail,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
sc.userService.ExpectedError = user.ErrUserNotFound
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
|
||||||
assert.Equal(t, contexthandler.UserNotFound, sc.respJson["message"])
|
|
||||||
}, configure, configureEmailClaim)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token with no user and auto_sign_up enabled", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": myEmail,
|
|
||||||
"name": "Vladimir Example",
|
|
||||||
"foo-email": myEmail,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
assert.Equal(t, id, sc.context.UserID)
|
|
||||||
assert.Equal(t, myEmail, sc.context.Email)
|
|
||||||
}, configure, configureEmailClaim, configureAutoSignUp)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token without a login claim", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": "baz",
|
|
||||||
"foo": "bar",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
|
||||||
assert.Equal(t, contexthandler.InvalidJWT, sc.respJson["message"])
|
|
||||||
}, configure, configureUsernameClaim)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token without a email claim", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": "baz",
|
|
||||||
"foo": "bar",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
|
||||||
assert.Equal(t, contexthandler.InvalidJWT, sc.respJson["message"])
|
|
||||||
}, configure, configureEmailClaim)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token with role", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": myEmail,
|
|
||||||
"role": "Editor",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleEditor}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
|
|
||||||
}, configure, configureAutoSignUp, configureRole)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token with invalid role", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": myEmail,
|
|
||||||
"role": "test",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleViewer}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, org.RoleViewer, sc.context.OrgRole)
|
|
||||||
}, configure, configureAutoSignUp, configureRole)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token with invalid role in strict mode", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": myEmail,
|
|
||||||
"role": "test",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleViewer}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 403, sc.resp.Code)
|
|
||||||
assert.Equal(t, contexthandler.InvalidRole, sc.respJson["message"])
|
|
||||||
}, configure, configureAutoSignUp, configureRole, configureRoleStrict)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token with grafana admin role not allowed", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": myEmail,
|
|
||||||
"role": "GrafanaAdmin",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleAdmin}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, org.RoleAdmin, sc.context.OrgRole)
|
|
||||||
assert.False(t, sc.context.IsGrafanaAdmin)
|
|
||||||
}, configure, configureAutoSignUp, configureRole)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid token with grafana admin role allowed", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return jwt.JWTClaims{
|
|
||||||
"sub": myEmail,
|
|
||||||
"role": "GrafanaAdmin",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: id, OrgID: orgID, Email: myEmail, OrgRole: org.RoleAdmin, IsGrafanaAdmin: true}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 200, sc.resp.Code)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, org.RoleAdmin, sc.context.OrgRole)
|
|
||||||
assert.True(t, sc.context.IsGrafanaAdmin)
|
|
||||||
}, configure, configureAutoSignUp, configureRole, configureRoleAllowAdmin)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Invalid token", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var verifiedToken string
|
|
||||||
sc.jwtAuthService.VerifyProvider = func(ctx context.Context, token string) (jwt.JWTClaims, error) {
|
|
||||||
verifiedToken = token
|
|
||||||
return nil, errors.New("token is invalid")
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec()
|
|
||||||
assert.Equal(t, verifiedToken, token)
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
|
||||||
assert.Equal(t, contexthandler.InvalidJWT, sc.respJson["message"])
|
|
||||||
}, configure, configureUsernameClaim)
|
|
||||||
}
|
|
@ -1,17 +1,10 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -19,47 +12,22 @@ import (
|
|||||||
"github.com/grafana/grafana-plugin-sdk-go/backend/gtime"
|
"github.com/grafana/grafana-plugin-sdk-go/backend/gtime"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/api/dtos"
|
"github.com/grafana/grafana/pkg/api/dtos"
|
||||||
"github.com/grafana/grafana/pkg/infra/db/dbtest"
|
|
||||||
"github.com/grafana/grafana/pkg/infra/fs"
|
"github.com/grafana/grafana/pkg/infra/fs"
|
||||||
"github.com/grafana/grafana/pkg/infra/log"
|
"github.com/grafana/grafana/pkg/infra/log"
|
||||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
|
||||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||||
"github.com/grafana/grafana/pkg/login"
|
|
||||||
"github.com/grafana/grafana/pkg/services/anonymous/anontest"
|
"github.com/grafana/grafana/pkg/services/anonymous/anontest"
|
||||||
"github.com/grafana/grafana/pkg/services/apikey"
|
|
||||||
"github.com/grafana/grafana/pkg/services/apikey/apikeytest"
|
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
|
||||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
||||||
"github.com/grafana/grafana/pkg/services/auth/jwt"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
|
||||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||||
"github.com/grafana/grafana/pkg/services/ldap/service"
|
|
||||||
loginsvc "github.com/grafana/grafana/pkg/services/login"
|
|
||||||
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
|
||||||
"github.com/grafana/grafana/pkg/services/login/logintest"
|
|
||||||
"github.com/grafana/grafana/pkg/services/navtree"
|
"github.com/grafana/grafana/pkg/services/navtree"
|
||||||
"github.com/grafana/grafana/pkg/services/org"
|
|
||||||
"github.com/grafana/grafana/pkg/services/org/orgtest"
|
|
||||||
"github.com/grafana/grafana/pkg/services/rendering"
|
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
|
||||||
"github.com/grafana/grafana/pkg/services/user/usertest"
|
"github.com/grafana/grafana/pkg/services/user/usertest"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/util"
|
|
||||||
"github.com/grafana/grafana/pkg/web"
|
"github.com/grafana/grafana/pkg/web"
|
||||||
)
|
)
|
||||||
|
|
||||||
func fakeGetTime() func() time.Time {
|
|
||||||
var timeSeed int64
|
|
||||||
return func() time.Time {
|
|
||||||
fakeNow := time.Unix(timeSeed, 0)
|
|
||||||
timeSeed++
|
|
||||||
return fakeNow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMiddleWareSecurityHeaders(t *testing.T) {
|
func TestMiddleWareSecurityHeaders(t *testing.T) {
|
||||||
middlewareScenario(t, "middleware should get correct x-xss-protection header", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "middleware should get correct x-xss-protection header", func(t *testing.T, sc *scenarioContext) {
|
||||||
sc.fakeReq("GET", "/api/").exec()
|
sc.fakeReq("GET", "/api/").exec()
|
||||||
@ -134,11 +102,6 @@ func TestMiddleWareContentSecurityPolicyHeaders(t *testing.T) {
|
|||||||
func TestMiddlewareContext(t *testing.T) {
|
func TestMiddlewareContext(t *testing.T) {
|
||||||
const noStore = "no-store"
|
const noStore = "no-store"
|
||||||
|
|
||||||
configureJWTAuthHeader := func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthEnabled = true
|
|
||||||
cfg.JWTAuthHeaderName = "Authorization"
|
|
||||||
}
|
|
||||||
|
|
||||||
middlewareScenario(t, "middleware should add context to injector", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "middleware should add context to injector", func(t *testing.T, sc *scenarioContext) {
|
||||||
sc.fakeReq("GET", "/").exec()
|
sc.fakeReq("GET", "/").exec()
|
||||||
assert.NotNil(t, sc.context)
|
assert.NotNil(t, sc.context)
|
||||||
@ -214,372 +177,6 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
cfg.AllowEmbedding = true
|
cfg.AllowEmbedding = true
|
||||||
})
|
})
|
||||||
|
|
||||||
middlewareScenario(t, "Invalid api key", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.apiKey = "invalid_key_test"
|
|
||||||
sc.fakeReq("GET", "/").exec()
|
|
||||||
|
|
||||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
|
||||||
assert.Equal(t, contexthandler.InvalidAPIKey, sc.respJson["message"])
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid API key", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
const orgID int64 = 12
|
|
||||||
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: orgID, Role: org.RoleEditor, Key: keyhash}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withValidApiKey().exec()
|
|
||||||
|
|
||||||
require.Equal(t, 200, sc.resp.Code)
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid API key with JWT enabled", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
const orgID int64 = 12
|
|
||||||
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: orgID, Role: org.RoleEditor, Key: keyhash}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withValidApiKey().exec()
|
|
||||||
|
|
||||||
require.Equal(t, 200, sc.resp.Code)
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
|
|
||||||
}, configureJWTAuthHeader)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid Basic Auth header with JWT enabled and empty 'sub' claim", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
const password = "MyPass"
|
|
||||||
const orgID int64 = 2
|
|
||||||
const userID int64 = 12
|
|
||||||
// #nosec G101 -- This is dummy/test token
|
|
||||||
const emptySubToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJzdWIiOiIiLCJpYXQiOjE1MTYyMzkwMjJ9.tnwtOHK58d47dO4DHW4b9MzeToxa1kGiko5Oo887Rqc"
|
|
||||||
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
|
||||||
authHeader := util.GetBasicAuthHeader("myuser", password)
|
|
||||||
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).withJWTAuthHeader(emptySubToken).exec()
|
|
||||||
|
|
||||||
require.Equal(t, 200, sc.resp.Code)
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthEnabled = true
|
|
||||||
cfg.JWTAuthHeaderName = "X-JWT-Token"
|
|
||||||
cfg.BasicAuthEnabled = true
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid Basic Auth header with JWT enabled and missing 'sub' claim", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
const password = "MyPass"
|
|
||||||
const orgID int64 = 2
|
|
||||||
const userID int64 = 12
|
|
||||||
// #nosec G101 -- This is dummy/test token
|
|
||||||
const missingSubToken = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjJ9.8nYFUX869Y1mnDDDU4yL11aANgVRuifoxrE8BHZY1iE"
|
|
||||||
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
|
||||||
authHeader := util.GetBasicAuthHeader("myuser", password)
|
|
||||||
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).withJWTAuthHeader(missingSubToken).exec()
|
|
||||||
|
|
||||||
require.Equal(t, 200, sc.resp.Code)
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthEnabled = true
|
|
||||||
cfg.JWTAuthHeaderName = "X-JWT-Token"
|
|
||||||
cfg.BasicAuthEnabled = true
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid API key, but does not match DB hash", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
const keyhash = "Something_not_matching"
|
|
||||||
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: 12, Role: org.RoleEditor, Key: keyhash}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withValidApiKey().exec()
|
|
||||||
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
|
||||||
assert.Equal(t, contexthandler.InvalidAPIKey, sc.respJson["message"])
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Valid API key, but expired", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.contextHandler.GetTime = fakeGetTime()
|
|
||||||
|
|
||||||
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
expires := sc.contextHandler.GetTime().Add(-1 * time.Second).Unix()
|
|
||||||
sc.apiKeyService.ExpectedAPIKey = &apikey.APIKey{OrgID: 12, Role: org.RoleEditor, Key: keyhash, Expires: &expires}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").withValidApiKey().exec()
|
|
||||||
|
|
||||||
assert.Equal(t, 401, sc.resp.Code)
|
|
||||||
assert.Equal(t, "Expired API key", sc.respJson["message"])
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Non-expired auth token in cookie which is not being rotated", func(
|
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
const userID int64 = 12
|
|
||||||
|
|
||||||
sc.withTokenSessionCookie("token")
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
|
||||||
|
|
||||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
|
||||||
return &auth.UserToken{
|
|
||||||
UserId: userID,
|
|
||||||
UnhashedToken: unhashedToken,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").exec()
|
|
||||||
|
|
||||||
require.NotNil(t, sc.context)
|
|
||||||
require.NotNil(t, sc.context.UserToken)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
|
||||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
|
||||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Non-expired auth token in cookie which is being rotated", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
const userID int64 = 12
|
|
||||||
|
|
||||||
sc.withTokenSessionCookie("token")
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
|
||||||
|
|
||||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
|
||||||
return &auth.UserToken{
|
|
||||||
UserId: userID,
|
|
||||||
UnhashedToken: "",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.userAuthTokenService.TryRotateTokenProvider = func(ctx context.Context, userToken *auth.UserToken,
|
|
||||||
clientIP net.IP, userAgent string) (bool, *auth.UserToken, error) {
|
|
||||||
userToken.UnhashedToken = "rotated"
|
|
||||||
return true, userToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
maxAge := int(sc.cfg.LoginMaxLifetime.Seconds())
|
|
||||||
|
|
||||||
sameSiteModes := []http.SameSite{
|
|
||||||
http.SameSiteNoneMode,
|
|
||||||
http.SameSiteLaxMode,
|
|
||||||
http.SameSiteStrictMode,
|
|
||||||
}
|
|
||||||
for _, sameSiteMode := range sameSiteModes {
|
|
||||||
t.Run(fmt.Sprintf("Same site mode %d", sameSiteMode), func(t *testing.T) {
|
|
||||||
origCookieSameSiteMode := setting.CookieSameSiteMode
|
|
||||||
t.Cleanup(func() {
|
|
||||||
setting.CookieSameSiteMode = origCookieSameSiteMode
|
|
||||||
})
|
|
||||||
setting.CookieSameSiteMode = sameSiteMode
|
|
||||||
|
|
||||||
expectedCookiePath := "/"
|
|
||||||
if len(sc.cfg.AppSubURL) > 0 {
|
|
||||||
expectedCookiePath = sc.cfg.AppSubURL
|
|
||||||
}
|
|
||||||
expectedCookie := &http.Cookie{
|
|
||||||
Name: sc.cfg.LoginCookieName,
|
|
||||||
Value: "rotated",
|
|
||||||
Path: expectedCookiePath,
|
|
||||||
HttpOnly: true,
|
|
||||||
MaxAge: maxAge,
|
|
||||||
Secure: setting.CookieSecure,
|
|
||||||
SameSite: sameSiteMode,
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").exec()
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
|
||||||
assert.Equal(t, "rotated", sc.context.UserToken.UnhashedToken)
|
|
||||||
assert.Equal(t, expectedCookie.String(), sc.resp.Header().Get("Set-Cookie"))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("Should not set cookie with SameSite attribute when setting.CookieSameSiteDisabled is true", func(t *testing.T) {
|
|
||||||
origCookieSameSiteDisabled := setting.CookieSameSiteDisabled
|
|
||||||
origCookieSameSiteMode := setting.CookieSameSiteMode
|
|
||||||
t.Cleanup(func() {
|
|
||||||
setting.CookieSameSiteDisabled = origCookieSameSiteDisabled
|
|
||||||
setting.CookieSameSiteMode = origCookieSameSiteMode
|
|
||||||
})
|
|
||||||
setting.CookieSameSiteDisabled = true
|
|
||||||
setting.CookieSameSiteMode = http.SameSiteLaxMode
|
|
||||||
|
|
||||||
expectedCookiePath := "/"
|
|
||||||
if len(sc.cfg.AppSubURL) > 0 {
|
|
||||||
expectedCookiePath = sc.cfg.AppSubURL
|
|
||||||
}
|
|
||||||
expectedCookie := &http.Cookie{
|
|
||||||
Name: sc.cfg.LoginCookieName,
|
|
||||||
Value: "rotated",
|
|
||||||
Path: expectedCookiePath,
|
|
||||||
HttpOnly: true,
|
|
||||||
MaxAge: maxAge,
|
|
||||||
Secure: setting.CookieSecure,
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").exec()
|
|
||||||
assert.Equal(t, expectedCookie.String(), sc.resp.Header().Get("Set-Cookie"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Invalid/expired auth token in cookie", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.withTokenSessionCookie("token")
|
|
||||||
|
|
||||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
|
||||||
return nil, auth.ErrUserTokenNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").exec()
|
|
||||||
|
|
||||||
assert.False(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, int64(0), sc.context.UserID)
|
|
||||||
assert.Nil(t, sc.context.UserToken)
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Non-expired auth token in cookie and non-expired OAuth access token", func(
|
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
const userID int64 = 12
|
|
||||||
sc.contextHandler.GetTime = fakeGetTime()
|
|
||||||
|
|
||||||
sc.withTokenSessionCookie("token")
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
|
||||||
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(11 * time.Second)}
|
|
||||||
|
|
||||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
|
||||||
return &auth.UserToken{
|
|
||||||
UserId: userID,
|
|
||||||
UnhashedToken: unhashedToken,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").exec()
|
|
||||||
|
|
||||||
require.NotNil(t, sc.context)
|
|
||||||
require.NotNil(t, sc.context.UserToken)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
|
||||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
|
||||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token fails", func(
|
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
const userID int64 = 12
|
|
||||||
sc.contextHandler.GetTime = fakeGetTime()
|
|
||||||
|
|
||||||
sc.withTokenSessionCookie("token")
|
|
||||||
signedInUser := &user.SignedInUser{OrgID: 2, UserID: userID}
|
|
||||||
sc.userService.ExpectedSignedInUser = signedInUser
|
|
||||||
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{
|
|
||||||
UserId: userID,
|
|
||||||
OAuthExpiry: fakeGetTime()().Add(-1 * time.Second),
|
|
||||||
OAuthAccessToken: "access_token",
|
|
||||||
OAuthRefreshToken: "refresh_token"}
|
|
||||||
sc.oauthTokenService.ExpectedErrors = map[string]error{"TryTokenRefresh": errors.New("error")}
|
|
||||||
|
|
||||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
|
||||||
return &auth.UserToken{
|
|
||||||
UserId: userID,
|
|
||||||
UnhashedToken: unhashedToken,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").exec()
|
|
||||||
|
|
||||||
token := sc.oauthTokenService.GetCurrentOAuthToken(sc.context.Req.Context(), signedInUser)
|
|
||||||
assert.Equal(t, token.AccessToken, "")
|
|
||||||
assert.Equal(t, token.RefreshToken, "")
|
|
||||||
assert.True(t, token.Expiry.IsZero())
|
|
||||||
|
|
||||||
require.NotNil(t, sc.context)
|
|
||||||
require.Nil(t, sc.context.UserToken)
|
|
||||||
assert.False(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, int64(0), sc.context.UserID)
|
|
||||||
assert.Equal(t, "grafana_session=; Path=/; Max-Age=0; HttpOnly", sc.resp.Header().Get("Set-Cookie"))
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Non-expired auth token in cookie and expired OAuth access token and refreshing the token succeeds", func(
|
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
const userID int64 = 12
|
|
||||||
sc.contextHandler.GetTime = fakeGetTime()
|
|
||||||
|
|
||||||
sc.withTokenSessionCookie("token")
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
|
||||||
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{UserId: userID, OAuthExpiry: fakeGetTime()().Add(-5 * time.Second), OAuthRefreshToken: "refreshtoken"}
|
|
||||||
|
|
||||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
|
||||||
return &auth.UserToken{
|
|
||||||
UserId: userID,
|
|
||||||
UnhashedToken: unhashedToken,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").exec()
|
|
||||||
|
|
||||||
require.NotNil(t, sc.context)
|
|
||||||
require.NotNil(t, sc.context.UserToken)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
|
||||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
|
||||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Non-expired auth token in cookie and OAuth Access Token's Expiry is not set", func(
|
|
||||||
t *testing.T, sc *scenarioContext) {
|
|
||||||
const userID int64 = 12
|
|
||||||
sc.contextHandler.GetTime = fakeGetTime()
|
|
||||||
|
|
||||||
sc.withTokenSessionCookie("token")
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 2, UserID: userID}
|
|
||||||
sc.oauthTokenService.ExpectedAuthUser = &loginsvc.UserAuth{UserId: userID}
|
|
||||||
|
|
||||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
|
||||||
return &auth.UserToken{
|
|
||||||
UserId: userID,
|
|
||||||
UnhashedToken: unhashedToken,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/").exec()
|
|
||||||
|
|
||||||
require.NotNil(t, sc.context)
|
|
||||||
require.NotNil(t, sc.context.UserToken)
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, userID, sc.context.UserToken.UserId)
|
|
||||||
assert.Equal(t, "token", sc.context.UserToken.UnhashedToken)
|
|
||||||
assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.orgService.ExpectedOrg = &org.Org{ID: 1, Name: sc.cfg.AnonymousOrgName}
|
|
||||||
sc.fakeReq("GET", "/").exec()
|
|
||||||
|
|
||||||
assert.Equal(t, int64(0), sc.context.UserID)
|
|
||||||
assert.Equal(t, int64(1), sc.context.OrgID)
|
|
||||||
assert.Equal(t, org.RoleEditor, sc.context.OrgRole)
|
|
||||||
assert.False(t, sc.context.IsSignedIn)
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
cfg.AnonymousEnabled = true
|
|
||||||
cfg.AnonymousOrgName = "test"
|
|
||||||
cfg.AnonymousOrgRole = string(org.RoleEditor)
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "middleware should add custom response headers", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "middleware should add custom response headers", func(t *testing.T, sc *scenarioContext) {
|
||||||
sc.fakeReq("GET", "/api/").exec()
|
sc.fakeReq("GET", "/api/").exec()
|
||||||
assert.Regexp(t, "test", sc.resp.Header().Get("X-Custom-Header"))
|
assert.Regexp(t, "test", sc.resp.Header().Get("X-Custom-Header"))
|
||||||
@ -590,278 +187,6 @@ func TestMiddlewareContext(t *testing.T) {
|
|||||||
"X-Other-Header": "other-test",
|
"X-Other-Header": "other-test",
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("auth_proxy", func(t *testing.T) {
|
|
||||||
const userID int64 = 33
|
|
||||||
const orgID int64 = 4
|
|
||||||
const defaultOrgId int64 = 1
|
|
||||||
const orgRole = "Admin"
|
|
||||||
|
|
||||||
configure := func(cfg *setting.Cfg) {
|
|
||||||
cfg.AuthProxyEnabled = true
|
|
||||||
cfg.AuthProxyAutoSignUp = true
|
|
||||||
cfg.LDAPAuthEnabled = true
|
|
||||||
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
|
|
||||||
cfg.AuthProxyHeaderProperty = "username"
|
|
||||||
cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS", "Role": "X-WEBAUTH-ROLE"}
|
|
||||||
}
|
|
||||||
|
|
||||||
const hdrName = "markelog"
|
|
||||||
const group = "grafana-core-team"
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
|
||||||
h, err := authproxy.HashCacheKey(hdrName + "-" + group)
|
|
||||||
require.NoError(t, err)
|
|
||||||
key := fmt.Sprintf(authproxy.CachePrefix, h)
|
|
||||||
userIdBytes := []byte(strconv.FormatInt(userID, 10))
|
|
||||||
err = sc.remoteCacheService.Set(context.Background(), key, userIdBytes, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.req.Header.Set("X-WEBAUTH-GROUPS", group)
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
}, configure)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var actualAuthProxyAutoSignUp *bool = nil
|
|
||||||
sc.loginService.ExpectedUserFunc = func(cmd *loginsvc.UpsertUserCommand) *user.User {
|
|
||||||
actualAuthProxyAutoSignUp = &cmd.SignupAllowed
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
sc.loginService.ExpectedError = login.ErrInvalidCredentials
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
assert.False(t, *actualAuthProxyAutoSignUp)
|
|
||||||
assert.Equal(t, 407, sc.resp.Code)
|
|
||||||
assert.Nil(t, sc.context)
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
configure(cfg)
|
|
||||||
cfg.LDAPAuthEnabled = false
|
|
||||||
cfg.AuthProxyAutoSignUp = false
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.loginService.ExpectedUser = &user.User{ID: userID}
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
list := contexthandler.AuthHTTPHeaderListFromContext(sc.context.Req.Context())
|
|
||||||
require.NotNil(t, list)
|
|
||||||
require.Contains(t, list.Items, sc.cfg.AuthProxyHeaderName)
|
|
||||||
require.Contains(t, list.Items, "X-WEBAUTH-GROUPS")
|
|
||||||
require.Contains(t, list.Items, "X-WEBAUTH-ROLE")
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
configure(cfg)
|
|
||||||
cfg.LDAPAuthEnabled = false
|
|
||||||
cfg.AuthProxyAutoSignUp = true
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should assign role from header to default org", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var storedRoleInfo map[int64]org.RoleType = nil
|
|
||||||
sc.loginService.ExpectedUserFunc = func(cmd *loginsvc.UpsertUserCommand) *user.User {
|
|
||||||
storedRoleInfo = cmd.ExternalUser.OrgRoles
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: defaultOrgId, UserID: userID, OrgRole: storedRoleInfo[defaultOrgId]}
|
|
||||||
return &user.User{ID: userID}
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.req.Header.Set("X-WEBAUTH-ROLE", orgRole)
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, defaultOrgId, sc.context.OrgID)
|
|
||||||
assert.Equal(t, orgRole, string(sc.context.OrgRole))
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
configure(cfg)
|
|
||||||
cfg.LDAPAuthEnabled = false
|
|
||||||
cfg.AuthProxyAutoSignUp = true
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should NOT assign role from header to non-default org", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var storedRoleInfo map[int64]org.RoleType = nil
|
|
||||||
sc.loginService.ExpectedUserFunc = func(cmd *loginsvc.UpsertUserCommand) *user.User {
|
|
||||||
storedRoleInfo = cmd.ExternalUser.OrgRoles
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID, OrgRole: storedRoleInfo[orgID]}
|
|
||||||
return &user.User{ID: userID}
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.req.Header.Set("X-WEBAUTH-ROLE", "Admin")
|
|
||||||
sc.req.Header.Set("X-Grafana-Org-Id", strconv.FormatInt(orgID, 10))
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
|
|
||||||
// For non-default org, the user role should be empty
|
|
||||||
assert.Equal(t, "", string(sc.context.OrgRole))
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
configure(cfg)
|
|
||||||
cfg.LDAPAuthEnabled = false
|
|
||||||
cfg.AuthProxyAutoSignUp = true
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should use organisation specified by targetOrgId parameter", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
var targetOrgID int64 = 123
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: targetOrgID, UserID: userID}
|
|
||||||
sc.loginService.ExpectedUser = &user.User{ID: userID}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", fmt.Sprintf("/?targetOrgId=%d", targetOrgID))
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, targetOrgID, sc.context.OrgID)
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
configure(cfg)
|
|
||||||
cfg.LDAPAuthEnabled = false
|
|
||||||
cfg.AuthProxyAutoSignUp = true
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Request body should not be read in default context handler", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.fakeReq("POST", "/?targetOrgId=123")
|
|
||||||
body := "key=value"
|
|
||||||
sc.req.Body = io.NopCloser(strings.NewReader(body))
|
|
||||||
|
|
||||||
sc.handlerFunc = func(c *contextmodel.ReqContext) {
|
|
||||||
t.Log("Handler called")
|
|
||||||
defer func() {
|
|
||||||
err := c.Req.Body.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
bodyAfterHandler, e := io.ReadAll(c.Req.Body)
|
|
||||||
require.NoError(t, e)
|
|
||||||
require.Equal(t, body, string(bodyAfterHandler))
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
sc.req.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
|
||||||
sc.m.Post("/", sc.defaultHandler)
|
|
||||||
sc.exec()
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Request body should not be read in default context handler, but query should be altered - jwt", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.fakeReq("POST", "/?targetOrgId=123&auth_token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NSIsImlhdCI6MTUxNjIzOTAyMn0.1E9qmtctlHAeJzNLPgGFfxdA8WfbEl_vwYO91ffQGxs")
|
|
||||||
body := "key=value"
|
|
||||||
sc.req.Body = io.NopCloser(strings.NewReader(body))
|
|
||||||
|
|
||||||
sc.handlerFunc = func(c *contextmodel.ReqContext) {
|
|
||||||
t.Log("Handler called")
|
|
||||||
defer func() {
|
|
||||||
err := c.Req.Body.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
require.Equal(t, "", c.Req.URL.Query().Get("auth_token"))
|
|
||||||
|
|
||||||
bodyAfterHandler, e := io.ReadAll(c.Req.Body)
|
|
||||||
require.NoError(t, e)
|
|
||||||
require.Equal(t, body, string(bodyAfterHandler))
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
sc.req.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
|
||||||
sc.m.Post("/", sc.defaultHandler)
|
|
||||||
sc.exec()
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
cfg.JWTAuthEnabled = true
|
|
||||||
cfg.JWTAuthURLLogin = true
|
|
||||||
cfg.JWTAuthHeaderName = "X-WEBAUTH-TOKEN"
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should get an existing user from header", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
const userID int64 = 12
|
|
||||||
const orgID int64 = 2
|
|
||||||
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
|
||||||
sc.loginService.ExpectedUser = &user.User{ID: userID}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
configure(cfg)
|
|
||||||
cfg.LDAPAuthEnabled = false
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: orgID, UserID: userID}
|
|
||||||
sc.loginService.ExpectedUser = &user.User{ID: userID}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.req.RemoteAddr = "[2001::23]:12345"
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
assert.True(t, sc.context.IsSignedIn)
|
|
||||||
assert.Equal(t, userID, sc.context.UserID)
|
|
||||||
assert.Equal(t, orgID, sc.context.OrgID)
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
configure(cfg)
|
|
||||||
cfg.AuthProxyWhitelist = "192.168.1.0/24, 2001::0/120"
|
|
||||||
cfg.LDAPAuthEnabled = false
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.loginService.ExpectedUser = &user.User{ID: userID}
|
|
||||||
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.req.RemoteAddr = "[2001::23]:12345"
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
assert.Equal(t, 407, sc.resp.Code)
|
|
||||||
assert.Nil(t, sc.context)
|
|
||||||
}, func(cfg *setting.Cfg) {
|
|
||||||
configure(cfg)
|
|
||||||
cfg.AuthProxyWhitelist = "8.8.8.8"
|
|
||||||
cfg.LDAPAuthEnabled = false
|
|
||||||
})
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
assert.Equal(t, 407, sc.resp.Code)
|
|
||||||
assert.Nil(t, sc.context)
|
|
||||||
}, configure)
|
|
||||||
|
|
||||||
middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) {
|
|
||||||
sc.fakeReq("GET", "/")
|
|
||||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
|
||||||
sc.exec()
|
|
||||||
|
|
||||||
assert.Equal(t, 407, sc.resp.Code)
|
|
||||||
assert.Nil(t, sc.context)
|
|
||||||
}, configure)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(*setting.Cfg)) {
|
func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(*setting.Cfg)) {
|
||||||
@ -894,22 +219,14 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
|||||||
sc.m.UseMiddleware(ContentSecurityPolicy(cfg, logger))
|
sc.m.UseMiddleware(ContentSecurityPolicy(cfg, logger))
|
||||||
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
||||||
|
|
||||||
sc.mockSQLStore = dbtest.NewFakeDB()
|
// defalut to not authenticated request
|
||||||
sc.loginService = &loginservice.LoginServiceMock{}
|
sc.authnService = &authntest.FakeService{ExpectedErr: errors.New("no auth")}
|
||||||
sc.userService = usertest.NewUserServiceFake()
|
sc.userService = usertest.NewUserServiceFake()
|
||||||
sc.orgService = orgtest.NewOrgServiceFake()
|
|
||||||
sc.apiKeyService = &apikeytest.Service{}
|
ctxHdlr := getContextHandler(t, cfg, sc.authnService)
|
||||||
sc.oauthTokenService = &authtest.FakeOAuthTokenService{}
|
|
||||||
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService, sc.apiKeyService, sc.userService, sc.orgService, sc.oauthTokenService)
|
|
||||||
sc.sqlStore = ctxHdlr.SQLStore
|
|
||||||
sc.contextHandler = ctxHdlr
|
|
||||||
sc.m.Use(ctxHdlr.Middleware)
|
sc.m.Use(ctxHdlr.Middleware)
|
||||||
sc.m.Use(OrgRedirect(sc.cfg, sc.userService))
|
sc.m.Use(OrgRedirect(sc.cfg, sc.userService))
|
||||||
|
|
||||||
sc.userAuthTokenService = ctxHdlr.AuthTokenService.(*authtest.FakeUserAuthTokenService)
|
|
||||||
sc.jwtAuthService = ctxHdlr.JWTAuthService.(*jwt.FakeJWTService)
|
|
||||||
sc.remoteCacheService = ctxHdlr.RemoteCache
|
|
||||||
|
|
||||||
sc.defaultHandler = func(c *contextmodel.ReqContext) {
|
sc.defaultHandler = func(c *contextmodel.ReqContext) {
|
||||||
require.NotNil(t, c)
|
require.NotNil(t, c)
|
||||||
t.Log("Default HTTP handler called")
|
t.Log("Default HTTP handler called")
|
||||||
@ -933,40 +250,14 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *dbtest.FakeDB,
|
func getContextHandler(t *testing.T, cfg *setting.Cfg, authnService authn.Service) *contexthandler.ContextHandler {
|
||||||
loginService *loginservice.LoginServiceMock, apiKeyService *apikeytest.Service,
|
|
||||||
userService *usertest.FakeUserService, orgService *orgtest.FakeOrgService,
|
|
||||||
oauthTokenService *authtest.FakeOAuthTokenService,
|
|
||||||
) *contexthandler.ContextHandler {
|
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
if cfg == nil {
|
tracer := tracing.NewFakeTracer()
|
||||||
cfg = setting.NewCfg()
|
return contexthandler.ProvideService(cfg, authtest.NewFakeUserAuthTokenService(), nil,
|
||||||
}
|
nil, nil, nil, tracer, nil,
|
||||||
cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{
|
nil, nil, nil, nil, nil,
|
||||||
Name: "database",
|
nil, featuremgmt.WithFeatures(featuremgmt.FlagAccessTokenExpirationCheck),
|
||||||
}
|
authnService, &anontest.FakeAnonymousSessionService{},
|
||||||
|
)
|
||||||
remoteCacheSvc := remotecache.NewFakeStore(t)
|
|
||||||
userAuthTokenSvc := authtest.NewFakeUserAuthTokenService()
|
|
||||||
renderSvc := &fakeRenderService{}
|
|
||||||
authJWTSvc := jwt.NewFakeJWTService()
|
|
||||||
tracer := tracing.InitializeTracerForTest()
|
|
||||||
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService,
|
|
||||||
userService, mockSQLStore, &service.LDAPFakeService{ExpectedError: service.ErrUnableToCreateLDAPClient})
|
|
||||||
authenticator := &logintest.AuthenticatorFake{ExpectedUser: &user.User{}}
|
|
||||||
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc,
|
|
||||||
remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy,
|
|
||||||
loginService, apiKeyService, authenticator, userService, orgService,
|
|
||||||
oauthTokenService,
|
|
||||||
featuremgmt.WithFeatures(featuremgmt.FlagAccessTokenExpirationCheck),
|
|
||||||
&authntest.FakeService{}, &anontest.FakeAnonymousSessionService{})
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeRenderService struct {
|
|
||||||
rendering.Service
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *fakeRenderService) Init() error {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOrgRedirectMiddleware(t *testing.T) {
|
func TestOrgRedirectMiddleware(t *testing.T) {
|
||||||
@ -46,15 +44,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
middlewareScenario(t, tc.desc, func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, tc.desc, func(t *testing.T, sc *scenarioContext) {
|
||||||
sc.withTokenSessionCookie("token")
|
sc.withIdentity(&authn.Identity{})
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12}
|
|
||||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
|
||||||
return &auth.UserToken{
|
|
||||||
UserId: 0,
|
|
||||||
UnhashedToken: "",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.m.Get("/", sc.defaultHandler)
|
sc.m.Get("/", sc.defaultHandler)
|
||||||
sc.fakeReq("GET", tc.input).exec()
|
sc.fakeReq("GET", tc.input).exec()
|
||||||
|
|
||||||
@ -64,19 +54,11 @@ func TestOrgRedirectMiddleware(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
middlewareScenario(t, "when setting an invalid org for user", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "when setting an invalid org for user", func(t *testing.T, sc *scenarioContext) {
|
||||||
sc.withTokenSessionCookie("token")
|
sc.withIdentity(&authn.Identity{})
|
||||||
sc.userService.ExpectedSetUsingOrgError = fmt.Errorf("")
|
sc.userService.ExpectedSetUsingOrgError = fmt.Errorf("")
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{OrgID: 1, UserID: 12}
|
|
||||||
|
|
||||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
|
||||||
return &auth.UserToken{
|
|
||||||
UserId: 12,
|
|
||||||
UnhashedToken: "",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.m.Get("/", sc.defaultHandler)
|
sc.m.Get("/", sc.defaultHandler)
|
||||||
sc.fakeReq("GET", "/?orgId=3").exec()
|
sc.fakeReq("GET", "/?orgId=1").exec()
|
||||||
|
|
||||||
require.Equal(t, 404, sc.resp.Code)
|
require.Equal(t, 404, sc.resp.Code)
|
||||||
})
|
})
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/services/auth"
|
"github.com/grafana/grafana/pkg/services/auth"
|
||||||
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/quota/quotatest"
|
"github.com/grafana/grafana/pkg/services/quota/quotatest"
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/web"
|
"github.com/grafana/grafana/pkg/web"
|
||||||
)
|
)
|
||||||
@ -53,14 +52,7 @@ func TestMiddlewareQuota(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("with user logged in", func(t *testing.T) {
|
t.Run("with user logged in", func(t *testing.T) {
|
||||||
setUp := func(sc *scenarioContext) {
|
setUp := func(sc *scenarioContext) {
|
||||||
sc.withTokenSessionCookie("token")
|
sc.withIdentity(&authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{UserId: 12}})
|
||||||
sc.userService.ExpectedSignedInUser = &user.SignedInUser{UserID: 12}
|
|
||||||
sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*auth.UserToken, error) {
|
|
||||||
return &auth.UserToken{
|
|
||||||
UserId: 12,
|
|
||||||
UnhashedToken: "",
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
middlewareScenario(t, "global datasource quota reached", func(t *testing.T, sc *scenarioContext) {
|
middlewareScenario(t, "global datasource quota reached", func(t *testing.T, sc *scenarioContext) {
|
||||||
|
@ -8,9 +8,10 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||||
|
"github.com/grafana/grafana/pkg/services/user/usertest"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/web"
|
"github.com/grafana/grafana/pkg/web"
|
||||||
)
|
)
|
||||||
@ -66,13 +67,10 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
|
|||||||
sc.m.Use(AddDefaultResponseHeaders(cfg))
|
sc.m.Use(AddDefaultResponseHeaders(cfg))
|
||||||
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
||||||
|
|
||||||
sc.userAuthTokenService = authtest.NewFakeUserAuthTokenService()
|
contextHandler := getContextHandler(t, setting.NewCfg(), &authntest.FakeService{ExpectedIdentity: &authn.Identity{}})
|
||||||
sc.remoteCacheService = remotecache.NewFakeStore(t)
|
|
||||||
|
|
||||||
contextHandler := getContextHandler(t, nil, nil, nil, nil, nil, nil, nil)
|
|
||||||
sc.m.Use(contextHandler.Middleware)
|
sc.m.Use(contextHandler.Middleware)
|
||||||
// mock out gc goroutine
|
// mock out gc goroutine
|
||||||
sc.m.Use(OrgRedirect(cfg, sc.userService))
|
sc.m.Use(OrgRedirect(cfg, usertest.NewUserServiceFake()))
|
||||||
|
|
||||||
sc.defaultHandler = func(c *contextmodel.ReqContext) {
|
sc.defaultHandler = func(c *contextmodel.ReqContext) {
|
||||||
sc.context = c
|
sc.context = c
|
||||||
|
@ -8,69 +8,35 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/grafana/grafana/pkg/infra/db"
|
"github.com/grafana/grafana/pkg/services/authn"
|
||||||
"github.com/grafana/grafana/pkg/infra/db/dbtest"
|
"github.com/grafana/grafana/pkg/services/authn/authntest"
|
||||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
|
||||||
"github.com/grafana/grafana/pkg/services/apikey/apikeytest"
|
|
||||||
"github.com/grafana/grafana/pkg/services/auth/authtest"
|
|
||||||
"github.com/grafana/grafana/pkg/services/auth/jwt"
|
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
|
||||||
"github.com/grafana/grafana/pkg/services/contexthandler/ctxkey"
|
"github.com/grafana/grafana/pkg/services/contexthandler/ctxkey"
|
||||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model"
|
||||||
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
|
||||||
"github.com/grafana/grafana/pkg/services/org/orgtest"
|
|
||||||
"github.com/grafana/grafana/pkg/services/user/usertest"
|
"github.com/grafana/grafana/pkg/services/user/usertest"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
"github.com/grafana/grafana/pkg/web"
|
"github.com/grafana/grafana/pkg/web"
|
||||||
)
|
)
|
||||||
|
|
||||||
type scenarioContext struct {
|
type scenarioContext struct {
|
||||||
t *testing.T
|
t *testing.T
|
||||||
m *web.Mux
|
m *web.Mux
|
||||||
context *contextmodel.ReqContext
|
context *contextmodel.ReqContext
|
||||||
resp *httptest.ResponseRecorder
|
resp *httptest.ResponseRecorder
|
||||||
apiKey string
|
respJson map[string]interface{}
|
||||||
authHeader string
|
handlerFunc handlerFunc
|
||||||
jwtAuthHeader string
|
defaultHandler web.Handler
|
||||||
tokenSessionCookie string
|
url string
|
||||||
respJson map[string]interface{}
|
authnService *authntest.FakeService
|
||||||
handlerFunc handlerFunc
|
userService *usertest.FakeUserService
|
||||||
defaultHandler web.Handler
|
cfg *setting.Cfg
|
||||||
url string
|
|
||||||
userAuthTokenService *authtest.FakeUserAuthTokenService
|
|
||||||
jwtAuthService *jwt.FakeJWTService
|
|
||||||
remoteCacheService *remotecache.RemoteCache
|
|
||||||
cfg *setting.Cfg
|
|
||||||
sqlStore db.DB
|
|
||||||
mockSQLStore *dbtest.FakeDB
|
|
||||||
contextHandler *contexthandler.ContextHandler
|
|
||||||
loginService *loginservice.LoginServiceMock
|
|
||||||
apiKeyService *apikeytest.Service
|
|
||||||
userService *usertest.FakeUserService
|
|
||||||
oauthTokenService *authtest.FakeOAuthTokenService
|
|
||||||
orgService *orgtest.FakeOrgService
|
|
||||||
|
|
||||||
req *http.Request
|
req *http.Request
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *scenarioContext) withValidApiKey() *scenarioContext {
|
// set identity to use for request
|
||||||
sc.apiKey = "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9"
|
func (sc *scenarioContext) withIdentity(identity *authn.Identity) {
|
||||||
return sc
|
sc.authnService.ExpectedErr = nil
|
||||||
}
|
sc.authnService.ExpectedIdentity = identity
|
||||||
|
|
||||||
func (sc *scenarioContext) withTokenSessionCookie(unhashedToken string) *scenarioContext {
|
|
||||||
sc.tokenSessionCookie = unhashedToken
|
|
||||||
return sc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sc *scenarioContext) withAuthorizationHeader(authHeader string) *scenarioContext {
|
|
||||||
sc.authHeader = authHeader
|
|
||||||
return sc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sc *scenarioContext) withJWTAuthHeader(jwtAuthHeader string) *scenarioContext {
|
|
||||||
sc.jwtAuthHeader = jwtAuthHeader
|
|
||||||
return sc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *scenarioContext) fakeReq(method, url string) *scenarioContext {
|
func (sc *scenarioContext) fakeReq(method, url string) *scenarioContext {
|
||||||
@ -116,29 +82,6 @@ func (sc *scenarioContext) fakeReqWithParams(method, url string, queryParams map
|
|||||||
func (sc *scenarioContext) exec() {
|
func (sc *scenarioContext) exec() {
|
||||||
sc.t.Helper()
|
sc.t.Helper()
|
||||||
|
|
||||||
if sc.apiKey != "" {
|
|
||||||
sc.t.Logf(`Adding header "Authorization: Bearer %s"`, sc.apiKey)
|
|
||||||
sc.req.Header.Set("Authorization", "Bearer "+sc.apiKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sc.authHeader != "" {
|
|
||||||
sc.t.Logf(`Adding header "Authorization: %s"`, sc.authHeader)
|
|
||||||
sc.req.Header.Set("Authorization", sc.authHeader)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sc.jwtAuthHeader != "" {
|
|
||||||
sc.t.Logf(`Adding header "%s: %s"`, sc.cfg.JWTAuthHeaderName, sc.jwtAuthHeader)
|
|
||||||
sc.req.Header.Set(sc.cfg.JWTAuthHeaderName, sc.jwtAuthHeader)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sc.tokenSessionCookie != "" {
|
|
||||||
sc.t.Log(`Adding cookie`, "name", sc.cfg.LoginCookieName, "value", sc.tokenSessionCookie)
|
|
||||||
sc.req.AddCookie(&http.Cookie{
|
|
||||||
Name: sc.cfg.LoginCookieName,
|
|
||||||
Value: sc.tokenSessionCookie,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
sc.m.ServeHTTP(sc.resp, sc.req)
|
sc.m.ServeHTTP(sc.resp, sc.req)
|
||||||
|
|
||||||
if sc.resp.Header().Get("Content-Type") == "application/json; charset=UTF-8" {
|
if sc.resp.Header().Get("Content-Type") == "application/json; charset=UTF-8" {
|
||||||
|
@ -337,9 +337,7 @@ type RedirectValidator func(url string) error
|
|||||||
// HandleLoginResponse is a utility function to perform common operations after a successful login and returns response.NormalResponse
|
// HandleLoginResponse is a utility function to perform common operations after a successful login and returns response.NormalResponse
|
||||||
func HandleLoginResponse(r *http.Request, w http.ResponseWriter, cfg *setting.Cfg, identity *Identity, validator RedirectValidator) *response.NormalResponse {
|
func HandleLoginResponse(r *http.Request, w http.ResponseWriter, cfg *setting.Cfg, identity *Identity, validator RedirectValidator) *response.NormalResponse {
|
||||||
result := map[string]interface{}{"message": "Logged in"}
|
result := map[string]interface{}{"message": "Logged in"}
|
||||||
if redirectURL := handleLogin(r, w, cfg, identity, validator); redirectURL != cfg.AppSubURL+"/" {
|
result["redirectUrl"] = handleLogin(r, w, cfg, identity, validator)
|
||||||
result["redirectUrl"] = redirectURL
|
|
||||||
}
|
|
||||||
return response.JSON(http.StatusOK, result)
|
return response.JSON(http.StatusOK, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -356,9 +354,11 @@ func HandleLoginRedirectResponse(r *http.Request, w http.ResponseWriter, cfg *se
|
|||||||
|
|
||||||
func handleLogin(r *http.Request, w http.ResponseWriter, cfg *setting.Cfg, identity *Identity, validator RedirectValidator) string {
|
func handleLogin(r *http.Request, w http.ResponseWriter, cfg *setting.Cfg, identity *Identity, validator RedirectValidator) string {
|
||||||
redirectURL := cfg.AppSubURL + "/"
|
redirectURL := cfg.AppSubURL + "/"
|
||||||
if redirectTo := getRedirectURL(r); len(redirectTo) > 0 && validator(redirectTo) == nil {
|
if redirectTo := getRedirectURL(r); len(redirectTo) > 0 {
|
||||||
cookies.DeleteCookie(w, "redirect_to", nil)
|
if validator(redirectTo) == nil {
|
||||||
redirectURL = redirectTo
|
redirectURL = redirectTo
|
||||||
|
}
|
||||||
|
cookies.DeleteCookie(w, "redirect_to", cookieOptions(cfg))
|
||||||
}
|
}
|
||||||
|
|
||||||
WriteSessionCookie(w, cfg, identity.SessionToken)
|
WriteSessionCookie(w, cfg, identity.SessionToken)
|
||||||
@ -386,17 +386,32 @@ func WriteSessionCookie(w http.ResponseWriter, cfg *setting.Cfg, token *usertoke
|
|||||||
cookies.WriteCookie(w, cfg.LoginCookieName, url.QueryEscape(token.UnhashedToken), maxAge, nil)
|
cookies.WriteCookie(w, cfg.LoginCookieName, url.QueryEscape(token.UnhashedToken), maxAge, nil)
|
||||||
expiry := token.NextRotation(time.Duration(cfg.TokenRotationIntervalMinutes) * time.Minute)
|
expiry := token.NextRotation(time.Duration(cfg.TokenRotationIntervalMinutes) * time.Minute)
|
||||||
cookies.WriteCookie(w, sessionExpiryCookie, url.QueryEscape(strconv.FormatInt(expiry.Unix(), 10)), maxAge, func() cookies.CookieOptions {
|
cookies.WriteCookie(w, sessionExpiryCookie, url.QueryEscape(strconv.FormatInt(expiry.Unix(), 10)), maxAge, func() cookies.CookieOptions {
|
||||||
opts := cookies.NewCookieOptions()
|
opts := cookieOptions(cfg)()
|
||||||
opts.NotHttpOnly = true
|
opts.NotHttpOnly = true
|
||||||
return opts
|
return opts
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteSessionCookie(w http.ResponseWriter, cfg *setting.Cfg) {
|
func DeleteSessionCookie(w http.ResponseWriter, cfg *setting.Cfg) {
|
||||||
cookies.DeleteCookie(w, cfg.LoginCookieName, nil)
|
cookies.DeleteCookie(w, cfg.LoginCookieName, cookieOptions(cfg))
|
||||||
cookies.DeleteCookie(w, sessionExpiryCookie, func() cookies.CookieOptions {
|
cookies.DeleteCookie(w, sessionExpiryCookie, func() cookies.CookieOptions {
|
||||||
opts := cookies.NewCookieOptions()
|
opts := cookieOptions(cfg)()
|
||||||
opts.NotHttpOnly = true
|
opts.NotHttpOnly = true
|
||||||
return opts
|
return opts
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cookieOptions(cfg *setting.Cfg) func() cookies.CookieOptions {
|
||||||
|
return func() cookies.CookieOptions {
|
||||||
|
path := "/"
|
||||||
|
if len(cfg.AppSubURL) > 0 {
|
||||||
|
path = cfg.AppSubURL
|
||||||
|
}
|
||||||
|
return cookies.CookieOptions{
|
||||||
|
Path: path,
|
||||||
|
Secure: cfg.CookieSecure,
|
||||||
|
SameSiteDisabled: cfg.CookieSameSiteDisabled,
|
||||||
|
SameSiteMode: cfg.CookieSameSiteMode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -17,14 +17,6 @@ import (
|
|||||||
"github.com/grafana/grafana/pkg/services/login"
|
"github.com/grafana/grafana/pkg/services/login"
|
||||||
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
"github.com/grafana/grafana/pkg/services/oauthtoken"
|
||||||
"github.com/grafana/grafana/pkg/services/user"
|
"github.com/grafana/grafana/pkg/services/user"
|
||||||
"github.com/grafana/grafana/pkg/util/errutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errExpiredAccessToken = errutil.NewBase(
|
|
||||||
errutil.StatusUnauthorized,
|
|
||||||
"oauth.expired-token",
|
|
||||||
errutil.WithPublicMessage("OAuth access token expired"))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
|
func ProvideOAuthTokenSync(service oauthtoken.OAuthTokenService, sessionService auth.UserTokenService, socialService social.Service) *OAuthTokenSync {
|
||||||
@ -122,7 +114,7 @@ func (s *OAuthTokenSync) SyncOauthTokenHook(ctx context.Context, identity *authn
|
|||||||
s.log.FromContext(ctx).Error("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
|
s.log.FromContext(ctx).Error("Failed to revoke session token", "id", identity.ID, "tokenId", identity.SessionToken.Id, "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return errExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
|
return authn.ErrExpiredAccessToken.Errorf("oauth access token could not be refreshed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -89,7 +89,7 @@ func TestOAuthTokenSync_SyncOAuthTokenHook(t *testing.T) {
|
|||||||
expectInvalidateOauthTokensCalled: true,
|
expectInvalidateOauthTokensCalled: true,
|
||||||
expectRevokeTokenCalled: true,
|
expectRevokeTokenCalled: true,
|
||||||
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
expectedHasEntryToken: &login.UserAuth{OAuthExpiry: time.Now().Add(-10 * time.Minute)},
|
||||||
expectedErr: errExpiredAccessToken,
|
expectedErr: authn.ErrExpiredAccessToken,
|
||||||
}, {
|
}, {
|
||||||
desc: "should skip sync when use_refresh_token is disabled",
|
desc: "should skip sync when use_refresh_token is disabled",
|
||||||
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},
|
identity: &authn.Identity{ID: "user:1", SessionToken: &auth.UserToken{}, AuthenticatedBy: login.GitLabAuthModule},
|
||||||
|
@ -7,4 +7,5 @@ var (
|
|||||||
ErrUnsupportedClient = errutil.NewBase(errutil.StatusBadRequest, "auth.client.unsupported")
|
ErrUnsupportedClient = errutil.NewBase(errutil.StatusBadRequest, "auth.client.unsupported")
|
||||||
ErrClientNotConfigured = errutil.NewBase(errutil.StatusBadRequest, "auth.client.notConfigured")
|
ErrClientNotConfigured = errutil.NewBase(errutil.StatusBadRequest, "auth.client.notConfigured")
|
||||||
ErrUnsupportedIdentity = errutil.NewBase(errutil.StatusNotImplemented, "auth.identity.unsupported")
|
ErrUnsupportedIdentity = errutil.NewBase(errutil.StatusNotImplemented, "auth.identity.unsupported")
|
||||||
|
ErrExpiredAccessToken = errutil.NewBase(errutil.StatusUnauthorized, "oauth.expired-token", errutil.WithPublicMessage("OAuth access token expired"))
|
||||||
)
|
)
|
||||||
|
@ -55,12 +55,11 @@ func ProvideService(cfg *setting.Cfg, tokenService auth.UserTokenService, jwtSer
|
|||||||
authnService authn.Service, anonDeviceService anonymous.Service,
|
authnService authn.Service, anonDeviceService anonymous.Service,
|
||||||
) *ContextHandler {
|
) *ContextHandler {
|
||||||
return &ContextHandler{
|
return &ContextHandler{
|
||||||
Cfg: cfg,
|
Cfg: cfg,
|
||||||
AuthTokenService: tokenService,
|
AuthTokenService: tokenService,
|
||||||
JWTAuthService: jwtService,
|
JWTAuthService: jwtService,
|
||||||
RemoteCache: remoteCache,
|
RemoteCache: remoteCache,
|
||||||
RenderService: renderService,
|
RenderService: renderService, SQLStore: sqlStore,
|
||||||
SQLStore: sqlStore,
|
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
authProxy: authProxy,
|
authProxy: authProxy,
|
||||||
authenticator: authenticator,
|
authenticator: authenticator,
|
||||||
@ -173,7 +172,7 @@ func (h *ContextHandler) Middleware(next http.Handler) http.Handler {
|
|||||||
if h.Cfg.AuthBrokerEnabled {
|
if h.Cfg.AuthBrokerEnabled {
|
||||||
identity, err := h.AuthnService.Authenticate(ctx, &authn.Request{HTTPRequest: reqContext.Req, Resp: reqContext.Resp})
|
identity, err := h.AuthnService.Authenticate(ctx, &authn.Request{HTTPRequest: reqContext.Req, Resp: reqContext.Resp})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, auth.ErrInvalidSessionToken) {
|
if errors.Is(err, auth.ErrInvalidSessionToken) || errors.Is(err, authn.ErrExpiredAccessToken) {
|
||||||
// Burn the cookie in case of invalid, expired or missing token
|
// Burn the cookie in case of invalid, expired or missing token
|
||||||
reqContext.Resp.Before(h.deleteInvalidCookieEndOfRequestFunc(reqContext))
|
reqContext.Resp.Before(h.deleteInvalidCookieEndOfRequestFunc(reqContext))
|
||||||
}
|
}
|
||||||
|
@ -968,11 +968,12 @@ var skipStaticRootValidation = false
|
|||||||
|
|
||||||
func NewCfg() *Cfg {
|
func NewCfg() *Cfg {
|
||||||
return &Cfg{
|
return &Cfg{
|
||||||
Target: []string{},
|
Target: []string{},
|
||||||
Logger: log.New("settings"),
|
Logger: log.New("settings"),
|
||||||
Raw: ini.Empty(),
|
Raw: ini.Empty(),
|
||||||
Azure: &azsettings.AzureSettings{},
|
Azure: &azsettings.AzureSettings{},
|
||||||
RBACEnabled: true,
|
RBACEnabled: true,
|
||||||
|
AuthBrokerEnabled: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user