mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
Authn: Refactor user sync and org sync as post auth hooks (#60504)
* add user sync * add org user sync * add client params * merge remaining conflicts * remove change to report.go * update comments * add basic tests for user ID population * add tests for auth ID find * add tests for user sync create and update * add tests for orgsync * satisfy lint * add userID guards
This commit is contained in:
parent
1b1a14b6f6
commit
a553040441
@ -7,8 +7,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -16,14 +18,25 @@ const (
|
||||
ClientAnonymous = "auth.client.anonymous"
|
||||
)
|
||||
|
||||
type ClientParams struct {
|
||||
SyncUser bool
|
||||
AllowSignUp bool
|
||||
EnableDisabledUsers bool
|
||||
}
|
||||
|
||||
type PostAuthHookFn func(ctx context.Context, clientParams *ClientParams, identity *Identity) error
|
||||
|
||||
type Service interface {
|
||||
// Authenticate is used to authenticate using a specific client
|
||||
// RegisterPostAuthHook registers a hook that is called after a successful authentication.
|
||||
RegisterPostAuthHook(hook PostAuthHookFn)
|
||||
// Authenticate authenticates a request using the specified client.
|
||||
Authenticate(ctx context.Context, client string, r *Request) (*Identity, bool, error)
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
// Authenticate performs the authentication for the request
|
||||
Authenticate(ctx context.Context, r *Request) (*Identity, error)
|
||||
ClientParams() *ClientParams
|
||||
// Test should return true if client can be used to authenticate request
|
||||
Test(ctx context.Context, r *Request) bool
|
||||
}
|
||||
@ -38,17 +51,20 @@ const (
|
||||
)
|
||||
|
||||
type Identity struct {
|
||||
OrgID int64
|
||||
OrgCount int
|
||||
OrgName string
|
||||
OrgRoles map[int64]org.RoleType
|
||||
|
||||
ID string
|
||||
OrgID int64
|
||||
OrgCount int
|
||||
OrgName string
|
||||
OrgRoles map[int64]org.RoleType
|
||||
Login string
|
||||
Name string
|
||||
Email string
|
||||
AuthID string
|
||||
AuthModule string
|
||||
IsGrafanaAdmin bool
|
||||
IsGrafanaAdmin *bool
|
||||
AuthModule string // AuthModule is the name of the external system
|
||||
AuthID string // AuthId is the unique identifier for the user in the external system
|
||||
OAuthToken *oauth2.Token
|
||||
LookUpParams models.UserLookupParams
|
||||
IsDisabled bool
|
||||
HelpFlags1 user.HelpFlags1
|
||||
LastSeenAt time.Time
|
||||
@ -64,7 +80,28 @@ func (i *Identity) IsAnonymous() bool {
|
||||
return i.ID == ""
|
||||
}
|
||||
|
||||
// SignedInUser is used to translate Identity into SignedInUser struct
|
||||
// TODO: improve error handling
|
||||
func (i *Identity) NamespacedID() (string, int64) {
|
||||
var (
|
||||
id int64
|
||||
namespace string
|
||||
)
|
||||
|
||||
split := strings.Split(i.ID, ":")
|
||||
if len(split) != 2 {
|
||||
return "", -1
|
||||
}
|
||||
|
||||
id, errI := strconv.ParseInt(split[1], 10, 64)
|
||||
if errI != nil {
|
||||
return "", -1
|
||||
}
|
||||
|
||||
namespace = split[0]
|
||||
|
||||
return namespace, id
|
||||
}
|
||||
|
||||
func (i *Identity) SignedInUser() *user.SignedInUser {
|
||||
u := &user.SignedInUser{
|
||||
UserID: 0,
|
||||
@ -77,7 +114,7 @@ func (i *Identity) SignedInUser() *user.SignedInUser {
|
||||
Name: i.Name,
|
||||
Email: i.Email,
|
||||
OrgCount: i.OrgCount,
|
||||
IsGrafanaAdmin: i.IsGrafanaAdmin,
|
||||
IsGrafanaAdmin: *i.IsGrafanaAdmin,
|
||||
IsAnonymous: i.IsAnonymous(),
|
||||
IsDisabled: i.IsDisabled,
|
||||
HelpFlags1: i.HelpFlags1,
|
||||
@ -108,7 +145,7 @@ func IdentityFromSignedInUser(id string, usr *user.SignedInUser) *Identity {
|
||||
Name: usr.Name,
|
||||
Email: usr.Email,
|
||||
OrgCount: usr.OrgCount,
|
||||
IsGrafanaAdmin: usr.IsGrafanaAdmin,
|
||||
IsGrafanaAdmin: &usr.IsGrafanaAdmin,
|
||||
IsDisabled: usr.IsDisabled,
|
||||
HelpFlags1: usr.HelpFlags1,
|
||||
LastSeenAt: usr.LastSeenAt,
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"github.com/grafana/grafana/pkg/infra/tracing"
|
||||
"github.com/grafana/grafana/pkg/services/apikey"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
sync "github.com/grafana/grafana/pkg/services/authn/authnimpl/usersync"
|
||||
"github.com/grafana/grafana/pkg/services/authn/clients"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
@ -18,11 +19,12 @@ var _ authn.Service = new(Service)
|
||||
|
||||
func ProvideService(cfg *setting.Cfg, tracer tracing.Tracer, orgService org.Service, apikeyService apikey.Service, userService user.Service) *Service {
|
||||
s := &Service{
|
||||
log: log.New("authn.service"),
|
||||
cfg: cfg,
|
||||
clients: make(map[string]authn.Client),
|
||||
tracer: tracer,
|
||||
userService: userService,
|
||||
log: log.New("authn.service"),
|
||||
cfg: cfg,
|
||||
clients: make(map[string]authn.Client),
|
||||
tracer: tracer,
|
||||
postAuthHooks: []authn.PostAuthHookFn{},
|
||||
userService: userService,
|
||||
}
|
||||
|
||||
s.clients[authn.ClientAPIKey] = clients.ProvideAPIKey(apikeyService, userService)
|
||||
@ -31,6 +33,12 @@ func ProvideService(cfg *setting.Cfg, tracer tracing.Tracer, orgService org.Serv
|
||||
s.clients[authn.ClientAnonymous] = clients.ProvideAnonymous(cfg, orgService)
|
||||
}
|
||||
|
||||
// FIXME (jguer): move to User package
|
||||
userSyncService := &sync.UserSync{}
|
||||
orgUserSyncService := &sync.OrgSync{}
|
||||
s.RegisterPostAuthHook(userSyncService.SyncUser)
|
||||
s.RegisterPostAuthHook(orgUserSyncService.SyncOrgUser)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@ -39,6 +47,9 @@ type Service struct {
|
||||
cfg *setting.Cfg
|
||||
clients map[string]authn.Client
|
||||
|
||||
// postAuthHooks are called after a successful authentication. They can modify the identity.
|
||||
postAuthHooks []authn.PostAuthHookFn
|
||||
|
||||
tracer tracing.Tracer
|
||||
userService user.Service
|
||||
}
|
||||
@ -78,6 +89,17 @@ func (s *Service) Authenticate(ctx context.Context, client string, r *authn.Requ
|
||||
// login handler, but if we want to perform basic auth during a request (called from contexthandler) we don't
|
||||
// want a session to be created.
|
||||
|
||||
logger.Debug("auth client successfully authenticated request", "client", client, "identity", identity)
|
||||
params := c.ClientParams()
|
||||
|
||||
for _, hook := range s.postAuthHooks {
|
||||
if err := hook(ctx, params, identity); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
return identity, true, nil
|
||||
}
|
||||
|
||||
func (s *Service) RegisterPostAuthHook(hook authn.PostAuthHookFn) {
|
||||
s.postAuthHooks = append(s.postAuthHooks, hook)
|
||||
}
|
||||
|
116
pkg/services/authn/authnimpl/usersync/orgsync.go
Normal file
116
pkg/services/authn/authnimpl/usersync/orgsync.go
Normal file
@ -0,0 +1,116 @@
|
||||
package usersync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/grafana/grafana/pkg/cmd/grafana-cli/logger"
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
)
|
||||
|
||||
type OrgSync struct {
|
||||
userService user.Service
|
||||
orgService org.Service
|
||||
accessControl accesscontrol.Service
|
||||
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
func (s *OrgSync) SyncOrgUser(ctx context.Context, clientParams *authn.ClientParams, id *authn.Identity) error {
|
||||
if !clientParams.SyncUser {
|
||||
s.log.Debug("Not syncing org user", "auth_module", id.AuthModule, "auth_id", id.AuthID)
|
||||
return nil
|
||||
}
|
||||
|
||||
namespace, userID := id.NamespacedID()
|
||||
if namespace != "user" && userID <= 0 {
|
||||
return fmt.Errorf("invalid namespace %q for user ID %q", namespace, userID)
|
||||
}
|
||||
|
||||
s.log.Debug("Syncing organization roles", "id", userID, "extOrgRoles", id.OrgRoles)
|
||||
// don't sync org roles if none is specified
|
||||
if len(id.OrgRoles) == 0 {
|
||||
s.log.Debug("Not syncing organization roles since external user doesn't have any")
|
||||
return nil
|
||||
}
|
||||
|
||||
orgsQuery := &org.GetUserOrgListQuery{UserID: userID}
|
||||
result, err := s.orgService.GetUserOrgList(ctx, orgsQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
handledOrgIds := map[int64]bool{}
|
||||
deleteOrgIds := []int64{}
|
||||
|
||||
// update existing org roles
|
||||
for _, orga := range result {
|
||||
handledOrgIds[orga.OrgID] = true
|
||||
|
||||
extRole := id.OrgRoles[orga.OrgID]
|
||||
if extRole == "" {
|
||||
deleteOrgIds = append(deleteOrgIds, orga.OrgID)
|
||||
} else if extRole != orga.Role {
|
||||
// update role
|
||||
cmd := &org.UpdateOrgUserCommand{OrgID: orga.OrgID, UserID: userID, Role: extRole}
|
||||
if err := s.orgService.UpdateOrgUser(ctx, cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add any new org roles
|
||||
for orgId, orgRole := range id.OrgRoles {
|
||||
if _, exists := handledOrgIds[orgId]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
// add role
|
||||
cmd := &org.AddOrgUserCommand{UserID: userID, Role: orgRole, OrgID: orgId}
|
||||
err := s.orgService.AddOrgUser(ctx, cmd)
|
||||
if err != nil && !errors.Is(err, models.ErrOrgNotFound) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// delete any removed org roles
|
||||
for _, orgId := range deleteOrgIds {
|
||||
s.log.Debug("Removing user's organization membership as part of syncing with OAuth login",
|
||||
"userId", userID, "orgId", orgId)
|
||||
cmd := &org.RemoveOrgUserCommand{OrgID: orgId, UserID: userID}
|
||||
if err := s.orgService.RemoveOrgUser(ctx, cmd); err != nil {
|
||||
if errors.Is(err, models.ErrLastOrgAdmin) {
|
||||
logger.Error(err.Error(), "userId", cmd.UserID, "orgId", cmd.OrgID)
|
||||
continue
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.accessControl.DeleteUserPermissions(ctx, orgId, cmd.UserID); err != nil {
|
||||
logger.Error("failed to delete permissions for user", "error", err, "userID", cmd.UserID, "orgID", orgId)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// update user's default org if needed
|
||||
if _, ok := id.OrgRoles[id.OrgID]; !ok {
|
||||
for orgId := range id.OrgRoles {
|
||||
id.OrgID = orgId
|
||||
break
|
||||
}
|
||||
|
||||
return s.userService.SetUsingOrg(ctx, &user.SetUsingOrgCommand{
|
||||
UserID: userID,
|
||||
OrgID: id.OrgID,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
123
pkg/services/authn/authnimpl/usersync/orgsync_test.go
Normal file
123
pkg/services/authn/authnimpl/usersync/orgsync_test.go
Normal file
@ -0,0 +1,123 @@
|
||||
package usersync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/models/roletype"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol"
|
||||
"github.com/grafana/grafana/pkg/services/accesscontrol/actest"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/org"
|
||||
"github.com/grafana/grafana/pkg/services/org/orgtest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/services/user/usertest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestOrgSync_SyncOrgUser(t *testing.T) {
|
||||
orgService := &orgtest.FakeOrgService{ExpectedUserOrgDTO: []*org.UserOrgDTO{
|
||||
{
|
||||
OrgID: 1,
|
||||
Role: org.RoleEditor,
|
||||
},
|
||||
{
|
||||
OrgID: 3,
|
||||
Role: org.RoleViewer,
|
||||
},
|
||||
},
|
||||
ExpectedOrgListResponse: orgtest.OrgListResponse{
|
||||
{
|
||||
OrgID: 3,
|
||||
Response: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
acService := &actest.FakeService{}
|
||||
userService := &usertest.FakeUserService{ExpectedUser: &user.User{
|
||||
ID: 1,
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
}}
|
||||
|
||||
type fields struct {
|
||||
userService user.Service
|
||||
orgService org.Service
|
||||
accessControl accesscontrol.Service
|
||||
log log.Logger
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
clientParams *authn.ClientParams
|
||||
id *authn.Identity
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
wantID *authn.Identity
|
||||
}{
|
||||
{
|
||||
name: "add user to multiple orgs",
|
||||
fields: fields{
|
||||
userService: userService,
|
||||
orgService: orgService,
|
||||
accessControl: acService,
|
||||
log: log.NewNopLogger(),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientParams: &authn.ClientParams{
|
||||
SyncUser: true,
|
||||
},
|
||||
id: &authn.Identity{
|
||||
ID: "user:1",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
OrgRoles: map[int64]roletype.RoleType{1: org.RoleAdmin, 2: org.RoleEditor},
|
||||
IsGrafanaAdmin: ptrBool(false),
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: ptrString("test"),
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantID: &authn.Identity{
|
||||
ID: "user:1",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
OrgRoles: map[int64]roletype.RoleType{1: org.RoleAdmin, 2: org.RoleEditor},
|
||||
OrgID: 1, //set using org
|
||||
IsGrafanaAdmin: ptrBool(false),
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: ptrString("test"),
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &OrgSync{
|
||||
userService: tt.fields.userService,
|
||||
orgService: tt.fields.orgService,
|
||||
accessControl: tt.fields.accessControl,
|
||||
log: tt.fields.log,
|
||||
}
|
||||
if err := s.SyncOrgUser(tt.args.ctx, tt.args.clientParams, tt.args.id); (err != nil) != tt.wantErr {
|
||||
t.Errorf("OrgSync.SyncOrgUser() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
assert.EqualValues(t, tt.wantID, tt.args.id)
|
||||
})
|
||||
}
|
||||
}
|
248
pkg/services/authn/authnimpl/usersync/usersync.go
Normal file
248
pkg/services/authn/authnimpl/usersync/usersync.go
Normal file
@ -0,0 +1,248 @@
|
||||
package usersync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/quota"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
)
|
||||
|
||||
type UserSync struct {
|
||||
userService user.Service
|
||||
authInfoService login.AuthInfoService
|
||||
quotaService quota.Service
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
// SyncUser syncs a user with the database
|
||||
func (s *UserSync) SyncUser(ctx context.Context, clientParams *authn.ClientParams, id *authn.Identity) error {
|
||||
if !clientParams.SyncUser {
|
||||
s.log.Debug("Not syncing user", "auth_module", id.AuthModule, "auth_id", id.AuthID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Does user exist in the database?
|
||||
usr, errUserInDB := s.UserInDB(ctx, &id.AuthModule, &id.AuthID, id.LookUpParams)
|
||||
if errUserInDB != nil && !errors.Is(errUserInDB, user.ErrUserNotFound) {
|
||||
return errUserInDB
|
||||
}
|
||||
|
||||
if errors.Is(errUserInDB, user.ErrUserNotFound) {
|
||||
if !clientParams.AllowSignUp {
|
||||
s.log.Warn("Not allowing login, user not found in internal user database and allow signup = false",
|
||||
"auth_module", id.AuthModule)
|
||||
return login.ErrSignupNotAllowed
|
||||
}
|
||||
|
||||
// create user
|
||||
var errCreate error
|
||||
usr, errCreate = s.createUser(ctx, id)
|
||||
if errCreate != nil {
|
||||
return errCreate
|
||||
}
|
||||
}
|
||||
|
||||
// update user
|
||||
if errUpdate := s.updateUserAttributes(ctx, clientParams, usr, id); errUpdate != nil {
|
||||
return errUpdate
|
||||
}
|
||||
|
||||
syncUserToIdentity(usr, id)
|
||||
|
||||
// persist latest auth info token
|
||||
if errAuthInfo := s.updateAuthInfo(ctx, id); errAuthInfo != nil {
|
||||
return errAuthInfo
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncUserToIdentity syncs a user to an identity.
|
||||
// This is used to update the identity with the latest user information.
|
||||
func syncUserToIdentity(usr *user.User, id *authn.Identity) {
|
||||
id.ID = fmt.Sprintf("user:%d", usr.ID)
|
||||
id.Login = usr.Login
|
||||
id.Email = usr.Email
|
||||
id.Name = usr.Name
|
||||
id.IsGrafanaAdmin = &usr.IsAdmin
|
||||
}
|
||||
|
||||
func (s *UserSync) updateAuthInfo(ctx context.Context, id *authn.Identity) error {
|
||||
if id.AuthModule != "" && id.OAuthToken != nil && id.AuthID != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
namespace, userID := id.NamespacedID()
|
||||
if namespace != "user" && userID <= 0 { // FIXME: constant namespace
|
||||
return fmt.Errorf("invalid namespace %q for user ID %q", namespace, userID)
|
||||
}
|
||||
|
||||
updateCmd := &models.UpdateAuthInfoCommand{
|
||||
AuthModule: id.AuthModule,
|
||||
AuthId: id.AuthID,
|
||||
UserId: userID,
|
||||
OAuthToken: id.OAuthToken,
|
||||
}
|
||||
|
||||
s.log.Debug("Updating user_auth info", "user_id", userID)
|
||||
return s.authInfoService.UpdateAuthInfo(ctx, updateCmd)
|
||||
}
|
||||
|
||||
func (s *UserSync) updateUserAttributes(ctx context.Context, clientParams *authn.ClientParams, usr *user.User, id *authn.Identity) error {
|
||||
// sync user info
|
||||
updateCmd := &user.UpdateUserCommand{
|
||||
UserID: usr.ID,
|
||||
}
|
||||
|
||||
needsUpdate := false
|
||||
if id.Login != "" && id.Login != usr.Login {
|
||||
updateCmd.Login = id.Login
|
||||
usr.Login = id.Login
|
||||
needsUpdate = true
|
||||
}
|
||||
|
||||
if id.Email != "" && id.Email != usr.Email {
|
||||
updateCmd.Email = id.Email
|
||||
usr.Email = id.Email
|
||||
needsUpdate = true
|
||||
}
|
||||
|
||||
if id.Name != "" && id.Name != usr.Name {
|
||||
updateCmd.Name = id.Name
|
||||
usr.Name = id.Name
|
||||
needsUpdate = true
|
||||
}
|
||||
|
||||
if needsUpdate {
|
||||
s.log.Debug("Syncing user info", "id", usr.ID, "update", updateCmd)
|
||||
if err := s.userService.Update(ctx, updateCmd); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if usr.IsDisabled && clientParams.EnableDisabledUsers {
|
||||
usr.IsDisabled = false
|
||||
if errDisableUser := s.userService.Disable(ctx,
|
||||
&user.DisableUserCommand{
|
||||
UserID: usr.ID, IsDisabled: false}); errDisableUser != nil {
|
||||
return errDisableUser
|
||||
}
|
||||
}
|
||||
|
||||
// Sync isGrafanaAdmin permission
|
||||
if id.IsGrafanaAdmin != nil && *id.IsGrafanaAdmin != usr.IsAdmin {
|
||||
usr.IsAdmin = *id.IsGrafanaAdmin
|
||||
if errPerms := s.userService.UpdatePermissions(ctx, usr.ID, *id.IsGrafanaAdmin); errPerms != nil {
|
||||
return errPerms
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserSync) createUser(ctx context.Context, id *authn.Identity) (*user.User, error) {
|
||||
isAdmin := false
|
||||
if id.IsGrafanaAdmin != nil {
|
||||
isAdmin = *id.IsGrafanaAdmin
|
||||
}
|
||||
|
||||
// TODO: add quota check
|
||||
usr, errCreateUser := s.userService.Create(ctx, &user.CreateUserCommand{
|
||||
Login: id.Login,
|
||||
Email: id.Email,
|
||||
Name: id.Name,
|
||||
IsAdmin: isAdmin,
|
||||
SkipOrgSetup: len(id.OrgRoles) > 0,
|
||||
})
|
||||
if errCreateUser != nil {
|
||||
return nil, errCreateUser
|
||||
}
|
||||
|
||||
if id.AuthModule != "" && id.AuthID != "" {
|
||||
if errSetAuth := s.authInfoService.SetAuthInfo(ctx, &models.SetAuthInfoCommand{
|
||||
UserId: usr.ID,
|
||||
AuthModule: id.AuthModule,
|
||||
AuthId: id.AuthID,
|
||||
OAuthToken: id.OAuthToken,
|
||||
}); errSetAuth != nil {
|
||||
return nil, errSetAuth
|
||||
}
|
||||
}
|
||||
|
||||
return usr, nil
|
||||
}
|
||||
|
||||
// Does user exist in the database?
|
||||
// Check first authinfo table, then user table
|
||||
// return user id if found, 0 if not found
|
||||
func (s *UserSync) UserInDB(ctx context.Context,
|
||||
authID *string,
|
||||
authModule *string,
|
||||
params models.UserLookupParams) (*user.User, error) {
|
||||
// Check authinfo table
|
||||
if authID != nil && authModule != nil {
|
||||
query := &models.GetAuthInfoQuery{
|
||||
AuthModule: *authModule,
|
||||
AuthId: *authID,
|
||||
}
|
||||
errGetAuthInfo := s.authInfoService.GetAuthInfo(ctx, query)
|
||||
if errGetAuthInfo == nil {
|
||||
usr, errGetByID := s.userService.GetByID(ctx, &user.GetUserByIDQuery{ID: query.Result.UserId})
|
||||
if errGetByID == nil {
|
||||
return usr, nil
|
||||
}
|
||||
|
||||
if !errors.Is(errGetByID, user.ErrUserNotFound) {
|
||||
return nil, errGetByID
|
||||
}
|
||||
}
|
||||
|
||||
if !errors.Is(errGetAuthInfo, user.ErrUserNotFound) {
|
||||
return nil, errGetAuthInfo
|
||||
}
|
||||
}
|
||||
|
||||
// Check user table to grab existing user
|
||||
return s.LookupByOneOf(ctx, ¶ms)
|
||||
}
|
||||
|
||||
func (s *UserSync) LookupByOneOf(ctx context.Context, params *models.UserLookupParams) (*user.User, error) {
|
||||
var usr *user.User
|
||||
var err error
|
||||
|
||||
// If not found, try to find the user by id
|
||||
if params.UserID != nil && *params.UserID != 0 {
|
||||
usr, err = s.userService.GetByID(ctx, &user.GetUserByIDQuery{ID: *params.UserID})
|
||||
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, try to find the user by email address
|
||||
if usr == nil && params.Email != nil && *params.Email != "" {
|
||||
usr, err = s.userService.GetByEmail(ctx, &user.GetUserByEmailQuery{Email: *params.Email})
|
||||
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// If not found, try to find the user by login
|
||||
if usr == nil && params.Login != nil && *params.Login != "" {
|
||||
usr, err = s.userService.GetByLogin(ctx, &user.GetUserByLoginQuery{LoginOrEmail: *params.Login})
|
||||
if err != nil && !errors.Is(err, user.ErrUserNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if usr == nil {
|
||||
return nil, user.ErrUserNotFound
|
||||
}
|
||||
|
||||
return usr, nil
|
||||
}
|
443
pkg/services/authn/authnimpl/usersync/usersync_test.go
Normal file
443
pkg/services/authn/authnimpl/usersync/usersync_test.go
Normal file
@ -0,0 +1,443 @@
|
||||
package usersync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/grafana/grafana/pkg/models"
|
||||
"github.com/grafana/grafana/pkg/services/authn"
|
||||
"github.com/grafana/grafana/pkg/services/login"
|
||||
"github.com/grafana/grafana/pkg/services/login/logintest"
|
||||
"github.com/grafana/grafana/pkg/services/quota"
|
||||
"github.com/grafana/grafana/pkg/services/quota/quotatest"
|
||||
"github.com/grafana/grafana/pkg/services/user"
|
||||
"github.com/grafana/grafana/pkg/services/user/usertest"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func ptrString(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func ptrBool(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
func ptrInt64(i int64) *int64 {
|
||||
return &i
|
||||
}
|
||||
|
||||
func TestUserSync_SyncUser(t *testing.T) {
|
||||
authFakeNil := &logintest.AuthInfoServiceFake{
|
||||
ExpectedUser: nil,
|
||||
ExpectedError: user.ErrUserNotFound,
|
||||
SetAuthInfoFn: func(ctx context.Context, cmd *models.SetAuthInfoCommand) error {
|
||||
return nil
|
||||
},
|
||||
UpdateAuthInfoFn: func(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
authFakeUserID := &logintest.AuthInfoServiceFake{
|
||||
ExpectedUser: nil,
|
||||
ExpectedError: nil,
|
||||
ExpectedUserAuth: &models.UserAuth{
|
||||
AuthModule: "oauth",
|
||||
AuthId: "2032",
|
||||
UserId: 1,
|
||||
Id: 1}}
|
||||
|
||||
userService := &usertest.FakeUserService{ExpectedUser: &user.User{
|
||||
ID: 1,
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
}}
|
||||
|
||||
userServiceMod := &usertest.FakeUserService{ExpectedUser: &user.User{
|
||||
ID: 3,
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
IsDisabled: true,
|
||||
IsAdmin: false,
|
||||
}}
|
||||
|
||||
userServiceNil := &usertest.FakeUserService{
|
||||
ExpectedUser: nil,
|
||||
ExpectedError: user.ErrUserNotFound,
|
||||
CreateFn: func(ctx context.Context, cmd *user.CreateUserCommand) (*user.User, error) {
|
||||
return &user.User{
|
||||
ID: 2,
|
||||
Login: cmd.Login,
|
||||
Name: cmd.Name,
|
||||
Email: cmd.Email,
|
||||
IsAdmin: cmd.IsAdmin,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
userService user.Service
|
||||
authInfoService login.AuthInfoService
|
||||
quotaService quota.Service
|
||||
log log.Logger
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
clientParams *authn.ClientParams
|
||||
id *authn.Identity
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
wantID *authn.Identity
|
||||
}{
|
||||
{
|
||||
name: "no sync",
|
||||
fields: fields{
|
||||
userService: userService,
|
||||
authInfoService: authFakeNil,
|
||||
quotaService: "atest.FakeQuotaService{},
|
||||
log: log.NewNopLogger(),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientParams: &authn.ClientParams{
|
||||
SyncUser: false,
|
||||
AllowSignUp: false,
|
||||
EnableDisabledUsers: false,
|
||||
},
|
||||
id: &authn.Identity{
|
||||
ID: "",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: ptrString("test"),
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
wantID: &authn.Identity{
|
||||
ID: "",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: ptrString("test"),
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sync - user found in DB - by email",
|
||||
fields: fields{
|
||||
userService: userService,
|
||||
authInfoService: authFakeNil,
|
||||
quotaService: "atest.FakeQuotaService{},
|
||||
log: log.NewNopLogger(),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientParams: &authn.ClientParams{
|
||||
SyncUser: true,
|
||||
AllowSignUp: false,
|
||||
EnableDisabledUsers: false,
|
||||
},
|
||||
id: &authn.Identity{
|
||||
ID: "",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: ptrString("test"),
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
wantID: &authn.Identity{
|
||||
ID: "user:1",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
IsGrafanaAdmin: ptrBool(false),
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: ptrString("test"),
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sync - user found in DB - by login",
|
||||
fields: fields{
|
||||
userService: userService,
|
||||
authInfoService: authFakeNil,
|
||||
quotaService: "atest.FakeQuotaService{},
|
||||
log: log.NewNopLogger(),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientParams: &authn.ClientParams{
|
||||
SyncUser: true,
|
||||
AllowSignUp: false,
|
||||
EnableDisabledUsers: false,
|
||||
},
|
||||
id: &authn.Identity{
|
||||
ID: "",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: nil,
|
||||
Login: ptrString("test"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
wantID: &authn.Identity{
|
||||
ID: "user:1",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
IsGrafanaAdmin: ptrBool(false),
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: nil,
|
||||
Login: ptrString("test"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sync - user found in DB - by ID",
|
||||
fields: fields{
|
||||
userService: userService,
|
||||
authInfoService: authFakeNil,
|
||||
quotaService: "atest.FakeQuotaService{},
|
||||
log: log.NewNopLogger(),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientParams: &authn.ClientParams{
|
||||
SyncUser: true,
|
||||
AllowSignUp: false,
|
||||
EnableDisabledUsers: false,
|
||||
},
|
||||
id: &authn.Identity{
|
||||
ID: "",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: ptrInt64(1),
|
||||
Email: nil,
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
wantID: &authn.Identity{
|
||||
ID: "user:1",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
IsGrafanaAdmin: ptrBool(false),
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: ptrInt64(1),
|
||||
Email: nil,
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sync - user found in authInfo",
|
||||
fields: fields{
|
||||
userService: userService,
|
||||
authInfoService: authFakeUserID,
|
||||
quotaService: "atest.FakeQuotaService{},
|
||||
log: log.NewNopLogger(),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientParams: &authn.ClientParams{
|
||||
SyncUser: true,
|
||||
AllowSignUp: false,
|
||||
EnableDisabledUsers: false,
|
||||
},
|
||||
id: &authn.Identity{
|
||||
ID: "",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: nil,
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
wantID: &authn.Identity{
|
||||
ID: "user:1",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
IsGrafanaAdmin: ptrBool(false),
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: nil,
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sync - user needs to be created - disabled signup",
|
||||
fields: fields{
|
||||
userService: userService,
|
||||
authInfoService: authFakeNil,
|
||||
quotaService: "atest.FakeQuotaService{},
|
||||
log: log.NewNopLogger(),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientParams: &authn.ClientParams{
|
||||
SyncUser: true,
|
||||
AllowSignUp: false,
|
||||
EnableDisabledUsers: false,
|
||||
},
|
||||
id: &authn.Identity{
|
||||
ID: "",
|
||||
Login: "test",
|
||||
Name: "test",
|
||||
Email: "test",
|
||||
AuthModule: "oauth",
|
||||
AuthID: "2032",
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: nil,
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "sync - user needs to be created - enabled signup",
|
||||
fields: fields{
|
||||
userService: userServiceNil,
|
||||
authInfoService: authFakeNil,
|
||||
quotaService: "atest.FakeQuotaService{},
|
||||
log: log.NewNopLogger(),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientParams: &authn.ClientParams{
|
||||
SyncUser: true,
|
||||
AllowSignUp: true,
|
||||
EnableDisabledUsers: true,
|
||||
},
|
||||
id: &authn.Identity{
|
||||
ID: "",
|
||||
Login: "test_create",
|
||||
Name: "test_create",
|
||||
IsGrafanaAdmin: ptrBool(true),
|
||||
Email: "test_create",
|
||||
AuthModule: "oauth",
|
||||
AuthID: "2032",
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: ptrString("test_create"),
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
wantID: &authn.Identity{
|
||||
ID: "user:2",
|
||||
Login: "test_create",
|
||||
Name: "test_create",
|
||||
Email: "test_create",
|
||||
AuthModule: "oauth",
|
||||
AuthID: "2032",
|
||||
IsGrafanaAdmin: ptrBool(true),
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: nil,
|
||||
Email: ptrString("test_create"),
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sync - needs full update",
|
||||
fields: fields{
|
||||
userService: userServiceMod,
|
||||
authInfoService: authFakeNil,
|
||||
quotaService: "atest.FakeQuotaService{},
|
||||
log: log.NewNopLogger(),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
clientParams: &authn.ClientParams{
|
||||
SyncUser: true,
|
||||
AllowSignUp: false,
|
||||
EnableDisabledUsers: true,
|
||||
},
|
||||
id: &authn.Identity{
|
||||
ID: "",
|
||||
Login: "test_mod",
|
||||
Name: "test_mod",
|
||||
Email: "test_mod",
|
||||
IsDisabled: false,
|
||||
IsGrafanaAdmin: ptrBool(true),
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: ptrInt64(3),
|
||||
Email: nil,
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
wantID: &authn.Identity{
|
||||
ID: "user:3",
|
||||
Login: "test_mod",
|
||||
Name: "test_mod",
|
||||
Email: "test_mod",
|
||||
IsDisabled: false,
|
||||
IsGrafanaAdmin: ptrBool(true),
|
||||
LookUpParams: models.UserLookupParams{
|
||||
UserID: ptrInt64(3),
|
||||
Email: nil,
|
||||
Login: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &UserSync{
|
||||
userService: tt.fields.userService,
|
||||
authInfoService: tt.fields.authInfoService,
|
||||
quotaService: tt.fields.quotaService,
|
||||
log: tt.fields.log,
|
||||
}
|
||||
err := s.SyncUser(tt.args.ctx, tt.args.clientParams, tt.args.id)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
require.EqualValues(t, tt.wantID, tt.args.id)
|
||||
})
|
||||
}
|
||||
}
|
@ -22,6 +22,10 @@ func (f *FakeClient) Authenticate(ctx context.Context, r *authn.Request) (*authn
|
||||
return f.ExpectedIdentity, f.ExpectedErr
|
||||
}
|
||||
|
||||
func (f *FakeClient) ClientParams() *authn.ClientParams {
|
||||
return &authn.ClientParams{}
|
||||
}
|
||||
|
||||
func (f *FakeClient) Test(ctx context.Context, r *authn.Request) bool {
|
||||
return f.ExpectedTest
|
||||
}
|
||||
|
@ -39,6 +39,10 @@ func (a *Anonymous) Authenticate(ctx context.Context, r *authn.Request) (*authn.
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Anonymous) ClientParams() *authn.ClientParams {
|
||||
return &authn.ClientParams{}
|
||||
}
|
||||
|
||||
func (a *Anonymous) Test(ctx context.Context, r *authn.Request) bool {
|
||||
// If anonymous client is register it can always be used for authentication
|
||||
return true
|
||||
|
@ -46,6 +46,14 @@ type APIKey struct {
|
||||
apiKeyService apikey.Service
|
||||
}
|
||||
|
||||
func (s *APIKey) ClientParams() *authn.ClientParams {
|
||||
return &authn.ClientParams{
|
||||
SyncUser: false,
|
||||
AllowSignUp: false,
|
||||
EnableDisabledUsers: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKey) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) {
|
||||
apiKey, err := s.getAPIKey(ctx, getTokenFromRequest(r))
|
||||
if err != nil {
|
||||
|
@ -75,11 +75,12 @@ func TestAPIKey_Authenticate(t *testing.T) {
|
||||
Name: "test",
|
||||
},
|
||||
expectedIdentity: &authn.Identity{
|
||||
ID: "service-account:1",
|
||||
OrgID: 1,
|
||||
OrgCount: 1,
|
||||
Name: "test",
|
||||
OrgRoles: map[int64]org.RoleType{1: org.RoleViewer},
|
||||
ID: "service-account:1",
|
||||
OrgID: 1,
|
||||
OrgCount: 1,
|
||||
Name: "test",
|
||||
OrgRoles: map[int64]org.RoleType{1: org.RoleViewer},
|
||||
IsGrafanaAdmin: boolPtr(false),
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -126,7 +127,7 @@ func TestAPIKey_Authenticate(t *testing.T) {
|
||||
assert.ErrorIs(t, err, tt.expectedErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, *tt.expectedIdentity, *identity)
|
||||
assert.EqualValues(t, *tt.expectedIdentity, *identity)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -193,6 +194,10 @@ func intPtr(n int64) *int64 {
|
||||
return &n
|
||||
}
|
||||
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
|
||||
func genApiKey(legacy bool) (string, string) {
|
||||
if legacy {
|
||||
res, _ := apikeygen.New(1, "test")
|
||||
|
@ -26,6 +26,9 @@ type AuthInfoServiceFake struct {
|
||||
ExpectedExternalUser *models.ExternalUserInfo
|
||||
ExpectedError error
|
||||
ExpectedLabels map[int64]string
|
||||
|
||||
SetAuthInfoFn func(ctx context.Context, cmd *models.SetAuthInfoCommand) error
|
||||
UpdateAuthInfoFn func(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error
|
||||
}
|
||||
|
||||
func (a *AuthInfoServiceFake) LookupAndUpdate(ctx context.Context, query *models.GetUserByAuthInfoQuery) (*user.User, error) {
|
||||
@ -48,10 +51,18 @@ func (a *AuthInfoServiceFake) GetUserLabels(ctx context.Context, query models.Ge
|
||||
}
|
||||
|
||||
func (a *AuthInfoServiceFake) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error {
|
||||
if a.SetAuthInfoFn != nil {
|
||||
return a.SetAuthInfoFn(ctx, cmd)
|
||||
}
|
||||
|
||||
return a.ExpectedError
|
||||
}
|
||||
|
||||
func (a *AuthInfoServiceFake) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error {
|
||||
if a.UpdateAuthInfoFn != nil {
|
||||
return a.UpdateAuthInfoFn(ctx, cmd)
|
||||
}
|
||||
|
||||
return a.ExpectedError
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,7 @@ type FakeUserService struct {
|
||||
ExpectedUserProfileDTOs []*user.UserProfileDTO
|
||||
|
||||
GetSignedInUserFn func(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error)
|
||||
CreateFn func(ctx context.Context, cmd *user.CreateUserCommand) (*user.User, error)
|
||||
|
||||
counter int
|
||||
}
|
||||
@ -25,6 +26,10 @@ func NewUserServiceFake() *FakeUserService {
|
||||
}
|
||||
|
||||
func (f *FakeUserService) Create(ctx context.Context, cmd *user.CreateUserCommand) (*user.User, error) {
|
||||
if f.CreateFn != nil {
|
||||
return f.CreateFn(ctx, cmd)
|
||||
}
|
||||
|
||||
return f.ExpectedUser, f.ExpectedError
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user