# retry.py — async retry with exponential backoff.
import asyncio
from typing import Awaitable, Callable, Optional, TypeVar
T = TypeVar("T")
async def retry(
fn: Callable[[], Awaitable[T]],
*,
max_attempts: int = 5,
initial_delay_s: float = 0.5,
max_delay_s: float = 10.0,
factor: float = 2.0,
should_retry: Optional[Callable[[BaseException, int], bool]] = None,
on_retry: Optional[Callable[[BaseException, int, float], None]] = None,
) -> T:
"""
Call `fn` up to `max_attempts` times. Wait `initial_delay_s` after the
first failure, doubling each time up to `max_delay_s`. Fails fast if
`should_retry(err, attempt)` returns False.
"""
last_err: Optional[BaseException] = None
for attempt in range(1, max_attempts + 1):
try:
return await fn()
except BaseException as err:
last_err = err
if attempt == max_attempts:
break
if should_retry and not should_retry(err, attempt):
raise
delay = min(initial_delay_s * (factor ** (attempt - 1)), max_delay_s)
if on_retry:
on_retry(err, attempt, delay)
await asyncio.sleep(delay)
assert last_err is not None
raise last_err