package sqlstore import ( "context" "errors" "fmt" "testing" "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRetryingDisabled(t *testing.T) { store := InitTestDB(t) require.Equal(t, 0, store.dbCfg.QueryRetries) funcToTest := map[string]func(ctx context.Context, callback DBTransactionFunc) error{ "WithDbSession": store.WithDbSession, "WithNewDbSession": store.WithNewDbSession, } for name, f := range funcToTest { t.Run(fmt.Sprintf("%s should return the error immediately", name), func(t *testing.T) { i := 0 callback := func(sess *DBSession) error { i++ return errors.New("some error") } err := f(context.Background(), callback) require.Error(t, err) require.Equal(t, 1, i) }) errCodes := []sqlite3.ErrNo{sqlite3.ErrBusy, sqlite3.ErrLocked} for _, c := range errCodes { t.Run(fmt.Sprintf("%s should return the sqlite3.Error %v immediately", name, c.Error()), func(t *testing.T) { i := 0 callback := func(sess *DBSession) error { i++ return sqlite3.Error{Code: c} } err := f(context.Background(), callback) require.Error(t, err) var driverErr sqlite3.Error require.ErrorAs(t, err, &driverErr) require.Equal(t, 1, i) assert.Equal(t, c, driverErr.Code) }) } t.Run(fmt.Sprintf("%s should not return error if the callback succeeds", name), func(t *testing.T) { i := 0 callback := func(sess *DBSession) error { i++ return nil } err := f(context.Background(), callback) require.NoError(t, err) require.Equal(t, 1, i) }) } } func TestRetryingOnFailures(t *testing.T) { store := InitTestDB(t) store.dbCfg.QueryRetries = 5 funcToTest := map[string]func(ctx context.Context, callback DBTransactionFunc) error{ "WithDbSession": store.WithDbSession, "WithNewDbSession": store.WithNewDbSession, } for name, f := range funcToTest { t.Run(fmt.Sprintf("%s should return the error immediately if it's other than sqlite3.ErrLocked or sqlite3.ErrBusy", name), func(t *testing.T) { i := 0 callback := func(sess *DBSession) error { i++ return errors.New("some error") } err := f(context.Background(), callback) require.Error(t, err) require.Equal(t, 1, i) }) errCodes := []sqlite3.ErrNo{sqlite3.ErrBusy, sqlite3.ErrLocked} for _, c := range errCodes { t.Run(fmt.Sprintf("%s should return the sqlite3.Error %v if all retries have failed", name, c.Error()), func(t *testing.T) { i := 0 callback := func(sess *DBSession) error { i++ return sqlite3.Error{Code: c} } err := f(context.Background(), callback) require.Error(t, err) var driverErr sqlite3.Error require.ErrorAs(t, err, &driverErr) require.Equal(t, store.dbCfg.QueryRetries, i) assert.Equal(t, c, driverErr.Code) }) } t.Run(fmt.Sprintf("%s should not return the error if successive retries succeed", name), func(t *testing.T) { i := 0 callback := func(sess *DBSession) error { i++ var err error switch { case store.dbCfg.QueryRetries == i: err = nil default: err = sqlite3.Error{Code: sqlite3.ErrBusy} } return err } err := f(context.Background(), callback) require.NoError(t, err) require.Equal(t, store.dbCfg.QueryRetries, i) }) } // Check SQL query sess := store.GetSqlxSession() rows, err := sess.Query(context.Background(), `SELECT "hello",2.3,4`) t.Cleanup(func() { err := rows.Close() require.NoError(t, err) }) require.NoError(t, err) require.True(t, rows.Next()) // first row str1 := "" val2 := float64(100.1) val3 := int64(200) err = rows.Scan(&str1, &val2, &val3) require.NoError(t, err) require.Equal(t, "hello", str1) require.Equal(t, 2.3, val2) require.Equal(t, int64(4), val3) require.False(t, rows.Next()) // no more rows }