go_zoom_kinesis/retry/
backoff.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
use rand::Rng;
use std::time::Duration;
use tracing::trace;

/// Trait defining backoff behavior
#[async_trait::async_trait]
pub trait Backoff: Send + Sync {
    /// Calculate the next backoff delay
    fn next_delay(&self, attempt: u32) -> Duration;

    /// Reset any internal state
    fn reset(&mut self);
}

/// Exponential backoff with jitter
#[derive(Debug, Clone)]
pub struct ExponentialBackoff {
    initial_delay: Duration,
    max_delay: Duration,
    multiplier: f64,
    jitter_factor: f64,
}

impl ExponentialBackoff {
    pub fn new(initial_delay: Duration, max_delay: Duration) -> Self {
        Self {
            initial_delay,
            max_delay,
            multiplier: 2.0,
            jitter_factor: 0.1,
        }
    }

    /// Create a new builder for ExponentialBackoff
    pub fn builder() -> ExponentialBackoffBuilder {
        ExponentialBackoffBuilder::default()
    }

    fn calculate_delay(&self, attempt: u32) -> Duration {
        let base = self.initial_delay.as_millis() as f64;
        let multiplier = self.multiplier.powi(attempt as i32);

        // Calculate exponential delay
        let exp_delay = base * multiplier;

        // Cap at max_delay BEFORE adding jitter
        let capped_delay = exp_delay.min(self.max_delay.as_millis() as f64);

        // Add jitter: random value between -jitter_factor and +jitter_factor
        let jitter_range = capped_delay * self.jitter_factor;
        let jitter = rand::thread_rng().gen_range(-jitter_range..=jitter_range);

        // Cap again after adding jitter to ensure we never exceed max_delay
        let final_delay = (capped_delay + jitter).min(self.max_delay.as_millis() as f64);

        trace!(
            attempt = attempt,
            base_delay_ms = capped_delay,
            jitter_ms = jitter,
            final_delay_ms = final_delay,
            "Calculated backoff delay"
        );

        Duration::from_millis(final_delay as u64)
    }
}

impl Backoff for ExponentialBackoff {
    fn next_delay(&self, attempt: u32) -> Duration {
        self.calculate_delay(attempt)
    }

    fn reset(&mut self) {
        // ExponentialBackoff is stateless, no reset needed
    }
}

/// Builder for ExponentialBackoff
#[derive(Debug)]
pub struct ExponentialBackoffBuilder {
    initial_delay: Duration,
    max_delay: Duration,
    multiplier: f64,
    jitter_factor: f64,
}

impl Default for ExponentialBackoffBuilder {
    fn default() -> Self {
        Self {
            initial_delay: Duration::from_millis(100),
            max_delay: Duration::from_secs(30),
            multiplier: 2.0,
            jitter_factor: 0.1,
        }
    }
}

impl ExponentialBackoffBuilder {
    pub fn initial_delay(mut self, delay: Duration) -> Self {
        self.initial_delay = delay;
        self
    }

    pub fn max_delay(mut self, delay: Duration) -> Self {
        self.max_delay = delay;
        self
    }

    pub fn multiplier(mut self, multiplier: f64) -> Self {
        self.multiplier = multiplier;
        self
    }

    pub fn jitter_factor(mut self, factor: f64) -> Self {
        self.jitter_factor = factor.clamp(0.0, 1.0);
        self
    }

    pub fn build(self) -> ExponentialBackoff {
        ExponentialBackoff {
            initial_delay: self.initial_delay,
            max_delay: self.max_delay,
            multiplier: self.multiplier,
            jitter_factor: self.jitter_factor,
        }
    }
}

/// Fixed backoff implementation
#[derive(Debug, Clone)]
pub struct FixedBackoff {
    delay: Duration,
}

impl FixedBackoff {
    #[cfg(test)]
    pub fn new(delay: Duration) -> Self {
        Self { delay }
    }
}

impl Backoff for FixedBackoff {
    fn next_delay(&self, attempt: u32) -> Duration {
        trace!(attempt = attempt, delay_ms = ?self.delay.as_millis(), "Fixed backoff delay");
        self.delay
    }

    fn reset(&mut self) {
        // FixedBackoff is stateless, no reset needed
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::time::Duration;

    #[test]
    fn test_exponential_backoff_calculation() {
        let max_delay = Duration::from_secs(10);
        let backoff = ExponentialBackoff::builder()
            .initial_delay(Duration::from_millis(100))
            .max_delay(max_delay)
            .multiplier(2.0)
            .jitter_factor(0.1)
            .build();

        // Test multiple attempts to ensure exponential growth
        let delays: Vec<Duration> = (0..5).map(|attempt| backoff.next_delay(attempt)).collect();

        // Verify each delay is larger than the previous (up to max)
        for i in 1..delays.len() {
            assert!(delays[i] >= delays[i - 1] || delays[i] == max_delay);
        }

        // Test with a high attempt number that would exceed max_delay without capping
        let max_attempt_delay = backoff.next_delay(20);
        assert!(
            max_attempt_delay <= max_delay,
            "Delay {:?} exceeded max delay {:?}",
            max_attempt_delay,
            max_delay
        );
    }
    #[test]
    fn test_jitter_variation() {
        let backoff = ExponentialBackoff::builder()
            .initial_delay(Duration::from_millis(100))
            .jitter_factor(0.5)
            .build();

        // Get multiple delays for the same attempt
        let delays: Vec<Duration> = (0..100).map(|_| backoff.next_delay(1)).collect();

        // Verify not all delays are identical (jitter is working)
        let unique_delays: std::collections::HashSet<_> = delays.iter().collect();
        assert!(unique_delays.len() > 1);

        // Verify delays are within expected bounds
        let base_delay = 200.0; // 100ms * 2^1
        for delay in delays {
            let ms = delay.as_millis() as f64;
            assert!(ms >= base_delay * 0.5); // -50% jitter
            assert!(ms <= base_delay * 1.5); // +50% jitter
        }
    }

    #[test]
    fn test_fixed_backoff() {
        let backoff = FixedBackoff::new(Duration::from_millis(100));

        // Verify delay remains constant
        for attempt in 0..5 {
            assert_eq!(backoff.next_delay(attempt), Duration::from_millis(100));
        }
    }

    #[test]
    fn test_builder_constraints() {
        let backoff = ExponentialBackoff::builder()
            .jitter_factor(1.5) // Should be clamped to 1.0
            .build();

        assert!(backoff.jitter_factor <= 1.0);

        let backoff = ExponentialBackoff::builder()
            .jitter_factor(-0.5) // Should be clamped to 0.0
            .build();

        assert!(backoff.jitter_factor >= 0.0);
    }
}