Migrating compliance store to sync by default (#10660)

* Migrating compliance store to sync by default

* Addressing PR review comments
This commit is contained in:
Jesús Espino
2019-04-24 12:50:25 +02:00
committed by Miguel de la Cruz
parent 4ae38d00a8
commit 7c7ff93d97
5 changed files with 347 additions and 426 deletions

View File

@@ -16,11 +16,7 @@ func (a *App) GetComplianceReports(page, perPage int) (model.Compliances, *model
return nil, model.NewAppError("GetComplianceReports", "ent.compliance.licence_disable.app_error", nil, "", http.StatusNotImplemented)
}
result := <-a.Srv.Store.Compliance().GetAll(page*perPage, perPage)
if result.Err != nil {
return nil, result.Err
}
return result.Data.(model.Compliances), nil
return a.Srv.Store.Compliance().GetAll(page*perPage, perPage)
}
func (a *App) SaveComplianceReport(job *model.Compliance) (*model.Compliance, *model.AppError) {
@@ -30,12 +26,11 @@ func (a *App) SaveComplianceReport(job *model.Compliance) (*model.Compliance, *m
job.Type = model.COMPLIANCE_TYPE_ADHOC
result := <-a.Srv.Store.Compliance().Save(job)
if result.Err != nil {
return nil, result.Err
job, err := a.Srv.Store.Compliance().Save(job)
if err != nil {
return nil, err
}
job = result.Data.(*model.Compliance)
a.Srv.Go(func() {
a.Compliance.RunComplianceJob(job)
})
@@ -48,11 +43,7 @@ func (a *App) GetComplianceReport(reportId string) (*model.Compliance, *model.Ap
return nil, model.NewAppError("downloadComplianceReport", "ent.compliance.licence_disable.app_error", nil, "", http.StatusNotImplemented)
}
result := <-a.Srv.Store.Compliance().Get(reportId)
if result.Err != nil {
return nil, result.Err
}
return result.Data.(*model.Compliance), nil
return a.Srv.Store.Compliance().Get(reportId)
}
func (a *App) GetComplianceFile(job *model.Compliance) ([]byte, *model.AppError) {

View File

@@ -36,226 +36,210 @@ func NewSqlComplianceStore(sqlStore SqlStore) store.ComplianceStore {
func (s SqlComplianceStore) CreateIndexesIfNotExists() {
}
func (s SqlComplianceStore) Save(compliance *model.Compliance) store.StoreChannel {
return store.Do(func(result *store.StoreResult) {
compliance.PreSave()
if result.Err = compliance.IsValid(); result.Err != nil {
return
}
func (s SqlComplianceStore) Save(compliance *model.Compliance) (*model.Compliance, *model.AppError) {
compliance.PreSave()
if err := compliance.IsValid(); err != nil {
return nil, err
}
if err := s.GetMaster().Insert(compliance); err != nil {
result.Err = model.NewAppError("SqlComplianceStore.Save", "store.sql_compliance.save.saving.app_error", nil, err.Error(), http.StatusInternalServerError)
} else {
result.Data = compliance
}
})
if err := s.GetMaster().Insert(compliance); err != nil {
return nil, model.NewAppError("SqlComplianceStore.Save", "store.sql_compliance.save.saving.app_error", nil, err.Error(), http.StatusInternalServerError)
}
return compliance, nil
}
func (us SqlComplianceStore) Update(compliance *model.Compliance) store.StoreChannel {
return store.Do(func(result *store.StoreResult) {
if result.Err = compliance.IsValid(); result.Err != nil {
return
}
func (us SqlComplianceStore) Update(compliance *model.Compliance) (*model.Compliance, *model.AppError) {
if err := compliance.IsValid(); err != nil {
return nil, err
}
if _, err := us.GetMaster().Update(compliance); err != nil {
result.Err = model.NewAppError("SqlComplianceStore.Update", "store.sql_compliance.save.saving.app_error", nil, err.Error(), http.StatusInternalServerError)
} else {
result.Data = compliance
}
})
if _, err := us.GetMaster().Update(compliance); err != nil {
return nil, model.NewAppError("SqlComplianceStore.Update", "store.sql_compliance.save.saving.app_error", nil, err.Error(), http.StatusInternalServerError)
}
return compliance, nil
}
func (s SqlComplianceStore) GetAll(offset, limit int) store.StoreChannel {
return store.Do(func(result *store.StoreResult) {
query := "SELECT * FROM Compliances ORDER BY CreateAt DESC LIMIT :Limit OFFSET :Offset"
func (s SqlComplianceStore) GetAll(offset, limit int) (model.Compliances, *model.AppError) {
query := "SELECT * FROM Compliances ORDER BY CreateAt DESC LIMIT :Limit OFFSET :Offset"
var compliances model.Compliances
if _, err := s.GetReplica().Select(&compliances, query, map[string]interface{}{"Offset": offset, "Limit": limit}); err != nil {
result.Err = model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusInternalServerError)
} else {
result.Data = compliances
}
})
var compliances model.Compliances
if _, err := s.GetReplica().Select(&compliances, query, map[string]interface{}{"Offset": offset, "Limit": limit}); err != nil {
return nil, model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusInternalServerError)
}
return compliances, nil
}
func (us SqlComplianceStore) Get(id string) store.StoreChannel {
return store.Do(func(result *store.StoreResult) {
if obj, err := us.GetReplica().Get(model.Compliance{}, id); err != nil {
result.Err = model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusInternalServerError)
} else if obj == nil {
result.Err = model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusNotFound)
} else {
result.Data = obj.(*model.Compliance)
}
})
func (us SqlComplianceStore) Get(id string) (*model.Compliance, *model.AppError) {
obj, err := us.GetReplica().Get(model.Compliance{}, id)
if err != nil {
return nil, model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusInternalServerError)
}
if obj == nil {
return nil, model.NewAppError("SqlComplianceStore.Get", "store.sql_compliance.get.finding.app_error", nil, err.Error(), http.StatusNotFound)
}
return obj.(*model.Compliance), nil
}
func (s SqlComplianceStore) ComplianceExport(job *model.Compliance) store.StoreChannel {
return store.Do(func(result *store.StoreResult) {
props := map[string]interface{}{"StartTime": job.StartAt, "EndTime": job.EndAt}
func (s SqlComplianceStore) ComplianceExport(job *model.Compliance) ([]*model.CompliancePost, *model.AppError) {
props := map[string]interface{}{"StartTime": job.StartAt, "EndTime": job.EndAt}
keywordQuery := ""
keywords := strings.Fields(strings.TrimSpace(strings.ToLower(strings.Replace(job.Keywords, ",", " ", -1))))
if len(keywords) > 0 {
keywordQuery := ""
keywords := strings.Fields(strings.TrimSpace(strings.ToLower(strings.Replace(job.Keywords, ",", " ", -1))))
if len(keywords) > 0 {
keywordQuery = "AND ("
keywordQuery = "AND ("
for index, keyword := range keywords {
if index >= 1 {
keywordQuery += " OR LOWER(Posts.Message) LIKE :Keyword" + strconv.Itoa(index)
} else {
keywordQuery += "LOWER(Posts.Message) LIKE :Keyword" + strconv.Itoa(index)
}
props["Keyword"+strconv.Itoa(index)] = "%" + keyword + "%"
for index, keyword := range keywords {
if index >= 1 {
keywordQuery += " OR LOWER(Posts.Message) LIKE :Keyword" + strconv.Itoa(index)
} else {
keywordQuery += "LOWER(Posts.Message) LIKE :Keyword" + strconv.Itoa(index)
}
keywordQuery += ")"
props["Keyword"+strconv.Itoa(index)] = "%" + keyword + "%"
}
emailQuery := ""
emails := strings.Fields(strings.TrimSpace(strings.ToLower(strings.Replace(job.Emails, ",", " ", -1))))
if len(emails) > 0 {
keywordQuery += ")"
}
emailQuery = "AND ("
emailQuery := ""
emails := strings.Fields(strings.TrimSpace(strings.ToLower(strings.Replace(job.Emails, ",", " ", -1))))
if len(emails) > 0 {
for index, email := range emails {
if index >= 1 {
emailQuery += " OR Users.Email = :Email" + strconv.Itoa(index)
} else {
emailQuery += "Users.Email = :Email" + strconv.Itoa(index)
}
emailQuery = "AND ("
props["Email"+strconv.Itoa(index)] = email
for index, email := range emails {
if index >= 1 {
emailQuery += " OR Users.Email = :Email" + strconv.Itoa(index)
} else {
emailQuery += "Users.Email = :Email" + strconv.Itoa(index)
}
emailQuery += ")"
props["Email"+strconv.Itoa(index)] = email
}
query :=
`(SELECT
Teams.Name AS TeamName,
Teams.DisplayName AS TeamDisplayName,
Channels.Name AS ChannelName,
Channels.DisplayName AS ChannelDisplayName,
Channels.Type AS ChannelType,
Users.Username AS UserUsername,
Users.Email AS UserEmail,
Users.Nickname AS UserNickname,
Posts.Id AS PostId,
Posts.CreateAt AS PostCreateAt,
Posts.UpdateAt AS PostUpdateAt,
Posts.DeleteAt AS PostDeleteAt,
Posts.RootId AS PostRootId,
Posts.ParentId AS PostParentId,
Posts.OriginalId AS PostOriginalId,
Posts.Message AS PostMessage,
Posts.Type AS PostType,
Posts.Props AS PostProps,
Posts.Hashtags AS PostHashtags,
Posts.FileIds AS PostFileIds
FROM
Teams,
Channels,
Users,
Posts
WHERE
Teams.Id = Channels.TeamId
AND Posts.ChannelId = Channels.Id
AND Posts.UserId = Users.Id
AND Posts.CreateAt > :StartTime
AND Posts.CreateAt <= :EndTime
` + emailQuery + `
` + keywordQuery + `)
UNION ALL
(SELECT
'direct-messages' AS TeamName,
'Direct Messages' AS TeamDisplayName,
Channels.Name AS ChannelName,
Channels.DisplayName AS ChannelDisplayName,
Channels.Type AS ChannelType,
Users.Username AS UserUsername,
Users.Email AS UserEmail,
Users.Nickname AS UserNickname,
Posts.Id AS PostId,
Posts.CreateAt AS PostCreateAt,
Posts.UpdateAt AS PostUpdateAt,
Posts.DeleteAt AS PostDeleteAt,
Posts.RootId AS PostRootId,
Posts.ParentId AS PostParentId,
Posts.OriginalId AS PostOriginalId,
Posts.Message AS PostMessage,
Posts.Type AS PostType,
Posts.Props AS PostProps,
Posts.Hashtags AS PostHashtags,
Posts.FileIds AS PostFileIds
FROM
Channels,
Users,
Posts
WHERE
Channels.TeamId = ''
AND Posts.ChannelId = Channels.Id
AND Posts.UserId = Users.Id
AND Posts.CreateAt > :StartTime
AND Posts.CreateAt <= :EndTime
` + emailQuery + `
` + keywordQuery + `)
ORDER BY PostCreateAt
LIMIT 30000`
emailQuery += ")"
}
var cposts []*model.CompliancePost
query :=
`(SELECT
Teams.Name AS TeamName,
Teams.DisplayName AS TeamDisplayName,
Channels.Name AS ChannelName,
Channels.DisplayName AS ChannelDisplayName,
Channels.Type AS ChannelType,
Users.Username AS UserUsername,
Users.Email AS UserEmail,
Users.Nickname AS UserNickname,
Posts.Id AS PostId,
Posts.CreateAt AS PostCreateAt,
Posts.UpdateAt AS PostUpdateAt,
Posts.DeleteAt AS PostDeleteAt,
Posts.RootId AS PostRootId,
Posts.ParentId AS PostParentId,
Posts.OriginalId AS PostOriginalId,
Posts.Message AS PostMessage,
Posts.Type AS PostType,
Posts.Props AS PostProps,
Posts.Hashtags AS PostHashtags,
Posts.FileIds AS PostFileIds
FROM
Teams,
Channels,
Users,
Posts
WHERE
Teams.Id = Channels.TeamId
AND Posts.ChannelId = Channels.Id
AND Posts.UserId = Users.Id
AND Posts.CreateAt > :StartTime
AND Posts.CreateAt <= :EndTime
` + emailQuery + `
` + keywordQuery + `)
UNION ALL
(SELECT
'direct-messages' AS TeamName,
'Direct Messages' AS TeamDisplayName,
Channels.Name AS ChannelName,
Channels.DisplayName AS ChannelDisplayName,
Channels.Type AS ChannelType,
Users.Username AS UserUsername,
Users.Email AS UserEmail,
Users.Nickname AS UserNickname,
Posts.Id AS PostId,
Posts.CreateAt AS PostCreateAt,
Posts.UpdateAt AS PostUpdateAt,
Posts.DeleteAt AS PostDeleteAt,
Posts.RootId AS PostRootId,
Posts.ParentId AS PostParentId,
Posts.OriginalId AS PostOriginalId,
Posts.Message AS PostMessage,
Posts.Type AS PostType,
Posts.Props AS PostProps,
Posts.Hashtags AS PostHashtags,
Posts.FileIds AS PostFileIds
FROM
Channels,
Users,
Posts
WHERE
Channels.TeamId = ''
AND Posts.ChannelId = Channels.Id
AND Posts.UserId = Users.Id
AND Posts.CreateAt > :StartTime
AND Posts.CreateAt <= :EndTime
` + emailQuery + `
` + keywordQuery + `)
ORDER BY PostCreateAt
LIMIT 30000`
if _, err := s.GetReplica().Select(&cposts, query, props); err != nil {
result.Err = model.NewAppError("SqlPostStore.ComplianceExport", "store.sql_post.compliance_export.app_error", nil, err.Error(), http.StatusInternalServerError)
} else {
result.Data = cposts
}
})
var cposts []*model.CompliancePost
if _, err := s.GetReplica().Select(&cposts, query, props); err != nil {
return nil, model.NewAppError("SqlPostStore.ComplianceExport", "store.sql_post.compliance_export.app_error", nil, err.Error(), http.StatusInternalServerError)
}
return cposts, nil
}
func (s SqlComplianceStore) MessageExport(after int64, limit int) store.StoreChannel {
return store.Do(func(result *store.StoreResult) {
props := map[string]interface{}{"StartTime": after, "Limit": limit}
query :=
`SELECT
Posts.Id AS PostId,
Posts.CreateAt AS PostCreateAt,
Posts.Message AS PostMessage,
Posts.Type AS PostType,
Posts.OriginalId AS PostOriginalId,
Posts.RootId AS PostRootId,
Posts.FileIds AS PostFileIds,
Teams.Id AS TeamId,
Teams.Name AS TeamName,
Teams.DisplayName AS TeamDisplayName,
Channels.Id AS ChannelId,
CASE
WHEN Channels.Type = 'D' THEN 'Direct Message'
WHEN Channels.Type = 'G' THEN 'Group Message'
ELSE Channels.DisplayName
END AS ChannelDisplayName,
Channels.Name AS ChannelName,
Channels.Type AS ChannelType,
Users.Id AS UserId,
Users.Email AS UserEmail,
Users.Username
FROM
Posts
LEFT OUTER JOIN Channels ON Posts.ChannelId = Channels.Id
LEFT OUTER JOIN Teams ON Channels.TeamId = Teams.Id
LEFT OUTER JOIN Users ON Posts.UserId = Users.Id
WHERE
Posts.CreateAt > :StartTime AND
Posts.Type = ''
ORDER BY PostCreateAt
LIMIT :Limit`
func (s SqlComplianceStore) MessageExport(after int64, limit int) ([]*model.MessageExport, *model.AppError) {
props := map[string]interface{}{"StartTime": after, "Limit": limit}
query :=
`SELECT
Posts.Id AS PostId,
Posts.CreateAt AS PostCreateAt,
Posts.Message AS PostMessage,
Posts.Type AS PostType,
Posts.OriginalId AS PostOriginalId,
Posts.RootId AS PostRootId,
Posts.FileIds AS PostFileIds,
Teams.Id AS TeamId,
Teams.Name AS TeamName,
Teams.DisplayName AS TeamDisplayName,
Channels.Id AS ChannelId,
CASE
WHEN Channels.Type = 'D' THEN 'Direct Message'
WHEN Channels.Type = 'G' THEN 'Group Message'
ELSE Channels.DisplayName
END AS ChannelDisplayName,
Channels.Name AS ChannelName,
Channels.Type AS ChannelType,
Users.Id AS UserId,
Users.Email AS UserEmail,
Users.Username
FROM
Posts
LEFT OUTER JOIN Channels ON Posts.ChannelId = Channels.Id
LEFT OUTER JOIN Teams ON Channels.TeamId = Teams.Id
LEFT OUTER JOIN Users ON Posts.UserId = Users.Id
WHERE
Posts.CreateAt > :StartTime AND
Posts.Type = ''
ORDER BY PostCreateAt
LIMIT :Limit`
var cposts []*model.MessageExport
if _, err := s.GetReplica().Select(&cposts, query, props); err != nil {
result.Err = model.NewAppError("SqlComplianceStore.MessageExport", "store.sql_compliance.message_export.app_error", nil, err.Error(), http.StatusInternalServerError)
} else {
result.Data = cposts
}
})
var cposts []*model.MessageExport
if _, err := s.GetReplica().Select(&cposts, query, props); err != nil {
return nil, model.NewAppError("SqlComplianceStore.MessageExport", "store.sql_compliance.message_export.app_error", nil, err.Error(), http.StatusInternalServerError)
}
return cposts, nil
}

View File

@@ -338,12 +338,12 @@ type ClusterDiscoveryStore interface {
}
type ComplianceStore interface {
Save(compliance *model.Compliance) StoreChannel
Update(compliance *model.Compliance) StoreChannel
Get(id string) StoreChannel
GetAll(offset, limit int) StoreChannel
ComplianceExport(compliance *model.Compliance) StoreChannel
MessageExport(after int64, limit int) StoreChannel
Save(compliance *model.Compliance) (*model.Compliance, *model.AppError)
Update(compliance *model.Compliance) (*model.Compliance, *model.AppError)
Get(id string) (*model.Compliance, *model.AppError)
GetAll(offset, limit int) (model.Compliances, *model.AppError)
ComplianceExport(compliance *model.Compliance) ([]*model.CompliancePost, *model.AppError)
MessageExport(after int64, limit int) ([]*model.MessageExport, *model.AppError)
}
type OAuthStore interface {

View File

@@ -25,45 +25,40 @@ func TestComplianceStore(t *testing.T, ss store.Store) {
func testComplianceStore(t *testing.T, ss store.Store) {
compliance1 := &model.Compliance{Desc: "Audit for federal subpoena case #22443", UserId: model.NewId(), Status: model.COMPLIANCE_STATUS_FAILED, StartAt: model.GetMillis() - 1, EndAt: model.GetMillis() + 1, Type: model.COMPLIANCE_TYPE_ADHOC}
store.Must(ss.Compliance().Save(compliance1))
_, err := ss.Compliance().Save(compliance1)
require.Nil(t, err)
time.Sleep(100 * time.Millisecond)
compliance2 := &model.Compliance{Desc: "Audit for federal subpoena case #11458", UserId: model.NewId(), Status: model.COMPLIANCE_STATUS_RUNNING, StartAt: model.GetMillis() - 1, EndAt: model.GetMillis() + 1, Type: model.COMPLIANCE_TYPE_ADHOC}
store.Must(ss.Compliance().Save(compliance2))
_, err = ss.Compliance().Save(compliance2)
require.Nil(t, err)
time.Sleep(100 * time.Millisecond)
c := ss.Compliance().GetAll(0, 1000)
result := <-c
compliances := result.Data.(model.Compliances)
compliances, _ := ss.Compliance().GetAll(0, 1000)
require.Equal(t, model.COMPLIANCE_STATUS_RUNNING, compliances[0].Status)
require.Equal(t, compliance2.Id, compliances[0].Id)
compliance2.Status = model.COMPLIANCE_STATUS_FAILED
store.Must(ss.Compliance().Update(compliance2))
_, err = ss.Compliance().Update(compliance2)
require.Nil(t, err)
c = ss.Compliance().GetAll(0, 1000)
result = <-c
compliances = result.Data.(model.Compliances)
compliances, _ = ss.Compliance().GetAll(0, 1000)
require.Equal(t, model.COMPLIANCE_STATUS_FAILED, compliances[0].Status)
require.Equal(t, compliance2.Id, compliances[0].Id)
c = ss.Compliance().GetAll(0, 1)
result = <-c
compliances = result.Data.(model.Compliances)
compliances, _ = ss.Compliance().GetAll(0, 1)
require.Len(t, compliances, 1)
c = ss.Compliance().GetAll(1, 1)
result = <-c
compliances = result.Data.(model.Compliances)
compliances, _ = ss.Compliance().GetAll(1, 1)
if len(compliances) != 1 {
t.Fatal("should only have returned 1")
}
rc2 := (<-ss.Compliance().Get(compliance2.Id)).Data.(*model.Compliance)
rc2, _ := ss.Compliance().Get(compliance2.Id)
require.Equal(t, compliance2.Status, rc2.Status)
}
@@ -127,106 +122,43 @@ func testComplianceExport(t *testing.T, ss store.Store) {
time.Sleep(100 * time.Millisecond)
cr1 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1}
if r1 := <-ss.Compliance().ComplianceExport(cr1); r1.Err != nil {
t.Fatal(r1.Err)
} else {
cposts := r1.Data.([]*model.CompliancePost)
if len(cposts) != 4 {
t.Fatal("return wrong results length")
}
if cposts[0].PostId != o1.Id {
t.Fatal("Wrong sort")
}
if cposts[3].PostId != o2a.Id {
t.Fatal("Wrong sort")
}
}
cposts, err := ss.Compliance().ComplianceExport(cr1)
require.Nil(t, err)
assert.Len(t, cposts, 4)
assert.Equal(t, cposts[0].PostId, o1.Id)
assert.Equal(t, cposts[3].PostId, o2a.Id)
cr2 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1, Emails: u2.Email}
if r1 := <-ss.Compliance().ComplianceExport(cr2); r1.Err != nil {
t.Fatal(r1.Err)
} else {
cposts := r1.Data.([]*model.CompliancePost)
if len(cposts) != 1 {
t.Fatal("return wrong results length")
}
if cposts[0].PostId != o2a.Id {
t.Fatal("Wrong sort")
}
}
cposts, err = ss.Compliance().ComplianceExport(cr2)
require.Nil(t, err)
assert.Len(t, cposts, 1)
assert.Equal(t, cposts[0].PostId, o2a.Id)
cr3 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1, Emails: u2.Email + ", " + u1.Email}
if r1 := <-ss.Compliance().ComplianceExport(cr3); r1.Err != nil {
t.Fatal(r1.Err)
} else {
cposts := r1.Data.([]*model.CompliancePost)
if len(cposts) != 4 {
t.Fatal("return wrong results length")
}
if cposts[0].PostId != o1.Id {
t.Fatal("Wrong sort")
}
if cposts[3].PostId != o2a.Id {
t.Fatal("Wrong sort")
}
}
cposts, err = ss.Compliance().ComplianceExport(cr3)
require.Nil(t, err)
assert.Len(t, cposts, 4)
assert.Equal(t, cposts[0].PostId, o1.Id)
assert.Equal(t, cposts[3].PostId, o2a.Id)
cr4 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1, Keywords: o2a.Message}
if r1 := <-ss.Compliance().ComplianceExport(cr4); r1.Err != nil {
t.Fatal(r1.Err)
} else {
cposts := r1.Data.([]*model.CompliancePost)
if len(cposts) != 1 {
t.Fatal("return wrong results length")
}
if cposts[0].PostId != o2a.Id {
t.Fatal("Wrong sort")
}
}
cposts, err = ss.Compliance().ComplianceExport(cr4)
require.Nil(t, err)
assert.Len(t, cposts, 1)
assert.Equal(t, cposts[0].PostId, o2a.Id)
cr5 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1, Keywords: o2a.Message + " " + o1.Message}
if r1 := <-ss.Compliance().ComplianceExport(cr5); r1.Err != nil {
t.Fatal(r1.Err)
} else {
cposts := r1.Data.([]*model.CompliancePost)
if len(cposts) != 2 {
t.Fatal("return wrong results length")
}
if cposts[0].PostId != o1.Id {
t.Fatal("Wrong sort")
}
}
cposts, err = ss.Compliance().ComplianceExport(cr5)
require.Nil(t, err)
assert.Len(t, cposts, 2)
assert.Equal(t, cposts[0].PostId, o1.Id)
cr6 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o2a.CreateAt + 1, Emails: u2.Email + ", " + u1.Email, Keywords: o2a.Message + " " + o1.Message}
if r1 := <-ss.Compliance().ComplianceExport(cr6); r1.Err != nil {
t.Fatal(r1.Err)
} else {
cposts := r1.Data.([]*model.CompliancePost)
if len(cposts) != 2 {
t.Fatal("return wrong results length")
}
if cposts[0].PostId != o1.Id {
t.Fatal("Wrong sort")
}
if cposts[1].PostId != o2a.Id {
t.Fatal("Wrong sort")
}
}
cposts, err = ss.Compliance().ComplianceExport(cr6)
require.Nil(t, err)
assert.Len(t, cposts, 2)
assert.Equal(t, cposts[0].PostId, o1.Id)
assert.Equal(t, cposts[1].PostId, o2a.Id)
}
func testComplianceExportDirectMessages(t *testing.T, ss store.Store) {
@@ -298,35 +230,19 @@ func testComplianceExportDirectMessages(t *testing.T, ss store.Store) {
time.Sleep(100 * time.Millisecond)
cr1 := &model.Compliance{Desc: "test" + model.NewId(), StartAt: o1.CreateAt - 1, EndAt: o3.CreateAt + 1, Emails: u1.Email}
if r1 := <-ss.Compliance().ComplianceExport(cr1); r1.Err != nil {
t.Fatal(r1.Err)
} else {
cposts := r1.Data.([]*model.CompliancePost)
if len(cposts) != 4 {
t.Fatal("return wrong results length")
}
if cposts[0].PostId != o1.Id {
t.Fatal("Wrong sort")
}
if cposts[len(cposts)-1].PostId != o3.Id {
t.Fatal("Wrong sort")
}
}
cposts, err := ss.Compliance().ComplianceExport(cr1)
require.Nil(t, err)
assert.Len(t, cposts, 4)
assert.Equal(t, cposts[0].PostId, o1.Id)
assert.Equal(t, cposts[len(cposts)-1].PostId, o3.Id)
}
func testMessageExportPublicChannel(t *testing.T, ss store.Store) {
// get the starting number of message export entries
startTime := model.GetMillis()
var numMessageExports = 0
if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil {
t.Fatal(r1.Err)
} else {
messages := r1.Data.([]*model.MessageExport)
numMessageExports = len(messages)
}
messages, err := ss.Compliance().MessageExport(startTime-10, 10)
require.Nil(t, err)
numMessageExports := len(messages)
// need a team
team := &model.Team{
@@ -386,15 +302,12 @@ func testMessageExportPublicChannel(t *testing.T, ss store.Store) {
// fetch the message exports for both posts that user1 sent
messageExportMap := map[string]model.MessageExport{}
if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil {
t.Fatal(r1.Err)
} else {
messages := r1.Data.([]*model.MessageExport)
assert.Equal(t, numMessageExports+2, len(messages))
messages, err = ss.Compliance().MessageExport(startTime-10, 10)
require.Nil(t, err)
assert.Equal(t, numMessageExports+2, len(messages))
for _, v := range messages {
messageExportMap[*v.PostId] = *v
}
for _, v := range messages {
messageExportMap[*v.PostId] = *v
}
// post1 was made by user1 in channel1 and team1
@@ -421,13 +334,9 @@ func testMessageExportPublicChannel(t *testing.T, ss store.Store) {
func testMessageExportPrivateChannel(t *testing.T, ss store.Store) {
// get the starting number of message export entries
startTime := model.GetMillis()
var numMessageExports = 0
if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil {
t.Fatal(r1.Err)
} else {
messages := r1.Data.([]*model.MessageExport)
numMessageExports = len(messages)
}
messages, err := ss.Compliance().MessageExport(startTime-10, 10)
require.Nil(t, err)
numMessageExports := len(messages)
// need a team
team := &model.Team{
@@ -487,15 +396,12 @@ func testMessageExportPrivateChannel(t *testing.T, ss store.Store) {
// fetch the message exports for both posts that user1 sent
messageExportMap := map[string]model.MessageExport{}
if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil {
t.Fatal(r1.Err)
} else {
messages := r1.Data.([]*model.MessageExport)
assert.Equal(t, numMessageExports+2, len(messages))
messages, err = ss.Compliance().MessageExport(startTime-10, 10)
require.Nil(t, err)
assert.Equal(t, numMessageExports+2, len(messages))
for _, v := range messages {
messageExportMap[*v.PostId] = *v
}
for _, v := range messages {
messageExportMap[*v.PostId] = *v
}
// post1 was made by user1 in channel1 and team1
@@ -524,13 +430,9 @@ func testMessageExportPrivateChannel(t *testing.T, ss store.Store) {
func testMessageExportDirectMessageChannel(t *testing.T, ss store.Store) {
// get the starting number of message export entries
startTime := model.GetMillis()
var numMessageExports = 0
if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil {
t.Fatal(r1.Err)
} else {
messages := r1.Data.([]*model.MessageExport)
numMessageExports = len(messages)
}
messages, err := ss.Compliance().MessageExport(startTime-10, 10)
require.Nil(t, err)
numMessageExports := len(messages)
// need a team
team := &model.Team{
@@ -576,15 +478,13 @@ func testMessageExportDirectMessageChannel(t *testing.T, ss store.Store) {
// fetch the message export for the post that user1 sent
messageExportMap := map[string]model.MessageExport{}
if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil {
t.Fatal(r1.Err)
} else {
messages := r1.Data.([]*model.MessageExport)
assert.Equal(t, numMessageExports+1, len(messages))
messages, err = ss.Compliance().MessageExport(startTime-10, 10)
require.Nil(t, err)
for _, v := range messages {
messageExportMap[*v.PostId] = *v
}
assert.Equal(t, numMessageExports+1, len(messages))
for _, v := range messages {
messageExportMap[*v.PostId] = *v
}
// post is a DM between user1 and user2
@@ -602,13 +502,9 @@ func testMessageExportDirectMessageChannel(t *testing.T, ss store.Store) {
func testMessageExportGroupMessageChannel(t *testing.T, ss store.Store) {
// get the starting number of message export entries
startTime := model.GetMillis()
var numMessageExports = 0
if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil {
t.Fatal(r1.Err)
} else {
messages := r1.Data.([]*model.MessageExport)
numMessageExports = len(messages)
}
messages, err := ss.Compliance().MessageExport(startTime-10, 10)
require.Nil(t, err)
numMessageExports := len(messages)
// need a team
team := &model.Team{
@@ -669,15 +565,12 @@ func testMessageExportGroupMessageChannel(t *testing.T, ss store.Store) {
// fetch the message export for the post that user1 sent
messageExportMap := map[string]model.MessageExport{}
if r1 := <-ss.Compliance().MessageExport(startTime-10, 10); r1.Err != nil {
t.Fatal(r1.Err)
} else {
messages := r1.Data.([]*model.MessageExport)
assert.Equal(t, numMessageExports+1, len(messages))
messages, err = ss.Compliance().MessageExport(startTime-10, 10)
require.Nil(t, err)
assert.Equal(t, numMessageExports+1, len(messages))
for _, v := range messages {
messageExportMap[*v.PostId] = *v
}
for _, v := range messages {
messageExportMap[*v.PostId] = *v
}
// post is a DM between user1 and user2

View File

@@ -6,7 +6,6 @@ package mocks
import mock "github.com/stretchr/testify/mock"
import model "github.com/mattermost/mattermost-server/model"
import store "github.com/mattermost/mattermost-server/store"
// ComplianceStore is an autogenerated mock type for the ComplianceStore type
type ComplianceStore struct {
@@ -14,97 +13,151 @@ type ComplianceStore struct {
}
// ComplianceExport provides a mock function with given fields: compliance
func (_m *ComplianceStore) ComplianceExport(compliance *model.Compliance) store.StoreChannel {
func (_m *ComplianceStore) ComplianceExport(compliance *model.Compliance) ([]*model.CompliancePost, *model.AppError) {
ret := _m.Called(compliance)
var r0 store.StoreChannel
if rf, ok := ret.Get(0).(func(*model.Compliance) store.StoreChannel); ok {
var r0 []*model.CompliancePost
if rf, ok := ret.Get(0).(func(*model.Compliance) []*model.CompliancePost); ok {
r0 = rf(compliance)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(store.StoreChannel)
r0 = ret.Get(0).([]*model.CompliancePost)
}
}
return r0
var r1 *model.AppError
if rf, ok := ret.Get(1).(func(*model.Compliance) *model.AppError); ok {
r1 = rf(compliance)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*model.AppError)
}
}
return r0, r1
}
// Get provides a mock function with given fields: id
func (_m *ComplianceStore) Get(id string) store.StoreChannel {
func (_m *ComplianceStore) Get(id string) (*model.Compliance, *model.AppError) {
ret := _m.Called(id)
var r0 store.StoreChannel
if rf, ok := ret.Get(0).(func(string) store.StoreChannel); ok {
var r0 *model.Compliance
if rf, ok := ret.Get(0).(func(string) *model.Compliance); ok {
r0 = rf(id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(store.StoreChannel)
r0 = ret.Get(0).(*model.Compliance)
}
}
return r0
var r1 *model.AppError
if rf, ok := ret.Get(1).(func(string) *model.AppError); ok {
r1 = rf(id)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*model.AppError)
}
}
return r0, r1
}
// GetAll provides a mock function with given fields: offset, limit
func (_m *ComplianceStore) GetAll(offset int, limit int) store.StoreChannel {
func (_m *ComplianceStore) GetAll(offset int, limit int) (model.Compliances, *model.AppError) {
ret := _m.Called(offset, limit)
var r0 store.StoreChannel
if rf, ok := ret.Get(0).(func(int, int) store.StoreChannel); ok {
var r0 model.Compliances
if rf, ok := ret.Get(0).(func(int, int) model.Compliances); ok {
r0 = rf(offset, limit)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(store.StoreChannel)
r0 = ret.Get(0).(model.Compliances)
}
}
return r0
var r1 *model.AppError
if rf, ok := ret.Get(1).(func(int, int) *model.AppError); ok {
r1 = rf(offset, limit)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*model.AppError)
}
}
return r0, r1
}
// MessageExport provides a mock function with given fields: after, limit
func (_m *ComplianceStore) MessageExport(after int64, limit int) store.StoreChannel {
func (_m *ComplianceStore) MessageExport(after int64, limit int) ([]*model.MessageExport, *model.AppError) {
ret := _m.Called(after, limit)
var r0 store.StoreChannel
if rf, ok := ret.Get(0).(func(int64, int) store.StoreChannel); ok {
var r0 []*model.MessageExport
if rf, ok := ret.Get(0).(func(int64, int) []*model.MessageExport); ok {
r0 = rf(after, limit)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(store.StoreChannel)
r0 = ret.Get(0).([]*model.MessageExport)
}
}
return r0
var r1 *model.AppError
if rf, ok := ret.Get(1).(func(int64, int) *model.AppError); ok {
r1 = rf(after, limit)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*model.AppError)
}
}
return r0, r1
}
// Save provides a mock function with given fields: compliance
func (_m *ComplianceStore) Save(compliance *model.Compliance) store.StoreChannel {
func (_m *ComplianceStore) Save(compliance *model.Compliance) (*model.Compliance, *model.AppError) {
ret := _m.Called(compliance)
var r0 store.StoreChannel
if rf, ok := ret.Get(0).(func(*model.Compliance) store.StoreChannel); ok {
var r0 *model.Compliance
if rf, ok := ret.Get(0).(func(*model.Compliance) *model.Compliance); ok {
r0 = rf(compliance)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(store.StoreChannel)
r0 = ret.Get(0).(*model.Compliance)
}
}
return r0
var r1 *model.AppError
if rf, ok := ret.Get(1).(func(*model.Compliance) *model.AppError); ok {
r1 = rf(compliance)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*model.AppError)
}
}
return r0, r1
}
// Update provides a mock function with given fields: compliance
func (_m *ComplianceStore) Update(compliance *model.Compliance) store.StoreChannel {
func (_m *ComplianceStore) Update(compliance *model.Compliance) (*model.Compliance, *model.AppError) {
ret := _m.Called(compliance)
var r0 store.StoreChannel
if rf, ok := ret.Get(0).(func(*model.Compliance) store.StoreChannel); ok {
var r0 *model.Compliance
if rf, ok := ret.Get(0).(func(*model.Compliance) *model.Compliance); ok {
r0 = rf(compliance)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(store.StoreChannel)
r0 = ret.Get(0).(*model.Compliance)
}
}
return r0
var r1 *model.AppError
if rf, ok := ret.Get(1).(func(*model.Compliance) *model.AppError); ok {
r1 = rf(compliance)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*model.AppError)
}
}
return r0, r1
}