From 7c7ff93d973e549eb329ce411a816d8938e9bd7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Espino?= Date: Wed, 24 Apr 2019 12:50:25 +0200 Subject: [PATCH] Migrating compliance store to sync by default (#10660) * Migrating compliance store to sync by default * Addressing PR review comments --- app/compliance.go | 19 +- store/sqlstore/compliance_store.go | 368 +++++++++++------------ store/store.go | 12 +- store/storetest/compliance_store.go | 259 +++++----------- store/storetest/mocks/ComplianceStore.go | 115 +++++-- 5 files changed, 347 insertions(+), 426 deletions(-) diff --git a/app/compliance.go b/app/compliance.go index 2d8dd2b517..11107c4b93 100644 --- a/app/compliance.go +++ b/app/compliance.go @@ -16,11 +16,7 @@ func (a *App) GetComplianceReports(page, perPage int) (model.Compliances, *model return nil, model.NewAppError("GetComplianceReports", "ent.compliance.licence_disable.app_error", nil, "", http.StatusNotImplemented) } - result := <-a.Srv.Store.Compliance().GetAll(page*perPage, perPage) - if result.Err != nil { - return nil, result.Err - } - return result.Data.(model.Compliances), nil + return a.Srv.Store.Compliance().GetAll(page*perPage, perPage) } func (a *App) SaveComplianceReport(job *model.Compliance) (*model.Compliance, *model.AppError) { @@ -30,12 +26,11 @@ func (a *App) SaveComplianceReport(job *model.Compliance) (*model.Compliance, *m job.Type = model.COMPLIANCE_TYPE_ADHOC - result := <-a.Srv.Store.Compliance().Save(job) - if result.Err != nil { - return nil, result.Err + job, err := a.Srv.Store.Compliance().Save(job) + if err != nil { + return nil, err } - job = result.Data.(*model.Compliance) a.Srv.Go(func() { a.Compliance.RunComplianceJob(job) }) @@ -48,11 +43,7 @@ func (a *App) GetComplianceReport(reportId string) (*model.Compliance, *model.Ap return nil, model.NewAppError("downloadComplianceReport", "ent.compliance.licence_disable.app_error", nil, "", http.StatusNotImplemented) } - result := <-a.Srv.Store.Compliance().Get(reportId) - if result.Err != nil { - return nil, result.Err - } - return result.Data.(*model.Compliance), nil + return a.Srv.Store.Compliance().Get(reportId) } func (a *App) GetComplianceFile(job *model.Compliance) ([]byte, *model.AppError) { diff --git a/store/sqlstore/compliance_store.go b/store/sqlstore/compliance_store.go index 3dabffaabc..9ad74ca89f 100644 --- a/store/sqlstore/compliance_store.go +++ b/store/sqlstore/compliance_store.go @@ -36,226 +36,210 @@ func NewSqlComplianceStore(sqlStore SqlStore) store.ComplianceStore { func (s SqlComplianceStore) CreateIndexesIfNotExists() { } -func (s SqlComplianceStore) Save(compliance *model.Compliance) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - compliance.PreSave() - if result.Err = compliance.IsValid(); result.Err != nil { - return - } +func (s SqlComplianceStore) Save(compliance *model.Compliance) (*model.Compliance, *model.AppError) { + compliance.PreSave() + if err := compliance.IsValid(); err != nil { + return nil, err + } - if err := s.GetMaster().Insert(compliance); err != nil { - result.Err = model.NewAppError("SqlComplianceStore.Save", "store.sql_compliance.save.saving.app_error", nil, err.Error(), http.StatusInternalServerError) - } else { - result.Data = compliance - } - }) + if err := s.GetMaster().Insert(compliance); err != nil { + return nil, model.NewAppError("SqlComplianceStore.Save", "store.sql_compliance.save.saving.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return compliance, nil } -func (us SqlComplianceStore) Update(compliance *model.Compliance) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - if result.Err = compliance.IsValid(); result.Err != nil { - return - } +func (us SqlComplianceStore) Update(compliance *model.Compliance) (*model.Compliance, *model.AppError) { + if err := compliance.IsValid(); err != nil { + return nil, err + } - if _, err := us.GetMaster().Update(compliance); err != nil { - result.Err = model.NewAppError("SqlComplianceStore.Update", "store.sql_compliance.save.saving.app_error", nil, err.Error(), http.StatusInternalServerError) - } else { - result.Data = compliance - } - }) + if _, err := us.GetMaster().Update(compliance); err != nil { + return nil, model.NewAppError("SqlComplianceStore.Update", "store.sql_compliance.save.saving.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return compliance, nil } -func (s SqlComplianceStore) GetAll(offset, limit int) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - query := "SELECT * FROM Compliances ORDER BY CreateAt DESC LIMIT :Limit OFFSET :Offset" +func (s SqlComplianceStore) GetAll(offset, limit int) (model.Compliances, *model.AppError) { + query := "SELECT * FROM Compliances ORDER BY CreateAt DESC LIMIT :Limit OFFSET :Offset" - var compliances model.Compliances - if _, err := s.GetReplica().Select(&compliances, query, map[string]interface{}{"Offset": offset, "Limit": limit}); err != nil { - result.Err = model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusInternalServerError) - } else { - result.Data = compliances - } - }) + var compliances model.Compliances + if _, err := s.GetReplica().Select(&compliances, query, map[string]interface{}{"Offset": offset, "Limit": limit}); err != nil { + return nil, model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return compliances, nil } -func (us SqlComplianceStore) Get(id string) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - if obj, err := us.GetReplica().Get(model.Compliance{}, id); err != nil { - result.Err = model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusInternalServerError) - } else if obj == nil { - result.Err = model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusNotFound) - } else { - result.Data = obj.(*model.Compliance) - } - }) +func (us SqlComplianceStore) Get(id string) (*model.Compliance, *model.AppError) { + obj, err := us.GetReplica().Get(model.Compliance{}, id) + if err != nil { + return nil, model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusInternalServerError) + } + if obj == nil { + return nil, model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusNotFound) + } + return obj.(*model.Compliance), nil } -func (s SqlComplianceStore) ComplianceExport(job *model.Compliance) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - props := map[string]interface{}{"StartTime": job.StartAt, "EndTime": job.EndAt} +func (s SqlComplianceStore) ComplianceExport(job *model.Compliance) ([]*model.CompliancePost, *model.AppError) { + props := map[string]interface{}{"StartTime": job.StartAt, "EndTime": job.EndAt} - keywordQuery := "" - keywords := strings.Fields(strings.TrimSpace(strings.ToLower(strings.Replace(job.Keywords, ",", " ", -1)))) - if len(keywords) > 0 { + keywordQuery := "" + keywords := strings.Fields(strings.TrimSpace(strings.ToLower(strings.Replace(job.Keywords, ",", " ", -1)))) + if len(keywords) > 0 { - keywordQuery = "AND (" + keywordQuery = "AND (" - for index, keyword := range keywords { - if index >= 1 { - keywordQuery += " OR LOWER(Posts.Message) LIKE :Keyword" + strconv.Itoa(index) - } else { - keywordQuery += "LOWER(Posts.Message) LIKE :Keyword" + strconv.Itoa(index) - } - - props["Keyword"+strconv.Itoa(index)] = "%" + keyword + "%" + for index, keyword := range keywords { + if index >= 1 { + keywordQuery += " OR LOWER(Posts.Message) LIKE :Keyword" + strconv.Itoa(index) + } else { + keywordQuery += "LOWER(Posts.Message) LIKE :Keyword" + strconv.Itoa(index) } - keywordQuery += ")" + props["Keyword"+strconv.Itoa(index)] = "%" + keyword + "%" } - emailQuery := "" - emails := strings.Fields(strings.TrimSpace(strings.ToLower(strings.Replace(job.Emails, ",", " ", -1)))) - if len(emails) > 0 { + keywordQuery += ")" + } - emailQuery = "AND (" + emailQuery := "" + emails := strings.Fields(strings.TrimSpace(strings.ToLower(strings.Replace(job.Emails, ",", " ", -1)))) + if len(emails) > 0 { - for index, email := range emails { - if index >= 1 { - emailQuery += " OR Users.Email = :Email" + strconv.Itoa(index) - } else { - emailQuery += "Users.Email = :Email" + strconv.Itoa(index) - } + emailQuery = "AND (" - props["Email"+strconv.Itoa(index)] = email + for index, email := range emails { + if index >= 1 { + emailQuery += " OR Users.Email = :Email" + strconv.Itoa(index) + } else { + emailQuery += "Users.Email = :Email" + strconv.Itoa(index) } - emailQuery += ")" + props["Email"+strconv.Itoa(index)] = email } - query := - `(SELECT - Teams.Name AS TeamName, - Teams.DisplayName AS TeamDisplayName, - Channels.Name AS ChannelName, - Channels.DisplayName AS ChannelDisplayName, - Channels.Type AS ChannelType, - Users.Username AS UserUsername, - Users.Email AS UserEmail, - Users.Nickname AS UserNickname, - Posts.Id AS PostId, - Posts.CreateAt AS PostCreateAt, - Posts.UpdateAt AS PostUpdateAt, - Posts.DeleteAt AS PostDeleteAt, - Posts.RootId AS PostRootId, - Posts.ParentId AS PostParentId, - Posts.OriginalId AS PostOriginalId, - Posts.Message AS PostMessage, - Posts.Type AS PostType, - Posts.Props AS PostProps, - Posts.Hashtags AS PostHashtags, - Posts.FileIds AS PostFileIds - FROM - Teams, - Channels, - Users, - Posts - WHERE - Teams.Id = Channels.TeamId - AND Posts.ChannelId = Channels.Id - AND Posts.UserId = Users.Id - AND Posts.CreateAt > :StartTime - AND Posts.CreateAt <= :EndTime - ` + emailQuery + ` - ` + keywordQuery + `) - UNION ALL - (SELECT - 'direct-messages' AS TeamName, - 'Direct Messages' AS TeamDisplayName, - Channels.Name AS ChannelName, - Channels.DisplayName AS ChannelDisplayName, - Channels.Type AS ChannelType, - Users.Username AS UserUsername, - Users.Email AS UserEmail, - Users.Nickname AS UserNickname, - Posts.Id AS PostId, - Posts.CreateAt AS PostCreateAt, - Posts.UpdateAt AS PostUpdateAt, - Posts.DeleteAt AS PostDeleteAt, - Posts.RootId AS PostRootId, - Posts.ParentId AS PostParentId, - Posts.OriginalId AS PostOriginalId, - Posts.Message AS PostMessage, - Posts.Type AS PostType, - Posts.Props AS PostProps, - Posts.Hashtags AS PostHashtags, - Posts.FileIds AS PostFileIds - FROM - Channels, - Users, - Posts - WHERE - Channels.TeamId = '' - AND Posts.ChannelId = Channels.Id - AND Posts.UserId = Users.Id - AND Posts.CreateAt > :StartTime - AND Posts.CreateAt <= :EndTime - ` + emailQuery + ` - ` + keywordQuery + `) - ORDER BY PostCreateAt - LIMIT 30000` + emailQuery += ")" + } - var cposts []*model.CompliancePost + query := + `(SELECT + Teams.Name AS TeamName, + Teams.DisplayName AS TeamDisplayName, + Channels.Name AS ChannelName, + Channels.DisplayName AS ChannelDisplayName, + Channels.Type AS ChannelType, + Users.Username AS UserUsername, + Users.Email AS UserEmail, + Users.Nickname AS UserNickname, + Posts.Id AS PostId, + Posts.CreateAt AS PostCreateAt, + Posts.UpdateAt AS PostUpdateAt, + Posts.DeleteAt AS PostDeleteAt, + Posts.RootId AS PostRootId, + Posts.ParentId AS PostParentId, + Posts.OriginalId AS PostOriginalId, + Posts.Message AS PostMessage, + Posts.Type AS PostType, + Posts.Props AS PostProps, + Posts.Hashtags AS PostHashtags, + Posts.FileIds AS PostFileIds + FROM + Teams, + Channels, + Users, + Posts + WHERE + Teams.Id = Channels.TeamId + AND Posts.ChannelId = Channels.Id + AND Posts.UserId = Users.Id + AND Posts.CreateAt > :StartTime + AND Posts.CreateAt <= :EndTime + ` + emailQuery + ` + ` + keywordQuery + `) + UNION ALL + (SELECT + 'direct-messages' AS TeamName, + 'Direct Messages' AS TeamDisplayName, + Channels.Name AS ChannelName, + Channels.DisplayName AS ChannelDisplayName, + Channels.Type AS ChannelType, + Users.Username AS UserUsername, + Users.Email AS UserEmail, + Users.Nickname AS UserNickname, + Posts.Id AS PostId, + Posts.CreateAt AS PostCreateAt, + Posts.UpdateAt AS PostUpdateAt, + Posts.DeleteAt AS PostDeleteAt, + Posts.RootId AS PostRootId, + Posts.ParentId AS PostParentId, + Posts.OriginalId AS PostOriginalId, + Posts.Message AS PostMessage, + Posts.Type AS PostType, + Posts.Props AS PostProps, + Posts.Hashtags AS PostHashtags, + Posts.FileIds AS PostFileIds + FROM + Channels, + Users, + Posts + WHERE + Channels.TeamId = '' + AND Posts.ChannelId = Channels.Id + AND Posts.UserId = Users.Id + AND Posts.CreateAt > :StartTime + AND Posts.CreateAt <= :EndTime + ` + emailQuery + ` + ` + keywordQuery + `) + ORDER BY PostCreateAt + LIMIT 30000` - if _, err := s.GetReplica().Select(&cposts, query, props); err != nil { - result.Err = model.NewAppError("SqlPostStore.ComplianceExport", "store.sql_post.compliance_export.app_error", nil, err.Error(), http.StatusInternalServerError) - } else { - result.Data = cposts - } - }) + var cposts []*model.CompliancePost + + if _, err := s.GetReplica().Select(&cposts, query, props); err != nil { + return nil, model.NewAppError("SqlPostStore.ComplianceExport", "store.sql_post.compliance_export.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return cposts, nil } -func (s SqlComplianceStore) MessageExport(after int64, limit int) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - props := map[string]interface{}{"StartTime": after, "Limit": limit} - query := - `SELECT - Posts.Id AS PostId, - Posts.CreateAt AS PostCreateAt, - Posts.Message AS PostMessage, - Posts.Type AS PostType, - Posts.OriginalId AS PostOriginalId, - Posts.RootId AS PostRootId, - Posts.FileIds AS PostFileIds, - Teams.Id AS TeamId, - Teams.Name AS TeamName, - Teams.DisplayName AS TeamDisplayName, - Channels.Id AS ChannelId, - CASE - WHEN Channels.Type = 'D' THEN 'Direct Message' - WHEN Channels.Type = 'G' THEN 'Group Message' - ELSE Channels.DisplayName - END AS ChannelDisplayName, - Channels.Name AS ChannelName, - Channels.Type AS ChannelType, - Users.Id AS UserId, - Users.Email AS UserEmail, - Users.Username - FROM - Posts - LEFT OUTER JOIN Channels ON Posts.ChannelId = Channels.Id - LEFT OUTER JOIN Teams ON Channels.TeamId = Teams.Id - LEFT OUTER JOIN Users ON Posts.UserId = Users.Id - WHERE - Posts.CreateAt > :StartTime AND - Posts.Type = '' - ORDER BY PostCreateAt - LIMIT :Limit` +func (s SqlComplianceStore) MessageExport(after int64, limit int) ([]*model.MessageExport, *model.AppError) { + props := map[string]interface{}{"StartTime": after, "Limit": limit} + query := + `SELECT + Posts.Id AS PostId, + Posts.CreateAt AS PostCreateAt, + Posts.Message AS PostMessage, + Posts.Type AS PostType, + Posts.OriginalId AS PostOriginalId, + Posts.RootId AS PostRootId, + Posts.FileIds AS PostFileIds, + Teams.Id AS TeamId, + Teams.Name AS TeamName, + Teams.DisplayName AS TeamDisplayName, + Channels.Id AS ChannelId, + CASE + WHEN Channels.Type = 'D' THEN 'Direct Message' + WHEN Channels.Type = 'G' THEN 'Group Message' + ELSE Channels.DisplayName + END AS ChannelDisplayName, + Channels.Name AS ChannelName, + Channels.Type AS ChannelType, + Users.Id AS UserId, + Users.Email AS UserEmail, + Users.Username + FROM + Posts + LEFT OUTER JOIN Channels ON Posts.ChannelId = Channels.Id + LEFT OUTER JOIN Teams ON Channels.TeamId = Teams.Id + LEFT OUTER JOIN Users ON Posts.UserId = Users.Id + WHERE + Posts.CreateAt > :StartTime AND + Posts.Type = '' + ORDER BY PostCreateAt + LIMIT :Limit` - var cposts []*model.MessageExport - if _, err := s.GetReplica().Select(&cposts, query, props); err != nil { - result.Err = model.NewAppError("SqlComplianceStore.MessageExport", "store.sql_compliance.message_export.app_error", nil, err.Error(), http.StatusInternalServerError) - } else { - result.Data = cposts - } - }) + var cposts []*model.MessageExport + if _, err := s.GetReplica().Select(&cposts, query, props); err != nil { + return nil, model.NewAppError("SqlComplianceStore.MessageExport", "store.sql_compliance.message_export.app_error", nil, err.Error(), http.StatusInternalServerError) + } + return cposts, nil } diff --git a/store/store.go b/store/store.go index e03b22161b..b02cf5c0dc 100644 --- a/store/store.go +++ b/store/store.go @@ -338,12 +338,12 @@ type ClusterDiscoveryStore interface { } type ComplianceStore interface { - Save(compliance *model.Compliance) StoreChannel - Update(compliance *model.Compliance) StoreChannel - Get(id string) StoreChannel - GetAll(offset, limit int) StoreChannel - ComplianceExport(compliance *model.Compliance) StoreChannel - MessageExport(after int64, limit int) StoreChannel + Save(compliance *model.Compliance) (*model.Compliance, *model.AppError) + Update(compliance *model.Compliance) (*model.Compliance, *model.AppError) + Get(id string) (*model.Compliance, *model.AppError) + GetAll(offset, limit int) (model.Compliances, *model.AppError) + ComplianceExport(compliance *model.Compliance) ([]*model.CompliancePost, *model.AppError) + MessageExport(after int64, limit int) ([]*model.MessageExport, *model.AppError) } type OAuthStore interface { diff --git a/store/storetest/compliance_store.go b/store/storetest/compliance_store.go index 14ed298651..7ce84d1764 100644 --- a/store/storetest/compliance_store.go +++ b/store/storetest/compliance_store.go @@ -25,45 +25,40 @@ func TestComplianceStore(t *testing.T, ss store.Store) { func testComplianceStore(t *testing.T, ss store.Store) { compliance1 := &model.Compliance{Desc: "Audit for federal subpoena case #22443", UserId: model.NewId(), Status: model.COMPLIANCE_STATUS_FAILED, StartAt: model.GetMillis() - 1, EndAt: model.GetMillis() + 1, Type: model.COMPLIANCE_TYPE_ADHOC} - store.Must(ss.Compliance().Save(compliance1)) + _, err := ss.Compliance().Save(compliance1) + require.Nil(t, err) time.Sleep(100 * time.Millisecond) compliance2 := &model.Compliance{Desc: "Audit for federal subpoena case #11458", UserId: model.NewId(), Status: model.COMPLIANCE_STATUS_RUNNING, StartAt: model.GetMillis() - 1, EndAt: model.GetMillis() + 1, Type: model.COMPLIANCE_TYPE_ADHOC} - store.Must(ss.Compliance().Save(compliance2)) + _, err = ss.Compliance().Save(compliance2) + require.Nil(t, err) time.Sleep(100 * time.Millisecond) - c := ss.Compliance().GetAll(0, 1000) - result := <-c - compliances := result.Data.(model.Compliances) + compliances, _ := ss.Compliance().GetAll(0, 1000) require.Equal(t, model.COMPLIANCE_STATUS_RUNNING, compliances[0].Status) require.Equal(t, compliance2.Id, compliances[0].Id) compliance2.Status = model.COMPLIANCE_STATUS_FAILED - store.Must(ss.Compliance().Update(compliance2)) + _, err = ss.Compliance().Update(compliance2) + require.Nil(t, err) - c = ss.Compliance().GetAll(0, 1000) - result = <-c - compliances = result.Data.(model.Compliances) + compliances, _ = ss.Compliance().GetAll(0, 1000) require.Equal(t, model.COMPLIANCE_STATUS_FAILED, compliances[0].Status) require.Equal(t, compliance2.Id, compliances[0].Id) - c = ss.Compliance().GetAll(0, 1) - result = <-c - compliances = result.Data.(model.Compliances) + compliances, _ = ss.Compliance().GetAll(0, 1) require.Len(t, compliances, 1) - c = ss.Compliance().GetAll(1, 1) - result = <-c - compliances = result.Data.(model.Compliances) + compliances, _ = ss.Compliance().GetAll(1, 1) if len(compliances) != 1 { t.Fatal("should only have returned 1") } - rc2 := (<-ss.Compliance().Get(compliance2.Id)).Data.(*model.Compliance) + rc2, _ := ss.Compliance().Get(compliance2.Id) require.Equal(t, compliance2.Status, rc2.Status) } @@ -127,106 +122,43 @@ func testComplianceExport(t *testing.T, ss store.Store) { time.Sleep(100 * time.Millisecond) cr1 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1} - if r1 := <-ss.Compliance().ComplianceExport(cr1); r1.Err != nil { - t.Fatal(r1.Err) - } else { - cposts := r1.Data.([]*model.CompliancePost) - - if len(cposts) != 4 { - t.Fatal("return wrong results length") - } - - if cposts[0].PostId != o1.Id { - t.Fatal("Wrong sort") - } - - if cposts[3].PostId != o2a.Id { - t.Fatal("Wrong sort") - } - } + cposts, err := ss.Compliance().ComplianceExport(cr1) + require.Nil(t, err) + assert.Len(t, cposts, 4) + assert.Equal(t, cposts[0].PostId, o1.Id) + assert.Equal(t, cposts[3].PostId, o2a.Id) cr2 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1, Emails: u2.Email} - if r1 := <-ss.Compliance().ComplianceExport(cr2); r1.Err != nil { - t.Fatal(r1.Err) - } else { - cposts := r1.Data.([]*model.CompliancePost) - - if len(cposts) != 1 { - t.Fatal("return wrong results length") - } - - if cposts[0].PostId != o2a.Id { - t.Fatal("Wrong sort") - } - } + cposts, err = ss.Compliance().ComplianceExport(cr2) + require.Nil(t, err) + assert.Len(t, cposts, 1) + assert.Equal(t, cposts[0].PostId, o2a.Id) cr3 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1, Emails: u2.Email + ", " + u1.Email} - if r1 := <-ss.Compliance().ComplianceExport(cr3); r1.Err != nil { - t.Fatal(r1.Err) - } else { - cposts := r1.Data.([]*model.CompliancePost) - - if len(cposts) != 4 { - t.Fatal("return wrong results length") - } - - if cposts[0].PostId != o1.Id { - t.Fatal("Wrong sort") - } - - if cposts[3].PostId != o2a.Id { - t.Fatal("Wrong sort") - } - } + cposts, err = ss.Compliance().ComplianceExport(cr3) + require.Nil(t, err) + assert.Len(t, cposts, 4) + assert.Equal(t, cposts[0].PostId, o1.Id) + assert.Equal(t, cposts[3].PostId, o2a.Id) cr4 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1, Keywords: o2a.Message} - if r1 := <-ss.Compliance().ComplianceExport(cr4); r1.Err != nil { - t.Fatal(r1.Err) - } else { - cposts := r1.Data.([]*model.CompliancePost) - - if len(cposts) != 1 { - t.Fatal("return wrong results length") - } - - if cposts[0].PostId != o2a.Id { - t.Fatal("Wrong sort") - } - } + cposts, err = ss.Compliance().ComplianceExport(cr4) + require.Nil(t, err) + assert.Len(t, cposts, 1) + assert.Equal(t, cposts[0].PostId, o2a.Id) cr5 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1, Keywords: o2a.Message + " " + o1.Message} - if r1 := <-ss.Compliance().ComplianceExport(cr5); r1.Err != nil { - t.Fatal(r1.Err) - } else { - cposts := r1.Data.([]*model.CompliancePost) - - if len(cposts) != 2 { - t.Fatal("return wrong results length") - } - - if cposts[0].PostId != o1.Id { - t.Fatal("Wrong sort") - } - } + cposts, err = ss.Compliance().ComplianceExport(cr5) + require.Nil(t, err) + assert.Len(t, cposts, 2) + assert.Equal(t, cposts[0].PostId, o1.Id) cr6 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1, Emails: u2.Email + ", " + u1.Email, Keywords: o2a.Message + " " + o1.Message} - if r1 := <-ss.Compliance().ComplianceExport(cr6); r1.Err != nil { - t.Fatal(r1.Err) - } else { - cposts := r1.Data.([]*model.CompliancePost) - - if len(cposts) != 2 { - t.Fatal("return wrong results length") - } - - if cposts[0].PostId != o1.Id { - t.Fatal("Wrong sort") - } - - if cposts[1].PostId != o2a.Id { - t.Fatal("Wrong sort") - } - } + cposts, err = ss.Compliance().ComplianceExport(cr6) + require.Nil(t, err) + assert.Len(t, cposts, 2) + assert.Equal(t, cposts[0].PostId, o1.Id) + assert.Equal(t, cposts[1].PostId, o2a.Id) } func testComplianceExportDirectMessages(t *testing.T, ss store.Store) { @@ -298,35 +230,19 @@ func testComplianceExportDirectMessages(t *testing.T, ss store.Store) { time.Sleep(100 * time.Millisecond) cr1 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o3.CreateAt + 1, Emails: u1.Email} - if r1 := <-ss.Compliance().ComplianceExport(cr1); r1.Err != nil { - t.Fatal(r1.Err) - } else { - cposts := r1.Data.([]*model.CompliancePost) - - if len(cposts) != 4 { - t.Fatal("return wrong results length") - } - - if cposts[0].PostId != o1.Id { - t.Fatal("Wrong sort") - } - - if cposts[len(cposts)-1].PostId != o3.Id { - t.Fatal("Wrong sort") - } - } + cposts, err := ss.Compliance().ComplianceExport(cr1) + require.Nil(t, err) + assert.Len(t, cposts, 4) + assert.Equal(t, cposts[0].PostId, o1.Id) + assert.Equal(t, cposts[len(cposts)-1].PostId, o3.Id) } func testMessageExportPublicChannel(t *testing.T, ss store.Store) { // get the starting number of message export entries startTime := model.GetMillis() - var numMessageExports = 0 - if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil { - t.Fatal(r1.Err) - } else { - messages := r1.Data.([]*model.MessageExport) - numMessageExports = len(messages) - } + messages, err := ss.Compliance().MessageExport(startTime-10, 10) + require.Nil(t, err) + numMessageExports := len(messages) // need a team team := &model.Team{ @@ -386,15 +302,12 @@ func testMessageExportPublicChannel(t *testing.T, ss store.Store) { // fetch the message exports for both posts that user1 sent messageExportMap := map[string]model.MessageExport{} - if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil { - t.Fatal(r1.Err) - } else { - messages := r1.Data.([]*model.MessageExport) - assert.Equal(t, numMessageExports+2, len(messages)) + messages, err = ss.Compliance().MessageExport(startTime-10, 10) + require.Nil(t, err) + assert.Equal(t, numMessageExports+2, len(messages)) - for _, v := range messages { - messageExportMap[*v.PostId] = *v - } + for _, v := range messages { + messageExportMap[*v.PostId] = *v } // post1 was made by user1 in channel1 and team1 @@ -421,13 +334,9 @@ func testMessageExportPublicChannel(t *testing.T, ss store.Store) { func testMessageExportPrivateChannel(t *testing.T, ss store.Store) { // get the starting number of message export entries startTime := model.GetMillis() - var numMessageExports = 0 - if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil { - t.Fatal(r1.Err) - } else { - messages := r1.Data.([]*model.MessageExport) - numMessageExports = len(messages) - } + messages, err := ss.Compliance().MessageExport(startTime-10, 10) + require.Nil(t, err) + numMessageExports := len(messages) // need a team team := &model.Team{ @@ -487,15 +396,12 @@ func testMessageExportPrivateChannel(t *testing.T, ss store.Store) { // fetch the message exports for both posts that user1 sent messageExportMap := map[string]model.MessageExport{} - if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil { - t.Fatal(r1.Err) - } else { - messages := r1.Data.([]*model.MessageExport) - assert.Equal(t, numMessageExports+2, len(messages)) + messages, err = ss.Compliance().MessageExport(startTime-10, 10) + require.Nil(t, err) + assert.Equal(t, numMessageExports+2, len(messages)) - for _, v := range messages { - messageExportMap[*v.PostId] = *v - } + for _, v := range messages { + messageExportMap[*v.PostId] = *v } // post1 was made by user1 in channel1 and team1 @@ -524,13 +430,9 @@ func testMessageExportPrivateChannel(t *testing.T, ss store.Store) { func testMessageExportDirectMessageChannel(t *testing.T, ss store.Store) { // get the starting number of message export entries startTime := model.GetMillis() - var numMessageExports = 0 - if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil { - t.Fatal(r1.Err) - } else { - messages := r1.Data.([]*model.MessageExport) - numMessageExports = len(messages) - } + messages, err := ss.Compliance().MessageExport(startTime-10, 10) + require.Nil(t, err) + numMessageExports := len(messages) // need a team team := &model.Team{ @@ -576,15 +478,13 @@ func testMessageExportDirectMessageChannel(t *testing.T, ss store.Store) { // fetch the message export for the post that user1 sent messageExportMap := map[string]model.MessageExport{} - if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil { - t.Fatal(r1.Err) - } else { - messages := r1.Data.([]*model.MessageExport) - assert.Equal(t, numMessageExports+1, len(messages)) + messages, err = ss.Compliance().MessageExport(startTime-10, 10) + require.Nil(t, err) - for _, v := range messages { - messageExportMap[*v.PostId] = *v - } + assert.Equal(t, numMessageExports+1, len(messages)) + + for _, v := range messages { + messageExportMap[*v.PostId] = *v } // post is a DM between user1 and user2 @@ -602,13 +502,9 @@ func testMessageExportDirectMessageChannel(t *testing.T, ss store.Store) { func testMessageExportGroupMessageChannel(t *testing.T, ss store.Store) { // get the starting number of message export entries startTime := model.GetMillis() - var numMessageExports = 0 - if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil { - t.Fatal(r1.Err) - } else { - messages := r1.Data.([]*model.MessageExport) - numMessageExports = len(messages) - } + messages, err := ss.Compliance().MessageExport(startTime-10, 10) + require.Nil(t, err) + numMessageExports := len(messages) // need a team team := &model.Team{ @@ -669,15 +565,12 @@ func testMessageExportGroupMessageChannel(t *testing.T, ss store.Store) { // fetch the message export for the post that user1 sent messageExportMap := map[string]model.MessageExport{} - if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil { - t.Fatal(r1.Err) - } else { - messages := r1.Data.([]*model.MessageExport) - assert.Equal(t, numMessageExports+1, len(messages)) + messages, err = ss.Compliance().MessageExport(startTime-10, 10) + require.Nil(t, err) + assert.Equal(t, numMessageExports+1, len(messages)) - for _, v := range messages { - messageExportMap[*v.PostId] = *v - } + for _, v := range messages { + messageExportMap[*v.PostId] = *v } // post is a DM between user1 and user2 diff --git a/store/storetest/mocks/ComplianceStore.go b/store/storetest/mocks/ComplianceStore.go index dd75941b3c..b175a9a7ba 100644 --- a/store/storetest/mocks/ComplianceStore.go +++ b/store/storetest/mocks/ComplianceStore.go @@ -6,7 +6,6 @@ package mocks import mock "github.com/stretchr/testify/mock" import model "github.com/mattermost/mattermost-server/model" -import store "github.com/mattermost/mattermost-server/store" // ComplianceStore is an autogenerated mock type for the ComplianceStore type type ComplianceStore struct { @@ -14,97 +13,151 @@ type ComplianceStore struct { } // ComplianceExport provides a mock function with given fields: compliance -func (_m *ComplianceStore) ComplianceExport(compliance *model.Compliance) store.StoreChannel { +func (_m *ComplianceStore) ComplianceExport(compliance *model.Compliance) ([]*model.CompliancePost, *model.AppError) { ret := _m.Called(compliance) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(*model.Compliance) store.StoreChannel); ok { + var r0 []*model.CompliancePost + if rf, ok := ret.Get(0).(func(*model.Compliance) []*model.CompliancePost); ok { r0 = rf(compliance) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).([]*model.CompliancePost) } } - return r0 + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.Compliance) *model.AppError); ok { + r1 = rf(compliance) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 } // Get provides a mock function with given fields: id -func (_m *ComplianceStore) Get(id string) store.StoreChannel { +func (_m *ComplianceStore) Get(id string) (*model.Compliance, *model.AppError) { ret := _m.Called(id) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(string) store.StoreChannel); ok { + var r0 *model.Compliance + if rf, ok := ret.Get(0).(func(string) *model.Compliance); ok { r0 = rf(id) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).(*model.Compliance) } } - return r0 + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(string) *model.AppError); ok { + r1 = rf(id) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 } // GetAll provides a mock function with given fields: offset, limit -func (_m *ComplianceStore) GetAll(offset int, limit int) store.StoreChannel { +func (_m *ComplianceStore) GetAll(offset int, limit int) (model.Compliances, *model.AppError) { ret := _m.Called(offset, limit) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(int, int) store.StoreChannel); ok { + var r0 model.Compliances + if rf, ok := ret.Get(0).(func(int, int) model.Compliances); ok { r0 = rf(offset, limit) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).(model.Compliances) } } - return r0 + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(int, int) *model.AppError); ok { + r1 = rf(offset, limit) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 } // MessageExport provides a mock function with given fields: after, limit -func (_m *ComplianceStore) MessageExport(after int64, limit int) store.StoreChannel { +func (_m *ComplianceStore) MessageExport(after int64, limit int) ([]*model.MessageExport, *model.AppError) { ret := _m.Called(after, limit) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(int64, int) store.StoreChannel); ok { + var r0 []*model.MessageExport + if rf, ok := ret.Get(0).(func(int64, int) []*model.MessageExport); ok { r0 = rf(after, limit) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).([]*model.MessageExport) } } - return r0 + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(int64, int) *model.AppError); ok { + r1 = rf(after, limit) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 } // Save provides a mock function with given fields: compliance -func (_m *ComplianceStore) Save(compliance *model.Compliance) store.StoreChannel { +func (_m *ComplianceStore) Save(compliance *model.Compliance) (*model.Compliance, *model.AppError) { ret := _m.Called(compliance) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(*model.Compliance) store.StoreChannel); ok { + var r0 *model.Compliance + if rf, ok := ret.Get(0).(func(*model.Compliance) *model.Compliance); ok { r0 = rf(compliance) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).(*model.Compliance) } } - return r0 + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.Compliance) *model.AppError); ok { + r1 = rf(compliance) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 } // Update provides a mock function with given fields: compliance -func (_m *ComplianceStore) Update(compliance *model.Compliance) store.StoreChannel { +func (_m *ComplianceStore) Update(compliance *model.Compliance) (*model.Compliance, *model.AppError) { ret := _m.Called(compliance) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(*model.Compliance) store.StoreChannel); ok { + var r0 *model.Compliance + if rf, ok := ret.Get(0).(func(*model.Compliance) *model.Compliance); ok { r0 = rf(compliance) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).(*model.Compliance) } } - return r0 + var r1 *model.AppError + if rf, ok := ret.Get(1).(func(*model.Compliance) *model.AppError); ok { + r1 = rf(compliance) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*model.AppError) + } + } + + return r0, r1 }