diff --git a/app/post.go b/app/post.go index 4391123dce..46ac39c8c7 100644 --- a/app/post.go +++ b/app/post.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "strings" + "sync" "time" "github.com/mattermost/mattermost-server/mlog" @@ -895,7 +896,9 @@ func (a *App) parseAndFetchChannelIdByNameFromInFilter(channelName, userId, team } func (a *App) searchPostsInTeam(teamId string, userId string, paramsList []*model.SearchParams, modifierFun func(*model.SearchParams)) (*model.PostList, *model.AppError) { - channels := []store.StoreChannel{} + var wg sync.WaitGroup + + pchan := make(chan store.StoreResult, len(paramsList)) for _, params := range paramsList { // Don't allow users to search for everything. @@ -903,12 +906,21 @@ func (a *App) searchPostsInTeam(teamId string, userId string, paramsList []*mode continue } modifierFun(params) - channels = append(channels, a.Srv.Store.Post().Search(teamId, userId, params)) + wg.Add(1) + + go func(params *model.SearchParams) { + defer wg.Done() + postList, err := a.Srv.Store.Post().Search(teamId, userId, params) + pchan <- store.StoreResult{Data: postList, Err: err} + }(params) } + wg.Wait() + close(pchan) + posts := model.NewPostList() - for _, channel := range channels { - result := <-channel + + for result := range pchan { if result.Err != nil { return nil, result.Err } diff --git a/store/sqlstore/post_store.go b/store/sqlstore/post_store.go index f277156fcd..59c9bd891d 100644 --- a/store/sqlstore/post_store.go +++ b/store/sqlstore/post_store.go @@ -794,47 +794,46 @@ var specialSearchChar = []string{ ":", } -func (s *SqlPostStore) Search(teamId string, userId string, params *model.SearchParams) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - queryParams := map[string]interface{}{ - "TeamId": teamId, - "UserId": userId, +func (s *SqlPostStore) Search(teamId string, userId string, params *model.SearchParams) (*model.PostList, *model.AppError) { + queryParams := map[string]interface{}{ + "TeamId": teamId, + "UserId": userId, + } + + termMap := map[string]bool{} + terms := params.Terms + list := model.NewPostList() + + if terms == "" && len(params.InChannels) == 0 && len(params.FromUsers) == 0 && len(params.OnDate) == 0 && len(params.AfterDate) == 0 && len(params.BeforeDate) == 0 { + return list, nil + } + + searchType := "Message" + if params.IsHashtag { + searchType = "Hashtags" + for _, term := range strings.Split(terms, " ") { + termMap[strings.ToUpper(term)] = true } + } - termMap := map[string]bool{} - terms := params.Terms + // these chars have special meaning and can be treated as spaces + for _, c := range specialSearchChar { + terms = strings.Replace(terms, c, " ", -1) + } - if terms == "" && len(params.InChannels) == 0 && len(params.FromUsers) == 0 && len(params.OnDate) == 0 && len(params.AfterDate) == 0 && len(params.BeforeDate) == 0 { - result.Data = []*model.Post{} - return - } + var posts []*model.Post - searchType := "Message" - if params.IsHashtag { - searchType = "Hashtags" - for _, term := range strings.Split(terms, " ") { - termMap[strings.ToUpper(term)] = true - } - } + deletedQueryPart := "AND DeleteAt = 0" + if params.IncludeDeletedChannels { + deletedQueryPart = "" + } - // these chars have special meaning and can be treated as spaces - for _, c := range specialSearchChar { - terms = strings.Replace(terms, c, " ", -1) - } + userIdPart := "AND UserId = :UserId" + if params.SearchWithoutUserId { + userIdPart = "" + } - var posts []*model.Post - - deletedQueryPart := "AND DeleteAt = 0" - if params.IncludeDeletedChannels { - deletedQueryPart = "" - } - - userIdPart := "AND UserId = :UserId" - if params.SearchWithoutUserId { - userIdPart = "" - } - - searchQuery := ` + searchQuery := ` SELECT * FROM @@ -860,35 +859,35 @@ func (s *SqlPostStore) Search(teamId string, userId string, params *model.Search ORDER BY CreateAt DESC LIMIT 100` - if len(params.InChannels) > 1 { - inClause := ":InChannel0" - queryParams["InChannel0"] = params.InChannels[0] + if len(params.InChannels) > 1 { + inClause := ":InChannel0" + queryParams["InChannel0"] = params.InChannels[0] - for i := 1; i < len(params.InChannels); i++ { - paramName := "InChannel" + strconv.FormatInt(int64(i), 10) - inClause += ", :" + paramName - queryParams[paramName] = params.InChannels[i] - } - - searchQuery = strings.Replace(searchQuery, "CHANNEL_FILTER", "AND Name IN ("+inClause+")", 1) - } else if len(params.InChannels) == 1 { - queryParams["InChannel"] = params.InChannels[0] - searchQuery = strings.Replace(searchQuery, "CHANNEL_FILTER", "AND Name = :InChannel", 1) - } else { - searchQuery = strings.Replace(searchQuery, "CHANNEL_FILTER", "", 1) + for i := 1; i < len(params.InChannels); i++ { + paramName := "InChannel" + strconv.FormatInt(int64(i), 10) + inClause += ", :" + paramName + queryParams[paramName] = params.InChannels[i] } - if len(params.FromUsers) > 1 { - inClause := ":FromUser0" - queryParams["FromUser0"] = params.FromUsers[0] + searchQuery = strings.Replace(searchQuery, "CHANNEL_FILTER", "AND Name IN ("+inClause+")", 1) + } else if len(params.InChannels) == 1 { + queryParams["InChannel"] = params.InChannels[0] + searchQuery = strings.Replace(searchQuery, "CHANNEL_FILTER", "AND Name = :InChannel", 1) + } else { + searchQuery = strings.Replace(searchQuery, "CHANNEL_FILTER", "", 1) + } - for i := 1; i < len(params.FromUsers); i++ { - paramName := "FromUser" + strconv.FormatInt(int64(i), 10) - inClause += ", :" + paramName - queryParams[paramName] = params.FromUsers[i] - } + if len(params.FromUsers) > 1 { + inClause := ":FromUser0" + queryParams["FromUser0"] = params.FromUsers[0] - searchQuery = strings.Replace(searchQuery, "POST_FILTER", ` + for i := 1; i < len(params.FromUsers); i++ { + paramName := "FromUser" + strconv.FormatInt(int64(i), 10) + inClause += ", :" + paramName + queryParams[paramName] = params.FromUsers[i] + } + + searchQuery = strings.Replace(searchQuery, "POST_FILTER", ` AND UserId IN ( SELECT Id @@ -899,9 +898,9 @@ func (s *SqlPostStore) Search(teamId string, userId string, params *model.Search TeamMembers.TeamId = :TeamId AND Users.Id = TeamMembers.UserId AND Username IN (`+inClause+`))`, 1) - } else if len(params.FromUsers) == 1 { - queryParams["FromUser"] = params.FromUsers[0] - searchQuery = strings.Replace(searchQuery, "POST_FILTER", ` + } else if len(params.FromUsers) == 1 { + queryParams["FromUser"] = params.FromUsers[0] + searchQuery = strings.Replace(searchQuery, "POST_FILTER", ` AND UserId IN ( SELECT Id @@ -912,106 +911,104 @@ func (s *SqlPostStore) Search(teamId string, userId string, params *model.Search TeamMembers.TeamId = :TeamId AND Users.Id = TeamMembers.UserId AND Username = :FromUser)`, 1) - } else { - searchQuery = strings.Replace(searchQuery, "POST_FILTER", "", 1) + } else { + searchQuery = strings.Replace(searchQuery, "POST_FILTER", "", 1) + } + + // handle after: before: on: filters + if len(params.AfterDate) > 1 || len(params.BeforeDate) > 1 || len(params.OnDate) > 1 { + if len(params.OnDate) > 1 { + onDateStart, onDateEnd := params.GetOnDateMillis() + queryParams["OnDateStart"] = strconv.FormatInt(onDateStart, 10) + queryParams["OnDateEnd"] = strconv.FormatInt(onDateEnd, 10) + + // between `on date` start of day and end of day + searchQuery = strings.Replace(searchQuery, "CREATEDATE_CLAUSE", "AND CreateAt BETWEEN :OnDateStart AND :OnDateEnd ", 1) + } else if len(params.AfterDate) > 1 && len(params.BeforeDate) > 1 { + afterDate := params.GetAfterDateMillis() + beforeDate := params.GetBeforeDateMillis() + queryParams["OnDateStart"] = strconv.FormatInt(afterDate, 10) + queryParams["OnDateEnd"] = strconv.FormatInt(beforeDate, 10) + + // between clause + searchQuery = strings.Replace(searchQuery, "CREATEDATE_CLAUSE", "AND CreateAt BETWEEN :OnDateStart AND :OnDateEnd ", 1) + } else if len(params.AfterDate) > 1 { + afterDate := params.GetAfterDateMillis() + queryParams["AfterDate"] = strconv.FormatInt(afterDate, 10) + + // greater than `after date` + searchQuery = strings.Replace(searchQuery, "CREATEDATE_CLAUSE", "AND CreateAt >= :AfterDate ", 1) + } else if len(params.BeforeDate) > 1 { + beforeDate := params.GetBeforeDateMillis() + queryParams["BeforeDate"] = strconv.FormatInt(beforeDate, 10) + + // less than `before date` + searchQuery = strings.Replace(searchQuery, "CREATEDATE_CLAUSE", "AND CreateAt <= :BeforeDate ", 1) + } + } else { + // no create date filters set + searchQuery = strings.Replace(searchQuery, "CREATEDATE_CLAUSE", "", 1) + } + + if terms == "" { + // we've already confirmed that we have a channel or user to search for + searchQuery = strings.Replace(searchQuery, "SEARCH_CLAUSE", "", 1) + } else if s.DriverName() == model.DATABASE_DRIVER_POSTGRES { + // Parse text for wildcards + if wildcard, err := regexp.Compile(`\*($| )`); err == nil { + terms = wildcard.ReplaceAllLiteralString(terms, ":* ") } - // handle after: before: on: filters - if len(params.AfterDate) > 1 || len(params.BeforeDate) > 1 || len(params.OnDate) > 1 { - if len(params.OnDate) > 1 { - onDateStart, onDateEnd := params.GetOnDateMillis() - queryParams["OnDateStart"] = strconv.FormatInt(onDateStart, 10) - queryParams["OnDateEnd"] = strconv.FormatInt(onDateEnd, 10) - - // between `on date` start of day and end of day - searchQuery = strings.Replace(searchQuery, "CREATEDATE_CLAUSE", "AND CreateAt BETWEEN :OnDateStart AND :OnDateEnd ", 1) - } else if len(params.AfterDate) > 1 && len(params.BeforeDate) > 1 { - afterDate := params.GetAfterDateMillis() - beforeDate := params.GetBeforeDateMillis() - queryParams["OnDateStart"] = strconv.FormatInt(afterDate, 10) - queryParams["OnDateEnd"] = strconv.FormatInt(beforeDate, 10) - - // between clause - searchQuery = strings.Replace(searchQuery, "CREATEDATE_CLAUSE", "AND CreateAt BETWEEN :OnDateStart AND :OnDateEnd ", 1) - } else if len(params.AfterDate) > 1 { - afterDate := params.GetAfterDateMillis() - queryParams["AfterDate"] = strconv.FormatInt(afterDate, 10) - - // greater than `after date` - searchQuery = strings.Replace(searchQuery, "CREATEDATE_CLAUSE", "AND CreateAt >= :AfterDate ", 1) - } else if len(params.BeforeDate) > 1 { - beforeDate := params.GetBeforeDateMillis() - queryParams["BeforeDate"] = strconv.FormatInt(beforeDate, 10) - - // less than `before date` - searchQuery = strings.Replace(searchQuery, "CREATEDATE_CLAUSE", "AND CreateAt <= :BeforeDate ", 1) - } + if params.OrTerms { + terms = strings.Join(strings.Fields(terms), " | ") } else { - // no create date filters set - searchQuery = strings.Replace(searchQuery, "CREATEDATE_CLAUSE", "", 1) + terms = strings.Join(strings.Fields(terms), " & ") } - if terms == "" { - // we've already confirmed that we have a channel or user to search for - searchQuery = strings.Replace(searchQuery, "SEARCH_CLAUSE", "", 1) - } else if s.DriverName() == model.DATABASE_DRIVER_POSTGRES { - // Parse text for wildcards - if wildcard, err := regexp.Compile(`\*($| )`); err == nil { - terms = wildcard.ReplaceAllLiteralString(terms, ":* ") + searchClause := fmt.Sprintf("AND to_tsvector('english', %s) @@ to_tsquery(:Terms)", searchType) + searchQuery = strings.Replace(searchQuery, "SEARCH_CLAUSE", searchClause, 1) + } else if s.DriverName() == model.DATABASE_DRIVER_MYSQL { + searchClause := fmt.Sprintf("AND MATCH (%s) AGAINST (:Terms IN BOOLEAN MODE)", searchType) + searchQuery = strings.Replace(searchQuery, "SEARCH_CLAUSE", searchClause, 1) + + if !params.OrTerms { + splitTerms := strings.Fields(terms) + for i, t := range strings.Fields(terms) { + splitTerms[i] = "+" + t } - if params.OrTerms { - terms = strings.Join(strings.Fields(terms), " | ") - } else { - terms = strings.Join(strings.Fields(terms), " & ") - } + terms = strings.Join(splitTerms, " ") + } + } - searchClause := fmt.Sprintf("AND to_tsvector('english', %s) @@ to_tsquery(:Terms)", searchType) - searchQuery = strings.Replace(searchQuery, "SEARCH_CLAUSE", searchClause, 1) - } else if s.DriverName() == model.DATABASE_DRIVER_MYSQL { - searchClause := fmt.Sprintf("AND MATCH (%s) AGAINST (:Terms IN BOOLEAN MODE)", searchType) - searchQuery = strings.Replace(searchQuery, "SEARCH_CLAUSE", searchClause, 1) + queryParams["Terms"] = terms - if !params.OrTerms { - splitTerms := strings.Fields(terms) - for i, t := range strings.Fields(terms) { - splitTerms[i] = "+" + t + _, err := s.GetSearchReplica().Select(&posts, searchQuery, queryParams) + if err != nil { + mlog.Warn(fmt.Sprintf("Query error searching posts: %v", err.Error())) + // Don't return the error to the caller as it is of no use to the user. Instead return an empty set of search results. + return list, nil + } + + for _, p := range posts { + if searchType == "Hashtags" { + exactMatch := false + for _, tag := range strings.Split(p.Hashtags, " ") { + if termMap[strings.ToUpper(tag)] { + exactMatch = true } - - terms = strings.Join(splitTerms, " ") + } + if !exactMatch { + continue } } + list.AddPost(p) + list.AddOrder(p.Id) + } - queryParams["Terms"] = terms + list.MakeNonNil() - list := model.NewPostList() - - _, err := s.GetSearchReplica().Select(&posts, searchQuery, queryParams) - if err != nil { - mlog.Warn(fmt.Sprintf("Query error searching posts: %v", err.Error())) - // Don't return the error to the caller as it is of no use to the user. Instead return an empty set of search results. - } else { - for _, p := range posts { - if searchType == "Hashtags" { - exactMatch := false - for _, tag := range strings.Split(p.Hashtags, " ") { - if termMap[strings.ToUpper(tag)] { - exactMatch = true - } - } - if !exactMatch { - continue - } - } - list.AddPost(p) - list.AddOrder(p.Id) - } - } - - list.MakeNonNil() - - result.Data = list - }) + return list, nil } func (s *SqlPostStore) AnalyticsUserCountsWithPostsByDay(teamId string) (model.AnalyticsRows, *model.AppError) { diff --git a/store/store.go b/store/store.go index cb74d670b2..ea07a9b33e 100644 --- a/store/store.go +++ b/store/store.go @@ -230,7 +230,7 @@ type PostStore interface { GetPostIdAfterTime(channelId string, time int64) (string, *model.AppError) GetPostIdBeforeTime(channelId string, time int64) (string, *model.AppError) GetEtag(channelId string, allowFromCache bool) string - Search(teamId string, userId string, params *model.SearchParams) StoreChannel + Search(teamId string, userId string, params *model.SearchParams) (*model.PostList, *model.AppError) AnalyticsUserCountsWithPostsByDay(teamId string) (model.AnalyticsRows, *model.AppError) AnalyticsPostCountsByDay(options *model.AnalyticsPostCountsOptions) (model.AnalyticsRows, *model.AppError) AnalyticsPostCount(teamId string, mustHaveFile bool, mustHaveHashtag bool) (int64, *model.AppError) diff --git a/store/storetest/mocks/PostStore.go b/store/storetest/mocks/PostStore.go index ce7068bc0f..70818dd7d7 100644 --- a/store/storetest/mocks/PostStore.go +++ b/store/storetest/mocks/PostStore.go @@ -6,7 +6,6 @@ package mocks import mock "github.com/stretchr/testify/mock" import model "github.com/mattermost/mattermost-server/model" -import store "github.com/mattermost/mattermost-server/store" // PostStore is an autogenerated mock type for the PostStore type type PostStore struct { @@ -717,19 +716,28 @@ func (_m *PostStore) Save(post *model.Post) (*model.Post, *model.AppError) { } // Search provides a mock function with given fields: teamId, userId, params -func (_m *PostStore) Search(teamId string, userId string, params *model.SearchParams) store.StoreChannel { +func (_m *PostStore) Search(teamId string, userId string, params *model.SearchParams) (*model.PostList, *model.AppError) { ret := _m.Called(teamId, userId, params) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(string, string, *model.SearchParams) store.StoreChannel); ok { + var r0 *model.PostList + if rf, ok := ret.Get(0).(func(string, string, *model.SearchParams) *model.PostList); ok { r0 = rf(teamId, userId, params) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).(*model.PostList) } } - return r0 + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(string, string, *model.SearchParams) *model.AppError); ok { + r1 = rf(teamId, userId, params) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 } // Update provides a mock function with given fields: newPost, oldPost diff --git a/store/storetest/post_store.go b/store/storetest/post_store.go index 295fda6a12..4b0ab85313 100644 --- a/store/storetest/post_store.go +++ b/store/storetest/post_store.go @@ -1343,7 +1343,8 @@ func testPostStoreSearch(t *testing.T, ss store.Store) { } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - result := (<-ss.Post().Search(teamId, userId, tc.searchParams)).Data.(*model.PostList) + result, err := ss.Post().Search(teamId, userId, tc.searchParams) + require.Nil(t, err) require.Len(t, result.Order, tc.expectedResultsCount) for _, expectedMessageResultId := range tc.expectedMessageResultIds { assert.Contains(t, result.Order, expectedMessageResultId)