Refactor Get/Create Direct Channel into one function (#9867)

* refactor GetDirectChannel and CreateDirectChannel in one function

* remove CreateDirectChannel plugin api and update GetDirectChannel and GetGroupChannel plugin api

* update tests
This commit is contained in:
Carlos Tadeu Panato Junior
2018-11-28 18:01:49 +01:00
committed by GitHub
parent 09eae76c00
commit 1bcf08aa4b
18 changed files with 241 additions and 402 deletions

View File

@@ -426,7 +426,7 @@ func (me *TestHelper) CreateDmChannel(user *model.User) *model.Channel {
utils.DisableDebugLogForTest()
var err *model.AppError
var channel *model.Channel
if channel, err = me.App.CreateDirectChannel(me.BasicUser.Id, user.Id); err != nil {
if channel, err = me.App.GetOrCreateDirectChannel(me.BasicUser.Id, user.Id); err != nil {
mlog.Error(err.Error())
time.Sleep(time.Second)

View File

@@ -339,7 +339,7 @@ func createDirectChannel(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
sc, err := c.App.CreateDirectChannel(userIds[0], userIds[1])
sc, err := c.App.GetOrCreateDirectChannel(userIds[0], userIds[1])
if err != nil {
c.Err = err
return

View File

@@ -410,7 +410,7 @@ func TestCreatePostAll(t *testing.T) {
user := model.User{Email: th.GenerateTestEmail(), Nickname: "Joram Wilander", Password: "hello1", Username: GenerateTestUsername(), Roles: model.SYSTEM_USER_ROLE_ID}
directChannel, _ := th.App.CreateDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
directChannel, _ := th.App.GetOrCreateDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
ruser, resp := Client.CreateUser(&user)
CheckNoError(t, resp)

View File

@@ -281,7 +281,7 @@ func (me *TestHelper) CreateDmChannel(user *model.User) *model.Channel {
utils.DisableDebugLogForTest()
var err *model.AppError
var channel *model.Channel
if channel, err = me.App.CreateDirectChannel(me.BasicUser.Id, user.Id); err != nil {
if channel, err = me.App.GetOrCreateDirectChannel(me.BasicUser.Id, user.Id); err != nil {
mlog.Error(err.Error())
time.Sleep(time.Second)

View File

@@ -224,35 +224,42 @@ func (a *App) CreateChannel(channel *model.Channel, addMember bool) (*model.Chan
return sc, nil
}
func (a *App) CreateDirectChannel(userId string, otherUserId string) (*model.Channel, *model.AppError) {
channel, err := a.createDirectChannel(userId, otherUserId)
if err != nil {
if err.Id == store.CHANNEL_EXISTS_ERROR {
func (a *App) GetOrCreateDirectChannel(userId, otherUserId string) (*model.Channel, *model.AppError) {
result := <-a.Srv.Store.Channel().GetByName("", model.GetDMNameFromIds(userId, otherUserId), true)
if result.Err != nil {
if result.Err.Id == store.MISSING_CHANNEL_ERROR {
channel, err := a.createDirectChannel(userId, otherUserId)
if err != nil {
if err.Id == store.CHANNEL_EXISTS_ERROR {
return channel, nil
}
return nil, err
}
a.WaitForChannelMembership(channel.Id, userId)
a.InvalidateCacheForUser(userId)
a.InvalidateCacheForUser(otherUserId)
if pluginsEnvironment := a.GetPluginsEnvironment(); pluginsEnvironment != nil {
a.Srv.Go(func() {
pluginContext := &plugin.Context{}
pluginsEnvironment.RunMultiPluginHook(func(hooks plugin.Hooks) bool {
hooks.ChannelHasBeenCreated(pluginContext, channel)
return true
}, plugin.ChannelHasBeenCreatedId)
})
}
message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_DIRECT_ADDED, "", channel.Id, "", nil)
message.Add("teammate_id", otherUserId)
a.Publish(message)
return channel, nil
}
return nil, err
return nil, model.NewAppError("GetOrCreateDMChannel", "web.incoming_webhook.channel.app_error", nil, "err="+result.Err.Message, result.Err.StatusCode)
}
a.WaitForChannelMembership(channel.Id, userId)
a.InvalidateCacheForUser(userId)
a.InvalidateCacheForUser(otherUserId)
if pluginsEnvironment := a.GetPluginsEnvironment(); pluginsEnvironment != nil {
a.Srv.Go(func() {
pluginContext := &plugin.Context{}
pluginsEnvironment.RunMultiPluginHook(func(hooks plugin.Hooks) bool {
hooks.ChannelHasBeenCreated(pluginContext, channel)
return true
}, plugin.ChannelHasBeenCreatedId)
})
}
message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_DIRECT_ADDED, "", channel.Id, "", nil)
message.Add("teammate_id", otherUserId)
a.Publish(message)
return channel, nil
return result.Data.(*model.Channel), nil
}
func (a *App) createDirectChannel(userId string, otherUserId string) (*model.Channel, *model.AppError) {
@@ -1764,32 +1771,6 @@ func (a *App) GetPinnedPosts(channelId string) (*model.PostList, *model.AppError
return result.Data.(*model.PostList), nil
}
func (a *App) GetDirectChannel(userId1, userId2 string) (*model.Channel, *model.AppError) {
result := <-a.Srv.Store.Channel().GetByName("", model.GetDMNameFromIds(userId1, userId2), true)
if result.Err != nil {
if result.Err.Id == store.MISSING_CHANNEL_ERROR {
result := <-a.Srv.Store.Channel().CreateDirectChannel(userId1, userId2)
if result.Err != nil {
return nil, model.NewAppError("GetOrCreateDMChannel", "web.incoming_webhook.channel.app_error", nil, "err="+result.Err.Message, http.StatusBadRequest)
}
a.InvalidateCacheForUser(userId1)
a.InvalidateCacheForUser(userId2)
channel := result.Data.(*model.Channel)
if result := <-a.Srv.Store.ChannelMemberHistory().LogJoinEvent(userId1, channel.Id, model.GetMillis()); result.Err != nil {
mlog.Warn(fmt.Sprintf("Failed to update ChannelMemberHistory table %v", result.Err))
}
if result := <-a.Srv.Store.ChannelMemberHistory().LogJoinEvent(userId2, channel.Id, model.GetMillis()); result.Err != nil {
mlog.Warn(fmt.Sprintf("Failed to update ChannelMemberHistory table %v", result.Err))
}
return channel, nil
}
return nil, model.NewAppError("GetOrCreateDMChannel", "web.incoming_webhook.channel.app_error", nil, "err="+result.Err.Message, result.Err.StatusCode)
}
return result.Data.(*model.Channel), nil
}
func (a *App) ToggleMuteChannel(channelId string, userId string) *model.ChannelMember {
result := <-a.Srv.Store.Channel().GetMember(channelId, userId)

View File

@@ -317,7 +317,7 @@ func TestCreateDirectChannelCreatesChannelMemberHistoryRecord(t *testing.T) {
user1 := th.CreateUser()
user2 := th.CreateUser()
if channel, err := th.App.CreateDirectChannel(user1.Id, user2.Id); err != nil {
if channel, err := th.App.GetOrCreateDirectChannel(user1.Id, user2.Id); err != nil {
t.Fatal("Failed to create direct channel. Error: " + err.Message)
} else {
// there should be a ChannelMemberHistory record for both users
@@ -345,7 +345,7 @@ func TestGetDirectChannelCreatesChannelMemberHistoryRecord(t *testing.T) {
user2 := th.CreateUser()
// this function call implicitly creates a direct channel between the two users if one doesn't already exist
if channel, err := th.App.GetDirectChannel(user1.Id, user2.Id); err != nil {
if channel, err := th.App.GetOrCreateDirectChannel(user1.Id, user2.Id); err != nil {
t.Fatal("Failed to create direct channel. Error: " + err.Message)
} else {
// there should be a ChannelMemberHistory record for both users

View File

@@ -70,7 +70,7 @@ func (me *msgProvider) DoCommand(a *App, args *model.CommandArgs, message string
return &model.CommandResponse{Text: args.T("api.command_msg.permission.app_error"), ResponseType: model.COMMAND_RESPONSE_TYPE_EPHEMERAL}
}
if directChannel, err := a.CreateDirectChannel(args.UserId, userProfile.Id); err != nil {
if directChannel, err := a.GetOrCreateDirectChannel(args.UserId, userProfile.Id); err != nil {
mlog.Error(err.Error())
return &model.CommandResponse{Text: args.T("api.command_msg.dm_fail.app_error"), ResponseType: model.COMMAND_RESPONSE_TYPE_EPHEMERAL}
} else {

View File

@@ -173,7 +173,7 @@ func TestMuteCommandDMChannel(t *testing.T) {
t.SkipNow()
}
channel2, _ := th.App.CreateDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
channel2, _ := th.App.GetOrCreateDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
channel2M, _ := th.App.GetChannelMember(channel2.Id, th.BasicUser.Id)
assert.Equal(t, model.CHANNEL_NOTIFY_ALL, channel2M.NotifyProps[model.MARK_UNREAD_NOTIFY_PROP])

View File

@@ -1835,9 +1835,8 @@ func TestImportImportDirectChannel(t *testing.T) {
},
Header: ptrStr("Channel Header"),
}
if err := th.App.ImportDirectChannel(&data, true); err == nil {
t.Fatalf("Expected error due to invalid name.")
}
err := th.App.ImportDirectChannel(&data, true)
require.NotNil(t, err)
// Check that no more channels are in the DB.
AssertChannelCount(t, th.App, model.CHANNEL_DIRECT, directChannelCount)
@@ -1848,9 +1847,8 @@ func TestImportImportDirectChannel(t *testing.T) {
model.NewId(),
model.NewId(),
}
if err := th.App.ImportDirectChannel(&data, true); err != nil {
t.Fatalf("Expected success as cannot validate existence of channel members in dry run mode.")
}
err = th.App.ImportDirectChannel(&data, true)
require.Nil(t, err)
// Check that no more channels are in the DB.
AssertChannelCount(t, th.App, model.CHANNEL_DIRECT, directChannelCount)
@@ -1862,9 +1860,8 @@ func TestImportImportDirectChannel(t *testing.T) {
model.NewId(),
model.NewId(),
}
if err := th.App.ImportDirectChannel(&data, true); err != nil {
t.Fatalf("Expected success as cannot validate existence of channel members in dry run mode.")
}
err = th.App.ImportDirectChannel(&data, true)
require.Nil(t, err)
// Check that no more channels are in the DB.
AssertChannelCount(t, th.App, model.CHANNEL_DIRECT, directChannelCount)
@@ -1874,9 +1871,8 @@ func TestImportImportDirectChannel(t *testing.T) {
data.Members = &[]string{
model.NewId(),
}
if err := th.App.ImportDirectChannel(&data, false); err == nil {
t.Fatalf("Expected error due to invalid member (apply mode).")
}
err = th.App.ImportDirectChannel(&data, false)
require.NotNil(t, err)
// Check that no more channels are in the DB.
AssertChannelCount(t, th.App, model.CHANNEL_DIRECT, directChannelCount)
@@ -1887,18 +1883,16 @@ func TestImportImportDirectChannel(t *testing.T) {
th.BasicUser.Username,
th.BasicUser2.Username,
}
if err := th.App.ImportDirectChannel(&data, false); err != nil {
t.Fatalf("Expected success: %v", err.Error())
}
err = th.App.ImportDirectChannel(&data, false)
require.Nil(t, err)
// Check that one more DIRECT channel is in the DB.
AssertChannelCount(t, th.App, model.CHANNEL_DIRECT, directChannelCount+1)
AssertChannelCount(t, th.App, model.CHANNEL_GROUP, groupChannelCount)
// Do the same DIRECT channel again.
if err := th.App.ImportDirectChannel(&data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectChannel(&data, false)
require.Nil(t, err)
// Check that no more channels are in the DB.
AssertChannelCount(t, th.App, model.CHANNEL_DIRECT, directChannelCount+1)
@@ -1906,22 +1900,17 @@ func TestImportImportDirectChannel(t *testing.T) {
// Update the channel's HEADER
data.Header = ptrStr("New Channel Header 2")
if err := th.App.ImportDirectChannel(&data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectChannel(&data, false)
require.Nil(t, err)
// Check that no more channels are in the DB.
AssertChannelCount(t, th.App, model.CHANNEL_DIRECT, directChannelCount+1)
AssertChannelCount(t, th.App, model.CHANNEL_GROUP, groupChannelCount)
// Get the channel to check that the header was updated.
if channel, err := th.App.createDirectChannel(th.BasicUser.Id, th.BasicUser2.Id); err == nil || err.Id != store.CHANNEL_EXISTS_ERROR {
t.Fatal("Should have got store.CHANNEL_EXISTS_ERROR")
} else {
if channel.Header != *data.Header {
t.Fatal("Channel header has not been updated successfully.")
}
}
channel, err := th.App.GetOrCreateDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
require.Nil(t, err)
require.Equal(t, channel.Header, *data.Header)
// Do a GROUP channel with an extra invalid member.
user3 := th.CreateUser()
@@ -1931,9 +1920,8 @@ func TestImportImportDirectChannel(t *testing.T) {
user3.Username,
model.NewId(),
}
if err := th.App.ImportDirectChannel(&data, false); err == nil {
t.Fatalf("Should have failed due to invalid member in list.")
}
err = th.App.ImportDirectChannel(&data, false)
require.NotNil(t, err)
// Check that no more channels are in the DB.
AssertChannelCount(t, th.App, model.CHANNEL_DIRECT, directChannelCount+1)
@@ -1945,18 +1933,16 @@ func TestImportImportDirectChannel(t *testing.T) {
th.BasicUser2.Username,
user3.Username,
}
if err := th.App.ImportDirectChannel(&data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectChannel(&data, false)
require.Nil(t, err)
// Check that one more GROUP channel is in the DB.
AssertChannelCount(t, th.App, model.CHANNEL_DIRECT, directChannelCount+1)
AssertChannelCount(t, th.App, model.CHANNEL_GROUP, groupChannelCount+1)
// Do the same DIRECT channel again.
if err := th.App.ImportDirectChannel(&data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectChannel(&data, false)
require.Nil(t, err)
// Check that no more channels are in the DB.
AssertChannelCount(t, th.App, model.CHANNEL_DIRECT, directChannelCount+1)
@@ -1964,9 +1950,8 @@ func TestImportImportDirectChannel(t *testing.T) {
// Update the channel's HEADER
data.Header = ptrStr("New Channel Header 3")
if err := th.App.ImportDirectChannel(&data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectChannel(&data, false)
require.Nil(t, err)
// Check that no more channels are in the DB.
AssertChannelCount(t, th.App, model.CHANNEL_DIRECT, directChannelCount+1)
@@ -1978,13 +1963,9 @@ func TestImportImportDirectChannel(t *testing.T) {
th.BasicUser2.Id,
user3.Id,
}
if channel, err := th.App.createGroupChannel(userIds, th.BasicUser.Id); err.Id != store.CHANNEL_EXISTS_ERROR {
t.Fatal("Should have got store.CHANNEL_EXISTS_ERROR")
} else {
if channel.Header != *data.Header {
t.Fatal("Channel header has not been updated successfully.")
}
}
channel, err = th.App.createGroupChannel(userIds, th.BasicUser.Id)
require.Equal(t, err.Id, store.CHANNEL_EXISTS_ERROR)
require.Equal(t, channel.Header, *data.Header)
// Import a channel with some favorites.
data.Members = &[]string{
@@ -1995,16 +1976,14 @@ func TestImportImportDirectChannel(t *testing.T) {
th.BasicUser.Username,
th.BasicUser2.Username,
}
if err := th.App.ImportDirectChannel(&data, false); err != nil {
t.Fatal(err)
}
err = th.App.ImportDirectChannel(&data, false)
require.Nil(t, err)
channel, err = th.App.GetOrCreateDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
require.Nil(t, err)
checkPreference(t, th.App, th.BasicUser.Id, model.PREFERENCE_CATEGORY_FAVORITE_CHANNEL, channel.Id, "true")
checkPreference(t, th.App, th.BasicUser2.Id, model.PREFERENCE_CATEGORY_FAVORITE_CHANNEL, channel.Id, "true")
if channel, err := th.App.createDirectChannel(th.BasicUser.Id, th.BasicUser2.Id); err == nil || err.Id != store.CHANNEL_EXISTS_ERROR {
t.Fatal("Should have got store.CHANNEL_EXISTS_ERROR")
} else {
checkPreference(t, th.App, th.BasicUser.Id, model.PREFERENCE_CATEGORY_FAVORITE_CHANNEL, channel.Id, "true")
checkPreference(t, th.App, th.BasicUser2.Id, model.PREFERENCE_CATEGORY_FAVORITE_CHANNEL, channel.Id, "true")
}
}
func TestImportImportDirectPost(t *testing.T) {
@@ -2018,25 +1997,21 @@ func TestImportImportDirectPost(t *testing.T) {
th.BasicUser2.Username,
},
}
if err := th.App.ImportDirectChannel(&channelData, false); err != nil {
t.Fatalf("Expected success: %v", err.Error())
}
err := th.App.ImportDirectChannel(&channelData, false)
require.Nil(t, err)
// Get the channel.
var directChannel *model.Channel
if channel, err := th.App.createDirectChannel(th.BasicUser.Id, th.BasicUser2.Id); err.Id != store.CHANNEL_EXISTS_ERROR {
t.Fatal("Should have got store.CHANNEL_EXISTS_ERROR")
} else {
directChannel = channel
}
channel, err := th.App.GetOrCreateDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
require.Nil(t, err)
require.NotEmpty(t, channel)
directChannel = channel
// Get the number of posts in the system.
var initialPostCount int64
if result := <-th.App.Srv.Store.Post().AnalyticsPostCount("", false, false); result.Err != nil {
t.Fatal(result.Err)
} else {
initialPostCount = result.Data.(int64)
}
result := <-th.App.Srv.Store.Post().AnalyticsPostCount("", false, false)
require.Nil(t, result.Err)
initialPostCount = result.Data.(int64)
// Try adding an invalid post in dry run mode.
data := &DirectPostImportData{
@@ -2047,9 +2022,8 @@ func TestImportImportDirectPost(t *testing.T) {
User: ptrStr(th.BasicUser.Username),
CreateAt: ptrInt64(model.GetMillis()),
}
if err := th.App.ImportDirectPost(data, true); err == nil {
t.Fatalf("Expected error.")
}
err = th.App.ImportDirectPost(data, true)
require.NotNil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 0, "")
// Try adding a valid post in dry run mode.
@@ -2062,9 +2036,8 @@ func TestImportImportDirectPost(t *testing.T) {
Message: ptrStr("Message"),
CreateAt: ptrInt64(model.GetMillis()),
}
if err := th.App.ImportDirectPost(data, true); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectPost(data, true)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 0, "")
// Try adding an invalid post in apply mode.
@@ -2077,9 +2050,8 @@ func TestImportImportDirectPost(t *testing.T) {
Message: ptrStr("Message"),
CreateAt: ptrInt64(model.GetMillis()),
}
if err := th.App.ImportDirectPost(data, false); err == nil {
t.Fatalf("Expected error.")
}
err = th.App.ImportDirectPost(data, false)
require.NotNil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 0, "")
// Try adding a valid post in apply mode.
@@ -2092,82 +2064,69 @@ func TestImportImportDirectPost(t *testing.T) {
Message: ptrStr("Message"),
CreateAt: ptrInt64(model.GetMillis()),
}
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success: %v", err.Error())
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 1, "")
// Check the post values.
if result := <-th.App.Srv.Store.Post().GetPostsCreatedAt(directChannel.Id, *data.CreateAt); result.Err != nil {
t.Fatal(result.Err.Error())
} else {
posts := result.Data.([]*model.Post)
if len(posts) != 1 {
t.Fatal("Unexpected number of posts found.")
}
post := posts[0]
if post.Message != *data.Message || post.CreateAt != *data.CreateAt || post.UserId != th.BasicUser.Id {
t.Fatal("Post properties not as expected")
}
}
result = <-th.App.Srv.Store.Post().GetPostsCreatedAt(directChannel.Id, *data.CreateAt)
require.Nil(t, result.Err)
posts := result.Data.([]*model.Post)
require.Equal(t, len(posts), 1)
post := posts[0]
require.Equal(t, post.Message, *data.Message)
require.Equal(t, post.CreateAt, *data.CreateAt)
require.Equal(t, post.UserId, th.BasicUser.Id)
// Import the post again.
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 1, "")
// Check the post values.
if result := <-th.App.Srv.Store.Post().GetPostsCreatedAt(directChannel.Id, *data.CreateAt); result.Err != nil {
t.Fatal(result.Err.Error())
} else {
posts := result.Data.([]*model.Post)
if len(posts) != 1 {
t.Fatal("Unexpected number of posts found.")
}
post := posts[0]
if post.Message != *data.Message || post.CreateAt != *data.CreateAt || post.UserId != th.BasicUser.Id {
t.Fatal("Post properties not as expected")
}
}
result = <-th.App.Srv.Store.Post().GetPostsCreatedAt(directChannel.Id, *data.CreateAt)
require.Nil(t, result.Err)
posts = result.Data.([]*model.Post)
require.Equal(t, len(posts), 1)
post = posts[0]
require.Equal(t, post.Message, *data.Message)
require.Equal(t, post.CreateAt, *data.CreateAt)
require.Equal(t, post.UserId, th.BasicUser.Id)
// Save the post with a different time.
data.CreateAt = ptrInt64(*data.CreateAt + 1)
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 2, "")
// Save the post with a different message.
data.Message = ptrStr("Message 2")
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 3, "")
// Test with hashtags
data.Message = ptrStr("Message 2 #hashtagmashupcity")
data.CreateAt = ptrInt64(*data.CreateAt + 1)
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 4, "")
if result := <-th.App.Srv.Store.Post().GetPostsCreatedAt(directChannel.Id, *data.CreateAt); result.Err != nil {
t.Fatal(result.Err.Error())
} else {
posts := result.Data.([]*model.Post)
if len(posts) != 1 {
t.Fatal("Unexpected number of posts found.")
}
post := posts[0]
if post.Message != *data.Message || post.CreateAt != *data.CreateAt || post.UserId != th.BasicUser.Id {
t.Fatal("Post properties not as expected")
}
if post.Hashtags != "#hashtagmashupcity" {
t.Fatalf("Hashtags not as expected: %s", post.Hashtags)
}
}
result = <-th.App.Srv.Store.Post().GetPostsCreatedAt(directChannel.Id, *data.CreateAt)
require.Nil(t, result.Err)
posts = result.Data.([]*model.Post)
require.Equal(t, len(posts), 1)
post = posts[0]
require.Equal(t, post.Message, *data.Message)
require.Equal(t, post.CreateAt, *data.CreateAt)
require.Equal(t, post.UserId, th.BasicUser.Id)
require.Equal(t, post.Hashtags, "#hashtagmashupcity")
// Test with some flags.
data = &DirectPostImportData{
@@ -2184,22 +2143,19 @@ func TestImportImportDirectPost(t *testing.T) {
CreateAt: ptrInt64(model.GetMillis()),
}
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success: %v", err.Error())
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
// Check the post values.
if result := <-th.App.Srv.Store.Post().GetPostsCreatedAt(directChannel.Id, *data.CreateAt); result.Err != nil {
t.Fatal(result.Err.Error())
} else {
posts := result.Data.([]*model.Post)
if len(posts) != 1 {
t.Fatal("Unexpected number of posts found.")
}
post := posts[0]
checkPreference(t, th.App, th.BasicUser.Id, model.PREFERENCE_CATEGORY_FLAGGED_POST, post.Id, "true")
checkPreference(t, th.App, th.BasicUser2.Id, model.PREFERENCE_CATEGORY_FLAGGED_POST, post.Id, "true")
}
result = <-th.App.Srv.Store.Post().GetPostsCreatedAt(directChannel.Id, *data.CreateAt)
require.Nil(t, result.Err)
posts = result.Data.([]*model.Post)
require.Equal(t, len(posts), 1)
post = posts[0]
checkPreference(t, th.App, th.BasicUser.Id, model.PREFERENCE_CATEGORY_FLAGGED_POST, post.Id, "true")
checkPreference(t, th.App, th.BasicUser2.Id, model.PREFERENCE_CATEGORY_FLAGGED_POST, post.Id, "true")
// ------------------ Group Channel -------------------------
@@ -2212,9 +2168,8 @@ func TestImportImportDirectPost(t *testing.T) {
user3.Username,
},
}
if err := th.App.ImportDirectChannel(&channelData, false); err != nil {
t.Fatalf("Expected success: %v", err.Error())
}
err = th.App.ImportDirectChannel(&channelData, false)
require.Nil(t, err)
// Get the channel.
var groupChannel *model.Channel
@@ -2223,18 +2178,14 @@ func TestImportImportDirectPost(t *testing.T) {
th.BasicUser2.Id,
user3.Id,
}
if channel, err := th.App.createGroupChannel(userIds, th.BasicUser.Id); err.Id != store.CHANNEL_EXISTS_ERROR {
t.Fatal("Should have got store.CHANNEL_EXISTS_ERROR")
} else {
groupChannel = channel
}
channel, err = th.App.createGroupChannel(userIds, th.BasicUser.Id)
require.Equal(t, err.Id, store.CHANNEL_EXISTS_ERROR)
groupChannel = channel
// Get the number of posts in the system.
if result := <-th.App.Srv.Store.Post().AnalyticsPostCount("", false, false); result.Err != nil {
t.Fatal(result.Err)
} else {
initialPostCount = result.Data.(int64)
}
result = <-th.App.Srv.Store.Post().AnalyticsPostCount("", false, false)
require.Nil(t, result.Err)
initialPostCount = result.Data.(int64)
// Try adding an invalid post in dry run mode.
data = &DirectPostImportData{
@@ -2246,9 +2197,8 @@ func TestImportImportDirectPost(t *testing.T) {
User: ptrStr(th.BasicUser.Username),
CreateAt: ptrInt64(model.GetMillis()),
}
if err := th.App.ImportDirectPost(data, true); err == nil {
t.Fatalf("Expected error.")
}
err = th.App.ImportDirectPost(data, true)
require.NotNil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 0, "")
// Try adding a valid post in dry run mode.
@@ -2262,9 +2212,8 @@ func TestImportImportDirectPost(t *testing.T) {
Message: ptrStr("Message"),
CreateAt: ptrInt64(model.GetMillis()),
}
if err := th.App.ImportDirectPost(data, true); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectPost(data, true)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 0, "")
// Try adding an invalid post in apply mode.
@@ -2279,9 +2228,8 @@ func TestImportImportDirectPost(t *testing.T) {
Message: ptrStr("Message"),
CreateAt: ptrInt64(model.GetMillis()),
}
if err := th.App.ImportDirectPost(data, false); err == nil {
t.Fatalf("Expected error.")
}
err = th.App.ImportDirectPost(data, false)
require.NotNil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 0, "")
// Try adding a valid post in apply mode.
@@ -2295,82 +2243,69 @@ func TestImportImportDirectPost(t *testing.T) {
Message: ptrStr("Message"),
CreateAt: ptrInt64(model.GetMillis()),
}
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success: %v", err.Error())
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 1, "")
// Check the post values.
if result := <-th.App.Srv.Store.Post().GetPostsCreatedAt(groupChannel.Id, *data.CreateAt); result.Err != nil {
t.Fatal(result.Err.Error())
} else {
posts := result.Data.([]*model.Post)
if len(posts) != 1 {
t.Fatal("Unexpected number of posts found.")
}
post := posts[0]
if post.Message != *data.Message || post.CreateAt != *data.CreateAt || post.UserId != th.BasicUser.Id {
t.Fatal("Post properties not as expected")
}
}
result = <-th.App.Srv.Store.Post().GetPostsCreatedAt(groupChannel.Id, *data.CreateAt)
require.Nil(t, result.Err)
posts = result.Data.([]*model.Post)
require.Equal(t, len(posts), 1)
post = posts[0]
require.Equal(t, post.Message, *data.Message)
require.Equal(t, post.CreateAt, *data.CreateAt)
require.Equal(t, post.UserId, th.BasicUser.Id)
// Import the post again.
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 1, "")
// Check the post values.
if result := <-th.App.Srv.Store.Post().GetPostsCreatedAt(groupChannel.Id, *data.CreateAt); result.Err != nil {
t.Fatal(result.Err.Error())
} else {
posts := result.Data.([]*model.Post)
if len(posts) != 1 {
t.Fatal("Unexpected number of posts found.")
}
post := posts[0]
if post.Message != *data.Message || post.CreateAt != *data.CreateAt || post.UserId != th.BasicUser.Id {
t.Fatal("Post properties not as expected")
}
}
result = <-th.App.Srv.Store.Post().GetPostsCreatedAt(groupChannel.Id, *data.CreateAt)
require.Nil(t, result.Err)
posts = result.Data.([]*model.Post)
require.Equal(t, len(posts), 1)
post = posts[0]
require.Equal(t, post.Message, *data.Message)
require.Equal(t, post.CreateAt, *data.CreateAt)
require.Equal(t, post.UserId, th.BasicUser.Id)
// Save the post with a different time.
data.CreateAt = ptrInt64(*data.CreateAt + 1)
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 2, "")
// Save the post with a different message.
data.Message = ptrStr("Message 2")
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 3, "")
// Test with hashtags
data.Message = ptrStr("Message 2 #hashtagmashupcity")
data.CreateAt = ptrInt64(*data.CreateAt + 1)
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success.")
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
AssertAllPostsCount(t, th.App, initialPostCount, 4, "")
if result := <-th.App.Srv.Store.Post().GetPostsCreatedAt(groupChannel.Id, *data.CreateAt); result.Err != nil {
t.Fatal(result.Err.Error())
} else {
posts := result.Data.([]*model.Post)
if len(posts) != 1 {
t.Fatal("Unexpected number of posts found.")
}
post := posts[0]
if post.Message != *data.Message || post.CreateAt != *data.CreateAt || post.UserId != th.BasicUser.Id {
t.Fatal("Post properties not as expected")
}
if post.Hashtags != "#hashtagmashupcity" {
t.Fatalf("Hashtags not as expected: %s", post.Hashtags)
}
}
result = <-th.App.Srv.Store.Post().GetPostsCreatedAt(groupChannel.Id, *data.CreateAt)
require.Nil(t, result.Err)
posts = result.Data.([]*model.Post)
require.Equal(t, len(posts), 1)
post = posts[0]
require.Equal(t, post.Message, *data.Message)
require.Equal(t, post.CreateAt, *data.CreateAt)
require.Equal(t, post.UserId, th.BasicUser.Id)
require.Equal(t, post.Hashtags, "#hashtagmashupcity")
// Test with some flags.
data = &DirectPostImportData{
@@ -2388,22 +2323,20 @@ func TestImportImportDirectPost(t *testing.T) {
CreateAt: ptrInt64(model.GetMillis()),
}
if err := th.App.ImportDirectPost(data, false); err != nil {
t.Fatalf("Expected success: %v", err.Error())
}
err = th.App.ImportDirectPost(data, false)
require.Nil(t, err)
// Check the post values.
if result := <-th.App.Srv.Store.Post().GetPostsCreatedAt(groupChannel.Id, *data.CreateAt); result.Err != nil {
t.Fatal(result.Err.Error())
} else {
posts := result.Data.([]*model.Post)
if len(posts) != 1 {
t.Fatal("Unexpected number of posts found.")
}
post := posts[0]
checkPreference(t, th.App, th.BasicUser.Id, model.PREFERENCE_CATEGORY_FLAGGED_POST, post.Id, "true")
checkPreference(t, th.App, th.BasicUser2.Id, model.PREFERENCE_CATEGORY_FLAGGED_POST, post.Id, "true")
}
result = <-th.App.Srv.Store.Post().GetPostsCreatedAt(groupChannel.Id, *data.CreateAt)
require.Nil(t, result.Err)
posts = result.Data.([]*model.Post)
require.Equal(t, len(posts), 1)
post = posts[0]
checkPreference(t, th.App, th.BasicUser.Id, model.PREFERENCE_CATEGORY_FLAGGED_POST, post.Id, "true")
checkPreference(t, th.App, th.BasicUser2.Id, model.PREFERENCE_CATEGORY_FLAGGED_POST, post.Id, "true")
}
func TestImportImportEmoji(t *testing.T) {

View File

@@ -41,7 +41,7 @@ func TestSendNotifications(t *testing.T) {
t.Fatal("user should have been mentioned")
}
dm, err := th.App.CreateDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
dm, err := th.App.GetOrCreateDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
if err != nil {
t.Fatal(err)
}

View File

@@ -297,7 +297,7 @@ func (api *PluginAPI) GetChannelStats(channelId string) (*model.ChannelStats, *m
}
func (api *PluginAPI) GetDirectChannel(userId1, userId2 string) (*model.Channel, *model.AppError) {
return api.app.GetDirectChannel(userId1, userId2)
return api.app.GetOrCreateDirectChannel(userId1, userId2)
}
func (api *PluginAPI) GetGroupChannel(userIds []string) (*model.Channel, *model.AppError) {
@@ -523,25 +523,6 @@ func (api *PluginAPI) RemoveTeamIcon(teamId string) *model.AppError {
return nil
}
func (api *PluginAPI) CreateDirectChannel(userId1 string, userId2 string) (*model.Channel, *model.AppError) {
_, err := api.app.GetUser(userId1)
if err != nil {
return nil, err
}
_, err = api.app.GetUser(userId2)
if err != nil {
return nil, err
}
dm, err := api.app.CreateDirectChannel(userId1, userId2)
if err != nil {
return nil, err
}
return dm, nil
}
// Plugin Section
func (api *PluginAPI) GetPlugins() ([]*model.Manifest, *model.AppError) {

View File

@@ -510,20 +510,20 @@ func TestPluginAPIRemoveTeamIcon(t *testing.T) {
require.Nil(t, err)
}
func TestPluginAPICreateDirectChannel(t *testing.T) {
func TestPluginAPIGetDirectChannel(t *testing.T) {
th := Setup().InitBasic()
defer th.TearDown()
api := th.SetupPluginAPI()
dm1, err := api.CreateDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
dm1, err := api.GetDirectChannel(th.BasicUser.Id, th.BasicUser2.Id)
require.Nil(t, err)
require.NotEmpty(t, dm1)
dm2, err := api.CreateDirectChannel(th.BasicUser.Id, th.BasicUser.Id)
dm2, err := api.GetDirectChannel(th.BasicUser.Id, th.BasicUser.Id)
require.Nil(t, err)
require.NotEmpty(t, dm2)
dm3, err := api.CreateDirectChannel(th.BasicUser.Id, model.NewId())
dm3, err := api.GetDirectChannel(th.BasicUser.Id, model.NewId())
require.NotNil(t, err)
require.Empty(t, dm3)
}

View File

@@ -622,7 +622,7 @@ func (a *App) parseAndFetchChannelIdByNameFromInFilter(channelName, userId, team
if err != nil {
return nil, err
}
channel, err := a.GetDirectChannel(userId, user.Id)
channel, err := a.GetOrCreateDirectChannel(userId, user.Id)
if err != nil {
return nil, err
}

View File

@@ -617,7 +617,7 @@ func (a *App) HandleIncomingWebhook(hookId string, req *model.IncomingWebhookReq
if result := <-a.Srv.Store.User().GetByUsername(channelName[1:]); result.Err != nil {
return model.NewAppError("HandleIncomingWebhook", "web.incoming_webhook.user.app_error", nil, "err="+result.Err.Message, http.StatusBadRequest)
} else {
if ch, err := a.GetDirectChannel(hook.UserId, result.Data.(*model.User).Id); err != nil {
if ch, err := a.GetOrCreateDirectChannel(hook.UserId, result.Data.(*model.User).Id); err != nil {
return err
} else {
channel = ch

View File

@@ -227,7 +227,7 @@ func (me *TestHelper) CreateDmChannel(user *model.User) *model.Channel {
utils.DisableDebugLogForTest()
var err *model.AppError
var channel *model.Channel
if channel, err = me.App.CreateDirectChannel(me.BasicUser.Id, user.Id); err != nil {
if channel, err = me.App.GetOrCreateDirectChannel(me.BasicUser.Id, user.Id); err != nil {
mlog.Error(err.Error())
time.Sleep(time.Second)

View File

@@ -52,11 +52,6 @@ type API interface {
// CreateUser creates a user.
CreateUser(user *model.User) (*model.User, *model.AppError)
// CreateDirectChannel creates a Direct channel.
//
// Minimum server version: 5.6
CreateDirectChannel(userId1 string, userId2 string) (*model.Channel, *model.AppError)
// DeleteUser deletes a user.
DeleteUser(userId string) *model.AppError
@@ -196,9 +191,11 @@ type API interface {
GetChannelStats(channelId string) (*model.ChannelStats, *model.AppError)
// GetDirectChannel gets a direct message channel.
// If the channel does not exist it will create it.
GetDirectChannel(userId1, userId2 string) (*model.Channel, *model.AppError)
// GetGroupChannel gets a group message channel.
// If the channel does not exist it will create it.
GetGroupChannel(userIds []string) (*model.Channel, *model.AppError)
// UpdateChannel updates a channel.

View File

@@ -767,36 +767,6 @@ func (s *apiRPCServer) CreateUser(args *Z_CreateUserArgs, returns *Z_CreateUserR
return nil
}
type Z_CreateDirectChannelArgs struct {
A string
B string
}
type Z_CreateDirectChannelReturns struct {
A *model.Channel
B *model.AppError
}
func (g *apiRPCClient) CreateDirectChannel(userId1 string, userId2 string) (*model.Channel, *model.AppError) {
_args := &Z_CreateDirectChannelArgs{userId1, userId2}
_returns := &Z_CreateDirectChannelReturns{}
if err := g.client.Call("Plugin.CreateDirectChannel", _args, _returns); err != nil {
log.Printf("RPC call to CreateDirectChannel API failed: %s", err.Error())
}
return _returns.A, _returns.B
}
func (s *apiRPCServer) CreateDirectChannel(args *Z_CreateDirectChannelArgs, returns *Z_CreateDirectChannelReturns) error {
if hook, ok := s.impl.(interface {
CreateDirectChannel(userId1 string, userId2 string) (*model.Channel, *model.AppError)
}); ok {
returns.A, returns.B = hook.CreateDirectChannel(args.A, args.B)
} else {
return encodableError(fmt.Errorf("API CreateDirectChannel called but not implemented."))
}
return nil
}
type Z_DeleteUserArgs struct {
A string
}

View File

@@ -112,31 +112,6 @@ func (_m *API) CreateChannel(channel *model.Channel) (*model.Channel, *model.App
return r0, r1
}
// CreateDirectChannel provides a mock function with given fields: userId1, userId2
func (_m *API) CreateDirectChannel(userId1 string, userId2 string) (*model.Channel, *model.AppError) {
ret := _m.Called(userId1, userId2)
var r0 *model.Channel
if rf, ok := ret.Get(0).(func(string, string) *model.Channel); ok {
r0 = rf(userId1, userId2)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.Channel)
}
}
var r1 *model.AppError
if rf, ok := ret.Get(1).(func(string, string) *model.AppError); ok {
r1 = rf(userId1, userId2)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*model.AppError)
}
}
return r0, r1
}
// CreatePost provides a mock function with given fields: post
func (_m *API) CreatePost(post *model.Post) (*model.Post, *model.AppError) {
ret := _m.Called(post)
@@ -606,20 +581,6 @@ func (_m *API) GetConfig() *model.Config {
return r0
}
// GetPluginConfig provides a mock function with given fields:
func (_m *API) GetPluginConfig() map[string]interface{} {
ret := _m.Called()
var r0 map[string]interface{}
if rf, ok := ret.Get(0).(func() map[string]interface{}); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(map[string]interface{})
}
return r0
}
// GetDirectChannel provides a mock function with given fields: userId1, userId2
func (_m *API) GetDirectChannel(userId1 string, userId2 string) (*model.Channel, *model.AppError) {
ret := _m.Called(userId1, userId2)
@@ -850,6 +811,22 @@ func (_m *API) GetLDAPUserAttributes(userId string, attributes []string) (map[st
return r0, r1
}
// GetPluginConfig provides a mock function with given fields:
func (_m *API) GetPluginConfig() map[string]interface{} {
ret := _m.Called()
var r0 map[string]interface{}
if rf, ok := ret.Get(0).(func() map[string]interface{}); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]interface{})
}
}
return r0
}
// GetPluginStatus provides a mock function with given fields: id
func (_m *API) GetPluginStatus(id string) (*model.PluginStatus, *model.AppError) {
ret := _m.Called(id)
@@ -1890,13 +1867,13 @@ func (_m *API) SaveConfig(config *model.Config) *model.AppError {
return r0
}
// SavePluginConfig provides a mock function with given fields: pluginConfig
func (_m *API) SavePluginConfig(pluginConfig map[string]interface{}) *model.AppError {
ret := _m.Called(pluginConfig)
// SavePluginConfig provides a mock function with given fields: config
func (_m *API) SavePluginConfig(config map[string]interface{}) *model.AppError {
ret := _m.Called(config)
var r0 *model.AppError
if rf, ok := ret.Get(0).(func(map[string]interface{}) *model.AppError); ok {
r0 = rf(pluginConfig)
r0 = rf(config)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*model.AppError)