Chore: Split get user by ID (#52442)

* Remove user from preferences, stars, orguser, team member

* Fix lint

* Add Delete user from org and dashboard acl

* Delete user from user auth

* Add DeleteUser to quota

* Add test files and adjust user auth store

* Rename package in wire for user auth

* Import Quota Service interface in other services

* do the same in tests

* fix lint tests

* Fix tests

* Add some tests

* Rename InsertUser and DeleteUser to InsertOrgUser and DeleteOrgUser

* Rename DeleteUser to DeleteByUser in quota

* changing a method name in few additional places

* Fix in other places

* Fix lint

* Fix tests

* Chore: Split Delete User method

* Add fakes for userauth

* Add mock for access control Delete User permossion, use interface

* Use interface for ream guardian

* Add simple fake for dashboard acl

* Add go routines, clean up, use interfaces

* fix lint

* Update pkg/services/user/userimpl/user_test.go

Co-authored-by: Sofia Papagiannaki <1632407+papagian@users.noreply.github.com>

* Update pkg/services/user/userimpl/user_test.go

Co-authored-by: Sofia Papagiannaki <1632407+papagian@users.noreply.github.com>

* Update pkg/services/user/userimpl/user_test.go

Co-authored-by: Sofia Papagiannaki <1632407+papagian@users.noreply.github.com>

* Split get user by ID

* Use new method in api

* Add tests

* Aplly emthod in auth info service

* Fix lint and some tests

* Fix get user by ID

* Fix lint
Remove unused fakes

* Use split get user id in admin users

* Use GetbyID in cli commands

* Clean up after merge

* Remove commented out code

* Clena up imports

* add back )

* Fix wire generation for runner after merge with main

Co-authored-by: Sofia Papagiannaki <1632407+papagian@users.noreply.github.com>
This commit is contained in:
idafurjes
2022-08-02 16:58:05 +02:00
committed by GitHub
parent 64488f6b90
commit fab6c38c95
28 changed files with 1182 additions and 1522 deletions

View File

@@ -109,13 +109,14 @@ func (hs *HTTPServer) AdminUpdateUserPassword(c *models.ReqContext) response.Res
return response.Error(400, "New password too short", nil) return response.Error(400, "New password too short", nil)
} }
userQuery := models.GetUserByIdQuery{Id: userID} userQuery := user.GetUserByIDQuery{ID: userID}
if err := hs.SQLStore.GetUserById(c.Req.Context(), &userQuery); err != nil { usr, err := hs.userService.GetByID(c.Req.Context(), &userQuery)
if err != nil {
return response.Error(500, "Could not read user from database", err) return response.Error(500, "Could not read user from database", err)
} }
passwordHashed, err := util.EncodePassword(form.Password, userQuery.Result.Salt) passwordHashed, err := util.EncodePassword(form.Password, usr.Salt)
if err != nil { if err != nil {
return response.Error(500, "Could not encode password", err) return response.Error(500, "Could not encode password", err)
} }

View File

@@ -4,6 +4,9 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/dtos"
"github.com/grafana/grafana/pkg/api/response" "github.com/grafana/grafana/pkg/api/response"
"github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/api/routing"
@@ -15,9 +18,8 @@ import (
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore" "github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/services/user/usertest"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
const ( const (
@@ -29,7 +31,7 @@ const (
func TestAdminAPIEndpoint(t *testing.T) { func TestAdminAPIEndpoint(t *testing.T) {
const role = models.ROLE_ADMIN const role = models.ROLE_ADMIN
userService := usertest.NewUserServiceFake()
t.Run("Given a server admin attempts to remove themselves as an admin", func(t *testing.T) { t.Run("Given a server admin attempts to remove themselves as an admin", func(t *testing.T) {
updateCmd := dtos.AdminUpdateUserPermissionsForm{ updateCmd := dtos.AdminUpdateUserPermissionsForm{
IsGrafanaAdmin: false, IsGrafanaAdmin: false,
@@ -45,46 +47,43 @@ func TestAdminAPIEndpoint(t *testing.T) {
}) })
t.Run("When a server admin attempts to logout himself from all devices", func(t *testing.T) { t.Run("When a server admin attempts to logout himself from all devices", func(t *testing.T) {
mock := mockstore.NewSQLStoreMock()
adminLogoutUserScenario(t, "Should not be allowed when calling POST on", adminLogoutUserScenario(t, "Should not be allowed when calling POST on",
"/api/admin/users/1/logout", "/api/admin/users/:id/logout", func(sc *scenarioContext) { "/api/admin/users/1/logout", "/api/admin/users/:id/logout", func(sc *scenarioContext) {
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
assert.Equal(t, 400, sc.resp.Code) assert.Equal(t, 400, sc.resp.Code)
}, mock) }, userService)
}) })
t.Run("When a server admin attempts to logout a non-existing user from all devices", func(t *testing.T) { t.Run("When a server admin attempts to logout a non-existing user from all devices", func(t *testing.T) {
mock := &mockstore.SQLStoreMock{ mockUserService := usertest.NewUserServiceFake()
ExpectedError: user.ErrUserNotFound, mockUserService.ExpectedError = user.ErrUserNotFound
}
adminLogoutUserScenario(t, "Should return not found when calling POST on", "/api/admin/users/200/logout", adminLogoutUserScenario(t, "Should return not found when calling POST on", "/api/admin/users/200/logout",
"/api/admin/users/:id/logout", func(sc *scenarioContext) { "/api/admin/users/:id/logout", func(sc *scenarioContext) {
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
assert.Equal(t, 404, sc.resp.Code) assert.Equal(t, 404, sc.resp.Code)
}, mock) }, mockUserService)
}) })
t.Run("When a server admin attempts to revoke an auth token for a non-existing user", func(t *testing.T) { t.Run("When a server admin attempts to revoke an auth token for a non-existing user", func(t *testing.T) {
cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2} cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2}
mock := &mockstore.SQLStoreMock{ mockUser := usertest.NewUserServiceFake()
ExpectedError: user.ErrUserNotFound, mockUser.ExpectedError = user.ErrUserNotFound
}
adminRevokeUserAuthTokenScenario(t, "Should return not found when calling POST on", adminRevokeUserAuthTokenScenario(t, "Should return not found when calling POST on",
"/api/admin/users/200/revoke-auth-token", "/api/admin/users/:id/revoke-auth-token", cmd, func(sc *scenarioContext) { "/api/admin/users/200/revoke-auth-token", "/api/admin/users/:id/revoke-auth-token", cmd, func(sc *scenarioContext) {
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
assert.Equal(t, 404, sc.resp.Code) assert.Equal(t, 404, sc.resp.Code)
}, mock) }, mockUser)
}) })
t.Run("When a server admin gets auth tokens for a non-existing user", func(t *testing.T) { t.Run("When a server admin gets auth tokens for a non-existing user", func(t *testing.T) {
mock := &mockstore.SQLStoreMock{ mockUserService := usertest.NewUserServiceFake()
ExpectedError: user.ErrUserNotFound, mockUserService.ExpectedError = user.ErrUserNotFound
}
adminGetUserAuthTokensScenario(t, "Should return not found when calling GET on", adminGetUserAuthTokensScenario(t, "Should return not found when calling GET on",
"/api/admin/users/200/auth-tokens", "/api/admin/users/:id/auth-tokens", func(sc *scenarioContext) { "/api/admin/users/200/auth-tokens", "/api/admin/users/:id/auth-tokens", func(sc *scenarioContext) {
sc.fakeReqWithParams("GET", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("GET", sc.url, map[string]string{}).exec()
assert.Equal(t, 404, sc.resp.Code) assert.Equal(t, 404, sc.resp.Code)
}, mock) }, mockUserService)
}) })
t.Run("When a server admin attempts to enable/disable a nonexistent user", func(t *testing.T) { t.Run("When a server admin attempts to enable/disable a nonexistent user", func(t *testing.T) {
@@ -154,17 +153,15 @@ func TestAdminAPIEndpoint(t *testing.T) {
adminDeleteUserScenario(t, "Should return user not found error", "/api/admin/users/42", adminDeleteUserScenario(t, "Should return user not found error", "/api/admin/users/42",
"/api/admin/users/:id", func(sc *scenarioContext) { "/api/admin/users/:id", func(sc *scenarioContext) {
sc.sqlStore.(*mockstore.SQLStoreMock).ExpectedError = user.ErrUserNotFound sc.sqlStore.(*mockstore.SQLStoreMock).ExpectedError = user.ErrUserNotFound
sc.userService.(*usertest.FakeUserService).ExpectedError = user.ErrUserNotFound
sc.authInfoService.ExpectedError = user.ErrUserNotFound sc.authInfoService.ExpectedError = user.ErrUserNotFound
sc.fakeReqWithParams("DELETE", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("DELETE", sc.url, map[string]string{}).exec()
userID := sc.sqlStore.(*mockstore.SQLStoreMock).LatestUserId
assert.Equal(t, 404, sc.resp.Code) assert.Equal(t, 404, sc.resp.Code)
respJSON, err := simplejson.NewJson(sc.resp.Body.Bytes()) respJSON, err := simplejson.NewJson(sc.resp.Body.Bytes())
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "user not found", respJSON.Get("message").MustString()) assert.Equal(t, "user not found", respJSON.Get("message").MustString())
assert.Equal(t, int64(42), userID)
}) })
}) })
@@ -266,11 +263,11 @@ func putAdminScenario(t *testing.T, desc string, url string, routePattern string
}) })
} }
func adminLogoutUserScenario(t *testing.T, desc string, url string, routePattern string, fn scenarioFunc, sqlStore sqlstore.Store) { func adminLogoutUserScenario(t *testing.T, desc string, url string, routePattern string, fn scenarioFunc, userService *usertest.FakeUserService) {
t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) {
hs := HTTPServer{ hs := HTTPServer{
AuthTokenService: auth.NewFakeUserAuthTokenService(), AuthTokenService: auth.NewFakeUserAuthTokenService(),
SQLStore: sqlStore, userService: userService,
} }
sc := setupScenarioContext(t, url) sc := setupScenarioContext(t, url)
@@ -291,13 +288,13 @@ func adminLogoutUserScenario(t *testing.T, desc string, url string, routePattern
}) })
} }
func adminRevokeUserAuthTokenScenario(t *testing.T, desc string, url string, routePattern string, cmd models.RevokeAuthTokenCmd, fn scenarioFunc, sqlStore sqlstore.Store) { func adminRevokeUserAuthTokenScenario(t *testing.T, desc string, url string, routePattern string, cmd models.RevokeAuthTokenCmd, fn scenarioFunc, userService user.Service) {
t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) {
fakeAuthTokenService := auth.NewFakeUserAuthTokenService() fakeAuthTokenService := auth.NewFakeUserAuthTokenService()
hs := HTTPServer{ hs := HTTPServer{
AuthTokenService: fakeAuthTokenService, AuthTokenService: fakeAuthTokenService,
SQLStore: sqlStore, userService: userService,
} }
sc := setupScenarioContext(t, url) sc := setupScenarioContext(t, url)
@@ -319,13 +316,13 @@ func adminRevokeUserAuthTokenScenario(t *testing.T, desc string, url string, rou
}) })
} }
func adminGetUserAuthTokensScenario(t *testing.T, desc string, url string, routePattern string, fn scenarioFunc, sqlStore sqlstore.Store) { func adminGetUserAuthTokensScenario(t *testing.T, desc string, url string, routePattern string, fn scenarioFunc, userService *usertest.FakeUserService) {
t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) {
fakeAuthTokenService := auth.NewFakeUserAuthTokenService() fakeAuthTokenService := auth.NewFakeUserAuthTokenService()
hs := HTTPServer{ hs := HTTPServer{
AuthTokenService: fakeAuthTokenService, AuthTokenService: fakeAuthTokenService,
SQLStore: sqlStore, userService: userService,
} }
sc := setupScenarioContext(t, url) sc := setupScenarioContext(t, url)
@@ -379,7 +376,8 @@ func adminDisableUserScenario(t *testing.T, desc string, action string, url stri
func adminDeleteUserScenario(t *testing.T, desc string, url string, routePattern string, fn scenarioFunc) { func adminDeleteUserScenario(t *testing.T, desc string, url string, routePattern string, fn scenarioFunc) {
hs := HTTPServer{ hs := HTTPServer{
SQLStore: mockstore.NewSQLStoreMock(), SQLStore: mockstore.NewSQLStoreMock(),
userService: usertest.NewUserServiceFake(),
} }
t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) {
sc := setupScenarioContext(t, url) sc := setupScenarioContext(t, url)
@@ -391,6 +389,7 @@ func adminDeleteUserScenario(t *testing.T, desc string, url string, routePattern
return hs.AdminDeleteUser(c) return hs.AdminDeleteUser(c)
}) })
sc.userService = hs.userService
sc.m.Delete(routePattern, sc.defaultHandler) sc.m.Delete(routePattern, sc.defaultHandler)

View File

@@ -43,6 +43,7 @@ import (
"github.com/grafana/grafana/pkg/services/searchusers" "github.com/grafana/grafana/pkg/services/searchusers"
"github.com/grafana/grafana/pkg/services/searchusers/filters" "github.com/grafana/grafana/pkg/services/searchusers/filters"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/web" "github.com/grafana/grafana/pkg/web"
"github.com/grafana/grafana/pkg/web/webtest" "github.com/grafana/grafana/pkg/web/webtest"
@@ -168,6 +169,7 @@ type scenarioContext struct {
sqlStore sqlstore.Store sqlStore sqlstore.Store
authInfoService *logintest.AuthInfoServiceFake authInfoService *logintest.AuthInfoServiceFake
dashboardVersionService dashver.Service dashboardVersionService dashver.Service
userService user.Service
} }
func (sc *scenarioContext) exec() { func (sc *scenarioContext) exec() {

View File

@@ -28,6 +28,7 @@ import (
"github.com/grafana/grafana/pkg/services/guardian" "github.com/grafana/grafana/pkg/services/guardian"
pref "github.com/grafana/grafana/pkg/services/preference" pref "github.com/grafana/grafana/pkg/services/preference"
"github.com/grafana/grafana/pkg/services/star" "github.com/grafana/grafana/pkg/services/star"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
"github.com/grafana/grafana/pkg/web" "github.com/grafana/grafana/pkg/web"
) )
@@ -250,12 +251,12 @@ func (hs *HTTPServer) getAnnotationPermissionsByScope(c *models.ReqContext, acti
} }
func (hs *HTTPServer) getUserLogin(ctx context.Context, userID int64) string { func (hs *HTTPServer) getUserLogin(ctx context.Context, userID int64) string {
query := models.GetUserByIdQuery{Id: userID} query := user.GetUserByIDQuery{ID: userID}
err := hs.SQLStore.GetUserById(ctx, &query) user, err := hs.userService.GetByID(ctx, &query)
if err != nil { if err != nil {
return anonString return anonString
} }
return query.Result.Login return user.Login
} }
func (hs *HTTPServer) getDashboardHelper(ctx context.Context, orgID int64, id int64, uid string) (*models.Dashboard, response.Response) { func (hs *HTTPServer) getDashboardHelper(ctx context.Context, orgID int64, id int64, uid string) (*models.Dashboard, response.Response) {

View File

@@ -209,9 +209,10 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon
return response.Error(http.StatusBadRequest, "id is invalid", err) return response.Error(http.StatusBadRequest, "id is invalid", err)
} }
query := models.GetUserByIdQuery{Id: userId} query := user.GetUserByIDQuery{ID: userId}
if err := hs.SQLStore.GetUserById(c.Req.Context(), &query); err != nil { // validate the userId exists usr, err := hs.userService.GetByID(c.Req.Context(), &query)
if err != nil { // validate the userId exists
if errors.Is(err, user.ErrUserNotFound) { if errors.Is(err, user.ErrUserNotFound) {
return response.Error(404, user.ErrUserNotFound.Error(), nil) return response.Error(404, user.ErrUserNotFound.Error(), nil)
} }
@@ -219,7 +220,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon
return response.Error(500, "Failed to get user", err) return response.Error(500, "Failed to get user", err)
} }
authModuleQuery := &models.GetAuthInfoQuery{UserId: query.Result.ID, AuthModule: models.AuthModuleLDAP} authModuleQuery := &models.GetAuthInfoQuery{UserId: usr.ID, AuthModule: models.AuthModuleLDAP}
if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), authModuleQuery); err != nil { // validate the userId comes from LDAP if err := hs.authInfoService.GetAuthInfo(c.Req.Context(), authModuleQuery); err != nil { // validate the userId comes from LDAP
if errors.Is(err, user.ErrUserNotFound) { if errors.Is(err, user.ErrUserNotFound) {
return response.Error(404, user.ErrUserNotFound.Error(), nil) return response.Error(404, user.ErrUserNotFound.Error(), nil)
@@ -229,17 +230,17 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon
} }
ldapServer := newLDAP(ldapConfig.Servers) ldapServer := newLDAP(ldapConfig.Servers)
user, _, err := ldapServer.User(query.Result.Login) userInfo, _, err := ldapServer.User(usr.Login)
if err != nil { if err != nil {
if errors.Is(err, multildap.ErrDidNotFindUser) { // User was not in the LDAP server - we need to take action: if errors.Is(err, multildap.ErrDidNotFindUser) { // User was not in the LDAP server - we need to take action:
if hs.Cfg.AdminUser == query.Result.Login { // User is *the* Grafana Admin. We cannot disable it. if hs.Cfg.AdminUser == usr.Login { // User is *the* Grafana Admin. We cannot disable it.
errMsg := fmt.Sprintf(`Refusing to sync grafana super admin "%s" - it would be disabled`, query.Result.Login) errMsg := fmt.Sprintf(`Refusing to sync grafana super admin "%s" - it would be disabled`, usr.Login)
ldapLogger.Error(errMsg) ldapLogger.Error(errMsg)
return response.Error(http.StatusBadRequest, errMsg, err) return response.Error(http.StatusBadRequest, errMsg, err)
} }
// Since the user was not in the LDAP server. Let's disable it. // Since the user was not in the LDAP server. Let's disable it.
err := hs.Login.DisableExternalUser(c.Req.Context(), query.Result.Login) err := hs.Login.DisableExternalUser(c.Req.Context(), usr.Login)
if err != nil { if err != nil {
return response.Error(http.StatusInternalServerError, "Failed to disable the user", err) return response.Error(http.StatusInternalServerError, "Failed to disable the user", err)
} }
@@ -258,10 +259,10 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) response.Respon
upsertCmd := &models.UpsertUserCommand{ upsertCmd := &models.UpsertUserCommand{
ReqContext: c, ReqContext: c,
ExternalUser: user, ExternalUser: userInfo,
SignupAllowed: hs.Cfg.LDAPAllowSignup, SignupAllowed: hs.Cfg.LDAPAllowSignup,
UserLookupParams: models.UserLookupParams{ UserLookupParams: models.UserLookupParams{
UserID: &query.Result.ID, // Upsert by ID only UserID: &usr.ID, // Upsert by ID only
Email: nil, Email: nil,
Login: nil, Login: nil,
}, },

View File

@@ -8,22 +8,22 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/stretchr/testify/assert"
"github.com/grafana/grafana/pkg/services/login/loginservice" "github.com/stretchr/testify/require"
"github.com/grafana/grafana/pkg/services/login/logintest"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/api/response" "github.com/grafana/grafana/pkg/api/response"
"github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/ldap" "github.com/grafana/grafana/pkg/services/ldap"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/login/logintest"
"github.com/grafana/grafana/pkg/services/multildap" "github.com/grafana/grafana/pkg/services/multildap"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/services/user/usertest"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
type LDAPMock struct { type LDAPMock struct {
@@ -363,7 +363,7 @@ func TestGetLDAPStatusAPIEndpoint(t *testing.T) {
// PostSyncUserWithLDAP tests // PostSyncUserWithLDAP tests
// *** // ***
func postSyncUserWithLDAPContext(t *testing.T, requestURL string, preHook func(*testing.T, *scenarioContext), sqlstoremock sqlstore.Store) *scenarioContext { func postSyncUserWithLDAPContext(t *testing.T, requestURL string, preHook func(*testing.T, *scenarioContext), userService user.Service) *scenarioContext {
t.Helper() t.Helper()
sc := setupScenarioContext(t, requestURL) sc := setupScenarioContext(t, requestURL)
@@ -378,9 +378,9 @@ func postSyncUserWithLDAPContext(t *testing.T, requestURL string, preHook func(*
hs := &HTTPServer{ hs := &HTTPServer{
Cfg: sc.cfg, Cfg: sc.cfg,
AuthTokenService: auth.NewFakeUserAuthTokenService(), AuthTokenService: auth.NewFakeUserAuthTokenService(),
SQLStore: sqlstoremock,
Login: loginservice.LoginServiceMock{}, Login: loginservice.LoginServiceMock{},
authInfoService: sc.authInfoService, authInfoService: sc.authInfoService,
userService: userService,
} }
sc.defaultHandler = routing.Wrap(func(c *models.ReqContext) response.Response { sc.defaultHandler = routing.Wrap(func(c *models.ReqContext) response.Response {
@@ -403,8 +403,8 @@ func postSyncUserWithLDAPContext(t *testing.T, requestURL string, preHook func(*
} }
func TestPostSyncUserWithLDAPAPIEndpoint_Success(t *testing.T) { func TestPostSyncUserWithLDAPAPIEndpoint_Success(t *testing.T) {
sqlstoremock := mockstore.SQLStoreMock{} userServiceMock := usertest.NewUserServiceFake()
sqlstoremock.ExpectedUser = &user.User{Login: "ldap-daniel", ID: 34} userServiceMock.ExpectedUser = &user.User{Login: "ldap-daniel", ID: 34}
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T, sc *scenarioContext) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T, sc *scenarioContext) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
@@ -417,7 +417,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_Success(t *testing.T) {
userSearchResult = &models.ExternalUserInfo{ userSearchResult = &models.ExternalUserInfo{
Login: "ldap-daniel", Login: "ldap-daniel",
} }
}, &sqlstoremock) }, userServiceMock)
assert.Equal(t, http.StatusOK, sc.resp.Code) assert.Equal(t, http.StatusOK, sc.resp.Code)
@@ -431,7 +431,8 @@ func TestPostSyncUserWithLDAPAPIEndpoint_Success(t *testing.T) {
} }
func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotFound(t *testing.T) { func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotFound(t *testing.T) {
sqlstoremock := mockstore.SQLStoreMock{ExpectedError: user.ErrUserNotFound} userServiceMock := usertest.NewUserServiceFake()
userServiceMock.ExpectedError = user.ErrUserNotFound
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T, sc *scenarioContext) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T, sc *scenarioContext) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
@@ -440,7 +441,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotFound(t *testing.T) {
newLDAP = func(_ []*ldap.ServerConfig) multildap.IMultiLDAP { newLDAP = func(_ []*ldap.ServerConfig) multildap.IMultiLDAP {
return &LDAPMock{} return &LDAPMock{}
} }
}, &sqlstoremock) }, userServiceMock)
assert.Equal(t, http.StatusNotFound, sc.resp.Code) assert.Equal(t, http.StatusNotFound, sc.resp.Code)
@@ -454,7 +455,8 @@ func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotFound(t *testing.T) {
} }
func TestPostSyncUserWithLDAPAPIEndpoint_WhenGrafanaAdmin(t *testing.T) { func TestPostSyncUserWithLDAPAPIEndpoint_WhenGrafanaAdmin(t *testing.T) {
sqlstoremock := mockstore.SQLStoreMock{ExpectedUser: &user.User{Login: "ldap-daniel", ID: 34}} userServiceMock := usertest.NewUserServiceFake()
userServiceMock.ExpectedUser = &user.User{Login: "ldap-daniel", ID: 34}
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T, sc *scenarioContext) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T, sc *scenarioContext) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil return &ldap.Config{}, nil
@@ -467,7 +469,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_WhenGrafanaAdmin(t *testing.T) {
userSearchError = multildap.ErrDidNotFindUser userSearchError = multildap.ErrDidNotFindUser
sc.cfg.AdminUser = "ldap-daniel" sc.cfg.AdminUser = "ldap-daniel"
}, &sqlstoremock) }, userServiceMock)
assert.Equal(t, http.StatusBadRequest, sc.resp.Code) assert.Equal(t, http.StatusBadRequest, sc.resp.Code)
var res map[string]interface{} var res map[string]interface{}
@@ -478,7 +480,8 @@ func TestPostSyncUserWithLDAPAPIEndpoint_WhenGrafanaAdmin(t *testing.T) {
} }
func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotInLDAP(t *testing.T) { func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotInLDAP(t *testing.T) {
sqlstoremock := mockstore.SQLStoreMock{ExpectedUser: &user.User{Login: "ldap-daniel", ID: 34}} userServiceMock := usertest.NewUserServiceFake()
userServiceMock.ExpectedUser = &user.User{Login: "ldap-daniel", ID: 34}
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T, sc *scenarioContext) { sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T, sc *scenarioContext) {
sc.authInfoService.ExpectedExternalUser = &models.ExternalUserInfo{IsDisabled: true, UserId: 34} sc.authInfoService.ExpectedExternalUser = &models.ExternalUserInfo{IsDisabled: true, UserId: 34}
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) { getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
@@ -491,7 +494,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotInLDAP(t *testing.T) {
userSearchResult = nil userSearchResult = nil
userSearchError = multildap.ErrDidNotFindUser userSearchError = multildap.ErrDidNotFindUser
}, &sqlstoremock) }, userServiceMock)
assert.Equal(t, http.StatusBadRequest, sc.resp.Code) assert.Equal(t, http.StatusBadRequest, sc.resp.Code)
@@ -603,6 +606,7 @@ func TestLDAP_AccessControl(t *testing.T) {
cfg.LDAPEnabled = true cfg.LDAPEnabled = true
sc, hs := setupAccessControlScenarioContext(t, cfg, test.url, test.permissions) sc, hs := setupAccessControlScenarioContext(t, cfg, test.url, test.permissions)
hs.SQLStore = &mockstore.SQLStoreMock{ExpectedUser: &user.User{}} hs.SQLStore = &mockstore.SQLStoreMock{ExpectedUser: &user.User{}}
hs.userService = &usertest.FakeUserService{ExpectedUser: &user.User{}}
hs.authInfoService = &logintest.AuthInfoServiceFake{} hs.authInfoService = &logintest.AuthInfoServiceFake{}
hs.Login = &loginservice.LoginServiceMock{} hs.Login = &loginservice.LoginServiceMock{}
sc.resp = httptest.NewRecorder() sc.resp = httptest.NewRecorder()

View File

@@ -387,17 +387,18 @@ func (hs *HTTPServer) ChangeUserPassword(c *models.ReqContext) response.Response
return response.Error(400, "Not allowed to change password when LDAP or Auth Proxy is enabled", nil) return response.Error(400, "Not allowed to change password when LDAP or Auth Proxy is enabled", nil)
} }
userQuery := models.GetUserByIdQuery{Id: c.UserId} userQuery := user.GetUserByIDQuery{ID: c.UserId}
if err := hs.SQLStore.GetUserById(c.Req.Context(), &userQuery); err != nil { user, err := hs.userService.GetByID(c.Req.Context(), &userQuery)
if err != nil {
return response.Error(500, "Could not read user from database", err) return response.Error(500, "Could not read user from database", err)
} }
passwordHashed, err := util.EncodePassword(cmd.OldPassword, userQuery.Result.Salt) passwordHashed, err := util.EncodePassword(cmd.OldPassword, user.Salt)
if err != nil { if err != nil {
return response.Error(500, "Failed to encode password", err) return response.Error(500, "Failed to encode password", err)
} }
if passwordHashed != userQuery.Result.Password { if passwordHashed != user.Password {
return response.Error(401, "Invalid old password", nil) return response.Error(401, "Invalid old password", nil)
} }
@@ -407,7 +408,7 @@ func (hs *HTTPServer) ChangeUserPassword(c *models.ReqContext) response.Response
} }
cmd.UserId = c.UserId cmd.UserId = c.UserId
cmd.NewPassword, err = util.EncodePassword(cmd.NewPassword, userQuery.Result.Salt) cmd.NewPassword, err = util.EncodePassword(cmd.NewPassword, user.Salt)
if err != nil { if err != nil {
return response.Error(500, "Failed to encode password", err) return response.Error(500, "Failed to encode password", err)
} }

View File

@@ -26,6 +26,7 @@ import (
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore" "github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/services/user/usertest"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
) )
@@ -49,7 +50,7 @@ func TestUserAPIEndpoint_userLoggedIn(t *testing.T) {
loggedInUserScenario(t, "When calling GET on", "api/users/1", "api/users/:id", func(sc *scenarioContext) { loggedInUserScenario(t, "When calling GET on", "api/users/1", "api/users/:id", func(sc *scenarioContext) {
fakeNow := time.Date(2019, 2, 11, 17, 30, 40, 0, time.UTC) fakeNow := time.Date(2019, 2, 11, 17, 30, 40, 0, time.UTC)
secretsService := secretsManager.SetupTestService(t, database.ProvideSecretsStore(sqlStore)) secretsService := secretsManager.SetupTestService(t, database.ProvideSecretsStore(sqlStore))
authInfoStore := authinfostore.ProvideAuthInfoStore(sqlStore, secretsService) authInfoStore := authinfostore.ProvideAuthInfoStore(sqlStore, secretsService, usertest.NewUserServiceFake())
srv := authinfoservice.ProvideAuthInfoService( srv := authinfoservice.ProvideAuthInfoService(
&authinfoservice.OSSUserProtectionImpl{}, &authinfoservice.OSSUserProtectionImpl{},
authInfoStore, authInfoStore,

View File

@@ -51,16 +51,17 @@ func (hs *HTTPServer) RevokeUserAuthToken(c *models.ReqContext) response.Respons
} }
func (hs *HTTPServer) logoutUserFromAllDevicesInternal(ctx context.Context, userID int64) response.Response { func (hs *HTTPServer) logoutUserFromAllDevicesInternal(ctx context.Context, userID int64) response.Response {
userQuery := models.GetUserByIdQuery{Id: userID} userQuery := user.GetUserByIDQuery{ID: userID}
if err := hs.SQLStore.GetUserById(ctx, &userQuery); err != nil { _, err := hs.userService.GetByID(ctx, &userQuery)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) { if errors.Is(err, user.ErrUserNotFound) {
return response.Error(404, "User not found", err) return response.Error(404, "User not found", err)
} }
return response.Error(500, "Could not read user from database", err) return response.Error(500, "Could not read user from database", err)
} }
err := hs.AuthTokenService.RevokeAllUserTokens(ctx, userID) err = hs.AuthTokenService.RevokeAllUserTokens(ctx, userID)
if err != nil { if err != nil {
return response.Error(500, "Failed to logout user", err) return response.Error(500, "Failed to logout user", err)
} }
@@ -71,9 +72,10 @@ func (hs *HTTPServer) logoutUserFromAllDevicesInternal(ctx context.Context, user
} }
func (hs *HTTPServer) getUserAuthTokensInternal(c *models.ReqContext, userID int64) response.Response { func (hs *HTTPServer) getUserAuthTokensInternal(c *models.ReqContext, userID int64) response.Response {
userQuery := models.GetUserByIdQuery{Id: userID} userQuery := user.GetUserByIDQuery{ID: userID}
if err := hs.SQLStore.GetUserById(c.Req.Context(), &userQuery); err != nil { _, err := hs.userService.GetByID(c.Req.Context(), &userQuery)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) { if errors.Is(err, user.ErrUserNotFound) {
return response.Error(http.StatusNotFound, "User not found", err) return response.Error(http.StatusNotFound, "User not found", err)
} else if errors.Is(err, user.ErrCaseInsensitive) { } else if errors.Is(err, user.ErrCaseInsensitive) {
@@ -142,8 +144,9 @@ func (hs *HTTPServer) getUserAuthTokensInternal(c *models.ReqContext, userID int
} }
func (hs *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, userID int64, cmd models.RevokeAuthTokenCmd) response.Response { func (hs *HTTPServer) revokeUserAuthTokenInternal(c *models.ReqContext, userID int64, cmd models.RevokeAuthTokenCmd) response.Response {
userQuery := models.GetUserByIdQuery{Id: userID} userQuery := user.GetUserByIDQuery{ID: userID}
if err := hs.SQLStore.GetUserById(c.Req.Context(), &userQuery); err != nil { _, err := hs.userService.GetByID(c.Req.Context(), &userQuery)
if err != nil {
if errors.Is(err, user.ErrUserNotFound) { if errors.Is(err, user.ErrUserNotFound) {
return response.Error(404, "User not found", err) return response.Error(404, "User not found", err)
} }

View File

@@ -6,62 +6,61 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/grafana/grafana/pkg/api/response" "github.com/grafana/grafana/pkg/api/response"
"github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth" "github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/assert" "github.com/grafana/grafana/pkg/services/user/usertest"
) )
func TestUserTokenAPIEndpoint(t *testing.T) { func TestUserTokenAPIEndpoint(t *testing.T) {
mock := mockstore.NewSQLStoreMock() userMock := usertest.NewUserServiceFake()
t.Run("When current user attempts to revoke an auth token for a non-existing user", func(t *testing.T) { t.Run("When current user attempts to revoke an auth token for a non-existing user", func(t *testing.T) {
cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2} cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2}
mock.ExpectedError = user.ErrUserNotFound userMock.ExpectedError = user.ErrUserNotFound
revokeUserAuthTokenScenario(t, "Should return not found when calling POST on", "/api/user/revoke-auth-token", revokeUserAuthTokenScenario(t, "Should return not found when calling POST on", "/api/user/revoke-auth-token",
"/api/user/revoke-auth-token", cmd, 200, func(sc *scenarioContext) { "/api/user/revoke-auth-token", cmd, 200, func(sc *scenarioContext) {
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
assert.Equal(t, 404, sc.resp.Code) assert.Equal(t, 404, sc.resp.Code)
}, mock) }, userMock)
}) })
t.Run("When current user gets auth tokens for a non-existing user", func(t *testing.T) { t.Run("When current user gets auth tokens for a non-existing user", func(t *testing.T) {
mock := &mockstore.SQLStoreMock{ mockUser := &usertest.FakeUserService{
ExpectedUser: &user.User{ID: 200}, ExpectedUser: &user.User{ID: 200},
ExpectedError: user.ErrUserNotFound, ExpectedError: user.ErrUserNotFound,
} }
getUserAuthTokensScenario(t, "Should return not found when calling GET on", "/api/user/auth-tokens", "/api/user/auth-tokens", 200, func(sc *scenarioContext) { getUserAuthTokensScenario(t, "Should return not found when calling GET on", "/api/user/auth-tokens", "/api/user/auth-tokens", 200, func(sc *scenarioContext) {
sc.fakeReqWithParams("GET", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("GET", sc.url, map[string]string{}).exec()
assert.Equal(t, 404, sc.resp.Code) assert.Equal(t, 404, sc.resp.Code)
}, mock) }, mockUser)
}) })
t.Run("When logging out an existing user from all devices", func(t *testing.T) { t.Run("When logging out an existing user from all devices", func(t *testing.T) {
mock := &mockstore.SQLStoreMock{ userMock := &usertest.FakeUserService{
ExpectedUser: &user.User{ID: 200}, ExpectedUser: &user.User{ID: 200},
} }
logoutUserFromAllDevicesInternalScenario(t, "Should be successful", 1, func(sc *scenarioContext) { logoutUserFromAllDevicesInternalScenario(t, "Should be successful", 1, func(sc *scenarioContext) {
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
assert.Equal(t, 200, sc.resp.Code) assert.Equal(t, 200, sc.resp.Code)
}, mock) }, userMock)
}) })
t.Run("When logout a non-existing user from all devices", func(t *testing.T) { t.Run("When logout a non-existing user from all devices", func(t *testing.T) {
logoutUserFromAllDevicesInternalScenario(t, "Should return not found", testUserID, func(sc *scenarioContext) { logoutUserFromAllDevicesInternalScenario(t, "Should return not found", testUserID, func(sc *scenarioContext) {
mock.ExpectedError = user.ErrUserNotFound userMock.ExpectedError = user.ErrUserNotFound
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
assert.Equal(t, 404, sc.resp.Code) assert.Equal(t, 404, sc.resp.Code)
}, mock) }, userMock)
}) })
t.Run("When revoke an auth token for a user", func(t *testing.T) { t.Run("When revoke an auth token for a user", func(t *testing.T) {
cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2} cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2}
token := &models.UserToken{Id: 1} token := &models.UserToken{Id: 1}
mock := &mockstore.SQLStoreMock{ mockUser := &usertest.FakeUserService{
ExpectedUser: &user.User{ID: 200}, ExpectedUser: &user.User{ID: 200},
} }
@@ -71,25 +70,25 @@ func TestUserTokenAPIEndpoint(t *testing.T) {
} }
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
assert.Equal(t, 200, sc.resp.Code) assert.Equal(t, 200, sc.resp.Code)
}, mock) }, mockUser)
}) })
t.Run("When revoke the active auth token used by himself", func(t *testing.T) { t.Run("When revoke the active auth token used by himself", func(t *testing.T) {
cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2} cmd := models.RevokeAuthTokenCmd{AuthTokenId: 2}
token := &models.UserToken{Id: 2} token := &models.UserToken{Id: 2}
mock := mockstore.NewSQLStoreMock() mockUser := usertest.NewUserServiceFake()
revokeUserAuthTokenInternalScenario(t, "Should not be successful", cmd, testUserID, token, func(sc *scenarioContext) { revokeUserAuthTokenInternalScenario(t, "Should not be successful", cmd, testUserID, token, func(sc *scenarioContext) {
sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) { sc.userAuthTokenService.GetUserTokenProvider = func(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) {
return token, nil return token, nil
} }
sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec() sc.fakeReqWithParams("POST", sc.url, map[string]string{}).exec()
assert.Equal(t, 400, sc.resp.Code) assert.Equal(t, 400, sc.resp.Code)
}, mock) }, mockUser)
}) })
t.Run("When gets auth tokens for a user", func(t *testing.T) { t.Run("When gets auth tokens for a user", func(t *testing.T) {
currentToken := &models.UserToken{Id: 1} currentToken := &models.UserToken{Id: 1}
mock := mockstore.NewSQLStoreMock() mockUser := usertest.NewUserServiceFake()
getUserAuthTokensInternalScenario(t, "Should be successful", currentToken, func(sc *scenarioContext) { getUserAuthTokensInternalScenario(t, "Should be successful", currentToken, func(sc *scenarioContext) {
tokens := []*models.UserToken{ tokens := []*models.UserToken{
{ {
@@ -141,18 +140,18 @@ func TestUserTokenAPIEndpoint(t *testing.T) {
assert.Equal(t, "11.0", resultTwo.Get("browserVersion").MustString()) assert.Equal(t, "11.0", resultTwo.Get("browserVersion").MustString())
assert.Equal(t, "iOS", resultTwo.Get("os").MustString()) assert.Equal(t, "iOS", resultTwo.Get("os").MustString())
assert.Equal(t, "11.0", resultTwo.Get("osVersion").MustString()) assert.Equal(t, "11.0", resultTwo.Get("osVersion").MustString())
}, mock) }, mockUser)
}) })
} }
func revokeUserAuthTokenScenario(t *testing.T, desc string, url string, routePattern string, cmd models.RevokeAuthTokenCmd, func revokeUserAuthTokenScenario(t *testing.T, desc string, url string, routePattern string, cmd models.RevokeAuthTokenCmd,
userId int64, fn scenarioFunc, sqlStore sqlstore.Store) { userId int64, fn scenarioFunc, userService user.Service) {
t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) {
fakeAuthTokenService := auth.NewFakeUserAuthTokenService() fakeAuthTokenService := auth.NewFakeUserAuthTokenService()
hs := HTTPServer{ hs := HTTPServer{
AuthTokenService: fakeAuthTokenService, AuthTokenService: fakeAuthTokenService,
SQLStore: sqlStore, userService: userService,
} }
sc := setupScenarioContext(t, url) sc := setupScenarioContext(t, url)
@@ -173,13 +172,13 @@ func revokeUserAuthTokenScenario(t *testing.T, desc string, url string, routePat
}) })
} }
func getUserAuthTokensScenario(t *testing.T, desc string, url string, routePattern string, userId int64, fn scenarioFunc, sqlStore sqlstore.Store) { func getUserAuthTokensScenario(t *testing.T, desc string, url string, routePattern string, userId int64, fn scenarioFunc, userService user.Service) {
t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) { t.Run(fmt.Sprintf("%s %s", desc, url), func(t *testing.T) {
fakeAuthTokenService := auth.NewFakeUserAuthTokenService() fakeAuthTokenService := auth.NewFakeUserAuthTokenService()
hs := HTTPServer{ hs := HTTPServer{
AuthTokenService: fakeAuthTokenService, AuthTokenService: fakeAuthTokenService,
SQLStore: sqlStore, userService: userService,
} }
sc := setupScenarioContext(t, url) sc := setupScenarioContext(t, url)
@@ -199,11 +198,11 @@ func getUserAuthTokensScenario(t *testing.T, desc string, url string, routePatte
}) })
} }
func logoutUserFromAllDevicesInternalScenario(t *testing.T, desc string, userId int64, fn scenarioFunc, sqlStore sqlstore.Store) { func logoutUserFromAllDevicesInternalScenario(t *testing.T, desc string, userId int64, fn scenarioFunc, userService user.Service) {
t.Run(desc, func(t *testing.T) { t.Run(desc, func(t *testing.T) {
hs := HTTPServer{ hs := HTTPServer{
AuthTokenService: auth.NewFakeUserAuthTokenService(), AuthTokenService: auth.NewFakeUserAuthTokenService(),
SQLStore: sqlStore, userService: userService,
} }
sc := setupScenarioContext(t, "/") sc := setupScenarioContext(t, "/")
@@ -223,13 +222,13 @@ func logoutUserFromAllDevicesInternalScenario(t *testing.T, desc string, userId
} }
func revokeUserAuthTokenInternalScenario(t *testing.T, desc string, cmd models.RevokeAuthTokenCmd, userId int64, func revokeUserAuthTokenInternalScenario(t *testing.T, desc string, cmd models.RevokeAuthTokenCmd, userId int64,
token *models.UserToken, fn scenarioFunc, sqlStore sqlstore.Store) { token *models.UserToken, fn scenarioFunc, userService user.Service) {
t.Run(desc, func(t *testing.T) { t.Run(desc, func(t *testing.T) {
fakeAuthTokenService := auth.NewFakeUserAuthTokenService() fakeAuthTokenService := auth.NewFakeUserAuthTokenService()
hs := HTTPServer{ hs := HTTPServer{
AuthTokenService: fakeAuthTokenService, AuthTokenService: fakeAuthTokenService,
SQLStore: sqlStore, userService: userService,
} }
sc := setupScenarioContext(t, "/") sc := setupScenarioContext(t, "/")
@@ -248,13 +247,13 @@ func revokeUserAuthTokenInternalScenario(t *testing.T, desc string, cmd models.R
}) })
} }
func getUserAuthTokensInternalScenario(t *testing.T, desc string, token *models.UserToken, fn scenarioFunc, sqlStore sqlstore.Store) { func getUserAuthTokensInternalScenario(t *testing.T, desc string, token *models.UserToken, fn scenarioFunc, userService user.Service) {
t.Run(desc, func(t *testing.T) { t.Run(desc, func(t *testing.T) {
fakeAuthTokenService := auth.NewFakeUserAuthTokenService() fakeAuthTokenService := auth.NewFakeUserAuthTokenService()
hs := HTTPServer{ hs := HTTPServer{
AuthTokenService: fakeAuthTokenService, AuthTokenService: fakeAuthTokenService,
SQLStore: sqlStore, userService: userService,
} }
sc := setupScenarioContext(t, "/") sc := setupScenarioContext(t, "/")

View File

@@ -151,7 +151,7 @@ var adminCommands = []*cli.Command{
{ {
Name: "reset-admin-password", Name: "reset-admin-password",
Usage: "reset-admin-password <new password>", Usage: "reset-admin-password <new password>",
Action: runDbCommand(resetPasswordCommand), Action: runRunnerCommand(resetPasswordCommand),
Flags: []cli.Flag{ Flags: []cli.Flag{
&cli.BoolFlag{ &cli.BoolFlag{
Name: "password-from-stdin", Name: "password-from-stdin",

View File

@@ -8,15 +8,16 @@ import (
"github.com/fatih/color" "github.com/fatih/color"
"github.com/grafana/grafana/pkg/cmd/grafana-cli/logger" "github.com/grafana/grafana/pkg/cmd/grafana-cli/logger"
"github.com/grafana/grafana/pkg/cmd/grafana-cli/runner"
"github.com/grafana/grafana/pkg/cmd/grafana-cli/utils" "github.com/grafana/grafana/pkg/cmd/grafana-cli/utils"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
) )
const AdminUserId = 1 const AdminUserId = 1
func resetPasswordCommand(c utils.CommandLine, sqlStore *sqlstore.SQLStore) error { func resetPasswordCommand(c utils.CommandLine, runner runner.Runner) error {
newPassword := "" newPassword := ""
if c.Bool("password-from-stdin") { if c.Bool("password-from-stdin") {
@@ -39,13 +40,14 @@ func resetPasswordCommand(c utils.CommandLine, sqlStore *sqlstore.SQLStore) erro
return fmt.Errorf("new password is too short") return fmt.Errorf("new password is too short")
} }
userQuery := models.GetUserByIdQuery{Id: AdminUserId} userQuery := user.GetUserByIDQuery{ID: AdminUserId}
if err := sqlStore.GetUserById(context.Background(), &userQuery); err != nil { usr, err := runner.UserService.GetByID(context.Background(), &userQuery)
if err != nil {
return fmt.Errorf("could not read user from database. Error: %v", err) return fmt.Errorf("could not read user from database. Error: %v", err)
} }
passwordHashed, err := util.EncodePassword(newPassword, userQuery.Result.Salt) passwordHashed, err := util.EncodePassword(newPassword, usr.Salt)
if err != nil { if err != nil {
return err return err
} }
@@ -55,7 +57,7 @@ func resetPasswordCommand(c utils.CommandLine, sqlStore *sqlstore.SQLStore) erro
NewPassword: passwordHashed, NewPassword: passwordHashed,
} }
if err := sqlStore.ChangeUserPassword(context.Background(), &cmd); err != nil { if err := runner.SQLStore.ChangeUserPassword(context.Background(), &cmd); err != nil {
return fmt.Errorf("failed to update user password: %w", err) return fmt.Errorf("failed to update user password: %w", err)
} }

View File

@@ -6,6 +6,7 @@ import (
"github.com/grafana/grafana/pkg/services/secrets" "github.com/grafana/grafana/pkg/services/secrets"
"github.com/grafana/grafana/pkg/services/secrets/manager" "github.com/grafana/grafana/pkg/services/secrets/manager"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
) )
@@ -17,11 +18,13 @@ type Runner struct {
EncryptionService encryption.Internal EncryptionService encryption.Internal
SecretsService *manager.SecretsService SecretsService *manager.SecretsService
SecretsMigrator secrets.Migrator SecretsMigrator secrets.Migrator
UserService user.Service
} }
func New(cfg *setting.Cfg, sqlStore *sqlstore.SQLStore, settingsProvider setting.Provider, func New(cfg *setting.Cfg, sqlStore *sqlstore.SQLStore, settingsProvider setting.Provider,
encryptionService encryption.Internal, features featuremgmt.FeatureToggles, encryptionService encryption.Internal, features featuremgmt.FeatureToggles,
secretsService *manager.SecretsService, secretsMigrator secrets.Migrator, secretsService *manager.SecretsService, secretsMigrator secrets.Migrator,
userService user.Service,
) Runner { ) Runner {
return Runner{ return Runner{
Cfg: cfg, Cfg: cfg,
@@ -31,5 +34,6 @@ func New(cfg *setting.Cfg, sqlStore *sqlstore.SQLStore, settingsProvider setting
SecretsService: secretsService, SecretsService: secretsService,
SecretsMigrator: secretsMigrator, SecretsMigrator: secretsMigrator,
Features: features, Features: features,
UserService: userService,
} }
} }

View File

@@ -7,21 +7,132 @@ import (
"context" "context"
"github.com/google/wire" "github.com/google/wire"
sdkhttpclient "github.com/grafana/grafana-plugin-sdk-go/backend/httpclient"
"github.com/grafana/grafana/pkg/api"
"github.com/grafana/grafana/pkg/api/avatar"
"github.com/grafana/grafana/pkg/api/routing" "github.com/grafana/grafana/pkg/api/routing"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/cuectx"
"github.com/grafana/grafana/pkg/expr"
cmreg "github.com/grafana/grafana/pkg/framework/coremodel/registry"
"github.com/grafana/grafana/pkg/infra/httpclient"
"github.com/grafana/grafana/pkg/infra/httpclient/httpclientprovider"
"github.com/grafana/grafana/pkg/infra/kvstore"
"github.com/grafana/grafana/pkg/infra/localcache" "github.com/grafana/grafana/pkg/infra/localcache"
"github.com/grafana/grafana/pkg/infra/metrics"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/infra/serverlock"
"github.com/grafana/grafana/pkg/infra/tracing" "github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/infra/usagestats" "github.com/grafana/grafana/pkg/infra/usagestats"
uss "github.com/grafana/grafana/pkg/infra/usagestats/service"
"github.com/grafana/grafana/pkg/infra/usagestats/statscollector"
loginpkg "github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/middleware/csrf"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin/coreplugin"
"github.com/grafana/grafana/pkg/plugins/manager"
"github.com/grafana/grafana/pkg/plugins/manager/loader"
"github.com/grafana/grafana/pkg/plugins/manager/registry"
"github.com/grafana/grafana/pkg/plugins/plugincontext"
"github.com/grafana/grafana/pkg/services/accesscontrol"
"github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol"
"github.com/grafana/grafana/pkg/services/alerting"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/auth/jwt"
"github.com/grafana/grafana/pkg/services/cleanup"
"github.com/grafana/grafana/pkg/services/comments"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/dashboardimport"
dashboardimportservice "github.com/grafana/grafana/pkg/services/dashboardimport/service"
"github.com/grafana/grafana/pkg/services/dashboards"
dashboardstore "github.com/grafana/grafana/pkg/services/dashboards/database"
dashboardservice "github.com/grafana/grafana/pkg/services/dashboards/service"
"github.com/grafana/grafana/pkg/services/dashboardsnapshots"
dashsnapstore "github.com/grafana/grafana/pkg/services/dashboardsnapshots/database"
dashsnapsvc "github.com/grafana/grafana/pkg/services/dashboardsnapshots/service"
"github.com/grafana/grafana/pkg/services/dashboardversion/dashverimpl"
"github.com/grafana/grafana/pkg/services/datasourceproxy"
"github.com/grafana/grafana/pkg/services/datasources"
datasourceservice "github.com/grafana/grafana/pkg/services/datasources/service"
"github.com/grafana/grafana/pkg/services/encryption" "github.com/grafana/grafana/pkg/services/encryption"
encryptionservice "github.com/grafana/grafana/pkg/services/encryption/service" encryptionservice "github.com/grafana/grafana/pkg/services/encryption/service"
"github.com/grafana/grafana/pkg/services/export"
"github.com/grafana/grafana/pkg/services/featuremgmt" "github.com/grafana/grafana/pkg/services/featuremgmt"
"github.com/grafana/grafana/pkg/services/guardian"
"github.com/grafana/grafana/pkg/services/hooks" "github.com/grafana/grafana/pkg/services/hooks"
"github.com/grafana/grafana/pkg/services/libraryelements"
"github.com/grafana/grafana/pkg/services/librarypanels"
"github.com/grafana/grafana/pkg/services/live"
"github.com/grafana/grafana/pkg/services/live/pushhttp"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
authinfodatabase "github.com/grafana/grafana/pkg/services/login/authinfoservice/database"
"github.com/grafana/grafana/pkg/services/login/loginservice"
"github.com/grafana/grafana/pkg/services/ngalert"
ngmetrics "github.com/grafana/grafana/pkg/services/ngalert/metrics"
"github.com/grafana/grafana/pkg/services/notifications"
"github.com/grafana/grafana/pkg/services/oauthtoken"
"github.com/grafana/grafana/pkg/services/org/orgimpl"
"github.com/grafana/grafana/pkg/services/playlist/playlistimpl"
"github.com/grafana/grafana/pkg/services/plugindashboards"
plugindashboardsservice "github.com/grafana/grafana/pkg/services/plugindashboards/service"
"github.com/grafana/grafana/pkg/services/pluginsettings"
pluginSettings "github.com/grafana/grafana/pkg/services/pluginsettings/service"
"github.com/grafana/grafana/pkg/services/preference/prefimpl"
"github.com/grafana/grafana/pkg/services/publicdashboards"
publicdashboardsApi "github.com/grafana/grafana/pkg/services/publicdashboards/api"
publicdashboardsStore "github.com/grafana/grafana/pkg/services/publicdashboards/database"
publicdashboardsService "github.com/grafana/grafana/pkg/services/publicdashboards/service"
"github.com/grafana/grafana/pkg/services/query"
"github.com/grafana/grafana/pkg/services/queryhistory"
"github.com/grafana/grafana/pkg/services/quota/quotaimpl"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/search"
"github.com/grafana/grafana/pkg/services/searchV2"
"github.com/grafana/grafana/pkg/services/secrets" "github.com/grafana/grafana/pkg/services/secrets"
secretsDatabase "github.com/grafana/grafana/pkg/services/secrets/database" secretsDatabase "github.com/grafana/grafana/pkg/services/secrets/database"
secretsStore "github.com/grafana/grafana/pkg/services/secrets/kvstore"
secretsMigrations "github.com/grafana/grafana/pkg/services/secrets/kvstore/migrations"
secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager" secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager"
secretsMigrator "github.com/grafana/grafana/pkg/services/secrets/migrator" secretsMigrator "github.com/grafana/grafana/pkg/services/secrets/migrator"
"github.com/grafana/grafana/pkg/services/serviceaccounts"
"github.com/grafana/grafana/pkg/services/serviceaccounts/database"
serviceaccountsmanager "github.com/grafana/grafana/pkg/services/serviceaccounts/manager"
"github.com/grafana/grafana/pkg/services/shorturls"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/db"
"github.com/grafana/grafana/pkg/services/sqlstore/mockstore"
"github.com/grafana/grafana/pkg/services/star/starimpl"
"github.com/grafana/grafana/pkg/services/store"
"github.com/grafana/grafana/pkg/services/store/sanitizer"
"github.com/grafana/grafana/pkg/services/teamguardian"
teamguardianDatabase "github.com/grafana/grafana/pkg/services/teamguardian/database"
teamguardianManager "github.com/grafana/grafana/pkg/services/teamguardian/manager"
"github.com/grafana/grafana/pkg/services/thumbs"
"github.com/grafana/grafana/pkg/services/updatechecker"
"github.com/grafana/grafana/pkg/services/user/userimpl"
"github.com/grafana/grafana/pkg/services/userauth/userauthimpl"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/tsdb/azuremonitor"
"github.com/grafana/grafana/pkg/tsdb/cloudmonitoring"
"github.com/grafana/grafana/pkg/tsdb/cloudwatch"
"github.com/grafana/grafana/pkg/tsdb/elasticsearch"
"github.com/grafana/grafana/pkg/tsdb/grafanads"
"github.com/grafana/grafana/pkg/tsdb/graphite"
"github.com/grafana/grafana/pkg/tsdb/influxdb"
"github.com/grafana/grafana/pkg/tsdb/legacydata"
legacydataservice "github.com/grafana/grafana/pkg/tsdb/legacydata/service"
"github.com/grafana/grafana/pkg/tsdb/loki"
"github.com/grafana/grafana/pkg/tsdb/mssql"
"github.com/grafana/grafana/pkg/tsdb/mysql"
"github.com/grafana/grafana/pkg/tsdb/postgres"
"github.com/grafana/grafana/pkg/tsdb/prometheus"
"github.com/grafana/grafana/pkg/tsdb/tempo"
"github.com/grafana/grafana/pkg/tsdb/testdatasource"
"github.com/grafana/grafana/pkg/web" "github.com/grafana/grafana/pkg/web"
) )
@@ -42,9 +153,171 @@ var wireSet = wire.NewSet(
wire.Bind(new(secrets.Store), new(*secretsDatabase.SecretsStoreImpl)), wire.Bind(new(secrets.Store), new(*secretsDatabase.SecretsStoreImpl)),
secretsManager.ProvideSecretsService, secretsManager.ProvideSecretsService,
wire.Bind(new(secrets.Service), new(*secretsManager.SecretsService)), wire.Bind(new(secrets.Service), new(*secretsManager.SecretsService)),
hooks.ProvideService,
legacydataservice.ProvideService,
wire.Bind(new(legacydata.RequestHandler), new(*legacydataservice.Service)),
alerting.ProvideAlertEngine,
wire.Bind(new(alerting.UsageStatsQuerier), new(*alerting.AlertEngine)),
api.ProvideHTTPServer,
query.ProvideService,
thumbs.ProvideService,
rendering.ProvideService,
wire.Bind(new(rendering.Service), new(*rendering.RenderingService)),
kvstore.ProvideService,
updatechecker.ProvideGrafanaService,
updatechecker.ProvidePluginsService,
uss.ProvideService,
registry.ProvideService,
wire.Bind(new(registry.Service), new(*registry.InMemory)),
manager.ProvideService,
wire.Bind(new(plugins.Manager), new(*manager.PluginManager)),
wire.Bind(new(plugins.Client), new(*manager.PluginManager)),
wire.Bind(new(plugins.Store), new(*manager.PluginManager)),
wire.Bind(new(plugins.DashboardFileStore), new(*manager.PluginManager)),
wire.Bind(new(plugins.StaticRouteResolver), new(*manager.PluginManager)),
wire.Bind(new(plugins.RendererManager), new(*manager.PluginManager)),
wire.Bind(new(plugins.SecretsPluginManager), new(*manager.PluginManager)),
coreplugin.ProvideCoreRegistry,
loader.ProvideService,
wire.Bind(new(loader.Service), new(*loader.Loader)),
wire.Bind(new(plugins.ErrorResolver), new(*loader.Loader)),
cloudwatch.ProvideService,
cloudmonitoring.ProvideService,
azuremonitor.ProvideService,
postgres.ProvideService,
mysql.ProvideService,
mssql.ProvideService,
store.ProvideEntityEventsService,
httpclientprovider.New,
wire.Bind(new(httpclient.Provider), new(*sdkhttpclient.Provider)),
serverlock.ProvideService,
cleanup.ProvideService,
shorturls.ProvideService,
wire.Bind(new(shorturls.Service), new(*shorturls.ShortURLService)),
queryhistory.ProvideService,
wire.Bind(new(queryhistory.Service), new(*queryhistory.QueryHistoryService)),
quotaimpl.ProvideService,
remotecache.ProvideService,
loginservice.ProvideService,
wire.Bind(new(login.Service), new(*loginservice.Implementation)),
authinfoservice.ProvideAuthInfoService,
wire.Bind(new(login.AuthInfoService), new(*authinfoservice.Implementation)),
authinfodatabase.ProvideAuthInfoStore,
loginpkg.ProvideService,
wire.Bind(new(loginpkg.Authenticator), new(*loginpkg.AuthenticatorService)),
datasourceproxy.ProvideService,
search.ProvideService,
searchV2.ProvideService,
store.ProvideService,
export.ProvideService,
live.ProvideService,
pushhttp.ProvideService,
plugincontext.ProvideService,
contexthandler.ProvideService,
jwt.ProvideService,
wire.Bind(new(models.JWTService), new(*jwt.AuthService)),
ngalert.ProvideService,
librarypanels.ProvideService,
wire.Bind(new(librarypanels.Service), new(*librarypanels.LibraryPanelService)),
libraryelements.ProvideService,
wire.Bind(new(libraryelements.Service), new(*libraryelements.LibraryElementService)),
notifications.ProvideService,
notifications.ProvideSmtpService,
metrics.ProvideService,
testdatasource.ProvideService,
social.ProvideService,
influxdb.ProvideService,
wire.Bind(new(social.Service), new(*social.SocialService)),
oauthtoken.ProvideService,
auth.ProvideActiveAuthTokenService,
wire.Bind(new(models.ActiveTokenService), new(*auth.ActiveAuthTokenService)),
wire.Bind(new(oauthtoken.OAuthTokenService), new(*oauthtoken.Service)),
tempo.ProvideService,
loki.ProvideService,
graphite.ProvideService,
prometheus.ProvideService,
elasticsearch.ProvideService,
secretsMigrator.ProvideSecretsMigrator, secretsMigrator.ProvideSecretsMigrator,
wire.Bind(new(secrets.Migrator), new(*secretsMigrator.SecretsMigrator)), wire.Bind(new(secrets.Migrator), new(*secretsMigrator.SecretsMigrator)),
hooks.ProvideService, grafanads.ProvideService,
wire.Bind(new(dashboardsnapshots.Store), new(*dashsnapstore.DashboardSnapshotStore)),
dashsnapstore.ProvideStore,
wire.Bind(new(dashboardsnapshots.Service), new(*dashsnapsvc.ServiceImpl)),
dashsnapsvc.ProvideService,
datasourceservice.ProvideService,
wire.Bind(new(datasources.DataSourceService), new(*datasourceservice.Service)),
pluginSettings.ProvideService,
wire.Bind(new(pluginsettings.Service), new(*pluginSettings.Service)),
alerting.ProvideService,
database.ProvideServiceAccountsStore,
wire.Bind(new(serviceaccounts.Store), new(*database.ServiceAccountsStoreImpl)),
ossaccesscontrol.ProvideServiceAccountPermissions,
wire.Bind(new(accesscontrol.ServiceAccountPermissionsService), new(*ossaccesscontrol.ServiceAccountPermissionsService)),
serviceaccountsmanager.ProvideServiceAccountsService,
wire.Bind(new(serviceaccounts.Service), new(*serviceaccountsmanager.ServiceAccountsService)),
expr.ProvideService,
teamguardianDatabase.ProvideTeamGuardianStore,
wire.Bind(new(teamguardian.Store), new(*teamguardianDatabase.TeamGuardianStoreImpl)),
teamguardianManager.ProvideService,
dashboardservice.ProvideDashboardService,
dashboardservice.ProvideFolderService,
dashboardstore.ProvideDashboardStore,
wire.Bind(new(dashboards.DashboardService), new(*dashboardservice.DashboardServiceImpl)),
wire.Bind(new(dashboards.DashboardProvisioningService), new(*dashboardservice.DashboardServiceImpl)),
wire.Bind(new(dashboards.PluginService), new(*dashboardservice.DashboardServiceImpl)),
wire.Bind(new(dashboards.FolderService), new(*dashboardservice.FolderServiceImpl)),
wire.Bind(new(dashboards.Store), new(*dashboardstore.DashboardStore)),
dashboardimportservice.ProvideService,
wire.Bind(new(dashboardimport.Service), new(*dashboardimportservice.ImportDashboardService)),
plugindashboardsservice.ProvideService,
wire.Bind(new(plugindashboards.Service), new(*plugindashboardsservice.Service)),
plugindashboardsservice.ProvideDashboardUpdater,
alerting.ProvideDashAlertExtractorService,
wire.Bind(new(alerting.DashAlertExtractor), new(*alerting.DashAlertExtractorService)),
comments.ProvideService,
guardian.ProvideService,
sanitizer.ProvideService,
secretsStore.ProvideService,
avatar.ProvideAvatarCacheServer,
authproxy.ProvideAuthProxy,
statscollector.ProvideService,
cmreg.CoremodelSet,
cuectx.ProvideCUEContext,
cuectx.ProvideThemaLibrary,
csrf.ProvideCSRFFilter,
ossaccesscontrol.ProvideTeamPermissions,
wire.Bind(new(accesscontrol.TeamPermissionsService), new(*ossaccesscontrol.TeamPermissionsService)),
ossaccesscontrol.ProvideFolderPermissions,
wire.Bind(new(accesscontrol.FolderPermissionsService), new(*ossaccesscontrol.FolderPermissionsService)),
ossaccesscontrol.ProvideDashboardPermissions,
wire.Bind(new(accesscontrol.DashboardPermissionsService), new(*ossaccesscontrol.DashboardPermissionsService)),
starimpl.ProvideService,
playlistimpl.ProvideService,
dashverimpl.ProvideService,
publicdashboardsService.ProvideService,
wire.Bind(new(publicdashboards.Service), new(*publicdashboardsService.PublicDashboardServiceImpl)),
publicdashboardsStore.ProvideStore,
wire.Bind(new(publicdashboards.Store), new(*publicdashboardsStore.PublicDashboardStoreImpl)),
publicdashboardsApi.ProvideApi,
userimpl.ProvideService,
orgimpl.ProvideService,
datasourceservice.ProvideDataSourceMigrationService,
secretsStore.ProvidePluginSecretMigrationService,
secretsMigrations.ProvideSecretMigrationService,
wire.Bind(new(secretsMigrations.SecretMigrationService), new(*secretsMigrations.SecretMigrationServiceImpl)),
userauthimpl.ProvideService,
ngmetrics.ProvideServiceForTest,
wire.Bind(new(alerting.AlertStore), new(*sqlstore.SQLStore)),
wire.Bind(new(sqlstore.TeamStore), new(*sqlstore.SQLStore)),
notifications.MockNotificationService,
wire.Bind(new(notifications.TempUserStore), new(*mockstore.SQLStoreMock)),
wire.Bind(new(notifications.Service), new(*notifications.NotificationServiceMock)),
wire.Bind(new(notifications.WebhookSender), new(*notifications.NotificationServiceMock)),
wire.Bind(new(notifications.EmailSender), new(*notifications.NotificationServiceMock)),
mockstore.NewSQLStoreMock,
wire.Bind(new(sqlstore.Store), new(*sqlstore.SQLStore)),
wire.Bind(new(db.DB), new(*sqlstore.SQLStore)),
prefimpl.ProvideService,
) )
func Initialize(cfg *setting.Cfg) (Runner, error) { func Initialize(cfg *setting.Cfg) (Runner, error) {

View File

@@ -5,14 +5,37 @@ package runner
import ( import (
"github.com/google/wire" "github.com/google/wire"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/plugins/backendplugin/provider"
"github.com/grafana/grafana/pkg/plugins/manager/signature"
"github.com/grafana/grafana/pkg/registry" "github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/server/backgroundsvcs"
"github.com/grafana/grafana/pkg/server/usagestatssvcs"
"github.com/grafana/grafana/pkg/services/accesscontrol"
acdb "github.com/grafana/grafana/pkg/services/accesscontrol/database"
"github.com/grafana/grafana/pkg/services/accesscontrol/ossaccesscontrol"
"github.com/grafana/grafana/pkg/services/accesscontrol/resourcepermissions"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/datasources/permissions"
datasourceservice "github.com/grafana/grafana/pkg/services/datasources/service"
"github.com/grafana/grafana/pkg/services/encryption" "github.com/grafana/grafana/pkg/services/encryption"
encryptionprovider "github.com/grafana/grafana/pkg/services/encryption/provider" encryptionprovider "github.com/grafana/grafana/pkg/services/encryption/provider"
"github.com/grafana/grafana/pkg/services/kmsproviders" "github.com/grafana/grafana/pkg/services/kmsproviders"
"github.com/grafana/grafana/pkg/services/kmsproviders/osskmsproviders" "github.com/grafana/grafana/pkg/services/kmsproviders/osskmsproviders"
"github.com/grafana/grafana/pkg/services/ldap"
"github.com/grafana/grafana/pkg/services/licensing" "github.com/grafana/grafana/pkg/services/licensing"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoservice"
"github.com/grafana/grafana/pkg/services/provisioning"
"github.com/grafana/grafana/pkg/services/searchusers"
"github.com/grafana/grafana/pkg/services/searchusers/filters"
secretsStore "github.com/grafana/grafana/pkg/services/secrets/kvstore"
"github.com/grafana/grafana/pkg/services/sqlstore/migrations" "github.com/grafana/grafana/pkg/services/sqlstore/migrations"
"github.com/grafana/grafana/pkg/services/thumbs"
"github.com/grafana/grafana/pkg/services/validations"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
) )
@@ -26,6 +49,47 @@ var wireExtsSet = wire.NewSet(
wire.Bind(new(setting.Provider), new(*setting.OSSImpl)), wire.Bind(new(setting.Provider), new(*setting.OSSImpl)),
osskmsproviders.ProvideService, osskmsproviders.ProvideService,
wire.Bind(new(kmsproviders.Service), new(osskmsproviders.Service)), wire.Bind(new(kmsproviders.Service), new(osskmsproviders.Service)),
// ossencryption.ProvideService,
// wire.Bind(new(encryption.Internal), new(*ossencryption.Service)),
auth.ProvideUserAuthTokenService,
wire.Bind(new(models.UserTokenService), new(*auth.UserAuthTokenService)),
wire.Bind(new(models.UserTokenBackgroundService), new(*auth.UserAuthTokenService)),
ossaccesscontrol.ProvideService,
wire.Bind(new(accesscontrol.RoleRegistry), new(*ossaccesscontrol.OSSAccessControlService)),
wire.Bind(new(accesscontrol.AccessControl), new(*ossaccesscontrol.OSSAccessControlService)),
thumbs.ProvideCrawlerAuthSetupService,
wire.Bind(new(thumbs.CrawlerAuthSetupService), new(*thumbs.OSSCrawlerAuthSetupService)),
validations.ProvideValidator,
wire.Bind(new(models.PluginRequestValidator), new(*validations.OSSPluginRequestValidator)),
provisioning.ProvideService,
wire.Bind(new(provisioning.ProvisioningService), new(*provisioning.ProvisioningServiceImpl)),
backgroundsvcs.ProvideBackgroundServiceRegistry,
wire.Bind(new(registry.BackgroundServiceRegistry), new(*backgroundsvcs.BackgroundServiceRegistry)),
datasourceservice.ProvideCacheService,
wire.Bind(new(datasources.CacheService), new(*datasourceservice.CacheServiceImpl)),
authinfoservice.ProvideOSSUserProtectionService,
wire.Bind(new(login.UserProtectionService), new(*authinfoservice.OSSUserProtectionImpl)),
filters.ProvideOSSSearchUserFilter,
wire.Bind(new(models.SearchUserFilter), new(*filters.OSSSearchUserFilter)),
searchusers.ProvideUsersService,
wire.Bind(new(searchusers.Service), new(*searchusers.OSSService)),
signature.ProvideOSSAuthorizer,
wire.Bind(new(plugins.PluginLoaderAuthorizer), new(*signature.UnsignedPluginAuthorizer)),
provider.ProvideService,
wire.Bind(new(plugins.BackendFactoryProvider), new(*provider.Service)),
acdb.ProvideService,
wire.Bind(new(resourcepermissions.Store), new(*acdb.AccessControlStore)),
wire.Bind(new(accesscontrol.PermissionsStore), new(*acdb.AccessControlStore)),
ldap.ProvideGroupsService,
wire.Bind(new(ldap.Groups), new(*ldap.OSSGroups)),
permissions.ProvideDatasourcePermissionsService,
wire.Bind(new(permissions.DatasourcePermissionsService), new(*permissions.OSSDatasourcePermissionsService)),
usagestatssvcs.ProvideUsageStatsProvidersRegistry,
wire.Bind(new(registry.UsageStatsProvidersRegistry), new(*usagestatssvcs.UsageStatsProvidersRegistry)),
ossaccesscontrol.ProvideDatasourcePermissionsService,
wire.Bind(new(accesscontrol.DatasourcePermissionsService), new(*ossaccesscontrol.DatasourcePermissionsService)),
secretsStore.ProvideRemotePluginCheck,
wire.Bind(new(secretsStore.UseRemoteSecretsPluginCheck), new(*secretsStore.OSSRemoteSecretsPluginCheck)),
encryptionprovider.ProvideEncryptionProvider, encryptionprovider.ProvideEncryptionProvider,
wire.Bind(new(encryption.Provider), new(encryptionprovider.Provider)), wire.Bind(new(encryption.Provider), new(encryptionprovider.Provider)),
) )

View File

@@ -197,7 +197,6 @@ var wireBasicSet = wire.NewSet(
authinfoservice.ProvideAuthInfoService, authinfoservice.ProvideAuthInfoService,
wire.Bind(new(login.AuthInfoService), new(*authinfoservice.Implementation)), wire.Bind(new(login.AuthInfoService), new(*authinfoservice.Implementation)),
authinfodatabase.ProvideAuthInfoStore, authinfodatabase.ProvideAuthInfoStore,
wire.Bind(new(login.Store), new(*authinfodatabase.AuthInfoStore)),
loginpkg.ProvideService, loginpkg.ProvideService,
wire.Bind(new(loginpkg.Authenticator), new(*loginpkg.AuthenticatorService)), wire.Bind(new(loginpkg.Authenticator), new(*loginpkg.AuthenticatorService)),
datasourceproxy.ProvideService, datasourceproxy.ProvideService,

View File

@@ -7,6 +7,7 @@ import (
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/secrets" "github.com/grafana/grafana/pkg/services/secrets"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user"
@@ -18,13 +19,15 @@ type AuthInfoStore struct {
sqlStore sqlstore.Store sqlStore sqlstore.Store
secretsService secrets.Service secretsService secrets.Service
logger log.Logger logger log.Logger
userService user.Service
} }
func ProvideAuthInfoStore(sqlStore sqlstore.Store, secretsService secrets.Service) *AuthInfoStore { func ProvideAuthInfoStore(sqlStore sqlstore.Store, secretsService secrets.Service, userService user.Service) login.Store {
store := &AuthInfoStore{ store := &AuthInfoStore{
sqlStore: sqlStore, sqlStore: sqlStore,
secretsService: secretsService, secretsService: secretsService,
logger: log.New("login.authinfo.store"), logger: log.New("login.authinfo.store"),
userService: userService,
} }
InitMetrics() InitMetrics()
return store return store
@@ -221,12 +224,13 @@ func (s *AuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAu
} }
func (s *AuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) { func (s *AuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) {
query := models.GetUserByIdQuery{Id: id} query := user.GetUserByIDQuery{ID: id}
if err := s.sqlStore.GetUserById(ctx, &query); err != nil { user, err := s.userService.GetByID(ctx, &query)
if err != nil {
return nil, err return nil, err
} }
return query.Result, nil return user, nil
} }
func (s *AuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) { func (s *AuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) {

View File

@@ -2,62 +2,38 @@ package database
import ( import (
"context" "context"
"sync"
"time" "time"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/sqlstore/db" "github.com/grafana/grafana/pkg/services/sqlstore/db"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
type LoginStats struct {
DuplicateUserEntries int `xorm:"duplicate_user_entries"`
MixedCasedUsers int `xorm:"mixed_cased_users"`
}
const (
ExporterName = "grafana"
metricsCollectionInterval = time.Second * 60 * 4 // every 4 hours, indication of duplicate users
)
var (
// MStatDuplicateUserEntries is a indication metric gauge for number of users with duplicate emails or logins
MStatDuplicateUserEntries prometheus.Gauge
// MStatHasDuplicateEntries is a metric for if there is duplicate users
MStatHasDuplicateEntries prometheus.Gauge
// MStatMixedCasedUsers is a metric for if there is duplicate users
MStatMixedCasedUsers prometheus.Gauge
once sync.Once
Initialised bool = false
)
func InitMetrics() { func InitMetrics() {
once.Do(func() { login.Once.Do(func() {
MStatDuplicateUserEntries = prometheus.NewGauge(prometheus.GaugeOpts{ login.MStatDuplicateUserEntries = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "stat_users_total_duplicate_user_entries", Name: "stat_users_total_duplicate_user_entries",
Help: "total number of duplicate user entries by email or login", Help: "total number of duplicate user entries by email or login",
Namespace: ExporterName, Namespace: login.ExporterName,
}) })
MStatHasDuplicateEntries = prometheus.NewGauge(prometheus.GaugeOpts{ login.MStatHasDuplicateEntries = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "stat_users_has_duplicate_user_entries", Name: "stat_users_has_duplicate_user_entries",
Help: "instance has duplicate user entries by email or login", Help: "instance has duplicate user entries by email or login",
Namespace: ExporterName, Namespace: login.ExporterName,
}) })
MStatMixedCasedUsers = prometheus.NewGauge(prometheus.GaugeOpts{ login.MStatMixedCasedUsers = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "stat_users_total_mixed_cased_users", Name: "stat_users_total_mixed_cased_users",
Help: "total number of users with upper and lower case logins or emails", Help: "total number of users with upper and lower case logins or emails",
Namespace: ExporterName, Namespace: login.ExporterName,
}) })
prometheus.MustRegister( prometheus.MustRegister(
MStatDuplicateUserEntries, login.MStatDuplicateUserEntries,
MStatHasDuplicateEntries, login.MStatHasDuplicateEntries,
MStatMixedCasedUsers, login.MStatMixedCasedUsers,
) )
}) })
} }
@@ -66,7 +42,7 @@ func (s *AuthInfoStore) RunMetricsCollection(ctx context.Context) error {
if _, err := s.GetLoginStats(ctx); err != nil { if _, err := s.GetLoginStats(ctx); err != nil {
s.logger.Warn("Failed to get authinfo metrics", "error", err.Error()) s.logger.Warn("Failed to get authinfo metrics", "error", err.Error())
} }
updateStatsTicker := time.NewTicker(metricsCollectionInterval) updateStatsTicker := time.NewTicker(login.MetricsCollectionInterval)
defer updateStatsTicker.Stop() defer updateStatsTicker.Stop()
for { for {
@@ -81,8 +57,8 @@ func (s *AuthInfoStore) RunMetricsCollection(ctx context.Context) error {
} }
} }
func (s *AuthInfoStore) GetLoginStats(ctx context.Context) (LoginStats, error) { func (s *AuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, error) {
var stats LoginStats var stats login.LoginStats
outerErr := s.sqlStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error { outerErr := s.sqlStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
rawSQL := `SELECT rawSQL := `SELECT
(SELECT COUNT(*) FROM (` + s.duplicateUserEntriesSQL(ctx) + `) AS d WHERE (d.dup_login IS NOT NULL OR d.dup_email IS NOT NULL)) as duplicate_user_entries, (SELECT COUNT(*) FROM (` + s.duplicateUserEntriesSQL(ctx) + `) AS d WHERE (d.dup_login IS NOT NULL OR d.dup_email IS NOT NULL)) as duplicate_user_entries,
@@ -96,14 +72,14 @@ func (s *AuthInfoStore) GetLoginStats(ctx context.Context) (LoginStats, error) {
} }
// set prometheus metrics stats // set prometheus metrics stats
MStatDuplicateUserEntries.Set(float64(stats.DuplicateUserEntries)) login.MStatDuplicateUserEntries.Set(float64(stats.DuplicateUserEntries))
if stats.DuplicateUserEntries == 0 { if stats.DuplicateUserEntries == 0 {
MStatHasDuplicateEntries.Set(float64(0)) login.MStatHasDuplicateEntries.Set(float64(0))
} else { } else {
MStatHasDuplicateEntries.Set(float64(1)) login.MStatHasDuplicateEntries.Set(float64(1))
} }
MStatMixedCasedUsers.Set(float64(stats.MixedCasedUsers)) login.MStatMixedCasedUsers.Set(float64(stats.MixedCasedUsers))
return stats, nil return stats, nil
} }
@@ -115,14 +91,12 @@ func (s *AuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]inter
s.logger.Error("Failed to get login stats", "error", err) s.logger.Error("Failed to get login stats", "error", err)
return nil, err return nil, err
} }
m["stats.users.duplicate_user_entries"] = loginStats.DuplicateUserEntries m["stats.users.duplicate_user_entries"] = loginStats.DuplicateUserEntries
if loginStats.DuplicateUserEntries > 0 { if loginStats.DuplicateUserEntries > 0 {
m["stats.users.has_duplicate_user_entries"] = 1 m["stats.users.has_duplicate_user_entries"] = 1
} else { } else {
m["stats.users.has_duplicate_user_entries"] = 0 m["stats.users.has_duplicate_user_entries"] = 0
} }
m["stats.users.mixed_cased_users"] = loginStats.MixedCasedUsers m["stats.users.mixed_cased_users"] = loginStats.MixedCasedUsers
return m, nil return m, nil

View File

@@ -2,26 +2,26 @@ package authinfoservice
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"testing" "testing"
"time" "time"
"github.com/grafana/grafana/pkg/infra/usagestats"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/login/authinfoservice/database"
secretstore "github.com/grafana/grafana/pkg/services/secrets/database"
secretsManager "github.com/grafana/grafana/pkg/services/secrets/manager"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/infra/usagestats"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/login/authinfoservice/database"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/services/user"
) )
//nolint:goconst //nolint:goconst
func TestUserAuth(t *testing.T) { func TestUserAuth(t *testing.T) {
sqlStore := sqlstore.InitTestDB(t) sqlStore := sqlstore.InitTestDB(t)
secretsService := secretsManager.SetupTestService(t, secretstore.ProvideSecretsStore(sqlStore)) authInfoStore := newFakeAuthInfoStore()
authInfoStore := database.ProvideAuthInfoStore(sqlStore, secretsService)
srv := ProvideAuthInfoService( srv := ProvideAuthInfoService(
&OSSUserProtectionImpl{}, &OSSUserProtectionImpl{},
authInfoStore, authInfoStore,
@@ -42,7 +42,11 @@ func TestUserAuth(t *testing.T) {
t.Run("Can find existing user", func(t *testing.T) { t.Run("Can find existing user", func(t *testing.T) {
// By Login // By Login
login := "loginuser0" login := "loginuser0"
authInfoStore.ExpectedUser = &user.User{
Login: "loginuser0",
ID: 1,
Email: "user1@test.com",
}
query := &models.GetUserByAuthInfoQuery{UserLookupParams: models.UserLookupParams{Login: &login}} query := &models.GetUserByAuthInfoQuery{UserLookupParams: models.UserLookupParams{Login: &login}}
usr, err := srv.LookupAndUpdate(context.Background(), query) usr, err := srv.LookupAndUpdate(context.Background(), query)
@@ -69,6 +73,7 @@ func TestUserAuth(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, usr.Email, email) require.Equal(t, usr.Email, email)
authInfoStore.ExpectedUser = nil
// Don't find nonexistent user // Don't find nonexistent user
email = "nonexistent@test.com" email = "nonexistent@test.com"
@@ -82,6 +87,8 @@ func TestUserAuth(t *testing.T) {
t.Run("Can set & locate by AuthModule and AuthId", func(t *testing.T) { t.Run("Can set & locate by AuthModule and AuthId", func(t *testing.T) {
// get nonexistent user_auth entry // get nonexistent user_auth entry
authInfoStore.ExpectedUser = &user.User{}
authInfoStore.ExpectedError = user.ErrUserNotFound
query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"} query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
usr, err := srv.LookupAndUpdate(context.Background(), query) usr, err := srv.LookupAndUpdate(context.Background(), query)
@@ -90,7 +97,9 @@ func TestUserAuth(t *testing.T) {
// create user_auth entry // create user_auth entry
login := "loginuser0" login := "loginuser0"
authInfoStore.ExpectedUser = &user.User{Login: "loginuser0", ID: 1, Email: ""}
authInfoStore.ExpectedError = nil
authInfoStore.ExpectedOAuth = &models.UserAuth{Id: 1}
query.UserLookupParams.Login = &login query.UserLookupParams.Login = &login
usr, err = srv.LookupAndUpdate(context.Background(), query) usr, err = srv.LookupAndUpdate(context.Background(), query)
@@ -107,6 +116,7 @@ func TestUserAuth(t *testing.T) {
// get with non-matching id // get with non-matching id
idPlusOne := usr.ID + 1 idPlusOne := usr.ID + 1
authInfoStore.ExpectedUser.Login = "loginuser1"
query.UserLookupParams.UserID = &idPlusOne query.UserLookupParams.UserID = &idPlusOne
usr, err = srv.LookupAndUpdate(context.Background(), query) usr, err = srv.LookupAndUpdate(context.Background(), query)
@@ -127,6 +137,8 @@ func TestUserAuth(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
authInfoStore.ExpectedUser = nil
authInfoStore.ExpectedError = user.ErrUserNotFound
// get via user_auth for deleted user // get via user_auth for deleted user
query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"} query = &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test"}
usr, err = srv.LookupAndUpdate(context.Background(), query) usr, err = srv.LookupAndUpdate(context.Background(), query)
@@ -147,7 +159,16 @@ func TestUserAuth(t *testing.T) {
// Find a user to set tokens on // Find a user to set tokens on
login := "loginuser0" login := "loginuser0"
authInfoStore.ExpectedUser = &user.User{Login: "loginuser0", ID: 1, Email: ""}
authInfoStore.ExpectedError = nil
authInfoStore.ExpectedOAuth = &models.UserAuth{
Id: 1,
OAuthAccessToken: token.AccessToken,
OAuthRefreshToken: token.RefreshToken,
OAuthTokenType: token.TokenType,
OAuthIdToken: idToken,
OAuthExpiry: token.Expiry,
}
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table // Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test", UserLookupParams: models.UserLookupParams{ query := &models.GetUserByAuthInfoQuery{AuthModule: "test", AuthId: "test", UserLookupParams: models.UserLookupParams{
Login: &login, Login: &login,
@@ -220,7 +241,7 @@ func TestUserAuth(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, user.Login, login) require.Equal(t, user.Login, login)
authInfoStore.ExpectedOAuth.AuthModule = "test2"
// Get the latest entry by not supply an authmodule or authid // Get the latest entry by not supply an authmodule or authid
getAuthQuery := &models.GetAuthInfoQuery{ getAuthQuery := &models.GetAuthInfoQuery{
UserId: user.ID, UserId: user.ID,
@@ -236,7 +257,7 @@ func TestUserAuth(t *testing.T) {
err = authInfoStore.UpdateAuthInfo(context.Background(), updateAuthCmd) err = authInfoStore.UpdateAuthInfo(context.Background(), updateAuthCmd)
require.Nil(t, err) require.Nil(t, err)
authInfoStore.ExpectedOAuth.AuthModule = "test1"
// Get the latest entry by not supply an authmodule or authid // Get the latest entry by not supply an authmodule or authid
getAuthQuery = &models.GetAuthInfoQuery{ getAuthQuery = &models.GetAuthInfoQuery{
UserId: user.ID, UserId: user.ID,
@@ -292,6 +313,7 @@ func TestUserAuth(t *testing.T) {
getAuthQuery := &models.GetAuthInfoQuery{ getAuthQuery := &models.GetAuthInfoQuery{
UserId: user.ID, UserId: user.ID,
} }
authInfoStore.ExpectedOAuth.AuthModule = "test2"
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery) err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
@@ -314,7 +336,8 @@ func TestUserAuth(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, user.Login, login) require.Equal(t, user.Login, login)
authInfoStore.ExpectedOAuth.AuthModule = "test1"
authInfoStore.ExpectedOAuth.OAuthAccessToken = "access_token"
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery) err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err) require.Nil(t, err)
@@ -327,6 +350,7 @@ func TestUserAuth(t *testing.T) {
user, err = srv.LookupAndUpdate(context.Background(), queryTwo) user, err = srv.LookupAndUpdate(context.Background(), queryTwo)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, user.Login, login) require.Equal(t, user.Login, login)
authInfoStore.ExpectedOAuth.AuthModule = "test2"
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery) err = authInfoStore.GetAuthInfo(context.Background(), getAuthQuery)
require.Nil(t, err) require.Nil(t, err)
@@ -337,10 +361,11 @@ func TestUserAuth(t *testing.T) {
UserId: user.ID, UserId: user.ID,
AuthModule: "test1", AuthModule: "test1",
} }
authInfoStore.ExpectedOAuth.AuthModule = "test1"
err = authInfoStore.GetAuthInfo(context.Background(), getAuthQueryUnchanged) err = authInfoStore.GetAuthInfo(context.Background(), getAuthQueryUnchanged)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "test1", getAuthQueryUnchanged.Result.AuthModule) require.Equal(t, "test1", getAuthQueryUnchanged.Result.AuthModule)
require.Less(t, getAuthQueryUnchanged.Result.Created, getAuthQuery.Result.Created)
}) })
t.Run("Can set & locate by generic oauth auth module and user id", func(t *testing.T) { t.Run("Can set & locate by generic oauth auth module and user id", func(t *testing.T) {
@@ -364,11 +389,14 @@ func TestUserAuth(t *testing.T) {
query = &models.GetUserByAuthInfoQuery{AuthModule: genericOAuthModule, AuthId: "", UserLookupParams: models.UserLookupParams{ query = &models.GetUserByAuthInfoQuery{AuthModule: genericOAuthModule, AuthId: "", UserLookupParams: models.UserLookupParams{
Login: &otherLoginUser, Login: &otherLoginUser,
}} }}
authInfoStore.ExpectedError = errors.New("some error")
user, err = srv.LookupAndUpdate(context.Background(), query) user, err = srv.LookupAndUpdate(context.Background(), query)
database.GetTime = time.Now database.GetTime = time.Now
require.NotNil(t, err) require.NotNil(t, err)
require.Nil(t, user) require.Nil(t, user)
authInfoStore.ExpectedError = nil
}) })
t.Run("should be able to run loginstats query in all dbs", func(t *testing.T) { t.Run("should be able to run loginstats query in all dbs", func(t *testing.T) {
@@ -426,8 +454,18 @@ func TestUserAuth(t *testing.T) {
} }
_, err = sqlStore.CreateUser(context.Background(), dupUserLogincmd) _, err = sqlStore.CreateUser(context.Background(), dupUserLogincmd)
require.NoError(t, err) require.NoError(t, err)
authInfoStore.ExpectedUser = &user.User{
// require stats to populate Email: "userduplicatetest1@test.com",
Name: "user name 1",
Login: "user_duplicate_test_1_login",
}
authInfoStore.ExpectedDuplicateUserEntries = 2
authInfoStore.ExpectedHasDuplicateUserEntries = 1
authInfoStore.ExpectedLoginStats = login.LoginStats{
DuplicateUserEntries: 2,
MixedCasedUsers: 1,
}
// require metrics and statistics to be 2
m, err := srv.authInfoStore.CollectLoginStats(context.Background()) m, err := srv.authInfoStore.CollectLoginStats(context.Background())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 2, m["stats.users.duplicate_user_entries"]) require.Equal(t, 2, m["stats.users.duplicate_user_entries"])
@@ -438,3 +476,67 @@ func TestUserAuth(t *testing.T) {
}) })
}) })
} }
type FakeAuthInfoStore struct {
ExpectedError error
ExpectedUser *user.User
ExpectedOAuth *models.UserAuth
ExpectedDuplicateUserEntries int
ExpectedHasDuplicateUserEntries int
ExpectedLoginStats login.LoginStats
}
func newFakeAuthInfoStore() *FakeAuthInfoStore {
return &FakeAuthInfoStore{}
}
func (f *FakeAuthInfoStore) GetExternalUserInfoByLogin(ctx context.Context, query *models.GetExternalUserInfoByLoginQuery) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) GetAuthInfo(ctx context.Context, query *models.GetAuthInfoQuery) error {
query.Result = f.ExpectedOAuth
return f.ExpectedError
}
func (f *FakeAuthInfoStore) SetAuthInfo(ctx context.Context, cmd *models.SetAuthInfoCommand) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *models.UserAuth) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *models.UpdateAuthInfoCommand) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *models.DeleteAuthInfoCommand) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeAuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeAuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeAuthInfoStore) CollectLoginStats(ctx context.Context) (map[string]interface{}, error) {
var res = make(map[string]interface{})
res["stats.users.duplicate_user_entries"] = f.ExpectedDuplicateUserEntries
res["stats.users.has_duplicate_user_entries"] = f.ExpectedHasDuplicateUserEntries
res["stats.users.duplicate_user_entries_by_login"] = 0
res["stats.users.has_duplicate_user_entries_by_login"] = 0
res["stats.users.duplicate_user_entries_by_email"] = 0
res["stats.users.has_duplicate_user_entries_by_email"] = 0
res["stats.users.mixed_cased_users"] = f.ExpectedLoginStats.MixedCasedUsers
return res, f.ExpectedError
}
func (f *FakeAuthInfoStore) RunMetricsCollection(ctx context.Context) error {
return f.ExpectedError
}
func (f *FakeAuthInfoStore) GetLoginStats(ctx context.Context) (login.LoginStats, error) {
return f.ExpectedLoginStats, f.ExpectedError
}

View File

@@ -0,0 +1,32 @@
package login
import (
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
)
type LoginStats struct {
DuplicateUserEntries int `xorm:"duplicate_user_entries"`
MixedCasedUsers int `xorm:"mixed_cased_users"`
}
const (
ExporterName = "grafana"
MetricsCollectionInterval = time.Second * 60 * 4 // every 4 hours, indication of duplicate users
)
var (
// MStatDuplicateUserEntries is a indication metric gauge for number of users with duplicate emails or logins
MStatDuplicateUserEntries prometheus.Gauge
// MStatHasDuplicateEntries is a metric for if there is duplicate users
MStatHasDuplicateEntries prometheus.Gauge
// MStatMixedCasedUsers is a metric for if there is duplicate users
MStatMixedCasedUsers prometheus.Gauge
Once sync.Once
Initialised bool = false
)

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/login/authinfoservice/database"
"github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user"
) )
@@ -24,5 +23,5 @@ type Store interface {
GetUserByEmail(ctx context.Context, email string) (*user.User, error) GetUserByEmail(ctx context.Context, email string) (*user.User, error)
CollectLoginStats(ctx context.Context) (map[string]interface{}, error) CollectLoginStats(ctx context.Context) (map[string]interface{}, error)
RunMetricsCollection(ctx context.Context) error RunMetricsCollection(ctx context.Context) error
GetLoginStats(ctx context.Context) (database.LoginStats, error) GetLoginStats(ctx context.Context) (LoginStats, error)
} }

View File

@@ -2,6 +2,8 @@ package user
import ( import (
"errors" "errors"
"fmt"
"strings"
"time" "time"
) )
@@ -69,3 +71,28 @@ func (u *User) NameOrFallback() string {
type DeleteUserCommand struct { type DeleteUserCommand struct {
UserID int64 UserID int64
} }
type GetUserByIDQuery struct {
ID int64
}
type ErrCaseInsensitiveLoginConflict struct {
Users []User
}
func (e *ErrCaseInsensitiveLoginConflict) Unwrap() error {
return ErrCaseInsensitive
}
func (e *ErrCaseInsensitiveLoginConflict) Error() string {
n := len(e.Users)
userStrings := make([]string, 0, n)
for _, v := range e.Users {
userStrings = append(userStrings, fmt.Sprintf("%s (email:%s, id:%d)", v.Login, v.Email, v.ID))
}
return fmt.Sprintf(
"Found a conflict in user login information. %d users already exist with either the same login or email: [%s].",
n, strings.Join(userStrings, ", "))
}

View File

@@ -7,4 +7,5 @@ import (
type Service interface { type Service interface {
Create(context.Context, *CreateUserCommand) (*User, error) Create(context.Context, *CreateUserCommand) (*User, error)
Delete(context.Context, *DeleteUserCommand) error Delete(context.Context, *DeleteUserCommand) error
GetByID(context.Context, *GetUserByIDQuery) (*User, error)
} }

View File

@@ -14,8 +14,10 @@ import (
type store interface { type store interface {
Insert(context.Context, *user.User) (int64, error) Insert(context.Context, *user.User) (int64, error)
Get(context.Context, *user.User) (*user.User, error) Get(context.Context, *user.User) (*user.User, error)
GetByID(context.Context, int64) (*user.User, error)
GetNotServiceAccount(context.Context, int64) (*user.User, error) GetNotServiceAccount(context.Context, int64) (*user.User, error)
Delete(context.Context, int64) error Delete(context.Context, int64) error
CaseInsensitiveLoginConflict(context.Context, string, string) error
} }
type sqlStore struct { type sqlStore struct {
@@ -91,8 +93,42 @@ func (ss *sqlStore) GetNotServiceAccount(ctx context.Context, userID int64) (*us
return &usr, err return &usr, err
} }
func (ss *sqlStore) GetByID(ctx context.Context, userID int64) (*user.User, error) {
var usr user.User
err := ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error {
has, err := sess.ID(&userID).
Where(ss.notServiceAccountFilter()).
Get(&usr)
if err != nil {
return err
} else if !has {
return user.ErrUserNotFound
}
return nil
})
return &usr, err
}
func (ss *sqlStore) notServiceAccountFilter() string { func (ss *sqlStore) notServiceAccountFilter() string {
return fmt.Sprintf("%s.is_service_account = %s", return fmt.Sprintf("%s.is_service_account = %s",
ss.dialect.Quote("user"), ss.dialect.Quote("user"),
ss.dialect.BooleanStr(false)) ss.dialect.BooleanStr(false))
} }
func (ss *sqlStore) CaseInsensitiveLoginConflict(ctx context.Context, login, email string) error {
users := make([]user.User, 0)
err := ss.db.WithDbSession(ctx, func(sess *sqlstore.DBSession) error {
if err := sess.Where("LOWER(email)=LOWER(?) OR LOWER(login)=LOWER(?)",
email, login).Find(&users); err != nil {
return err
}
if len(users) > 1 {
return &user.ErrCaseInsensitiveLoginConflict{Users: users}
}
return nil
})
return err
}

View File

@@ -3,7 +3,6 @@ package userimpl
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"time" "time"
"github.com/grafana/grafana/pkg/services/accesscontrol" "github.com/grafana/grafana/pkg/services/accesscontrol"
@@ -32,6 +31,8 @@ type Service struct {
userAuthService userauth.Service userAuthService userauth.Service
quotaService quota.Service quotaService quota.Service
accessControlStore accesscontrol.AccessControl accessControlStore accesscontrol.AccessControl
cfg *setting.Cfg
} }
func ProvideService( func ProvideService(
@@ -44,6 +45,7 @@ func ProvideService(
userAuthService userauth.Service, userAuthService userauth.Service,
quotaService quota.Service, quotaService quota.Service,
accessControlStore accesscontrol.AccessControl, accessControlStore accesscontrol.AccessControl,
cfg *setting.Cfg,
) user.Service { ) user.Service {
return &Service{ return &Service{
store: &sqlStore{ store: &sqlStore{
@@ -58,6 +60,7 @@ func ProvideService(
userAuthService: userAuthService, userAuthService: userAuthService,
quotaService: quotaService, quotaService: quotaService,
accessControlStore: accessControlStore, accessControlStore: accessControlStore,
cfg: cfg,
} }
} }
@@ -157,7 +160,7 @@ func (s *Service) Create(ctx context.Context, cmd *user.CreateUserCommand) (*use
func (s *Service) Delete(ctx context.Context, cmd *user.DeleteUserCommand) error { func (s *Service) Delete(ctx context.Context, cmd *user.DeleteUserCommand) error {
_, err := s.store.GetNotServiceAccount(ctx, cmd.UserID) _, err := s.store.GetNotServiceAccount(ctx, cmd.UserID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user with not service account: %w", err) return err
} }
// delete from all the stores // delete from all the stores
if err := s.store.Delete(ctx, cmd.UserID); err != nil { if err := s.store.Delete(ctx, cmd.UserID); err != nil {
@@ -225,3 +228,16 @@ func (s *Service) Delete(ctx context.Context, cmd *user.DeleteUserCommand) error
return nil return nil
} }
func (s *Service) GetByID(ctx context.Context, query *user.GetUserByIDQuery) (*user.User, error) {
user, err := s.store.GetByID(ctx, query.ID)
if err != nil {
return nil, err
}
if s.cfg.CaseInsensitiveLogin {
if err := s.store.CaseInsensitiveLoginConflict(ctx, user.Login, user.Email); err != nil {
return nil, err
}
}
return user, nil
}

View File

@@ -14,6 +14,8 @@ import (
"github.com/grafana/grafana/pkg/services/teamguardian/manager" "github.com/grafana/grafana/pkg/services/teamguardian/manager"
"github.com/grafana/grafana/pkg/services/user" "github.com/grafana/grafana/pkg/services/user"
"github.com/grafana/grafana/pkg/services/userauth/userauthtest" "github.com/grafana/grafana/pkg/services/userauth/userauthtest"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -40,7 +42,67 @@ func TestUserService(t *testing.T) {
} }
t.Run("create user", func(t *testing.T) { t.Run("create user", func(t *testing.T) {
_, err := userService.Create(context.Background(), &user.CreateUserCommand{}) _, err := userService.Create(context.Background(), &user.CreateUserCommand{
Email: "email",
Login: "login",
Name: "name",
})
require.NoError(t, err)
})
t.Run("get user by ID", func(t *testing.T) {
userService.cfg = setting.NewCfg()
userService.cfg.CaseInsensitiveLogin = false
userStore.ExpectedUser = &user.User{ID: 1, Email: "email", Login: "login", Name: "name"}
u, err := userService.GetByID(context.Background(), &user.GetUserByIDQuery{ID: 1})
require.NoError(t, err)
require.Equal(t, "login", u.Login)
require.Equal(t, "name", u.Name)
require.Equal(t, "email", u.Email)
})
t.Run("get user by ID with case insensitive login", func(t *testing.T) {
userService.cfg = setting.NewCfg()
userService.cfg.CaseInsensitiveLogin = true
userStore.ExpectedUser = &user.User{ID: 1, Email: "email", Login: "login", Name: "name"}
u, err := userService.GetByID(context.Background(), &user.GetUserByIDQuery{ID: 1})
require.NoError(t, err)
require.Equal(t, "login", u.Login)
require.Equal(t, "name", u.Name)
require.Equal(t, "email", u.Email)
})
t.Run("delete user store returns error", func(t *testing.T) {
userStore.ExpectedDeleteUserError = user.ErrUserNotFound
t.Cleanup(func() {
userStore.ExpectedDeleteUserError = nil
})
err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1})
require.Error(t, err, user.ErrUserNotFound)
})
t.Run("delete user returns from team", func(t *testing.T) {
teamMemberService.ExpectedError = errors.New("some error")
t.Cleanup(func() {
teamMemberService.ExpectedError = nil
})
err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1})
require.Error(t, err)
})
t.Run("delete user returns from team and pref", func(t *testing.T) {
teamMemberService.ExpectedError = errors.New("some error")
preferenceService.ExpectedError = errors.New("some error 2")
t.Cleanup(func() {
teamMemberService.ExpectedError = nil
preferenceService.ExpectedError = nil
})
err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1})
require.Error(t, err)
})
t.Run("delete user successfully", func(t *testing.T) {
err := userService.Delete(context.Background(), &user.DeleteUserCommand{UserID: 1})
require.NoError(t, err) require.NoError(t, err)
}) })
@@ -104,3 +166,11 @@ func (f *FakeUserStore) Delete(ctx context.Context, userID int64) error {
func (f *FakeUserStore) GetNotServiceAccount(ctx context.Context, userID int64) (*user.User, error) { func (f *FakeUserStore) GetNotServiceAccount(ctx context.Context, userID int64) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError return f.ExpectedUser, f.ExpectedError
} }
func (f *FakeUserStore) GetByID(context.Context, int64) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}
func (f *FakeUserStore) CaseInsensitiveLoginConflict(context.Context, string, string) error {
return f.ExpectedError
}

View File

@@ -22,3 +22,7 @@ func (f *FakeUserService) Create(ctx context.Context, cmd *user.CreateUserComman
func (f *FakeUserService) Delete(ctx context.Context, cmd *user.DeleteUserCommand) error { func (f *FakeUserService) Delete(ctx context.Context, cmd *user.DeleteUserCommand) error {
return f.ExpectedError return f.ExpectedError
} }
func (f *FakeUserService) GetByID(ctx context.Context, query *user.GetUserByIDQuery) (*user.User, error) {
return f.ExpectedUser, f.ExpectedError
}

File diff suppressed because it is too large Load Diff