mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
bus: add ctx for all signed in user queries (#33970)
Signed-off-by: bergquist <carl.bergquist@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
@@ -59,7 +60,7 @@ func TestMiddlewareBasicAuth(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("get-sign-user", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
t.Log("Handling GetSignedInUserQuery")
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: id}
|
||||
return nil
|
||||
@@ -92,7 +93,7 @@ func TestMiddlewareBasicAuth(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("get-sign-user", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{UserId: query.UserId}
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -40,7 +40,7 @@ func TestMiddlewareJWTAuth(t *testing.T) {
|
||||
"foo-username": myUsername,
|
||||
}, nil
|
||||
}
|
||||
bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("get-sign-user", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{
|
||||
UserId: id,
|
||||
OrgId: orgID,
|
||||
@@ -67,7 +67,7 @@ func TestMiddlewareJWTAuth(t *testing.T) {
|
||||
"foo-email": myEmail,
|
||||
}, nil
|
||||
}
|
||||
bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("get-sign-user", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{
|
||||
UserId: id,
|
||||
OrgId: orgID,
|
||||
|
||||
@@ -203,7 +203,7 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
|
||||
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: 2, UserId: userID}
|
||||
return nil
|
||||
})
|
||||
@@ -231,7 +231,7 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
|
||||
sc.withTokenSessionCookie("token")
|
||||
|
||||
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: 2, UserId: userID}
|
||||
return nil
|
||||
})
|
||||
@@ -363,7 +363,7 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
const group = "grafana-core-team"
|
||||
|
||||
middlewareScenario(t, "Should not sync the user if it's in the cache", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: query.UserId}
|
||||
return nil
|
||||
})
|
||||
@@ -406,7 +406,7 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should create an user from a header", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
if query.UserId > 0 {
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
return nil
|
||||
@@ -436,7 +436,7 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
const userID int64 = 12
|
||||
const orgID int64 = 2
|
||||
|
||||
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
return nil
|
||||
})
|
||||
@@ -459,7 +459,7 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should allow the request from whitelist IP", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
return nil
|
||||
})
|
||||
@@ -484,7 +484,7 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
})
|
||||
|
||||
middlewareScenario(t, "Should not allow the request from whitelisted IP", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: orgID, UserId: userID}
|
||||
return nil
|
||||
})
|
||||
@@ -521,7 +521,7 @@ func TestMiddlewareContext(t *testing.T) {
|
||||
}, configure)
|
||||
|
||||
middlewareScenario(t, "Should return 407 status code if there is cache mishap", func(t *testing.T, sc *scenarioContext) {
|
||||
bus.AddHandler("Do not have the user", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("Do not have the user", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
return errors.New("Do not add user")
|
||||
})
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: 1, UserId: 12}
|
||||
return nil
|
||||
})
|
||||
@@ -41,7 +41,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
|
||||
return fmt.Errorf("")
|
||||
})
|
||||
|
||||
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: 1, UserId: 12}
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -89,7 +89,7 @@ func TestMiddlewareQuota(t *testing.T) {
|
||||
|
||||
setUp := func(sc *scenarioContext) {
|
||||
sc.withTokenSessionCookie("token")
|
||||
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{OrgId: 2, UserId: 12}
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -42,7 +42,7 @@ func (h *ContextHandler) initContextWithJWT(ctx *models.ReqContext, orgId int64)
|
||||
return true
|
||||
}
|
||||
|
||||
if err := bus.Dispatch(&query); err != nil {
|
||||
if err := bus.DispatchCtx(ctx.Req.Context(), &query); err != nil {
|
||||
if errors.Is(err, models.ErrUserNotFound) {
|
||||
ctx.Logger.Debug(
|
||||
"Failed to find user using JWT claims",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package contexthandler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
@@ -34,20 +35,20 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
|
||||
cmd.Result = &models.User{Id: userID}
|
||||
return nil
|
||||
}
|
||||
getUserHandler := func(cmd *models.GetSignedInUserQuery) error {
|
||||
getUserHandler := func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
// Simulate that the cached user ID is stale
|
||||
if cmd.UserId != userID {
|
||||
if query.UserId != userID {
|
||||
return models.ErrUserNotFound
|
||||
}
|
||||
|
||||
cmd.Result = &models.SignedInUser{
|
||||
query.Result = &models.SignedInUser{
|
||||
UserId: userID,
|
||||
OrgId: orgID,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
bus.AddHandler("", upsertHandler)
|
||||
bus.AddHandler("", getUserHandler)
|
||||
bus.AddHandlerCtx("", getUserHandler)
|
||||
t.Cleanup(func() {
|
||||
bus.ClearBusHandlers()
|
||||
})
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package authproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -319,7 +320,7 @@ func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, erro
|
||||
UserId: userID,
|
||||
}
|
||||
|
||||
if err := bus.Dispatch(query); err != nil {
|
||||
if err := bus.DispatchCtx(context.Background(), query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -237,7 +237,7 @@ func (h *ContextHandler) initContextWithBasicAuth(ctx *models.ReqContext, orgID
|
||||
user := authQuery.User
|
||||
|
||||
query := models.GetSignedInUserQuery{UserId: user.Id, OrgId: orgID}
|
||||
if err := bus.Dispatch(&query); err != nil {
|
||||
if err := bus.DispatchCtx(ctx.Req.Context(), &query); err != nil {
|
||||
ctx.Logger.Error(
|
||||
"Failed at user signed in",
|
||||
"id", user.Id,
|
||||
@@ -270,7 +270,7 @@ func (h *ContextHandler) initContextWithToken(ctx *models.ReqContext, orgID int6
|
||||
}
|
||||
|
||||
query := models.GetSignedInUserQuery{UserId: token.UserId, OrgId: orgID}
|
||||
if err := bus.Dispatch(&query); err != nil {
|
||||
if err := bus.DispatchCtx(ctx.Req.Context(), &query); err != nil {
|
||||
ctx.Logger.Error("Failed to get user with id", "userId", token.UserId, "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/bus"
|
||||
@@ -26,7 +27,7 @@ func TestSearch_SortedResults(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
bus.AddHandler("test", func(query *models.GetSignedInUserQuery) error {
|
||||
bus.AddHandlerCtx("test", func(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
query.Result = &models.SignedInUser{IsGrafanaAdmin: true}
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -186,7 +186,7 @@ func TestAccountDataAccess(t *testing.T) {
|
||||
|
||||
Convey("Can get logged in user projection", func() {
|
||||
query := models.GetSignedInUserQuery{UserId: ac2.Id}
|
||||
err := GetSignedInUser(&query)
|
||||
err := GetSignedInUser(context.Background(), &query)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
So(query.Result.Email, ShouldEqual, "ac2@test.com")
|
||||
@@ -247,7 +247,7 @@ func TestAccountDataAccess(t *testing.T) {
|
||||
|
||||
Convey("SignedInUserQuery with a different org", func() {
|
||||
query := models.GetSignedInUserQuery{UserId: ac2.Id}
|
||||
err := GetSignedInUser(&query)
|
||||
err := GetSignedInUser(context.Background(), &query)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
So(query.Result.OrgId, ShouldEqual, ac1.OrgId)
|
||||
@@ -264,7 +264,7 @@ func TestAccountDataAccess(t *testing.T) {
|
||||
So(err, ShouldBeNil)
|
||||
|
||||
query := models.GetSignedInUserQuery{UserId: ac2.Id}
|
||||
err = GetSignedInUser(&query)
|
||||
err = GetSignedInUser(context.Background(), &query)
|
||||
|
||||
So(err, ShouldBeNil)
|
||||
So(query.Result.OrgId, ShouldEqual, ac2.OrgId)
|
||||
@@ -282,7 +282,7 @@ func TestAccountDataAccess(t *testing.T) {
|
||||
So(err, ShouldBeNil)
|
||||
So(remCmd.UserWasDeleted, ShouldBeTrue)
|
||||
|
||||
err = GetSignedInUser(&models.GetSignedInUserQuery{UserId: ac2.Id})
|
||||
err = GetSignedInUser(context.Background(), &models.GetSignedInUserQuery{UserId: ac2.Id})
|
||||
So(err, ShouldEqual, models.ErrUserNotFound)
|
||||
})
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
func (ss *SQLStore) addUserQueryAndCommandHandlers() {
|
||||
ss.Bus.AddHandler(ss.GetSignedInUserWithCache)
|
||||
ss.Bus.AddHandlerCtx(ss.GetSignedInUserWithCacheCtx)
|
||||
|
||||
bus.AddHandler("sql", GetUserById)
|
||||
bus.AddHandler("sql", UpdateUser)
|
||||
@@ -490,14 +490,14 @@ func newSignedInUserCacheKey(orgID, userID int64) string {
|
||||
return fmt.Sprintf("signed-in-user-%d-%d", userID, orgID)
|
||||
}
|
||||
|
||||
func (ss *SQLStore) GetSignedInUserWithCache(query *models.GetSignedInUserQuery) error {
|
||||
func (ss *SQLStore) GetSignedInUserWithCacheCtx(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
cacheKey := newSignedInUserCacheKey(query.OrgId, query.UserId)
|
||||
if cached, found := ss.CacheService.Get(cacheKey); found {
|
||||
query.Result = cached.(*models.SignedInUser)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := GetSignedInUser(query)
|
||||
err := GetSignedInUser(ctx, query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -507,7 +507,7 @@ func (ss *SQLStore) GetSignedInUserWithCache(query *models.GetSignedInUserQuery)
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSignedInUser(query *models.GetSignedInUserQuery) error {
|
||||
func GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error {
|
||||
orgId := "u.org_id"
|
||||
if query.OrgId > 0 {
|
||||
orgId = strconv.FormatInt(query.OrgId, 10)
|
||||
|
||||
@@ -346,14 +346,14 @@ func TestUserDataAccess(t *testing.T) {
|
||||
ss.CacheService.Flush()
|
||||
|
||||
query3 := &models.GetSignedInUserQuery{OrgId: users[1].OrgId, UserId: users[1].Id}
|
||||
err = ss.GetSignedInUserWithCache(query3)
|
||||
err = ss.GetSignedInUserWithCacheCtx(context.Background(), query3)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, query3.Result)
|
||||
require.Equal(t, query3.OrgId, users[1].OrgId)
|
||||
err = SetUsingOrg(&models.SetUsingOrgCommand{UserId: users[1].Id, OrgId: users[0].OrgId})
|
||||
require.Nil(t, err)
|
||||
query4 := &models.GetSignedInUserQuery{OrgId: 0, UserId: users[1].Id}
|
||||
err = ss.GetSignedInUserWithCache(query4)
|
||||
err = ss.GetSignedInUserWithCacheCtx(context.Background(), query4)
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, query4.Result)
|
||||
require.Equal(t, query4.Result.OrgId, users[0].OrgId)
|
||||
|
||||
Reference in New Issue
Block a user