diff --git a/server/channels/store/sqlstore/compliance_store.go b/server/channels/store/sqlstore/compliance_store.go index bd902a292e..cd82f50c6c 100644 --- a/server/channels/store/sqlstore/compliance_store.go +++ b/server/channels/store/sqlstore/compliance_store.go @@ -330,13 +330,8 @@ func (s SqlComplianceStore) MessageExport(c request.CTX, cursor model.MessageExp builder = builder.Where(sq.LtOrEq{"Posts.UpdateAt": cursor.UntilUpdateAt}) } - query, args, err := builder.ToSql() - if err != nil { - return nil, cursor, errors.Wrap(err, "unable to construct query to export messages") - } - cposts := []*model.MessageExport{} - if err := s.GetReplica().SelectCtx(c.Context(), &cposts, query, args...); err != nil { + if err := s.GetReplica().SelectBuilderCtx(c.Context(), &cposts, builder); err != nil { return nil, cursor, errors.Wrap(err, "unable to export messages") } if len(cposts) > 0 { diff --git a/server/channels/store/sqlstore/outgoing_oauth_connection_store.go b/server/channels/store/sqlstore/outgoing_oauth_connection_store.go index 398f6ad60c..ba625214d4 100644 --- a/server/channels/store/sqlstore/outgoing_oauth_connection_store.go +++ b/server/channels/store/sqlstore/outgoing_oauth_connection_store.go @@ -121,7 +121,7 @@ func (s *SqlOutgoingOAuthConnectionStore) GetConnections(c request.CTX, filters query = query.Where(sq.Like{"Audiences": fmt.Sprint("%", filters.Audience, "%")}) } - if err := s.GetReplica().SelectBuilder(&conns, query); err != nil { + if err := s.GetReplica().SelectBuilderCtx(c.Context(), &conns, query); err != nil { return nil, errors.Wrap(err, "failed to get OutgoingOAuthConnections") } diff --git a/server/channels/store/sqlstore/sqlx_wrapper.go b/server/channels/store/sqlstore/sqlx_wrapper.go index 3ea0d28dad..bd294ba628 100644 --- a/server/channels/store/sqlstore/sqlx_wrapper.go +++ b/server/channels/store/sqlstore/sqlx_wrapper.go @@ -249,12 +249,16 @@ func (w *sqlxDBWrapper) SelectCtx(ctx context.Context, dest any, query string, a } func (w *sqlxDBWrapper) SelectBuilder(dest any, builder Builder) error { + return w.SelectBuilderCtx(context.Background(), dest, builder) +} + +func (w *sqlxDBWrapper) SelectBuilderCtx(ctx context.Context, dest any, builder Builder) error { query, args, err := builder.ToSql() if err != nil { return err } - return w.Select(dest, query, args...) + return w.SelectCtx(ctx, dest, query, args...) } type sqlxTxWrapper struct { diff --git a/server/channels/store/sqlstore/sqlx_wrapper_test.go b/server/channels/store/sqlstore/sqlx_wrapper_test.go index 264c60844b..35dab5bbed 100644 --- a/server/channels/store/sqlstore/sqlx_wrapper_test.go +++ b/server/channels/store/sqlstore/sqlx_wrapper_test.go @@ -90,3 +90,73 @@ func TestSqlX(t *testing.T) { } }) } + +func TestSqlxSelect(t *testing.T) { + testDrivers := []string{ + model.DatabaseDriverPostgres, + model.DatabaseDriverMysql, + } + for _, driver := range testDrivers { + t.Run(driver, func(t *testing.T) { + settings, err := makeSqlSettings(driver) + if err != nil { + t.Skip(err) + } + *settings.QueryTimeout = 1 + store := &SqlStore{ + rrCounter: 0, + srCounter: 0, + settings: settings, + logger: mlog.CreateConsoleTestLogger(t), + quitMonitor: make(chan struct{}), + wgMonitor: &sync.WaitGroup{}, + } + + require.NoError(t, store.initConnection()) + defer store.Close() + + t.Run("SelectCtx", func(t *testing.T) { + var result []string + err := store.GetMaster().SelectCtx(context.Background(), &result, "SELECT 'test' AS col") + require.NoError(t, err) + require.Equal(t, []string{"test"}, result) + + // Test timeout + ctx, cancel := context.WithTimeout(context.Background(), 1) + defer cancel() + var query string + if driver == model.DatabaseDriverMysql { + query = "SELECT SLEEP(2)" + } else { + query = "SELECT pg_sleep(2)" + } + err = store.GetMaster().SelectCtx(ctx, &result, query) + require.Error(t, err) + require.Equal(t, context.DeadlineExceeded, err) + }) + + t.Run("SelectBuilderCtx", func(t *testing.T) { + var result []string + builder := store.getQueryBuilder(). + Select("'test' AS col") + err := store.GetMaster().SelectBuilderCtx(context.Background(), &result, builder) + require.NoError(t, err) + require.Equal(t, []string{"test"}, result) + + // Test timeout + ctx, cancel := context.WithTimeout(context.Background(), 1) + defer cancel() + if driver == model.DatabaseDriverMysql { + builder = store.getQueryBuilder(). + Select("SLEEP(2)") + } else { + builder = store.getQueryBuilder(). + Select("pg_sleep(2)") + } + err = store.GetMaster().SelectBuilderCtx(ctx, &result, builder) + require.Error(t, err) + require.Equal(t, context.DeadlineExceeded, err) + }) + }) + } +} diff --git a/server/channels/store/sqlstore/user_store.go b/server/channels/store/sqlstore/user_store.go index 4bc5388d83..fda085fc46 100644 --- a/server/channels/store/sqlstore/user_store.go +++ b/server/channels/store/sqlstore/user_store.go @@ -501,7 +501,7 @@ func (us SqlUserStore) GetMfaUsedTimestamps(userId string) ([]int, error) { func (us SqlUserStore) GetMany(ctx context.Context, ids []string) ([]*model.User, error) { query := us.usersQuery.Where(sq.Eq{"Id": ids}) users := []*model.User{} - if err := us.SqlStore.DBXFromContext(ctx).SelectBuilder(&users, query); err != nil { + if err := us.SqlStore.DBXFromContext(ctx).SelectBuilderCtx(ctx, &users, query); err != nil { return nil, errors.Wrap(err, "users_get_many_select") }