Move middleware context handler logic to service (#29605)

* middleware: Move context handler to own service

Signed-off-by: Arve Knudsen <arve.knudsen@gmail.com>

Co-authored-by: Emil Tullsted <sakjur@users.noreply.github.com>
Co-authored-by: Will Browne <wbrowne@users.noreply.github.com>
This commit is contained in:
Arve Knudsen
2020-12-11 11:44:44 +01:00
committed by GitHub
parent d0f52d5334
commit 12661e8a9d
51 changed files with 1321 additions and 1079 deletions

View File

@@ -22,7 +22,7 @@ const (
func TestAdminAPIEndpoint(t *testing.T) { func TestAdminAPIEndpoint(t *testing.T) {
const role = models.ROLE_ADMIN const role = models.ROLE_ADMIN
t.Run("Given a server admin attempts to remove themself as an admin", func(t *testing.T) { t.Run("Given a server admin attempts to remove themselves as an admin", func(t *testing.T) {
updateCmd := dtos.AdminUpdateUserPermissionsForm{ updateCmd := dtos.AdminUpdateUserPermissionsForm{
IsGrafanaAdmin: false, IsGrafanaAdmin: false,
} }

View File

@@ -18,7 +18,7 @@ func (hs *HTTPServer) registerRoutes() {
reqEditorRole := middleware.ReqEditorRole reqEditorRole := middleware.ReqEditorRole
reqOrgAdmin := middleware.ReqOrgAdmin reqOrgAdmin := middleware.ReqOrgAdmin
reqCanAccessTeams := middleware.AdminOrFeatureEnabled(hs.Cfg.EditorsCanAdmin) reqCanAccessTeams := middleware.AdminOrFeatureEnabled(hs.Cfg.EditorsCanAdmin)
reqSnapshotPublicModeOrSignedIn := middleware.SnapshotPublicModeOrSignedIn() reqSnapshotPublicModeOrSignedIn := middleware.SnapshotPublicModeOrSignedIn(hs.Cfg)
redirectFromLegacyDashboardURL := middleware.RedirectFromLegacyDashboardURL() redirectFromLegacyDashboardURL := middleware.RedirectFromLegacyDashboardURL()
redirectFromLegacyDashboardSoloURL := middleware.RedirectFromLegacyDashboardSoloURL() redirectFromLegacyDashboardSoloURL := middleware.RedirectFromLegacyDashboardSoloURL()
redirectFromLegacyPanelEditURL := middleware.RedirectFromLegacyPanelEditURL() redirectFromLegacyPanelEditURL := middleware.RedirectFromLegacyPanelEditURL()

View File

@@ -85,7 +85,7 @@ func Success(message string) *NormalResponse {
return JSON(200, resp) return JSON(200, resp)
} }
// Error create a erroneous response // Error creates an error response.
func Error(status int, message string, err error) *NormalResponse { func Error(status int, message string, err error) *NormalResponse {
data := make(map[string]interface{}) data := make(map[string]interface{})

View File

@@ -8,9 +8,14 @@ import (
"testing" "testing"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/middleware" "github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/macaron.v1" "gopkg.in/macaron.v1"
) )
@@ -141,20 +146,68 @@ func (sc *scenarioContext) exec() {
type scenarioFunc func(c *scenarioContext) type scenarioFunc func(c *scenarioContext)
type handlerFunc func(c *models.ReqContext) Response type handlerFunc func(c *models.ReqContext) Response
func getContextHandler(t *testing.T) *contexthandler.ContextHandler {
t.Helper()
sqlStore := sqlstore.InitTestDB(t)
remoteCacheSvc := &remotecache.RemoteCache{}
cfg := setting.NewCfg()
cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{
Name: "database",
}
userAuthTokenSvc := auth.NewFakeUserAuthTokenService()
renderSvc := &fakeRenderService{}
ctxHdlr := &contexthandler.ContextHandler{}
err := registry.BuildServiceGraph([]interface{}{cfg}, []*registry.Descriptor{
{
Name: sqlstore.ServiceName,
Instance: sqlStore,
},
{
Name: remotecache.ServiceName,
Instance: remoteCacheSvc,
},
{
Name: auth.ServiceName,
Instance: userAuthTokenSvc,
},
{
Name: rendering.ServiceName,
Instance: renderSvc,
},
{
Name: contexthandler.ServiceName,
Instance: ctxHdlr,
},
})
require.NoError(t, err)
return ctxHdlr
}
func setupScenarioContext(t *testing.T, url string) *scenarioContext { func setupScenarioContext(t *testing.T, url string) *scenarioContext {
sc := &scenarioContext{ sc := &scenarioContext{
url: url, url: url,
t: t, t: t,
} }
viewsPath, _ := filepath.Abs("../../public/views") viewsPath, err := filepath.Abs("../../public/views")
require.NoError(t, err)
sc.m = macaron.New() sc.m = macaron.New()
sc.m.Use(macaron.Renderer(macaron.RenderOptions{ sc.m.Use(macaron.Renderer(macaron.RenderOptions{
Directory: viewsPath, Directory: viewsPath,
Delims: macaron.Delims{Left: "[[", Right: "]]"}, Delims: macaron.Delims{Left: "[[", Right: "]]"},
})) }))
sc.m.Use(getContextHandler(t).Middleware)
sc.m.Use(middleware.GetContextHandler(nil, nil, nil))
return sc return sc
} }
type fakeRenderService struct {
rendering.Service
}
func (s *fakeRenderService) Init() error {
return nil
}

View File

@@ -193,11 +193,11 @@ func (hs *HTTPServer) getFrontendSettingsMap(c *models.ReqContext) (map[string]i
"datasources": dataSources, "datasources": dataSources,
"minRefreshInterval": setting.MinRefreshInterval, "minRefreshInterval": setting.MinRefreshInterval,
"panels": panels, "panels": panels,
"appUrl": setting.AppUrl, "appUrl": hs.Cfg.AppURL,
"appSubUrl": setting.AppSubUrl, "appSubUrl": hs.Cfg.AppSubURL,
"allowOrgCreate": (setting.AllowUserOrgCreate && c.IsSignedIn) || c.IsGrafanaAdmin, "allowOrgCreate": (setting.AllowUserOrgCreate && c.IsSignedIn) || c.IsGrafanaAdmin,
"authProxyEnabled": setting.AuthProxyEnabled, "authProxyEnabled": setting.AuthProxyEnabled,
"ldapEnabled": setting.LDAPEnabled, "ldapEnabled": hs.Cfg.LDAPEnabled,
"alertingEnabled": setting.AlertingEnabled, "alertingEnabled": setting.AlertingEnabled,
"alertingErrorOrTimeout": setting.AlertingErrorOrTimeout, "alertingErrorOrTimeout": setting.AlertingErrorOrTimeout,
"alertingNoDataOrNullValues": setting.AlertingNoDataOrNullValues, "alertingNoDataOrNullValues": setting.AlertingNoDataOrNullValues,

View File

@@ -18,7 +18,6 @@ import (
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/middleware"
"gopkg.in/macaron.v1" "gopkg.in/macaron.v1"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
@@ -53,7 +52,7 @@ func setupTestEnvironment(t *testing.T, cfg *setting.Cfg) (*macaron.Macaron, *HT
} }
m := macaron.New() m := macaron.New()
m.Use(middleware.GetContextHandler(nil, nil, nil)) m.Use(getContextHandler(t).Middleware)
m.Use(macaron.Renderer(macaron.RenderOptions{ m.Use(macaron.Renderer(macaron.RenderOptions{
Directory: filepath.Join(setting.StaticRootPath, "views"), Directory: filepath.Join(setting.StaticRootPath, "views"),
IndentJSON: true, IndentJSON: true,
@@ -84,10 +83,12 @@ func TestHTTPServer_GetFrontendSettings_hideVersionAnonyomus(t *testing.T) {
setting.Env = "testing" setting.Env = "testing"
tests := []struct { tests := []struct {
desc string
hideVersion bool hideVersion bool
expected settings expected settings
}{ }{
{ {
desc: "Not hiding version",
hideVersion: false, hideVersion: false,
expected: settings{ expected: settings{
BuildInfo: buildInfo{ BuildInfo: buildInfo{
@@ -98,6 +99,7 @@ func TestHTTPServer_GetFrontendSettings_hideVersionAnonyomus(t *testing.T) {
}, },
}, },
{ {
desc: "Hiding version",
hideVersion: true, hideVersion: true,
expected: settings{ expected: settings{
BuildInfo: buildInfo{ BuildInfo: buildInfo{
@@ -110,6 +112,7 @@ func TestHTTPServer_GetFrontendSettings_hideVersionAnonyomus(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
hs.Cfg.AnonymousHideVersion = test.hideVersion hs.Cfg.AnonymousHideVersion = test.hideVersion
expected := test.expected expected := test.expected
@@ -118,8 +121,9 @@ func TestHTTPServer_GetFrontendSettings_hideVersionAnonyomus(t *testing.T) {
got := settings{} got := settings{}
err := json.Unmarshal(recorder.Body.Bytes(), &got) err := json.Unmarshal(recorder.Body.Bytes(), &got)
require.NoError(t, err) require.NoError(t, err)
require.GreaterOrEqual(t, 400, recorder.Code, "status codes higher than 400 indicates a failure") require.GreaterOrEqual(t, 400, recorder.Code, "status codes higher than 400 indicate a failure")
assert.EqualValues(t, expected, got) assert.EqualValues(t, expected, got)
})
} }
} }

View File

@@ -29,6 +29,7 @@ import (
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/registry" "github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/hooks" "github.com/grafana/grafana/pkg/services/hooks"
"github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login"
@@ -75,6 +76,7 @@ type HTTPServer struct {
SearchService *search.SearchService `inject:""` SearchService *search.SearchService `inject:""`
ShortURLService *shorturls.ShortURLService `inject:""` ShortURLService *shorturls.ShortURLService `inject:""`
Live *live.GrafanaLive `inject:""` Live *live.GrafanaLive `inject:""`
ContextHandler *contexthandler.ContextHandler `inject:""`
Listener net.Listener Listener net.Listener
} }
@@ -100,7 +102,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
Addr: net.JoinHostPort(setting.HttpAddr, setting.HttpPort), Addr: net.JoinHostPort(setting.HttpAddr, setting.HttpPort),
Handler: hs.macaron, Handler: hs.macaron,
} }
switch setting.Protocol { switch hs.Cfg.Protocol {
case setting.HTTP2Scheme: case setting.HTTP2Scheme:
if err := hs.configureHttp2(); err != nil { if err := hs.configureHttp2(); err != nil {
return err return err
@@ -118,7 +120,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
} }
hs.log.Info("HTTP Server Listen", "address", listener.Addr().String(), "protocol", hs.log.Info("HTTP Server Listen", "address", listener.Addr().String(), "protocol",
setting.Protocol, "subUrl", setting.AppSubUrl, "socket", setting.SocketPath) hs.Cfg.Protocol, "subUrl", hs.Cfg.AppSubURL, "socket", hs.Cfg.SocketPath)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@@ -133,7 +135,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
} }
}() }()
switch setting.Protocol { switch hs.Cfg.Protocol {
case setting.HTTPScheme, setting.SocketScheme: case setting.HTTPScheme, setting.SocketScheme:
if err := hs.httpSrv.Serve(listener); err != nil { if err := hs.httpSrv.Serve(listener); err != nil {
if errors.Is(err, http.ErrServerClosed) { if errors.Is(err, http.ErrServerClosed) {
@@ -151,7 +153,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
return err return err
} }
default: default:
panic(fmt.Sprintf("Unhandled protocol %q", setting.Protocol)) panic(fmt.Sprintf("Unhandled protocol %q", hs.Cfg.Protocol))
} }
wg.Wait() wg.Wait()
@@ -164,7 +166,7 @@ func (hs *HTTPServer) getListener() (net.Listener, error) {
return hs.Listener, nil return hs.Listener, nil
} }
switch setting.Protocol { switch hs.Cfg.Protocol {
case setting.HTTPScheme, setting.HTTPSScheme, setting.HTTP2Scheme: case setting.HTTPScheme, setting.HTTPSScheme, setting.HTTP2Scheme:
listener, err := net.Listen("tcp", hs.httpSrv.Addr) listener, err := net.Listen("tcp", hs.httpSrv.Addr)
if err != nil { if err != nil {
@@ -172,21 +174,21 @@ func (hs *HTTPServer) getListener() (net.Listener, error) {
} }
return listener, nil return listener, nil
case setting.SocketScheme: case setting.SocketScheme:
listener, err := net.ListenUnix("unix", &net.UnixAddr{Name: setting.SocketPath, Net: "unix"}) listener, err := net.ListenUnix("unix", &net.UnixAddr{Name: hs.Cfg.SocketPath, Net: "unix"})
if err != nil { if err != nil {
return nil, errutil.Wrapf(err, "failed to open listener for socket %s", setting.SocketPath) return nil, errutil.Wrapf(err, "failed to open listener for socket %s", hs.Cfg.SocketPath)
} }
// Make socket writable by group // Make socket writable by group
// nolint:gosec // nolint:gosec
if err := os.Chmod(setting.SocketPath, 0660); err != nil { if err := os.Chmod(hs.Cfg.SocketPath, 0660); err != nil {
return nil, errutil.Wrapf(err, "failed to change socket permissions") return nil, errutil.Wrapf(err, "failed to change socket permissions")
} }
return listener, nil return listener, nil
default: default:
hs.log.Error("Invalid protocol", "protocol", setting.Protocol) hs.log.Error("Invalid protocol", "protocol", hs.Cfg.Protocol)
return nil, fmt.Errorf("invalid protocol %q", setting.Protocol) return nil, fmt.Errorf("invalid protocol %q", hs.Cfg.Protocol)
} }
} }
@@ -271,7 +273,7 @@ func (hs *HTTPServer) configureHttp2() error {
} }
func (hs *HTTPServer) newMacaron() *macaron.Macaron { func (hs *HTTPServer) newMacaron() *macaron.Macaron {
macaron.Env = setting.Env macaron.Env = hs.Cfg.Env
m := macaron.New() m := macaron.New()
// automatically set HEAD for every GET // automatically set HEAD for every GET
@@ -294,13 +296,13 @@ func (hs *HTTPServer) applyRoutes() {
func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() { func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
m := hs.macaron m := hs.macaron
m.Use(middleware.Logger()) m.Use(middleware.Logger(hs.Cfg))
if setting.EnableGzip { if setting.EnableGzip {
m.Use(middleware.Gziper()) m.Use(middleware.Gziper())
} }
m.Use(middleware.Recovery()) m.Use(middleware.Recovery(hs.Cfg))
for _, route := range plugins.StaticRoutes { for _, route := range plugins.StaticRoutes {
pluginRoute := path.Join("/public/plugins/", route.PluginId) pluginRoute := path.Join("/public/plugins/", route.PluginId)
@@ -316,7 +318,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
hs.mapStatic(m, hs.Cfg.ImagesDir, "", "/public/img/attachments") hs.mapStatic(m, hs.Cfg.ImagesDir, "", "/public/img/attachments")
} }
m.Use(middleware.AddDefaultResponseHeaders()) m.Use(middleware.AddDefaultResponseHeaders(hs.Cfg))
if setting.ServeFromSubPath && setting.AppSubUrl != "" { if setting.ServeFromSubPath && setting.AppSubUrl != "" {
m.SetURLPrefix(setting.AppSubUrl) m.SetURLPrefix(setting.AppSubUrl)
@@ -334,16 +336,12 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
m.Use(hs.apiHealthHandler) m.Use(hs.apiHealthHandler)
m.Use(hs.metricsEndpoint) m.Use(hs.metricsEndpoint)
m.Use(middleware.GetContextHandler( m.Use(hs.ContextHandler.Middleware)
hs.AuthTokenService,
hs.RemoteCacheService,
hs.RenderService,
))
m.Use(middleware.OrgRedirect()) m.Use(middleware.OrgRedirect())
// needs to be after context handler // needs to be after context handler
if setting.EnforceDomain { if setting.EnforceDomain {
m.Use(middleware.ValidateHostHeader(setting.Domain)) m.Use(middleware.ValidateHostHeader(hs.Cfg.Domain))
} }
m.Use(middleware.HandleNoCacheHeader()) m.Use(middleware.HandleNoCacheHeader())
@@ -433,7 +431,7 @@ func (hs *HTTPServer) mapStatic(m *macaron.Macaron, rootDir string, dir string,
} }
} }
if setting.Env == setting.Dev { if hs.Cfg.Env == setting.Dev {
headers = func(c *macaron.Context) { headers = func(c *macaron.Context) {
c.Resp.Header().Set("Cache-Control", "max-age=0, must-revalidate, no-cache") c.Resp.Header().Set("Cache-Control", "max-age=0, must-revalidate, no-cache")
} }

View File

@@ -300,7 +300,7 @@ func (hs *HTTPServer) getNavTree(c *models.ReqContext, hasEditPerm bool) ([]*dto
{Text: "Stats", Id: "server-stats", Url: setting.AppSubUrl + "/admin/stats", Icon: "graph-bar"}, {Text: "Stats", Id: "server-stats", Url: setting.AppSubUrl + "/admin/stats", Icon: "graph-bar"},
} }
if setting.LDAPEnabled { if hs.Cfg.LDAPEnabled {
adminNavLinks = append(adminNavLinks, &dtos.NavLink{ adminNavLinks = append(adminNavLinks, &dtos.NavLink{
Text: "LDAP", Id: "ldap", Url: setting.AppSubUrl + "/admin/ldap", Icon: "book", Text: "LDAP", Id: "ldap", Url: setting.AppSubUrl + "/admin/ldap", Icon: "book",
}) })
@@ -371,7 +371,7 @@ func (hs *HTTPServer) setIndexViewData(c *models.ReqContext) (*dtos.IndexViewDat
// special case when doing localhost call from image renderer // special case when doing localhost call from image renderer
if c.IsRenderCall && !hs.Cfg.ServeFromSubPath { if c.IsRenderCall && !hs.Cfg.ServeFromSubPath {
appURL = fmt.Sprintf("%s://localhost:%s", setting.Protocol, setting.HttpPort) appURL = fmt.Sprintf("%s://localhost:%s", hs.Cfg.Protocol, setting.HttpPort)
appSubURL = "" appSubURL = ""
settings["appSubUrl"] = "" settings["appSubUrl"] = ""
} }

View File

@@ -116,7 +116,7 @@ func (hs *HTTPServer) GetLDAPStatus(c *models.ReqContext) Response {
return Error(http.StatusBadRequest, "LDAP is not enabled", nil) return Error(http.StatusBadRequest, "LDAP is not enabled", nil)
} }
ldapConfig, err := getLDAPConfig() ldapConfig, err := getLDAPConfig(hs.Cfg)
if err != nil { if err != nil {
return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration. Please verify the configuration and try again", err) return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration. Please verify the configuration and try again", err)
@@ -158,7 +158,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) Response {
return Error(http.StatusBadRequest, "LDAP is not enabled", nil) return Error(http.StatusBadRequest, "LDAP is not enabled", nil)
} }
ldapConfig, err := getLDAPConfig() ldapConfig, err := getLDAPConfig(hs.Cfg)
if err != nil { if err != nil {
return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration. Please verify the configuration and try again", err) return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration. Please verify the configuration and try again", err)
} }
@@ -217,7 +217,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) Response {
upsertCmd := &models.UpsertUserCommand{ upsertCmd := &models.UpsertUserCommand{
ReqContext: c, ReqContext: c,
ExternalUser: user, ExternalUser: user,
SignupAllowed: setting.LDAPAllowSignup, SignupAllowed: hs.Cfg.LDAPAllowSignup,
} }
err = bus.Dispatch(upsertCmd) err = bus.Dispatch(upsertCmd)
@@ -235,7 +235,7 @@ func (hs *HTTPServer) GetUserFromLDAP(c *models.ReqContext) Response {
return Error(http.StatusBadRequest, "LDAP is not enabled", nil) return Error(http.StatusBadRequest, "LDAP is not enabled", nil)
} }
ldapConfig, err := getLDAPConfig() ldapConfig, err := getLDAPConfig(hs.Cfg)
if err != nil { if err != nil {
return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration", err) return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration", err)

View File

@@ -74,7 +74,7 @@ func getUserFromLDAPContext(t *testing.T, requestURL string) *scenarioContext {
} }
func TestGetUserFromLDAPAPIEndpoint_UserNotFound(t *testing.T) { func TestGetUserFromLDAPAPIEndpoint_UserNotFound(t *testing.T) {
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
} }
@@ -131,7 +131,7 @@ func TestGetUserFromLDAPAPIEndpoint_OrgNotfound(t *testing.T) {
return nil return nil
}) })
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
} }
@@ -193,7 +193,7 @@ func TestGetUserFromLDAPAPIEndpoint(t *testing.T) {
return nil return nil
}) })
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
} }
@@ -273,7 +273,7 @@ func TestGetUserFromLDAPAPIEndpoint_WithTeamHandler(t *testing.T) {
return nil return nil
}) })
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
} }
@@ -349,7 +349,7 @@ func TestGetLDAPStatusAPIEndpoint(t *testing.T) {
{Host: "10.0.0.5", Port: 361, Available: false, Error: errors.New("something is awfully wrong")}, {Host: "10.0.0.5", Port: 361, Available: false, Error: errors.New("something is awfully wrong")},
} }
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
} }
@@ -412,7 +412,7 @@ func postSyncUserWithLDAPContext(t *testing.T, requestURL string, preHook func(t
func TestPostSyncUserWithLDAPAPIEndpoint_Success(t *testing.T) { func TestPostSyncUserWithLDAPAPIEndpoint_Success(t *testing.T) {
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) {
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
} }
@@ -457,7 +457,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_Success(t *testing.T) {
func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotFound(t *testing.T) { func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotFound(t *testing.T) {
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) {
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
} }
@@ -485,7 +485,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotFound(t *testing.T) {
func TestPostSyncUserWithLDAPAPIEndpoint_WhenGrafanaAdmin(t *testing.T) { func TestPostSyncUserWithLDAPAPIEndpoint_WhenGrafanaAdmin(t *testing.T) {
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) {
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
} }
@@ -528,7 +528,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_WhenGrafanaAdmin(t *testing.T) {
func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotInLDAP(t *testing.T) { func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotInLDAP(t *testing.T) {
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) {
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
} }

View File

@@ -13,7 +13,7 @@ import (
"github.com/grafana/grafana/pkg/infra/metrics" "github.com/grafana/grafana/pkg/infra/metrics"
"github.com/grafana/grafana/pkg/infra/network" "github.com/grafana/grafana/pkg/infra/network"
"github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/middleware" "github.com/grafana/grafana/pkg/middleware/cookies"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
@@ -61,12 +61,12 @@ func (hs *HTTPServer) ValidateRedirectTo(redirectTo string) error {
return nil return nil
} }
func (hs *HTTPServer) CookieOptionsFromCfg() middleware.CookieOptions { func (hs *HTTPServer) CookieOptionsFromCfg() cookies.CookieOptions {
path := "/" path := "/"
if len(hs.Cfg.AppSubURL) > 0 { if len(hs.Cfg.AppSubURL) > 0 {
path = hs.Cfg.AppSubURL path = hs.Cfg.AppSubURL
} }
return middleware.CookieOptions{ return cookies.CookieOptions{
Path: path, Path: path,
Secure: hs.Cfg.CookieSecure, Secure: hs.Cfg.CookieSecure,
SameSiteDisabled: hs.Cfg.CookieSameSiteDisabled, SameSiteDisabled: hs.Cfg.CookieSameSiteDisabled,
@@ -101,7 +101,7 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) {
// therefore the loginError should be passed to the view data // therefore the loginError should be passed to the view data
// and the view should return immediately before attempting // and the view should return immediately before attempting
// to login again via OAuth and enter to a redirect loop // to login again via OAuth and enter to a redirect loop
middleware.DeleteCookie(c.Resp, LoginErrorCookieName, hs.CookieOptionsFromCfg) cookies.DeleteCookie(c.Resp, LoginErrorCookieName, hs.CookieOptionsFromCfg)
viewData.Settings["loginError"] = loginError viewData.Settings["loginError"] = loginError
c.HTML(200, getViewIndex(), viewData) c.HTML(200, getViewIndex(), viewData)
return return
@@ -113,7 +113,7 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) {
if c.IsSignedIn { if c.IsSignedIn {
// Assign login token to auth proxy users if enable_login_token = true // Assign login token to auth proxy users if enable_login_token = true
if setting.AuthProxyEnabled && setting.AuthProxyEnableLoginToken { if hs.Cfg.AuthProxyEnabled && hs.Cfg.AuthProxyEnableLoginToken {
user := &models.User{Id: c.SignedInUser.UserId, Email: c.SignedInUser.Email, Login: c.SignedInUser.Login} user := &models.User{Id: c.SignedInUser.UserId, Email: c.SignedInUser.Email, Login: c.SignedInUser.Login}
err := hs.loginUserWithUser(user, c) err := hs.loginUserWithUser(user, c)
if err != nil { if err != nil {
@@ -129,7 +129,7 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) {
log.Debugf("Ignored invalid redirect_to cookie value: %v", redirectTo) log.Debugf("Ignored invalid redirect_to cookie value: %v", redirectTo)
redirectTo = hs.Cfg.AppSubURL + "/" redirectTo = hs.Cfg.AppSubURL + "/"
} }
middleware.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg) cookies.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg)
c.Redirect(redirectTo) c.Redirect(redirectTo)
return return
} }
@@ -196,6 +196,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext, cmd dtos.LoginCommand) Res
Username: cmd.User, Username: cmd.User,
Password: cmd.Password, Password: cmd.Password,
IpAddress: c.Req.RemoteAddr, IpAddress: c.Req.RemoteAddr,
Cfg: hs.Cfg,
} }
err := bus.Dispatch(authQuery) err := bus.Dispatch(authQuery)
@@ -236,7 +237,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext, cmd dtos.LoginCommand) Res
} else { } else {
log.Infof("Ignored invalid redirect_to cookie value: %v", redirectTo) log.Infof("Ignored invalid redirect_to cookie value: %v", redirectTo)
} }
middleware.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg) cookies.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg)
} }
metrics.MApiLoginPost.Inc() metrics.MApiLoginPost.Inc()
@@ -263,7 +264,7 @@ func (hs *HTTPServer) loginUserWithUser(user *models.User, c *models.ReqContext)
} }
hs.log.Info("Successful Login", "User", user.Email) hs.log.Info("Successful Login", "User", user.Email)
middleware.WriteSessionCookie(c, userToken.UnhashedToken, hs.Cfg.LoginMaxLifetime) cookies.WriteSessionCookie(c, hs.Cfg, userToken.UnhashedToken, hs.Cfg.LoginMaxLifetime)
return nil return nil
} }
@@ -278,7 +279,7 @@ func (hs *HTTPServer) Logout(c *models.ReqContext) {
hs.log.Error("failed to revoke auth token", "error", err) hs.log.Error("failed to revoke auth token", "error", err)
} }
middleware.WriteSessionCookie(c, "", -1) cookies.WriteSessionCookie(c, hs.Cfg, "", -1)
if setting.SignoutRedirectUrl != "" { if setting.SignoutRedirectUrl != "" {
c.Redirect(setting.SignoutRedirectUrl) c.Redirect(setting.SignoutRedirectUrl)
@@ -309,7 +310,7 @@ func (hs *HTTPServer) trySetEncryptedCookie(ctx *models.ReqContext, cookieName s
return err return err
} }
middleware.WriteCookie(ctx.Resp, cookieName, hex.EncodeToString(encryptedError), 60, hs.CookieOptionsFromCfg) cookies.WriteCookie(ctx.Resp, cookieName, hex.EncodeToString(encryptedError), 60, hs.CookieOptionsFromCfg)
return nil return nil
} }

View File

@@ -18,7 +18,7 @@ import (
"github.com/grafana/grafana/pkg/infra/metrics" "github.com/grafana/grafana/pkg/infra/metrics"
"github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/login/social" "github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/middleware" "github.com/grafana/grafana/pkg/middleware/cookies"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
) )
@@ -81,7 +81,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
} }
hashedState := hashStatecode(state, setting.OAuthService.OAuthInfos[name].ClientSecret) hashedState := hashStatecode(state, setting.OAuthService.OAuthInfos[name].ClientSecret)
middleware.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg) cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
if setting.OAuthService.OAuthInfos[name].HostedDomain == "" { if setting.OAuthService.OAuthInfos[name].HostedDomain == "" {
ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline)) ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline))
} else { } else {
@@ -93,7 +93,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
cookieState := ctx.GetCookie(OauthStateCookieName) cookieState := ctx.GetCookie(OauthStateCookieName)
// delete cookie // delete cookie
middleware.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg) cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
if cookieState == "" { if cookieState == "" {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{ hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
@@ -192,7 +192,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
if redirectTo, err := url.QueryUnescape(ctx.GetCookie("redirect_to")); err == nil && len(redirectTo) > 0 { if redirectTo, err := url.QueryUnescape(ctx.GetCookie("redirect_to")); err == nil && len(redirectTo) > 0 {
if err := hs.ValidateRedirectTo(redirectTo); err == nil { if err := hs.ValidateRedirectTo(redirectTo); err == nil {
middleware.DeleteCookie(ctx.Resp, "redirect_to", hs.CookieOptionsFromCfg) cookies.DeleteCookie(ctx.Resp, "redirect_to", hs.CookieOptionsFromCfg)
ctx.Redirect(redirectTo) ctx.Redirect(redirectTo)
return return
} }

View File

@@ -592,8 +592,8 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte
setting.OAuthService = &setting.OAuther{} setting.OAuthService = &setting.OAuther{}
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo) setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
setting.AuthProxyEnabled = true hs.Cfg.AuthProxyEnabled = true
setting.AuthProxyEnableLoginToken = enableLoginToken hs.Cfg.AuthProxyEnableLoginToken = enableLoginToken
sc.m.Get(sc.url, sc.defaultHandler) sc.m.Get(sc.url, sc.defaultHandler)
sc.fakeReqNoAssertions("GET", sc.url).exec() sc.fakeReqNoAssertions("GET", sc.url).exec()

View File

@@ -23,8 +23,17 @@ var (
defaultMaxCacheExpiration = time.Hour * 24 defaultMaxCacheExpiration = time.Hour * 24
) )
const (
ServiceName = "RemoteCache"
)
func init() { func init() {
registry.RegisterService(&RemoteCache{}) rc := &RemoteCache{}
registry.Register(&registry.Descriptor{
Name: ServiceName,
Instance: rc,
InitPriority: registry.Medium,
})
} }
// CacheStorage allows the caller to set, get and delete items in the cache. // CacheStorage allows the caller to set, get and delete items in the cache.

View File

@@ -25,12 +25,12 @@ var (
var loginLogger = log.New("login") var loginLogger = log.New("login")
func Init() { func Init() {
bus.AddHandler("auth", AuthenticateUser) bus.AddHandler("auth", authenticateUser)
} }
// AuthenticateUser authenticates the user via username & password // authenticateUser authenticates the user via username & password
func AuthenticateUser(query *models.LoginUserQuery) error { func authenticateUser(query *models.LoginUserQuery) error {
if err := validateLoginAttempts(query.Username); err != nil { if err := validateLoginAttempts(query); err != nil {
return err return err
} }

View File

@@ -21,7 +21,7 @@ func TestAuthenticateUser(t *testing.T) {
Username: "user", Username: "user",
Password: "", Password: "",
} }
err := AuthenticateUser(&loginQuery) err := authenticateUser(&loginQuery)
Convey("login should fail", func() { Convey("login should fail", func() {
So(sc.grafanaLoginWasCalled, ShouldBeFalse) So(sc.grafanaLoginWasCalled, ShouldBeFalse)
@@ -37,7 +37,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, nil, sc) mockLoginUsingLDAP(true, nil, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery) err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() { Convey("it should result in", func() {
So(err, ShouldEqual, ErrTooManyLoginAttempts) So(err, ShouldEqual, ErrTooManyLoginAttempts)
@@ -55,7 +55,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc) mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery) err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() { Convey("it should result in", func() {
So(err, ShouldEqual, nil) So(err, ShouldEqual, nil)
@@ -74,7 +74,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc) mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery) err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() { Convey("it should result in", func() {
So(err, ShouldEqual, customErr) So(err, ShouldEqual, customErr)
@@ -92,7 +92,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(false, nil, sc) mockLoginUsingLDAP(false, nil, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery) err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() { Convey("it should result in", func() {
So(err, ShouldEqual, models.ErrUserNotFound) So(err, ShouldEqual, models.ErrUserNotFound)
@@ -110,7 +110,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc) mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery) err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() { Convey("it should result in", func() {
So(err, ShouldEqual, ErrInvalidCredentials) So(err, ShouldEqual, ErrInvalidCredentials)
@@ -128,7 +128,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, nil, sc) mockLoginUsingLDAP(true, nil, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery) err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() { Convey("it should result in", func() {
So(err, ShouldBeNil) So(err, ShouldBeNil)
@@ -147,7 +147,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, customErr, sc) mockLoginUsingLDAP(true, customErr, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery) err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() { Convey("it should result in", func() {
So(err, ShouldEqual, customErr) So(err, ShouldEqual, customErr)
@@ -165,7 +165,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc) mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc) mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery) err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() { Convey("it should result in", func() {
So(err, ShouldEqual, ErrInvalidCredentials) So(err, ShouldEqual, ErrInvalidCredentials)
@@ -203,7 +203,7 @@ func mockLoginUsingLDAP(enabled bool, err error, sc *authScenarioContext) {
} }
func mockLoginAttemptValidation(err error, sc *authScenarioContext) { func mockLoginAttemptValidation(err error, sc *authScenarioContext) {
validateLoginAttempts = func(username string) error { validateLoginAttempts = func(*models.LoginUserQuery) error {
sc.loginAttemptValidationWasCalled = true sc.loginAttemptValidationWasCalled = true
return err return err
} }

View File

@@ -5,7 +5,6 @@ import (
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
) )
var ( var (
@@ -13,13 +12,13 @@ var (
loginAttemptsWindow = time.Minute * 5 loginAttemptsWindow = time.Minute * 5
) )
var validateLoginAttempts = func(username string) error { var validateLoginAttempts = func(query *models.LoginUserQuery) error {
if setting.DisableBruteForceLoginProtection { if query.Cfg.DisableBruteForceLoginProtection {
return nil return nil
} }
loginAttemptCountQuery := models.GetUserLoginAttemptCountQuery{ loginAttemptCountQuery := models.GetUserLoginAttemptCountQuery{
Username: username, Username: query.Username,
Since: time.Now().Add(-loginAttemptsWindow), Since: time.Now().Add(-loginAttemptsWindow),
} }
@@ -35,7 +34,7 @@ var validateLoginAttempts = func(username string) error {
} }
var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error { var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error {
if setting.DisableBruteForceLoginProtection { if query.Cfg.DisableBruteForceLoginProtection {
return nil return nil
} }

View File

@@ -12,11 +12,16 @@ import (
func TestLoginAttemptsValidation(t *testing.T) { func TestLoginAttemptsValidation(t *testing.T) {
Convey("Validate login attempts", t, func() { Convey("Validate login attempts", t, func() {
Convey("Given brute force login protection enabled", func() { Convey("Given brute force login protection enabled", func() {
setting.DisableBruteForceLoginProtection = false cfg := setting.NewCfg()
cfg.DisableBruteForceLoginProtection = false
query := &models.LoginUserQuery{
Username: "user",
Cfg: cfg,
}
Convey("When user login attempt count equals max-1 ", func() { Convey("When user login attempt count equals max-1 ", func() {
withLoginAttempts(maxInvalidLoginAttempts - 1) withLoginAttempts(maxInvalidLoginAttempts - 1)
err := validateLoginAttempts("user") err := validateLoginAttempts(query)
Convey("it should not result in error", func() { Convey("it should not result in error", func() {
So(err, ShouldBeNil) So(err, ShouldBeNil)
@@ -25,7 +30,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Convey("When user login attempt count equals max ", func() { Convey("When user login attempt count equals max ", func() {
withLoginAttempts(maxInvalidLoginAttempts) withLoginAttempts(maxInvalidLoginAttempts)
err := validateLoginAttempts("user") err := validateLoginAttempts(query)
Convey("it should result in too many login attempts error", func() { Convey("it should result in too many login attempts error", func() {
So(err, ShouldEqual, ErrTooManyLoginAttempts) So(err, ShouldEqual, ErrTooManyLoginAttempts)
@@ -34,7 +39,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Convey("When user login attempt count is greater than max ", func() { Convey("When user login attempt count is greater than max ", func() {
withLoginAttempts(maxInvalidLoginAttempts + 5) withLoginAttempts(maxInvalidLoginAttempts + 5)
err := validateLoginAttempts("user") err := validateLoginAttempts(query)
Convey("it should result in too many login attempts error", func() { Convey("it should result in too many login attempts error", func() {
So(err, ShouldEqual, ErrTooManyLoginAttempts) So(err, ShouldEqual, ErrTooManyLoginAttempts)
@@ -54,6 +59,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Username: "user", Username: "user",
Password: "pwd", Password: "pwd",
IpAddress: "192.168.1.1:56433", IpAddress: "192.168.1.1:56433",
Cfg: setting.NewCfg(),
}) })
So(err, ShouldBeNil) So(err, ShouldBeNil)
@@ -66,11 +72,16 @@ func TestLoginAttemptsValidation(t *testing.T) {
}) })
Convey("Given brute force login protection disabled", func() { Convey("Given brute force login protection disabled", func() {
setting.DisableBruteForceLoginProtection = true cfg := setting.NewCfg()
cfg.DisableBruteForceLoginProtection = true
query := &models.LoginUserQuery{
Username: "user",
Cfg: cfg,
}
Convey("When user login attempt count equals max-1 ", func() { Convey("When user login attempt count equals max-1 ", func() {
withLoginAttempts(maxInvalidLoginAttempts - 1) withLoginAttempts(maxInvalidLoginAttempts - 1)
err := validateLoginAttempts("user") err := validateLoginAttempts(query)
Convey("it should not result in error", func() { Convey("it should not result in error", func() {
So(err, ShouldBeNil) So(err, ShouldBeNil)
@@ -79,7 +90,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Convey("When user login attempt count equals max ", func() { Convey("When user login attempt count equals max ", func() {
withLoginAttempts(maxInvalidLoginAttempts) withLoginAttempts(maxInvalidLoginAttempts)
err := validateLoginAttempts("user") err := validateLoginAttempts(query)
Convey("it should not result in error", func() { Convey("it should not result in error", func() {
So(err, ShouldBeNil) So(err, ShouldBeNil)
@@ -88,7 +99,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Convey("When user login attempt count is greater than max ", func() { Convey("When user login attempt count is greater than max ", func() {
withLoginAttempts(maxInvalidLoginAttempts + 5) withLoginAttempts(maxInvalidLoginAttempts + 5)
err := validateLoginAttempts("user") err := validateLoginAttempts(query)
Convey("it should not result in error", func() { Convey("it should not result in error", func() {
So(err, ShouldBeNil) So(err, ShouldBeNil)
@@ -97,7 +108,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Convey("When saving invalid login attempt", func() { Convey("When saving invalid login attempt", func() {
defer bus.ClearBusHandlers() defer bus.ClearBusHandlers()
createLoginAttemptCmd := (*models.CreateLoginAttemptCommand)(nil) var createLoginAttemptCmd *models.CreateLoginAttemptCommand
bus.AddHandler("test", func(cmd *models.CreateLoginAttemptCommand) error { bus.AddHandler("test", func(cmd *models.CreateLoginAttemptCommand) error {
createLoginAttemptCmd = cmd createLoginAttemptCmd = cmd
@@ -108,6 +119,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Username: "user", Username: "user",
Password: "pwd", Password: "pwd",
IpAddress: "192.168.1.1:56433", IpAddress: "192.168.1.1:56433",
Cfg: cfg,
}) })
So(err, ShouldBeNil) So(err, ShouldBeNil)

View File

@@ -33,7 +33,7 @@ var loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) {
return false, nil return false, nil
} }
config, err := getLDAPConfig() config, err := getLDAPConfig(query.Cfg)
if err != nil { if err != nil {
return true, errutil.Wrap("Failed to get LDAP config", err) return true, errutil.Wrap("Failed to get LDAP config", err)
} }

View File

@@ -20,7 +20,7 @@ func TestLDAPLogin(t *testing.T) {
LDAPLoginScenario("When login", func(sc *LDAPLoginScenarioContext) { LDAPLoginScenario("When login", func(sc *LDAPLoginScenarioContext) {
sc.withLoginResult(false) sc.withLoginResult(false)
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
config := &ldap.Config{ config := &ldap.Config{
Servers: []*ldap.ServerConfig{}, Servers: []*ldap.ServerConfig{},
} }
@@ -150,7 +150,14 @@ func LDAPLoginScenario(desc string, fn LDAPLoginScenarioFunc) {
LDAPAuthenticatorMock: mock, LDAPAuthenticatorMock: mock,
} }
getLDAPConfig = func() (*ldap.Config, error) { origNewLDAP := newLDAP
origGetLDAPConfig := getLDAPConfig
defer func() {
newLDAP = origNewLDAP
getLDAPConfig = origGetLDAPConfig
}()
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
config := &ldap.Config{ config := &ldap.Config{
Servers: []*ldap.ServerConfig{ Servers: []*ldap.ServerConfig{
{ {
@@ -166,11 +173,6 @@ func LDAPLoginScenario(desc string, fn LDAPLoginScenarioFunc) {
return mock return mock
} }
defer func() {
newLDAP = multildap.New
getLDAPConfig = multildap.GetConfig
}()
fn(sc) fn(sc)
}) })
} }

View File

@@ -8,9 +8,9 @@ import (
macaron "gopkg.in/macaron.v1" macaron "gopkg.in/macaron.v1"
"github.com/grafana/grafana/pkg/middleware/cookies"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
) )
type AuthOptions struct { type AuthOptions struct {
@@ -18,22 +18,6 @@ type AuthOptions struct {
ReqSignedIn bool ReqSignedIn bool
} }
func getApiKey(c *models.ReqContext) string {
header := c.Req.Header.Get("Authorization")
parts := strings.SplitN(header, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
key := parts[1]
return key
}
username, password, err := util.DecodeBasicAuthHeader(header)
if err == nil && username == "api_key" {
return password
}
return ""
}
func accessForbidden(c *models.ReqContext) { func accessForbidden(c *models.ReqContext) {
if c.IsApiRequest() { if c.IsApiRequest() {
c.JsonApiErr(403, "Permission denied", nil) c.JsonApiErr(403, "Permission denied", nil)
@@ -57,7 +41,7 @@ func notAuthorized(c *models.ReqContext) {
// remove any forceLogin=true params // remove any forceLogin=true params
redirectTo = removeForceLoginParams(redirectTo) redirectTo = removeForceLoginParams(redirectTo)
WriteCookie(c.Resp, "redirect_to", url.QueryEscape(redirectTo), 0, nil) cookies.WriteCookie(c.Resp, "redirect_to", url.QueryEscape(redirectTo), 0, nil)
c.Redirect(setting.AppSubUrl + "/login") c.Redirect(setting.AppSubUrl + "/login")
} }
@@ -135,9 +119,9 @@ func AdminOrFeatureEnabled(enabled bool) macaron.Handler {
} }
} }
func SnapshotPublicModeOrSignedIn() macaron.Handler { func SnapshotPublicModeOrSignedIn(cfg *setting.Cfg) macaron.Handler {
return func(c *models.ReqContext) { return func(c *models.ReqContext) {
if setting.SnapshotPublicMode { if cfg.SnapshotPublicMode {
return return
} }

View File

@@ -1,133 +0,0 @@
package middleware
import (
"errors"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/middleware/authproxy"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
)
var header = setting.AuthProxyHeaderName
func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) {
logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache)
// Try to log in user via various providers
id, err := auth.Login(logger, ignoreCache)
if err != nil {
details := err
var e authproxy.Error
if errors.As(err, &e) {
details = e.DetailsError
}
logger.Error("Failed to login", "username", username, "message", err.Error(), "error", details,
"ignoreCache", ignoreCache)
return 0, err
}
return id, nil
}
// handleError calls ctx.Handle with the error message and the underlying error.
// If the error is of type authproxy.Error, its DetailsError is unwrapped and passed to ctx.Handle.
// If a callback is provided, it's called with either err.DetailsError, if err is of type
// authproxy.Error, otherwise err itself.
func handleError(ctx *models.ReqContext, err error, statusCode int, cb func(err error)) {
details := err
var e authproxy.Error
if errors.As(err, &e) {
details = e.DetailsError
}
ctx.Handle(statusCode, err.Error(), details)
if cb != nil {
cb(details)
}
}
func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *models.ReqContext, orgID int64) bool {
username := ctx.Req.Header.Get(header)
auth := authproxy.New(&authproxy.Options{
Store: store,
Ctx: ctx,
OrgID: orgID,
})
logger := log.New("auth.proxy")
// Bail if auth proxy is not enabled
if !auth.IsEnabled() {
return false
}
// If there is no header - we can't move forward
if !auth.HasHeader() {
return false
}
// Check if allowed to continue with this IP
if err := auth.IsAllowedIP(); err != nil {
handleError(ctx, err, 407, func(details error) {
logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details)
})
return true
}
id, err := logUserIn(auth, username, logger, false)
if err != nil {
handleError(ctx, err, 407, nil)
return true
}
logger.Debug("Got user ID, getting full user info", "userID", id)
user, e := auth.GetSignedUser(id)
if e != nil {
// The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale
// cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated
// because cache keys are computed from request header values and not just the user ID. Meaning that
// we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to
// log the user in again without the cache.
logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id)
if err := auth.RemoveUserFromCache(logger); err != nil {
if !errors.Is(err, remotecache.ErrCacheItemNotFound) {
logger.Error("Got unexpected error when removing user from auth cache", "error", err)
}
}
id, err = logUserIn(auth, username, logger, true)
if err != nil {
handleError(ctx, err, 407, nil)
return true
}
user, err = auth.GetSignedUser(id)
if err != nil {
handleError(ctx, err, 407, nil)
return true
}
}
logger.Debug("Successfully got user info", "userID", user.UserId, "username", user.Login)
// Add user info to context
ctx.SignedInUser = user
ctx.IsSignedIn = true
// Remember user data in cache
if err := auth.Remember(id); err != nil {
handleError(ctx, err, 500, func(details error) {
logger.Error(
"Failed to store user in cache",
"username", username,
"message", e.Error(),
"error", details,
)
})
return true
}
return true
}

View File

@@ -33,16 +33,10 @@ func TestMiddlewareAuth(t *testing.T) {
t.Run("Anonymous auth enabled", func(t *testing.T) { t.Run("Anonymous auth enabled", func(t *testing.T) {
const orgID int64 = 1 const orgID int64 = 1
origEnabled := setting.AnonymousEnabled configure := func(cfg *setting.Cfg) {
t.Cleanup(func() { cfg.AnonymousEnabled = true
setting.AnonymousEnabled = origEnabled cfg.AnonymousOrgName = "test"
}) }
origName := setting.AnonymousOrgName
t.Cleanup(func() {
setting.AnonymousOrgName = origName
})
setting.AnonymousEnabled = true
setting.AnonymousOrgName = "test"
middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func( middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func(
t *testing.T, sc *scenarioContext) { t *testing.T, sc *scenarioContext) {
@@ -59,7 +53,7 @@ func TestMiddlewareAuth(t *testing.T) {
location, ok := sc.resp.Header()["Location"] location, ok := sc.resp.Header()["Location"]
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "/login", location[0]) assert.Equal(t, "/login", location[0])
}) }, configure)
middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func( middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func(
t *testing.T, sc *scenarioContext) { t *testing.T, sc *scenarioContext) {
@@ -73,7 +67,7 @@ func TestMiddlewareAuth(t *testing.T) {
sc.fakeReq("GET", fmt.Sprintf("/secure?orgId=%d", orgID)).exec() sc.fakeReq("GET", fmt.Sprintf("/secure?orgId=%d", orgID)).exec()
assert.Equal(t, 200, sc.resp.Code) assert.Equal(t, 200, sc.resp.Code)
}) }, configure)
middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func( middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func(
t *testing.T, sc *scenarioContext) { t *testing.T, sc *scenarioContext) {
@@ -90,20 +84,20 @@ func TestMiddlewareAuth(t *testing.T) {
location, ok := sc.resp.Header()["Location"] location, ok := sc.resp.Header()["Location"]
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "/login", location[0]) assert.Equal(t, "/login", location[0])
}) }, configure)
}) })
middlewareScenario(t, "Snapshot public mode disabled and unauthenticated request should return 401", func( middlewareScenario(t, "Snapshot public mode disabled and unauthenticated request should return 401", func(
t *testing.T, sc *scenarioContext) { t *testing.T, sc *scenarioContext) {
sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler) sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
sc.fakeReq("GET", "/api/snapshot").exec() sc.fakeReq("GET", "/api/snapshot").exec()
assert.Equal(t, 401, sc.resp.Code) assert.Equal(t, 401, sc.resp.Code)
}) })
middlewareScenario(t, "Snapshot public mode enabled and unauthenticated request should return 200", func( middlewareScenario(t, "Snapshot public mode enabled and unauthenticated request should return 200", func(
t *testing.T, sc *scenarioContext) { t *testing.T, sc *scenarioContext) {
setting.SnapshotPublicMode = true sc.cfg.SnapshotPublicMode = true
sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler) sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
sc.fakeReq("GET", "/api/snapshot").exec() sc.fakeReq("GET", "/api/snapshot").exec()
assert.Equal(t, 200, sc.resp.Code) assert.Equal(t, 200, sc.resp.Code)
}) })

View File

@@ -1,4 +1,4 @@
package middleware package cookies
import ( import (
"net/http" "net/http"
@@ -55,8 +55,8 @@ func WriteCookie(w http.ResponseWriter, name string, value string, maxAge int, g
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
} }
func WriteSessionCookie(ctx *models.ReqContext, value string, maxLifetime time.Duration) { func WriteSessionCookie(ctx *models.ReqContext, cfg *setting.Cfg, value string, maxLifetime time.Duration) {
if setting.Env == setting.Dev { if cfg.Env == setting.Dev {
ctx.Logger.Info("New token", "unhashed token", value) ctx.Logger.Info("New token", "unhashed token", value)
} }

View File

@@ -26,7 +26,7 @@ import (
"gopkg.in/macaron.v1" "gopkg.in/macaron.v1"
) )
func Logger() macaron.Handler { func Logger(cfg *setting.Cfg) macaron.Handler {
return func(res http.ResponseWriter, req *http.Request, c *macaron.Context) { return func(res http.ResponseWriter, req *http.Request, c *macaron.Context) {
start := time.Now() start := time.Now()
c.Data["perfmon.start"] = start c.Data["perfmon.start"] = start
@@ -43,7 +43,7 @@ func Logger() macaron.Handler {
status := rw.Status() status := rw.Status()
if status == 200 || status == 304 { if status == 200 || status == 304 {
if !setting.RouterLogging { if !cfg.RouterLogging {
return return
} }
} }

View File

@@ -1,32 +1,13 @@
package middleware package middleware
import ( import (
"context"
"errors"
"fmt" "fmt"
"strconv"
"strings" "strings"
"time"
macaron "gopkg.in/macaron.v1" macaron "gopkg.in/macaron.v1"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/components/apikeygen"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/network"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
)
var getTime = time.Now
const (
errStringInvalidUsernamePassword = "Invalid username or password"
errStringInvalidAPIKey = "Invalid API key"
) )
var ( var (
@@ -39,244 +20,7 @@ var (
ReqOrgAdmin = RoleAuth(models.ROLE_ADMIN) ReqOrgAdmin = RoleAuth(models.ROLE_ADMIN)
) )
func GetContextHandler( func AddDefaultResponseHeaders(cfg *setting.Cfg) macaron.Handler {
ats models.UserTokenService,
remoteCache *remotecache.RemoteCache,
renderService rendering.Service,
) macaron.Handler {
return func(c *macaron.Context) {
ctx := &models.ReqContext{
Context: c,
SignedInUser: &models.SignedInUser{},
IsSignedIn: false,
AllowAnonymous: false,
SkipCache: false,
Logger: log.New("context"),
}
orgID := int64(0)
orgIDHeader := ctx.Req.Header.Get("X-Grafana-Org-Id")
if orgIDHeader != "" {
orgIDParsed, err := strconv.ParseInt(orgIDHeader, 10, 64)
if err == nil {
orgID = orgIDParsed
}
}
// the order in which these are tested are important
// look for api key in Authorization header first
// then init session and look for userId in session
// then look for api key in session (special case for render calls via api)
// then test if anonymous access is enabled
switch {
case initContextWithRenderAuth(ctx, renderService):
case initContextWithApiKey(ctx):
case initContextWithBasicAuth(ctx, orgID):
case initContextWithAuthProxy(remoteCache, ctx, orgID):
case initContextWithToken(ats, ctx, orgID):
case initContextWithAnonymousUser(ctx):
}
ctx.Logger = log.New("context", "userId", ctx.UserId, "orgId", ctx.OrgId, "uname", ctx.Login)
ctx.Data["ctx"] = ctx
c.Map(ctx)
// update last seen every 5min
if ctx.ShouldUpdateLastSeenAt() {
ctx.Logger.Debug("Updating last user_seen_at", "user_id", ctx.UserId)
if err := bus.Dispatch(&models.UpdateUserLastSeenAtCommand{UserId: ctx.UserId}); err != nil {
ctx.Logger.Error("Failed to update last_seen_at", "error", err)
}
}
}
}
func initContextWithAnonymousUser(ctx *models.ReqContext) bool {
if !setting.AnonymousEnabled {
return false
}
orgQuery := models.GetOrgByNameQuery{Name: setting.AnonymousOrgName}
if err := bus.Dispatch(&orgQuery); err != nil {
log.Errorf(3, "Anonymous access organization error: '%s': %s", setting.AnonymousOrgName, err)
return false
}
ctx.IsSignedIn = false
ctx.AllowAnonymous = true
ctx.SignedInUser = &models.SignedInUser{IsAnonymous: true}
ctx.OrgRole = models.RoleType(setting.AnonymousOrgRole)
ctx.OrgId = orgQuery.Result.Id
ctx.OrgName = orgQuery.Result.Name
return true
}
func initContextWithApiKey(ctx *models.ReqContext) bool {
var keyString string
if keyString = getApiKey(ctx); keyString == "" {
return false
}
// base64 decode key
decoded, err := apikeygen.Decode(keyString)
if err != nil {
ctx.JsonApiErr(401, errStringInvalidAPIKey, err)
return true
}
// fetch key
keyQuery := models.GetApiKeyByNameQuery{KeyName: decoded.Name, OrgId: decoded.OrgId}
if err := bus.Dispatch(&keyQuery); err != nil {
ctx.JsonApiErr(401, errStringInvalidAPIKey, err)
return true
}
apikey := keyQuery.Result
// validate api key
isValid, err := apikeygen.IsValid(decoded, apikey.Key)
if err != nil {
ctx.JsonApiErr(500, "Validating API key failed", err)
return true
}
if !isValid {
ctx.JsonApiErr(401, errStringInvalidAPIKey, err)
return true
}
// check for expiration
if apikey.Expires != nil && *apikey.Expires <= getTime().Unix() {
ctx.JsonApiErr(401, "Expired API key", err)
return true
}
ctx.IsSignedIn = true
ctx.SignedInUser = &models.SignedInUser{}
ctx.OrgRole = apikey.Role
ctx.ApiKeyId = apikey.Id
ctx.OrgId = apikey.OrgId
return true
}
func initContextWithBasicAuth(ctx *models.ReqContext, orgId int64) bool {
if !setting.BasicAuthEnabled {
return false
}
header := ctx.Req.Header.Get("Authorization")
if header == "" {
return false
}
username, password, err := util.DecodeBasicAuthHeader(header)
if err != nil {
ctx.JsonApiErr(401, "Invalid Basic Auth Header", err)
return true
}
authQuery := models.LoginUserQuery{
Username: username,
Password: password,
}
if err := bus.Dispatch(&authQuery); err != nil {
ctx.Logger.Debug(
"Failed to authorize the user",
"username", username,
"err", err,
)
if errors.Is(err, models.ErrUserNotFound) {
err = login.ErrInvalidCredentials
}
ctx.JsonApiErr(401, errStringInvalidUsernamePassword, err)
return true
}
user := authQuery.User
query := models.GetSignedInUserQuery{UserId: user.Id, OrgId: orgId}
if err := bus.Dispatch(&query); err != nil {
ctx.Logger.Error(
"Failed at user signed in",
"id", user.Id,
"org", orgId,
)
ctx.JsonApiErr(401, errStringInvalidUsernamePassword, err)
return true
}
ctx.SignedInUser = query.Result
ctx.IsSignedIn = true
return true
}
func initContextWithToken(authTokenService models.UserTokenService, ctx *models.ReqContext, orgID int64) bool {
if setting.LoginCookieName == "" {
return false
}
rawToken := ctx.GetCookie(setting.LoginCookieName)
if rawToken == "" {
return false
}
token, err := authTokenService.LookupToken(ctx.Req.Context(), rawToken)
if err != nil {
ctx.Logger.Error("Failed to look up user based on cookie", "error", err)
WriteSessionCookie(ctx, "", -1)
return false
}
query := models.GetSignedInUserQuery{UserId: token.UserId, OrgId: orgID}
if err := bus.Dispatch(&query); err != nil {
ctx.Logger.Error("Failed to get user with id", "userId", token.UserId, "error", err)
return false
}
ctx.SignedInUser = query.Result
ctx.IsSignedIn = true
ctx.UserToken = token
// Rotate the token just before we write response headers to ensure there is no delay between
// the new token being generated and the client receiving it.
ctx.Resp.Before(rotateEndOfRequestFunc(ctx, authTokenService, token))
return true
}
func rotateEndOfRequestFunc(ctx *models.ReqContext, authTokenService models.UserTokenService, token *models.UserToken) macaron.BeforeFunc {
return func(w macaron.ResponseWriter) {
// if response has already been written, skip.
if w.Written() {
return
}
// if the request is cancelled by the client we should not try
// to rotate the token since the client would not accept any result.
if errors.Is(ctx.Context.Req.Context().Err(), context.Canceled) {
return
}
addr := ctx.RemoteAddr()
ip, err := network.GetIPFromAddress(addr)
if err != nil {
ctx.Logger.Debug("Failed to get client IP address", "addr", addr, "err", err)
ip = nil
}
rotated, err := authTokenService.TryRotateToken(ctx.Req.Context(), token, ip, ctx.Req.UserAgent())
if err != nil {
ctx.Logger.Error("Failed to rotate token", "error", err)
return
}
if rotated {
WriteSessionCookie(ctx, token.UnhashedToken, setting.LoginMaxLifetime)
}
}
}
func AddDefaultResponseHeaders() macaron.Handler {
return func(ctx *macaron.Context) { return func(ctx *macaron.Context) {
ctx.Resp.Before(func(w macaron.ResponseWriter) { ctx.Resp.Before(func(w macaron.ResponseWriter) {
// if response has already been written, skip. // if response has already been written, skip.
@@ -285,47 +29,46 @@ func AddDefaultResponseHeaders() macaron.Handler {
} }
if !strings.HasPrefix(ctx.Req.URL.Path, "/api/datasources/proxy/") { if !strings.HasPrefix(ctx.Req.URL.Path, "/api/datasources/proxy/") {
AddNoCacheHeaders(ctx.Resp) addNoCacheHeaders(ctx.Resp)
} }
if !setting.AllowEmbedding { if !cfg.AllowEmbedding {
AddXFrameOptionsDenyHeader(w) addXFrameOptionsDenyHeader(w)
} }
AddSecurityHeaders(w) addSecurityHeaders(w, cfg)
}) })
} }
} }
// AddSecurityHeaders adds various HTTP(S) response headers that enable various security protections behaviors in the client's browser. // addSecurityHeaders adds HTTP(S) response headers that enable various security protections in the client's browser.
func AddSecurityHeaders(w macaron.ResponseWriter) { func addSecurityHeaders(w macaron.ResponseWriter, cfg *setting.Cfg) {
if (setting.Protocol == setting.HTTPSScheme || setting.Protocol == setting.HTTP2Scheme) && if (cfg.Protocol == setting.HTTPSScheme || cfg.Protocol == setting.HTTP2Scheme) && cfg.StrictTransportSecurity {
setting.StrictTransportSecurity { strictHeaderValues := []string{fmt.Sprintf("max-age=%v", cfg.StrictTransportSecurityMaxAge)}
strictHeaderValues := []string{fmt.Sprintf("max-age=%v", setting.StrictTransportSecurityMaxAge)} if cfg.StrictTransportSecurityPreload {
if setting.StrictTransportSecurityPreload {
strictHeaderValues = append(strictHeaderValues, "preload") strictHeaderValues = append(strictHeaderValues, "preload")
} }
if setting.StrictTransportSecuritySubDomains { if cfg.StrictTransportSecuritySubDomains {
strictHeaderValues = append(strictHeaderValues, "includeSubDomains") strictHeaderValues = append(strictHeaderValues, "includeSubDomains")
} }
w.Header().Add("Strict-Transport-Security", strings.Join(strictHeaderValues, "; ")) w.Header().Add("Strict-Transport-Security", strings.Join(strictHeaderValues, "; "))
} }
if setting.ContentTypeProtectionHeader { if cfg.ContentTypeProtectionHeader {
w.Header().Add("X-Content-Type-Options", "nosniff") w.Header().Add("X-Content-Type-Options", "nosniff")
} }
if setting.XSSProtectionHeader { if cfg.XSSProtectionHeader {
w.Header().Add("X-XSS-Protection", "1; mode=block") w.Header().Add("X-XSS-Protection", "1; mode=block")
} }
} }
func AddNoCacheHeaders(w macaron.ResponseWriter) { func addNoCacheHeaders(w macaron.ResponseWriter) {
w.Header().Add("Cache-Control", "no-cache") w.Header().Add("Cache-Control", "no-cache")
w.Header().Add("Pragma", "no-cache") w.Header().Add("Pragma", "no-cache")
w.Header().Add("Expires", "-1") w.Header().Add("Expires", "-1")
} }
func AddXFrameOptionsDenyHeader(w macaron.ResponseWriter) { func addXFrameOptionsDenyHeader(w macaron.ResponseWriter) {
w.Header().Add("X-Frame-Options", "deny") w.Header().Add("X-Frame-Options", "deny")
} }

View File

@@ -7,6 +7,7 @@ import (
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -14,19 +15,13 @@ import (
) )
func TestMiddlewareBasicAuth(t *testing.T) { func TestMiddlewareBasicAuth(t *testing.T) {
var origBasicAuthEnabled = setting.BasicAuthEnabled
var origDisableBruteForceLoginProtection = setting.DisableBruteForceLoginProtection
t.Cleanup(func() {
setting.BasicAuthEnabled = origBasicAuthEnabled
setting.DisableBruteForceLoginProtection = origDisableBruteForceLoginProtection
})
setting.BasicAuthEnabled = true
setting.DisableBruteForceLoginProtection = true
bus.ClearBusHandlers()
const id int64 = 12 const id int64 = 12
configure := func(cfg *setting.Cfg) {
cfg.BasicAuthEnabled = true
cfg.DisableBruteForceLoginProtection = true
}
middlewareScenario(t, "Valid API key", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Valid API key", func(t *testing.T, sc *scenarioContext) {
const orgID int64 = 2 const orgID int64 = 2
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
@@ -44,16 +39,15 @@ func TestMiddlewareBasicAuth(t *testing.T) {
assert.True(t, sc.context.IsSignedIn) assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgId) assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole) assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole)
}) }, configure)
middlewareScenario(t, "Handle auth", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Handle auth", func(t *testing.T, sc *scenarioContext) {
const password = "MyPass" const password = "MyPass"
const salt = "Salt" const salt = "Salt"
const orgID int64 = 2 const orgID int64 = 2
t.Cleanup(bus.ClearBusHandlers)
bus.AddHandler("grafana-auth", func(query *models.LoginUserQuery) error { bus.AddHandler("grafana-auth", func(query *models.LoginUserQuery) error {
t.Log("Handling LoginUserQuery")
encoded, err := util.EncodePassword(password, salt) encoded, err := util.EncodePassword(password, salt)
if err != nil { if err != nil {
return err return err
@@ -66,6 +60,7 @@ func TestMiddlewareBasicAuth(t *testing.T) {
}) })
bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error { bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error {
t.Log("Handling GetSignedInUserQuery")
query.Result = &models.SignedInUser{OrgId: orgID, UserId: id} query.Result = &models.SignedInUser{OrgId: orgID, UserId: id}
return nil return nil
}) })
@@ -76,7 +71,7 @@ func TestMiddlewareBasicAuth(t *testing.T) {
assert.True(t, sc.context.IsSignedIn) assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgId) assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, id, sc.context.UserId) assert.Equal(t, id, sc.context.UserId)
}) }, configure)
middlewareScenario(t, "Auth sequence", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Auth sequence", func(t *testing.T, sc *scenarioContext) {
const password = "MyPass" const password = "MyPass"
@@ -104,10 +99,11 @@ func TestMiddlewareBasicAuth(t *testing.T) {
authHeader := util.GetBasicAuthHeader("myUser", password) authHeader := util.GetBasicAuthHeader("myUser", password)
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
require.NotNil(t, sc.context)
assert.True(t, sc.context.IsSignedIn) assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, id, sc.context.UserId) assert.Equal(t, id, sc.context.UserId)
}) }, configure)
middlewareScenario(t, "Should return error if user is not found", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should return error if user is not found", func(t *testing.T, sc *scenarioContext) {
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
@@ -118,8 +114,8 @@ func TestMiddlewareBasicAuth(t *testing.T) {
require.Error(t, err) require.Error(t, err)
assert.Equal(t, 401, sc.resp.Code) assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, errStringInvalidUsernamePassword, sc.respJson["message"]) assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"])
}) }, configure)
middlewareScenario(t, "Should return error if user & password do not match", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should return error if user & password do not match", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("user-query", func(loginUserQuery *models.GetUserByLoginQuery) error { bus.AddHandler("user-query", func(loginUserQuery *models.GetUserByLoginQuery) error {
@@ -134,6 +130,6 @@ func TestMiddlewareBasicAuth(t *testing.T) {
require.Error(t, err) require.Error(t, err)
assert.Equal(t, 401, sc.resp.Code) assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, errStringInvalidUsernamePassword, sc.respJson["message"]) assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"])
}) }, configure)
} }

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"path/filepath" "path/filepath"
"testing" "testing"
"time" "time"
@@ -18,31 +17,30 @@ import (
"github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/dtos"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/components/gtime" "github.com/grafana/grafana/pkg/components/gtime"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/middleware/authproxy" "github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
) )
const errorTemplate = "error-template" const errorTemplate = "error-template"
func mockGetTime() { func fakeGetTime() func() time.Time {
var timeSeed int64 var timeSeed int64
getTime = func() time.Time { return func() time.Time {
fakeNow := time.Unix(timeSeed, 0) fakeNow := time.Unix(timeSeed, 0)
timeSeed++ timeSeed++
return fakeNow return fakeNow
} }
} }
func resetGetTime() {
getTime = time.Now
}
func TestMiddleWareSecurityHeaders(t *testing.T) { func TestMiddleWareSecurityHeaders(t *testing.T) {
origErrTemplateName := setting.ErrTemplateName origErrTemplateName := setting.ErrTemplateName
t.Cleanup(func() { t.Cleanup(func() {
@@ -51,46 +49,32 @@ func TestMiddleWareSecurityHeaders(t *testing.T) {
setting.ErrTemplateName = errorTemplate setting.ErrTemplateName = errorTemplate
middlewareScenario(t, "middleware should get correct x-xss-protection header", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "middleware should get correct x-xss-protection header", func(t *testing.T, sc *scenarioContext) {
origXSSProtectionHeader := setting.XSSProtectionHeader
t.Cleanup(func() {
setting.XSSProtectionHeader = origXSSProtectionHeader
})
setting.XSSProtectionHeader = true
sc.fakeReq("GET", "/api/").exec() sc.fakeReq("GET", "/api/").exec()
assert.Equal(t, "1; mode=block", sc.resp.Header().Get("X-XSS-Protection")) assert.Equal(t, "1; mode=block", sc.resp.Header().Get("X-XSS-Protection"))
}, func(cfg *setting.Cfg) {
cfg.XSSProtectionHeader = true
}) })
middlewareScenario(t, "middleware should not get x-xss-protection when disabled", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "middleware should not get x-xss-protection when disabled", func(t *testing.T, sc *scenarioContext) {
origXSSProtectionHeader := setting.XSSProtectionHeader
t.Cleanup(func() {
setting.XSSProtectionHeader = origXSSProtectionHeader
})
setting.XSSProtectionHeader = false
sc.fakeReq("GET", "/api/").exec() sc.fakeReq("GET", "/api/").exec()
assert.Empty(t, sc.resp.Header().Get("X-XSS-Protection")) assert.Empty(t, sc.resp.Header().Get("X-XSS-Protection"))
}, func(cfg *setting.Cfg) {
cfg.XSSProtectionHeader = false
}) })
middlewareScenario(t, "middleware should add correct Strict-Transport-Security header", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "middleware should add correct Strict-Transport-Security header", func(t *testing.T, sc *scenarioContext) {
origStrictTransportSecurity := setting.StrictTransportSecurity
origProtocol := setting.Protocol
origStrictTransportSecurityMaxAge := setting.StrictTransportSecurityMaxAge
t.Cleanup(func() {
setting.StrictTransportSecurity = origStrictTransportSecurity
setting.Protocol = origProtocol
setting.StrictTransportSecurityMaxAge = origStrictTransportSecurityMaxAge
})
setting.StrictTransportSecurity = true
setting.Protocol = setting.HTTPSScheme
setting.StrictTransportSecurityMaxAge = 64000
sc.fakeReq("GET", "/api/").exec() sc.fakeReq("GET", "/api/").exec()
assert.Equal(t, "max-age=64000", sc.resp.Header().Get("Strict-Transport-Security")) assert.Equal(t, "max-age=64000", sc.resp.Header().Get("Strict-Transport-Security"))
setting.StrictTransportSecurityPreload = true sc.cfg.StrictTransportSecurityPreload = true
sc.fakeReq("GET", "/api/").exec() sc.fakeReq("GET", "/api/").exec()
assert.Equal(t, "max-age=64000; preload", sc.resp.Header().Get("Strict-Transport-Security")) assert.Equal(t, "max-age=64000; preload", sc.resp.Header().Get("Strict-Transport-Security"))
setting.StrictTransportSecuritySubDomains = true sc.cfg.StrictTransportSecuritySubDomains = true
sc.fakeReq("GET", "/api/").exec() sc.fakeReq("GET", "/api/").exec()
assert.Equal(t, "max-age=64000; preload; includeSubDomains", sc.resp.Header().Get("Strict-Transport-Security")) assert.Equal(t, "max-age=64000; preload; includeSubDomains", sc.resp.Header().Get("Strict-Transport-Security"))
}, func(cfg *setting.Cfg) {
cfg.Protocol = setting.HTTPSScheme
cfg.StrictTransportSecurity = true
cfg.StrictTransportSecurityMaxAge = 64000
}) })
} }
@@ -151,13 +135,10 @@ func TestMiddlewareContext(t *testing.T) {
middlewareScenario(t, "middleware should not add X-Frame-Options header for request when allowing embedding", func( middlewareScenario(t, "middleware should not add X-Frame-Options header for request when allowing embedding", func(
t *testing.T, sc *scenarioContext) { t *testing.T, sc *scenarioContext) {
origAllowEmbedding := setting.AllowEmbedding
t.Cleanup(func() {
setting.AllowEmbedding = origAllowEmbedding
})
setting.AllowEmbedding = true
sc.fakeReq("GET", "/api/search").exec() sc.fakeReq("GET", "/api/search").exec()
assert.Empty(t, sc.resp.Header().Get("X-Frame-Options")) assert.Empty(t, sc.resp.Header().Get("X-Frame-Options"))
}, func(cfg *setting.Cfg) {
cfg.AllowEmbedding = true
}) })
middlewareScenario(t, "Invalid api key", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Invalid api key", func(t *testing.T, sc *scenarioContext) {
@@ -166,7 +147,7 @@ func TestMiddlewareContext(t *testing.T) {
assert.Empty(t, sc.resp.Header().Get("Set-Cookie")) assert.Empty(t, sc.resp.Header().Get("Set-Cookie"))
assert.Equal(t, 401, sc.resp.Code) assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, errStringInvalidAPIKey, sc.respJson["message"]) assert.Equal(t, contexthandler.InvalidAPIKey, sc.respJson["message"])
}) })
middlewareScenario(t, "Valid api key", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Valid api key", func(t *testing.T, sc *scenarioContext) {
@@ -199,19 +180,18 @@ func TestMiddlewareContext(t *testing.T) {
sc.fakeReq("GET", "/").withValidApiKey().exec() sc.fakeReq("GET", "/").withValidApiKey().exec()
assert.Equal(t, 401, sc.resp.Code) assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, errStringInvalidAPIKey, sc.respJson["message"]) assert.Equal(t, contexthandler.InvalidAPIKey, sc.respJson["message"])
}) })
middlewareScenario(t, "Valid api key, but expired", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Valid API key, but expired", func(t *testing.T, sc *scenarioContext) {
mockGetTime() sc.contextHandler.GetTime = fakeGetTime()
defer resetGetTime()
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
require.NoError(t, err) require.NoError(t, err)
bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error { bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error {
// api key expired one second before // api key expired one second before
expires := getTime().Add(-1 * time.Second).Unix() expires := sc.contextHandler.GetTime().Add(-1 * time.Second).Unix()
query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash, query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash,
Expires: &expires} Expires: &expires}
return nil return nil
@@ -223,7 +203,7 @@ func TestMiddlewareContext(t *testing.T) {
assert.Equal(t, "Expired API key", sc.respJson["message"]) assert.Equal(t, "Expired API key", sc.respJson["message"])
}) })
middlewareScenario(t, "Non-expired auth token in cookie which not are being rotated", func( middlewareScenario(t, "Non-expired auth token in cookie which is not being rotated", func(
t *testing.T, sc *scenarioContext) { t *testing.T, sc *scenarioContext) {
const userID int64 = 12 const userID int64 = 12
@@ -357,18 +337,6 @@ func TestMiddlewareContext(t *testing.T) {
middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) {
const orgID int64 = 2 const orgID int64 = 2
origAnonymousEnabled := setting.AnonymousEnabled
origAnonymousOrgName := setting.AnonymousOrgName
origAnonymousOrgRole := setting.AnonymousOrgRole
t.Cleanup(func() {
setting.AnonymousEnabled = origAnonymousEnabled
setting.AnonymousOrgName = origAnonymousOrgName
setting.AnonymousOrgRole = origAnonymousOrgRole
})
setting.AnonymousEnabled = true
setting.AnonymousOrgName = "test"
setting.AnonymousOrgRole = string(models.ROLE_EDITOR)
bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error { bus.AddHandler("test", func(query *models.GetOrgByNameQuery) error {
assert.Equal(t, "test", query.Name) assert.Equal(t, "test", query.Name)
@@ -382,35 +350,24 @@ func TestMiddlewareContext(t *testing.T) {
assert.Equal(t, orgID, sc.context.OrgId) assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole) assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole)
assert.False(t, sc.context.IsSignedIn) assert.False(t, sc.context.IsSignedIn)
}, func(cfg *setting.Cfg) {
cfg.AnonymousEnabled = true
cfg.AnonymousOrgName = "test"
cfg.AnonymousOrgRole = string(models.ROLE_EDITOR)
}) })
t.Run("auth_proxy", func(t *testing.T) { t.Run("auth_proxy", func(t *testing.T) {
const userID int64 = 33 const userID int64 = 33
const orgID int64 = 4 const orgID int64 = 4
origAuthProxyEnabled := setting.AuthProxyEnabled configure := func(cfg *setting.Cfg) {
origAuthProxyWhitelist := setting.AuthProxyWhitelist cfg.AuthProxyEnabled = true
origAuthProxyAutoSignUp := setting.AuthProxyAutoSignUp cfg.AuthProxyAutoSignUp = true
origLDAPEnabled := setting.LDAPEnabled cfg.LDAPEnabled = true
origAuthProxyHeaderName := setting.AuthProxyHeaderName cfg.AuthProxyHeaderName = "X-WEBAUTH-USER"
origAuthProxyHeaderProperty := setting.AuthProxyHeaderProperty cfg.AuthProxyHeaderProperty = "username"
origAuthProxyHeaders := setting.AuthProxyHeaders cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"}
t.Cleanup(func() { }
setting.AuthProxyEnabled = origAuthProxyEnabled
setting.AuthProxyWhitelist = origAuthProxyWhitelist
setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp
setting.LDAPEnabled = origLDAPEnabled
setting.AuthProxyHeaderName = origAuthProxyHeaderName
setting.AuthProxyHeaderProperty = origAuthProxyHeaderProperty
setting.AuthProxyHeaders = origAuthProxyHeaders
})
setting.AuthProxyEnabled = true
setting.AuthProxyWhitelist = ""
setting.AuthProxyAutoSignUp = true
setting.LDAPEnabled = true
setting.AuthProxyHeaderName = "X-WEBAUTH-USER"
setting.AuthProxyHeaderProperty = "username"
setting.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"}
const hdrName = "markelog" const hdrName = "markelog"
const group = "grafana-core-team" const group = "grafana-core-team"
@@ -426,25 +383,16 @@ func TestMiddlewareContext(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.req.Header.Set("X-WEBAUTH-GROUPS", group) sc.req.Header.Set("X-WEBAUTH-GROUPS", group)
sc.exec() sc.exec()
assert.True(t, sc.context.IsSignedIn) assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId) assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, orgID, sc.context.OrgId) assert.Equal(t, orgID, sc.context.OrgId)
}) }, configure)
middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should respect auto signup option", func(t *testing.T, sc *scenarioContext) {
origLDAPEnabled = setting.LDAPEnabled
origAuthProxyAutoSignUp = setting.AuthProxyAutoSignUp
t.Cleanup(func() {
setting.LDAPEnabled = origLDAPEnabled
setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp
})
setting.LDAPEnabled = false
setting.AuthProxyAutoSignUp = false
var actualAuthProxyAutoSignUp *bool = nil var actualAuthProxyAutoSignUp *bool = nil
bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error { bus.AddHandler("test", func(cmd *models.UpsertUserCommand) error {
@@ -453,24 +401,19 @@ func TestMiddlewareContext(t *testing.T) {
}) })
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec() sc.exec()
assert.False(t, *actualAuthProxyAutoSignUp) assert.False(t, *actualAuthProxyAutoSignUp)
assert.Equal(t, sc.resp.Code, 407) assert.Equal(t, 407, sc.resp.Code)
assert.Nil(t, sc.context) assert.Nil(t, sc.context)
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.LDAPEnabled = false
cfg.AuthProxyAutoSignUp = false
}) })
middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) {
origLDAPEnabled = setting.LDAPEnabled
origAuthProxyAutoSignUp = setting.AuthProxyAutoSignUp
t.Cleanup(func() {
setting.LDAPEnabled = origLDAPEnabled
setting.AuthProxyAutoSignUp = origAuthProxyAutoSignUp
})
setting.LDAPEnabled = false
setting.AuthProxyAutoSignUp = true
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
if query.UserId > 0 { if query.UserId > 0 {
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
@@ -485,24 +428,22 @@ func TestMiddlewareContext(t *testing.T) {
}) })
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec() sc.exec()
assert.True(t, sc.context.IsSignedIn) assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId) assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, orgID, sc.context.OrgId) assert.Equal(t, orgID, sc.context.OrgId)
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.LDAPEnabled = false
cfg.AuthProxyAutoSignUp = true
}) })
middlewareScenario(t, "Should get an existing user from header", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should get an existing user from header", func(t *testing.T, sc *scenarioContext) {
const userID int64 = 12 const userID int64 = 12
const orgID int64 = 2 const orgID int64 = 2
origLDAPEnabled = setting.LDAPEnabled
t.Cleanup(func() {
setting.LDAPEnabled = origLDAPEnabled
})
setting.LDAPEnabled = false
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil return nil
@@ -514,24 +455,18 @@ func TestMiddlewareContext(t *testing.T) {
}) })
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec() sc.exec()
assert.True(t, sc.context.IsSignedIn) assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId) assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, orgID, sc.context.OrgId) assert.Equal(t, orgID, sc.context.OrgId)
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.LDAPEnabled = false
}) })
middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) {
origAuthProxyWhitelist = setting.AuthProxyWhitelist
origLDAPEnabled = setting.LDAPEnabled
t.Cleanup(func() {
setting.AuthProxyWhitelist = origAuthProxyWhitelist
setting.LDAPEnabled = origLDAPEnabled
})
setting.AuthProxyWhitelist = "192.168.1.0/24, 2001::0/120"
setting.LDAPEnabled = false
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil return nil
@@ -543,25 +478,20 @@ func TestMiddlewareContext(t *testing.T) {
}) })
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.req.RemoteAddr = "[2001::23]:12345" sc.req.RemoteAddr = "[2001::23]:12345"
sc.exec() sc.exec()
assert.True(t, sc.context.IsSignedIn) assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, userID, sc.context.UserId) assert.Equal(t, userID, sc.context.UserId)
assert.Equal(t, orgID, sc.context.OrgId) assert.Equal(t, orgID, sc.context.OrgId)
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.AuthProxyWhitelist = "192.168.1.0/24, 2001::0/120"
cfg.LDAPEnabled = false
}) })
middlewareScenario(t, "Should not allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) {
origAuthProxyWhitelist = setting.AuthProxyWhitelist
origLDAPEnabled = setting.LDAPEnabled
t.Cleanup(func() {
setting.AuthProxyWhitelist = origAuthProxyWhitelist
setting.LDAPEnabled = origLDAPEnabled
})
setting.AuthProxyWhitelist = "8.8.8.8"
setting.LDAPEnabled = false
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error { bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID} query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
return nil return nil
@@ -573,12 +503,16 @@ func TestMiddlewareContext(t *testing.T) {
}) })
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.req.RemoteAddr = "[2001::23]:12345" sc.req.RemoteAddr = "[2001::23]:12345"
sc.exec() sc.exec()
assert.Equal(t, 407, sc.resp.Code) assert.Equal(t, 407, sc.resp.Code)
assert.Nil(t, sc.context) assert.Nil(t, sc.context)
}, func(cfg *setting.Cfg) {
configure(cfg)
cfg.AuthProxyWhitelist = "8.8.8.8"
cfg.LDAPEnabled = false
}) })
middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should return 407 status code if LDAP says no", func(t *testing.T, sc *scenarioContext) {
@@ -587,12 +521,12 @@ func TestMiddlewareContext(t *testing.T) {
}) })
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec() sc.exec()
assert.Equal(t, 407, sc.resp.Code) assert.Equal(t, 407, sc.resp.Code)
assert.Nil(t, sc.context) assert.Nil(t, sc.context)
}) }, configure)
middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) { middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("Do not have the user", func(query *models.GetSignedInUserQuery) error { bus.AddHandler("Do not have the user", func(query *models.GetSignedInUserQuery) error {
@@ -600,52 +534,53 @@ func TestMiddlewareContext(t *testing.T) {
}) })
sc.fakeReq("GET", "/") sc.fakeReq("GET", "/")
sc.req.Header.Set(setting.AuthProxyHeaderName, hdrName) sc.req.Header.Set(sc.cfg.AuthProxyHeaderName, hdrName)
sc.exec() sc.exec()
assert.Equal(t, 407, sc.resp.Code) assert.Equal(t, 407, sc.resp.Code)
assert.Nil(t, sc.context) assert.Nil(t, sc.context)
}) }, configure)
}) })
} }
func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) { func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func(*setting.Cfg)) {
t.Helper() t.Helper()
t.Run(desc, func(t *testing.T) { t.Run(desc, func(t *testing.T) {
t.Cleanup(bus.ClearBusHandlers) t.Cleanup(bus.ClearBusHandlers)
origLoginCookieName := setting.LoginCookieName loginMaxLifetime, err := gtime.ParseDuration("30d")
origLoginMaxLifetime := setting.LoginMaxLifetime
t.Cleanup(func() {
setting.LoginCookieName = origLoginCookieName
setting.LoginMaxLifetime = origLoginMaxLifetime
})
setting.LoginCookieName = "grafana_session"
var err error
setting.LoginMaxLifetime, err = gtime.ParseDuration("30d")
require.NoError(t, err) require.NoError(t, err)
cfg := setting.NewCfg()
cfg.LoginCookieName = "grafana_session"
cfg.LoginMaxLifetime = loginMaxLifetime
for _, cb := range cbs {
cb(cfg)
}
sc := &scenarioContext{t: t} sc := &scenarioContext{t: t, cfg: cfg}
viewsPath, err := filepath.Abs("../../public/views") viewsPath, err := filepath.Abs("../../public/views")
require.NoError(t, err) require.NoError(t, err)
sc.m = macaron.New() sc.m = macaron.New()
sc.m.Use(AddDefaultResponseHeaders()) sc.m.Use(AddDefaultResponseHeaders(cfg))
sc.m.Use(macaron.Renderer(macaron.RenderOptions{ sc.m.Use(macaron.Renderer(macaron.RenderOptions{
Directory: viewsPath, Directory: viewsPath,
Delims: macaron.Delims{Left: "[[", Right: "]]"}, Delims: macaron.Delims{Left: "[[", Right: "]]"},
})) }))
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService() ctxHdlr := getContextHandler(t, cfg)
sc.remoteCacheService = remotecache.NewFakeStore(t) sc.contextHandler = ctxHdlr
sc.m.Use(ctxHdlr.Middleware)
sc.m.Use(GetContextHandler(sc.userAuthTokenService, sc.remoteCacheService, nil))
sc.m.Use(OrgRedirect()) sc.m.Use(OrgRedirect())
sc.userAuthTokenService = ctxHdlr.AuthTokenService.(*auth.FakeUserAuthTokenService)
sc.remoteCacheService = ctxHdlr.RemoteCache
sc.defaultHandler = func(c *models.ReqContext) { sc.defaultHandler = func(c *models.ReqContext) {
require.NotNil(t, c)
t.Log("Default HTTP handler called")
sc.context = c sc.context = c
if sc.handlerFunc != nil { if sc.handlerFunc != nil {
sc.handlerFunc(sc.context) sc.handlerFunc(sc.context)
@@ -662,106 +597,52 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc) {
}) })
} }
func TestDontRotateTokensOnCancelledRequests(t *testing.T) { func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHandler {
ctx, cancel := context.WithCancel(context.Background())
reqContext, _, err := initTokenRotationTest(ctx, t)
require.NoError(t, err)
tryRotateCallCount := 0
uts := &auth.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP,
userAgent string) (bool, error) {
tryRotateCallCount++
return false, nil
},
}
token := &models.UserToken{AuthToken: "oldtoken"}
fn := rotateEndOfRequestFunc(reqContext, uts, token)
cancel()
fn(reqContext.Resp)
assert.Equal(t, 0, tryRotateCallCount, "Token rotation was attempted")
}
func TestTokenRotationAtEndOfRequest(t *testing.T) {
reqContext, rr, err := initTokenRotationTest(context.Background(), t)
require.NoError(t, err)
uts := &auth.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP,
userAgent string) (bool, error) {
newToken, err := util.RandomHex(16)
require.NoError(t, err)
token.AuthToken = newToken
return true, nil
},
}
token := &models.UserToken{AuthToken: "oldtoken"}
rotateEndOfRequestFunc(reqContext, uts, token)(reqContext.Resp)
foundLoginCookie := false
resp := rr.Result()
defer resp.Body.Close()
for _, c := range resp.Cookies() {
if c.Name == "login_token" {
foundLoginCookie = true
require.NotEqual(t, token.AuthToken, c.Value, "Auth token is still the same")
}
}
assert.True(t, foundLoginCookie, "Could not find cookie")
}
func initTokenRotationTest(ctx context.Context, t *testing.T) (*models.ReqContext, *httptest.ResponseRecorder, error) {
t.Helper() t.Helper()
origLoginCookieName := setting.LoginCookieName sqlStore := sqlstore.InitTestDB(t)
origLoginMaxLifetime := setting.LoginMaxLifetime remoteCacheSvc := &remotecache.RemoteCache{}
t.Cleanup(func() { if cfg == nil {
setting.LoginCookieName = origLoginCookieName cfg = setting.NewCfg()
setting.LoginMaxLifetime = origLoginMaxLifetime }
cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{
Name: "database",
}
userAuthTokenSvc := auth.NewFakeUserAuthTokenService()
renderSvc := &fakeRenderService{}
ctxHdlr := &contexthandler.ContextHandler{}
err := registry.BuildServiceGraph([]interface{}{cfg}, []*registry.Descriptor{
{
Name: sqlstore.ServiceName,
Instance: sqlStore,
},
{
Name: remotecache.ServiceName,
Instance: remoteCacheSvc,
},
{
Name: auth.ServiceName,
Instance: userAuthTokenSvc,
},
{
Name: rendering.ServiceName,
Instance: renderSvc,
},
{
Name: contexthandler.ServiceName,
Instance: ctxHdlr,
},
}) })
setting.LoginCookieName = "login_token" require.NoError(t, err)
var err error
setting.LoginMaxLifetime, err = gtime.ParseDuration("7d") return ctxHdlr
if err != nil {
return nil, nil, err
} }
rr := httptest.NewRecorder() type fakeRenderService struct {
req, err := http.NewRequestWithContext(ctx, "", "", nil) rendering.Service
if err != nil {
return nil, nil, err
}
reqContext := &models.ReqContext{
Context: &macaron.Context{
Req: macaron.Request{
Request: req,
},
},
Logger: log.New("testlogger"),
} }
mw := mockWriter{rr} func (s *fakeRenderService) Init() error {
reqContext.Resp = mw
return reqContext, rr, nil
}
type mockWriter struct {
*httptest.ResponseRecorder
}
func (mw mockWriter) Flush() {}
func (mw mockWriter) Status() int { return 0 }
func (mw mockWriter) Size() int { return 0 }
func (mw mockWriter) Written() bool { return false }
func (mw mockWriter) Before(macaron.BeforeFunc) {}
func (mw mockWriter) Push(target string, opts *http.PushOptions) error {
return nil return nil
} }

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -18,6 +19,8 @@ type advanceTimeFunc func(deltaTime time.Duration)
type rateLimiterScenarioFunc func(c execFunc, t advanceTimeFunc) type rateLimiterScenarioFunc func(c execFunc, t advanceTimeFunc)
func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateLimiterScenarioFunc) { func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateLimiterScenarioFunc) {
t.Helper()
t.Run(desc, func(t *testing.T) { t.Run(desc, func(t *testing.T) {
defaultHandler := func(c *models.ReqContext) { defaultHandler := func(c *models.ReqContext) {
resp := make(map[string]interface{}) resp := make(map[string]interface{})
@@ -26,12 +29,14 @@ func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateL
} }
currentTime := time.Now() currentTime := time.Now()
cfg := setting.NewCfg()
m := macaron.New() m := macaron.New()
m.Use(macaron.Renderer(macaron.RenderOptions{ m.Use(macaron.Renderer(macaron.RenderOptions{
Directory: "", Directory: "",
Delims: macaron.Delims{Left: "[[", Right: "]]"}, Delims: macaron.Delims{Left: "[[", Right: "]]"},
})) }))
m.Use(GetContextHandler(nil, nil, nil)) m.Use(getContextHandler(t, cfg).Middleware)
m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler) m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler)
fn(func() *httptest.ResponseRecorder { fn(func() *httptest.ResponseRecorder {

View File

@@ -103,7 +103,7 @@ func function(pc uintptr) []byte {
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one. // Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
// While Martini is in development mode, Recovery will also output the panic as HTML. // While Martini is in development mode, Recovery will also output the panic as HTML.
func Recovery() macaron.Handler { func Recovery(cfg *setting.Cfg) macaron.Handler {
return func(c *macaron.Context) { return func(c *macaron.Context) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -134,7 +134,7 @@ func Recovery() macaron.Handler {
c.Data["Title"] = "Server Error" c.Data["Title"] = "Server Error"
c.Data["AppSubUrl"] = setting.AppSubUrl c.Data["AppSubUrl"] = setting.AppSubUrl
c.Data["Theme"] = setting.DefaultTheme c.Data["Theme"] = cfg.DefaultTheme
if setting.Env == setting.Dev { if setting.Env == setting.Dev {
if err, ok := r.(error); ok { if err, ok := r.(error); ok {
@@ -158,7 +158,7 @@ func Recovery() macaron.Handler {
c.JSON(500, resp) c.JSON(500, resp)
} else { } else {
c.HTML(500, setting.ErrTemplateName) c.HTML(500, cfg.ErrTemplateName)
} }
} }
}() }()

View File

@@ -16,8 +16,6 @@ import (
) )
func TestRecoveryMiddleware(t *testing.T) { func TestRecoveryMiddleware(t *testing.T) {
setting.ErrTemplateName = "error-template"
t.Run("Given an API route that panics", func(t *testing.T) { t.Run("Given an API route that panics", func(t *testing.T) {
apiURL := "/api/whatever" apiURL := "/api/whatever"
recoveryScenario(t, "recovery middleware should return json", apiURL, func(t *testing.T, sc *scenarioContext) { recoveryScenario(t, "recovery middleware should return json", apiURL, func(t *testing.T, sc *scenarioContext) {
@@ -52,18 +50,21 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
t.Run(desc, func(t *testing.T) { t.Run(desc, func(t *testing.T) {
defer bus.ClearBusHandlers() defer bus.ClearBusHandlers()
cfg := setting.NewCfg()
cfg.ErrTemplateName = "error-template"
sc := &scenarioContext{ sc := &scenarioContext{
t: t, t: t,
url: url, url: url,
cfg: cfg,
} }
viewsPath, err := filepath.Abs("../../public/views") viewsPath, err := filepath.Abs("../../public/views")
require.NoError(t, err) require.NoError(t, err)
sc.m = macaron.New() sc.m = macaron.New()
sc.m.Use(Recovery()) sc.m.Use(Recovery(cfg))
sc.m.Use(AddDefaultResponseHeaders()) sc.m.Use(AddDefaultResponseHeaders(cfg))
sc.m.Use(macaron.Renderer(macaron.RenderOptions{ sc.m.Use(macaron.Renderer(macaron.RenderOptions{
Directory: viewsPath, Directory: viewsPath,
Delims: macaron.Delims{Left: "[[", Right: "]]"}, Delims: macaron.Delims{Left: "[[", Right: "]]"},
@@ -72,7 +73,8 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService() sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
sc.remoteCacheService = remotecache.NewFakeStore(t) sc.remoteCacheService = remotecache.NewFakeStore(t)
sc.m.Use(GetContextHandler(sc.userAuthTokenService, sc.remoteCacheService, nil)) contextHandler := getContextHandler(t, nil)
sc.m.Use(contextHandler.Middleware)
// mock out gc goroutine // mock out gc goroutine
sc.m.Use(OrgRedirect()) sc.m.Use(OrgRedirect())

View File

@@ -1,31 +0,0 @@
package middleware
import (
"time"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/rendering"
)
func initContextWithRenderAuth(ctx *models.ReqContext, renderService rendering.Service) bool {
key := ctx.GetCookie("renderKey")
if key == "" {
return false
}
renderUser, exists := renderService.GetRenderUser(key)
if !exists {
ctx.JsonApiErr(401, "Invalid Render Key", nil)
return true
}
ctx.IsSignedIn = true
ctx.SignedInUser = &models.SignedInUser{
OrgId: renderUser.OrgID,
UserId: renderUser.UserID,
OrgRole: models.RoleType(renderUser.OrgRole),
}
ctx.IsRenderCall = true
ctx.LastSeenAt = time.Now()
return true
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -29,6 +30,8 @@ type scenarioContext struct {
url string url string
userAuthTokenService *auth.FakeUserAuthTokenService userAuthTokenService *auth.FakeUserAuthTokenService
remoteCacheService *remotecache.RemoteCache remoteCacheService *remotecache.RemoteCache
cfg *setting.Cfg
contextHandler *contexthandler.ContextHandler
req *http.Request req *http.Request
} }
@@ -94,9 +97,9 @@ func (sc *scenarioContext) exec() {
} }
if sc.tokenSessionCookie != "" { if sc.tokenSessionCookie != "" {
sc.t.Log(`Adding cookie`, "name", setting.LoginCookieName, "value", sc.tokenSessionCookie) sc.t.Log(`Adding cookie`, "name", sc.cfg.LoginCookieName, "value", sc.tokenSessionCookie)
sc.req.AddCookie(&http.Cookie{ sc.req.AddCookie(&http.Cookie{
Name: setting.LoginCookieName, Name: sc.cfg.LoginCookieName,
Value: sc.tokenSessionCookie, Value: sc.tokenSessionCookie,
}) })
} }

View File

@@ -3,6 +3,7 @@ package models
import ( import (
"time" "time"
"github.com/grafana/grafana/pkg/setting"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@@ -84,6 +85,7 @@ type LoginUserQuery struct {
User *User User *User
IpAddress string IpAddress string
AuthModule string AuthModule string
Cfg *setting.Cfg
} }
type GetUserByAuthInfoQuery struct { type GetUserByAuthInfoQuery struct {

45
pkg/registry/di.go Normal file
View File

@@ -0,0 +1,45 @@
package registry
import (
"fmt"
"github.com/facebookgo/inject"
)
// BuildServiceGraph builds a graph of services and their dependencies.
// The services are initialized after the graph is built.
func BuildServiceGraph(objs []interface{}, services []*Descriptor) error {
if services == nil {
services = GetServices()
}
for _, service := range services {
objs = append(objs, service.Instance)
}
serviceGraph := inject.Graph{}
// Provide services and their dependencies to the graph.
for _, obj := range objs {
if err := serviceGraph.Provide(&inject.Object{Value: obj}); err != nil {
return fmt.Errorf("failed to provide object to the graph: %w", err)
}
}
// Resolve services and their dependencies.
if err := serviceGraph.Populate(); err != nil {
return fmt.Errorf("failed to populate service dependencies: %w", err)
}
// Initialize services.
for _, service := range services {
if IsDisabled(service.Instance) {
continue
}
if err := service.Instance.Init(); err != nil {
return fmt.Errorf("service init failed: %w", err)
}
}
return nil
}

View File

@@ -18,8 +18,14 @@ import (
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
) )
const ServiceName = "UserAuthTokenService"
func init() { func init() {
registry.RegisterService(&UserAuthTokenService{}) registry.Register(&registry.Descriptor{
Name: ServiceName,
Instance: &UserAuthTokenService{},
InitPriority: registry.Medium,
})
} }
var getTime = time.Now var getTime = time.Now

View File

@@ -57,8 +57,13 @@ func NewFakeUserAuthTokenService() *FakeUserAuthTokenService {
} }
} }
func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP net.IP, // Init initializes the service.
userAgent string) (*models.UserToken, error) { // Required for dependency injection.
func (s *FakeUserAuthTokenService) Init() error {
return nil
}
func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP net.IP, userAgent string) (*models.UserToken, error) {
return s.CreateTokenProvider(context.Background(), userId, clientIP, userAgent) return s.CreateTokenProvider(context.Background(), userId, clientIP, userAgent)
} }

View File

@@ -1,4 +1,4 @@
package middleware package contexthandler
import ( import (
"fmt" "fmt"
@@ -8,8 +8,12 @@ import (
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/middleware/authproxy"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
macaron "gopkg.in/macaron.v1" macaron "gopkg.in/macaron.v1"
@@ -41,25 +45,16 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
} }
return nil return nil
} }
origHeaderName := setting.AuthProxyHeaderName
origEnabled := setting.AuthProxyEnabled
origHeaderProperty := setting.AuthProxyHeaderProperty
bus.AddHandler("", upsertHandler) bus.AddHandler("", upsertHandler)
bus.AddHandler("", getUserHandler) bus.AddHandler("", getUserHandler)
t.Cleanup(func() { t.Cleanup(func() {
setting.AuthProxyHeaderName = origHeaderName
setting.AuthProxyEnabled = origEnabled
setting.AuthProxyHeaderProperty = origHeaderProperty
bus.ClearBusHandlers() bus.ClearBusHandlers()
}) })
setting.AuthProxyHeaderName = "X-Killa" svc := getContextHandler(t)
setting.AuthProxyEnabled = true
setting.AuthProxyHeaderProperty = "username"
req, err := http.NewRequest("POST", "http://example.com", nil) req, err := http.NewRequest("POST", "http://example.com", nil)
require.NoError(t, err) require.NoError(t, err)
store := remotecache.NewFakeStore(t)
ctx := &models.ReqContext{ ctx := &models.ReqContext{
Context: &macaron.Context{ Context: &macaron.Context{
Req: macaron.Request{ Req: macaron.Request{
@@ -69,20 +64,72 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
}, },
Logger: log.New("Test"), Logger: log.New("Test"),
} }
req.Header.Add(setting.AuthProxyHeaderName, name) req.Header.Set(svc.Cfg.AuthProxyHeaderName, name)
key := fmt.Sprintf(authproxy.CachePrefix, authproxy.HashCacheKey(name)) key := fmt.Sprintf(authproxy.CachePrefix, authproxy.HashCacheKey(name))
t.Logf("Injecting stale user ID in cache with key %q", key) t.Logf("Injecting stale user ID in cache with key %q", key)
err = store.Set(key, int64(33), 0) err = svc.RemoteCache.Set(key, int64(33), 0)
require.NoError(t, err) require.NoError(t, err)
authEnabled := initContextWithAuthProxy(store, ctx, orgID) authEnabled := svc.initContextWithAuthProxy(ctx, orgID)
require.True(t, authEnabled) require.True(t, authEnabled)
require.Equal(t, userID, ctx.SignedInUser.UserId) require.Equal(t, userID, ctx.SignedInUser.UserId)
require.True(t, ctx.IsSignedIn) require.True(t, ctx.IsSignedIn)
i, err := store.Get(key) i, err := svc.RemoteCache.Get(key)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, userID, i.(int64)) require.Equal(t, userID, i.(int64))
} }
type fakeRenderService struct {
rendering.Service
}
func (s *fakeRenderService) Init() error {
return nil
}
func getContextHandler(t *testing.T) *ContextHandler {
t.Helper()
sqlStore := sqlstore.InitTestDB(t)
remoteCacheSvc := &remotecache.RemoteCache{}
cfg := setting.NewCfg()
cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{
Name: "database",
}
cfg.AuthProxyHeaderName = "X-Killa"
cfg.AuthProxyEnabled = true
cfg.AuthProxyHeaderProperty = "username"
userAuthTokenSvc := auth.NewFakeUserAuthTokenService()
renderSvc := &fakeRenderService{}
svc := &ContextHandler{}
err := registry.BuildServiceGraph([]interface{}{cfg}, []*registry.Descriptor{
{
Name: sqlstore.ServiceName,
Instance: sqlStore,
},
{
Name: remotecache.ServiceName,
Instance: remoteCacheSvc,
},
{
Name: auth.ServiceName,
Instance: userAuthTokenSvc,
},
{
Name: rendering.ServiceName,
Instance: renderSvc,
},
{
Name: ServiceName,
Instance: svc,
},
})
require.NoError(t, err)
return svc
}

View File

@@ -32,7 +32,13 @@ const (
var getLDAPConfig = ldap.GetConfig var getLDAPConfig = ldap.GetConfig
// isLDAPEnabled checks if LDAP is enabled // isLDAPEnabled checks if LDAP is enabled
var isLDAPEnabled = ldap.IsEnabled var isLDAPEnabled = func(cfg *setting.Cfg) bool {
if cfg != nil {
return cfg.LDAPEnabled
}
return setting.LDAPEnabled
}
// newLDAP creates multiple LDAP instance // newLDAP creates multiple LDAP instance
var newLDAP = multildap.New var newLDAP = multildap.New
@@ -42,18 +48,11 @@ var supportedHeaderFields = []string{"Name", "Email", "Login", "Groups"}
// AuthProxy struct // AuthProxy struct
type AuthProxy struct { type AuthProxy struct {
store *remotecache.RemoteCache cfg *setting.Cfg
remoteCache *remotecache.RemoteCache
ctx *models.ReqContext ctx *models.ReqContext
orgID int64 orgID int64
header string header string
enabled bool
LDAPAllowSignup bool
AuthProxyAutoSignUp bool
whitelistIP string
headerType string
headers map[string]string
cacheTTL int
} }
// Error auth proxy specific error // Error auth proxy specific error
@@ -77,35 +76,27 @@ func (err Error) Error() string {
// Options for the AuthProxy // Options for the AuthProxy
type Options struct { type Options struct {
Store *remotecache.RemoteCache RemoteCache *remotecache.RemoteCache
Ctx *models.ReqContext Ctx *models.ReqContext
OrgID int64 OrgID int64
} }
// New instance of the AuthProxy // New instance of the AuthProxy
func New(options *Options) *AuthProxy { func New(cfg *setting.Cfg, options *Options) *AuthProxy {
header := options.Ctx.Req.Header.Get(setting.AuthProxyHeaderName) header := options.Ctx.Req.Header.Get(cfg.AuthProxyHeaderName)
return &AuthProxy{ return &AuthProxy{
store: options.Store, remoteCache: options.RemoteCache,
cfg: cfg,
ctx: options.Ctx, ctx: options.Ctx,
orgID: options.OrgID, orgID: options.OrgID,
header: header, header: header,
enabled: setting.AuthProxyEnabled,
headerType: setting.AuthProxyHeaderProperty,
headers: setting.AuthProxyHeaders,
whitelistIP: setting.AuthProxyWhitelist,
cacheTTL: setting.AuthProxySyncTtl,
LDAPAllowSignup: setting.LDAPAllowSignup,
AuthProxyAutoSignUp: setting.AuthProxyAutoSignUp,
} }
} }
// IsEnabled checks if the proxy auth is enabled // IsEnabled checks if the proxy auth is enabled
func (auth *AuthProxy) IsEnabled() bool { func (auth *AuthProxy) IsEnabled() bool {
// Bail if the setting is not enabled // Bail if the setting is not enabled
return auth.enabled return auth.cfg.AuthProxyEnabled
} }
// HasHeader checks if the we have specified header // HasHeader checks if the we have specified header
@@ -113,15 +104,15 @@ func (auth *AuthProxy) HasHeader() bool {
return len(auth.header) != 0 return len(auth.header) != 0
} }
// IsAllowedIP compares presented IP with the whitelist one // IsAllowedIP returns whether provided IP is allowed.
func (auth *AuthProxy) IsAllowedIP() error { func (auth *AuthProxy) IsAllowedIP() error {
ip := auth.ctx.Req.RemoteAddr ip := auth.ctx.Req.RemoteAddr
if len(strings.TrimSpace(auth.whitelistIP)) == 0 { if len(strings.TrimSpace(auth.cfg.AuthProxyWhitelist)) == 0 {
return nil return nil
} }
proxies := strings.Split(auth.whitelistIP, ",") proxies := strings.Split(auth.cfg.AuthProxyWhitelist, ",")
var proxyObjs []*net.IPNet var proxyObjs []*net.IPNet
for _, proxy := range proxies { for _, proxy := range proxies {
result, err := coerceProxyAddress(proxy) result, err := coerceProxyAddress(proxy)
@@ -181,7 +172,7 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
} }
} }
if isLDAPEnabled() { if isLDAPEnabled(auth.cfg) {
id, err := auth.LoginViaLDAP() id, err := auth.LoginViaLDAP()
if err != nil { if err != nil {
if errors.Is(err, ldap.ErrInvalidCredentials) { if errors.Is(err, ldap.ErrInvalidCredentials) {
@@ -205,7 +196,7 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) { func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) {
cacheKey := auth.getKey() cacheKey := auth.getKey()
logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey) logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
userID, err := auth.store.Get(cacheKey) userID, err := auth.remoteCache.Get(cacheKey)
if err != nil { if err != nil {
logger.Debug("Failed getting user ID via auth cache", "error", err) logger.Debug("Failed getting user ID via auth cache", "error", err)
return 0, err return 0, err
@@ -219,7 +210,7 @@ func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) {
func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error { func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error {
cacheKey := auth.getKey() cacheKey := auth.getKey()
logger.Debug("Removing user from auth cache", "cacheKey", cacheKey) logger.Debug("Removing user from auth cache", "cacheKey", cacheKey)
if err := auth.store.Delete(cacheKey); err != nil { if err := auth.remoteCache.Delete(cacheKey); err != nil {
return err return err
} }
@@ -229,12 +220,13 @@ func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error {
// LoginViaLDAP logs in user via LDAP request // LoginViaLDAP logs in user via LDAP request
func (auth *AuthProxy) LoginViaLDAP() (int64, error) { func (auth *AuthProxy) LoginViaLDAP() (int64, error) {
config, err := getLDAPConfig() config, err := getLDAPConfig(auth.cfg)
if err != nil { if err != nil {
return 0, newError("failed to get LDAP config", err) return 0, newError("failed to get LDAP config", err)
} }
extUser, _, err := newLDAP(config.Servers).User(auth.header) mldap := newLDAP(config.Servers)
extUser, _, err := mldap.User(auth.header)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -242,7 +234,7 @@ func (auth *AuthProxy) LoginViaLDAP() (int64, error) {
// Have to sync grafana and LDAP user during log in // Have to sync grafana and LDAP user during log in
upsert := &models.UpsertUserCommand{ upsert := &models.UpsertUserCommand{
ReqContext: auth.ctx, ReqContext: auth.ctx,
SignupAllowed: auth.LDAPAllowSignup, SignupAllowed: auth.cfg.LDAPAllowSignup,
ExternalUser: extUser, ExternalUser: extUser,
} }
if err := bus.Dispatch(upsert); err != nil { if err := bus.Dispatch(upsert); err != nil {
@@ -259,7 +251,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
AuthId: auth.header, AuthId: auth.header,
} }
switch auth.headerType { switch auth.cfg.AuthProxyHeaderProperty {
case "username": case "username":
extUser.Login = auth.header extUser.Login = auth.header
@@ -284,7 +276,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
upsert := &models.UpsertUserCommand{ upsert := &models.UpsertUserCommand{
ReqContext: auth.ctx, ReqContext: auth.ctx,
SignupAllowed: setting.AuthProxyAutoSignUp, SignupAllowed: auth.cfg.AuthProxyAutoSignUp,
ExternalUser: extUser, ExternalUser: extUser,
} }
@@ -299,8 +291,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
// headersIterator iterates over all non-empty supported additional headers // headersIterator iterates over all non-empty supported additional headers
func (auth *AuthProxy) headersIterator(fn func(field string, header string)) { func (auth *AuthProxy) headersIterator(fn func(field string, header string)) {
for _, field := range supportedHeaderFields { for _, field := range supportedHeaderFields {
h := auth.headers[field] h := auth.cfg.AuthProxyHeaders[field]
if h == "" { if h == "" {
continue continue
} }
@@ -311,8 +302,8 @@ func (auth *AuthProxy) headersIterator(fn func(field string, header string)) {
} }
} }
// GetSignedUser gets full signed user info. // GetSignedUser gets full signed in user info.
func (auth *AuthProxy) GetSignedUser(userID int64) (*models.SignedInUser, error) { func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, error) {
query := &models.GetSignedInUserQuery{ query := &models.GetSignedInUserQuery{
OrgId: auth.orgID, OrgId: auth.orgID,
UserId: userID, UserId: userID,
@@ -330,14 +321,14 @@ func (auth *AuthProxy) Remember(id int64) error {
key := auth.getKey() key := auth.getKey()
// Check if user already in cache // Check if user already in cache
userID, _ := auth.store.Get(key) userID, _ := auth.remoteCache.Get(key)
if userID != nil { if userID != nil {
return nil return nil
} }
expiration := time.Duration(auth.cacheTTL) * time.Minute expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute
err := auth.store.Set(key, id, expiration) err := auth.remoteCache.Set(key, id, expiration)
if err != nil { if err != nil {
return err return err
} }
@@ -353,5 +344,8 @@ func coerceProxyAddress(proxyAddr string) (*net.IPNet, error) {
} }
_, network, err := net.ParseCIDR(proxyAddr) _, network, err := net.ParseCIDR(proxyAddr)
return network, err if err != nil {
return nil, fmt.Errorf("could not parse the network: %w", err)
}
return network, nil
} }

View File

@@ -47,9 +47,22 @@ func (m *fakeMultiLDAP) User(login string) (
return result, ldap.ServerConfig{}, nil return result, ldap.ServerConfig{}, nil
} }
func prepareMiddleware(t *testing.T, req *http.Request, store *remotecache.RemoteCache) *AuthProxy { const hdrName = "markelog"
func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, cb func(*http.Request, *setting.Cfg)) *AuthProxy {
t.Helper() t.Helper()
cfg := setting.NewCfg()
cfg.AuthProxyHeaderName = "X-Killa"
req, err := http.NewRequest("POST", "http://example.com", nil)
require.NoError(t, err)
req.Header.Set(cfg.AuthProxyHeaderName, hdrName)
if cb != nil {
cb(req, cfg)
}
ctx := &models.ReqContext{ ctx := &models.ReqContext{
Context: &macaron.Context{ Context: &macaron.Context{
Req: macaron.Request{ Req: macaron.Request{
@@ -58,8 +71,8 @@ func prepareMiddleware(t *testing.T, req *http.Request, store *remotecache.Remot
}, },
} }
auth := New(&Options{ auth := New(cfg, &Options{
Store: store, RemoteCache: remoteCache,
Ctx: ctx, Ctx: ctx,
OrgID: 4, OrgID: 4,
}) })
@@ -69,24 +82,17 @@ func prepareMiddleware(t *testing.T, req *http.Request, store *remotecache.Remot
func TestMiddlewareContext(t *testing.T) { func TestMiddlewareContext(t *testing.T) {
logger := log.New("test") logger := log.New("test")
req, err := http.NewRequest("POST", "http://example.com", nil) cache := remotecache.NewFakeStore(t)
require.NoError(t, err)
setting.AuthProxyHeaderName = "X-Killa"
store := remotecache.NewFakeStore(t)
name := "markelog"
req.Header.Add(setting.AuthProxyHeaderName, name)
t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) { t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) {
const id int64 = 33 const id int64 = 33
// Set cache key // Set cache key
key := fmt.Sprintf(CachePrefix, HashCacheKey(name)) key := fmt.Sprintf(CachePrefix, HashCacheKey(hdrName))
err := store.Set(key, id, 0) err := cache.Set(key, id, 0)
require.NoError(t, err) require.NoError(t, err)
// Set up the middleware // Set up the middleware
auth := prepareMiddleware(t, req, store) auth := prepareMiddleware(t, cache, nil)
assert.Equal(t, "auth-proxy-sync-ttl:0a7f3374e9659b10980fd66247b0cf2f", auth.getKey()) assert.Equal(t, key, auth.getKey())
gotID, err := auth.Login(logger, false) gotID, err := auth.Login(logger, false)
require.NoError(t, err) require.NoError(t, err)
@@ -96,15 +102,16 @@ func TestMiddlewareContext(t *testing.T) {
t.Run("When the cache key contains additional headers", func(t *testing.T) { t.Run("When the cache key contains additional headers", func(t *testing.T) {
const id int64 = 33 const id int64 = 33
setting.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"} const group = "grafana-core-team"
group := "grafana-core-team"
req.Header.Add("X-WEBAUTH-GROUPS", group)
key := fmt.Sprintf(CachePrefix, HashCacheKey(name+"-"+group)) key := fmt.Sprintf(CachePrefix, HashCacheKey(hdrName+"-"+group))
err := store.Set(key, id, 0) err := cache.Set(key, id, 0)
require.NoError(t, err) require.NoError(t, err)
auth := prepareMiddleware(t, req, store) auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
req.Header.Set("X-WEBAUTH-GROUPS", group)
cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"}
})
assert.Equal(t, "auth-proxy-sync-ttl:14f69b7023baa0ac98c96b31cec07bc0", auth.getKey()) assert.Equal(t, "auth-proxy-sync-ttl:14f69b7023baa0ac98c96b31cec07bc0", auth.getKey())
gotID, err := auth.Login(logger, false) gotID, err := auth.Login(logger, false)
@@ -115,12 +122,6 @@ func TestMiddlewareContext(t *testing.T) {
func TestMiddlewareContext_ldap(t *testing.T) { func TestMiddlewareContext_ldap(t *testing.T) {
logger := log.New("test") logger := log.New("test")
req, err := http.NewRequest("POST", "http://example.com", nil)
require.NoError(t, err)
setting.AuthProxyHeaderName = "X-Killa"
const headerName = "markelog"
req.Header.Add(setting.AuthProxyHeaderName, headerName)
t.Run("Logs in via LDAP", func(t *testing.T) { t.Run("Logs in via LDAP", func(t *testing.T) {
const id int64 = 42 const id int64 = 42
@@ -133,7 +134,16 @@ func TestMiddlewareContext_ldap(t *testing.T) {
return nil return nil
}) })
isLDAPEnabled = func() bool { origIsLDAPEnabled := isLDAPEnabled
origGetLDAPConfig := getLDAPConfig
origNewLDAP := newLDAP
t.Cleanup(func() {
newLDAP = origNewLDAP
isLDAPEnabled = origIsLDAPEnabled
getLDAPConfig = origGetLDAPConfig
})
isLDAPEnabled = func(*setting.Cfg) bool {
return true return true
} }
@@ -141,7 +151,7 @@ func TestMiddlewareContext_ldap(t *testing.T) {
ID: id, ID: id,
} }
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
config := &ldap.Config{ config := &ldap.Config{
Servers: []*ldap.ServerConfig{ Servers: []*ldap.ServerConfig{
{ {
@@ -156,15 +166,9 @@ func TestMiddlewareContext_ldap(t *testing.T) {
return stub return stub
} }
defer func() { cache := remotecache.NewFakeStore(t)
newLDAP = multildap.New
isLDAPEnabled = ldap.IsEnabled
getLDAPConfig = ldap.GetConfig
}()
store := remotecache.NewFakeStore(t) auth := prepareMiddleware(t, cache, nil)
auth := prepareMiddleware(t, req, store)
gotID, err := auth.Login(logger, false) gotID, err := auth.Login(logger, false)
require.NoError(t, err) require.NoError(t, err)
@@ -173,25 +177,28 @@ func TestMiddlewareContext_ldap(t *testing.T) {
assert.True(t, stub.userCalled) assert.True(t, stub.userCalled)
}) })
t.Run("Gets nice error if ldap is enabled but not configured", func(t *testing.T) { t.Run("Gets nice error if LDAP is enabled, but not configured", func(t *testing.T) {
const id int64 = 42 const id int64 = 42
isLDAPEnabled = func() bool { origIsLDAPEnabled := isLDAPEnabled
origNewLDAP := newLDAP
origGetLDAPConfig := getLDAPConfig
t.Cleanup(func() {
isLDAPEnabled = origIsLDAPEnabled
newLDAP = origNewLDAP
getLDAPConfig = origGetLDAPConfig
})
isLDAPEnabled = func(*setting.Cfg) bool {
return true return true
} }
getLDAPConfig = func() (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return nil, errors.New("something went wrong") return nil, errors.New("something went wrong")
} }
defer func() { cache := remotecache.NewFakeStore(t)
newLDAP = multildap.New
isLDAPEnabled = ldap.IsEnabled
getLDAPConfig = ldap.GetConfig
}()
store := remotecache.NewFakeStore(t) auth := prepareMiddleware(t, cache, nil)
auth := prepareMiddleware(t, req, store)
stub := &fakeMultiLDAP{ stub := &fakeMultiLDAP{
ID: id, ID: id,

View File

@@ -0,0 +1,448 @@
// Package contexthandler contains the ContextHandler service.
package contexthandler
import (
"context"
"errors"
"strconv"
"strings"
"time"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/components/apikeygen"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/network"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/middleware/cookies"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"gopkg.in/macaron.v1"
)
const (
InvalidUsernamePassword = "invalid username or password"
InvalidAPIKey = "invalid API key"
)
const ServiceName = "ContextHandler"
func init() {
registry.Register(&registry.Descriptor{
Name: ServiceName,
Instance: &ContextHandler{},
InitPriority: registry.High,
})
}
// ContextHandler is a middleware.
type ContextHandler struct {
Cfg *setting.Cfg `inject:""`
AuthTokenService models.UserTokenService `inject:""`
RemoteCache *remotecache.RemoteCache `inject:""`
RenderService rendering.Service `inject:""`
SQLStore *sqlstore.SQLStore `inject:""`
// GetTime returns the current time.
// Stubbable by tests.
GetTime func() time.Time
}
// Init initializes the service.
func (h *ContextHandler) Init() error {
return nil
}
// Middleware provides a middleware to initialize the Macaron context.
func (h *ContextHandler) Middleware(c *macaron.Context) {
ctx := &models.ReqContext{
Context: c,
SignedInUser: &models.SignedInUser{},
IsSignedIn: false,
AllowAnonymous: false,
SkipCache: false,
Logger: log.New("context"),
}
const headerName = "X-Grafana-Org-Id"
orgID := int64(0)
orgIDHeader := ctx.Req.Header.Get(headerName)
if orgIDHeader != "" {
id, err := strconv.ParseInt(orgIDHeader, 10, 64)
if err == nil {
orgID = id
} else {
ctx.Logger.Debug("Received invalid header", "header", headerName, "value", orgIDHeader)
}
}
// the order in which these are tested are important
// look for api key in Authorization header first
// then init session and look for userId in session
// then look for api key in session (special case for render calls via api)
// then test if anonymous access is enabled
switch {
case h.initContextWithRenderAuth(ctx):
case h.initContextWithAPIKey(ctx):
case h.initContextWithBasicAuth(ctx, orgID):
case h.initContextWithAuthProxy(ctx, orgID):
case h.initContextWithToken(ctx, orgID):
case h.initContextWithAnonymousUser(ctx):
}
ctx.Logger = log.New("context", "userId", ctx.UserId, "orgId", ctx.OrgId, "uname", ctx.Login)
ctx.Data["ctx"] = ctx
c.Map(ctx)
// update last seen every 5min
if ctx.ShouldUpdateLastSeenAt() {
ctx.Logger.Debug("Updating last user_seen_at", "user_id", ctx.UserId)
if err := bus.Dispatch(&models.UpdateUserLastSeenAtCommand{UserId: ctx.UserId}); err != nil {
ctx.Logger.Error("Failed to update last_seen_at", "error", err)
}
}
}
func (h *ContextHandler) initContextWithAnonymousUser(ctx *models.ReqContext) bool {
if !h.Cfg.AnonymousEnabled {
return false
}
orgQuery := models.GetOrgByNameQuery{Name: h.Cfg.AnonymousOrgName}
if err := bus.Dispatch(&orgQuery); err != nil {
log.Errorf(3, "Anonymous access organization error: '%s': %s", h.Cfg.AnonymousOrgName, err)
return false
}
ctx.IsSignedIn = false
ctx.AllowAnonymous = true
ctx.SignedInUser = &models.SignedInUser{IsAnonymous: true}
ctx.OrgRole = models.RoleType(h.Cfg.AnonymousOrgRole)
ctx.OrgId = orgQuery.Result.Id
ctx.OrgName = orgQuery.Result.Name
return true
}
func (h *ContextHandler) initContextWithAPIKey(ctx *models.ReqContext) bool {
header := ctx.Req.Header.Get("Authorization")
parts := strings.SplitN(header, " ", 2)
var keyString string
if len(parts) == 2 && parts[0] == "Bearer" {
keyString = parts[1]
} else {
username, password, err := util.DecodeBasicAuthHeader(header)
if err == nil && username == "api_key" {
keyString = password
}
}
if keyString == "" {
return false
}
// base64 decode key
decoded, err := apikeygen.Decode(keyString)
if err != nil {
ctx.JsonApiErr(401, InvalidAPIKey, err)
return true
}
// fetch key
keyQuery := models.GetApiKeyByNameQuery{KeyName: decoded.Name, OrgId: decoded.OrgId}
if err := bus.Dispatch(&keyQuery); err != nil {
ctx.JsonApiErr(401, InvalidAPIKey, err)
return true
}
apikey := keyQuery.Result
// validate api key
isValid, err := apikeygen.IsValid(decoded, apikey.Key)
if err != nil {
ctx.JsonApiErr(500, "Validating API key failed", err)
return true
}
if !isValid {
ctx.JsonApiErr(401, InvalidAPIKey, err)
return true
}
// check for expiration
getTime := h.GetTime
if getTime == nil {
getTime = time.Now
}
if apikey.Expires != nil && *apikey.Expires <= getTime().Unix() {
ctx.JsonApiErr(401, "Expired API key", err)
return true
}
ctx.IsSignedIn = true
ctx.SignedInUser = &models.SignedInUser{}
ctx.OrgRole = apikey.Role
ctx.ApiKeyId = apikey.Id
ctx.OrgId = apikey.OrgId
return true
}
func (h *ContextHandler) initContextWithBasicAuth(ctx *models.ReqContext, orgID int64) bool {
if !h.Cfg.BasicAuthEnabled {
return false
}
header := ctx.Req.Header.Get("Authorization")
if header == "" {
return false
}
username, password, err := util.DecodeBasicAuthHeader(header)
if err != nil {
ctx.JsonApiErr(401, "Invalid Basic Auth Header", err)
return true
}
authQuery := models.LoginUserQuery{
Username: username,
Password: password,
Cfg: h.Cfg,
}
if err := bus.Dispatch(&authQuery); err != nil {
ctx.Logger.Debug(
"Failed to authorize the user",
"username", username,
"err", err,
)
if errors.Is(err, models.ErrUserNotFound) {
err = login.ErrInvalidCredentials
}
ctx.JsonApiErr(401, InvalidUsernamePassword, err)
return true
}
user := authQuery.User
query := models.GetSignedInUserQuery{UserId: user.Id, OrgId: orgID}
if err := bus.Dispatch(&query); err != nil {
ctx.Logger.Error(
"Failed at user signed in",
"id", user.Id,
"org", orgID,
)
ctx.JsonApiErr(401, InvalidUsernamePassword, err)
return true
}
ctx.SignedInUser = query.Result
ctx.IsSignedIn = true
return true
}
func (h *ContextHandler) initContextWithToken(ctx *models.ReqContext, orgID int64) bool {
if h.Cfg.LoginCookieName == "" {
return false
}
rawToken := ctx.GetCookie(h.Cfg.LoginCookieName)
if rawToken == "" {
return false
}
token, err := h.AuthTokenService.LookupToken(ctx.Req.Context(), rawToken)
if err != nil {
ctx.Logger.Error("Failed to look up user based on cookie", "error", err)
cookies.WriteSessionCookie(ctx, h.Cfg, "", -1)
return false
}
query := models.GetSignedInUserQuery{UserId: token.UserId, OrgId: orgID}
if err := bus.Dispatch(&query); err != nil {
ctx.Logger.Error("Failed to get user with id", "userId", token.UserId, "error", err)
return false
}
ctx.SignedInUser = query.Result
ctx.IsSignedIn = true
ctx.UserToken = token
// Rotate the token just before we write response headers to ensure there is no delay between
// the new token being generated and the client receiving it.
ctx.Resp.Before(h.rotateEndOfRequestFunc(ctx, h.AuthTokenService, token))
return true
}
func (h *ContextHandler) rotateEndOfRequestFunc(ctx *models.ReqContext, authTokenService models.UserTokenService,
token *models.UserToken) macaron.BeforeFunc {
return func(w macaron.ResponseWriter) {
// if response has already been written, skip.
if w.Written() {
return
}
// if the request is cancelled by the client we should not try
// to rotate the token since the client would not accept any result.
if errors.Is(ctx.Context.Req.Context().Err(), context.Canceled) {
return
}
addr := ctx.RemoteAddr()
ip, err := network.GetIPFromAddress(addr)
if err != nil {
ctx.Logger.Debug("Failed to get client IP address", "addr", addr, "err", err)
ip = nil
}
rotated, err := authTokenService.TryRotateToken(ctx.Req.Context(), token, ip, ctx.Req.UserAgent())
if err != nil {
ctx.Logger.Error("Failed to rotate token", "error", err)
return
}
if rotated {
cookies.WriteSessionCookie(ctx, h.Cfg, token.UnhashedToken, h.Cfg.LoginMaxLifetime)
}
}
}
func (h *ContextHandler) initContextWithRenderAuth(ctx *models.ReqContext) bool {
key := ctx.GetCookie("renderKey")
if key == "" {
return false
}
renderUser, exists := h.RenderService.GetRenderUser(key)
if !exists {
ctx.JsonApiErr(401, "Invalid Render Key", nil)
return true
}
ctx.IsSignedIn = true
ctx.SignedInUser = &models.SignedInUser{
OrgId: renderUser.OrgID,
UserId: renderUser.UserID,
OrgRole: models.RoleType(renderUser.OrgRole),
}
ctx.IsRenderCall = true
ctx.LastSeenAt = time.Now()
return true
}
func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) {
logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache)
// Try to log in user via various providers
id, err := auth.Login(logger, ignoreCache)
if err != nil {
details := err
var e authproxy.Error
if errors.As(err, &e) {
details = e.DetailsError
}
logger.Error("Failed to login", "username", username, "message", err.Error(), "error", details,
"ignoreCache", ignoreCache)
return 0, err
}
return id, nil
}
func handleError(ctx *models.ReqContext, err error, statusCode int, cb func(error)) {
details := err
var e authproxy.Error
if errors.As(err, &e) {
details = e.DetailsError
}
ctx.Handle(statusCode, err.Error(), details)
if cb != nil {
cb(details)
}
}
func (h *ContextHandler) initContextWithAuthProxy(ctx *models.ReqContext, orgID int64) bool {
username := ctx.Req.Header.Get(h.Cfg.AuthProxyHeaderName)
auth := authproxy.New(h.Cfg, &authproxy.Options{
RemoteCache: h.RemoteCache,
Ctx: ctx,
OrgID: orgID,
})
logger := log.New("auth.proxy")
// Bail if auth proxy is not enabled
if !auth.IsEnabled() {
return false
}
// If there is no header - we can't move forward
if !auth.HasHeader() {
return false
}
// Check if allowed to continue with this IP
if err := auth.IsAllowedIP(); err != nil {
handleError(ctx, err, 407, func(details error) {
logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details)
})
return true
}
id, err := logUserIn(auth, username, logger, false)
if err != nil {
handleError(ctx, err, 407, nil)
return true
}
logger.Debug("Got user ID, getting full user info", "userID", id)
user, err := auth.GetSignedInUser(id)
if err != nil {
// The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale
// cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated
// because cache keys are computed from request header values and not just the user ID. Meaning that
// we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to
// log the user in again without the cache.
logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id)
if err := auth.RemoveUserFromCache(logger); err != nil {
if !errors.Is(err, remotecache.ErrCacheItemNotFound) {
logger.Error("Got unexpected error when removing user from auth cache", "error", err)
}
}
id, err = logUserIn(auth, username, logger, true)
if err != nil {
handleError(ctx, err, 407, nil)
return true
}
user, err = auth.GetSignedInUser(id)
if err != nil {
handleError(ctx, err, 407, nil)
return true
}
}
logger.Debug("Successfully got user info", "userID", user.UserId, "username", user.Login)
// Add user info to context
ctx.SignedInUser = user
ctx.IsSignedIn = true
// Remember user data in cache
if err := auth.Remember(id); err != nil {
handleError(ctx, err, 500, func(details error) {
logger.Error(
"Failed to store user in cache",
"username", username,
"message", err.Error(),
"error", details,
)
})
return true
}
return true
}

View File

@@ -0,0 +1,126 @@
package contexthandler
import (
"context"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/grafana/grafana/pkg/components/gtime"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
macaron "gopkg.in/macaron.v1"
)
func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
ctxHdlr := getContextHandler(t)
ctx, cancel := context.WithCancel(context.Background())
reqContext, _, err := initTokenRotationScenario(ctx, t)
require.NoError(t, err)
tryRotateCallCount := 0
uts := &auth.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP,
userAgent string) (bool, error) {
tryRotateCallCount++
return false, nil
},
}
token := &models.UserToken{AuthToken: "oldtoken"}
fn := ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token)
cancel()
fn(reqContext.Resp)
assert.Equal(t, 0, tryRotateCallCount, "Token rotation was attempted")
}
func TestTokenRotationAtEndOfRequest(t *testing.T) {
ctxHdlr := getContextHandler(t)
reqContext, rr, err := initTokenRotationScenario(context.Background(), t)
require.NoError(t, err)
uts := &auth.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP,
userAgent string) (bool, error) {
newToken, err := util.RandomHex(16)
require.NoError(t, err)
token.AuthToken = newToken
return true, nil
},
}
token := &models.UserToken{AuthToken: "oldtoken"}
ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token)(reqContext.Resp)
foundLoginCookie := false
resp := rr.Result()
defer resp.Body.Close()
for _, c := range resp.Cookies() {
if c.Name == "login_token" {
foundLoginCookie = true
require.NotEqual(t, token.AuthToken, c.Value, "Auth token is still the same")
}
}
assert.True(t, foundLoginCookie, "Could not find cookie")
}
func initTokenRotationScenario(ctx context.Context, t *testing.T) (*models.ReqContext, *httptest.ResponseRecorder, error) {
t.Helper()
origLoginCookieName := setting.LoginCookieName
origLoginMaxLifetime := setting.LoginMaxLifetime
t.Cleanup(func() {
setting.LoginCookieName = origLoginCookieName
setting.LoginMaxLifetime = origLoginMaxLifetime
})
setting.LoginCookieName = "login_token"
var err error
setting.LoginMaxLifetime, err = gtime.ParseDuration("7d")
if err != nil {
return nil, nil, err
}
rr := httptest.NewRecorder()
req, err := http.NewRequestWithContext(ctx, "", "", nil)
if err != nil {
return nil, nil, err
}
reqContext := &models.ReqContext{
Context: &macaron.Context{
Req: macaron.Request{
Request: req,
},
},
Logger: log.New("testlogger"),
}
mw := mockWriter{rr}
reqContext.Resp = mw
return reqContext, rr, nil
}
type mockWriter struct {
*httptest.ResponseRecorder
}
func (mw mockWriter) Flush() {}
func (mw mockWriter) Status() int { return 0 }
func (mw mockWriter) Size() int { return 0 }
func (mw mockWriter) Written() bool { return false }
func (mw mockWriter) Before(macaron.BeforeFunc) {}
func (mw mockWriter) Push(target string, opts *http.PushOptions) error {
return nil
}

View File

@@ -94,8 +94,12 @@ var config *Config
// GetConfig returns the LDAP config if LDAP is enabled otherwise it returns nil. It returns either cached value of // GetConfig returns the LDAP config if LDAP is enabled otherwise it returns nil. It returns either cached value of
// the config or it reads it and caches it first. // the config or it reads it and caches it first.
func GetConfig() (*Config, error) { func GetConfig(cfg *setting.Cfg) (*Config, error) {
if !IsEnabled() { if cfg != nil {
if !cfg.LDAPEnabled {
return nil, nil
}
} else if !IsEnabled() {
return nil, nil return nil, nil
} }

View File

@@ -25,12 +25,13 @@ import (
func init() { func init() {
remotecache.Register(&RenderUser{}) remotecache.Register(&RenderUser{})
registry.Register(&registry.Descriptor{ registry.Register(&registry.Descriptor{
Name: "RenderingService", Name: ServiceName,
Instance: &RenderingService{}, Instance: &RenderingService{},
InitPriority: registry.High, InitPriority: registry.High,
}) })
} }
const ServiceName = "RenderingService"
const renderKeyPrefix = "render-%s" const renderKeyPrefix = "render-%s"
type RenderUser struct { type RenderUser struct {
@@ -226,8 +227,8 @@ func (rs *RenderingService) getURL(path string) string {
return fmt.Sprintf("%s%s&render=1", rs.Cfg.RendererCallbackUrl, path) return fmt.Sprintf("%s%s&render=1", rs.Cfg.RendererCallbackUrl, path)
} }
protocol := setting.Protocol protocol := rs.Cfg.Protocol
switch setting.Protocol { switch protocol {
case setting.HTTPScheme: case setting.HTTPScheme:
protocol = "http" protocol = "http"
case setting.HTTP2Scheme, setting.HTTPSScheme: case setting.HTTP2Scheme, setting.HTTPSScheme:

View File

@@ -28,7 +28,7 @@ func TestGetUrl(t *testing.T) {
t.Run("And protocol HTTP configured should return expected path", func(t *testing.T) { t.Run("And protocol HTTP configured should return expected path", func(t *testing.T) {
rs.Cfg.ServeFromSubPath = false rs.Cfg.ServeFromSubPath = false
rs.Cfg.AppSubURL = "" rs.Cfg.AppSubURL = ""
setting.Protocol = setting.HTTPScheme rs.Cfg.Protocol = setting.HTTPScheme
url := rs.getURL(path) url := rs.getURL(path)
require.Equal(t, "http://localhost:3000/"+path+"&render=1", url) require.Equal(t, "http://localhost:3000/"+path+"&render=1", url)
@@ -43,7 +43,7 @@ func TestGetUrl(t *testing.T) {
t.Run("And protocol HTTPS configured should return expected path", func(t *testing.T) { t.Run("And protocol HTTPS configured should return expected path", func(t *testing.T) {
rs.Cfg.ServeFromSubPath = false rs.Cfg.ServeFromSubPath = false
rs.Cfg.AppSubURL = "" rs.Cfg.AppSubURL = ""
setting.Protocol = setting.HTTPSScheme rs.Cfg.Protocol = setting.HTTPSScheme
url := rs.getURL(path) url := rs.getURL(path)
require.Equal(t, "https://localhost:3000/"+path+"&render=1", url) require.Equal(t, "https://localhost:3000/"+path+"&render=1", url)
}) })
@@ -51,7 +51,7 @@ func TestGetUrl(t *testing.T) {
t.Run("And protocol HTTP2 configured should return expected path", func(t *testing.T) { t.Run("And protocol HTTP2 configured should return expected path", func(t *testing.T) {
rs.Cfg.ServeFromSubPath = false rs.Cfg.ServeFromSubPath = false
rs.Cfg.AppSubURL = "" rs.Cfg.AppSubURL = ""
setting.Protocol = setting.HTTP2Scheme rs.Cfg.Protocol = setting.HTTP2Scheme
url := rs.getURL(path) url := rs.getURL(path)
require.Equal(t, "https://localhost:3000/"+path+"&render=1", url) require.Equal(t, "https://localhost:3000/"+path+"&render=1", url)
}) })

View File

@@ -6,8 +6,6 @@ import (
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
) )
func (ss *SQLStore) addPreferencesQueryAndCommandHandlers() { func (ss *SQLStore) addPreferencesQueryAndCommandHandlers() {
@@ -42,7 +40,7 @@ func (ss *SQLStore) GetPreferencesWithDefaults(query *models.GetPreferencesWithD
} }
res := &models.Preferences{ res := &models.Preferences{
Theme: setting.DefaultTheme, Theme: ss.Cfg.DefaultTheme,
Timezone: ss.Cfg.DateFormats.DefaultTimezone, Timezone: ss.Cfg.DateFormats.DefaultTimezone,
HomeDashboardId: 0, HomeDashboardId: 0,
} }

View File

@@ -6,7 +6,6 @@ import (
"testing" "testing"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -14,7 +13,7 @@ func TestPreferencesDataAccess(t *testing.T) {
ss := InitTestDB(t) ss := InitTestDB(t)
t.Run("GetPreferencesWithDefaults with no saved preferences should return defaults", func(t *testing.T) { t.Run("GetPreferencesWithDefaults with no saved preferences should return defaults", func(t *testing.T) {
setting.DefaultTheme = "light" ss.Cfg.DefaultTheme = "light"
ss.Cfg.DateFormats.DefaultTimezone = "UTC" ss.Cfg.DateFormats.DefaultTimezone = "UTC"
query := &models.GetPreferencesWithDefaultsQuery{User: &models.SignedInUser{}} query := &models.GetPreferencesWithDefaultsQuery{User: &models.SignedInUser{}}

View File

@@ -39,16 +39,21 @@ var (
// ContextSessionKey is used as key to save values in `context.Context` // ContextSessionKey is used as key to save values in `context.Context`
type ContextSessionKey struct{} type ContextSessionKey struct{}
const ServiceName = "SqlStore"
const InitPriority = registry.High
func init() { func init() {
ss := &SQLStore{}
// This change will make xorm use an empty default schema for postgres and // This change will make xorm use an empty default schema for postgres and
// by that mimic the functionality of how it was functioning before // by that mimic the functionality of how it was functioning before
// xorm's changes above. // xorm's changes above.
xorm.DefaultPostgresSchema = "" xorm.DefaultPostgresSchema = ""
registry.Register(&registry.Descriptor{ registry.Register(&registry.Descriptor{
Name: "SQLStore", Name: ServiceName,
Instance: &SQLStore{}, Instance: ss,
InitPriority: registry.High, InitPriority: InitPriority,
}) })
} }
@@ -113,13 +118,20 @@ func (ss *SQLStore) Init() error {
func (ss *SQLStore) ensureMainOrgAndAdminUser() error { func (ss *SQLStore) ensureMainOrgAndAdminUser() error {
err := ss.InTransaction(context.Background(), func(ctx context.Context) error { err := ss.InTransaction(context.Background(), func(ctx context.Context) error {
systemUserCountQuery := models.GetSystemUserCountStatsQuery{} var stats models.SystemUserCountStats
err := bus.DispatchCtx(ctx, &systemUserCountQuery) err := ss.WithDbSession(ctx, func(sess *DBSession) error {
if err != nil { var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
if _, err := sess.SQL(rawSql).Get(&stats); err != nil {
return fmt.Errorf("could not determine if admin user exists: %w", err) return fmt.Errorf("could not determine if admin user exists: %w", err)
} }
if systemUserCountQuery.Result.Count > 0 { return nil
})
if err != nil {
return err
}
if stats.Count > 0 {
return nil return nil
} }
@@ -351,7 +363,7 @@ func InitTestDB(t ITestDB) *SQLStore {
testSQLStore = &SQLStore{} testSQLStore = &SQLStore{}
testSQLStore.Bus = bus.New() testSQLStore.Bus = bus.New()
testSQLStore.CacheService = localcache.New(5*time.Minute, 10*time.Minute) testSQLStore.CacheService = localcache.New(5*time.Minute, 10*time.Minute)
testSQLStore.skipEnsureDefaultOrgAndUser = true testSQLStore.skipEnsureDefaultOrgAndUser = false
dbType := migrator.SQLite dbType := migrator.SQLite

View File

@@ -47,7 +47,7 @@ var (
// This constant corresponds to the default value for ldap_sync_ttl in .ini files // This constant corresponds to the default value for ldap_sync_ttl in .ini files
// it is used for comparison and has to be kept in sync // it is used for comparison and has to be kept in sync
const ( const (
AuthProxySyncTTL = 60 authProxySyncTTL = 60
) )
var ( var (
@@ -75,12 +75,8 @@ var (
CustomInitPath = "conf/custom.ini" CustomInitPath = "conf/custom.ini"
// HTTP server options // HTTP server options
Protocol Scheme
Domain string
HttpAddr, HttpPort string HttpAddr, HttpPort string
CertFile, KeyFile string CertFile, KeyFile string
SocketPath string
RouterLogging bool
DataProxyLogging bool DataProxyLogging bool
DataProxyTimeout int DataProxyTimeout int
DataProxyTLSHandshakeTimeout int DataProxyTLSHandshakeTimeout int
@@ -97,24 +93,15 @@ var (
DisableGravatar bool DisableGravatar bool
EmailCodeValidMinutes int EmailCodeValidMinutes int
DataProxyWhiteList map[string]bool DataProxyWhiteList map[string]bool
DisableBruteForceLoginProtection bool
CookieSecure bool CookieSecure bool
CookieSameSiteDisabled bool CookieSameSiteDisabled bool
CookieSameSiteMode http.SameSite CookieSameSiteMode http.SameSite
AllowEmbedding bool
XSSProtectionHeader bool
ContentTypeProtectionHeader bool
StrictTransportSecurity bool
StrictTransportSecurityMaxAge int
StrictTransportSecurityPreload bool
StrictTransportSecuritySubDomains bool
// Snapshots // Snapshots
ExternalSnapshotUrl string ExternalSnapshotUrl string
ExternalSnapshotName string ExternalSnapshotName string
ExternalEnabled bool ExternalEnabled bool
SnapShotRemoveExpired bool SnapShotRemoveExpired bool
SnapshotPublicMode bool
// Dashboard history // Dashboard history
DashboardVersionsToKeep int DashboardVersionsToKeep int
@@ -129,7 +116,6 @@ var (
VerifyEmailEnabled bool VerifyEmailEnabled bool
LoginHint string LoginHint string
PasswordHint string PasswordHint string
DefaultTheme string
DisableLoginForm bool DisableLoginForm bool
DisableSignoutMenu bool DisableSignoutMenu bool
SignoutRedirectUrl string SignoutRedirectUrl string
@@ -139,7 +125,7 @@ var (
OAuthAutoLogin bool OAuthAutoLogin bool
ViewersCanEdit bool ViewersCanEdit bool
// Http auth // HTTP auth
AdminUser string AdminUser string
AdminPassword string AdminPassword string
LoginCookieName string LoginCookieName string
@@ -147,18 +133,10 @@ var (
SigV4AuthEnabled bool SigV4AuthEnabled bool
AnonymousEnabled bool AnonymousEnabled bool
AnonymousOrgName string
AnonymousOrgRole string
// Auth proxy settings // Auth proxy settings
AuthProxyEnabled bool AuthProxyEnabled bool
AuthProxyHeaderName string
AuthProxyHeaderProperty string AuthProxyHeaderProperty string
AuthProxyAutoSignUp bool
AuthProxyEnableLoginToken bool
AuthProxySyncTtl int
AuthProxyWhitelist string
AuthProxyHeaders map[string]string
// Basic Auth // Basic Auth
BasicAuthEnabled bool BasicAuthEnabled bool
@@ -224,6 +202,9 @@ type Cfg struct {
ServeFromSubPath bool ServeFromSubPath bool
StaticRootPath string StaticRootPath string
Protocol Scheme Protocol Scheme
SocketPath string
RouterLogging bool
Domain string
// build // build
BuildVersion string BuildVersion string
@@ -256,6 +237,13 @@ type Cfg struct {
CookieSecure bool CookieSecure bool
CookieSameSiteDisabled bool CookieSameSiteDisabled bool
CookieSameSiteMode http.SameSite CookieSameSiteMode http.SameSite
AllowEmbedding bool
XSSProtectionHeader bool
ContentTypeProtectionHeader bool
StrictTransportSecurity bool
StrictTransportSecurityMaxAge int
StrictTransportSecurityPreload bool
StrictTransportSecuritySubDomains bool
TempDataLifetime time.Duration TempDataLifetime time.Duration
PluginsEnableAlpha bool PluginsEnableAlpha bool
@@ -282,6 +270,17 @@ type Cfg struct {
LoginMaxLifetime time.Duration LoginMaxLifetime time.Duration
TokenRotationIntervalMinutes int TokenRotationIntervalMinutes int
SigV4AuthEnabled bool SigV4AuthEnabled bool
BasicAuthEnabled bool
// Auth proxy settings
AuthProxyEnabled bool
AuthProxyHeaderName string
AuthProxyHeaderProperty string
AuthProxyAutoSignUp bool
AuthProxyEnableLoginToken bool
AuthProxyWhitelist string
AuthProxyHeaders map[string]string
AuthProxySyncTTL int
// OAuth // OAuth
OAuthCookieMaxAge int OAuthCookieMaxAge int
@@ -302,6 +301,9 @@ type Cfg struct {
// Use to enable new features which may still be in alpha/beta stage. // Use to enable new features which may still be in alpha/beta stage.
FeatureToggles map[string]bool FeatureToggles map[string]bool
AnonymousEnabled bool
AnonymousOrgName string
AnonymousOrgRole string
AnonymousHideVersion bool AnonymousHideVersion bool
DateFormats DateFormats DateFormats DateFormats
@@ -317,6 +319,21 @@ type Cfg struct {
// Sentry config // Sentry config
Sentry Sentry Sentry Sentry
// Snapshots
SnapshotPublicMode bool
ErrTemplateName string
Env string
// LDAP
LDAPEnabled bool
LDAPAllowSignup bool
Quota QuotaSettings
DefaultTheme string
} }
// IsExpressionsEnabled returns whether the expressions feature is enabled. // IsExpressionsEnabled returns whether the expressions feature is enabled.
@@ -707,9 +724,12 @@ func (cfg *Cfg) Load(args *CommandLineArgs) error {
cfg.IsEnterprise = IsEnterprise cfg.IsEnterprise = IsEnterprise
cfg.Packaging = Packaging cfg.Packaging = Packaging
cfg.ErrTemplateName = ErrTemplateName
ApplicationName = "Grafana" ApplicationName = "Grafana"
Env = valueAsString(iniFile.Section(""), "app_mode", "development") Env = valueAsString(iniFile.Section(""), "app_mode", "development")
cfg.Env = Env
InstanceName = valueAsString(iniFile.Section(""), "instance_name", "unknown_instance_name") InstanceName = valueAsString(iniFile.Section(""), "instance_name", "unknown_instance_name")
plugins := valueAsString(iniFile.Section("paths"), "plugins", "") plugins := valueAsString(iniFile.Section("paths"), "plugins", "")
PluginsPath = makeAbsolute(plugins, HomePath) PluginsPath = makeAbsolute(plugins, HomePath)
@@ -736,7 +756,7 @@ func (cfg *Cfg) Load(args *CommandLineArgs) error {
return err return err
} }
if err := readSnapshotsSettings(iniFile); err != nil { if err := readSnapshotsSettings(cfg, iniFile); err != nil {
return err return err
} }
@@ -789,7 +809,6 @@ func (cfg *Cfg) Load(args *CommandLineArgs) error {
cfg.PluginsAllowUnsigned = append(cfg.PluginsAllowUnsigned, plug) cfg.PluginsAllowUnsigned = append(cfg.PluginsAllowUnsigned, plug)
} }
cfg.MarketplaceURL = pluginsSection.Key("marketplace_url").MustString("https://grafana.com/grafana/plugins/") cfg.MarketplaceURL = pluginsSection.Key("marketplace_url").MustString("https://grafana.com/grafana/plugins/")
cfg.Protocol = Protocol
// Read and populate feature toggles list // Read and populate feature toggles list
featureTogglesSection := iniFile.Section("feature_toggles") featureTogglesSection := iniFile.Section("feature_toggles")
@@ -858,8 +877,10 @@ func (cfg *Cfg) readLDAPConfig() {
LDAPConfigFile = ldapSec.Key("config_file").String() LDAPConfigFile = ldapSec.Key("config_file").String()
LDAPSyncCron = ldapSec.Key("sync_cron").String() LDAPSyncCron = ldapSec.Key("sync_cron").String()
LDAPEnabled = ldapSec.Key("enabled").MustBool(false) LDAPEnabled = ldapSec.Key("enabled").MustBool(false)
cfg.LDAPEnabled = LDAPEnabled
LDAPActiveSyncEnabled = ldapSec.Key("active_sync_enabled").MustBool(false) LDAPActiveSyncEnabled = ldapSec.Key("active_sync_enabled").MustBool(false)
LDAPAllowSignup = ldapSec.Key("allow_sign_up").MustBool(true) LDAPAllowSignup = ldapSec.Key("allow_sign_up").MustBool(true)
cfg.LDAPAllowSignup = LDAPAllowSignup
} }
func (cfg *Cfg) readSessionConfig() { func (cfg *Cfg) readSessionConfig() {
@@ -910,7 +931,7 @@ func (cfg *Cfg) LogConfigSources() {
cfg.Logger.Info("Path Logs", "path", cfg.LogsPath) cfg.Logger.Info("Path Logs", "path", cfg.LogsPath)
cfg.Logger.Info("Path Plugins", "path", PluginsPath) cfg.Logger.Info("Path Plugins", "path", PluginsPath)
cfg.Logger.Info("Path Provisioning", "path", cfg.ProvisioningPath) cfg.Logger.Info("Path Provisioning", "path", cfg.ProvisioningPath)
cfg.Logger.Info("App mode " + Env) cfg.Logger.Info("App mode " + cfg.Env)
} }
type DynamicSection struct { type DynamicSection struct {
@@ -949,7 +970,6 @@ func readSecuritySettings(iniFile *ini.File, cfg *Cfg) error {
SecretKey = valueAsString(security, "secret_key", "") SecretKey = valueAsString(security, "secret_key", "")
DisableGravatar = security.Key("disable_gravatar").MustBool(true) DisableGravatar = security.Key("disable_gravatar").MustBool(true)
cfg.DisableBruteForceLoginProtection = security.Key("disable_brute_force_login_protection").MustBool(false) cfg.DisableBruteForceLoginProtection = security.Key("disable_brute_force_login_protection").MustBool(false)
DisableBruteForceLoginProtection = cfg.DisableBruteForceLoginProtection
CookieSecure = security.Key("cookie_secure").MustBool(false) CookieSecure = security.Key("cookie_secure").MustBool(false)
cfg.CookieSecure = CookieSecure cfg.CookieSecure = CookieSecure
@@ -974,14 +994,14 @@ func readSecuritySettings(iniFile *ini.File, cfg *Cfg) error {
cfg.CookieSameSiteMode = CookieSameSiteMode cfg.CookieSameSiteMode = CookieSameSiteMode
} }
} }
AllowEmbedding = security.Key("allow_embedding").MustBool(false) cfg.AllowEmbedding = security.Key("allow_embedding").MustBool(false)
ContentTypeProtectionHeader = security.Key("x_content_type_options").MustBool(true) cfg.ContentTypeProtectionHeader = security.Key("x_content_type_options").MustBool(true)
XSSProtectionHeader = security.Key("x_xss_protection").MustBool(true) cfg.XSSProtectionHeader = security.Key("x_xss_protection").MustBool(true)
StrictTransportSecurity = security.Key("strict_transport_security").MustBool(false) cfg.StrictTransportSecurity = security.Key("strict_transport_security").MustBool(false)
StrictTransportSecurityMaxAge = security.Key("strict_transport_security_max_age_seconds").MustInt(86400) cfg.StrictTransportSecurityMaxAge = security.Key("strict_transport_security_max_age_seconds").MustInt(86400)
StrictTransportSecurityPreload = security.Key("strict_transport_security_preload").MustBool(false) cfg.StrictTransportSecurityPreload = security.Key("strict_transport_security_preload").MustBool(false)
StrictTransportSecuritySubDomains = security.Key("strict_transport_security_subdomains").MustBool(false) cfg.StrictTransportSecuritySubDomains = security.Key("strict_transport_security_subdomains").MustBool(false)
// read data source proxy whitelist // read data source proxy whitelist
DataProxyWhiteList = make(map[string]bool) DataProxyWhiteList = make(map[string]bool)
@@ -1054,41 +1074,45 @@ func readAuthSettings(iniFile *ini.File, cfg *Cfg) (err error) {
// anonymous access // anonymous access
AnonymousEnabled = iniFile.Section("auth.anonymous").Key("enabled").MustBool(false) AnonymousEnabled = iniFile.Section("auth.anonymous").Key("enabled").MustBool(false)
AnonymousOrgName = valueAsString(iniFile.Section("auth.anonymous"), "org_name", "") cfg.AnonymousEnabled = AnonymousEnabled
AnonymousOrgRole = valueAsString(iniFile.Section("auth.anonymous"), "org_role", "") cfg.AnonymousOrgName = valueAsString(iniFile.Section("auth.anonymous"), "org_name", "")
cfg.AnonymousOrgRole = valueAsString(iniFile.Section("auth.anonymous"), "org_role", "")
cfg.AnonymousHideVersion = iniFile.Section("auth.anonymous").Key("hide_version").MustBool(false) cfg.AnonymousHideVersion = iniFile.Section("auth.anonymous").Key("hide_version").MustBool(false)
// basic auth // basic auth
authBasic := iniFile.Section("auth.basic") authBasic := iniFile.Section("auth.basic")
BasicAuthEnabled = authBasic.Key("enabled").MustBool(true) BasicAuthEnabled = authBasic.Key("enabled").MustBool(true)
cfg.BasicAuthEnabled = BasicAuthEnabled
authProxy := iniFile.Section("auth.proxy") authProxy := iniFile.Section("auth.proxy")
AuthProxyEnabled = authProxy.Key("enabled").MustBool(false) AuthProxyEnabled = authProxy.Key("enabled").MustBool(false)
cfg.AuthProxyEnabled = AuthProxyEnabled
AuthProxyHeaderName = valueAsString(authProxy, "header_name", "") cfg.AuthProxyHeaderName = valueAsString(authProxy, "header_name", "")
AuthProxyHeaderProperty = valueAsString(authProxy, "header_property", "") AuthProxyHeaderProperty = valueAsString(authProxy, "header_property", "")
AuthProxyAutoSignUp = authProxy.Key("auto_sign_up").MustBool(true) cfg.AuthProxyHeaderProperty = AuthProxyHeaderProperty
AuthProxyEnableLoginToken = authProxy.Key("enable_login_token").MustBool(false) cfg.AuthProxyAutoSignUp = authProxy.Key("auto_sign_up").MustBool(true)
cfg.AuthProxyEnableLoginToken = authProxy.Key("enable_login_token").MustBool(false)
ldapSyncVal := authProxy.Key("ldap_sync_ttl").MustInt() ldapSyncVal := authProxy.Key("ldap_sync_ttl").MustInt()
syncVal := authProxy.Key("sync_ttl").MustInt() syncVal := authProxy.Key("sync_ttl").MustInt()
if ldapSyncVal != AuthProxySyncTTL { if ldapSyncVal != authProxySyncTTL {
AuthProxySyncTtl = ldapSyncVal cfg.AuthProxySyncTTL = ldapSyncVal
cfg.Logger.Warn("[Deprecated] the configuration setting 'ldap_sync_ttl' is deprecated, please use 'sync_ttl' instead") cfg.Logger.Warn("[Deprecated] the configuration setting 'ldap_sync_ttl' is deprecated, please use 'sync_ttl' instead")
} else { } else {
AuthProxySyncTtl = syncVal cfg.AuthProxySyncTTL = syncVal
} }
AuthProxyWhitelist = valueAsString(authProxy, "whitelist", "") cfg.AuthProxyWhitelist = valueAsString(authProxy, "whitelist", "")
AuthProxyHeaders = make(map[string]string) cfg.AuthProxyHeaders = make(map[string]string)
headers := valueAsString(authProxy, "headers", "") headers := valueAsString(authProxy, "headers", "")
for _, propertyAndHeader := range util.SplitString(headers) { for _, propertyAndHeader := range util.SplitString(headers) {
split := strings.SplitN(propertyAndHeader, ":", 2) split := strings.SplitN(propertyAndHeader, ":", 2)
if len(split) == 2 { if len(split) == 2 {
AuthProxyHeaders[split[0]] = split[1] cfg.AuthProxyHeaders[split[0]] = split[1]
} }
} }
@@ -1106,7 +1130,7 @@ func readUserSettings(iniFile *ini.File, cfg *Cfg) error {
LoginHint = valueAsString(users, "login_hint", "") LoginHint = valueAsString(users, "login_hint", "")
PasswordHint = valueAsString(users, "password_hint", "") PasswordHint = valueAsString(users, "password_hint", "")
DefaultTheme = valueAsString(users, "default_theme", "") cfg.DefaultTheme = valueAsString(users, "default_theme", "")
ExternalUserMngLinkUrl = valueAsString(users, "external_manage_link_url", "") ExternalUserMngLinkUrl = valueAsString(users, "external_manage_link_url", "")
ExternalUserMngLinkName = valueAsString(users, "external_manage_link_name", "") ExternalUserMngLinkName = valueAsString(users, "external_manage_link_name", "")
ExternalUserMngInfo = valueAsString(users, "external_manage_info", "") ExternalUserMngInfo = valueAsString(users, "external_manage_info", "")
@@ -1178,7 +1202,7 @@ func readAlertingSettings(iniFile *ini.File) error {
return nil return nil
} }
func readSnapshotsSettings(iniFile *ini.File) error { func readSnapshotsSettings(cfg *Cfg, iniFile *ini.File) error {
snapshots := iniFile.Section("snapshots") snapshots := iniFile.Section("snapshots")
ExternalSnapshotUrl = valueAsString(snapshots, "external_snapshot_url", "") ExternalSnapshotUrl = valueAsString(snapshots, "external_snapshot_url", "")
@@ -1186,7 +1210,7 @@ func readSnapshotsSettings(iniFile *ini.File) error {
ExternalEnabled = snapshots.Key("external_enabled").MustBool(true) ExternalEnabled = snapshots.Key("external_enabled").MustBool(true)
SnapShotRemoveExpired = snapshots.Key("snapshot_remove_expired").MustBool(true) SnapShotRemoveExpired = snapshots.Key("snapshot_remove_expired").MustBool(true)
SnapshotPublicMode = snapshots.Key("public_mode").MustBool(false) cfg.SnapshotPublicMode = snapshots.Key("public_mode").MustBool(false)
return nil return nil
} }
@@ -1204,28 +1228,28 @@ func readServerSettings(iniFile *ini.File, cfg *Cfg) error {
cfg.AppSubURL = AppSubUrl cfg.AppSubURL = AppSubUrl
cfg.ServeFromSubPath = ServeFromSubPath cfg.ServeFromSubPath = ServeFromSubPath
Protocol = HTTPScheme cfg.Protocol = HTTPScheme
protocolStr := valueAsString(server, "protocol", "http") protocolStr := valueAsString(server, "protocol", "http")
if protocolStr == "https" { if protocolStr == "https" {
Protocol = HTTPSScheme cfg.Protocol = HTTPSScheme
CertFile = server.Key("cert_file").String() CertFile = server.Key("cert_file").String()
KeyFile = server.Key("cert_key").String() KeyFile = server.Key("cert_key").String()
} }
if protocolStr == "h2" { if protocolStr == "h2" {
Protocol = HTTP2Scheme cfg.Protocol = HTTP2Scheme
CertFile = server.Key("cert_file").String() CertFile = server.Key("cert_file").String()
KeyFile = server.Key("cert_key").String() KeyFile = server.Key("cert_key").String()
} }
if protocolStr == "socket" { if protocolStr == "socket" {
Protocol = SocketScheme cfg.Protocol = SocketScheme
SocketPath = server.Key("socket").String() cfg.SocketPath = server.Key("socket").String()
} }
Domain = valueAsString(server, "domain", "localhost") cfg.Domain = valueAsString(server, "domain", "localhost")
HttpAddr = valueAsString(server, "http_addr", DefaultHTTPAddr) HttpAddr = valueAsString(server, "http_addr", DefaultHTTPAddr)
HttpPort = valueAsString(server, "http_port", "3000") HttpPort = valueAsString(server, "http_port", "3000")
RouterLogging = server.Key("router_logging").MustBool(false) cfg.RouterLogging = server.Key("router_logging").MustBool(false)
EnableGzip = server.Key("enable_gzip").MustBool(false) EnableGzip = server.Key("enable_gzip").MustBool(false)
EnforceDomain = server.Key("enforce_domain").MustBool(false) EnforceDomain = server.Key("enforce_domain").MustBool(false)

View File

@@ -86,4 +86,6 @@ func (cfg *Cfg) readQuotaSettings() {
ApiKey: quota.Key("global_api_key").MustInt64(-1), ApiKey: quota.Key("global_api_key").MustInt64(-1),
Session: quota.Key("global_session").MustInt64(-1), Session: quota.Key("global_session").MustInt64(-1),
} }
cfg.Quota = Quota
} }

View File

@@ -133,7 +133,7 @@ func TestLoadingSettings(t *testing.T) {
}) })
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(Domain, ShouldEqual, "test2") So(cfg.Domain, ShouldEqual, "test2")
}) })
Convey("Defaults can be overridden in specified config file", func() { Convey("Defaults can be overridden in specified config file", func() {
@@ -239,7 +239,7 @@ func TestLoadingSettings(t *testing.T) {
}) })
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(AuthProxySyncTtl, ShouldEqual, 2) So(cfg.AuthProxySyncTTL, ShouldEqual, 2)
}) })
Convey("Only ldap_sync_ttl should return the value ldap_sync_ttl", func() { Convey("Only ldap_sync_ttl should return the value ldap_sync_ttl", func() {
@@ -250,7 +250,7 @@ func TestLoadingSettings(t *testing.T) {
}) })
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(AuthProxySyncTtl, ShouldEqual, 5) So(cfg.AuthProxySyncTTL, ShouldEqual, 5)
}) })
Convey("ldap_sync should override ldap_sync_ttl that is default value", func() { Convey("ldap_sync should override ldap_sync_ttl that is default value", func() {
@@ -261,7 +261,7 @@ func TestLoadingSettings(t *testing.T) {
}) })
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(AuthProxySyncTtl, ShouldEqual, 5) So(cfg.AuthProxySyncTTL, ShouldEqual, 5)
}) })
Convey("ldap_sync should not override ldap_sync_ttl that is different from default value", func() { Convey("ldap_sync should not override ldap_sync_ttl that is different from default value", func() {
@@ -272,7 +272,7 @@ func TestLoadingSettings(t *testing.T) {
}) })
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(AuthProxySyncTtl, ShouldEqual, 12) So(cfg.AuthProxySyncTTL, ShouldEqual, 12)
}) })
}) })