diff --git a/retry/retry.go b/retry/retry.go index 8d0cf64..0210c76 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -8,6 +8,8 @@ import ( "context" "errors" "fmt" + "math" + "math/rand" "reflect" "runtime" "strings" @@ -56,6 +58,36 @@ func RetryWithLinearBackoff(interval time.Duration) Option { } } +// RetryWithExponentialWithJitterBackoff set exponential strategy backoff +// todo: Add playground link +func RetryWithExponentialWithJitterBackoff(interval time.Duration, base uint64, maxJitter time.Duration) Option { + if interval <= 0 { + panic("programming error: retry interval should not be lower or equal to 0") + } + + if maxJitter < 0 { + panic("programming error: retry maxJitter should not be lower to 0") + } + + if base%2 == 0 { + return func(rc *RetryConfig) { + rc.backoffStrategy = &shiftExponentialWithJitter{ + interval: interval, + maxJitter: maxJitter, + shifter: uint64(math.Log2(float64(base))), + } + } + } + + return func(rc *RetryConfig) { + rc.backoffStrategy = &exponentialWithJitter{ + interval: interval, + base: time.Duration(base), + maxJitter: maxJitter, + } + } +} + // Context set retry context config. // Play: https://go.dev/play/p/xnAOOXv9GkS func Context(ctx context.Context) Option { @@ -117,8 +149,45 @@ type linear struct { interval time.Duration } -// CalculateInterval is the method implementation for the linear struct. -// It returns the fixed interval defined in the linear struct. +// CalculateInterval calculates the next interval returns a constant. func (l *linear) CalculateInterval() time.Duration { return l.interval } + +// exponentialWithJitter is a struct that implements the BackoffStrategy interface using a exponential backoff strategy. +type exponentialWithJitter struct { + base time.Duration // base is the multiplier for the exponential backoff. + interval time.Duration // interval is the current backoff interval, which will be adjusted over time. + maxJitter time.Duration // maxJitter is the maximum amount of jitter to apply to the backoff interval. +} + +// CalculateInterval calculates the next backoff interval with jitter and updates the interval. +func (e *exponentialWithJitter) CalculateInterval() time.Duration { + current := e.interval + e.interval = e.interval * e.base + return current + jitter(e.maxJitter) +} + +// shiftExponentialWithJitter is a struct that implements the BackoffStrategy interface using a exponential backoff strategy. +type shiftExponentialWithJitter struct { + interval time.Duration // interval is the current backoff interval, which will be adjusted over time. + maxJitter time.Duration // maxJitter is the maximum amount of jitter to apply to the backoff interval. + shifter uint64 // shift by n faster than multiplication +} + +// CalculateInterval calculates the next backoff interval with jitter and updates the interval. +// Uses shift instead of multiplication +func (e *shiftExponentialWithJitter) CalculateInterval() time.Duration { + current := e.interval + e.interval = e.interval << e.shifter + return current + jitter(e.maxJitter) +} + +// Jitter adds a random duration, up to maxJitter, +// to the current interval to introduce randomness and avoid synchronized patterns in retry behavior +func jitter(maxJitter time.Duration) time.Duration { + if maxJitter == 0 { + return 0 + } + return time.Duration(rand.Int63n(int64(maxJitter)) + 1) +} diff --git a/retry/retry_example_test.go b/retry/retry_example_test.go index 431afc4..81b26bb 100644 --- a/retry/retry_example_test.go +++ b/retry/retry_example_test.go @@ -51,6 +51,27 @@ func ExampleRetryWithLinearBackoff() { // 3 } +func ExampleRetryWithExponentialWithJitterBackoff() { + number := 0 + increaseNumber := func() error { + number++ + if number == 3 { + return nil + } + return errors.New("error occurs") + } + + err := Retry(increaseNumber, RetryWithExponentialWithJitterBackoff(time.Microsecond*50, 2, time.Microsecond*25)) + if err != nil { + return + } + + fmt.Println(number) + + // Output: + // 3 +} + func ExampleRetryTimes() { number := 0 diff --git a/retry/retry_test.go b/retry/retry_test.go index 373c821..f4d2e28 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -26,6 +26,80 @@ func TestRetryFailed(t *testing.T) { assert.Equal(DefaultRetryTimes, number) } +func TestRetryShiftExponentialWithJitterFailed(t *testing.T) { + t.Parallel() + + assert := internal.NewAssert(t, "TestRetryShiftExponentialWithJitterFailed") + + var number int + increaseNumber := func() error { + number++ + return errors.New("error occurs") + } + + err := Retry(increaseNumber, RetryWithExponentialWithJitterBackoff(time.Microsecond*50, 2, time.Microsecond*25)) + + assert.IsNotNil(err) + assert.Equal(DefaultRetryTimes, number) +} + +func TestRetryExponentialWithJitterFailed(t *testing.T) { + t.Parallel() + + assert := internal.NewAssert(t, "TestRetryExponentialWithJitterFailed") + + var number int + increaseNumber := func() error { + number++ + return errors.New("error occurs") + } + + err := Retry(increaseNumber, RetryWithExponentialWithJitterBackoff(time.Microsecond*50, 3, time.Microsecond*25)) + + assert.IsNotNil(err) + assert.Equal(DefaultRetryTimes, number) +} + +func TestRetryWithExponentialSucceeded(t *testing.T) { + t.Parallel() + + assert := internal.NewAssert(t, "TestRetryWithExponentialSucceeded") + + var number int + increaseNumber := func() error { + number++ + if number == DefaultRetryTimes { + return nil + } + return errors.New("error occurs") + } + + err := Retry(increaseNumber, RetryWithExponentialWithJitterBackoff(time.Microsecond*50, 3, time.Microsecond*25)) + + assert.IsNil(err) + assert.Equal(DefaultRetryTimes, number) +} + +func TestRetryWithExponentialShiftSucceeded(t *testing.T) { + t.Parallel() + + assert := internal.NewAssert(t, "TestRetryWithExponentialShiftSucceeded") + + var number int + increaseNumber := func() error { + number++ + if number == DefaultRetryTimes { + return nil + } + return errors.New("error occurs") + } + + err := Retry(increaseNumber, RetryWithExponentialWithJitterBackoff(time.Microsecond*50, 4, time.Microsecond*25)) + + assert.IsNil(err) + assert.Equal(DefaultRetryTimes, number) +} + func TestRetrySucceeded(t *testing.T) { t.Parallel() @@ -46,6 +120,74 @@ func TestRetrySucceeded(t *testing.T) { assert.Equal(DefaultRetryTimes, number) } +func TestRetryOneShotSucceeded(t *testing.T) { + t.Parallel() + + assert := internal.NewAssert(t, "TestRetryOneShotSucceeded") + + var number int + increaseNumber := func() error { + number++ + return nil + } + + err := Retry(increaseNumber, RetryWithLinearBackoff(time.Microsecond*50)) + + assert.IsNil(err) + assert.Equal(1, number) +} + +func TestRetryWithExponentialWithJitterBackoffShiftOneShotSucceeded(t *testing.T) { + t.Parallel() + + assert := internal.NewAssert(t, "TestRetryWithExponentialWithJitterBackoffShiftOneShotSucceeded") + + var number int + increaseNumber := func() error { + number++ + return nil + } + + err := Retry(increaseNumber, RetryWithExponentialWithJitterBackoff(time.Microsecond*50, 2, time.Microsecond*25)) + + assert.IsNil(err) + assert.Equal(1, number) +} + +func TestRetryWithExponentialWithJitterBackoffOneShotSucceeded(t *testing.T) { + t.Parallel() + + assert := internal.NewAssert(t, "TestRetryWithExponentialWithJitterBackoffOneShotSucceeded") + + var number int + increaseNumber := func() error { + number++ + return nil + } + + err := Retry(increaseNumber, RetryWithExponentialWithJitterBackoff(time.Microsecond*50, 3, time.Microsecond*25)) + + assert.IsNil(err) + assert.Equal(1, number) +} + +func TestRetryWithExponentialWithJitterBackoffNoJitterOneShotSucceeded(t *testing.T) { + t.Parallel() + + assert := internal.NewAssert(t, "TestRetryWithExponentialWithJitterBackoffNoJitterOneShotSucceeded") + + var number int + increaseNumber := func() error { + number++ + return nil + } + + err := Retry(increaseNumber, RetryWithExponentialWithJitterBackoff(time.Microsecond*50, 3, 0)) + + assert.IsNil(err) + assert.Equal(1, number) +} + func TestSetRetryTimes(t *testing.T) { t.Parallel()