diff --git a/api4/terms_of_service.go b/api4/terms_of_service.go index 8272d3c86e..68d95f8361 100644 --- a/api4/terms_of_service.go +++ b/api4/terms_of_service.go @@ -11,11 +11,11 @@ import ( ) func (api *API) InitTermsOfService() { - api.BaseRoutes.TermsOfService.Handle("", api.ApiSessionRequired(getTermsOfService)).Methods("GET") + api.BaseRoutes.TermsOfService.Handle("", api.ApiSessionRequired(getLatestTermsOfService)).Methods("GET") api.BaseRoutes.TermsOfService.Handle("", api.ApiSessionRequired(createTermsOfService)).Methods("POST") } -func getTermsOfService(c *Context, w http.ResponseWriter, r *http.Request) { +func getLatestTermsOfService(c *Context, w http.ResponseWriter, r *http.Request) { termsOfService, err := c.App.GetLatestTermsOfService() if err != nil { c.Err = err diff --git a/api4/user.go b/api4/user.go index fe9d331c33..4b76c59514 100644 --- a/api4/user.go +++ b/api4/user.go @@ -40,7 +40,8 @@ func (api *API) InitUser() { api.BaseRoutes.Users.Handle("/password/reset/send", api.ApiHandler(sendPasswordReset)).Methods("POST") api.BaseRoutes.Users.Handle("/email/verify", api.ApiHandler(verifyUserEmail)).Methods("POST") api.BaseRoutes.Users.Handle("/email/verify/send", api.ApiHandler(sendVerificationEmail)).Methods("POST") - api.BaseRoutes.User.Handle("/terms_of_service", api.ApiSessionRequired(registerTermsOfServiceAction)).Methods("POST") + api.BaseRoutes.User.Handle("/terms_of_service", api.ApiSessionRequired(saveUserTermsOfService)).Methods("POST") + api.BaseRoutes.User.Handle("/terms_of_service", api.ApiSessionRequired(getUserTermsOfService)).Methods("GET") api.BaseRoutes.User.Handle("/auth", api.ApiSessionRequiredTrustRequester(updateUserAuth)).Methods("PUT") @@ -1626,7 +1627,7 @@ func enableUserAccessToken(c *Context, w http.ResponseWriter, r *http.Request) { ReturnStatusOK(w) } -func registerTermsOfServiceAction(c *Context, w http.ResponseWriter, r *http.Request) { +func saveUserTermsOfService(c *Context, w http.ResponseWriter, r *http.Request) { props := model.StringInterfaceFromJson(r.Body) userId := c.Session.UserId @@ -1638,7 +1639,7 @@ func registerTermsOfServiceAction(c *Context, w http.ResponseWriter, r *http.Req return } - if err := c.App.RecordUserTermsOfServiceAction(userId, termsOfServiceId, accepted); err != nil { + if err := c.App.SaveUserTermsOfService(userId, termsOfServiceId, accepted); err != nil { c.Err = err return } @@ -1646,3 +1647,13 @@ func registerTermsOfServiceAction(c *Context, w http.ResponseWriter, r *http.Req c.LogAudit("TermsOfServiceId=" + termsOfServiceId + ", accepted=" + strconv.FormatBool(accepted)) ReturnStatusOK(w) } + +func getUserTermsOfService(c *Context, w http.ResponseWriter, r *http.Request) { + userId := c.Session.UserId + if result, err := c.App.GetUserTermsOfService(userId); err != nil { + c.Err = err + return + } else { + w.Write([]byte(result.ToJson())) + } +} diff --git a/api4/user_test.go b/api4/user_test.go index f7970128f5..bb5c782fd5 100644 --- a/api4/user_test.go +++ b/api4/user_test.go @@ -3086,10 +3086,34 @@ func TestRegisterTermsOfServiceAction(t *testing.T) { CheckNoError(t, resp) assert.True(t, *success) - user, err := th.App.GetUser(th.BasicUser.Id) + _, err = th.App.GetUser(th.BasicUser.Id) + if err != nil { + t.Fatal(err) + } +} + + +func TestGetUserTermsOfService(t *testing.T) { + th := Setup().InitBasic() + defer th.TearDown() + Client := th.Client + + _, resp := Client.GetUserTermsOfService(th.BasicUser.Id, "") + CheckErrorMessage(t, resp, "store.sql_user_terms_of_service.get_by_user.no_rows.app_error") + + termsOfService, err := th.App.CreateTermsOfService("terms of service", th.BasicUser.Id) if err != nil { t.Fatal(err) } - assert.Equal(t, user.AcceptedTermsOfServiceId, termsOfService.Id) + success, resp := Client.RegisteTermsOfServiceAction(th.BasicUser.Id, termsOfService.Id, true) + CheckNoError(t, resp) + assert.True(t, *success) + + userTermsOfService, resp := Client.GetUserTermsOfService(th.BasicUser.Id, "") + CheckNoError(t, resp) + + assert.Equal(t, th.BasicUser.Id, userTermsOfService.UserId) + assert.Equal(t, termsOfService.Id, userTermsOfService.TermsOfServiceId) + assert.NotEmpty(t, userTermsOfService.CreateAt) } diff --git a/app/diagnostics.go b/app/diagnostics.go index fdbf9cab7c..c706ce4bf2 100644 --- a/app/diagnostics.go +++ b/app/diagnostics.go @@ -408,13 +408,14 @@ func (a *App) trackConfig() { }) a.SendDiagnostic(TRACK_CONFIG_SUPPORT, map[string]interface{}{ - "isdefault_terms_of_service_link": isDefault(*cfg.SupportSettings.TermsOfServiceLink, model.SUPPORT_SETTINGS_DEFAULT_TERMS_OF_SERVICE_LINK), - "isdefault_privacy_policy_link": isDefault(*cfg.SupportSettings.PrivacyPolicyLink, model.SUPPORT_SETTINGS_DEFAULT_PRIVACY_POLICY_LINK), - "isdefault_about_link": isDefault(*cfg.SupportSettings.AboutLink, model.SUPPORT_SETTINGS_DEFAULT_ABOUT_LINK), - "isdefault_help_link": isDefault(*cfg.SupportSettings.HelpLink, model.SUPPORT_SETTINGS_DEFAULT_HELP_LINK), - "isdefault_report_a_problem_link": isDefault(*cfg.SupportSettings.ReportAProblemLink, model.SUPPORT_SETTINGS_DEFAULT_REPORT_A_PROBLEM_LINK), - "isdefault_support_email": isDefault(*cfg.SupportSettings.SupportEmail, model.SUPPORT_SETTINGS_DEFAULT_SUPPORT_EMAIL), - "custom_terms_of_service_enabled": *cfg.SupportSettings.CustomTermsOfServiceEnabled, + "isdefault_terms_of_service_link": isDefault(*cfg.SupportSettings.TermsOfServiceLink, model.SUPPORT_SETTINGS_DEFAULT_TERMS_OF_SERVICE_LINK), + "isdefault_privacy_policy_link": isDefault(*cfg.SupportSettings.PrivacyPolicyLink, model.SUPPORT_SETTINGS_DEFAULT_PRIVACY_POLICY_LINK), + "isdefault_about_link": isDefault(*cfg.SupportSettings.AboutLink, model.SUPPORT_SETTINGS_DEFAULT_ABOUT_LINK), + "isdefault_help_link": isDefault(*cfg.SupportSettings.HelpLink, model.SUPPORT_SETTINGS_DEFAULT_HELP_LINK), + "isdefault_report_a_problem_link": isDefault(*cfg.SupportSettings.ReportAProblemLink, model.SUPPORT_SETTINGS_DEFAULT_REPORT_A_PROBLEM_LINK), + "isdefault_support_email": isDefault(*cfg.SupportSettings.SupportEmail, model.SUPPORT_SETTINGS_DEFAULT_SUPPORT_EMAIL), + "custom_terms_of_service_enabled": *cfg.SupportSettings.CustomTermsOfServiceEnabled, + "custom_terms_of_service_re_acceptance_period": *cfg.SupportSettings.CustomTermsOfServiceReAcceptancePeriod, }) a.SendDiagnostic(TRACK_CONFIG_LDAP, map[string]interface{}{ diff --git a/app/user.go b/app/user.go index 92269e618f..05d0fe40f5 100644 --- a/app/user.go +++ b/app/user.go @@ -1622,22 +1622,3 @@ func (a *App) UpdateOAuthUserAttrs(userData io.Reader, user *model.User, provide return nil } - -func (a *App) RecordUserTermsOfServiceAction(userId, termsOfServiceId string, accepted bool) *model.AppError { - user, err := a.GetUser(userId) - if err != nil { - return err - } - - if accepted { - user.AcceptedTermsOfServiceId = termsOfServiceId - } else { - user.AcceptedTermsOfServiceId = "" - } - _, err = a.UpdateUser(user, false) - if err != nil { - return err - } - - return nil -} diff --git a/app/user_terms_of_service.go b/app/user_terms_of_service.go new file mode 100644 index 0000000000..11d50c002c --- /dev/null +++ b/app/user_terms_of_service.go @@ -0,0 +1,33 @@ +// Copyright (c) 2016-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package app + +import "github.com/mattermost/mattermost-server/model" + +func (a *App) GetUserTermsOfService(userId string) (*model.UserTermsOfService, *model.AppError) { + if result := <-a.Srv.Store.UserTermsOfService().GetByUser(userId); result.Err != nil { + return nil, result.Err + } else { + return result.Data.(*model.UserTermsOfService), nil + } +} + +func (a *App) SaveUserTermsOfService(userId, termsOfServiceId string, accepted bool) *model.AppError { + if accepted { + userTermsOfService := &model.UserTermsOfService{ + UserId: userId, + TermsOfServiceId: termsOfServiceId, + } + + if result := <-a.Srv.Store.UserTermsOfService().Save(userTermsOfService); result.Err != nil { + return result.Err + } + } else { + if result := <-a.Srv.Store.UserTermsOfService().Delete(userId, termsOfServiceId); result.Err != nil { + return result.Err + } + } + + return nil +} diff --git a/app/user_terms_of_service_test.go b/app/user_terms_of_service_test.go new file mode 100644 index 0000000000..4beb508258 --- /dev/null +++ b/app/user_terms_of_service_test.go @@ -0,0 +1,34 @@ +// Copyright (c) 2016-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package app + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestUserTermsOfService(t *testing.T) { + th := Setup().InitBasic() + defer th.TearDown() + + userTermsOfService, err := th.App.GetUserTermsOfService(th.BasicUser.Id) + checkError(t, err) + assert.Nil(t, userTermsOfService) + assert.Equal(t, "store.sql_user_terms_of_service.get_by_user.no_rows.app_error", err.Id) + + termsOfService, err := th.App.CreateTermsOfService("terms of service", th.BasicUser.Id) + checkNoError(t, err) + + err = th.App.SaveUserTermsOfService(th.BasicUser.Id, termsOfService.Id, true) + checkNoError(t, err) + + userTermsOfService, err = th.App.GetUserTermsOfService(th.BasicUser.Id) + checkNoError(t, err) + assert.NotNil(t, userTermsOfService) + assert.NotEmpty(t, userTermsOfService) + + assert.Equal(t, th.BasicUser.Id, userTermsOfService.UserId) + assert.Equal(t, termsOfService.Id, userTermsOfService.TermsOfServiceId) + assert.NotEmpty(t, userTermsOfService.CreateAt) +} diff --git a/app/user_test.go b/app/user_test.go index 7cd3697497..2aeed6fc84 100644 --- a/app/user_test.go +++ b/app/user_test.go @@ -544,43 +544,3 @@ func TestPermanentDeleteUser(t *testing.T) { t.Fatal("GetFileInfo after DeleteUser is nil") } } - -func TestRecordUserTermsOfServiceAction(t *testing.T) { - th := Setup().InitBasic() - defer th.TearDown() - - user := &model.User{ - Email: strings.ToLower(model.NewId()) + "success+test@example.com", - Nickname: "Luke Skywalker", // trying to bring balance to the "Force", one test user at a time - Username: "luke" + model.NewId(), - Password: "passwd1", - AuthService: "", - } - user, err := th.App.CreateUser(user) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - - defer th.App.PermanentDeleteUser(user) - - termsOfService, err := th.App.CreateTermsOfService("text", user.Id) - if err != nil { - t.Fatalf("failed to create terms of service: %v", err) - } - - err = th.App.RecordUserTermsOfServiceAction(user.Id, termsOfService.Id, true) - if err != nil { - t.Fatalf("failed to record user action: %v", err) - } - - nuser, err := th.App.GetUser(user.Id) - assert.Equal(t, termsOfService.Id, nuser.AcceptedTermsOfServiceId) - - err = th.App.RecordUserTermsOfServiceAction(user.Id, termsOfService.Id, false) - if err != nil { - t.Fatalf("failed to record user action: %v", err) - } - - nuser, err = th.App.GetUser(user.Id) - assert.Empty(t, nuser.AcceptedTermsOfServiceId) -} diff --git a/i18n/en.json b/i18n/en.json index b98d13b471..797141af8a 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -4794,6 +4794,18 @@ "id": "model.terms_of_service.is_valid.text.app_error", "translation": "Custom terms of service text is too long. Maximum {{.MaxLength}} characters allowed." }, + { + "id": "model.user_terms_of_service.is_valid.user_id.app_error", + "translation": "Missing required user terms of service property: user_id." + }, + { + "id": "model.user_terms_of_service.is_valid.service_terms_id.app_error", + "translation": "Missing required user terms of service property: service_terms_id." + }, + { + "id": "model.user_terms_of_service.is_valid.create_at.app_error", + "translation": "Missing required user terms of service property: create_at." + }, { "id": "oauth.gitlab.tos.error", "translation": "GitLab's Terms of Service have updated. Please go to gitlab.com to accept them and then try logging into Mattermost again." @@ -6454,6 +6466,22 @@ "id": "store.sql_terms_of_service_store.get.no_rows.app_error", "translation": "No terms of service found." }, + { + "id": "store.sql_user_terms_of_service.get_by_user.no_rows.app_error", + "translation": "No user terms of service found." + }, + { + "id": "store.sql_user_terms_of_service.get_by_user.app_error", + "translation": "Unable to fetch user terms of service." + }, + { + "id": "store.sql_user_terms_of_service.save.app_error", + "translation": "Unable to save user terms of service." + }, + { + "id": "store.sql_user_terms_of_service.delete.app_error", + "translation": "Unable to delete user terms of service." + }, { "id": "system.message.name", "translation": "System" diff --git a/model/client4.go b/model/client4.go index 484694416b..92a0f0565d 100644 --- a/model/client4.go +++ b/model/client4.go @@ -401,7 +401,7 @@ func (c *Client4) GetRedirectLocationRoute() string { return fmt.Sprintf("/redirect_location") } -func (c *Client4) GetRegisterTermsOfServiceRoute(userId string) string { +func (c *Client4) GetUserTermsOfServiceRoute(userId string) string { return c.GetUserRoute(userId) + "/terms_of_service" } @@ -3849,7 +3849,7 @@ func (c *Client4) GetRedirectLocation(urlParam, etag string) (string, *Response) } func (c *Client4) RegisteTermsOfServiceAction(userId, termsOfServiceId string, accepted bool) (*bool, *Response) { - url := c.GetRegisterTermsOfServiceRoute(userId) + url := c.GetUserTermsOfServiceRoute(userId) data := map[string]interface{}{"termsOfServiceId": termsOfServiceId, "accepted": accepted} if r, err := c.DoApiPost(url, StringInterfaceToJson(data)); err != nil { @@ -3871,11 +3871,22 @@ func (c *Client4) GetTermsOfService(etag string) (*TermsOfService, *Response) { } } +func (c *Client4) GetUserTermsOfService(userId, etag string) (*UserTermsOfService, *Response) { + url := c.GetUserTermsOfServiceRoute(userId) + + if r, err := c.DoApiGet(url, etag); err != nil { + return nil, BuildErrorResponse(r, err) + } else { + defer closeBody(r) + return UserTermsOfServiceFromJson(r.Body), BuildResponse(r) + } +} + func (c *Client4) CreateTermsOfService(text, userId string) (*TermsOfService, *Response) { url := c.GetTermsOfServiceRoute() - data := map[string]string{"text": text} - if r, err := c.DoApiPost(url, MapToJson(data)); err != nil { + data := map[string]interface{}{"text": text} + if r, err := c.DoApiPost(url, StringInterfaceToJson(data)); err != nil { return nil, BuildErrorResponse(r, err) } else { defer closeBody(r) diff --git a/model/config.go b/model/config.go index 9f25662c3a..3af08f5974 100644 --- a/model/config.go +++ b/model/config.go @@ -112,6 +112,7 @@ const ( SUPPORT_SETTINGS_DEFAULT_HELP_LINK = "https://about.mattermost.com/default-help/" SUPPORT_SETTINGS_DEFAULT_REPORT_A_PROBLEM_LINK = "https://about.mattermost.com/default-report-a-problem/" SUPPORT_SETTINGS_DEFAULT_SUPPORT_EMAIL = "feedback@mattermost.com" + SUPPORT_SETTINGS_DEFAULT_RE_ACCEPTANCE_PERIOD = 365 LDAP_SETTINGS_DEFAULT_FIRST_NAME_ATTRIBUTE = "" LDAP_SETTINGS_DEFAULT_LAST_NAME_ATTRIBUTE = "" @@ -1030,13 +1031,14 @@ type PrivacySettings struct { } type SupportSettings struct { - TermsOfServiceLink *string - PrivacyPolicyLink *string - AboutLink *string - HelpLink *string - ReportAProblemLink *string - SupportEmail *string - CustomTermsOfServiceEnabled *bool + TermsOfServiceLink *string + PrivacyPolicyLink *string + AboutLink *string + HelpLink *string + ReportAProblemLink *string + SupportEmail *string + CustomTermsOfServiceEnabled *bool + CustomTermsOfServiceReAcceptancePeriod *int } func (s *SupportSettings) SetDefaults() { @@ -1087,6 +1089,10 @@ func (s *SupportSettings) SetDefaults() { if s.CustomTermsOfServiceEnabled == nil { s.CustomTermsOfServiceEnabled = NewBool(false) } + + if s.CustomTermsOfServiceReAcceptancePeriod == nil { + s.CustomTermsOfServiceReAcceptancePeriod = NewInt(SUPPORT_SETTINGS_DEFAULT_RE_ACCEPTANCE_PERIOD) + } } type AnnouncementSettings struct { diff --git a/model/terms_of_service.go b/model/terms_of_service.go index c99a785688..e43f89095c 100644 --- a/model/terms_of_service.go +++ b/model/terms_of_service.go @@ -11,7 +11,6 @@ import ( "unicode/utf8" ) -// we only ever need the latest version of terms of service const TERMS_OF_SERVICE_CACHE_SIZE = 1 type TermsOfService struct { @@ -58,7 +57,7 @@ func InvalidTermsOfServiceError(fieldName string, termsOfServiceId string) *AppE if termsOfServiceId != "" { details = "terms_of_service_id=" + termsOfServiceId } - return NewAppError("TermsOfServiceStore.IsValid", id, map[string]interface{}{"MaxLength": POST_MESSAGE_MAX_RUNES_V2}, details, http.StatusBadRequest) + return NewAppError("TermsOfService.IsValid", id, map[string]interface{}{"MaxLength": POST_MESSAGE_MAX_RUNES_V2}, details, http.StatusBadRequest) } func (t *TermsOfService) PreSave() { diff --git a/model/user.go b/model/user.go index 40ccd16610..8fc9a771ce 100644 --- a/model/user.go +++ b/model/user.go @@ -50,33 +50,32 @@ const ( ) type User struct { - Id string `json:"id"` - CreateAt int64 `json:"create_at,omitempty"` - UpdateAt int64 `json:"update_at,omitempty"` - DeleteAt int64 `json:"delete_at"` - Username string `json:"username"` - Password string `json:"password,omitempty"` - AuthData *string `json:"auth_data,omitempty"` - AuthService string `json:"auth_service"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified,omitempty"` - Nickname string `json:"nickname"` - FirstName string `json:"first_name"` - LastName string `json:"last_name"` - Position string `json:"position"` - Roles string `json:"roles"` - AllowMarketing bool `json:"allow_marketing,omitempty"` - Props StringMap `json:"props,omitempty"` - NotifyProps StringMap `json:"notify_props,omitempty"` - LastPasswordUpdate int64 `json:"last_password_update,omitempty"` - LastPictureUpdate int64 `json:"last_picture_update,omitempty"` - FailedAttempts int `json:"failed_attempts,omitempty"` - Locale string `json:"locale"` - Timezone StringMap `json:"timezone"` - MfaActive bool `json:"mfa_active,omitempty"` - MfaSecret string `json:"mfa_secret,omitempty"` - LastActivityAt int64 `db:"-" json:"last_activity_at,omitempty"` - AcceptedTermsOfServiceId string `json:"accepted_terms_of_service_id,omitempty"` // TODO remove this field when new TOS user action table is created + Id string `json:"id"` + CreateAt int64 `json:"create_at,omitempty"` + UpdateAt int64 `json:"update_at,omitempty"` + DeleteAt int64 `json:"delete_at"` + Username string `json:"username"` + Password string `json:"password,omitempty"` + AuthData *string `json:"auth_data,omitempty"` + AuthService string `json:"auth_service"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified,omitempty"` + Nickname string `json:"nickname"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Position string `json:"position"` + Roles string `json:"roles"` + AllowMarketing bool `json:"allow_marketing,omitempty"` + Props StringMap `json:"props,omitempty"` + NotifyProps StringMap `json:"notify_props,omitempty"` + LastPasswordUpdate int64 `json:"last_password_update,omitempty"` + LastPictureUpdate int64 `json:"last_picture_update,omitempty"` + FailedAttempts int `json:"failed_attempts,omitempty"` + Locale string `json:"locale"` + Timezone StringMap `json:"timezone"` + MfaActive bool `json:"mfa_active,omitempty"` + MfaSecret string `json:"mfa_secret,omitempty"` + LastActivityAt int64 `db:"-" json:"last_activity_at,omitempty"` } type UserPatch struct { diff --git a/model/user_terms_of_Service_test.go b/model/user_terms_of_Service_test.go new file mode 100644 index 0000000000..f28171b41e --- /dev/null +++ b/model/user_terms_of_Service_test.go @@ -0,0 +1,46 @@ +// Copyright (c) 2016-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package model + +import ( + "github.com/stretchr/testify/assert" + "strings" + "testing" +) + +func TestUserTermsOfServiceIsValid(t *testing.T) { + s := UserTermsOfService{} + + if err := s.IsValid(); err == nil { + t.Fatal("should be invalid") + } + + s.UserId = NewId() + if err := s.IsValid(); err == nil { + t.Fatal("should be invalid") + } + + s.TermsOfServiceId = NewId() + if err := s.IsValid(); err == nil { + t.Fatal("should be invalid") + } + + s.CreateAt = GetMillis() + if err := s.IsValid(); err != nil { + t.Fatal("should be valid") + } +} + +func TestUserTermsOfServiceJson(t *testing.T) { + o := UserTermsOfService{ + UserId: NewId(), + TermsOfServiceId: NewId(), + CreateAt: GetMillis(), + } + j := o.ToJson() + ro := UserTermsOfServiceFromJson(strings.NewReader(j)) + + assert.NotNil(t, ro) + assert.Equal(t, o, *ro) +} diff --git a/model/user_terms_of_service.go b/model/user_terms_of_service.go new file mode 100644 index 0000000000..b714f923c4 --- /dev/null +++ b/model/user_terms_of_service.go @@ -0,0 +1,61 @@ +// Copyright (c) 2016-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package model + +import ( + "encoding/json" + "fmt" + "io" + "net/http" +) + +type UserTermsOfService struct { + UserId string `json:"user_id"` + TermsOfServiceId string `json:"terms_of_service_id"` + CreateAt int64 `json:"create_at"` +} + +func (ut *UserTermsOfService) IsValid() *AppError { + if len(ut.UserId) != 26 { + return InvalidUserTermsOfServiceError("user_id", ut.UserId) + } + + if len(ut.TermsOfServiceId) != 26 { + return InvalidUserTermsOfServiceError("terms_of_service_id", ut.UserId) + } + + if ut.CreateAt == 0 { + return InvalidUserTermsOfServiceError("create_at", ut.UserId) + } + + return nil +} + +func (ut *UserTermsOfService) ToJson() string { + b, _ := json.Marshal(ut) + return string(b) +} + +func (ut *UserTermsOfService) PreSave() { + if ut.UserId == "" { + ut.UserId = NewId() + } + + ut.CreateAt = GetMillis() +} + +func UserTermsOfServiceFromJson(data io.Reader) *UserTermsOfService { + var userTermsOfService *UserTermsOfService + json.NewDecoder(data).Decode(&userTermsOfService) + return userTermsOfService +} + +func InvalidUserTermsOfServiceError(fieldName string, userTermsOfServiceId string) *AppError { + id := fmt.Sprintf("model.user_terms_of_service.is_valid.%s.app_error", fieldName) + details := "" + if userTermsOfServiceId != "" { + details = "user_terms_of_service_user_id=" + userTermsOfServiceId + } + return NewAppError("UserTermsOfService.IsValid", id, nil, details, http.StatusBadRequest) +} diff --git a/store/layered_store.go b/store/layered_store.go index da2880fa5f..f69f55a7eb 100644 --- a/store/layered_store.go +++ b/store/layered_store.go @@ -173,6 +173,10 @@ func (s *LayeredStore) TermsOfService() TermsOfServiceStore { return s.DatabaseLayer.TermsOfService() } +func (s *LayeredStore) UserTermsOfService() UserTermsOfServiceStore { + return s.DatabaseLayer.UserTermsOfService() +} + func (s *LayeredStore) Scheme() SchemeStore { return s.SchemeStore } diff --git a/store/sqlstore/store.go b/store/sqlstore/store.go index 0408c5feb8..7f633d17de 100644 --- a/store/sqlstore/store.go +++ b/store/sqlstore/store.go @@ -94,4 +94,5 @@ type SqlStore interface { Role() store.RoleStore Scheme() store.SchemeStore TermsOfService() store.TermsOfServiceStore + UserTermsOfService() store.UserTermsOfServiceStore } diff --git a/store/sqlstore/supplier.go b/store/sqlstore/supplier.go index c0c92aaaf9..68fed70d34 100644 --- a/store/sqlstore/supplier.go +++ b/store/sqlstore/supplier.go @@ -93,6 +93,7 @@ type SqlSupplierOldStores struct { role store.RoleStore scheme store.SchemeStore TermsOfService store.TermsOfServiceStore + UserTermsOfService store.UserTermsOfServiceStore } type SqlSupplier struct { @@ -142,6 +143,7 @@ func NewSqlSupplier(settings model.SqlSettings, metrics einterfaces.MetricsInter supplier.oldStores.channelMemberHistory = NewSqlChannelMemberHistoryStore(supplier) supplier.oldStores.plugin = NewSqlPluginStore(supplier) supplier.oldStores.TermsOfService = NewSqlTermsOfServiceStore(supplier, metrics) + supplier.oldStores.UserTermsOfService = NewSqlUserTermsOfServiceStore(supplier) initSqlSupplierReactions(supplier) initSqlSupplierRoles(supplier) @@ -178,6 +180,7 @@ func NewSqlSupplier(settings model.SqlSettings, metrics einterfaces.MetricsInter supplier.oldStores.userAccessToken.(*SqlUserAccessTokenStore).CreateIndexesIfNotExists() supplier.oldStores.plugin.(*SqlPluginStore).CreateIndexesIfNotExists() supplier.oldStores.TermsOfService.(SqlTermsOfServiceStore).CreateIndexesIfNotExists() + supplier.oldStores.UserTermsOfService.(SqlUserTermsOfServiceStore).CreateIndexesIfNotExists() supplier.oldStores.preference.(*SqlPreferenceStore).DeleteUnusedFeatures() @@ -963,6 +966,10 @@ func (ss *SqlSupplier) TermsOfService() store.TermsOfServiceStore { return ss.oldStores.TermsOfService } +func (ss *SqlSupplier) UserTermsOfService() store.UserTermsOfServiceStore { + return ss.oldStores.UserTermsOfService +} + func (ss *SqlSupplier) Scheme() store.SchemeStore { return ss.oldStores.scheme } diff --git a/store/sqlstore/terms_of_service_store.go b/store/sqlstore/terms_of_service_store.go index 47557ee8ce..33907b297e 100644 --- a/store/sqlstore/terms_of_service_store.go +++ b/store/sqlstore/terms_of_service_store.go @@ -20,7 +20,9 @@ type SqlTermsOfServiceStore struct { var termsOfServiceCache = utils.NewLru(model.TERMS_OF_SERVICE_CACHE_SIZE) -const termsOfServiceCacheName = "TermsOfServiceStore" +const ( + termsOfServiceCacheName = "TermsOfServiceStore" +) func NewSqlTermsOfServiceStore(sqlStore SqlStore, metrics einterfaces.MetricsInterface) store.TermsOfServiceStore { s := SqlTermsOfServiceStore{sqlStore, metrics} diff --git a/store/sqlstore/upgrade.go b/store/sqlstore/upgrade.go index 26f0ed88db..73830032c6 100644 --- a/store/sqlstore/upgrade.go +++ b/store/sqlstore/upgrade.go @@ -504,7 +504,6 @@ func UpgradeDatabaseToVersion54(sqlStore SqlStore) { time.Sleep(time.Second) os.Exit(EXIT_GENERIC_FAILURE) } - sqlStore.CreateColumnIfNotExists("Users", "AcceptedTermsOfServiceId", "varchar(64)", "varchar(64)", "") saveSchemaVersion(sqlStore, VERSION_5_4_0) } } @@ -516,9 +515,12 @@ func UpgradeDatabaseToVersion55(sqlStore SqlStore) { } func UpgradeDatabaseToVersion56(sqlStore SqlStore) { - // TODO: Uncomment following condition when version 5.5.0 is released + // TODO: Uncomment following condition when version 5.6.0 is released //if shouldPerformUpgrade(sqlStore, VERSION_5_5_0, VERSION_5_6_0) { sqlStore.CreateColumnIfNotExists("PluginKeyValueStore", "ExpireAt", "bigint(20)", "bigint", "0") - // saveSchemaVersion(sqlStore, VERSION_5_5_0) + + // migrating user's accepted terms of service data into the new table + sqlStore.GetMaster().Exec("INSERT INTO UserTermsOfService SELECT Id, AcceptedTermsOfServiceId as TermsOfServiceId, :CreateAt FROM Users WHERE AcceptedTermsOfServiceId != \"\" AND AcceptedTermsOfServiceId IS NOT NULL", map[string]interface{}{"CreateAt": model.GetMillis()}) + //saveSchemaVersion(sqlStore, VERSION_5_6_0) //} } diff --git a/store/sqlstore/user_store.go b/store/sqlstore/user_store.go index 0e70a87d9c..8136d4ac4f 100644 --- a/store/sqlstore/user_store.go +++ b/store/sqlstore/user_store.go @@ -82,7 +82,6 @@ func NewSqlUserStore(sqlStore SqlStore, metrics einterfaces.MetricsInterface) st table.ColMap("MfaSecret").SetMaxSize(128) table.ColMap("Position").SetMaxSize(128) table.ColMap("Timezone").SetMaxSize(256) - table.ColMap("AcceptedTermsOfServiceId").SetMaxSize(64) } return us diff --git a/store/sqlstore/user_terms_of_service.go b/store/sqlstore/user_terms_of_service.go new file mode 100644 index 0000000000..1385029819 --- /dev/null +++ b/store/sqlstore/user_terms_of_service.go @@ -0,0 +1,89 @@ +// Copyright (c) 2016-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package sqlstore + +import ( + "database/sql" + "github.com/mattermost/mattermost-server/model" + "github.com/mattermost/mattermost-server/store" + "net/http" +) + +type SqlUserTermsOfServiceStore struct { + SqlStore +} + +func NewSqlUserTermsOfServiceStore(sqlStore SqlStore) store.UserTermsOfServiceStore { + s := SqlUserTermsOfServiceStore{sqlStore} + + for _, db := range sqlStore.GetAllConns() { + table := db.AddTableWithName(model.UserTermsOfService{}, "UserTermsOfService").SetKeys(false, "UserId") + table.ColMap("UserId").SetMaxSize(26) + table.ColMap("TermsOfServiceId").SetMaxSize(26) + } + + return s +} + +func (s SqlUserTermsOfServiceStore) CreateIndexesIfNotExists() { + s.CreateIndexIfNotExists("idx_user_terms_of_service_user_id", "UserTermsOfService", "UserId") +} + +func (s SqlUserTermsOfServiceStore) GetByUser(userId string) store.StoreChannel { + return store.Do(func(result *store.StoreResult) { + var userTermsOfService *model.UserTermsOfService + + err := s.GetReplica().SelectOne(&userTermsOfService, "SELECT * FROM UserTermsOfService WHERE UserId = :userId", map[string]interface{}{"userId": userId}) + if err != nil { + if err == sql.ErrNoRows { + result.Err = model.NewAppError("NewSqlUserTermsOfServiceStore.GetByUser", "store.sql_user_terms_of_service.get_by_user.no_rows.app_error", nil, "", http.StatusNotFound) + } else { + result.Err = model.NewAppError("NewSqlUserTermsOfServiceStore.GetByUser", "store.sql_user_terms_of_service.get_by_user.app_error", nil, "", http.StatusInternalServerError) + } + } else { + result.Data = userTermsOfService + } + }) +} + +func (s SqlUserTermsOfServiceStore) Save(userTermsOfService *model.UserTermsOfService) store.StoreChannel { + return store.Do(func(result *store.StoreResult) { + userTermsOfService.PreSave() + + if result.Err = userTermsOfService.IsValid(); result.Err != nil { + return + } + + if c, err := s.GetMaster().Update(userTermsOfService); err != nil { + result.Err = model.NewAppError( + "SqlUserTermsOfServiceStore.Save", + "store.sql_user_terms_of_service.save.app_error", + nil, + "user_terms_of_service_user_id="+userTermsOfService.UserId+",user_terms_of_service_terms_of_service_id="+userTermsOfService.TermsOfServiceId+",err="+err.Error(), + http.StatusInternalServerError, + ) + } else if c == 0 { + if err := s.GetMaster().Insert(userTermsOfService); err != nil { + result.Err = model.NewAppError( + "SqlUserTermsOfServiceStore.Save", + "store.sql_user_terms_of_service.save.app_error", + nil, + "user_terms_of_service_user_id="+userTermsOfService.UserId+",user_terms_of_service_terms_of_service_id="+userTermsOfService.TermsOfServiceId+",err="+err.Error(), + http.StatusInternalServerError, + ) + } + } + + result.Data = userTermsOfService + }) +} + +func (s SqlUserTermsOfServiceStore) Delete(userId, termsOfServiceId string) store.StoreChannel { + return store.Do(func(result *store.StoreResult) { + if _, err := s.GetMaster().Exec("DELETE FROM UserTermsOfService WHERE UserId = :UserId AND TermsOfServiceId = :TermsOfServiceId", map[string]interface{}{"UserId": userId, "TermsOfServiceId": termsOfServiceId}); err != nil { + result.Err = model.NewAppError("SqlUserTermsOfServiceStore.Delete", "store.sql_user_terms_of_service.delete.app_error", nil, "userId="+userId+", termsOfServiceId="+termsOfServiceId, http.StatusInternalServerError) + return + } + }) +} diff --git a/store/sqlstore/user_terms_of_service_store_test.go b/store/sqlstore/user_terms_of_service_store_test.go new file mode 100644 index 0000000000..ed1dd8f2ac --- /dev/null +++ b/store/sqlstore/user_terms_of_service_store_test.go @@ -0,0 +1,10 @@ +package sqlstore + +import ( + "github.com/mattermost/mattermost-server/store/storetest" + "testing" +) + +func TestUserTermsOfServiceStore(t *testing.T) { + StoreTest(t, storetest.TestUserTermsOfServiceStore) +} diff --git a/store/store.go b/store/store.go index eefaa4649b..6af7de6bad 100644 --- a/store/store.go +++ b/store/store.go @@ -66,6 +66,7 @@ type Store interface { ChannelMemberHistory() ChannelMemberHistoryStore Plugin() PluginStore TermsOfService() TermsOfServiceStore + UserTermsOfService() UserTermsOfServiceStore MarkSystemRanUnitTests() Close() LockToMaster() @@ -527,3 +528,9 @@ type TermsOfServiceStore interface { GetLatest(allowFromCache bool) StoreChannel Get(id string, allowFromCache bool) StoreChannel } + +type UserTermsOfServiceStore interface { + GetByUser(userId string) StoreChannel + Save(userTermsOfService *model.UserTermsOfService) StoreChannel + Delete(userId, termsOfServiceId string) StoreChannel +} diff --git a/store/storetest/mocks/LayeredStoreDatabaseLayer.go b/store/storetest/mocks/LayeredStoreDatabaseLayer.go index 3b06bbdf5d..0531d7c377 100644 --- a/store/storetest/mocks/LayeredStoreDatabaseLayer.go +++ b/store/storetest/mocks/LayeredStoreDatabaseLayer.go @@ -909,6 +909,22 @@ func (_m *LayeredStoreDatabaseLayer) UserAccessToken() store.UserAccessTokenStor return r0 } +// UserTermsOfService provides a mock function with given fields: +func (_m *LayeredStoreDatabaseLayer) UserTermsOfService() store.UserTermsOfServiceStore { + ret := _m.Called() + + var r0 store.UserTermsOfServiceStore + if rf, ok := ret.Get(0).(func() store.UserTermsOfServiceStore); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.UserTermsOfServiceStore) + } + } + + return r0 +} + // Webhook provides a mock function with given fields: func (_m *LayeredStoreDatabaseLayer) Webhook() store.WebhookStore { ret := _m.Called() diff --git a/store/storetest/mocks/SqlStore.go b/store/storetest/mocks/SqlStore.go index 278ca1a619..6f76c8a030 100644 --- a/store/storetest/mocks/SqlStore.go +++ b/store/storetest/mocks/SqlStore.go @@ -778,6 +778,22 @@ func (_m *SqlStore) UserAccessToken() store.UserAccessTokenStore { return r0 } +// UserTermsOfService provides a mock function with given fields: +func (_m *SqlStore) UserTermsOfService() store.UserTermsOfServiceStore { + ret := _m.Called() + + var r0 store.UserTermsOfServiceStore + if rf, ok := ret.Get(0).(func() store.UserTermsOfServiceStore); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.UserTermsOfServiceStore) + } + } + + return r0 +} + // Webhook provides a mock function with given fields: func (_m *SqlStore) Webhook() store.WebhookStore { ret := _m.Called() diff --git a/store/storetest/mocks/Store.go b/store/storetest/mocks/Store.go index b55df20971..1f52d98ecc 100644 --- a/store/storetest/mocks/Store.go +++ b/store/storetest/mocks/Store.go @@ -495,6 +495,22 @@ func (_m *Store) UserAccessToken() store.UserAccessTokenStore { return r0 } +// UserTermsOfService provides a mock function with given fields: +func (_m *Store) UserTermsOfService() store.UserTermsOfServiceStore { + ret := _m.Called() + + var r0 store.UserTermsOfServiceStore + if rf, ok := ret.Get(0).(func() store.UserTermsOfServiceStore); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.UserTermsOfServiceStore) + } + } + + return r0 +} + // Webhook provides a mock function with given fields: func (_m *Store) Webhook() store.WebhookStore { ret := _m.Called() diff --git a/store/storetest/mocks/UserTermsOfServiceStore.go b/store/storetest/mocks/UserTermsOfServiceStore.go new file mode 100644 index 0000000000..f13faecc9f --- /dev/null +++ b/store/storetest/mocks/UserTermsOfServiceStore.go @@ -0,0 +1,62 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +// Regenerate this file using `make store-mocks`. + +package mocks + +import mock "github.com/stretchr/testify/mock" +import model "github.com/mattermost/mattermost-server/model" +import store "github.com/mattermost/mattermost-server/store" + +// UserTermsOfServiceStore is an autogenerated mock type for the UserTermsOfServiceStore type +type UserTermsOfServiceStore struct { + mock.Mock +} + +// Delete provides a mock function with given fields: userId, termsOfServiceId +func (_m *UserTermsOfServiceStore) Delete(userId string, termsOfServiceId string) store.StoreChannel { + ret := _m.Called(userId, termsOfServiceId) + + var r0 store.StoreChannel + if rf, ok := ret.Get(0).(func(string, string) store.StoreChannel); ok { + r0 = rf(userId, termsOfServiceId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.StoreChannel) + } + } + + return r0 +} + +// GetByUser provides a mock function with given fields: userId +func (_m *UserTermsOfServiceStore) GetByUser(userId string) store.StoreChannel { + ret := _m.Called(userId) + + var r0 store.StoreChannel + if rf, ok := ret.Get(0).(func(string) store.StoreChannel); ok { + r0 = rf(userId) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.StoreChannel) + } + } + + return r0 +} + +// Save provides a mock function with given fields: userTermsOfService +func (_m *UserTermsOfServiceStore) Save(userTermsOfService *model.UserTermsOfService) store.StoreChannel { + ret := _m.Called(userTermsOfService) + + var r0 store.StoreChannel + if rf, ok := ret.Get(0).(func(*model.UserTermsOfService) store.StoreChannel); ok { + r0 = rf(userTermsOfService) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(store.StoreChannel) + } + } + + return r0 +} diff --git a/store/storetest/store.go b/store/storetest/store.go index d6ef4fcd00..15971a53fa 100644 --- a/store/storetest/store.go +++ b/store/storetest/store.go @@ -46,34 +46,36 @@ type Store struct { RoleStore mocks.RoleStore SchemeStore mocks.SchemeStore TermsOfServiceStore mocks.TermsOfServiceStore + UserTermsOfServiceStore mocks.UserTermsOfServiceStore } -func (s *Store) Team() store.TeamStore { return &s.TeamStore } -func (s *Store) Channel() store.ChannelStore { return &s.ChannelStore } -func (s *Store) Post() store.PostStore { return &s.PostStore } -func (s *Store) User() store.UserStore { return &s.UserStore } -func (s *Store) Audit() store.AuditStore { return &s.AuditStore } -func (s *Store) ClusterDiscovery() store.ClusterDiscoveryStore { return &s.ClusterDiscoveryStore } -func (s *Store) Compliance() store.ComplianceStore { return &s.ComplianceStore } -func (s *Store) Session() store.SessionStore { return &s.SessionStore } -func (s *Store) OAuth() store.OAuthStore { return &s.OAuthStore } -func (s *Store) System() store.SystemStore { return &s.SystemStore } -func (s *Store) Webhook() store.WebhookStore { return &s.WebhookStore } -func (s *Store) Command() store.CommandStore { return &s.CommandStore } -func (s *Store) CommandWebhook() store.CommandWebhookStore { return &s.CommandWebhookStore } -func (s *Store) Preference() store.PreferenceStore { return &s.PreferenceStore } -func (s *Store) License() store.LicenseStore { return &s.LicenseStore } -func (s *Store) Token() store.TokenStore { return &s.TokenStore } -func (s *Store) Emoji() store.EmojiStore { return &s.EmojiStore } -func (s *Store) Status() store.StatusStore { return &s.StatusStore } -func (s *Store) FileInfo() store.FileInfoStore { return &s.FileInfoStore } -func (s *Store) Reaction() store.ReactionStore { return &s.ReactionStore } -func (s *Store) Job() store.JobStore { return &s.JobStore } -func (s *Store) UserAccessToken() store.UserAccessTokenStore { return &s.UserAccessTokenStore } -func (s *Store) Plugin() store.PluginStore { return &s.PluginStore } -func (s *Store) Role() store.RoleStore { return &s.RoleStore } -func (s *Store) Scheme() store.SchemeStore { return &s.SchemeStore } -func (s *Store) TermsOfService() store.TermsOfServiceStore { return &s.TermsOfServiceStore } +func (s *Store) Team() store.TeamStore { return &s.TeamStore } +func (s *Store) Channel() store.ChannelStore { return &s.ChannelStore } +func (s *Store) Post() store.PostStore { return &s.PostStore } +func (s *Store) User() store.UserStore { return &s.UserStore } +func (s *Store) Audit() store.AuditStore { return &s.AuditStore } +func (s *Store) ClusterDiscovery() store.ClusterDiscoveryStore { return &s.ClusterDiscoveryStore } +func (s *Store) Compliance() store.ComplianceStore { return &s.ComplianceStore } +func (s *Store) Session() store.SessionStore { return &s.SessionStore } +func (s *Store) OAuth() store.OAuthStore { return &s.OAuthStore } +func (s *Store) System() store.SystemStore { return &s.SystemStore } +func (s *Store) Webhook() store.WebhookStore { return &s.WebhookStore } +func (s *Store) Command() store.CommandStore { return &s.CommandStore } +func (s *Store) CommandWebhook() store.CommandWebhookStore { return &s.CommandWebhookStore } +func (s *Store) Preference() store.PreferenceStore { return &s.PreferenceStore } +func (s *Store) License() store.LicenseStore { return &s.LicenseStore } +func (s *Store) Token() store.TokenStore { return &s.TokenStore } +func (s *Store) Emoji() store.EmojiStore { return &s.EmojiStore } +func (s *Store) Status() store.StatusStore { return &s.StatusStore } +func (s *Store) FileInfo() store.FileInfoStore { return &s.FileInfoStore } +func (s *Store) Reaction() store.ReactionStore { return &s.ReactionStore } +func (s *Store) Job() store.JobStore { return &s.JobStore } +func (s *Store) UserAccessToken() store.UserAccessTokenStore { return &s.UserAccessTokenStore } +func (s *Store) Plugin() store.PluginStore { return &s.PluginStore } +func (s *Store) Role() store.RoleStore { return &s.RoleStore } +func (s *Store) Scheme() store.SchemeStore { return &s.SchemeStore } +func (s *Store) TermsOfService() store.TermsOfServiceStore { return &s.TermsOfServiceStore } +func (s *Store) UserTermsOfService() store.UserTermsOfServiceStore { return &s.UserTermsOfServiceStore } func (s *Store) ChannelMemberHistory() store.ChannelMemberHistoryStore { return &s.ChannelMemberHistoryStore } diff --git a/store/storetest/user_terms_of_service.go b/store/storetest/user_terms_of_service.go new file mode 100644 index 0000000000..0b2d132e78 --- /dev/null +++ b/store/storetest/user_terms_of_service.go @@ -0,0 +1,81 @@ +// Copyright (c) 2016-present Mattermost, Inc. All Rights Reserved. +// See License.txt for license information. + +package storetest + +import ( + "github.com/mattermost/mattermost-server/model" + "github.com/mattermost/mattermost-server/store" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestUserTermsOfServiceStore(t *testing.T, ss store.Store) { + t.Run("TestSaveUserTermsOfService", func(t *testing.T) { testSaveUserTermsOfService(t, ss) }) + t.Run("TestGetByUserTermsOfService", func(t *testing.T) { testGetByUserTermsOfService(t, ss) }) + t.Run("TestDeleteUserTermsOfService", func(t *testing.T) { testDeleteUserTermsOfService(t, ss) }) +} + +func testSaveUserTermsOfService(t *testing.T, ss store.Store) { + userTermsOfService := &model.UserTermsOfService{ + UserId: model.NewId(), + TermsOfServiceId: model.NewId(), + } + + r1 := <-ss.UserTermsOfService().Save(userTermsOfService) + if r1.Err != nil { + t.Fatal(r1.Err) + } + + savedUserTermsOfService := r1.Data.(*model.UserTermsOfService) + assert.Equal(t, userTermsOfService.UserId, savedUserTermsOfService.UserId) + assert.Equal(t, userTermsOfService.TermsOfServiceId, savedUserTermsOfService.TermsOfServiceId) + assert.NotEmpty(t, savedUserTermsOfService.CreateAt) +} + +func testGetByUserTermsOfService(t *testing.T, ss store.Store) { + userTermsOfService := &model.UserTermsOfService{ + UserId: model.NewId(), + TermsOfServiceId: model.NewId(), + } + + r1 := <-ss.UserTermsOfService().Save(userTermsOfService) + if r1.Err != nil { + t.Fatal(r1.Err) + } + + r1 = <-ss.UserTermsOfService().GetByUser(userTermsOfService.UserId) + if r1.Err != nil { + t.Fatal(r1.Err) + } + + fetchedUserTermsOfService := r1.Data.(*model.UserTermsOfService) + assert.Equal(t, userTermsOfService.UserId, fetchedUserTermsOfService.UserId) + assert.Equal(t, userTermsOfService.TermsOfServiceId, fetchedUserTermsOfService.TermsOfServiceId) + assert.NotEmpty(t, fetchedUserTermsOfService.CreateAt) +} + +func testDeleteUserTermsOfService(t *testing.T, ss store.Store) { + userTermsOfService := &model.UserTermsOfService{ + UserId: model.NewId(), + TermsOfServiceId: model.NewId(), + } + + r1 := <-ss.UserTermsOfService().Save(userTermsOfService) + if r1.Err != nil { + t.Fatal(r1.Err) + } + + r1 = <-ss.UserTermsOfService().GetByUser(userTermsOfService.UserId) + if r1.Err != nil { + t.Fatal(r1.Err) + } + + r1 = <-ss.UserTermsOfService().Delete(userTermsOfService.UserId, userTermsOfService.TermsOfServiceId) + if r1.Err != nil { + t.Fatal(r1.Err) + } + + r1 = <-ss.UserTermsOfService().GetByUser(userTermsOfService.UserId) + assert.Equal(t, "store.sql_user_terms_of_service.get_by_user.no_rows.app_error", r1.Err.Id) +} diff --git a/utils/config.go b/utils/config.go index 4e14c3bf4e..4becfb3843 100644 --- a/utils/config.go +++ b/utils/config.go @@ -690,6 +690,7 @@ func GenerateClientConfig(c *model.Config, diagnosticId string, license *model.L if *license.Features.CustomTermsOfService { props["EnableCustomTermsOfService"] = strconv.FormatBool(*c.SupportSettings.CustomTermsOfServiceEnabled) + props["CustomTermsOfServiceReAcceptancePeriod"] = strconv.FormatInt(int64(*c.SupportSettings.CustomTermsOfServiceReAcceptancePeriod), 10) } }