mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
AuthN: set org id for authentication request in service (#60528)
* AuthN: Replicate functionallity to get org id for request * Authn: parse org id for the request and populate the auth request with it * AuthN: add simple mock for client to use in test * AuthN: add tests to verify that authentication is called with correct org id * AuthN: Add ClientParams to mock * AuthN: Fix flaky org id selection
This commit is contained in:
parent
17696f8dec
commit
c4b4baea2a
@ -42,6 +42,8 @@ type Client interface {
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
// OrgID will be populated by authn.Service
|
||||
OrgID int64
|
||||
HTTPRequest *http.Request
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,8 @@ package authnimpl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
@ -15,6 +17,7 @@ import (
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
)
|
||||
|
||||
// make sure service implements authn.Service interface
|
||||
var _ authn.Service = new(Service)
|
||||
|
||||
func ProvideService(cfg *setting.Cfg, tracer tracing.Tracer, orgService org.Service, apikeyService apikey.Service, userService user.Service) *Service {
|
||||
@ -24,7 +27,6 @@ func ProvideService(cfg *setting.Cfg, tracer tracing.Tracer, orgService org.Serv
|
||||
clients: make(map[string]authn.Client),
|
||||
tracer: tracer,
|
||||
postAuthHooks: []authn.PostAuthHookFn{},
|
||||
userService: userService,
|
||||
}
|
||||
|
||||
s.clients[authn.ClientAPIKey] = clients.ProvideAPIKey(apikeyService, userService)
|
||||
@ -46,12 +48,9 @@ type Service struct {
|
||||
log log.Logger
|
||||
cfg *setting.Cfg
|
||||
clients map[string]authn.Client
|
||||
|
||||
// postAuthHooks are called after a successful authentication. They can modify the identity.
|
||||
postAuthHooks []authn.PostAuthHookFn
|
||||
|
||||
tracer tracing.Tracer
|
||||
userService user.Service
|
||||
tracer tracing.Tracer
|
||||
}
|
||||
|
||||
func (s *Service) Authenticate(ctx context.Context, client string, r *authn.Request) (*authn.Identity, bool, error) {
|
||||
@ -74,6 +73,7 @@ func (s *Service) Authenticate(ctx context.Context, client string, r *authn.Requ
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
r.OrgID = orgIDFromRequest(r)
|
||||
identity, err := c.Authenticate(ctx, r)
|
||||
if err != nil {
|
||||
logger.Warn("auth client could not authenticate request", "client", client, "error", err)
|
||||
@ -103,3 +103,46 @@ func (s *Service) Authenticate(ctx context.Context, client string, r *authn.Requ
|
||||
func (s *Service) RegisterPostAuthHook(hook authn.PostAuthHookFn) {
|
||||
s.postAuthHooks = append(s.postAuthHooks, hook)
|
||||
}
|
||||
|
||||
func orgIDFromRequest(r *authn.Request) int64 {
|
||||
if r.HTTPRequest == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
orgID := orgIDFromQuery(r.HTTPRequest)
|
||||
if orgID > 0 {
|
||||
return orgID
|
||||
}
|
||||
|
||||
return orgIDFromHeader(r.HTTPRequest)
|
||||
}
|
||||
|
||||
// name of query string used to target specific org for request
|
||||
const orgIDTargetQuery = "targetOrgId"
|
||||
|
||||
func orgIDFromQuery(req *http.Request) int64 {
|
||||
params := req.URL.Query()
|
||||
if !params.Has(orgIDTargetQuery) {
|
||||
return 0
|
||||
}
|
||||
id, err := strconv.ParseInt(params.Get(orgIDTargetQuery), 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// name of header containing org id for request
|
||||
const orgIDHeaderName = "X-Grafana-Org-Id"
|
||||
|
||||
func orgIDFromHeader(req *http.Request) int64 {
|
||||
header := req.Header.Get(orgIDHeaderName)
|
||||
if header == "" {
|
||||
return 0
|
||||
}
|
||||
id, err := strconv.ParseInt(header, 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
@ -3,6 +3,8 @@ package authnimpl
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -59,6 +61,77 @@ func TestService_Authenticate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestService_AuthenticateOrgID(t *testing.T) {
|
||||
type TestCase struct {
|
||||
desc string
|
||||
req *authn.Request
|
||||
expectedOrgID int64
|
||||
}
|
||||
|
||||
tests := []TestCase{
|
||||
{
|
||||
desc: "should set org id when present in header",
|
||||
req: &authn.Request{HTTPRequest: &http.Request{
|
||||
Header: map[string][]string{orgIDHeaderName: {"1"}},
|
||||
URL: &url.URL{},
|
||||
}},
|
||||
expectedOrgID: 1,
|
||||
},
|
||||
{
|
||||
desc: "should set org id when present in url",
|
||||
req: &authn.Request{HTTPRequest: &http.Request{
|
||||
Header: map[string][]string{},
|
||||
URL: mustParseURL("http://localhost/?targetOrgId=2"),
|
||||
}},
|
||||
expectedOrgID: 2,
|
||||
},
|
||||
{
|
||||
desc: "should prioritise org id from url when present in both header and url",
|
||||
req: &authn.Request{HTTPRequest: &http.Request{
|
||||
Header: map[string][]string{orgIDHeaderName: {"1"}},
|
||||
URL: mustParseURL("http://localhost/?targetOrgId=2"),
|
||||
}},
|
||||
expectedOrgID: 2,
|
||||
},
|
||||
{
|
||||
desc: "should set org id to 0 when missing in both header and url",
|
||||
req: &authn.Request{HTTPRequest: &http.Request{
|
||||
Header: map[string][]string{},
|
||||
URL: &url.URL{},
|
||||
}},
|
||||
expectedOrgID: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
var calledWith int64
|
||||
s := setupTests(t, func(svc *Service) {
|
||||
svc.clients["fake"] = authntest.MockClient{
|
||||
AuthenticateFunc: func(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
|
||||
calledWith = r.OrgID
|
||||
return nil, nil
|
||||
},
|
||||
TestFunc: func(ctx context.Context, r *authn.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
_, _, _ = s.Authenticate(context.Background(), "fake", tt.req)
|
||||
assert.Equal(t, tt.expectedOrgID, calledWith)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseURL(s string) *url.URL {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func setupTests(t *testing.T, opts ...func(svc *Service)) *Service {
|
||||
t.Helper()
|
||||
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/grafana/grafana/pkg/cmd/grafana-cli/logger"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
@ -65,8 +66,10 @@ func (s *OrgSync) SyncOrgUser(ctx context.Context, clientParams *authn.ClientPar
|
||||
}
|
||||
}
|
||||
|
||||
orgIDs := make([]int64, 0, len(id.OrgRoles))
|
||||
// add any new org roles
|
||||
for orgId, orgRole := range id.OrgRoles {
|
||||
orgIDs = append(orgIDs, orgId)
|
||||
if _, exists := handledOrgIds[orgId]; exists {
|
||||
continue
|
||||
}
|
||||
@ -99,17 +102,17 @@ func (s *OrgSync) SyncOrgUser(ctx context.Context, clientParams *authn.ClientPar
|
||||
}
|
||||
}
|
||||
|
||||
// Note: sort all org ids to not make it flaky, for now we default to the lowest id
|
||||
sort.Slice(orgIDs, func(i, j int) bool { return orgIDs[i] < orgIDs[j] })
|
||||
// update user's default org if needed
|
||||
if _, ok := id.OrgRoles[id.OrgID]; !ok {
|
||||
for orgId := range id.OrgRoles {
|
||||
id.OrgID = orgId
|
||||
break
|
||||
if len(orgIDs) > 0 {
|
||||
id.OrgID = orgIDs[0]
|
||||
return s.userService.SetUsingOrg(ctx, &user.SetUsingOrgCommand{
|
||||
UserID: userID,
|
||||
OrgID: id.OrgID,
|
||||
})
|
||||
}
|
||||
|
||||
return s.userService.SetUsingOrg(ctx, &user.SetUsingOrgCommand{
|
||||
UserID: userID,
|
||||
OrgID: id.OrgID,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
|
36
pkg/services/authn/authntest/mock.go
Normal file
36
pkg/services/authn/authntest/mock.go
Normal file
@ -0,0 +1,36 @@
|
||||
package authntest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
)
|
||||
|
||||
var _ authn.Client = new(MockClient)
|
||||
|
||||
type MockClient struct {
|
||||
AuthenticateFunc func(ctx context.Context, r *authn.Request) (*authn.Identity, error)
|
||||
ClientParamsFunc func() *authn.ClientParams
|
||||
TestFunc func(ctx context.Context, r *authn.Request) bool
|
||||
}
|
||||
|
||||
func (m MockClient) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
|
||||
if m.AuthenticateFunc != nil {
|
||||
return m.AuthenticateFunc(ctx, r)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m MockClient) ClientParams() *authn.ClientParams {
|
||||
if m.ClientParamsFunc != nil {
|
||||
return m.ClientParamsFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MockClient) Test(ctx context.Context, r *authn.Request) bool {
|
||||
if m.TestFunc != nil {
|
||||
return m.TestFunc(ctx, r)
|
||||
}
|
||||
return false
|
||||
}
|
Loading…
Reference in New Issue
Block a user