Skip to content

Commit 460c3d5

Browse files
authored
fix(fal_client): expose subscribe polling interval (#1073)
1 parent 0fd3d9d commit 460c3d5

2 files changed

Lines changed: 139 additions & 16 deletions

File tree

projects/fal_client/src/fal_client/client.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,7 @@ async def _async_request(
10411041
# to fire when the hang is at the raw SSL socket level (ssl.read()). Using an
10421042
# httpx.Timeout object with a shorter connect timeout ensures we detect stalls.
10431043
QUEUE_POLL_TIMEOUT = httpx.Timeout(120.0, connect=30.0)
1044+
DEFAULT_QUEUE_POLL_INTERVAL = 0.1
10441045

10451046

10461047
def _is_ingress_error(response: httpx.Response) -> bool:
@@ -1514,7 +1515,7 @@ def status(self, *, with_logs: bool = False) -> Status:
15141515
return self._parse_status(response.json())
15151516

15161517
def iter_events(
1517-
self, *, with_logs: bool = False, interval: float = 0.1
1518+
self, *, with_logs: bool = False, interval: float = DEFAULT_QUEUE_POLL_INTERVAL
15181519
) -> Iterator[Status]:
15191520
"""Continuously poll for the status of the request and yield it at each interval till
15201521
the request is completed. If `with_logs` is True, logs will be included in the response.
@@ -1528,9 +1529,9 @@ def iter_events(
15281529

15291530
time.sleep(interval)
15301531

1531-
def get(self) -> AnyJSON:
1532+
def get(self, *, interval: float = DEFAULT_QUEUE_POLL_INTERVAL) -> AnyJSON:
15321533
"""Wait till the request is completed and return the result of the inference call."""
1533-
for _ in self.iter_events(with_logs=False):
1534+
for _ in self.iter_events(with_logs=False, interval=interval):
15341535
continue
15351536

15361537
response = _maybe_retry_request(
@@ -1591,7 +1592,7 @@ async def status(self, *, with_logs: bool = False) -> Status:
15911592
return self._parse_status(response.json())
15921593

15931594
async def iter_events(
1594-
self, *, with_logs: bool = False, interval: float = 0.1
1595+
self, *, with_logs: bool = False, interval: float = DEFAULT_QUEUE_POLL_INTERVAL
15951596
) -> AsyncIterator[Status]:
15961597
"""Continuously poll for the status of the request and yield it at each interval till
15971598
the request is completed. If `with_logs` is True, logs will be included in the response.
@@ -1605,9 +1606,9 @@ async def iter_events(
16051606

16061607
await asyncio.sleep(interval)
16071608

1608-
async def get(self) -> AnyJSON:
1609+
async def get(self, *, interval: float = DEFAULT_QUEUE_POLL_INTERVAL) -> AnyJSON:
16091610
"""Wait till the request is completed and return the result."""
1610-
async for _ in self.iter_events(with_logs=False):
1611+
async for _ in self.iter_events(with_logs=False, interval=interval):
16111612
continue
16121613

16131614
response = await _async_maybe_retry_request(
@@ -1807,6 +1808,7 @@ async def subscribe(
18071808
path: str = "",
18081809
hint: str | None = None,
18091810
with_logs: bool = False,
1811+
interval: float = DEFAULT_QUEUE_POLL_INTERVAL,
18101812
on_enqueue: Optional[Callable[[str], None | Awaitable[None]]] = None,
18111813
on_queue_update: Optional[Callable[[Status], None | Awaitable[None]]] = None,
18121814
priority: Optional[Priority] = None,
@@ -1817,6 +1819,7 @@ async def subscribe(
18171819
"""Subscribe to an application and wait for the result.
18181820
18191821
Args:
1822+
interval: Polling interval in seconds while waiting for request updates.
18201823
start_timeout: Server-side request timeout in seconds. Limits total time spent
18211824
waiting before processing starts (includes queue wait, retries, and
18221825
routing). Does not apply once the application begins processing.
@@ -1855,12 +1858,14 @@ async def _do_subscribe() -> AnyJSON:
18551858
await result
18561859

18571860
if on_queue_update is not None:
1858-
async for event in handle.iter_events(with_logs=with_logs):
1861+
async for event in handle.iter_events(
1862+
with_logs=with_logs, interval=interval
1863+
):
18591864
result = on_queue_update(event)
18601865
if inspect.isawaitable(result):
18611866
await result
18621867

1863-
return await handle.get()
1868+
return await handle.get(interval=interval)
18641869

18651870
if client_timeout is None:
18661871
return await _do_subscribe()
@@ -2338,6 +2343,7 @@ def subscribe(
23382343
path: str = "",
23392344
hint: str | None = None,
23402345
with_logs: bool = False,
2346+
interval: float = DEFAULT_QUEUE_POLL_INTERVAL,
23412347
on_enqueue: Optional[Callable[[str], None]] = None,
23422348
on_queue_update: Optional[Callable[[Status], None]] = None,
23432349
priority: Optional[Priority] = None,
@@ -2348,6 +2354,7 @@ def subscribe(
23482354
"""Subscribe to an application and wait for the result.
23492355
23502356
Args:
2357+
interval: Polling interval in seconds while waiting for request updates.
23512358
start_timeout: Server-side request timeout in seconds. Limits total time spent
23522359
waiting before processing starts (includes queue wait, retries, and
23532360
routing). Does not apply once the application begins processing.
@@ -2384,10 +2391,10 @@ def _do_subscribe() -> AnyJSON:
23842391
on_enqueue(handle.request_id)
23852392

23862393
if on_queue_update is not None:
2387-
for event in handle.iter_events(with_logs=with_logs):
2394+
for event in handle.iter_events(with_logs=with_logs, interval=interval):
23882395
on_queue_update(event)
23892396

2390-
return handle.get()
2397+
return handle.get(interval=interval)
23912398

23922399
if client_timeout is None:
23932400
return _do_subscribe()

projects/fal_client/tests/unit/test_client.py

Lines changed: 122 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
AsyncRequestHandle,
1818
CDN_URL,
1919
Completed,
20+
DEFAULT_QUEUE_POLL_INTERVAL,
2021
FAL_CDN_FALLBACK_URL,
2122
FalClientHTTPError,
2223
FalClientTimeoutError,
@@ -1254,7 +1255,11 @@ def test_sync_handle_retries(monkeypatch):
12541255
)
12551256

12561257
# Mock iter_events to skip waiting
1257-
def _iter_events(self, with_logs: bool = False, interval: float = 0.1):
1258+
def _iter_events(
1259+
self,
1260+
with_logs: bool = False,
1261+
interval: float = DEFAULT_QUEUE_POLL_INTERVAL,
1262+
):
12581263
return iter([Completed(logs=[], metrics={})])
12591264

12601265
monkeypatch.setattr(SyncRequestHandle, "iter_events", _iter_events, raising=True)
@@ -1336,7 +1341,11 @@ async def test_async_handle_retries(monkeypatch):
13361341
)
13371342

13381343
# Mock iter_events to skip waiting
1339-
async def _iter_events(self, with_logs: bool = False, interval: float = 0.1):
1344+
async def _iter_events(
1345+
self,
1346+
with_logs: bool = False,
1347+
interval: float = DEFAULT_QUEUE_POLL_INTERVAL,
1348+
):
13401349
yield Completed(logs=[], metrics={})
13411350

13421351
monkeypatch.setattr(AsyncRequestHandle, "iter_events", _iter_events, raising=True)
@@ -1418,7 +1427,11 @@ def test_sync_get_does_not_retry_on_500_503(monkeypatch, status_code):
14181427
)
14191428

14201429
# Mock iter_events to skip waiting
1421-
def _iter_events(self, with_logs: bool = False, interval: float = 0.1):
1430+
def _iter_events(
1431+
self,
1432+
with_logs: bool = False,
1433+
interval: float = DEFAULT_QUEUE_POLL_INTERVAL,
1434+
):
14221435
return iter([Completed(logs=[], metrics={})])
14231436

14241437
monkeypatch.setattr(SyncRequestHandle, "iter_events", _iter_events, raising=True)
@@ -1454,7 +1467,11 @@ async def test_async_get_does_not_retry_on_500_503(monkeypatch, status_code):
14541467
)
14551468

14561469
# Mock iter_events to skip waiting
1457-
async def _iter_events(self, with_logs: bool = False, interval: float = 0.1):
1470+
async def _iter_events(
1471+
self,
1472+
with_logs: bool = False,
1473+
interval: float = DEFAULT_QUEUE_POLL_INTERVAL,
1474+
):
14581475
yield Completed(logs=[], metrics={})
14591476

14601477
monkeypatch.setattr(AsyncRequestHandle, "iter_events", _iter_events, raising=True)
@@ -1486,7 +1503,11 @@ def test_sync_handle_retries_ingress(monkeypatch):
14861503
)
14871504

14881505
# Mock iter_events to skip waiting
1489-
def _iter_events(self, with_logs: bool = False, interval: float = 0.1):
1506+
def _iter_events(
1507+
self,
1508+
with_logs: bool = False,
1509+
interval: float = DEFAULT_QUEUE_POLL_INTERVAL,
1510+
):
14901511
return iter([Completed(logs=[], metrics={})])
14911512

14921513
monkeypatch.setattr(SyncRequestHandle, "iter_events", _iter_events, raising=True)
@@ -1574,7 +1595,11 @@ async def test_async_handle_retries_ingress(monkeypatch):
15741595
)
15751596

15761597
# Mock iter_events to skip waiting
1577-
async def _iter_events(self, with_logs: bool = False, interval: float = 0.1):
1598+
async def _iter_events(
1599+
self,
1600+
with_logs: bool = False,
1601+
interval: float = DEFAULT_QUEUE_POLL_INTERVAL,
1602+
):
15781603
yield Completed(logs=[], metrics={})
15791604

15801605
monkeypatch.setattr(AsyncRequestHandle, "iter_events", _iter_events, raising=True)
@@ -1972,6 +1997,50 @@ def test_sync_client_subscribe_with_start_timeout():
19721997
assert first_call_kwargs["headers"]["X-Fal-Request-Timeout"] == "90.0"
19731998

19741999

2000+
def test_sync_client_subscribe_with_interval(monkeypatch):
2001+
"""Test that subscribe() passes interval through to request polling."""
2002+
iter_event_calls = []
2003+
2004+
def _iter_events(
2005+
self,
2006+
*,
2007+
with_logs: bool = False,
2008+
interval: float = DEFAULT_QUEUE_POLL_INTERVAL,
2009+
):
2010+
iter_event_calls.append((with_logs, interval))
2011+
return iter([Completed(logs=[], metrics={})])
2012+
2013+
monkeypatch.setattr(SyncRequestHandle, "iter_events", _iter_events, raising=True)
2014+
2015+
with patch("fal_client.client._maybe_retry_request") as mock_request:
2016+
submit_response = Mock()
2017+
submit_response.json.return_value = {
2018+
"request_id": "req-123",
2019+
"response_url": "http://response",
2020+
"status_url": "http://status",
2021+
"cancel_url": "http://cancel",
2022+
}
2023+
2024+
result_response = Mock()
2025+
result_response.json.return_value = {"result": "done"}
2026+
2027+
mock_request.side_effect = [submit_response, result_response]
2028+
2029+
queue_updates = []
2030+
client = SyncClient(key="test-key")
2031+
result = client.subscribe(
2032+
"test-app",
2033+
{"input": "data"},
2034+
interval=0.5,
2035+
with_logs=True,
2036+
on_queue_update=queue_updates.append,
2037+
)
2038+
2039+
assert result == {"result": "done"}
2040+
assert len(queue_updates) == 1
2041+
assert iter_event_calls == [(True, 0.5), (False, 0.5)]
2042+
2043+
19752044
@pytest.mark.asyncio
19762045
async def test_async_client_run_with_start_timeout():
19772046
"""Test that start_timeout adds X-Fal-Request-Timeout header in async run()."""
@@ -2044,6 +2113,53 @@ async def test_async_client_subscribe_with_start_timeout():
20442113
assert first_call_kwargs["headers"]["X-Fal-Request-Timeout"] == "90.0"
20452114

20462115

2116+
@pytest.mark.asyncio
2117+
async def test_async_client_subscribe_with_interval(monkeypatch):
2118+
"""Test that async subscribe() passes interval through to request polling."""
2119+
iter_event_calls = []
2120+
2121+
async def _iter_events(
2122+
self,
2123+
*,
2124+
with_logs: bool = False,
2125+
interval: float = DEFAULT_QUEUE_POLL_INTERVAL,
2126+
):
2127+
iter_event_calls.append((with_logs, interval))
2128+
yield Completed(logs=[], metrics={})
2129+
2130+
monkeypatch.setattr(AsyncRequestHandle, "iter_events", _iter_events, raising=True)
2131+
2132+
with patch(
2133+
"fal_client.client._async_maybe_retry_request", new_callable=AsyncMock
2134+
) as mock_request:
2135+
submit_response = Mock()
2136+
submit_response.json.return_value = {
2137+
"request_id": "req-789",
2138+
"response_url": "http://response",
2139+
"status_url": "http://status",
2140+
"cancel_url": "http://cancel",
2141+
}
2142+
2143+
result_response = Mock()
2144+
result_response.json.return_value = {"result": "async_done"}
2145+
2146+
mock_request.side_effect = [submit_response, result_response]
2147+
2148+
queue_updates = []
2149+
client = AsyncClient(key="test-key")
2150+
result = await client.subscribe(
2151+
"test-app",
2152+
{"input": "data"},
2153+
interval=0.5,
2154+
with_logs=True,
2155+
on_queue_update=queue_updates.append,
2156+
)
2157+
2158+
assert result == {"result": "async_done"}
2159+
assert len(queue_updates) == 1
2160+
assert iter_event_calls == [(True, 0.5), (False, 0.5)]
2161+
2162+
20472163
def test_sync_client_run_without_start_timeout_no_header():
20482164
"""Test that no timeout header is added when start_timeout is not specified."""
20492165
with patch("fal_client.client._maybe_retry_request") as mock_request:

0 commit comments

Comments
 (0)