diff --git a/pkg/infra/appcontext/user.go b/pkg/infra/appcontext/user.go new file mode 100644 index 00000000000..604ec008b7a --- /dev/null +++ b/pkg/infra/appcontext/user.go @@ -0,0 +1,52 @@ +package appcontext + +import ( + "context" + "fmt" + + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/contexthandler/ctxkey" + grpccontext "github.com/grafana/grafana/pkg/services/grpcserver/context" + "github.com/grafana/grafana/pkg/services/user" +) + +type ctxUserKey struct{} + +// WithUser adds the supplied SignedInUser to the context. +func WithUser(ctx context.Context, usr *user.SignedInUser) context.Context { + return context.WithValue(ctx, ctxUserKey{}, usr) +} + +// User extracts the SignedInUser from the supplied context. +// Supports context set by appcontext.WithUser, gRPC server context, and HTTP ReqContext. +func User(ctx context.Context) (*user.SignedInUser, error) { + // Set by appcontext.WithUser + u, ok := ctx.Value(ctxUserKey{}).(*user.SignedInUser) + if ok && u != nil { + return u, nil + } + + // Set by incoming gRPC server request + grpcCtx := grpccontext.FromContext(ctx) + if grpcCtx != nil && grpcCtx.SignedInUser != nil { + return grpcCtx.SignedInUser, nil + } + + // Set by incoming HTTP request + c, ok := ctxkey.Get(ctx).(*models.ReqContext) + if ok && c.SignedInUser != nil { + return c.SignedInUser, nil + } + + return nil, fmt.Errorf("a SignedInUser was not found in the context") +} + +// MustUser extracts the SignedInUser from the supplied context, and panics if a user is not found. +// Supports context set by appcontext.WithUser, gRPC server context, and HTTP ReqContext. +func MustUser(ctx context.Context) *user.SignedInUser { + usr, err := User(ctx) + if err != nil { + panic(err) + } + return usr +} diff --git a/pkg/infra/appcontext/user_test.go b/pkg/infra/appcontext/user_test.go new file mode 100644 index 00000000000..91913c692f1 --- /dev/null +++ b/pkg/infra/appcontext/user_test.go @@ -0,0 +1,68 @@ +package appcontext_test + +import ( + "context" + "crypto/rand" + "math/big" + "testing" + + "github.com/grafana/grafana/pkg/infra/appcontext" + "github.com/grafana/grafana/pkg/infra/tracing" + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/contexthandler/ctxkey" + grpccontext "github.com/grafana/grafana/pkg/services/grpcserver/context" + "github.com/grafana/grafana/pkg/services/user" + "github.com/stretchr/testify/require" +) + +func TestUserFromContext(t *testing.T) { + t.Run("User should error when context is missing user", func(t *testing.T) { + usr, err := appcontext.User(context.Background()) + require.Nil(t, usr) + require.Error(t, err) + }) + + t.Run("MustUser should panic when context is missing user", func(t *testing.T) { + require.Panics(t, func() { + _ = appcontext.MustUser(context.Background()) + }) + }) + + t.Run("should return user set by ContextWithUser", func(t *testing.T) { + expected := testUser() + ctx := appcontext.WithUser(context.Background(), expected) + actual, err := appcontext.User(ctx) + require.NoError(t, err) + require.Equal(t, expected.UserID, actual.UserID) + }) + + t.Run("should return user set by gRPC context", func(t *testing.T) { + expected := testUser() + handler := grpccontext.ProvideContextHandler(tracing.InitializeTracerForTest()) + ctx := handler.SetUser(context.Background(), expected) + actual, err := appcontext.User(ctx) + require.NoError(t, err) + require.Equal(t, expected.UserID, actual.UserID) + }) + + t.Run("should return user set by HTTP ReqContext", func(t *testing.T) { + expected := testUser() + ctx := ctxkey.Set(context.Background(), &models.ReqContext{ + SignedInUser: expected, + }) + actual, err := appcontext.User(ctx) + require.NoError(t, err) + require.Equal(t, expected.UserID, actual.UserID) + }) +} + +func testUser() *user.SignedInUser { + i, err := rand.Int(rand.Reader, big.NewInt(100000)) + if err != nil { + panic(err) + } + return &user.SignedInUser{ + UserID: i.Int64(), + OrgID: 1, + } +}