diff --git a/src/acp/connection.py b/src/acp/connection.py index aca1c19..cd8e9ee 100644 --- a/src/acp/connection.py +++ b/src/acp/connection.py @@ -73,6 +73,7 @@ def __init__( sender_factory: SenderFactory | None = None, observers: list[StreamObserver] | None = None, listening: bool = True, + receive_timeout: float | None = None, ) -> None: self._handler = handler self._writer = writer @@ -102,6 +103,7 @@ def __init__( ) self._dispatcher.start() self._observers: list[StreamObserver] = list(observers or []) + self._receive_timeout = receive_timeout async def close(self) -> None: """Stop the receive loop and cancel any in-flight handler tasks.""" @@ -148,7 +150,7 @@ async def send_notification(self, method: str, params: JsonValue | None = None) async def _receive_loop(self) -> None: try: while True: - line = await self._reader.readline() + line = await asyncio.wait_for(self._reader.readline(), timeout=self._receive_timeout) if not line: break try: @@ -160,6 +162,8 @@ async def _receive_loop(self) -> None: await self._process_message(message) except asyncio.CancelledError: return + except asyncio.TimeoutError: + raise RequestError.internal_error({"details": "Agent timeout"}) from None async def _process_message(self, message: dict[str, Any]) -> None: method = message.get("method")