Give dialects control over how insert and update queries are performed (#79946)

* Refactor insert, update
* Add separate insert, update methods
* Refactor Insert, Update signatures
This commit is contained in:
Arati R 2024-01-11 19:55:45 +01:00 committed by GitHub
parent 4e6b0fd9ce
commit ca9d147a44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 37 deletions

View File

@ -1,10 +1,12 @@
package migrator
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/grafana/grafana/pkg/services/sqlstore/session"
"golang.org/x/exp/slices"
"xorm.io/xorm"
)
@ -82,6 +84,13 @@ type Dialect interface {
// column names to values to use in the where clause.
// It returns a query string and a slice of parameters that can be executed against the database.
UpdateQuery(tableName string, row map[string]any, where map[string]any) (string, []any, error)
// Insert accepts a table name and a map of column names to insert.
// The insert is executed as part of the provided session.
Insert(ctx context.Context, tx *session.SessionTx, tableName string, row map[string]any) error
// Update accepts a table name, a map of column names to values to update, and a map of
// column names to values to use in the where clause.
// The update is executed as part of the provided session.
Update(ctx context.Context, tx *session.SessionTx, tableName string, row map[string]any, where map[string]any) error
}
type LockCfg struct {
@ -421,3 +430,23 @@ func (b *BaseDialect) UpdateQuery(tableName string, row map[string]any, where ma
return fmt.Sprintf("UPDATE %s SET %s WHERE %s", b.dialect.Quote(tableName), strings.Join(cols, ", "), strings.Join(whereCols, " AND ")), vals, nil
}
func (b *BaseDialect) Insert(ctx context.Context, tx *session.SessionTx, tableName string, row map[string]any) error {
query, args, err := b.InsertQuery(tableName, row)
if err != nil {
return err
}
_, err = tx.Exec(ctx, query, args...)
return err
}
func (b *BaseDialect) Update(ctx context.Context, tx *session.SessionTx, tableName string, row map[string]any, where map[string]any) error {
query, args, err := b.UpdateQuery(tableName, row, where)
if err != nil {
return err
}
_, err = tx.Exec(ctx, query, args...)
return err
}

View File

@ -443,31 +443,13 @@ func (s *sqlEntityServer) Create(ctx context.Context, r *entity.CreateEntityRequ
}
// 1. Add row to the `entity_history` values
query, args, err := s.dialect.InsertQuery("entity_history", values)
if err != nil {
s.log.Error("error building entity history insert", "msg", err.Error())
return err
}
s.log.Debug("create", "query", query, "args", args)
_, err = tx.Exec(ctx, query, args...)
if err != nil {
s.log.Error("error writing entity history", "msg", err.Error())
if err := s.dialect.Insert(ctx, tx, "entity_history", values); err != nil {
s.log.Error("error inserting entity history", "msg", err.Error())
return err
}
// 2. Add row to the main `entity` table
query, args, err = s.dialect.InsertQuery("entity", values)
if err != nil {
s.log.Error("error building entity insert sql", "msg", err.Error())
return err
}
s.log.Debug("create", "query", query, "args", args)
_, err = tx.Exec(ctx, query, args...)
if err != nil {
if err := s.dialect.Insert(ctx, tx, "entity", values); err != nil {
s.log.Error("error inserting entity", "msg", err.Error())
return err
}
@ -666,15 +648,8 @@ func (s *sqlEntityServer) Update(ctx context.Context, r *entity.UpdateEntityRequ
}
// 1. Add the `entity_history` values
query, args, err := s.dialect.InsertQuery("entity_history", values)
if err != nil {
s.log.Error("error building entity history insert", "msg", err.Error())
return err
}
_, err = tx.Exec(ctx, query, args...)
if err != nil {
s.log.Error("error writing entity history", "msg", err.Error())
if err := s.dialect.Insert(ctx, tx, "entity_history", values); err != nil {
s.log.Error("error inserting entity history", "msg", err.Error())
return err
}
@ -690,19 +665,15 @@ func (s *sqlEntityServer) Update(ctx context.Context, r *entity.UpdateEntityRequ
delete(values, "created_at")
delete(values, "created_by")
query, args, err = s.dialect.UpdateQuery(
err = s.dialect.Update(
ctx,
tx,
"entity",
values,
map[string]any{
"guid": current.Guid,
},
)
if err != nil {
s.log.Error("error building entity update sql", "msg", err.Error())
return err
}
_, err = tx.Exec(ctx, query, args...)
if err != nil {
s.log.Error("error updating entity", "msg", err.Error())
return err