Using native JSON operations to add thread participants (#18093)

Automatic Merge
This commit is contained in:
Agniva De Sarker
2021-08-12 20:15:03 +05:30
committed by GitHub
parent fd853e74a6
commit e1b0644b0d
5 changed files with 188 additions and 107 deletions

138
store/sqlstore/adapters.go Normal file
View File

@@ -0,0 +1,138 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package sqlstore
import (
"bytes"
"database/sql/driver"
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/dyatlov/go-opengraph/opengraph"
"github.com/mattermost/gorp"
"github.com/mattermost/mattermost-server/v6/model"
"github.com/mattermost/mattermost-server/v6/shared/i18n"
"github.com/mattermost/mattermost-server/v6/shared/mlog"
"github.com/pkg/errors"
)
type jsonArray []string
func (a jsonArray) Value() (driver.Value, error) {
var out bytes.Buffer
if err := out.WriteByte('['); err != nil {
return nil, err
}
for i, item := range a {
if _, err := out.WriteString(strconv.Quote(item)); err != nil {
return nil, err
}
// Skip the last element.
if i < len(a)-1 {
out.WriteByte(',')
}
}
if err := out.WriteByte(']'); err != nil {
return nil, err
}
return out.Bytes(), nil
}
type TraceOnAdapter struct{}
func (t *TraceOnAdapter) Printf(format string, v ...interface{}) {
originalString := fmt.Sprintf(format, v...)
newString := strings.ReplaceAll(originalString, "\n", " ")
newString = strings.ReplaceAll(newString, "\t", " ")
newString = strings.ReplaceAll(newString, "\"", "")
mlog.Debug(newString)
}
type JSONSerializable interface {
ToJson() string
}
type mattermConverter struct{}
func (me mattermConverter) ToDb(val interface{}) (interface{}, error) {
switch t := val.(type) {
case model.StringMap:
return model.MapToJson(t), nil
case map[string]string:
return model.MapToJson(model.StringMap(t)), nil
case model.StringArray:
return model.ArrayToJson(t), nil
case model.StringInterface:
return model.StringInterfaceToJson(t), nil
case map[string]interface{}:
return model.StringInterfaceToJson(model.StringInterface(t)), nil
case JSONSerializable:
return t.ToJson(), nil
case *opengraph.OpenGraph:
return json.Marshal(t)
}
return val, nil
}
func (me mattermConverter) FromDb(target interface{}) (gorp.CustomScanner, bool) {
switch target.(type) {
case *model.StringMap:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New(i18n.T("store.sql.convert_string_map"))
}
b := []byte(*s)
return json.Unmarshal(b, target)
}
return gorp.CustomScanner{Holder: new(string), Target: target, Binder: binder}, true
case *map[string]string:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New(i18n.T("store.sql.convert_string_map"))
}
b := []byte(*s)
return json.Unmarshal(b, target)
}
return gorp.CustomScanner{Holder: new(string), Target: target, Binder: binder}, true
case *model.StringArray:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New(i18n.T("store.sql.convert_string_array"))
}
b := []byte(*s)
return json.Unmarshal(b, target)
}
return gorp.CustomScanner{Holder: new(string), Target: target, Binder: binder}, true
case *model.StringInterface:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New(i18n.T("store.sql.convert_string_interface"))
}
b := []byte(*s)
return json.Unmarshal(b, target)
}
return gorp.CustomScanner{Holder: new(string), Target: target, Binder: binder}, true
case *map[string]interface{}:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New(i18n.T("store.sql.convert_string_interface"))
}
b := []byte(*s)
return json.Unmarshal(b, target)
}
return gorp.CustomScanner{Holder: new(string), Target: target, Binder: binder}, true
}
return gorp.CustomScanner{}, false
}

View File

@@ -0,0 +1,21 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package sqlstore
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestJSONArray(t *testing.T) {
input := []string{"a", "b"}
out, err := jsonArray(input).Value()
require.NoError(t, err)
outBuf, ok := out.([]byte)
require.True(t, ok)
assert.Equal(t, []byte(`["a","b"]`), outBuf)
}

View File

@@ -6,7 +6,6 @@ package sqlstore
import (
"context"
dbsql "database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
@@ -17,7 +16,6 @@ import (
"time"
sq "github.com/Masterminds/squirrel"
"github.com/dyatlov/go-opengraph/opengraph"
"github.com/go-sql-driver/mysql"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
@@ -33,7 +31,6 @@ import (
"github.com/mattermost/mattermost-server/v6/db/migrations"
"github.com/mattermost/mattermost-server/v6/einterfaces"
"github.com/mattermost/mattermost-server/v6/model"
"github.com/mattermost/mattermost-server/v6/shared/i18n"
"github.com/mattermost/mattermost-server/v6/shared/mlog"
"github.com/mattermost/mattermost-server/v6/store"
)
@@ -155,22 +152,12 @@ type SqlStore struct {
metrics einterfaces.MetricsInterface
}
type TraceOnAdapter struct{}
// ColumnInfo holds information about a column.
type ColumnInfo struct {
DataType string
CharMaximumLength int
}
func (t *TraceOnAdapter) Printf(format string, v ...interface{}) {
originalString := fmt.Sprintf(format, v...)
newString := strings.ReplaceAll(originalString, "\n", " ")
newString = strings.ReplaceAll(newString, "\t", " ")
newString = strings.ReplaceAll(newString, "\"", "")
mlog.Debug(newString)
}
func New(settings model.SqlSettings, metrics einterfaces.MetricsInterface) *SqlStore {
store := &SqlStore{
rrCounter: 0,
@@ -1567,91 +1554,6 @@ func resetReadTimeout(dataSource string) (string, error) {
return config.FormatDSN(), nil
}
type mattermConverter struct{}
func (me mattermConverter) ToDb(val interface{}) (interface{}, error) {
switch t := val.(type) {
case model.StringMap:
return model.MapToJson(t), nil
case map[string]string:
return model.MapToJson(model.StringMap(t)), nil
case model.StringArray:
return model.ArrayToJson(t), nil
case model.StringInterface:
return model.StringInterfaceToJson(t), nil
case map[string]interface{}:
return model.StringInterfaceToJson(model.StringInterface(t)), nil
case JSONSerializable:
return t.ToJson(), nil
case *opengraph.OpenGraph:
return json.Marshal(t)
}
return val, nil
}
func (me mattermConverter) FromDb(target interface{}) (gorp.CustomScanner, bool) {
switch target.(type) {
case *model.StringMap:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New(i18n.T("store.sql.convert_string_map"))
}
b := []byte(*s)
return json.Unmarshal(b, target)
}
return gorp.CustomScanner{Holder: new(string), Target: target, Binder: binder}, true
case *map[string]string:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New(i18n.T("store.sql.convert_string_map"))
}
b := []byte(*s)
return json.Unmarshal(b, target)
}
return gorp.CustomScanner{Holder: new(string), Target: target, Binder: binder}, true
case *model.StringArray:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New(i18n.T("store.sql.convert_string_array"))
}
b := []byte(*s)
return json.Unmarshal(b, target)
}
return gorp.CustomScanner{Holder: new(string), Target: target, Binder: binder}, true
case *model.StringInterface:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New(i18n.T("store.sql.convert_string_interface"))
}
b := []byte(*s)
return json.Unmarshal(b, target)
}
return gorp.CustomScanner{Holder: new(string), Target: target, Binder: binder}, true
case *map[string]interface{}:
binder := func(holder, target interface{}) error {
s, ok := holder.(*string)
if !ok {
return errors.New(i18n.T("store.sql.convert_string_interface"))
}
b := []byte(*s)
return json.Unmarshal(b, target)
}
return gorp.CustomScanner{Holder: new(string), Target: target, Binder: binder}, true
}
return gorp.CustomScanner{}, false
}
type JSONSerializable interface {
ToJson() string
}
func convertMySQLFullTextColumnsToPostgres(columnNames string) string {
columns := strings.Split(columnNames, ", ")
concatenatedColumnNames := ""

View File

@@ -6,6 +6,7 @@ package sqlstore
import (
"context"
"database/sql"
"strconv"
"time"
sq "github.com/Masterminds/squirrel"
@@ -658,14 +659,21 @@ func (s *SqlThreadStore) MaintainMembership(userId, postId string, opts store.Th
}
if opts.UpdateParticipants {
thread, getErr := s.get(trx, postId)
if getErr != nil {
return nil, getErr
}
if thread != nil && !thread.Participants.Contains(userId) {
thread.Participants = append(thread.Participants, userId)
if _, err = s.update(trx, thread); err != nil {
return nil, err
if s.DriverName() == model.DatabaseDriverPostgres {
if _, err2 := trx.Exec(`UPDATE Threads
SET participants = participants || $1::jsonb
WHERE postid=$2
AND NOT participants ? $3`, jsonArray([]string{userId}), postId, userId); err2 != nil {
return nil, err2
}
} else {
// CONCAT('$[', JSON_LENGTH(Participants), ']') just generates $[n]
// which is the positional syntax required for appending.
if _, err2 := trx.Exec(`UPDATE Threads
SET Participants = JSON_ARRAY_INSERT(Participants, CONCAT('$[', JSON_LENGTH(Participants), ']'), ?)
WHERE PostId=?
AND NOT JSON_CONTAINS(Participants, ?)`, userId, postId, strconv.Quote(userId)); err2 != nil {
return nil, err2
}
}
}

View File

@@ -330,18 +330,30 @@ func testThreadStorePopulation(t *testing.T, ss store.Store) {
IncrementMentions: false,
UpdateFollowing: true,
UpdateViewedTimestamp: false,
UpdateParticipants: false,
UpdateParticipants: true,
}
tm, e := ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, opts)
require.NoError(t, e)
require.Equal(t, int64(0), tm.LastViewed)
// No update since array has same elements.
th, e := ss.Thread().Get(newPosts[0].Id)
require.NoError(t, e)
assert.ElementsMatch(t, model.StringArray{newPosts[0].UserId, newPosts[1].UserId}, th.Participants)
opts.UpdateViewedTimestamp = true
_, e = ss.Thread().MaintainMembership(newPosts[0].UserId, newPosts[0].Id, opts)
require.NoError(t, e)
m2, err2 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
require.NoError(t, err2)
require.Greater(t, m2.LastViewed, int64(0))
// Adding a new participant
_, e = ss.Thread().MaintainMembership("newuser", newPosts[0].Id, opts)
require.NoError(t, e)
th, e = ss.Thread().Get(newPosts[0].Id)
require.NoError(t, e)
assert.ElementsMatch(t, model.StringArray{newPosts[0].UserId, newPosts[1].UserId, "newuser"}, th.Participants)
})
t.Run("Thread membership 'viewed' timestamp is updated properly for new membership", func(t *testing.T) {