diff --git a/pkg/api/api.go b/pkg/api/api.go index 9074cf3f37a..e89f4996567 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -145,19 +145,19 @@ func (hs *HTTPServer) registerRoutes() { // user (signed in) apiRoute.Group("/user", func(userRoute routing.RouteRegister) { userRoute.Get("/", routing.Wrap(hs.GetSignedInUser)) - userRoute.Put("/", routing.Wrap(UpdateSignedInUser)) - userRoute.Post("/using/:id", routing.Wrap(UserSetUsingOrg)) - userRoute.Get("/orgs", routing.Wrap(GetSignedInUserOrgList)) - userRoute.Get("/teams", routing.Wrap(GetSignedInUserTeamList)) + userRoute.Put("/", routing.Wrap(hs.UpdateSignedInUser)) + userRoute.Post("/using/:id", routing.Wrap(hs.UserSetUsingOrg)) + userRoute.Get("/orgs", routing.Wrap(hs.GetSignedInUserOrgList)) + userRoute.Get("/teams", routing.Wrap(hs.GetSignedInUserTeamList)) userRoute.Post("/stars/dashboard/:id", routing.Wrap(hs.StarDashboard)) userRoute.Delete("/stars/dashboard/:id", routing.Wrap(hs.UnstarDashboard)) - userRoute.Put("/password", routing.Wrap(ChangeUserPassword)) + userRoute.Put("/password", routing.Wrap(hs.ChangeUserPassword)) userRoute.Get("/quotas", routing.Wrap(GetUserQuotas)) - userRoute.Put("/helpflags/:id", routing.Wrap(SetHelpFlag)) + userRoute.Put("/helpflags/:id", routing.Wrap(hs.SetHelpFlag)) // For dev purpose - userRoute.Get("/helpflags/clear", routing.Wrap(ClearHelpFlags)) + userRoute.Get("/helpflags/clear", routing.Wrap(hs.ClearHelpFlags)) userRoute.Get("/preferences", routing.Wrap(hs.GetUserPreferences)) userRoute.Put("/preferences", routing.Wrap(hs.UpdateUserPreferences)) @@ -171,12 +171,12 @@ func (hs *HTTPServer) registerRoutes() { usersRoute.Get("/", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersRead, ac.ScopeGlobalUsersAll)), routing.Wrap(hs.searchUsersService.SearchUsers)) usersRoute.Get("/search", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersRead, ac.ScopeGlobalUsersAll)), routing.Wrap(hs.searchUsersService.SearchUsersWithPaging)) usersRoute.Get("/:id", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersRead, userIDScope)), routing.Wrap(hs.GetUserByID)) - usersRoute.Get("/:id/teams", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersTeamRead, userIDScope)), routing.Wrap(GetUserTeams)) - usersRoute.Get("/:id/orgs", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersRead, userIDScope)), routing.Wrap(GetUserOrgList)) + usersRoute.Get("/:id/teams", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersTeamRead, userIDScope)), routing.Wrap(hs.GetUserTeams)) + usersRoute.Get("/:id/orgs", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersRead, userIDScope)), routing.Wrap(hs.GetUserOrgList)) // query parameters /users/lookup?loginOrEmail=admin@example.com - usersRoute.Get("/lookup", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersRead, ac.ScopeGlobalUsersAll)), routing.Wrap(GetUserByLoginOrEmail)) - usersRoute.Put("/:id", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersWrite, userIDScope)), routing.Wrap(UpdateUser)) - usersRoute.Post("/:id/using/:orgId", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersWrite, userIDScope)), routing.Wrap(UpdateUserActiveOrg)) + usersRoute.Get("/lookup", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersRead, ac.ScopeGlobalUsersAll)), routing.Wrap(hs.GetUserByLoginOrEmail)) + usersRoute.Put("/:id", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersWrite, userIDScope)), routing.Wrap(hs.UpdateUser)) + usersRoute.Post("/:id/using/:orgId", authorize(reqGrafanaAdmin, ac.EvalPermission(ac.ActionUsersWrite, userIDScope)), routing.Wrap(hs.UpdateUserActiveOrg)) }) // team (admin permission required) diff --git a/pkg/api/http_server.go b/pkg/api/http_server.go index 87619a0ca43..3124178c74e 100644 --- a/pkg/api/http_server.go +++ b/pkg/api/http_server.go @@ -43,6 +43,7 @@ import ( "github.com/grafana/grafana/pkg/services/live" "github.com/grafana/grafana/pkg/services/live/pushhttp" "github.com/grafana/grafana/pkg/services/login" + "github.com/grafana/grafana/pkg/services/login/authinfoservice" "github.com/grafana/grafana/pkg/services/ngalert" "github.com/grafana/grafana/pkg/services/provisioning" "github.com/grafana/grafana/pkg/services/query" @@ -121,6 +122,7 @@ type HTTPServer struct { teamGuardian teamguardian.TeamGuardian queryDataService *query.Service serviceAccountsService serviceaccounts.Service + authInfoService authinfoservice.Service TeamPermissionsService *resourcepermissions.Service } @@ -147,7 +149,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi encryptionService encryption.Internal, updateChecker *updatechecker.Service, searchUsersService searchusers.Service, dataSourcesService *datasources.Service, secretsService secrets.Service, queryDataService *query.Service, teamGuardian teamguardian.TeamGuardian, serviceaccountsService serviceaccounts.Service, - resourcePermissionServices *resourceservices.ResourceServices) (*HTTPServer, error) { + authInfoService authinfoservice.Service, resourcePermissionServices *resourceservices.ResourceServices) (*HTTPServer, error) { web.Env = cfg.Env m := web.New() @@ -202,6 +204,7 @@ func ProvideHTTPServer(opts ServerOptions, cfg *setting.Cfg, routeRegister routi teamGuardian: teamGuardian, queryDataService: queryDataService, serviceAccountsService: serviceaccountsService, + authInfoService: authInfoService, TeamPermissionsService: resourcePermissionServices.GetTeamService(), } if hs.Listener != nil { diff --git a/pkg/api/user.go b/pkg/api/user.go index 1f6ee0b221c..fac5aa8acbf 100644 --- a/pkg/api/user.go +++ b/pkg/api/user.go @@ -9,7 +9,6 @@ import ( "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/response" - "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/setting" @@ -34,7 +33,7 @@ func (hs *HTTPServer) GetUserByID(c *models.ReqContext) response.Response { func (hs *HTTPServer) getUserUserProfile(c *models.ReqContext, userID int64) response.Response { query := models.GetUserProfileQuery{UserId: userID} - if err := bus.Dispatch(c.Req.Context(), &query); err != nil { + if err := hs.SQLStore.GetUserProfile(c.Req.Context(), &query); err != nil { if errors.Is(err, models.ErrUserNotFound) { return response.Error(404, models.ErrUserNotFound.Error(), nil) } @@ -43,7 +42,7 @@ func (hs *HTTPServer) getUserUserProfile(c *models.ReqContext, userID int64) res getAuthQuery := models.GetAuthInfoQuery{UserId: userID} query.Result.AuthLabels = []string{} - if err := bus.Dispatch(c.Req.Context(), &getAuthQuery); err == nil { + if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), &getAuthQuery); err == nil { authLabel := GetAuthProviderLabel(getAuthQuery.Result.AuthModule) query.Result.AuthLabels = append(query.Result.AuthLabels, authLabel) query.Result.IsExternal = true @@ -77,9 +76,9 @@ func (hs *HTTPServer) getGlobalUserAccessControlMetadata(c *models.ReqContext, u } // GET /api/users/lookup -func GetUserByLoginOrEmail(c *models.ReqContext) response.Response { +func (hs *HTTPServer) GetUserByLoginOrEmail(c *models.ReqContext) response.Response { query := models.GetUserByLoginQuery{LoginOrEmail: c.Query("loginOrEmail")} - if err := bus.Dispatch(c.Req.Context(), &query); err != nil { + if err := hs.SQLStore.GetUserByLogin(c.Req.Context(), &query); err != nil { if errors.Is(err, models.ErrUserNotFound) { return response.Error(404, models.ErrUserNotFound.Error(), nil) } @@ -101,7 +100,7 @@ func GetUserByLoginOrEmail(c *models.ReqContext) response.Response { } // POST /api/user -func UpdateSignedInUser(c *models.ReqContext) response.Response { +func (hs *HTTPServer) UpdateSignedInUser(c *models.ReqContext) response.Response { cmd := models.UpdateUserCommand{} if err := web.Bind(c.Req, &cmd); err != nil { return response.Error(http.StatusBadRequest, "bad request data", err) @@ -115,11 +114,11 @@ func UpdateSignedInUser(c *models.ReqContext) response.Response { } } cmd.UserId = c.UserId - return handleUpdateUser(c.Req.Context(), cmd) + return hs.handleUpdateUser(c.Req.Context(), cmd) } // POST /api/users/:id -func UpdateUser(c *models.ReqContext) response.Response { +func (hs *HTTPServer) UpdateUser(c *models.ReqContext) response.Response { cmd := models.UpdateUserCommand{} var err error if err := web.Bind(c.Req, &cmd); err != nil { @@ -129,11 +128,11 @@ func UpdateUser(c *models.ReqContext) response.Response { if err != nil { return response.Error(http.StatusBadRequest, "id is invalid", err) } - return handleUpdateUser(c.Req.Context(), cmd) + return hs.handleUpdateUser(c.Req.Context(), cmd) } // POST /api/users/:id/using/:orgId -func UpdateUserActiveOrg(c *models.ReqContext) response.Response { +func (hs *HTTPServer) UpdateUserActiveOrg(c *models.ReqContext) response.Response { userID, err := strconv.ParseInt(web.Params(c.Req)[":id"], 10, 64) if err != nil { return response.Error(http.StatusBadRequest, "id is invalid", err) @@ -143,20 +142,20 @@ func UpdateUserActiveOrg(c *models.ReqContext) response.Response { return response.Error(http.StatusBadRequest, "orgId is invalid", err) } - if !validateUsingOrg(c.Req.Context(), userID, orgID) { + if !hs.validateUsingOrg(c.Req.Context(), userID, orgID) { return response.Error(401, "Not a valid organization", nil) } cmd := models.SetUsingOrgCommand{UserId: userID, OrgId: orgID} - if err := bus.Dispatch(c.Req.Context(), &cmd); err != nil { + if err := hs.SQLStore.SetUsingOrg(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to change active organization", err) } return response.Success("Active organization changed") } -func handleUpdateUser(ctx context.Context, cmd models.UpdateUserCommand) response.Response { +func (hs *HTTPServer) handleUpdateUser(ctx context.Context, cmd models.UpdateUserCommand) response.Response { if len(cmd.Login) == 0 { cmd.Login = cmd.Email if len(cmd.Login) == 0 { @@ -164,7 +163,7 @@ func handleUpdateUser(ctx context.Context, cmd models.UpdateUserCommand) respons } } - if err := bus.Dispatch(ctx, &cmd); err != nil { + if err := hs.SQLStore.UpdateUser(ctx, &cmd); err != nil { return response.Error(500, "Failed to update user", err) } @@ -172,28 +171,28 @@ func handleUpdateUser(ctx context.Context, cmd models.UpdateUserCommand) respons } // GET /api/user/orgs -func GetSignedInUserOrgList(c *models.ReqContext) response.Response { - return getUserOrgList(c.Req.Context(), c.UserId) +func (hs *HTTPServer) GetSignedInUserOrgList(c *models.ReqContext) response.Response { + return hs.getUserOrgList(c.Req.Context(), c.UserId) } // GET /api/user/teams -func GetSignedInUserTeamList(c *models.ReqContext) response.Response { - return getUserTeamList(c.Req.Context(), c.OrgId, c.UserId) +func (hs *HTTPServer) GetSignedInUserTeamList(c *models.ReqContext) response.Response { + return hs.getUserTeamList(c.Req.Context(), c.OrgId, c.UserId) } // GET /api/users/:id/teams -func GetUserTeams(c *models.ReqContext) response.Response { +func (hs *HTTPServer) GetUserTeams(c *models.ReqContext) response.Response { id, err := strconv.ParseInt(web.Params(c.Req)[":id"], 10, 64) if err != nil { return response.Error(http.StatusBadRequest, "id is invalid", err) } - return getUserTeamList(c.Req.Context(), c.OrgId, id) + return hs.getUserTeamList(c.Req.Context(), c.OrgId, id) } -func getUserTeamList(ctx context.Context, orgID int64, userID int64) response.Response { +func (hs *HTTPServer) getUserTeamList(ctx context.Context, orgID int64, userID int64) response.Response { query := models.GetTeamsByUserQuery{OrgId: orgID, UserId: userID} - if err := bus.Dispatch(ctx, &query); err != nil { + if err := hs.SQLStore.GetTeamsByUser(ctx, &query); err != nil { return response.Error(500, "Failed to get user teams", err) } @@ -204,28 +203,28 @@ func getUserTeamList(ctx context.Context, orgID int64, userID int64) response.Re } // GET /api/users/:id/orgs -func GetUserOrgList(c *models.ReqContext) response.Response { +func (hs *HTTPServer) GetUserOrgList(c *models.ReqContext) response.Response { id, err := strconv.ParseInt(web.Params(c.Req)[":id"], 10, 64) if err != nil { return response.Error(http.StatusBadRequest, "id is invalid", err) } - return getUserOrgList(c.Req.Context(), id) + return hs.getUserOrgList(c.Req.Context(), id) } -func getUserOrgList(ctx context.Context, userID int64) response.Response { +func (hs *HTTPServer) getUserOrgList(ctx context.Context, userID int64) response.Response { query := models.GetUserOrgListQuery{UserId: userID} - if err := bus.Dispatch(ctx, &query); err != nil { + if err := hs.SQLStore.GetUserOrgList(ctx, &query); err != nil { return response.Error(500, "Failed to get user organizations", err) } return response.JSON(200, query.Result) } -func validateUsingOrg(ctx context.Context, userID int64, orgID int64) bool { +func (hs *HTTPServer) validateUsingOrg(ctx context.Context, userID int64, orgID int64) bool { query := models.GetUserOrgListQuery{UserId: userID} - if err := bus.Dispatch(ctx, &query); err != nil { + if err := hs.SQLStore.GetUserOrgList(ctx, &query); err != nil { return false } @@ -241,19 +240,19 @@ func validateUsingOrg(ctx context.Context, userID int64, orgID int64) bool { } // POST /api/user/using/:id -func UserSetUsingOrg(c *models.ReqContext) response.Response { +func (hs *HTTPServer) UserSetUsingOrg(c *models.ReqContext) response.Response { orgID, err := strconv.ParseInt(web.Params(c.Req)[":id"], 10, 64) if err != nil { return response.Error(http.StatusBadRequest, "id is invalid", err) } - if !validateUsingOrg(c.Req.Context(), c.UserId, orgID) { + if !hs.validateUsingOrg(c.Req.Context(), c.UserId, orgID) { return response.Error(401, "Not a valid organization", nil) } cmd := models.SetUsingOrgCommand{UserId: c.UserId, OrgId: orgID} - if err := bus.Dispatch(c.Req.Context(), &cmd); err != nil { + if err := hs.SQLStore.SetUsingOrg(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to change active organization", err) } @@ -268,20 +267,20 @@ func (hs *HTTPServer) ChangeActiveOrgAndRedirectToHome(c *models.ReqContext) { return } - if !validateUsingOrg(c.Req.Context(), c.UserId, orgID) { + if !hs.validateUsingOrg(c.Req.Context(), c.UserId, orgID) { hs.NotFoundHandler(c) } cmd := models.SetUsingOrgCommand{UserId: c.UserId, OrgId: orgID} - if err := bus.Dispatch(c.Req.Context(), &cmd); err != nil { + if err := hs.SQLStore.SetUsingOrg(c.Req.Context(), &cmd); err != nil { hs.NotFoundHandler(c) } c.Redirect(hs.Cfg.AppSubURL + "/") } -func ChangeUserPassword(c *models.ReqContext) response.Response { +func (hs *HTTPServer) ChangeUserPassword(c *models.ReqContext) response.Response { cmd := models.ChangeUserPasswordCommand{} if err := web.Bind(c.Req, &cmd); err != nil { return response.Error(http.StatusBadRequest, "bad request data", err) @@ -292,7 +291,7 @@ func ChangeUserPassword(c *models.ReqContext) response.Response { userQuery := models.GetUserByIdQuery{Id: c.UserId} - if err := bus.Dispatch(c.Req.Context(), &userQuery); err != nil { + if err := hs.SQLStore.GetUserById(c.Req.Context(), &userQuery); err != nil { return response.Error(500, "Could not read user from database", err) } @@ -315,7 +314,7 @@ func ChangeUserPassword(c *models.ReqContext) response.Response { return response.Error(500, "Failed to encode password", err) } - if err := bus.Dispatch(c.Req.Context(), &cmd); err != nil { + if err := hs.SQLStore.ChangeUserPassword(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to change user password", err) } @@ -327,7 +326,7 @@ func redirectToChangePassword(c *models.ReqContext) { c.Redirect("/profile/password", 302) } -func SetHelpFlag(c *models.ReqContext) response.Response { +func (hs *HTTPServer) SetHelpFlag(c *models.ReqContext) response.Response { flag, err := strconv.ParseInt(web.Params(c.Req)[":id"], 10, 64) if err != nil { return response.Error(http.StatusBadRequest, "id is invalid", err) @@ -341,20 +340,20 @@ func SetHelpFlag(c *models.ReqContext) response.Response { HelpFlags1: *bitmask, } - if err := bus.Dispatch(c.Req.Context(), &cmd); err != nil { + if err := hs.SQLStore.SetUserHelpFlag(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to update help flag", err) } return response.JSON(200, &util.DynMap{"message": "Help flag set", "helpFlags1": cmd.HelpFlags1}) } -func ClearHelpFlags(c *models.ReqContext) response.Response { +func (hs *HTTPServer) ClearHelpFlags(c *models.ReqContext) response.Response { cmd := models.SetUserHelpFlagCommand{ UserId: c.UserId, HelpFlags1: models.HelpFlags1(0), } - if err := bus.Dispatch(c.Req.Context(), &cmd); err != nil { + if err := hs.SQLStore.SetUserHelpFlag(c.Req.Context(), &cmd); err != nil { return response.Error(500, "Failed to update help flag", err) } diff --git a/pkg/api/user_test.go b/pkg/api/user_test.go index bba0c577ae1..e9aa77d46c3 100644 --- a/pkg/api/user_test.go +++ b/pkg/api/user_test.go @@ -2,18 +2,23 @@ package api import ( "context" + "encoding/json" "fmt" "net/http" "testing" "time" + "github.com/grafana/grafana/pkg/api/dtos" + "github.com/grafana/grafana/pkg/services/login/authinfoservice" "github.com/grafana/grafana/pkg/services/searchusers/filters" + "github.com/grafana/grafana/pkg/services/secrets/database" + secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager" "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/setting" + "golang.org/x/oauth2" "github.com/grafana/grafana/pkg/services/searchusers" - "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/models" @@ -23,10 +28,11 @@ import ( func TestUserAPIEndpoint_userLoggedIn(t *testing.T) { settings := setting.NewCfg() - hs := &HTTPServer{Cfg: settings} - sqlStore := sqlstore.InitTestDB(t) - hs.SQLStore = sqlStore + hs := &HTTPServer{ + Cfg: settings, + SQLStore: sqlStore, + } mockResult := models.SearchUserQueryResult{ Users: []*models.UserSearchHitDTO{ @@ -38,55 +44,62 @@ func TestUserAPIEndpoint_userLoggedIn(t *testing.T) { loggedInUserScenario(t, "When calling GET on", "api/users/1", "api/users/:id", func(sc *scenarioContext) { fakeNow := time.Date(2019, 2, 11, 17, 30, 40, 0, time.UTC) - bus.AddHandler("test", func(ctx context.Context, query *models.GetUserProfileQuery) error { - query.Result = models.UserProfileDTO{ - Id: int64(1), - Email: "daniel@grafana.com", - Name: "Daniel", - Login: "danlee", - OrgId: int64(2), - IsGrafanaAdmin: true, - IsDisabled: false, - IsExternal: false, - UpdatedAt: fakeNow, - CreatedAt: fakeNow, - } - return nil - }) + secretsService := secretsManager.SetupTestService(t, database.ProvideSecretsStore(sqlStore)) + srv := authinfoservice.ProvideAuthInfoService(bus.New(), sqlStore, &authinfoservice.OSSUserProtectionImpl{}, secretsService) + hs.authInfoService = srv - bus.AddHandler("test", func(ctx context.Context, query *models.GetAuthInfoQuery) error { - query.Result = &models.UserAuth{ - AuthModule: models.AuthModuleLDAP, - } - return nil - }) + createUserCmd := models.CreateUserCommand{ + Email: fmt.Sprint("user", "@test.com"), + Name: "user", + Login: "loginuser", + IsAdmin: true, + } + user, err := sqlStore.CreateUser(context.Background(), createUserCmd) + require.Nil(t, err) sc.handlerFunc = hs.GetUserByID - avatarUrl := dtos.GetGravatarUrl("daniel@grafana.com") - sc.fakeReqWithParams("GET", sc.url, map[string]string{"id": "1"}).exec() - expected := fmt.Sprintf(` - { - "id": 1, - "email": "daniel@grafana.com", - "name": "Daniel", - "login": "danlee", - "theme": "", - "orgId": 2, - "isGrafanaAdmin": true, - "isDisabled": false, - "isExternal": true, - "authLabels": [ - "LDAP" - ], - "avatarUrl": "%s", - "updatedAt": "2019-02-11T17:30:40Z", - "createdAt": "2019-02-11T17:30:40Z" - } - `, avatarUrl) + token := &oauth2.Token{ + AccessToken: "testaccess", + RefreshToken: "testrefresh", + Expiry: time.Now(), + TokenType: "Bearer", + } + idToken := "testidtoken" + token = token.WithExtra(map[string]interface{}{"id_token": idToken}) + query := &models.GetUserByAuthInfoQuery{Login: "loginuser", AuthModule: "test", AuthId: "test"} + cmd := &models.UpdateAuthInfoCommand{ + UserId: user.Id, + AuthId: query.AuthId, + AuthModule: query.AuthModule, + OAuthToken: token, + } + err = srv.UpdateAuthInfo(context.Background(), cmd) + require.NoError(t, err) + avatarUrl := dtos.GetGravatarUrl("@test.com") + sc.fakeReqWithParams("GET", sc.url, map[string]string{"id": fmt.Sprintf("%v", user.Id)}).exec() + expected := models.UserProfileDTO{ + Id: 1, + Email: "user@test.com", + Name: "user", + Login: "loginuser", + OrgId: 1, + IsGrafanaAdmin: true, + AuthLabels: []string{}, + CreatedAt: fakeNow, + UpdatedAt: fakeNow, + AvatarUrl: avatarUrl, + } + + var resp models.UserProfileDTO require.Equal(t, http.StatusOK, sc.resp.Code) - require.JSONEq(t, expected, sc.resp.Body.String()) + err = json.Unmarshal(sc.resp.Body.Bytes(), &resp) + require.NoError(t, err) + resp.CreatedAt = fakeNow + resp.UpdatedAt = fakeNow + resp.AvatarUrl = avatarUrl + require.EqualValues(t, expected, resp) }) loggedInUserScenario(t, "When calling GET on", "/api/users/lookup", "/api/users/lookup", func(sc *scenarioContext) { @@ -109,30 +122,25 @@ func TestUserAPIEndpoint_userLoggedIn(t *testing.T) { return nil }) + createUserCmd := models.CreateUserCommand{ + Email: fmt.Sprint("admin", "@test.com"), + Name: "admin", + Login: "admin", + IsAdmin: true, + } + _, err := sqlStore.CreateUser(context.Background(), createUserCmd) + require.Nil(t, err) - sc.handlerFunc = GetUserByLoginOrEmail - sc.fakeReqWithParams("GET", sc.url, map[string]string{"loginOrEmail": "danlee"}).exec() - - expected := ` - { - "id": 1, - "email": "daniel@grafana.com", - "name": "Daniel", - "login": "danlee", - "theme": "light", - "orgId": 2, - "isGrafanaAdmin": true, - "isDisabled": false, - "authLabels": null, - "isExternal": false, - "avatarUrl": "", - "updatedAt": "2019-02-11T17:30:40Z", - "createdAt": "2019-02-11T17:30:40Z" - } - ` + sc.handlerFunc = hs.GetUserByLoginOrEmail + sc.fakeReqWithParams("GET", sc.url, map[string]string{"loginOrEmail": "admin@test.com"}).exec() + var resp models.UserProfileDTO require.Equal(t, http.StatusOK, sc.resp.Code) - require.JSONEq(t, expected, sc.resp.Body.String()) + err = json.Unmarshal(sc.resp.Body.Bytes(), &resp) + require.NoError(t, err) + require.Equal(t, "admin", resp.Login) + require.Equal(t, "admin@test.com", resp.Email) + require.True(t, resp.IsGrafanaAdmin) }) loggedInUserScenario(t, "When calling GET on", "/api/users", "/api/users", func(sc *scenarioContext) { diff --git a/pkg/server/wire.go b/pkg/server/wire.go index 4a5b226196b..c998cc16edd 100644 --- a/pkg/server/wire.go +++ b/pkg/server/wire.go @@ -184,6 +184,7 @@ var wireBasicSet = wire.NewSet( wire.Bind(new(teamguardian.Store), new(*teamguardianDatabase.TeamGuardianStoreImpl)), teamguardianManager.ProvideService, wire.Bind(new(teamguardian.TeamGuardian), new(*teamguardianManager.Service)), + wire.Bind(new(authinfoservice.Service), new(*authinfoservice.Implementation)), featuremgmt.ProvideManagerService, featuremgmt.ProvideToggles, resourceservices.ProvideResourceServices, diff --git a/pkg/services/login/authinfoservice/service.go b/pkg/services/login/authinfoservice/service.go index c5817b9b210..e1a13cea8c2 100644 --- a/pkg/services/login/authinfoservice/service.go +++ b/pkg/services/login/authinfoservice/service.go @@ -22,6 +22,10 @@ type Implementation struct { logger log.Logger } +type Service interface { + GetAuthInfo(ctx context.Context, query *models.GetAuthInfoQuery) error +} + func ProvideAuthInfoService(bus bus.Bus, store *sqlstore.SQLStore, userProtectionService login.UserProtectionService, secretsService secrets.Service) *Implementation { s := &Implementation{ diff --git a/pkg/services/sqlstore/dashboard_test.go b/pkg/services/sqlstore/dashboard_test.go index 07cff5f5d8a..96106172075 100644 --- a/pkg/services/sqlstore/dashboard_test.go +++ b/pkg/services/sqlstore/dashboard_test.go @@ -590,7 +590,7 @@ func createUser(t *testing.T, sqlStore *SQLStore, name string, role string, isAd currentUser, err := sqlStore.CreateUser(context.Background(), currentUserCmd) require.NoError(t, err) q1 := models.GetUserOrgListQuery{UserId: currentUser.Id} - err = GetUserOrgList(context.Background(), &q1) + err = sqlStore.GetUserOrgList(context.Background(), &q1) require.NoError(t, err) require.Equal(t, models.RoleType(role), q1.Result[0].Role) return *currentUser diff --git a/pkg/services/sqlstore/org_test.go b/pkg/services/sqlstore/org_test.go index 0cda5495fcb..2d3c5ffb103 100644 --- a/pkg/services/sqlstore/org_test.go +++ b/pkg/services/sqlstore/org_test.go @@ -87,9 +87,9 @@ func TestAccountDataAccess(t *testing.T) { q1 := models.GetUserOrgListQuery{UserId: ac1.Id} q2 := models.GetUserOrgListQuery{UserId: ac2.Id} - err = GetUserOrgList(context.Background(), &q1) + err = sqlStore.GetUserOrgList(context.Background(), &q1) require.NoError(t, err) - err = GetUserOrgList(context.Background(), &q2) + err = sqlStore.GetUserOrgList(context.Background(), &q2) require.NoError(t, err) require.Equal(t, q1.Result[0].OrgId, q2.Result[0].OrgId) @@ -209,7 +209,7 @@ func TestAccountDataAccess(t *testing.T) { t.Run("Can get user organizations", func(t *testing.T) { query := models.GetUserOrgListQuery{UserId: ac2.Id} - err := GetUserOrgList(context.Background(), &query) + err := sqlStore.GetUserOrgList(context.Background(), &query) require.NoError(t, err) require.Equal(t, len(query.Result), 2) diff --git a/pkg/services/sqlstore/user.go b/pkg/services/sqlstore/user.go index 5f3d7f35a0b..0e09553c1ef 100644 --- a/pkg/services/sqlstore/user.go +++ b/pkg/services/sqlstore/user.go @@ -20,7 +20,7 @@ import ( func (ss *SQLStore) addUserQueryAndCommandHandlers() { ss.Bus.AddHandler(ss.GetSignedInUserWithCacheCtx) - bus.AddHandler("sql", GetUserById) + bus.AddHandler("sql", ss.GetUserById) bus.AddHandler("sql", ss.UpdateUser) bus.AddHandler("sql", ss.ChangeUserPassword) bus.AddHandler("sql", ss.GetUserByLogin) @@ -29,7 +29,7 @@ func (ss *SQLStore) addUserQueryAndCommandHandlers() { bus.AddHandler("sql", ss.UpdateUserLastSeenAt) bus.AddHandler("sql", ss.GetUserProfile) bus.AddHandler("sql", SearchUsers) - bus.AddHandler("sql", GetUserOrgList) + bus.AddHandler("sql", ss.GetUserOrgList) bus.AddHandler("sql", DisableUser) bus.AddHandler("sql", ss.BatchDisableUsers) bus.AddHandler("sql", ss.DeleteUser) @@ -320,7 +320,7 @@ func (ss *SQLStore) CreateUser(ctx context.Context, cmd models.CreateUserCommand return user, err } -func GetUserById(ctx context.Context, query *models.GetUserByIdQuery) error { +func (ss *SQLStore) GetUserById(ctx context.Context, query *models.GetUserByIdQuery) error { return withDbSession(ctx, x, func(sess *DBSession) error { user := new(models.User) has, err := sess.ID(query.Id).Get(user) @@ -444,7 +444,7 @@ func (ss *SQLStore) UpdateUserLastSeenAt(ctx context.Context, cmd *models.Update func (ss *SQLStore) SetUsingOrg(ctx context.Context, cmd *models.SetUsingOrgCommand) error { getOrgsForUserCmd := &models.GetUserOrgListQuery{UserId: cmd.UserId} - if err := GetUserOrgList(ctx, getOrgsForUserCmd); err != nil { + if err := ss.GetUserOrgList(ctx, getOrgsForUserCmd); err != nil { return err } @@ -522,7 +522,7 @@ func (o byOrgName) Less(i, j int) bool { return o[i].Name < o[j].Name } -func GetUserOrgList(ctx context.Context, query *models.GetUserOrgListQuery) error { +func (ss *SQLStore) GetUserOrgList(ctx context.Context, query *models.GetUserOrgListQuery) error { query.Result = make([]*models.UserOrgDTO, 0) sess := x.Table("org_user") sess.Join("INNER", "org", "org_user.org_id=org.id") diff --git a/pkg/services/sqlstore/user_test.go b/pkg/services/sqlstore/user_test.go index 4a18372ec89..84f787afbbc 100644 --- a/pkg/services/sqlstore/user_test.go +++ b/pkg/services/sqlstore/user_test.go @@ -27,7 +27,7 @@ func TestUserDataAccess(t *testing.T) { require.NoError(t, err) query := models.GetUserByIdQuery{Id: user.Id} - err = GetUserById(context.Background(), &query) + err = ss.GetUserById(context.Background(), &query) require.Nil(t, err) require.Equal(t, query.Result.Email, "usertest@test.com") @@ -37,7 +37,7 @@ func TestUserDataAccess(t *testing.T) { require.False(t, query.Result.IsDisabled) query = models.GetUserByIdQuery{Id: user.Id} - err = GetUserById(context.Background(), &query) + err = ss.GetUserById(context.Background(), &query) require.Nil(t, err) require.Equal(t, query.Result.Email, "usertest@test.com") @@ -60,7 +60,7 @@ func TestUserDataAccess(t *testing.T) { require.Nil(t, err) query := models.GetUserByIdQuery{Id: user.Id} - err = GetUserById(context.Background(), &query) + err = ss.GetUserById(context.Background(), &query) require.Nil(t, err) require.Equal(t, query.Result.Email, "usertest@test.com") @@ -94,7 +94,7 @@ func TestUserDataAccess(t *testing.T) { require.Nil(t, err) query := models.GetUserByIdQuery{Id: user.Id} - err = GetUserById(context.Background(), &query) + err = ss.GetUserById(context.Background(), &query) require.Nil(t, err) require.Equal(t, query.Result.Email, "usertest@test.com") @@ -469,7 +469,7 @@ func TestUserDataAccess(t *testing.T) { require.Equal(t, updatePermsError, models.ErrLastGrafanaAdmin) query := models.GetUserByIdQuery{Id: user.Id} - getUserError := GetUserById(context.Background(), &query) + getUserError := ss.GetUserById(context.Background(), &query) require.Nil(t, getUserError) require.True(t, query.Result.IsAdmin)