diff --git a/pkg/middleware/auth_test.go b/pkg/middleware/auth_test.go index db6bdc96d5d..282c4a6eca8 100644 --- a/pkg/middleware/auth_test.go +++ b/pkg/middleware/auth_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/grafana/grafana/pkg/models" - "github.com/grafana/grafana/pkg/services/sqlstore/mockstore" "github.com/grafana/grafana/pkg/setting" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -39,9 +38,7 @@ func TestMiddlewareAuth(t *testing.T) { middlewareScenario(t, "ReqSignIn true and NoAnonynmous true", func( t *testing.T, sc *scenarioContext) { - sqlStore := mockstore.NewSQLStoreMock() - sqlStore.ExpectedOrg = &models.Org{Id: orgID, Name: "test"} - sc.sqlStore = sqlStore + sc.mockSQLStore.ExpectedOrg = &models.Org{Id: orgID, Name: "test"} sc.m.Get("/api/secure", ReqSignedInNoAnonymous, sc.defaultHandler) sc.fakeReq("GET", "/api/secure").exec() @@ -50,9 +47,7 @@ func TestMiddlewareAuth(t *testing.T) { middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func( t *testing.T, sc *scenarioContext) { - sqlStore := mockstore.NewSQLStoreMock() - sqlStore.ExpectedOrg = &models.Org{Id: orgID, Name: "test"} - sc.sqlStore = sqlStore + sc.mockSQLStore.ExpectedOrg = &models.Org{Id: orgID, Name: "test"} sc.m.Get("/secure", reqSignIn, sc.defaultHandler) sc.fakeReq("GET", "/secure?forceLogin=true").exec() @@ -65,7 +60,8 @@ func TestMiddlewareAuth(t *testing.T) { middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func( t *testing.T, sc *scenarioContext) { - org, err := sc.sqlStore.CreateOrgWithMember(sc.cfg.AnonymousOrgName, 1) + sc.mockSQLStore.ExpectedOrg = &models.Org{Id: 1, Name: sc.cfg.AnonymousOrgName} + org, err := sc.mockSQLStore.CreateOrgWithMember(sc.cfg.AnonymousOrgName, 1) require.NoError(t, err) sc.m.Get("/secure", reqSignIn, sc.defaultHandler) @@ -77,6 +73,7 @@ func TestMiddlewareAuth(t *testing.T) { middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func( t *testing.T, sc *scenarioContext) { + sc.mockSQLStore.ExpectedOrg = &models.Org{Id: 1, Name: sc.cfg.AnonymousOrgName} sc.m.Get("/secure", reqSignIn, sc.defaultHandler) sc.fakeReq("GET", "/secure?orgId=2").exec() diff --git a/pkg/middleware/middleware_basic_auth_test.go b/pkg/middleware/middleware_basic_auth_test.go index 1c4fbca2ee0..7588d2172e1 100644 --- a/pkg/middleware/middleware_basic_auth_test.go +++ b/pkg/middleware/middleware_basic_auth_test.go @@ -29,10 +29,7 @@ func TestMiddlewareBasicAuth(t *testing.T) { keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") require.NoError(t, err) - bus.AddHandler("test", func(ctx context.Context, query *models.GetApiKeyByNameQuery) error { - query.Result = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash} - return nil - }) + sc.mockSQLStore.ExpectedAPIKey = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash} authHeader := util.GetBasicAuthHeader("api_key", "eyJrIjoidjVuQXdwTWFmRlA2em5hUzR1cmhkV0RMUzU1MTFNNDIiLCJuIjoiYXNkIiwiaWQiOjF9") sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() @@ -61,11 +58,7 @@ func TestMiddlewareBasicAuth(t *testing.T) { return nil }) - bus.AddHandler("get-sign-user", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - t.Log("Handling GetSignedInUserQuery") - query.Result = &models.SignedInUser{OrgId: orgID, UserId: id} - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: orgID, UserId: id} authHeader := util.GetBasicAuthHeader("myUser", password) sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() @@ -85,10 +78,6 @@ func TestMiddlewareBasicAuth(t *testing.T) { sc.mockSQLStore.ExpectedUser = &models.User{Password: encoded, Id: id, Salt: salt} sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{UserId: id} login.ProvideService(sc.mockSQLStore, &logintest.LoginServiceFake{}) - bus.AddHandler("get-sign-user", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{UserId: query.UserId} - return nil - }) authHeader := util.GetBasicAuthHeader("myUser", password) sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec() diff --git a/pkg/middleware/middleware_jwt_auth_test.go b/pkg/middleware/middleware_jwt_auth_test.go index 770d7b1c802..02d4075201d 100644 --- a/pkg/middleware/middleware_jwt_auth_test.go +++ b/pkg/middleware/middleware_jwt_auth_test.go @@ -46,14 +46,7 @@ func TestMiddlewareJWTAuth(t *testing.T) { "foo-username": myUsername, }, nil } - bus.AddHandler("get-sign-user", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{ - UserId: id, - OrgId: orgID, - Login: query.Login, - } - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{UserId: id, OrgId: orgID, Login: myUsername} sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec() assert.Equal(t, verifiedToken, token) @@ -74,14 +67,7 @@ func TestMiddlewareJWTAuth(t *testing.T) { "foo-email": myEmail, }, nil } - bus.AddHandler("get-sign-user", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{ - UserId: id, - OrgId: orgID, - Email: query.Email, - } - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{UserId: id, OrgId: orgID, Email: myEmail} sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec() assert.Equal(t, verifiedToken, token) @@ -103,9 +89,7 @@ func TestMiddlewareJWTAuth(t *testing.T) { "foo-email": myEmail, }, nil } - bus.AddHandler("get-sign-user", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - return models.ErrUserNotFound - }) + sc.mockSQLStore.ExpectedError = models.ErrUserNotFound sc.fakeReq("GET", "/").withJWTAuthHeader(token).exec() assert.Equal(t, verifiedToken, token) @@ -124,14 +108,7 @@ func TestMiddlewareJWTAuth(t *testing.T) { "foo-email": myEmail, }, nil } - bus.AddHandler("get-sign-user", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{ - UserId: id, - OrgId: orgID, - Email: query.Email, - } - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{UserId: id, OrgId: orgID, Email: myEmail} bus.AddHandler("upsert-user", func(ctx context.Context, command *models.UpsertUserCommand) error { command.Result = &models.User{ Id: id, diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index 637e02c0a10..fe645b87cd2 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -26,7 +26,6 @@ import ( "github.com/grafana/grafana/pkg/services/contexthandler/authproxy" "github.com/grafana/grafana/pkg/services/login/loginservice" "github.com/grafana/grafana/pkg/services/rendering" - "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore/mockstore" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/util" @@ -150,10 +149,7 @@ func TestMiddlewareContext(t *testing.T) { keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") require.NoError(t, err) - bus.AddHandler("test", func(ctx context.Context, query *models.GetApiKeyByNameQuery) error { - query.Result = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash} - return nil - }) + sc.mockSQLStore.ExpectedAPIKey = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash} sc.fakeReq("GET", "/").withValidApiKey().exec() @@ -166,11 +162,7 @@ func TestMiddlewareContext(t *testing.T) { middlewareScenario(t, "Valid API key, but does not match DB hash", func(t *testing.T, sc *scenarioContext) { const keyhash = "Something_not_matching" - - bus.AddHandler("test", func(ctx context.Context, query *models.GetApiKeyByNameQuery) error { - query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash} - return nil - }) + sc.mockSQLStore.ExpectedAPIKey = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash} sc.fakeReq("GET", "/").withValidApiKey().exec() @@ -184,13 +176,8 @@ func TestMiddlewareContext(t *testing.T) { keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd") require.NoError(t, err) - bus.AddHandler("test", func(ctx context.Context, query *models.GetApiKeyByNameQuery) error { - // api key expired one second before - expires := sc.contextHandler.GetTime().Add(-1 * time.Second).Unix() - query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash, - Expires: &expires} - return nil - }) + expires := sc.contextHandler.GetTime().Add(-1 * time.Second).Unix() + sc.mockSQLStore.ExpectedAPIKey = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash, Expires: &expires} sc.fakeReq("GET", "/").withValidApiKey().exec() @@ -203,11 +190,7 @@ func TestMiddlewareContext(t *testing.T) { const userID int64 = 12 sc.withTokenSessionCookie("token") - - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 2, UserId: userID} - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: 2, UserId: userID} sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { return &models.UserToken{ @@ -231,11 +214,7 @@ func TestMiddlewareContext(t *testing.T) { const userID int64 = 12 sc.withTokenSessionCookie("token") - - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 2, UserId: userID} - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: 2, UserId: userID} sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { return &models.UserToken{ @@ -332,7 +311,8 @@ func TestMiddlewareContext(t *testing.T) { }) middlewareScenario(t, "When anonymous access is enabled", func(t *testing.T, sc *scenarioContext) { - org, err := sc.sqlStore.CreateOrgWithMember(sc.cfg.AnonymousOrgName, 1) + sc.mockSQLStore.ExpectedOrg = &models.Org{Id: 1, Name: sc.cfg.AnonymousOrgName} + org, err := sc.mockSQLStore.CreateOrgWithMember(sc.cfg.AnonymousOrgName, 1) require.NoError(t, err) sc.fakeReq("GET", "/").exec() @@ -651,7 +631,6 @@ func middlewareScenario(t *testing.T, desc string, fn scenarioFunc, cbs ...func( func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.SQLStoreMock, loginService *loginservice.LoginServiceMock) *contexthandler.ContextHandler { t.Helper() - sqlStore := sqlstore.InitTestDB(t) if cfg == nil { cfg = setting.NewCfg() } @@ -666,7 +645,7 @@ func getContextHandler(t *testing.T, cfg *setting.Cfg, mockSQLStore *mockstore.S tracer, err := tracing.InitializeTracerForTest() authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, mockSQLStore) require.NoError(t, err) - return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, sqlStore, tracer, authProxy) + return contexthandler.ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, renderSvc, mockSQLStore, tracer, authProxy) } type fakeRenderService struct { diff --git a/pkg/middleware/org_redirect_test.go b/pkg/middleware/org_redirect_test.go index 4c4620843f6..5917e9a5377 100644 --- a/pkg/middleware/org_redirect_test.go +++ b/pkg/middleware/org_redirect_test.go @@ -46,15 +46,11 @@ func TestOrgRedirectMiddleware(t *testing.T) { for _, tc := range testCases { middlewareScenario(t, tc.desc, func(t *testing.T, sc *scenarioContext) { sc.withTokenSessionCookie("token") + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: 1, UserId: 12} bus.AddHandler("test", func(ctx context.Context, query *models.SetUsingOrgCommand) error { return nil }) - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 1, UserId: 12} - return nil - }) - sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { return &models.UserToken{ UserId: 0, @@ -75,11 +71,7 @@ func TestOrgRedirectMiddleware(t *testing.T) { bus.AddHandler("test", func(ctx context.Context, query *models.SetUsingOrgCommand) error { return fmt.Errorf("") }) - - bus.AddHandler("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error { - query.Result = &models.SignedInUser{OrgId: 1, UserId: 12} - return nil - }) + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{OrgId: 1, UserId: 12} sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { return &models.UserToken{ diff --git a/pkg/middleware/quota_test.go b/pkg/middleware/quota_test.go index 3112798a4a4..0d3ffec35ea 100644 --- a/pkg/middleware/quota_test.go +++ b/pkg/middleware/quota_test.go @@ -60,6 +60,7 @@ func TestMiddlewareQuota(t *testing.T) { const quotaUsed = 4 setUp := func(sc *scenarioContext) { sc.withTokenSessionCookie("token") + sc.mockSQLStore.ExpectedSignedInUser = &models.SignedInUser{UserId: 12} sc.userAuthTokenService.LookupTokenProvider = func(ctx context.Context, unhashedToken string) (*models.UserToken, error) { return &models.UserToken{ UserId: 12, diff --git a/pkg/server/wire.go b/pkg/server/wire.go index 851e064c762..566aa24de28 100644 --- a/pkg/server/wire.go +++ b/pkg/server/wire.go @@ -263,7 +263,7 @@ var wireTestSet = wire.NewSet( wire.Bind(new(notifications.WebhookSender), new(*notifications.NotificationServiceMock)), wire.Bind(new(notifications.EmailSender), new(*notifications.NotificationServiceMock)), mockstore.NewSQLStoreMock, - wire.Bind(new(sqlstore.Store), new(*mockstore.SQLStoreMock)), + wire.Bind(new(sqlstore.Store), new(*sqlstore.SQLStore)), ) func Initialize(cla setting.CommandLineArgs, opts Options, apiOpts api.ServerOptions) (*Server, error) { diff --git a/pkg/services/contexthandler/auth_jwt.go b/pkg/services/contexthandler/auth_jwt.go index 748bd13c120..d29b4b07d65 100644 --- a/pkg/services/contexthandler/auth_jwt.go +++ b/pkg/services/contexthandler/auth_jwt.go @@ -73,7 +73,7 @@ func (h *ContextHandler) initContextWithJWT(ctx *models.ReqContext, orgId int64) } } - if err := bus.Dispatch(ctx.Req.Context(), &query); err != nil { + if err := h.SQLStore.GetSignedInUserWithCacheCtx(ctx.Req.Context(), &query); err != nil { if errors.Is(err, models.ErrUserNotFound) { ctx.Logger.Debug( "Failed to find user using JWT claims", diff --git a/pkg/services/contexthandler/contexthandler.go b/pkg/services/contexthandler/contexthandler.go index 0a62173b64a..99111ac9194 100644 --- a/pkg/services/contexthandler/contexthandler.go +++ b/pkg/services/contexthandler/contexthandler.go @@ -35,7 +35,7 @@ const ( const ServiceName = "ContextHandler" func ProvideService(cfg *setting.Cfg, tokenService models.UserTokenService, jwtService models.JWTService, - remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore *sqlstore.SQLStore, + remoteCache *remotecache.RemoteCache, renderService rendering.Service, sqlStore sqlstore.Store, tracer tracing.Tracer, authProxy *authproxy.AuthProxy) *ContextHandler { return &ContextHandler{ Cfg: cfg, @@ -151,7 +151,7 @@ func (h *ContextHandler) Middleware(mContext *web.Context) { // update last seen every 5min if reqContext.ShouldUpdateLastSeenAt() { reqContext.Logger.Debug("Updating last user_seen_at", "user_id", reqContext.UserId) - if err := bus.Dispatch(mContext.Req.Context(), &models.UpdateUserLastSeenAtCommand{UserId: reqContext.UserId}); err != nil { + if err := h.SQLStore.UpdateUserLastSeenAt(mContext.Req.Context(), &models.UpdateUserLastSeenAtCommand{UserId: reqContext.UserId}); err != nil { reqContext.Logger.Error("Failed to update last_seen_at", "error", err) } } @@ -209,7 +209,7 @@ func (h *ContextHandler) initContextWithAPIKey(reqContext *models.ReqContext) bo // fetch key keyQuery := models.GetApiKeyByNameQuery{KeyName: decoded.Name, OrgId: decoded.OrgId} - if err := bus.Dispatch(reqContext.Req.Context(), &keyQuery); err != nil { + if err := h.SQLStore.GetApiKeyByName(reqContext.Req.Context(), &keyQuery); err != nil { reqContext.JsonApiErr(401, InvalidAPIKey, err) return true } @@ -251,7 +251,7 @@ func (h *ContextHandler) initContextWithAPIKey(reqContext *models.ReqContext) bo //Use service account linked to API key as the signed in user query := models.GetSignedInUserQuery{UserId: *apikey.ServiceAccountId, OrgId: apikey.OrgId} - if err := bus.Dispatch(reqContext.Req.Context(), &query); err != nil { + if err := h.SQLStore.GetSignedInUserWithCacheCtx(reqContext.Req.Context(), &query); err != nil { reqContext.Logger.Error( "Failed to link API key to service account in", "id", query.UserId, @@ -308,7 +308,7 @@ func (h *ContextHandler) initContextWithBasicAuth(reqContext *models.ReqContext, user := authQuery.User query := models.GetSignedInUserQuery{UserId: user.Id, OrgId: orgID} - if err := bus.Dispatch(ctx, &query); err != nil { + if err := h.SQLStore.GetSignedInUserWithCacheCtx(ctx, &query); err != nil { reqContext.Logger.Error( "Failed at user signed in", "id", user.Id, @@ -344,7 +344,7 @@ func (h *ContextHandler) initContextWithToken(reqContext *models.ReqContext, org } query := models.GetSignedInUserQuery{UserId: token.UserId, OrgId: orgID} - if err := bus.Dispatch(ctx, &query); err != nil { + if err := h.SQLStore.GetSignedInUserWithCacheCtx(ctx, &query); err != nil { reqContext.Logger.Error("Failed to get user with id", "userId", token.UserId, "error", err) return false } diff --git a/pkg/services/sqlstore/mockstore/mockstore.go b/pkg/services/sqlstore/mockstore/mockstore.go index 403089ffdae..2094f0100da 100644 --- a/pkg/services/sqlstore/mockstore/mockstore.go +++ b/pkg/services/sqlstore/mockstore/mockstore.go @@ -40,6 +40,7 @@ type SQLStoreMock struct { ExpectedNotifierUsageStats []*models.NotifierUsageStats ExpectedPersistedDashboards models.HitList ExpectedSignedInUser *models.SignedInUser + ExpectedAPIKey *models.ApiKey ExpectedUserStars map[int64]bool ExpectedLoginAttempts int64 @@ -622,6 +623,7 @@ func (m *SQLStoreMock) GetApiKeyById(ctx context.Context, query *models.GetApiKe } func (m *SQLStoreMock) GetApiKeyByName(ctx context.Context, query *models.GetApiKeyByNameQuery) error { + query.Result = m.ExpectedAPIKey return m.ExpectedError }