diff --git a/pkg/api/common_test.go b/pkg/api/common_test.go index b256c1454d7..f1870cca517 100644 --- a/pkg/api/common_test.go +++ b/pkg/api/common_test.go @@ -197,7 +197,9 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg) *contexthandler.ContextHa tracer, err := tracing.InitializeTracerForTest() require.NoError(t, err) authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginservice.LoginServiceMock{}, sqlStore) - ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy) + loginService := &logintest.LoginServiceFake{} + authenticator := &logintest.AuthenticatorFake{} + ctxHdlr := contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, authenticator) return ctxHdlr } diff --git a/pkg/api/http_server.go b/pkg/api/http_server.go index 6d026fc5f26..fc87b642ca5 100644 --- a/pkg/api/http_server.go +++ b/pkg/api/http_server.go @@ -507,7 +507,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() { m.Use(hs.pluginMetricsEndpoint) m.Use(hs.ContextHandler.Middleware) - m.Use(middleware.OrgRedirect(hs.Cfg)) + m.Use(middleware.OrgRedirect(hs.Cfg, hs.SQLStore)) m.Use(acmiddleware.LoadPermissionsMiddleware(hs.AccessControl)) // needs to be after context handler diff --git a/pkg/middleware/middleware_basic_auth_test.go b/pkg/middleware/middleware_basic_auth_test.go index 7588d2172e1..dda1b609269 100644 --- a/pkg/middleware/middleware_basic_auth_test.go +++ b/pkg/middleware/middleware_basic_auth_test.go @@ -88,6 +88,7 @@ func TestMiddlewareBasicAuth(t *testing.T) { }, configure) middlewareScenario(t, "Should return error if user is not found", func(t *testing.T, sc *scenarioContext) { + sc.mockSQLStore.ExpectedError = models.ErrUserNotFound sc.fakeReq("GET", "/") sc.req.SetBasicAuth("user", "password") sc.exec() @@ -100,10 +101,7 @@ func TestMiddlewareBasicAuth(t *testing.T) { }, configure) middlewareScenario(t, "Should return error if user & password do not match", func(t *testing.T, sc *scenarioContext) { - bus.AddHandler("user-query", func(ctx context.Context, loginUserQuery *models.GetUserByLoginQuery) error { - return nil - }) - + sc.mockSQLStore.ExpectedError = models.ErrUserNotFound sc.fakeReq("GET", "/") sc.req.SetBasicAuth("killa", "gorilla") sc.exec() diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index fe645b87cd2..b52083dbeed 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -25,6 +25,7 @@ import ( "github.com/grafana/grafana/pkg/services/contexthandler" "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" "github.com/grafana/grafana/pkg/services/login/loginservice" + "github.com/grafana/grafana/pkg/services/login/logintest" "github.com/grafana/grafana/pkg/services/rendering" "github.com/grafana/grafana/pkg/services/sqlstore/mockstore" "github.com/grafana/grafana/pkg/setting" @@ -602,7 +603,7 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func( sc.sqlStore = ctxHdlr.SQLStore sc.contextHandler = ctxHdlr sc.m.Use(ctxHdlr.Middleware) - sc.m.Use(OrgRedirect(sc.cfg)) + sc.m.Use(OrgRedirect(sc.cfg, sc.mockSQLStore)) sc.userAuthTokenService = ctxHdlr.AuthTokenService.(*auth.FakeUserAuthTokenService) sc.jwtAuthService = ctxHdlr.JWTAuthService.(*models.FakeJWTService) @@ -644,8 +645,9 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.S authJWTSvc := models.NewFakeJWTService() tracer, err := tracing.InitializeTracerForTest() authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, mockSQLStore) + authenticator := &logintest.AuthenticatorFake{ExpectedUser: &models.User{}} require.NoError(t, err) - return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy) + return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy, loginService, authenticator) } type fakeRenderService struct { diff --git a/pkg/middleware/org_redirect.go b/pkg/middleware/org_redirect.go index e6f8b638f52..ecf0d0800b0 100644 --- a/pkg/middleware/org_redirect.go +++ b/pkg/middleware/org_redirect.go @@ -6,16 +6,16 @@ import ( "strconv" "strings" - "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/contexthandler" + "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/web" ) // OrgRedirect changes org and redirects users if the // querystring `orgId` doesn't match the active org. -func OrgRedirect(cfg *setting.Cfg) web.Handler { +func OrgRedirect(cfg *setting.Cfg, store sqlstore.Store) web.Handler { return func(res http.ResponseWriter, req *http.Request, c *web.Context) { orgIdValue := req.URL.Query().Get("orgId") orgId, err := strconv.ParseInt(orgIdValue, 10, 64) @@ -34,7 +34,7 @@ func OrgRedirect(cfg *setting.Cfg) web.Handler { } cmd := models.SetUsingOrgCommand{UserId: ctx.UserId, OrgId: orgId} - if err := bus.Dispatch(ctx.Req.Context(), &cmd); err != nil { + if err := store.SetUsingOrg(ctx.Req.Context(), &cmd); err != nil { if ctx.IsApiRequest() { ctx.JsonApiErr(404, "Not found", nil) } else { diff --git a/pkg/middleware/org_redirect_test.go b/pkg/middleware/org_redirect_test.go index 5917e9a5377..22450ec880e 100644 --- a/pkg/middleware/org_redirect_test.go +++ b/pkg/middleware/org_redirect_test.go @@ -68,9 +68,7 @@ func TestOrgRedirectMiddleware(t *testing.T) { middlewareScenario(t, "when setting an invalid org for user", func(t *testing.T, sc *scenarioContext) { sc.withTokenSessionCookie("token") - bus.AddHandler("test", func(ctx context.Context, query *models.SetUsingOrgCommand) error { - return fmt.Errorf("") - }) + sc.mockSQLStore.ExpectedSetUsingOrgError = fmt.Errorf("") sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: 1, UserId: 12} sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { diff --git a/pkg/middleware/recovery_test.go b/pkg/middleware/recovery_test.go index 51b1c6c1c4d..7e3e4a5939f 100644 --- a/pkg/middleware/recovery_test.go +++ b/pkg/middleware/recovery_test.go @@ -73,7 +73,7 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) { contextHandler := getContextHandler(t, nil, nil, nil) sc.m.Use(contextHandler.Middleware) // mock out gc goroutine - sc.m.Use(OrgRedirect(cfg)) + sc.m.Use(OrgRedirect(cfg, sc.mockSQLStore)) sc.defaultHandler = func(c *models.ReqContext) { sc.context = c diff --git a/pkg/services/contexthandler/auth_jwt.go b/pkg/services/contexthandler/auth_jwt.go index d29b4b07d65..cda763f2c37 100644 --- a/pkg/services/contexthandler/auth_jwt.go +++ b/pkg/services/contexthandler/auth_jwt.go @@ -3,7 +3,6 @@ package contexthandler import ( "errors" - "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/models" ) @@ -67,7 +66,7 @@ func (h *ContextHandler) initContextWithJWT(ctx *models.ReqContext, orgId int64) SignupAllowed: h.Cfg.JWTAuthAutoSignUp, ExternalUser: extUser, } - if err := bus.Dispatch(ctx.Req.Context(), upsert); err != nil { + if err := h.loginService.UpsertUser(ctx.Req.Context(), upsert); err != nil { ctx.Logger.Error("Failed to upsert JWT user", "error", err) return false } diff --git a/pkg/services/contexthandler/auth_proxy_test.go b/pkg/services/contexthandler/auth_proxy_test.go index 82e6245c428..48fea6227bb 100644 --- a/pkg/services/contexthandler/auth_proxy_test.go +++ b/pkg/services/contexthandler/auth_proxy_test.go @@ -84,8 +84,9 @@ func getContextHandler(t *testing.T) *ContextHandler { loginService := loginservice.LoginServiceMock{ExpectedUser: &models.User{Id: userID}} authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, &FakeGetSignUserStore{}) + authenticator := &fakeAuthenticator{} - return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy) + return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy, loginService, authenticator) } type FakeGetSignUserStore struct { @@ -103,3 +104,9 @@ func (f *FakeGetSignUserStore) GetSignedInUser(ctx context.Context, query *model } return nil } + +type fakeAuthenticator struct{} + +func (fa *fakeAuthenticator) AuthenticateUser(c context.Context, query *models.LoginUserQuery) error { + return nil +} diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index 99111ac9194..407b89fa326 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -9,12 +9,12 @@ import ( "strings" "time" - "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/components/apikeygen" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/network" "github.com/grafana/grafana/pkg/infra/remotecache" "github.com/grafana/grafana/pkg/infra/tracing" + loginpkg "github.com/grafana/grafana/pkg/login" "github.com/grafana/grafana/pkg/middleware/cookies" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" @@ -36,7 +36,7 @@ const ServiceName = "ContextHandler" func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtService models.JWTService, remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore sqlstore.Store, - tracer tracing.Tracer, authProxy *authproxy.AuthProxy) *ContextHandler { + tracer tracing.Tracer, authProxy *authproxy.AuthProxy, loginService login.Service, authenticator loginpkg.Authenticator) *ContextHandler { return &ContextHandler{ Cfg: cfg, AuthTokenService: tokenService, @@ -46,6 +46,8 @@ func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtS SQLStore: sqlStore, tracer: tracer, authProxy: authProxy, + authenticator: authenticator, + loginService: loginService, } } @@ -59,6 +61,8 @@ type ContextHandler struct { SQLStore sqlstore.Store tracer tracing.Tracer authProxy *authproxy.AuthProxy + authenticator loginpkg.Authenticator + loginService login.Service // GetTime returns the current time. // Stubbable by tests. GetTime func() time.Time @@ -291,7 +295,7 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *models.ReqContext, Password: password, Cfg: h.Cfg, } - if err := bus.Dispatch(reqContext.Req.Context(), &authQuery); err != nil { + if err := h.authenticator.AuthenticateUser(reqContext.Req.Context(), &authQuery); err != nil { reqContext.Logger.Debug( "Failed to authorize the user", "username", username, diff --git a/pkg/services/login/logintest/logintest.go b/pkg/services/login/logintest/logintest.go index 6197e8d5ce9..020714632e8 100644 --- a/pkg/services/login/logintest/logintest.go +++ b/pkg/services/login/logintest/logintest.go @@ -49,3 +49,13 @@ func (a *AuthInfoServiceFake) GetExternalUserInfoByLogin(ctx context.Context, qu query.Result = a.ExpectedExternalUser return a.ExpectedError } + +type AuthenticatorFake struct { + ExpectedUser *models.User + ExpectedError error +} + +func (a *AuthenticatorFake) AuthenticateUser(c context.Context, query *models.LoginUserQuery) error { + query.User = a.ExpectedUser + return a.ExpectedError +} diff --git a/pkg/services/sqlstore/mockstore/mockstore.go b/pkg/services/sqlstore/mockstore/mockstore.go index 2094f0100da..f66b16b723e 100644 --- a/pkg/services/sqlstore/mockstore/mockstore.go +++ b/pkg/services/sqlstore/mockstore/mockstore.go @@ -44,7 +44,8 @@ type SQLStoreMock struct { ExpectedUserStars map[int64]bool ExpectedLoginAttempts int64 - ExpectedError error + ExpectedError error + ExpectedSetUsingOrgError error } func NewSQLStoreMock() *SQLStoreMock { @@ -178,7 +179,7 @@ func (m *SQLStoreMock) UpdateUserLastSeenAt(ctx context.Context, cmd *models.Upd } func (m *SQLStoreMock) SetUsingOrg(ctx context.Context, cmd *models.SetUsingOrgCommand) error { - return m.ExpectedError + return m.ExpectedSetUsingOrgError } func (m *SQLStoreMock) GetUserProfile(ctx context.Context, query *models.GetUserProfileQuery) error {