Refactor: Change sqlstore.inTransaction to SQLStore.WithTransactionalDBSession in misc files (#43926)

* Refactor: Change sqlstore.inTransaction to SQLStore.WithTransactionalDBSession in misc files

* Refactor: Change .inTransaction in org.go file

* Refactor: Update init() to proper SQLStore handlers

* Refactor: Update funcs in tests to be sqlStore methods

* Refactor: Update API funcs to receive HTTPServer

* Fix: define methods on sqlstore

* Adjust GetSignedInUser calls

* Refactor: Add sqlStore to Service struct

* Chore: Add back black spaces to remove file from PR

Co-authored-by: Ida Furjesova <ida.furjesova@grafana.com>
This commit is contained in:
Katarina Yang
2022-01-25 14:30:08 -05:00
committed by GitHub
parent eed9e5543d
commit 92ca38bedf
17 changed files with 88 additions and 83 deletions

View File

@@ -206,8 +206,8 @@ func (hs *HTTPServer) registerRoutes() {
// current org
apiRoute.Group("/org", func(orgRoute routing.RouteRegister) {
userIDScope := ac.Scope("users", "id", ac.Parameter(":userId"))
orgRoute.Put("/", authorize(reqOrgAdmin, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(UpdateCurrentOrg))
orgRoute.Put("/address", authorize(reqOrgAdmin, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(UpdateCurrentOrgAddress))
orgRoute.Put("/", authorize(reqOrgAdmin, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(hs.UpdateCurrentOrg))
orgRoute.Put("/address", authorize(reqOrgAdmin, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(hs.UpdateCurrentOrgAddress))
orgRoute.Get("/users", authorize(reqOrgAdmin, ac.EvalPermission(ac.ActionOrgUsersRead)), routing.Wrap(hs.GetOrgUsersForCurrentOrg))
orgRoute.Get("/users/search", authorize(reqOrgAdmin, ac.EvalPermission(ac.ActionOrgUsersRead)), routing.Wrap(hs.SearchOrgUsersWithPaging))
orgRoute.Post("/users", authorize(reqOrgAdmin, ac.EvalPermission(ac.ActionOrgUsersAdd, ac.ScopeUsersAll)), quota("user"), routing.Wrap(hs.AddOrgUserToCurrentOrg))
@@ -239,9 +239,9 @@ func (hs *HTTPServer) registerRoutes() {
apiRoute.Group("/orgs/:orgId", func(orgsRoute routing.RouteRegister) {
userIDScope := ac.Scope("users", "id", ac.Parameter(":userId"))
orgsRoute.Get("/", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsRead)), routing.Wrap(GetOrgByID))
orgsRoute.Put("/", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(UpdateOrg))
orgsRoute.Put("/address", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(UpdateOrgAddress))
orgsRoute.Delete("/", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsDelete)), routing.Wrap(DeleteOrgByID))
orgsRoute.Put("/", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(hs.UpdateOrg))
orgsRoute.Put("/address", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsWrite)), routing.Wrap(hs.UpdateOrgAddress))
orgsRoute.Delete("/", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ActionOrgsDelete)), routing.Wrap(hs.DeleteOrgByID))
orgsRoute.Get("/users", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ac.ActionOrgUsersRead, ac.ScopeUsersAll)), routing.Wrap(hs.GetOrgUsers))
orgsRoute.Post("/users", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ac.ActionOrgUsersAdd, ac.ScopeUsersAll)), routing.Wrap(hs.AddOrgUser))
orgsRoute.Patch("/users/:userId", authorizeInOrg(reqGrafanaAdmin, acmiddleware.UseOrgFromContextParams, ac.EvalPermission(ac.ActionOrgUsersRoleUpdate, userIDScope)), routing.Wrap(hs.UpdateOrgUser))

View File

@@ -111,31 +111,30 @@ func (hs *HTTPServer) CreateOrg(c *models.ReqContext) response.Response {
}
// PUT /api/org
func UpdateCurrentOrg(c *models.ReqContext) response.Response {
func (hs *HTTPServer) UpdateCurrentOrg(c *models.ReqContext) response.Response {
form := dtos.UpdateOrgForm{}
if err := web.Bind(c.Req, &form); err != nil {
return response.Error(http.StatusBadRequest, "bad request data", err)
}
return updateOrgHelper(c.Req.Context(), form, c.OrgId)
return hs.updateOrgHelper(c.Req.Context(), form, c.OrgId)
}
// PUT /api/orgs/:orgId
func UpdateOrg(c *models.ReqContext) response.Response {
func (hs *HTTPServer) UpdateOrg(c *models.ReqContext) response.Response {
form := dtos.UpdateOrgForm{}
if err := web.Bind(c.Req, &form); err != nil {
return response.Error(http.StatusBadRequest, "bad request data", err)
}
orgId, err := strconv.ParseInt(web.Params(c.Req)[":orgId"], 10, 64)
if err != nil {
return response.Error(http.StatusBadRequest, "orgId is invalid", err)
}
return updateOrgHelper(c.Req.Context(), form, orgId)
return hs.updateOrgHelper(c.Req.Context(), form, orgId)
}
func updateOrgHelper(ctx context.Context, form dtos.UpdateOrgForm, orgID int64) response.Response {
func (hs *HTTPServer) updateOrgHelper(ctx context.Context, form dtos.UpdateOrgForm, orgID int64) response.Response {
cmd := models.UpdateOrgCommand{Name: form.Name, OrgId: orgID}
if err := sqlstore.UpdateOrg(ctx, &cmd); err != nil {
if err := hs.SQLStore.UpdateOrg(ctx, &cmd); err != nil {
if errors.Is(err, models.ErrOrgNameTaken) {
return response.Error(400, "Organization name taken", err)
}
@@ -146,16 +145,16 @@ func updateOrgHelper(ctx context.Context, form dtos.UpdateOrgForm, orgID int64)
}
// PUT /api/org/address
func UpdateCurrentOrgAddress(c *models.ReqContext) response.Response {
func (hs *HTTPServer) UpdateCurrentOrgAddress(c *models.ReqContext) response.Response {
form := dtos.UpdateOrgAddressForm{}
if err := web.Bind(c.Req, &form); err != nil {
return response.Error(http.StatusBadRequest, "bad request data", err)
}
return updateOrgAddressHelper(c.Req.Context(), form, c.OrgId)
return hs.updateOrgAddressHelper(c.Req.Context(), form, c.OrgId)
}
// PUT /api/orgs/:orgId/address
func UpdateOrgAddress(c *models.ReqContext) response.Response {
func (hs *HTTPServer) UpdateOrgAddress(c *models.ReqContext) response.Response {
form := dtos.UpdateOrgAddressForm{}
if err := web.Bind(c.Req, &form); err != nil {
return response.Error(http.StatusBadRequest, "bad request data", err)
@@ -164,10 +163,10 @@ func UpdateOrgAddress(c *models.ReqContext) response.Response {
if err != nil {
return response.Error(http.StatusBadRequest, "orgId is invalid", err)
}
return updateOrgAddressHelper(c.Req.Context(), form, orgId)
return hs.updateOrgAddressHelper(c.Req.Context(), form, orgId)
}
func updateOrgAddressHelper(ctx context.Context, form dtos.UpdateOrgAddressForm, orgID int64) response.Response {
func (hs *HTTPServer) updateOrgAddressHelper(ctx context.Context, form dtos.UpdateOrgAddressForm, orgID int64) response.Response {
cmd := models.UpdateOrgAddressCommand{
OrgId: orgID,
Address: models.Address{
@@ -180,7 +179,7 @@ func updateOrgAddressHelper(ctx context.Context, form dtos.UpdateOrgAddressForm,
},
}
if err := sqlstore.UpdateOrgAddress(ctx, &cmd); err != nil {
if err := hs.SQLStore.UpdateOrgAddress(ctx, &cmd); err != nil {
return response.Error(500, "Failed to update org address", err)
}
@@ -188,7 +187,7 @@ func updateOrgAddressHelper(ctx context.Context, form dtos.UpdateOrgAddressForm,
}
// DELETE /api/orgs/:orgId
func DeleteOrgByID(c *models.ReqContext) response.Response {
func (hs *HTTPServer) DeleteOrgByID(c *models.ReqContext) response.Response {
orgID, err := strconv.ParseInt(web.Params(c.Req)[":orgId"], 10, 64)
if err != nil {
return response.Error(http.StatusBadRequest, "orgId is invalid", err)
@@ -198,7 +197,7 @@ func DeleteOrgByID(c *models.ReqContext) response.Response {
return response.Error(400, "Can not delete org for current user", nil)
}
if err := sqlstore.DeleteOrg(c.Req.Context(), &models.DeleteOrgCommand{Id: orgID}); err != nil {
if err := hs.SQLStore.DeleteOrg(c.Req.Context(), &models.DeleteOrgCommand{Id: orgID}); err != nil {
if errors.Is(err, models.ErrOrgNotFound) {
return response.Error(404, "Failed to delete organization. ID not found", nil)
}

View File

@@ -690,7 +690,7 @@ func TestPatchOrgUsersAPIEndpoint_AccessControl(t *testing.T) {
UserId: tc.targetUserId,
OrgId: tc.targetOrg,
}
err = sqlstore.GetSignedInUser(context.Background(), &getUserQuery)
err = sc.db.GetSignedInUser(context.Background(), &getUserQuery)
require.NoError(t, err)
assert.Equal(t, tc.expectedUserRole, getUserQuery.Result.OrgRole)
}

View File

@@ -207,7 +207,7 @@ func (s *Service) validateResource(ctx context.Context, orgID int64, resourceID
}
func (s *Service) validateUser(ctx context.Context, orgID, userID int64) error {
if err := sqlstore.GetSignedInUser(ctx, &models.GetSignedInUserQuery{OrgId: orgID, UserId: userID}); err != nil {
if err := s.sqlStore.GetSignedInUser(ctx, &models.GetSignedInUserQuery{OrgId: orgID, UserId: userID}); err != nil {
return err
}
return nil

View File

@@ -245,7 +245,7 @@ func TestAlertingDataAccess(t *testing.T) {
err := sqlStore.SaveAlerts(context.Background(), testDash.Id, items)
require.Nil(t, err)
err = DeleteDashboard(context.Background(), &models.DeleteDashboardCommand{
err = sqlStore.DeleteDashboard(context.Background(), &models.DeleteDashboardCommand{
OrgId: 1,
Id: testDash.Id,
})

View File

@@ -29,7 +29,6 @@ var shadowSearchCounter = prometheus.NewCounterVec(
func init() {
bus.AddHandler("sql", GetDashboard)
bus.AddHandler("sql", GetDashboards)
bus.AddHandler("sql", DeleteDashboard)
bus.AddHandler("sql", GetDashboardTags)
bus.AddHandler("sql", GetDashboardSlugById)
bus.AddHandler("sql", GetDashboardsByPluginId)
@@ -44,6 +43,7 @@ func init() {
func (ss *SQLStore) addDashboardQueryAndCommandHandlers() {
bus.AddHandler("sql", ss.GetDashboardUIDById)
bus.AddHandler("sql", ss.SearchDashboards)
bus.AddHandler("sql", ss.DeleteDashboard)
}
var generateNewUid func() string = util.GenerateShortUID
@@ -410,8 +410,8 @@ func GetDashboardTags(ctx context.Context, query *models.GetDashboardTagsQuery)
return err
}
func DeleteDashboard(ctx context.Context, cmd *models.DeleteDashboardCommand) error {
return inTransaction(func(sess *DBSession) error {
func (ss *SQLStore) DeleteDashboard(ctx context.Context, cmd *models.DeleteDashboardCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
return deleteDashboard(cmd, sess)
})
}

View File

@@ -8,9 +8,9 @@ import (
"github.com/grafana/grafana/pkg/models"
)
func init() {
func (ss *SQLStore) addDashboardProvisioningQueryAndCommandHandlers() {
bus.AddHandler("sql", UnprovisionDashboard)
bus.AddHandler("sql", DeleteOrphanedProvisionedDashboards)
bus.AddHandler("sql", ss.DeleteOrphanedProvisionedDashboards)
}
type DashboardExtras struct {
@@ -111,7 +111,7 @@ func UnprovisionDashboard(ctx context.Context, cmd *models.UnprovisionDashboardC
return nil
}
func DeleteOrphanedProvisionedDashboards(ctx context.Context, cmd *models.DeleteOrphanedProvisionedDashboardsCommand) error {
func (ss *SQLStore) DeleteOrphanedProvisionedDashboards(ctx context.Context, cmd *models.DeleteOrphanedProvisionedDashboardsCommand) error {
var result []*models.DashboardProvisioning
convertedReaderNames := make([]interface{}, len(cmd.ReaderNames))
@@ -125,7 +125,7 @@ func DeleteOrphanedProvisionedDashboards(ctx context.Context, cmd *models.Delete
}
for _, deleteDashCommand := range result {
err := DeleteDashboard(ctx, &models.DeleteDashboardCommand{Id: deleteDashCommand.DashboardId})
err := ss.DeleteDashboard(ctx, &models.DeleteDashboardCommand{Id: deleteDashCommand.DashboardId})
if err != nil && !errors.Is(err, models.ErrDashboardNotFound) {
return err
}

View File

@@ -80,7 +80,7 @@ func TestDashboardProvisioningTest(t *testing.T) {
require.NotNil(t, query.Result)
deleteCmd := &models.DeleteOrphanedProvisionedDashboardsCommand{ReaderNames: []string{"default"}}
require.Nil(t, DeleteOrphanedProvisionedDashboards(context.Background(), deleteCmd))
require.Nil(t, sqlStore.DeleteOrphanedProvisionedDashboards(context.Background(), deleteCmd))
query = &models.GetDashboardsQuery{DashboardIds: []int64{dash.Id, anotherDash.Id}}
err = GetDashboards(context.Background(), query)
@@ -117,7 +117,7 @@ func TestDashboardProvisioningTest(t *testing.T) {
OrgId: 1,
}
require.Nil(t, DeleteDashboard(context.Background(), deleteCmd))
require.Nil(t, sqlStore.DeleteDashboard(context.Background(), deleteCmd))
data, err := sqlStore.GetProvisionedDataByDashboardID(dash.Id)
require.Nil(t, err)

View File

@@ -117,7 +117,7 @@ func TestDashboardDataAccess(t *testing.T) {
setup()
dash := insertTestDashboard(t, sqlStore, "delete me", 1, 0, false, "delete this")
err := DeleteDashboard(context.Background(), &models.DeleteDashboardCommand{
err := sqlStore.DeleteDashboard(context.Background(), &models.DeleteDashboardCommand{
Id: dash.Id,
OrgId: 1,
})
@@ -214,21 +214,21 @@ func TestDashboardDataAccess(t *testing.T) {
emptyFolder := insertTestDashboard(t, sqlStore, "2 test dash folder", 1, 0, true, "prod", "webapp")
deleteCmd := &models.DeleteDashboardCommand{Id: emptyFolder.Id}
err := DeleteDashboard(context.Background(), deleteCmd)
err := sqlStore.DeleteDashboard(context.Background(), deleteCmd)
require.NoError(t, err)
})
t.Run("Should be not able to delete a dashboard if force delete rules is disabled", func(t *testing.T) {
setup()
deleteCmd := &models.DeleteDashboardCommand{Id: savedFolder.Id, ForceDeleteFolderRules: false}
err := DeleteDashboard(context.Background(), deleteCmd)
err := sqlStore.DeleteDashboard(context.Background(), deleteCmd)
require.True(t, errors.Is(err, models.ErrFolderContainsAlertRules))
})
t.Run("Should be able to delete a dashboard folder and its children if force delete rules is enabled", func(t *testing.T) {
setup()
deleteCmd := &models.DeleteDashboardCommand{Id: savedFolder.Id, ForceDeleteFolderRules: true}
err := DeleteDashboard(context.Background(), deleteCmd)
err := sqlStore.DeleteDashboard(context.Background(), deleteCmd)
require.NoError(t, err)
query := search.FindPersistedDashboardsQuery{

View File

@@ -11,14 +11,14 @@ import (
var getTimeNow = time.Now
func init() {
bus.AddHandler("sql", CreateLoginAttempt)
bus.AddHandler("sql", DeleteOldLoginAttempts)
func (ss *SQLStore) addLoginAttemptQueryAndCommandHandlers() {
bus.AddHandler("sql", ss.CreateLoginAttempt)
bus.AddHandler("sql", ss.DeleteOldLoginAttempts)
bus.AddHandler("sql", GetUserLoginAttemptCount)
}
func CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error {
return inTransaction(func(sess *DBSession) error {
func (ss *SQLStore) CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
loginAttempt := models.LoginAttempt{
Username: cmd.Username,
IpAddress: cmd.IpAddress,
@@ -35,8 +35,8 @@ func CreateLoginAttempt(ctx context.Context, cmd *models.CreateLoginAttemptComma
})
}
func DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error {
return inTransaction(func(sess *DBSession) error {
func (ss *SQLStore) DeleteOldLoginAttempts(ctx context.Context, cmd *models.DeleteOldLoginAttemptsCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
var maxId int64
sql := "SELECT max(id) as id FROM login_attempt WHERE created < ?"
result, err := sess.Query(sql, cmd.OlderThan.Unix())

View File

@@ -20,24 +20,25 @@ func mockTime(mock time.Time) time.Time {
func TestLoginAttempts(t *testing.T) {
var beginningOfTime, timePlusOneMinute, timePlusTwoMinutes time.Time
var sqlStore *SQLStore
user := "user"
setup := func(t *testing.T) {
InitTestDB(t)
sqlStore = InitTestDB(t)
beginningOfTime = mockTime(time.Date(2017, 10, 22, 8, 0, 0, 0, time.Local))
err := CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
err := sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
timePlusOneMinute = mockTime(beginningOfTime.Add(time.Minute * 1))
err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
err = sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
require.Nil(t, err)
timePlusTwoMinutes = mockTime(beginningOfTime.Add(time.Minute * 2))
err = CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
err = sqlStore.CreateLoginAttempt(context.Background(), &models.CreateLoginAttemptCommand{
Username: user,
IpAddress: "192.168.0.1",
})
@@ -93,7 +94,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: beginningOfTime,
}
err := DeleteOldLoginAttempts(context.Background(), &cmd)
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(0), cmd.DeletedRows)
@@ -104,7 +105,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusOneMinute,
}
err := DeleteOldLoginAttempts(context.Background(), &cmd)
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(1), cmd.DeletedRows)
@@ -115,7 +116,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusTwoMinutes,
}
err := DeleteOldLoginAttempts(context.Background(), &cmd)
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(2), cmd.DeletedRows)
@@ -126,7 +127,7 @@ func TestLoginAttempts(t *testing.T) {
cmd := models.DeleteOldLoginAttemptsCommand{
OlderThan: timePlusTwoMinutes.Add(time.Second * 1),
}
err := DeleteOldLoginAttempts(context.Background(), &cmd)
err := sqlStore.DeleteOldLoginAttempts(context.Background(), &cmd)
require.Nil(t, err)
require.Equal(t, int64(3), cmd.DeletedRows)

View File

@@ -15,14 +15,14 @@ import (
// MainOrgName is the name of the main organization.
const MainOrgName = "Main Org."
func init() {
func (ss *SQLStore) addOrgQueryAndCommandHandlers() {
bus.AddHandler("sql", GetOrgById)
bus.AddHandler("sql", CreateOrg)
bus.AddHandler("sql", UpdateOrg)
bus.AddHandler("sql", UpdateOrgAddress)
bus.AddHandler("sql", ss.UpdateOrg)
bus.AddHandler("sql", ss.UpdateOrgAddress)
bus.AddHandler("sql", GetOrgByName)
bus.AddHandler("sql", SearchOrgs)
bus.AddHandler("sql", DeleteOrg)
bus.AddHandler("sql", ss.DeleteOrg)
}
func SearchOrgs(ctx context.Context, query *models.SearchOrgsQuery) error {
@@ -164,8 +164,8 @@ func CreateOrg(ctx context.Context, cmd *models.CreateOrgCommand) error {
return nil
}
func UpdateOrg(ctx context.Context, cmd *models.UpdateOrgCommand) error {
return inTransaction(func(sess *DBSession) error {
func (ss *SQLStore) UpdateOrg(ctx context.Context, cmd *models.UpdateOrgCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
if isNameTaken, err := isOrgNameTaken(cmd.Name, cmd.OrgId, sess); err != nil {
return err
} else if isNameTaken {
@@ -197,8 +197,8 @@ func UpdateOrg(ctx context.Context, cmd *models.UpdateOrgCommand) error {
})
}
func UpdateOrgAddress(ctx context.Context, cmd *models.UpdateOrgAddressCommand) error {
return inTransaction(func(sess *DBSession) error {
func (ss *SQLStore) UpdateOrgAddress(ctx context.Context, cmd *models.UpdateOrgAddressCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
org := models.Org{
Address1: cmd.Address1,
Address2: cmd.Address2,
@@ -224,8 +224,8 @@ func UpdateOrgAddress(ctx context.Context, cmd *models.UpdateOrgAddressCommand)
})
}
func DeleteOrg(ctx context.Context, cmd *models.DeleteOrgCommand) error {
return inTransaction(func(sess *DBSession) error {
func (ss *SQLStore) DeleteOrg(ctx context.Context, cmd *models.DeleteOrgCommand) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
if res, err := sess.Query("SELECT 1 from org WHERE id=?", cmd.Id); err != nil {
return err
} else if len(res) != 1 {

View File

@@ -195,7 +195,7 @@ func TestAccountDataAccess(t *testing.T) {
t.Run("Can get logged in user projection", func(t *testing.T) {
query := models.GetSignedInUserQuery{UserId: ac2.Id}
err := GetSignedInUser(context.Background(), &query)
err := sqlStore.GetSignedInUser(context.Background(), &query)
require.NoError(t, err)
require.Equal(t, query.Result.Email, "ac2@test.com")
@@ -256,7 +256,7 @@ func TestAccountDataAccess(t *testing.T) {
t.Run("SignedInUserQuery with a different org", func(t *testing.T) {
query := models.GetSignedInUserQuery{UserId: ac2.Id}
err := GetSignedInUser(context.Background(), &query)
err := sqlStore.GetSignedInUser(context.Background(), &query)
require.NoError(t, err)
require.Equal(t, query.Result.OrgId, ac1.OrgId)
@@ -273,7 +273,7 @@ func TestAccountDataAccess(t *testing.T) {
require.NoError(t, err)
query := models.GetSignedInUserQuery{UserId: ac2.Id}
err = GetSignedInUser(context.Background(), &query)
err = sqlStore.GetSignedInUser(context.Background(), &query)
require.NoError(t, err)
require.Equal(t, query.Result.OrgId, ac2.OrgId)
@@ -282,7 +282,7 @@ func TestAccountDataAccess(t *testing.T) {
t.Run("Removing user from org should delete user completely if in no other org", func(t *testing.T) {
// make sure ac2 has no org
err := DeleteOrg(context.Background(), &models.DeleteOrgCommand{Id: ac2.OrgId})
err := sqlStore.DeleteOrg(context.Background(), &models.DeleteOrgCommand{Id: ac2.OrgId})
require.NoError(t, err)
// remove ac2 user from ac1 org
@@ -291,7 +291,7 @@ func TestAccountDataAccess(t *testing.T) {
require.NoError(t, err)
require.True(t, remCmd.UserWasDeleted)
err = GetSignedInUser(context.Background(), &models.GetSignedInUserQuery{UserId: ac2.Id})
err = sqlStore.GetSignedInUser(context.Background(), &models.GetSignedInUserQuery{UserId: ac2.Id})
require.Equal(t, err, models.ErrUserNotFound)
})

View File

@@ -127,7 +127,10 @@ func newSQLStore(cfg *setting.Cfg, cacheService *localcache.CacheService, bus bu
ss.addDashboardVersionQueryAndCommandHandlers()
ss.addAPIKeysQueryAndCommandHandlers()
ss.addPlaylistQueryAndCommandHandlers()
ss.addLoginAttemptQueryAndCommandHandlers()
ss.addTeamQueryAndCommandHandlers()
ss.addDashboardProvisioningQueryAndCommandHandlers()
ss.addOrgQueryAndCommandHandlers()
// if err := ss.Reset(); err != nil {
// return nil, err

View File

@@ -16,7 +16,7 @@ func (ss *SQLStore) addTeamQueryAndCommandHandlers() {
bus.AddHandler("sql", ss.DeleteTeam)
bus.AddHandler("sql", ss.SearchTeams)
bus.AddHandler("sql", ss.GetTeamById)
bus.AddHandler("sql", GetTeamsByUser)
bus.AddHandler("sql", ss.GetTeamsByUser)
bus.AddHandler("sql", ss.UpdateTeamMember)
bus.AddHandler("sql", ss.RemoveTeamMember)
@@ -106,7 +106,7 @@ func (ss *SQLStore) CreateTeam(name, email string, orgID int64) (models.Team, er
}
func (ss *SQLStore) UpdateTeam(ctx context.Context, cmd *models.UpdateTeamCommand) error {
return inTransaction(func(sess *DBSession) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
if isNameTaken, err := isTeamNameTaken(cmd.OrgId, cmd.Name, cmd.Id, sess); err != nil {
return err
} else if isNameTaken {
@@ -137,7 +137,7 @@ func (ss *SQLStore) UpdateTeam(ctx context.Context, cmd *models.UpdateTeamComman
// DeleteTeam will delete a team, its member and any permissions connected to the team
func (ss *SQLStore) DeleteTeam(ctx context.Context, cmd *models.DeleteTeamCommand) error {
return inTransaction(func(sess *DBSession) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
if _, err := teamExists(cmd.OrgId, cmd.Id, sess); err != nil {
return err
}
@@ -274,17 +274,19 @@ func (ss *SQLStore) GetTeamById(ctx context.Context, query *models.GetTeamByIdQu
}
// GetTeamsByUser is used by the Guardian when checking a users' permissions
func GetTeamsByUser(ctx context.Context, query *models.GetTeamsByUserQuery) error {
query.Result = make([]*models.TeamDTO, 0)
func (ss *SQLStore) GetTeamsByUser(ctx context.Context, query *models.GetTeamsByUserQuery) error {
return ss.WithDbSession(ctx, func(sess *DBSession) error {
query.Result = make([]*models.TeamDTO, 0)
var sql bytes.Buffer
var sql bytes.Buffer
sql.WriteString(getTeamSelectSQLBase([]string{}))
sql.WriteString(` INNER JOIN team_member on team.id = team_member.team_id`)
sql.WriteString(` WHERE team.org_id = ? and team_member.user_id = ?`)
sql.WriteString(getTeamSelectSQLBase([]string{}))
sql.WriteString(` INNER JOIN team_member on team.id = team_member.team_id`)
sql.WriteString(` WHERE team.org_id = ? and team_member.user_id = ?`)
err := x.SQL(sql.String(), query.OrgId, query.UserId).Find(&query.Result)
return err
err := sess.SQL(sql.String(), query.OrgId, query.UserId).Find(&query.Result)
return err
})
}
// AddTeamMember adds a user to a team
@@ -333,7 +335,7 @@ func getTeamMember(sess *DBSession, orgId int64, teamId int64, userId int64) (mo
// UpdateTeamMember updates a team member
func (ss *SQLStore) UpdateTeamMember(ctx context.Context, cmd *models.UpdateTeamMemberCommand) error {
return inTransaction(func(sess *DBSession) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
member, err := getTeamMember(sess, cmd.OrgId, cmd.TeamId, cmd.UserId)
if err != nil {
return err
@@ -359,7 +361,7 @@ func (ss *SQLStore) UpdateTeamMember(ctx context.Context, cmd *models.UpdateTeam
// RemoveTeamMember removes a member from a team
func (ss *SQLStore) RemoveTeamMember(ctx context.Context, cmd *models.RemoveTeamMemberCommand) error {
return inTransaction(func(sess *DBSession) error {
return ss.WithTransactionalDbSession(ctx, func(sess *DBSession) error {
if _, err := teamExists(cmd.OrgId, cmd.TeamId, sess); err != nil {
return err
}

View File

@@ -209,7 +209,7 @@ func TestTeamCommandsAndQueries(t *testing.T) {
require.NoError(t, err)
query := &models.GetTeamsByUserQuery{OrgId: testOrgID, UserId: userIds[0]}
err = GetTeamsByUser(context.Background(), query)
err = sqlStore.GetTeamsByUser(context.Background(), query)
require.NoError(t, err)
require.Equal(t, len(query.Result), 1)
require.Equal(t, query.Result[0].Name, "group2 name")

View File

@@ -546,7 +546,7 @@ func (ss *SQLStore) GetSignedInUserWithCacheCtx(ctx context.Context, query *mode
return nil
}
err := GetSignedInUser(ctx, query)
err := ss.GetSignedInUser(ctx, query)
if err != nil {
return err
}
@@ -556,7 +556,7 @@ func (ss *SQLStore) GetSignedInUserWithCacheCtx(ctx context.Context, query *mode
return nil
}
func GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error {
func (ss *SQLStore) GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) error {
orgId := "u.org_id"
if query.OrgId > 0 {
orgId = strconv.FormatInt(query.OrgId, 10)
@@ -603,7 +603,7 @@ func GetSignedInUser(ctx context.Context, query *models.GetSignedInUserQuery) er
}
getTeamsByUserQuery := &models.GetTeamsByUserQuery{OrgId: user.OrgId, UserId: user.UserId}
err = GetTeamsByUser(ctx, getTeamsByUserQuery)
err = ss.GetTeamsByUser(ctx, getTeamsByUserQuery)
if err != nil {
return err
}