Chore: Remove bus from authproxy (#46936)

* Make authproxy injectable

* Fix import

* Provide function was in wrong place

* Fixing tests

* More imports and rollback a change

* Fix lint
This commit is contained in:
Selene 2022-03-30 17:01:24 +02:00 committed by GitHub
parent 118b87ee8f
commit 8e52dbb87b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 189 additions and 260 deletions

View File

@ -27,11 +27,13 @@ import (
"github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/dashboards"
dashboardsstore "github.com/grafana/grafana/pkg/services/dashboards/database"
dashboardservice "github.com/grafana/grafana/pkg/services/dashboards/manager"
"github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/ldap"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/searchusers"
@ -193,7 +195,8 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
authJWTSvc := models.NewFakeJWTService()
tracer, err := tracing.InitializeTracerForTest()
require.NoError(t, err)
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer)
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, sqlStore)
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
return ctxHdlr
}

View File

@ -2,7 +2,6 @@ package middleware
import (
"context"
"errors"
"fmt"
"io"
"net"
@ -25,8 +24,10 @@ import (
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/grafana/grafana/pkg/web"
@ -364,10 +365,7 @@ func TestMiddlewareContext(t *testing.T) {
const group = "grafana-core-team"
middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: orgID, UserId: query.UserId}
return nil
})
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
h, err := authproxy.HashCacheKey(hdrName + "-" + group)
require.NoError(t, err)
@ -387,11 +385,11 @@ func TestMiddlewareContext(t *testing.T) {
middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) {
var actualAuthProxyAutoSignUp *bool = nil
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
actualAuthProxyAutoSignUp = &cmd.SignupAllowed
return login.ErrInvalidCredentials
})
return nil
}
sc.loginService.ExpectedError = login.ErrInvalidCredentials
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -407,18 +405,8 @@ func TestMiddlewareContext(t *testing.T) {
})
middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
if query.UserId > 0 {
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil
}
return models.ErrUserNotFound
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
return nil
})
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
sc.loginService.ExpectedUser = &models.User{Id: userID}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -435,19 +423,11 @@ func TestMiddlewareContext(t *testing.T) {
middlewareScenario(t, "Should assign role from header to default org", func(t *testing.T, sc *scenarioContext) {
var storedRoleInfo map[int64]models.RoleType = nil
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
if query.UserId > 0 {
query.Result = &models.SignedInUser{OrgId: defaultOrgId, UserId: userID, OrgRole: storedRoleInfo[defaultOrgId]}
return nil
}
return models.ErrUserNotFound
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
storedRoleInfo = cmd.ExternalUser.OrgRoles
return nil
})
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: defaultOrgId, UserId: userID, OrgRole: storedRoleInfo[defaultOrgId]}
return &models.User{Id: userID}
}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -466,19 +446,11 @@ func TestMiddlewareContext(t *testing.T) {
middlewareScenario(t, "Should NOT assign role from header to non-default org", func(t *testing.T, sc *scenarioContext) {
var storedRoleInfo map[int64]models.RoleType = nil
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
if query.UserId > 0 {
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID, OrgRole: storedRoleInfo[orgID]}
return nil
}
return models.ErrUserNotFound
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
storedRoleInfo = cmd.ExternalUser.OrgRoles
return nil
})
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID, OrgRole: storedRoleInfo[orgID]}
return &models.User{Id: userID}
}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -499,27 +471,17 @@ func TestMiddlewareContext(t *testing.T) {
})
middlewareScenario(t, "Should use organisation specified by targetOrgId parameter", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
if query.UserId > 0 {
query.Result = &models.SignedInUser{OrgId: query.OrgId, UserId: userID}
return nil
}
return models.ErrUserNotFound
})
var targetOrgID int64 = 123
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: targetOrgID, UserId: userID}
sc.loginService.ExpectedUser = &models.User{Id: userID}
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
return nil
})
targetOrgID := 123
sc.fakeReq("GET", fmt.Sprintf("/?targetOrgId=%d", targetOrgID))
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec()
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, int64(targetOrgID), sc.context.OrgId)
assert.Equal(t, targetOrgID, sc.context.OrgId)
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.LDAPEnabled = false
@ -554,15 +516,8 @@ func TestMiddlewareContext(t *testing.T) {
const userID int64 = 12
const orgID int64 = 2
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
return nil
})
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
sc.loginService.ExpectedUser = &models.User{Id: userID}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -577,15 +532,8 @@ func TestMiddlewareContext(t *testing.T) {
})
middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
return nil
})
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
sc.loginService.ExpectedUser = &models.User{Id: userID}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -602,15 +550,7 @@ func TestMiddlewareContext(t *testing.T) {
})
middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil
})
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{Id: userID}
return nil
})
sc.loginService.ExpectedUser = &models.User{Id: userID}
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
@ -626,10 +566,6 @@ func TestMiddlewareContext(t *testing.T) {
})
middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("LDAP", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
return errors.New("Do not add user")
})
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec()
@ -639,10 +575,6 @@ func TestMiddlewareContext(t *testing.T) {
}, configure)
middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("Do not have the user", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
return errors.New("Do not add user")
})
sc.fakeReq("GET", "/")
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec()
@ -684,7 +616,9 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
sc.m.UseMiddleware(AddCSPHeader(cfg, logger))
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
ctxHdlr := getContextHandler(t, cfg)
sc.mockSQLStore = mockstore.NewSQLStoreMock()
sc.loginService = &loginservice.LoginServiceMock{}
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService)
sc.sqlStore = ctxHdlr.SQLStore
sc.contextHandler = ctxHdlr
sc.m.Use(ctxHdlr.Middleware)
@ -714,7 +648,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
})
}
func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHandler {
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock, loginService *loginservice.LoginServiceMock) *contexthandler.ContextHandler {
t.Helper()
sqlStore := sqlstore.InitTestDB(t)
@ -730,8 +664,9 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
renderSvc := &fakeRenderService{}
authJWTSvc := models.NewFakeJWTService()
tracer, err := tracing.InitializeTracerForTest()
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, mockSQLStore)
require.NoError(t, err)
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer)
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
}
type fakeRenderService struct {

View File

@ -33,7 +33,7 @@ func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateL
m := web.New()
m.UseMiddleware(web.Renderer("../../public/views", "[[", "]]"))
m.Use(getContextHandler(t, cfg).Middleware)
m.Use(getContextHandler(t, cfg, nil, nil).Middleware)
m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler)
fn(func() *httptest.ResponseRecorder {

View File

@ -70,7 +70,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
sc.remoteCacheService = remotecache.NewFakeStore(t)
contextHandler := getContextHandler(t, nil)
contextHandler := getContextHandler(t, nil, nil, nil)
sc.m.Use(contextHandler.Middleware)
// mock out gc goroutine
sc.m.Use(OrgRedirect(cfg))

View File

@ -10,7 +10,9 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web"
"github.com/stretchr/testify/require"
@ -34,7 +36,9 @@ type scenarioContext struct {
remoteCacheService *remotecache.RemoteCache
cfg *setting.Cfg
sqlStore sqlstore.Store
mockSQLStore *mockstore.SQLStoreMock
contextHandler *contexthandler.ContextHandler
loginService *loginservice.LoginServiceMock
req *http.Request
}

View File

@ -32,6 +32,7 @@ import (
"github.com/grafana/grafana/pkg/services/cleanup"
"github.com/grafana/grafana/pkg/services/comments"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/dashboardimport"
dashboardimportservice "github.com/grafana/grafana/pkg/services/dashboardimport/service"
"github.com/grafana/grafana/pkg/services/dashboards"
@ -228,6 +229,7 @@ var wireBasicSet = wire.NewSet(
wire.Bind(new(alerting.DashAlertExtractor), new(*alerting.DashAlertExtractorService)),
comments.ProvideService,
guardian.ProvideService,
authproxy.ProvideAuthProxy,
)
var wireSet = wire.NewSet(

View File

@ -6,13 +6,13 @@ import (
"net/http"
"testing"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting"
@ -20,42 +20,18 @@ import (
"github.com/stretchr/testify/require"
)
const userID = int64(1)
const orgID = int64(4)
// Test initContextWithAuthProxy with a cached user ID that is no longer valid.
//
// In this case, the cache entry should be ignored/cleared and another attempt should be done to sign the user
// in without cache.
func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
const name = "markelog"
const userID = int64(1)
const orgID = int64(4)
svc := getContextHandler(t)
// XXX: These handlers have to be injected AFTER calling getContextHandler, since the latter
// creates a SQLStore which installs its own handlers.
upsertHandler := func(ctx context.Context, cmd *models.UpsertUserCommand) error {
require.Equal(t, name, cmd.ExternalUser.Login)
cmd.Result = &models.User{Id: userID}
return nil
}
getUserHandler := func(ctx context.Context, query *models.GetSignedInUserQuery) error {
// Simulate that the cached user ID is stale
if query.UserId != userID {
return models.ErrUserNotFound
}
query.Result = &models.SignedInUser{
UserId: userID,
OrgId: orgID,
}
return nil
}
bus.AddHandler("", upsertHandler)
bus.AddHandler("", getUserHandler)
t.Cleanup(func() {
bus.ClearBusHandlers()
})
req, err := http.NewRequest("POST", "http://example.com", nil)
require.NoError(t, err)
ctx := &models.ReqContext{
@ -106,5 +82,24 @@ func getContextHandler(t *testing.T) *ContextHandler {
tracer, err := tracing.InitializeTracerForTest()
require.NoError(t, err)
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer)
loginService := loginservice.LoginServiceMock{ExpectedUser: &models.User{Id: userID}}
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, &FakeGetSignUserStore{})
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
}
type FakeGetSignUserStore struct {
sqlstore.Store
}
func (f *FakeGetSignUserStore) GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error {
if query.UserId != userID {
return models.ErrUserNotFound
}
query.Result = &models.SignedInUser{
UserId: userID,
OrgId: orgID,
}
return nil
}

View File

@ -13,12 +13,13 @@ import (
"strings"
"time"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/ldap"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/multildap"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
)
@ -49,11 +50,22 @@ var supportedHeaderFields = []string{"Name", "Email", "Login", "Groups", "Role"}
// AuthProxy struct
type AuthProxy struct {
cfg *setting.Cfg
remoteCache *remotecache.RemoteCache
ctx *models.ReqContext
orgID int64
header string
cfg *setting.Cfg
remoteCache *remotecache.RemoteCache
loginService login.Service
sqlStore sqlstore.Store
logger log.Logger
}
func ProvideAuthProxy(cfg *setting.Cfg, remoteCache *remotecache.RemoteCache, loginService login.Service, sqlStore sqlstore.Store) *AuthProxy {
return &AuthProxy{
cfg: cfg,
remoteCache: remoteCache,
loginService: loginService,
sqlStore: sqlStore,
logger: log.New("auth.proxy"),
}
}
// Error auth proxy specific error
@ -75,40 +87,20 @@ func (err Error) Error() string {
return err.Message
}
// Options for the AuthProxy
type Options struct {
RemoteCache *remotecache.RemoteCache
Ctx *models.ReqContext
OrgID int64
}
// New instance of the AuthProxy.
func New(cfg *setting.Cfg, options *Options) *AuthProxy {
auth := &AuthProxy{
remoteCache: options.RemoteCache,
cfg: cfg,
ctx: options.Ctx,
orgID: options.OrgID,
}
auth.header = auth.getDecodedHeader(cfg.AuthProxyHeaderName)
return auth
}
// IsEnabled checks if the auth proxy is enabled.
func (auth *AuthProxy) IsEnabled() bool {
// Bail if the setting is not enabled
return auth.cfg.AuthProxyEnabled
}
// HasHeader checks if the we have specified header
func (auth *AuthProxy) HasHeader() bool {
return len(auth.header) != 0
// HasHeader checks if we have specified header
func (auth *AuthProxy) HasHeader(reqCtx *models.ReqContext) bool {
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
return len(header) != 0
}
// IsAllowedIP returns whether provided IP is allowed.
func (auth *AuthProxy) IsAllowedIP() error {
ip := auth.ctx.Req.RemoteAddr
func (auth *AuthProxy) IsAllowedIP(ip string) error {
if len(strings.TrimSpace(auth.cfg.AuthProxyWhitelist)) == 0 {
return nil
}
@ -137,7 +129,7 @@ func (auth *AuthProxy) IsAllowedIP() error {
}
return newError("proxy authentication required", fmt.Errorf(
"request for user (%s) from %s is not from the authentication proxy", auth.header,
"request for user from %s is not from the authentication proxy",
sourceIP,
))
}
@ -153,10 +145,11 @@ func HashCacheKey(key string) (string, error) {
// getKey forms a key for the cache based on the headers received as part of the authentication flow.
// Our configuration supports multiple headers. The main header contains the email or username.
// And the additional ones that allow us to specify extra attributes: Name, Email, Role, or Groups.
func (auth *AuthProxy) getKey() (string, error) {
key := strings.TrimSpace(auth.header) // start the key with the main header
func (auth *AuthProxy) getKey(reqCtx *models.ReqContext) (string, error) {
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
key := strings.TrimSpace(header) // start the key with the main header
auth.headersIterator(func(_, header string) {
auth.headersIterator(reqCtx, func(_, header string) {
key = strings.Join([]string{key, header}, "-") // compose the key with any additional headers
})
@ -168,17 +161,17 @@ func (auth *AuthProxy) getKey() (string, error) {
}
// Login logs in user ID by whatever means possible.
func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error) {
func (auth *AuthProxy) Login(reqCtx *models.ReqContext, ignoreCache bool) (int64, error) {
if !ignoreCache {
// Error here means absent cache - we don't need to handle that
id, err := auth.GetUserViaCache(logger)
id, err := auth.getUserViaCache(reqCtx)
if err == nil && id != 0 {
return id, nil
}
}
if isLDAPEnabled(auth.cfg) {
id, err := auth.LoginViaLDAP()
id, err := auth.LoginViaLDAP(reqCtx)
if err != nil {
if errors.Is(err, ldap.ErrInvalidCredentials) {
return 0, newError("proxy authentication required", ldap.ErrInvalidCredentials)
@ -189,7 +182,7 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
return id, nil
}
id, err := auth.LoginViaHeader()
id, err := auth.loginViaHeader(reqCtx)
if err != nil {
return 0, newError("failed to log in as user, specified in auth proxy header", err)
}
@ -197,87 +190,89 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
return id, nil
}
// GetUserViaCache gets user ID from cache.
func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) {
cacheKey, err := auth.getKey()
// getUserViaCache gets user ID from cache.
func (auth *AuthProxy) getUserViaCache(reqCtx *models.ReqContext) (int64, error) {
cacheKey, err := auth.getKey(reqCtx)
if err != nil {
return 0, err
}
logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
userID, err := auth.remoteCache.Get(auth.ctx.Req.Context(), cacheKey)
auth.logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), cacheKey)
if err != nil {
logger.Debug("Failed getting user ID via auth cache", "error", err)
auth.logger.Debug("Failed getting user ID via auth cache", "error", err)
return 0, err
}
logger.Debug("Successfully got user ID via auth cache", "id", userID)
auth.logger.Debug("Successfully got user ID via auth cache", "id", userID)
return userID.(int64), nil
}
// RemoveUserFromCache removes user from cache.
func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error {
cacheKey, err := auth.getKey()
func (auth *AuthProxy) RemoveUserFromCache(reqCtx *models.ReqContext) error {
cacheKey, err := auth.getKey(reqCtx)
if err != nil {
return err
}
logger.Debug("Removing user from auth cache", "cacheKey", cacheKey)
if err := auth.remoteCache.Delete(auth.ctx.Req.Context(), cacheKey); err != nil {
auth.logger.Debug("Removing user from auth cache", "cacheKey", cacheKey)
if err := auth.remoteCache.Delete(reqCtx.Req.Context(), cacheKey); err != nil {
return err
}
logger.Debug("Successfully removed user from auth cache", "cacheKey", cacheKey)
auth.logger.Debug("Successfully removed user from auth cache", "cacheKey", cacheKey)
return nil
}
// LoginViaLDAP logs in user via LDAP request
func (auth *AuthProxy) LoginViaLDAP() (int64, error) {
func (auth *AuthProxy) LoginViaLDAP(reqCtx *models.ReqContext) (int64, error) {
config, err := getLDAPConfig(auth.cfg)
if err != nil {
return 0, newError("failed to get LDAP config", err)
}
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
mldap := newLDAP(config.Servers)
extUser, _, err := mldap.User(auth.header)
extUser, _, err := mldap.User(header)
if err != nil {
return 0, err
}
// Have to sync grafana and LDAP user during log in
upsert := &models.UpsertUserCommand{
ReqContext: auth.ctx,
ReqContext: reqCtx,
SignupAllowed: auth.cfg.LDAPAllowSignup,
ExternalUser: extUser,
}
if err := bus.Dispatch(auth.ctx.Req.Context(), upsert); err != nil {
if err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert); err != nil {
return 0, err
}
return upsert.Result.Id, nil
}
// LoginViaHeader logs in user from the header only
func (auth *AuthProxy) LoginViaHeader() (int64, error) {
// loginViaHeader logs in user from the header only
func (auth *AuthProxy) loginViaHeader(reqCtx *models.ReqContext) (int64, error) {
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
extUser := &models.ExternalUserInfo{
AuthModule: "authproxy",
AuthId: auth.header,
AuthId: header,
}
switch auth.cfg.AuthProxyHeaderProperty {
case "username":
extUser.Login = auth.header
extUser.Login = header
emailAddr, emailErr := mail.ParseAddress(auth.header) // only set Email if it can be parsed as an email address
emailAddr, emailErr := mail.ParseAddress(header) // only set Email if it can be parsed as an email address
if emailErr == nil {
extUser.Email = emailAddr.Address
}
case "email":
extUser.Email = auth.header
extUser.Login = auth.header
extUser.Email = header
extUser.Login = header
default:
return 0, fmt.Errorf("auth proxy header property invalid")
}
auth.headersIterator(func(field string, header string) {
auth.headersIterator(reqCtx, func(field string, header string) {
switch field {
case "Groups":
extUser.Groups = util.SplitString(header)
@ -300,12 +295,12 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
})
upsert := &models.UpsertUserCommand{
ReqContext: auth.ctx,
ReqContext: reqCtx,
SignupAllowed: auth.cfg.AuthProxyAutoSignUp,
ExternalUser: extUser,
}
err := bus.Dispatch(auth.ctx.Req.Context(), upsert)
err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert)
if err != nil {
return 0, err
}
@ -314,8 +309,8 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
}
// getDecodedHeader gets decoded value of a header with given headerName
func (auth *AuthProxy) getDecodedHeader(headerName string) string {
headerValue := auth.ctx.Req.Header.Get(headerName)
func (auth *AuthProxy) getDecodedHeader(reqCtx *models.ReqContext, headerName string) string {
headerValue := reqCtx.Req.Header.Get(headerName)
if auth.cfg.AuthProxyHeadersEncoded {
headerValue = util.DecodeQuotedPrintable(headerValue)
@ -325,27 +320,27 @@ func (auth *AuthProxy) getDecodedHeader(headerName string) string {
}
// headersIterator iterates over all non-empty supported additional headers
func (auth *AuthProxy) headersIterator(fn func(field string, header string)) {
func (auth *AuthProxy) headersIterator(reqCtx *models.ReqContext, fn func(field string, header string)) {
for _, field := range supportedHeaderFields {
h := auth.cfg.AuthProxyHeaders[field]
if h == "" {
continue
}
if value := auth.getDecodedHeader(h); value != "" {
if value := auth.getDecodedHeader(reqCtx, h); value != "" {
fn(field, strings.TrimSpace(value))
}
}
}
// GetSignedUser gets full signed in user info.
func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, error) {
// GetSignedInUser gets full signed in user info.
func (auth *AuthProxy) GetSignedInUser(userID int64, orgID int64) (*models.SignedInUser, error) {
query := &models.GetSignedInUserQuery{
OrgId: auth.orgID,
OrgId: orgID,
UserId: userID,
}
if err := bus.Dispatch(context.Background(), query); err != nil {
if err := auth.sqlStore.GetSignedInUser(context.Background(), query); err != nil {
return nil, err
}
@ -353,21 +348,21 @@ func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, erro
}
// Remember user in cache
func (auth *AuthProxy) Remember(id int64) error {
key, err := auth.getKey()
func (auth *AuthProxy) Remember(reqCtx *models.ReqContext, id int64) error {
key, err := auth.getKey(reqCtx)
if err != nil {
return err
}
// Check if user already in cache
userID, err := auth.remoteCache.Get(auth.ctx.Req.Context(), key)
userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), key)
if err == nil && userID != nil {
return nil
}
expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute
if err := auth.remoteCache.Set(auth.ctx.Req.Context(), key, id, expiration); err != nil {
if err := auth.remoteCache.Set(reqCtx.Req.Context(), key, id, expiration); err != nil {
return err
}

View File

@ -7,8 +7,8 @@ import (
"net/http"
"testing"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/ldap"
@ -20,8 +20,9 @@ import (
)
const hdrName = "markelog"
const id int64 = 42
func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, configureReq func(*http.Request, *setting.Cfg)) *AuthProxy {
func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, configureReq func(*http.Request, *setting.Cfg)) (*AuthProxy, *models.ReqContext) {
t.Helper()
req, err := http.NewRequest("POST", "http://example.com", nil)
@ -40,17 +41,16 @@ func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, confi
Context: &web.Context{Req: req},
}
auth := New(cfg, &Options{
RemoteCache: remoteCache,
Ctx: ctx,
OrgID: 4,
})
loginService := loginservice.LoginServiceMock{
ExpectedUser: &models.User{
Id: id,
},
}
return auth
return ProvideAuthProxy(cfg, remoteCache, loginService, nil), ctx
}
func TestMiddlewareContext(t *testing.T) {
logger := log.New("test")
cache := remotecache.NewFakeStore(t)
t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) {
@ -62,12 +62,12 @@ func TestMiddlewareContext(t *testing.T) {
err = cache.Set(context.Background(), key, id, 0)
require.NoError(t, err)
// Set up the middleware
auth := prepareMiddleware(t, cache, nil)
gotKey, err := auth.getKey()
auth, reqCtx := prepareMiddleware(t, cache, nil)
gotKey, err := auth.getKey(reqCtx)
require.NoError(t, err)
assert.Equal(t, key, gotKey)
gotID, err := auth.Login(logger, false)
gotID, err := auth.Login(reqCtx, false)
require.NoError(t, err)
assert.Equal(t, id, gotID)
@ -84,7 +84,7 @@ func TestMiddlewareContext(t *testing.T) {
err = cache.Set(context.Background(), key, id, 0)
require.NoError(t, err)
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
cfg.AuthProxyHeaderName = "X-Killa"
cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS", "Role": "X-WEBAUTH-ROLE"}
req.Header.Set(cfg.AuthProxyHeaderName, hdrName)
@ -93,26 +93,14 @@ func TestMiddlewareContext(t *testing.T) {
})
assert.Equal(t, "auth-proxy-sync-ttl:f5acfffd56daac98d502ef8c8b8c5d56", key)
gotID, err := auth.Login(logger, false)
gotID, err := auth.Login(reqCtx, false)
require.NoError(t, err)
assert.Equal(t, id, gotID)
})
}
func TestMiddlewareContext_ldap(t *testing.T) {
logger := log.New("test")
t.Run("Logs in via LDAP", func(t *testing.T) {
const id int64 = 42
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
cmd.Result = &models.User{
Id: id,
}
return nil
})
origIsLDAPEnabled := isLDAPEnabled
origGetLDAPConfig := getLDAPConfig
origNewLDAP := newLDAP
@ -147,9 +135,9 @@ func TestMiddlewareContext_ldap(t *testing.T) {
cache := remotecache.NewFakeStore(t)
auth := prepareMiddleware(t, cache, nil)
auth, reqCtx := prepareMiddleware(t, cache, nil)
gotID, err := auth.Login(logger, false)
gotID, err := auth.Login(reqCtx, false)
require.NoError(t, err)
assert.Equal(t, id, gotID)
@ -177,7 +165,7 @@ func TestMiddlewareContext_ldap(t *testing.T) {
cache := remotecache.NewFakeStore(t)
auth := prepareMiddleware(t, cache, nil)
auth, reqCtx := prepareMiddleware(t, cache, nil)
stub := &multildap.MultiLDAPmock{
ID: id,
@ -187,7 +175,7 @@ func TestMiddlewareContext_ldap(t *testing.T) {
return stub
}
gotID, err := auth.Login(logger, false)
gotID, err := auth.Login(reqCtx, false)
require.EqualError(t, err, "failed to get the user")
assert.NotEqual(t, id, gotID)
@ -198,22 +186,24 @@ func TestMiddlewareContext_ldap(t *testing.T) {
func TestDecodeHeader(t *testing.T) {
cache := remotecache.NewFakeStore(t)
t.Run("should not decode header if not enabled in settings", func(t *testing.T) {
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
cfg.AuthProxyHeadersEncoded = false
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen")
})
assert.Equal(t, "M=C3=BCnchen", auth.header)
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
assert.Equal(t, "M=C3=BCnchen", header)
})
t.Run("should decode header if enabled in settings", func(t *testing.T) {
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
cfg.AuthProxyHeadersEncoded = true
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen")
})
assert.Equal(t, "München", auth.header)
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
assert.Equal(t, "München", header)
})
}

View File

@ -36,7 +36,7 @@ const ServiceName = "ContextHandler"
func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtService models.JWTService,
remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore *sqlstore.SQLStore,
tracer tracing.Tracer) *ContextHandler {
tracer tracing.Tracer, authProxy *authproxy.AuthProxy) *ContextHandler {
return &ContextHandler{
Cfg: cfg,
AuthTokenService: tokenService,
@ -45,6 +45,7 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtS
RenderService: renderService,
SQLStore: sqlStore,
tracer: tracer,
authProxy: authProxy,
}
}
@ -57,6 +58,7 @@ type ContextHandler struct {
RenderService rendering.Service
SQLStore sqlstore.Store
tracer tracing.Tracer
authProxy *authproxy.AuthProxy
// GetTime returns the current time.
// Stubbable by tests.
GetTime func() time.Time
@ -419,10 +421,10 @@ func (h *ContextHandler) initContextWithRenderAuth(reqContext *models.ReqContext
return true
}
func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) {
func logUserIn(reqContext *models.ReqContext, auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) {
logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache)
// Try to log in user via various providers
id, err := auth.Login(logger, ignoreCache)
id, err := auth.Login(reqContext, ignoreCache)
if err != nil {
details := err
var e authproxy.Error
@ -451,36 +453,31 @@ func (h *ContextHandler) handleError(ctx *models.ReqContext, err error, statusCo
func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, orgID int64) bool {
username := reqContext.Req.Header.Get(h.Cfg.AuthProxyHeaderName)
auth := authproxy.New(h.Cfg, &authproxy.Options{
RemoteCache: h.RemoteCache,
Ctx: reqContext,
OrgID: orgID,
})
logger := log.New("auth.proxy")
// Bail if auth proxy is not enabled
if !auth.IsEnabled() {
if !h.authProxy.IsEnabled() {
return false
}
// If there is no header - we can't move forward
if !auth.HasHeader() {
if !h.authProxy.HasHeader(reqContext) {
return false
}
_, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithAuthProxy")
defer span.End()
// Check if allowed to continue with this IP
if err := auth.IsAllowedIP(); err != nil {
// Check if allowed continuing with this IP
if err := h.authProxy.IsAllowedIP(reqContext.Req.RemoteAddr); err != nil {
h.handleError(reqContext, err, 407, func(details error) {
logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details)
})
return true
}
id, err := logUserIn(auth, username, logger, false)
id, err := logUserIn(reqContext, h.authProxy, username, logger, false)
if err != nil {
h.handleError(reqContext, err, 407, nil)
return true
@ -488,7 +485,7 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
logger.Debug("Got user ID, getting full user info", "userID", id)
user, err := auth.GetSignedInUser(id)
user, err := h.authProxy.GetSignedInUser(id, orgID)
if err != nil {
// The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale
// cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated
@ -496,18 +493,18 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
// we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to
// log the user in again without the cache.
logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id)
if err := auth.RemoveUserFromCache(logger); err != nil {
if err := h.authProxy.RemoveUserFromCache(reqContext); err != nil {
if !errors.Is(err, remotecache.ErrCacheItemNotFound) {
logger.Error("Got unexpected error when removing user from auth cache", "error", err)
}
}
id, err = logUserIn(auth, username, logger, true)
id, err = logUserIn(reqContext, h.authProxy, username, logger, true)
if err != nil {
h.handleError(reqContext, err, 407, nil)
return true
}
user, err = auth.GetSignedInUser(id)
user, err = h.authProxy.GetSignedInUser(id, orgID)
if err != nil {
h.handleError(reqContext, err, 407, nil)
return true
@ -521,7 +518,7 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
reqContext.IsSignedIn = true
// Remember user data in cache
if err := auth.Remember(id); err != nil {
if err := h.authProxy.Remember(reqContext, id); err != nil {
h.handleError(reqContext, err, 500, func(details error) {
logger.Error(
"Failed to store user in cache",

View File

@ -15,6 +15,9 @@ type LoginServiceMock struct {
NoExistingOrgId int64
AlreadyExitingLogin string
GeneratedUserId int64
ExpectedUser *models.User
ExpectedUserFunc func(cmd *models.UpsertUserCommand) *models.User
ExpectedError error
}
func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User, error) {
@ -35,5 +38,10 @@ func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User
}
func (s LoginServiceMock) UpsertUser(ctx context.Context, cmd *models.UpsertUserCommand) error {
return nil
if s.ExpectedUserFunc != nil {
cmd.Result = s.ExpectedUserFunc(cmd)
return s.ExpectedError
}
cmd.Result = s.ExpectedUser
return s.ExpectedError
}