diff --git a/api4/user_test.go b/api4/user_test.go index c8f73fb1ba..b5821c63d9 100644 --- a/api4/user_test.go +++ b/api4/user_test.go @@ -2118,8 +2118,8 @@ func TestUserLoginMFAFlow(t *testing.T) { t.Fatal(result.Err) } - if result := <-th.Server.Store.User().UpdateMfaSecret(th.BasicUser.Id, secret.Secret); result.Err != nil { - t.Fatal(result.Err) + if err = th.Server.Store.User().UpdateMfaSecret(th.BasicUser.Id, secret.Secret); err != nil { + t.Fatal(err) } user, resp := th.Client.Login(th.BasicUser.Email, th.BasicUser.Password) @@ -2150,8 +2150,8 @@ func TestUserLoginMFAFlow(t *testing.T) { t.Fatal(result.Err) } - if result := <-th.Server.Store.User().UpdateMfaSecret(th.BasicUser.Id, secret.Secret); result.Err != nil { - t.Fatal(result.Err) + if err = th.Server.Store.User().UpdateMfaSecret(th.BasicUser.Id, secret.Secret); err != nil { + t.Fatal(err) } code := dgoogauth.ComputeCode(secret.Secret, time.Now().UTC().Unix()/30) diff --git a/services/mfa/mfa.go b/services/mfa/mfa.go index e404f96e71..6b88250226 100644 --- a/services/mfa/mfa.go +++ b/services/mfa/mfa.go @@ -70,8 +70,8 @@ func (m *Mfa) GenerateSecret(user *model.User) (string, []byte, *model.AppError) img := code.PNG() - if result := <-m.Store.User().UpdateMfaSecret(user.Id, secret); result.Err != nil { - return "", nil, model.NewAppError("GenerateQrCode", "mfa.generate_qr_code.save_secret.app_error", nil, result.Err.Error(), http.StatusInternalServerError) + if err := m.Store.User().UpdateMfaSecret(user.Id, secret); err != nil { + return "", nil, model.NewAppError("GenerateQrCode", "mfa.generate_qr_code.save_secret.app_error", nil, err.Error(), http.StatusInternalServerError) } return secret, img, nil @@ -112,14 +112,18 @@ func (m *Mfa) Deactivate(userId string) *model.AppError { } achan := m.Store.User().UpdateMfaActive(userId, false) - schan := m.Store.User().UpdateMfaSecret(userId, "") + schan := make(chan *model.AppError, 1) + go func() { + schan <- m.Store.User().UpdateMfaSecret(userId, "") + close(schan) + }() if result := <-achan; result.Err != nil { return model.NewAppError("Deactivate", "mfa.deactivate.save_active.app_error", nil, result.Err.Error(), http.StatusInternalServerError) } - if result := <-schan; result.Err != nil { - return model.NewAppError("Deactivate", "mfa.deactivate.save_secret.app_error", nil, result.Err.Error(), http.StatusInternalServerError) + if err := <-schan; err != nil { + return model.NewAppError("Deactivate", "mfa.deactivate.save_secret.app_error", nil, err.Error(), http.StatusInternalServerError) } return nil diff --git a/services/mfa/mfa_test.go b/services/mfa/mfa_test.go index 34b7bbc713..17da3722df 100644 --- a/services/mfa/mfa_test.go +++ b/services/mfa/mfa_test.go @@ -9,7 +9,6 @@ import ( "github.com/mattermost/mattermost-server/model" "github.com/mattermost/mattermost-server/plugin/plugintest/mock" - "github.com/mattermost/mattermost-server/store" "github.com/mattermost/mattermost-server/store/storetest/mocks" "github.com/mattermost/mattermost-server/utils/testutils" @@ -26,11 +25,8 @@ func TestGenerateSecret(t *testing.T) { configService := testutils.StaticConfigService{Cfg: &config} storeMock := mocks.Store{} userStoreMock := mocks.UserStore{} - userStoreMock.On("UpdateMfaSecret", user.Id, mock.AnythingOfType("string")).Return(func(userId string, secret string) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - result.Data = nil - result.Err = nil - }) + userStoreMock.On("UpdateMfaSecret", user.Id, mock.AnythingOfType("string")).Return(func(userId string, secret string) *model.AppError { + return nil }) storeMock.On("User").Return(&userStoreMock) diff --git a/store/sqlstore/user_store.go b/store/sqlstore/user_store.go index ebf0b2e6b9..217308b26c 100644 --- a/store/sqlstore/user_store.go +++ b/store/sqlstore/user_store.go @@ -302,16 +302,14 @@ func (us SqlUserStore) UpdateAuthData(userId string, service string, authData *s return userId, nil } -func (us SqlUserStore) UpdateMfaSecret(userId, secret string) store.StoreChannel { - return store.Do(func(result *store.StoreResult) { - updateAt := model.GetMillis() +func (us SqlUserStore) UpdateMfaSecret(userId, secret string) *model.AppError { + updateAt := model.GetMillis() - if _, err := us.GetMaster().Exec("UPDATE Users SET MfaSecret = :Secret, UpdateAt = :UpdateAt WHERE Id = :UserId", map[string]interface{}{"Secret": secret, "UpdateAt": updateAt, "UserId": userId}); err != nil { - result.Err = model.NewAppError("SqlUserStore.UpdateMfaSecret", "store.sql_user.update_mfa_secret.app_error", nil, "id="+userId+", "+err.Error(), http.StatusInternalServerError) - } else { - result.Data = userId - } - }) + if _, err := us.GetMaster().Exec("UPDATE Users SET MfaSecret = :Secret, UpdateAt = :UpdateAt WHERE Id = :UserId", map[string]interface{}{"Secret": secret, "UpdateAt": updateAt, "UserId": userId}); err != nil { + return model.NewAppError("SqlUserStore.UpdateMfaSecret", "store.sql_user.update_mfa_secret.app_error", nil, "id="+userId+", "+err.Error(), http.StatusInternalServerError) + } + + return nil } func (us SqlUserStore) UpdateMfaActive(userId string, active bool) store.StoreChannel { diff --git a/store/store.go b/store/store.go index 45b56835fd..b472511834 100644 --- a/store/store.go +++ b/store/store.go @@ -253,7 +253,7 @@ type UserStore interface { UpdateUpdateAt(userId string) StoreChannel UpdatePassword(userId, newPassword string) StoreChannel UpdateAuthData(userId string, service string, authData *string, email string, resetMfa bool) (string, *model.AppError) - UpdateMfaSecret(userId, secret string) StoreChannel + UpdateMfaSecret(userId, secret string) *model.AppError UpdateMfaActive(userId string, active bool) StoreChannel Get(id string) (*model.User, *model.AppError) GetAll() StoreChannel diff --git a/store/storetest/mocks/UserStore.go b/store/storetest/mocks/UserStore.go index 4b58258d17..2860044ce1 100644 --- a/store/storetest/mocks/UserStore.go +++ b/store/storetest/mocks/UserStore.go @@ -1036,15 +1036,15 @@ func (_m *UserStore) UpdateMfaActive(userId string, active bool) store.StoreChan } // UpdateMfaSecret provides a mock function with given fields: userId, secret -func (_m *UserStore) UpdateMfaSecret(userId string, secret string) store.StoreChannel { +func (_m *UserStore) UpdateMfaSecret(userId string, secret string) *model.AppError { ret := _m.Called(userId, secret) - var r0 store.StoreChannel - if rf, ok := ret.Get(0).(func(string, string) store.StoreChannel); ok { + var r0 *model.AppError + if rf, ok := ret.Get(0).(func(string, string) *model.AppError); ok { r0 = rf(userId, secret) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(store.StoreChannel) + r0 = ret.Get(0).(*model.AppError) } } diff --git a/store/storetest/user_store.go b/store/storetest/user_store.go index 7e44b228f7..306934e6a9 100644 --- a/store/storetest/user_store.go +++ b/store/storetest/user_store.go @@ -1916,12 +1916,12 @@ func testUserStoreUpdateMfaSecret(t *testing.T, ss store.Store) { time.Sleep(100 * time.Millisecond) - if err := (<-ss.User().UpdateMfaSecret(u1.Id, "12345")).Err; err != nil { + if err := ss.User().UpdateMfaSecret(u1.Id, "12345"); err != nil { t.Fatal(err) } // should pass, no update will occur though - if err := (<-ss.User().UpdateMfaSecret("junk", "12345")).Err; err != nil { + if err := ss.User().UpdateMfaSecret("junk", "12345"); err != nil { t.Fatal(err) } }