mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Chore: Remove bus from authproxy (#46936)
* Make authproxy injectable * Fix import * Provide function was in wrong place * Fixing tests * More imports and rollback a change * Fix lint
This commit is contained in:
parent
118b87ee8f
commit
8e52dbb87b
@ -27,11 +27,13 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
||||
"github.com/grafana/grafana/pkg/services/dashboards"
|
||||
dashboardsstore "github.com/grafana/grafana/pkg/services/dashboards/database"
|
||||
dashboardservice "github.com/grafana/grafana/pkg/services/dashboards/manager"
|
||||
"github.com/grafana/grafana/pkg/services/featuremgmt"
|
||||
"github.com/grafana/grafana/pkg/services/ldap"
|
||||
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||
"github.com/grafana/grafana/pkg/services/quota"
|
||||
"github.com/grafana/grafana/pkg/services/rendering"
|
||||
"github.com/grafana/grafana/pkg/services/searchusers"
|
||||
@ -193,7 +195,8 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
|
||||
authJWTSvc := models.NewFakeJWTService()
|
||||
tracer, err := tracing.InitializeTracerForTest()
|
||||
require.NoError(t, err)
|
||||
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer)
|
||||
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, sqlStore)
|
||||
ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
|
||||
|
||||
return ctxHdlr
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -25,8 +24,10 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
||||
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||
"github.com/grafana/grafana/pkg/services/rendering"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
@ -364,10 +365,7 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
const group = "grafana-core-team"
|
||||
|
||||
middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: query.UserId}
|
||||
return nil
|
||||
})
|
||||
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
|
||||
h, err := authproxy.HashCacheKey(hdrName + "-" + group)
|
||||
require.NoError(t, err)
|
||||
@ -387,11 +385,11 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
|
||||
middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) {
|
||||
var actualAuthProxyAutoSignUp *bool = nil
|
||||
|
||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
|
||||
actualAuthProxyAutoSignUp = &cmd.SignupAllowed
|
||||
return login.ErrInvalidCredentials
|
||||
})
|
||||
return nil
|
||||
}
|
||||
sc.loginService.ExpectedError = login.ErrInvalidCredentials
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
@ -407,18 +405,8 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
if query.UserId > 0 {
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
return nil
|
||||
}
|
||||
return models.ErrUserNotFound
|
||||
})
|
||||
|
||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
cmd.Result = &models.User{Id: userID}
|
||||
return nil
|
||||
})
|
||||
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
sc.loginService.ExpectedUser = &models.User{Id: userID}
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
@ -435,19 +423,11 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
|
||||
middlewareScenario(t, "Should assign role from header to default org", func(t *testing.T, sc *scenarioContext) {
|
||||
var storedRoleInfo map[int64]models.RoleType = nil
|
||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
if query.UserId > 0 {
|
||||
query.Result = &models.SignedInUser{OrgId: defaultOrgId, UserId: userID, OrgRole: storedRoleInfo[defaultOrgId]}
|
||||
return nil
|
||||
}
|
||||
return models.ErrUserNotFound
|
||||
})
|
||||
|
||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
cmd.Result = &models.User{Id: userID}
|
||||
sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
|
||||
storedRoleInfo = cmd.ExternalUser.OrgRoles
|
||||
return nil
|
||||
})
|
||||
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: defaultOrgId, UserId: userID, OrgRole: storedRoleInfo[defaultOrgId]}
|
||||
return &models.User{Id: userID}
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
@ -466,19 +446,11 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
|
||||
middlewareScenario(t, "Should NOT assign role from header to non-default org", func(t *testing.T, sc *scenarioContext) {
|
||||
var storedRoleInfo map[int64]models.RoleType = nil
|
||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
if query.UserId > 0 {
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID, OrgRole: storedRoleInfo[orgID]}
|
||||
return nil
|
||||
}
|
||||
return models.ErrUserNotFound
|
||||
})
|
||||
|
||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
cmd.Result = &models.User{Id: userID}
|
||||
sc.loginService.ExpectedUserFunc = func(cmd *models.UpsertUserCommand) *models.User {
|
||||
storedRoleInfo = cmd.ExternalUser.OrgRoles
|
||||
return nil
|
||||
})
|
||||
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID, OrgRole: storedRoleInfo[orgID]}
|
||||
return &models.User{Id: userID}
|
||||
}
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
@ -499,27 +471,17 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should use organisation specified by targetOrgId parameter", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
if query.UserId > 0 {
|
||||
query.Result = &models.SignedInUser{OrgId: query.OrgId, UserId: userID}
|
||||
return nil
|
||||
}
|
||||
return models.ErrUserNotFound
|
||||
})
|
||||
var targetOrgID int64 = 123
|
||||
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: targetOrgID, UserId: userID}
|
||||
sc.loginService.ExpectedUser = &models.User{Id: userID}
|
||||
|
||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
cmd.Result = &models.User{Id: userID}
|
||||
return nil
|
||||
})
|
||||
|
||||
targetOrgID := 123
|
||||
sc.fakeReq("GET", fmt.Sprintf("/?targetOrgId=%d", targetOrgID))
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.exec()
|
||||
|
||||
assert.True(t, sc.context.IsSignedIn)
|
||||
assert.Equal(t, userID, sc.context.UserId)
|
||||
assert.Equal(t, int64(targetOrgID), sc.context.OrgId)
|
||||
assert.Equal(t, targetOrgID, sc.context.OrgId)
|
||||
}, func(cfg *setting.Cfg) {
|
||||
configure(cfg)
|
||||
cfg.LDAPEnabled = false
|
||||
@ -554,15 +516,8 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
const userID int64 = 12
|
||||
const orgID int64 = 2
|
||||
|
||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
return nil
|
||||
})
|
||||
|
||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
cmd.Result = &models.User{Id: userID}
|
||||
return nil
|
||||
})
|
||||
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
sc.loginService.ExpectedUser = &models.User{Id: userID}
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
@ -577,15 +532,8 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
return nil
|
||||
})
|
||||
|
||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
cmd.Result = &models.User{Id: userID}
|
||||
return nil
|
||||
})
|
||||
sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
sc.loginService.ExpectedUser = &models.User{Id: userID}
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
@ -602,15 +550,7 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
return nil
|
||||
})
|
||||
|
||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
cmd.Result = &models.User{Id: userID}
|
||||
return nil
|
||||
})
|
||||
sc.loginService.ExpectedUser = &models.User{Id: userID}
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
@ -626,10 +566,6 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("LDAP", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
return errors.New("Do not add user")
|
||||
})
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.exec()
|
||||
@ -639,10 +575,6 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
}, configure)
|
||||
|
||||
middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("Do not have the user", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
return errors.New("Do not add user")
|
||||
})
|
||||
|
||||
sc.fakeReq("GET", "/")
|
||||
sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
|
||||
sc.exec()
|
||||
@ -684,7 +616,9 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
||||
sc.m.UseMiddleware(AddCSPHeader(cfg, logger))
|
||||
sc.m.UseMiddleware(web.Renderer(viewsPath, "[[", "]]"))
|
||||
|
||||
ctxHdlr := getContextHandler(t, cfg)
|
||||
sc.mockSQLStore = mockstore.NewSQLStoreMock()
|
||||
sc.loginService = &loginservice.LoginServiceMock{}
|
||||
ctxHdlr := getContextHandler(t, cfg, sc.mockSQLStore, sc.loginService)
|
||||
sc.sqlStore = ctxHdlr.SQLStore
|
||||
sc.contextHandler = ctxHdlr
|
||||
sc.m.Use(ctxHdlr.Middleware)
|
||||
@ -714,7 +648,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(
|
||||
})
|
||||
}
|
||||
|
||||
func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHandler {
|
||||
func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock, loginService *loginservice.LoginServiceMock) *contexthandler.ContextHandler {
|
||||
t.Helper()
|
||||
|
||||
sqlStore := sqlstore.InitTestDB(t)
|
||||
@ -730,8 +664,9 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa
|
||||
renderSvc := &fakeRenderService{}
|
||||
authJWTSvc := models.NewFakeJWTService()
|
||||
tracer, err := tracing.InitializeTracerForTest()
|
||||
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, mockSQLStore)
|
||||
require.NoError(t, err)
|
||||
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer)
|
||||
return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
|
||||
}
|
||||
|
||||
type fakeRenderService struct {
|
||||
|
@ -33,7 +33,7 @@ func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateL
|
||||
|
||||
m := web.New()
|
||||
m.UseMiddleware(web.Renderer("../../public/views", "[[", "]]"))
|
||||
m.Use(getContextHandler(t, cfg).Middleware)
|
||||
m.Use(getContextHandler(t, cfg, nil, nil).Middleware)
|
||||
m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler)
|
||||
|
||||
fn(func() *httptest.ResponseRecorder {
|
||||
|
@ -70,7 +70,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
|
||||
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
|
||||
sc.remoteCacheService = remotecache.NewFakeStore(t)
|
||||
|
||||
contextHandler := getContextHandler(t, nil)
|
||||
contextHandler := getContextHandler(t, nil, nil, nil)
|
||||
sc.m.Use(contextHandler.Middleware)
|
||||
// mock out gc goroutine
|
||||
sc.m.Use(OrgRedirect(cfg))
|
||||
|
@ -10,7 +10,9 @@ import (
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/web"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -34,7 +36,9 @@ type scenarioContext struct {
|
||||
remoteCacheService *remotecache.RemoteCache
|
||||
cfg *setting.Cfg
|
||||
sqlStore sqlstore.Store
|
||||
mockSQLStore *mockstore.SQLStoreMock
|
||||
contextHandler *contexthandler.ContextHandler
|
||||
loginService *loginservice.LoginServiceMock
|
||||
|
||||
req *http.Request
|
||||
}
|
||||
|
@ -32,6 +32,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/services/cleanup"
|
||||
"github.com/grafana/grafana/pkg/services/comments"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
||||
"github.com/grafana/grafana/pkg/services/dashboardimport"
|
||||
dashboardimportservice "github.com/grafana/grafana/pkg/services/dashboardimport/service"
|
||||
"github.com/grafana/grafana/pkg/services/dashboards"
|
||||
@ -228,6 +229,7 @@ var wireBasicSet = wire.NewSet(
|
||||
wire.Bind(new(alerting.DashAlertExtractor), new(*alerting.DashAlertExtractorService)),
|
||||
comments.ProvideService,
|
||||
guardian.ProvideService,
|
||||
authproxy.ProvideAuthProxy,
|
||||
)
|
||||
|
||||
var wireSet = wire.NewSet(
|
||||
|
@ -6,13 +6,13 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/auth"
|
||||
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
|
||||
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||
"github.com/grafana/grafana/pkg/services/rendering"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
@ -20,42 +20,18 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const userID = int64(1)
|
||||
const orgID = int64(4)
|
||||
|
||||
// Test initContextWithAuthProxy with a cached user ID that is no longer valid.
|
||||
//
|
||||
// In this case, the cache entry should be ignored/cleared and another attempt should be done to sign the user
|
||||
// in without cache.
|
||||
func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
|
||||
const name = "markelog"
|
||||
const userID = int64(1)
|
||||
const orgID = int64(4)
|
||||
|
||||
svc := getContextHandler(t)
|
||||
|
||||
// XXX: These handlers have to be injected AFTER calling getContextHandler, since the latter
|
||||
// creates a SQLStore which installs its own handlers.
|
||||
upsertHandler := func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
require.Equal(t, name, cmd.ExternalUser.Login)
|
||||
cmd.Result = &models.User{Id: userID}
|
||||
return nil
|
||||
}
|
||||
getUserHandler := func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
// Simulate that the cached user ID is stale
|
||||
if query.UserId != userID {
|
||||
return models.ErrUserNotFound
|
||||
}
|
||||
|
||||
query.Result = &models.SignedInUser{
|
||||
UserId: userID,
|
||||
OrgId: orgID,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
bus.AddHandler("", upsertHandler)
|
||||
bus.AddHandler("", getUserHandler)
|
||||
t.Cleanup(func() {
|
||||
bus.ClearBusHandlers()
|
||||
})
|
||||
|
||||
req, err := http.NewRequest("POST", "http://example.com", nil)
|
||||
require.NoError(t, err)
|
||||
ctx := &models.ReqContext{
|
||||
@ -106,5 +82,24 @@ func getContextHandler(t *testing.T) *ContextHandler {
|
||||
tracer, err := tracing.InitializeTracerForTest()
|
||||
require.NoError(t, err)
|
||||
|
||||
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer)
|
||||
loginService := loginservice.LoginServiceMock{ExpectedUser: &models.User{Id: userID}}
|
||||
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, &FakeGetSignUserStore{})
|
||||
|
||||
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy)
|
||||
}
|
||||
|
||||
type FakeGetSignUserStore struct {
|
||||
sqlstore.Store
|
||||
}
|
||||
|
||||
func (f *FakeGetSignUserStore) GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
if query.UserId != userID {
|
||||
return models.ErrUserNotFound
|
||||
}
|
||||
|
||||
query.Result = &models.SignedInUser{
|
||||
UserId: userID,
|
||||
OrgId: orgID,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -13,12 +13,13 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/ldap"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/multildap"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util"
|
||||
)
|
||||
@ -49,11 +50,22 @@ var supportedHeaderFields = []string{"Name", "Email", "Login", "Groups", "Role"}
|
||||
|
||||
// AuthProxy struct
|
||||
type AuthProxy struct {
|
||||
cfg *setting.Cfg
|
||||
remoteCache *remotecache.RemoteCache
|
||||
ctx *models.ReqContext
|
||||
orgID int64
|
||||
header string
|
||||
cfg *setting.Cfg
|
||||
remoteCache *remotecache.RemoteCache
|
||||
loginService login.Service
|
||||
sqlStore sqlstore.Store
|
||||
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func ProvideAuthProxy(cfg *setting.Cfg, remoteCache *remotecache.RemoteCache, loginService login.Service, sqlStore sqlstore.Store) *AuthProxy {
|
||||
return &AuthProxy{
|
||||
cfg: cfg,
|
||||
remoteCache: remoteCache,
|
||||
loginService: loginService,
|
||||
sqlStore: sqlStore,
|
||||
logger: log.New("auth.proxy"),
|
||||
}
|
||||
}
|
||||
|
||||
// Error auth proxy specific error
|
||||
@ -75,40 +87,20 @@ func (err Error) Error() string {
|
||||
return err.Message
|
||||
}
|
||||
|
||||
// Options for the AuthProxy
|
||||
type Options struct {
|
||||
RemoteCache *remotecache.RemoteCache
|
||||
Ctx *models.ReqContext
|
||||
OrgID int64
|
||||
}
|
||||
|
||||
// New instance of the AuthProxy.
|
||||
func New(cfg *setting.Cfg, options *Options) *AuthProxy {
|
||||
auth := &AuthProxy{
|
||||
remoteCache: options.RemoteCache,
|
||||
cfg: cfg,
|
||||
ctx: options.Ctx,
|
||||
orgID: options.OrgID,
|
||||
}
|
||||
auth.header = auth.getDecodedHeader(cfg.AuthProxyHeaderName)
|
||||
return auth
|
||||
}
|
||||
|
||||
// IsEnabled checks if the auth proxy is enabled.
|
||||
func (auth *AuthProxy) IsEnabled() bool {
|
||||
// Bail if the setting is not enabled
|
||||
return auth.cfg.AuthProxyEnabled
|
||||
}
|
||||
|
||||
// HasHeader checks if the we have specified header
|
||||
func (auth *AuthProxy) HasHeader() bool {
|
||||
return len(auth.header) != 0
|
||||
// HasHeader checks if we have specified header
|
||||
func (auth *AuthProxy) HasHeader(reqCtx *models.ReqContext) bool {
|
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||
return len(header) != 0
|
||||
}
|
||||
|
||||
// IsAllowedIP returns whether provided IP is allowed.
|
||||
func (auth *AuthProxy) IsAllowedIP() error {
|
||||
ip := auth.ctx.Req.RemoteAddr
|
||||
|
||||
func (auth *AuthProxy) IsAllowedIP(ip string) error {
|
||||
if len(strings.TrimSpace(auth.cfg.AuthProxyWhitelist)) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -137,7 +129,7 @@ func (auth *AuthProxy) IsAllowedIP() error {
|
||||
}
|
||||
|
||||
return newError("proxy authentication required", fmt.Errorf(
|
||||
"request for user (%s) from %s is not from the authentication proxy", auth.header,
|
||||
"request for user from %s is not from the authentication proxy",
|
||||
sourceIP,
|
||||
))
|
||||
}
|
||||
@ -153,10 +145,11 @@ func HashCacheKey(key string) (string, error) {
|
||||
// getKey forms a key for the cache based on the headers received as part of the authentication flow.
|
||||
// Our configuration supports multiple headers. The main header contains the email or username.
|
||||
// And the additional ones that allow us to specify extra attributes: Name, Email, Role, or Groups.
|
||||
func (auth *AuthProxy) getKey() (string, error) {
|
||||
key := strings.TrimSpace(auth.header) // start the key with the main header
|
||||
func (auth *AuthProxy) getKey(reqCtx *models.ReqContext) (string, error) {
|
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||
key := strings.TrimSpace(header) // start the key with the main header
|
||||
|
||||
auth.headersIterator(func(_, header string) {
|
||||
auth.headersIterator(reqCtx, func(_, header string) {
|
||||
key = strings.Join([]string{key, header}, "-") // compose the key with any additional headers
|
||||
})
|
||||
|
||||
@ -168,17 +161,17 @@ func (auth *AuthProxy) getKey() (string, error) {
|
||||
}
|
||||
|
||||
// Login logs in user ID by whatever means possible.
|
||||
func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error) {
|
||||
func (auth *AuthProxy) Login(reqCtx *models.ReqContext, ignoreCache bool) (int64, error) {
|
||||
if !ignoreCache {
|
||||
// Error here means absent cache - we don't need to handle that
|
||||
id, err := auth.GetUserViaCache(logger)
|
||||
id, err := auth.getUserViaCache(reqCtx)
|
||||
if err == nil && id != 0 {
|
||||
return id, nil
|
||||
}
|
||||
}
|
||||
|
||||
if isLDAPEnabled(auth.cfg) {
|
||||
id, err := auth.LoginViaLDAP()
|
||||
id, err := auth.LoginViaLDAP(reqCtx)
|
||||
if err != nil {
|
||||
if errors.Is(err, ldap.ErrInvalidCredentials) {
|
||||
return 0, newError("proxy authentication required", ldap.ErrInvalidCredentials)
|
||||
@ -189,7 +182,7 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
|
||||
return id, nil
|
||||
}
|
||||
|
||||
id, err := auth.LoginViaHeader()
|
||||
id, err := auth.loginViaHeader(reqCtx)
|
||||
if err != nil {
|
||||
return 0, newError("failed to log in as user, specified in auth proxy header", err)
|
||||
}
|
||||
@ -197,87 +190,89 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// GetUserViaCache gets user ID from cache.
|
||||
func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) {
|
||||
cacheKey, err := auth.getKey()
|
||||
// getUserViaCache gets user ID from cache.
|
||||
func (auth *AuthProxy) getUserViaCache(reqCtx *models.ReqContext) (int64, error) {
|
||||
cacheKey, err := auth.getKey(reqCtx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
|
||||
userID, err := auth.remoteCache.Get(auth.ctx.Req.Context(), cacheKey)
|
||||
auth.logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
|
||||
userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), cacheKey)
|
||||
if err != nil {
|
||||
logger.Debug("Failed getting user ID via auth cache", "error", err)
|
||||
auth.logger.Debug("Failed getting user ID via auth cache", "error", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
logger.Debug("Successfully got user ID via auth cache", "id", userID)
|
||||
auth.logger.Debug("Successfully got user ID via auth cache", "id", userID)
|
||||
return userID.(int64), nil
|
||||
}
|
||||
|
||||
// RemoveUserFromCache removes user from cache.
|
||||
func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error {
|
||||
cacheKey, err := auth.getKey()
|
||||
func (auth *AuthProxy) RemoveUserFromCache(reqCtx *models.ReqContext) error {
|
||||
cacheKey, err := auth.getKey(reqCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Debug("Removing user from auth cache", "cacheKey", cacheKey)
|
||||
if err := auth.remoteCache.Delete(auth.ctx.Req.Context(), cacheKey); err != nil {
|
||||
auth.logger.Debug("Removing user from auth cache", "cacheKey", cacheKey)
|
||||
if err := auth.remoteCache.Delete(reqCtx.Req.Context(), cacheKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Debug("Successfully removed user from auth cache", "cacheKey", cacheKey)
|
||||
auth.logger.Debug("Successfully removed user from auth cache", "cacheKey", cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoginViaLDAP logs in user via LDAP request
|
||||
func (auth *AuthProxy) LoginViaLDAP() (int64, error) {
|
||||
func (auth *AuthProxy) LoginViaLDAP(reqCtx *models.ReqContext) (int64, error) {
|
||||
config, err := getLDAPConfig(auth.cfg)
|
||||
if err != nil {
|
||||
return 0, newError("failed to get LDAP config", err)
|
||||
}
|
||||
|
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||
mldap := newLDAP(config.Servers)
|
||||
extUser, _, err := mldap.User(auth.header)
|
||||
extUser, _, err := mldap.User(header)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Have to sync grafana and LDAP user during log in
|
||||
upsert := &models.UpsertUserCommand{
|
||||
ReqContext: auth.ctx,
|
||||
ReqContext: reqCtx,
|
||||
SignupAllowed: auth.cfg.LDAPAllowSignup,
|
||||
ExternalUser: extUser,
|
||||
}
|
||||
if err := bus.Dispatch(auth.ctx.Req.Context(), upsert); err != nil {
|
||||
if err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return upsert.Result.Id, nil
|
||||
}
|
||||
|
||||
// LoginViaHeader logs in user from the header only
|
||||
func (auth *AuthProxy) LoginViaHeader() (int64, error) {
|
||||
// loginViaHeader logs in user from the header only
|
||||
func (auth *AuthProxy) loginViaHeader(reqCtx *models.ReqContext) (int64, error) {
|
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||
extUser := &models.ExternalUserInfo{
|
||||
AuthModule: "authproxy",
|
||||
AuthId: auth.header,
|
||||
AuthId: header,
|
||||
}
|
||||
|
||||
switch auth.cfg.AuthProxyHeaderProperty {
|
||||
case "username":
|
||||
extUser.Login = auth.header
|
||||
extUser.Login = header
|
||||
|
||||
emailAddr, emailErr := mail.ParseAddress(auth.header) // only set Email if it can be parsed as an email address
|
||||
emailAddr, emailErr := mail.ParseAddress(header) // only set Email if it can be parsed as an email address
|
||||
if emailErr == nil {
|
||||
extUser.Email = emailAddr.Address
|
||||
}
|
||||
case "email":
|
||||
extUser.Email = auth.header
|
||||
extUser.Login = auth.header
|
||||
extUser.Email = header
|
||||
extUser.Login = header
|
||||
default:
|
||||
return 0, fmt.Errorf("auth proxy header property invalid")
|
||||
}
|
||||
|
||||
auth.headersIterator(func(field string, header string) {
|
||||
auth.headersIterator(reqCtx, func(field string, header string) {
|
||||
switch field {
|
||||
case "Groups":
|
||||
extUser.Groups = util.SplitString(header)
|
||||
@ -300,12 +295,12 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
|
||||
})
|
||||
|
||||
upsert := &models.UpsertUserCommand{
|
||||
ReqContext: auth.ctx,
|
||||
ReqContext: reqCtx,
|
||||
SignupAllowed: auth.cfg.AuthProxyAutoSignUp,
|
||||
ExternalUser: extUser,
|
||||
}
|
||||
|
||||
err := bus.Dispatch(auth.ctx.Req.Context(), upsert)
|
||||
err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -314,8 +309,8 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
|
||||
}
|
||||
|
||||
// getDecodedHeader gets decoded value of a header with given headerName
|
||||
func (auth *AuthProxy) getDecodedHeader(headerName string) string {
|
||||
headerValue := auth.ctx.Req.Header.Get(headerName)
|
||||
func (auth *AuthProxy) getDecodedHeader(reqCtx *models.ReqContext, headerName string) string {
|
||||
headerValue := reqCtx.Req.Header.Get(headerName)
|
||||
|
||||
if auth.cfg.AuthProxyHeadersEncoded {
|
||||
headerValue = util.DecodeQuotedPrintable(headerValue)
|
||||
@ -325,27 +320,27 @@ func (auth *AuthProxy) getDecodedHeader(headerName string) string {
|
||||
}
|
||||
|
||||
// headersIterator iterates over all non-empty supported additional headers
|
||||
func (auth *AuthProxy) headersIterator(fn func(field string, header string)) {
|
||||
func (auth *AuthProxy) headersIterator(reqCtx *models.ReqContext, fn func(field string, header string)) {
|
||||
for _, field := range supportedHeaderFields {
|
||||
h := auth.cfg.AuthProxyHeaders[field]
|
||||
if h == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if value := auth.getDecodedHeader(h); value != "" {
|
||||
if value := auth.getDecodedHeader(reqCtx, h); value != "" {
|
||||
fn(field, strings.TrimSpace(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetSignedUser gets full signed in user info.
|
||||
func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, error) {
|
||||
// GetSignedInUser gets full signed in user info.
|
||||
func (auth *AuthProxy) GetSignedInUser(userID int64, orgID int64) (*models.SignedInUser, error) {
|
||||
query := &models.GetSignedInUserQuery{
|
||||
OrgId: auth.orgID,
|
||||
OrgId: orgID,
|
||||
UserId: userID,
|
||||
}
|
||||
|
||||
if err := bus.Dispatch(context.Background(), query); err != nil {
|
||||
if err := auth.sqlStore.GetSignedInUser(context.Background(), query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -353,21 +348,21 @@ func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, erro
|
||||
}
|
||||
|
||||
// Remember user in cache
|
||||
func (auth *AuthProxy) Remember(id int64) error {
|
||||
key, err := auth.getKey()
|
||||
func (auth *AuthProxy) Remember(reqCtx *models.ReqContext, id int64) error {
|
||||
key, err := auth.getKey(reqCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if user already in cache
|
||||
userID, err := auth.remoteCache.Get(auth.ctx.Req.Context(), key)
|
||||
userID, err := auth.remoteCache.Get(reqCtx.Req.Context(), key)
|
||||
if err == nil && userID != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute
|
||||
|
||||
if err := auth.remoteCache.Set(auth.ctx.Req.Context(), key, id, expiration); err != nil {
|
||||
if err := auth.remoteCache.Set(reqCtx.Req.Context(), key, id, expiration); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -7,8 +7,8 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/services/login/loginservice"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/ldap"
|
||||
@ -20,8 +20,9 @@ import (
|
||||
)
|
||||
|
||||
const hdrName = "markelog"
|
||||
const id int64 = 42
|
||||
|
||||
func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, configureReq func(*http.Request, *setting.Cfg)) *AuthProxy {
|
||||
func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, configureReq func(*http.Request, *setting.Cfg)) (*AuthProxy, *models.ReqContext) {
|
||||
t.Helper()
|
||||
|
||||
req, err := http.NewRequest("POST", "http://example.com", nil)
|
||||
@ -40,17 +41,16 @@ func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, confi
|
||||
Context: &web.Context{Req: req},
|
||||
}
|
||||
|
||||
auth := New(cfg, &Options{
|
||||
RemoteCache: remoteCache,
|
||||
Ctx: ctx,
|
||||
OrgID: 4,
|
||||
})
|
||||
loginService := loginservice.LoginServiceMock{
|
||||
ExpectedUser: &models.User{
|
||||
Id: id,
|
||||
},
|
||||
}
|
||||
|
||||
return auth
|
||||
return ProvideAuthProxy(cfg, remoteCache, loginService, nil), ctx
|
||||
}
|
||||
|
||||
func TestMiddlewareContext(t *testing.T) {
|
||||
logger := log.New("test")
|
||||
cache := remotecache.NewFakeStore(t)
|
||||
|
||||
t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) {
|
||||
@ -62,12 +62,12 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
err = cache.Set(context.Background(), key, id, 0)
|
||||
require.NoError(t, err)
|
||||
// Set up the middleware
|
||||
auth := prepareMiddleware(t, cache, nil)
|
||||
gotKey, err := auth.getKey()
|
||||
auth, reqCtx := prepareMiddleware(t, cache, nil)
|
||||
gotKey, err := auth.getKey(reqCtx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, key, gotKey)
|
||||
|
||||
gotID, err := auth.Login(logger, false)
|
||||
gotID, err := auth.Login(reqCtx, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, id, gotID)
|
||||
@ -84,7 +84,7 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
err = cache.Set(context.Background(), key, id, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
||||
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
||||
cfg.AuthProxyHeaderName = "X-Killa"
|
||||
cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS", "Role": "X-WEBAUTH-ROLE"}
|
||||
req.Header.Set(cfg.AuthProxyHeaderName, hdrName)
|
||||
@ -93,26 +93,14 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
})
|
||||
assert.Equal(t, "auth-proxy-sync-ttl:f5acfffd56daac98d502ef8c8b8c5d56", key)
|
||||
|
||||
gotID, err := auth.Login(logger, false)
|
||||
gotID, err := auth.Login(reqCtx, false)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, id, gotID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMiddlewareContext_ldap(t *testing.T) {
|
||||
logger := log.New("test")
|
||||
|
||||
t.Run("Logs in via LDAP", func(t *testing.T) {
|
||||
const id int64 = 42
|
||||
|
||||
bus.AddHandler("test", func(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
cmd.Result = &models.User{
|
||||
Id: id,
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
origIsLDAPEnabled := isLDAPEnabled
|
||||
origGetLDAPConfig := getLDAPConfig
|
||||
origNewLDAP := newLDAP
|
||||
@ -147,9 +135,9 @@ func TestMiddlewareContext_ldap(t *testing.T) {
|
||||
|
||||
cache := remotecache.NewFakeStore(t)
|
||||
|
||||
auth := prepareMiddleware(t, cache, nil)
|
||||
auth, reqCtx := prepareMiddleware(t, cache, nil)
|
||||
|
||||
gotID, err := auth.Login(logger, false)
|
||||
gotID, err := auth.Login(reqCtx, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, id, gotID)
|
||||
@ -177,7 +165,7 @@ func TestMiddlewareContext_ldap(t *testing.T) {
|
||||
|
||||
cache := remotecache.NewFakeStore(t)
|
||||
|
||||
auth := prepareMiddleware(t, cache, nil)
|
||||
auth, reqCtx := prepareMiddleware(t, cache, nil)
|
||||
|
||||
stub := &multildap.MultiLDAPmock{
|
||||
ID: id,
|
||||
@ -187,7 +175,7 @@ func TestMiddlewareContext_ldap(t *testing.T) {
|
||||
return stub
|
||||
}
|
||||
|
||||
gotID, err := auth.Login(logger, false)
|
||||
gotID, err := auth.Login(reqCtx, false)
|
||||
require.EqualError(t, err, "failed to get the user")
|
||||
|
||||
assert.NotEqual(t, id, gotID)
|
||||
@ -198,22 +186,24 @@ func TestMiddlewareContext_ldap(t *testing.T) {
|
||||
func TestDecodeHeader(t *testing.T) {
|
||||
cache := remotecache.NewFakeStore(t)
|
||||
t.Run("should not decode header if not enabled in settings", func(t *testing.T) {
|
||||
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
||||
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
||||
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
|
||||
cfg.AuthProxyHeadersEncoded = false
|
||||
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen")
|
||||
})
|
||||
|
||||
assert.Equal(t, "M=C3=BCnchen", auth.header)
|
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||
assert.Equal(t, "M=C3=BCnchen", header)
|
||||
})
|
||||
|
||||
t.Run("should decode header if enabled in settings", func(t *testing.T) {
|
||||
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
||||
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
|
||||
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
|
||||
cfg.AuthProxyHeadersEncoded = true
|
||||
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen")
|
||||
})
|
||||
|
||||
assert.Equal(t, "München", auth.header)
|
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName)
|
||||
assert.Equal(t, "München", header)
|
||||
})
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ const ServiceName = "ContextHandler"
|
||||
|
||||
func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtService models.JWTService,
|
||||
remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore *sqlstore.SQLStore,
|
||||
tracer tracing.Tracer) *ContextHandler {
|
||||
tracer tracing.Tracer, authProxy *authproxy.AuthProxy) *ContextHandler {
|
||||
return &ContextHandler{
|
||||
Cfg: cfg,
|
||||
AuthTokenService: tokenService,
|
||||
@ -45,6 +45,7 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtS
|
||||
RenderService: renderService,
|
||||
SQLStore: sqlStore,
|
||||
tracer: tracer,
|
||||
authProxy: authProxy,
|
||||
}
|
||||
}
|
||||
|
||||
@ -57,6 +58,7 @@ type ContextHandler struct {
|
||||
RenderService rendering.Service
|
||||
SQLStore sqlstore.Store
|
||||
tracer tracing.Tracer
|
||||
authProxy *authproxy.AuthProxy
|
||||
// GetTime returns the current time.
|
||||
// Stubbable by tests.
|
||||
GetTime func() time.Time
|
||||
@ -419,10 +421,10 @@ func (h *ContextHandler) initContextWithRenderAuth(reqContext *models.ReqContext
|
||||
return true
|
||||
}
|
||||
|
||||
func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) {
|
||||
func logUserIn(reqContext *models.ReqContext, auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) {
|
||||
logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache)
|
||||
// Try to log in user via various providers
|
||||
id, err := auth.Login(logger, ignoreCache)
|
||||
id, err := auth.Login(reqContext, ignoreCache)
|
||||
if err != nil {
|
||||
details := err
|
||||
var e authproxy.Error
|
||||
@ -451,36 +453,31 @@ func (h *ContextHandler) handleError(ctx *models.ReqContext, err error, statusCo
|
||||
|
||||
func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext, orgID int64) bool {
|
||||
username := reqContext.Req.Header.Get(h.Cfg.AuthProxyHeaderName)
|
||||
auth := authproxy.New(h.Cfg, &authproxy.Options{
|
||||
RemoteCache: h.RemoteCache,
|
||||
Ctx: reqContext,
|
||||
OrgID: orgID,
|
||||
})
|
||||
|
||||
logger := log.New("auth.proxy")
|
||||
|
||||
// Bail if auth proxy is not enabled
|
||||
if !auth.IsEnabled() {
|
||||
if !h.authProxy.IsEnabled() {
|
||||
return false
|
||||
}
|
||||
|
||||
// If there is no header - we can't move forward
|
||||
if !auth.HasHeader() {
|
||||
if !h.authProxy.HasHeader(reqContext) {
|
||||
return false
|
||||
}
|
||||
|
||||
_, span := h.tracer.Start(reqContext.Req.Context(), "initContextWithAuthProxy")
|
||||
defer span.End()
|
||||
|
||||
// Check if allowed to continue with this IP
|
||||
if err := auth.IsAllowedIP(); err != nil {
|
||||
// Check if allowed continuing with this IP
|
||||
if err := h.authProxy.IsAllowedIP(reqContext.Req.RemoteAddr); err != nil {
|
||||
h.handleError(reqContext, err, 407, func(details error) {
|
||||
logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details)
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
id, err := logUserIn(auth, username, logger, false)
|
||||
id, err := logUserIn(reqContext, h.authProxy, username, logger, false)
|
||||
if err != nil {
|
||||
h.handleError(reqContext, err, 407, nil)
|
||||
return true
|
||||
@ -488,7 +485,7 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
|
||||
|
||||
logger.Debug("Got user ID, getting full user info", "userID", id)
|
||||
|
||||
user, err := auth.GetSignedInUser(id)
|
||||
user, err := h.authProxy.GetSignedInUser(id, orgID)
|
||||
if err != nil {
|
||||
// The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale
|
||||
// cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated
|
||||
@ -496,18 +493,18 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
|
||||
// we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to
|
||||
// log the user in again without the cache.
|
||||
logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id)
|
||||
if err := auth.RemoveUserFromCache(logger); err != nil {
|
||||
if err := h.authProxy.RemoveUserFromCache(reqContext); err != nil {
|
||||
if !errors.Is(err, remotecache.ErrCacheItemNotFound) {
|
||||
logger.Error("Got unexpected error when removing user from auth cache", "error", err)
|
||||
}
|
||||
}
|
||||
id, err = logUserIn(auth, username, logger, true)
|
||||
id, err = logUserIn(reqContext, h.authProxy, username, logger, true)
|
||||
if err != nil {
|
||||
h.handleError(reqContext, err, 407, nil)
|
||||
return true
|
||||
}
|
||||
|
||||
user, err = auth.GetSignedInUser(id)
|
||||
user, err = h.authProxy.GetSignedInUser(id, orgID)
|
||||
if err != nil {
|
||||
h.handleError(reqContext, err, 407, nil)
|
||||
return true
|
||||
@ -521,7 +518,7 @@ func (h *ContextHandler) initContextWithAuthProxy(reqContext *models.ReqContext,
|
||||
reqContext.IsSignedIn = true
|
||||
|
||||
// Remember user data in cache
|
||||
if err := auth.Remember(id); err != nil {
|
||||
if err := h.authProxy.Remember(reqContext, id); err != nil {
|
||||
h.handleError(reqContext, err, 500, func(details error) {
|
||||
logger.Error(
|
||||
"Failed to store user in cache",
|
||||
|
@ -15,6 +15,9 @@ type LoginServiceMock struct {
|
||||
NoExistingOrgId int64
|
||||
AlreadyExitingLogin string
|
||||
GeneratedUserId int64
|
||||
ExpectedUser *models.User
|
||||
ExpectedUserFunc func(cmd *models.UpsertUserCommand) *models.User
|
||||
ExpectedError error
|
||||
}
|
||||
|
||||
func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User, error) {
|
||||
@ -35,5 +38,10 @@ func (s LoginServiceMock) CreateUser(cmd models.CreateUserCommand) (*models.User
|
||||
}
|
||||
|
||||
func (s LoginServiceMock) UpsertUser(ctx context.Context, cmd *models.UpsertUserCommand) error {
|
||||
return nil
|
||||
if s.ExpectedUserFunc != nil {
|
||||
cmd.Result = s.ExpectedUserFunc(cmd)
|
||||
return s.ExpectedError
|
||||
}
|
||||
cmd.Result = s.ExpectedUser
|
||||
return s.ExpectedError
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user