Source code for hs_py.ws

"""Sans-I/O WebSocket wrapper for Haystack.

Uses the ``websockets`` library's sans-I/O protocol objects together with
``asyncio`` TCP/TLS streams.  Each :class:`HaystackWebSocket` instance owns
one WebSocket connection backed by a ``(StreamReader, StreamWriter)`` pair.
"""

from __future__ import annotations

import asyncio
import contextlib
import logging
import socket
from collections import deque
from typing import TYPE_CHECKING
from urllib.parse import urlparse

from websockets.client import ClientProtocol
from websockets.exceptions import (
    ConnectionClosedError,
    ConnectionClosedOK,
    InvalidState,
    ProtocolError,
)
from websockets.extensions.permessage_deflate import (
    ClientPerMessageDeflateFactory,
    ServerPerMessageDeflateFactory,
)
from websockets.frames import Close, Frame, Opcode
from websockets.http11 import Request
from websockets.protocol import State as _WSState
from websockets.server import ServerProtocol
from websockets.typing import Subprotocol
from websockets.uri import parse_uri

if TYPE_CHECKING:
    import ssl
    from asyncio import StreamReader, StreamWriter
    from collections.abc import Sequence

__all__ = [
    "HaystackWebSocket",
    "cancel_task",
    "heartbeat_loop",
]

_log = logging.getLogger(__name__)

# Read buffer size for asyncio streams.
_READ_SIZE = 65536

# Write buffer tuning — keep low for prompt frame delivery.
_WRITE_HIGH_WATER = 32768
_WRITE_LOW_WATER = 8192

# Default subprotocol for Haystack WebSocket connections.
HAYSTACK_SUBPROTOCOL = "haystack"


[docs] async def cancel_task(task: asyncio.Task[object] | None) -> None: """Cancel an asyncio task and suppress :class:`~asyncio.CancelledError`. :param task: Task to cancel, or ``None`` (no-op). """ if task is not None: task.cancel() with contextlib.suppress(asyncio.CancelledError): await task
[docs] async def heartbeat_loop(ws: HaystackWebSocket, interval: float) -> None: """Periodically send WebSocket pings to keep a connection alive. :param ws: WebSocket connection to ping. :param interval: Seconds between pings. """ try: while True: await asyncio.sleep(interval) await ws.ping() except asyncio.CancelledError: return except Exception: _log.debug("Heartbeat loop ended")
def _set_nodelay(writer: StreamWriter) -> None: """Enable TCP_NODELAY and tune write buffer limits.""" sock = writer.get_extra_info("socket") if sock is not None: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) transport = writer.transport if transport is not None: transport.set_write_buffer_limits(high=_WRITE_HIGH_WATER, low=_WRITE_LOW_WATER) def _write_pending(protocol: ClientProtocol | ServerProtocol, writer: StreamWriter) -> bool: """Write all pending protocol data to the writer. Return True if any written.""" wrote = False for chunk in protocol.data_to_send(): if chunk: writer.write(chunk) wrote = True return wrote def _drain_to_send(protocol: ClientProtocol | ServerProtocol) -> bytes: """Collect all pending outgoing data from the protocol.""" return b"".join(protocol.data_to_send())
[docs] class HaystackWebSocket: """Async WebSocket connection using the websockets sans-I/O protocol. Use :meth:`connect` to initiate a client connection or :meth:`accept` to accept a server-side connection. Both return a ready-to-use instance. """ def __init__( self, reader: StreamReader, writer: StreamWriter, protocol: ClientProtocol | ServerProtocol, ) -> None: """Initialise from an established TCP connection. Prefer :meth:`connect` (client) or :meth:`accept` (server) rather than calling this directly. :param reader: asyncio stream reader. :param writer: asyncio stream writer. :param protocol: Negotiated websockets protocol object. """ self._reader = reader self._writer = writer self._protocol = protocol self._pending_frames: deque[Frame] = deque() # -- Client factory --
[docs] @classmethod async def connect( cls, uri: str, ssl_ctx: ssl.SSLContext | None = None, *, subprotocol: str = HAYSTACK_SUBPROTOCOL, handshake_timeout: float = 10.0, max_size: int | None = None, compression: bool = False, ) -> HaystackWebSocket: """Initiate a WebSocket client connection. :param uri: WebSocket URI (``wss://host:port/path``). :param ssl_ctx: TLS context, or *None* for plaintext ``ws://``. :param subprotocol: WebSocket subprotocol to negotiate. :param handshake_timeout: Maximum seconds for the handshake. :param max_size: Maximum WebSocket message size. :param compression: Enable per-message deflate compression. :returns: Connected :class:`HaystackWebSocket` instance. """ parsed = urlparse(uri) host = parsed.hostname or "localhost" default_port = 443 if parsed.scheme == "wss" else 80 port = parsed.port or default_port use_ssl = ssl_ctx if parsed.scheme == "wss" else None reader, writer = await asyncio.open_connection(host, port, ssl=use_ssl) _set_nodelay(writer) ws_uri = parse_uri(uri) extensions_factories = [ClientPerMessageDeflateFactory()] if compression else None protocol = ClientProtocol( ws_uri, subprotocols=[Subprotocol(subprotocol)], max_size=max_size, extensions=extensions_factories, ) request = protocol.connect() protocol.send_request(request) outgoing = _drain_to_send(protocol) if outgoing: writer.write(outgoing) await writer.drain() try: async with asyncio.timeout(handshake_timeout): while True: data = await reader.read(_READ_SIZE) if not data: msg = "Connection closed during WebSocket handshake" raise ConnectionError(msg) protocol.receive_data(data) if protocol.handshake_exc is not None: raise protocol.handshake_exc events = protocol.events_received() if events: break outgoing = _drain_to_send(protocol) if outgoing: writer.write(outgoing) await writer.drain() except BaseException: writer.close() with contextlib.suppress(OSError, ConnectionError): await writer.wait_closed() raise _log.debug("Haystack WebSocket client connected to %s:%d", host, port) return cls(reader, writer, protocol)
# -- Server factory --
[docs] @classmethod async def accept( cls, reader: StreamReader, writer: StreamWriter, *, subprotocol: str = HAYSTACK_SUBPROTOCOL, handshake_timeout: float = 10.0, max_size: int | None = None, compression: bool = False, ) -> HaystackWebSocket: """Accept an inbound WebSocket connection on existing streams. :param reader: asyncio StreamReader from the accepted TCP connection. :param writer: asyncio StreamWriter from the accepted TCP connection. :param subprotocol: WebSocket subprotocol to accept. :param handshake_timeout: Maximum seconds for the handshake. :param max_size: Maximum WebSocket message size. :param compression: Enable per-message deflate compression. :returns: Accepted :class:`HaystackWebSocket` instance. """ extensions_factories = [ServerPerMessageDeflateFactory()] if compression else [] protocol = ServerProtocol( subprotocols=[Subprotocol(subprotocol)], max_size=max_size, extensions=extensions_factories, ) _set_nodelay(writer) async with asyncio.timeout(handshake_timeout): while True: data = await reader.read(_READ_SIZE) if not data: msg = "Connection closed before WebSocket handshake" raise ConnectionError(msg) protocol.receive_data(data) events: Sequence[object] = protocol.events_received() if events: request = events[0] if not isinstance(request, Request): msg = f"Expected HTTP request, got {type(request)}" raise ProtocolError(msg) break response = protocol.accept(request) protocol.send_response(response) outgoing = _drain_to_send(protocol) if outgoing: writer.write(outgoing) await writer.drain() if protocol.handshake_exc is not None: raise protocol.handshake_exc _log.debug("Haystack WebSocket server accepted connection") return cls(reader, writer, protocol)
# -- I/O operations --
[docs] async def send_text(self, text: str) -> None: """Send a text WebSocket frame. :param text: Text payload to send. """ self._protocol.send_text(text.encode()) if _write_pending(self._protocol, self._writer): await self._writer.drain()
[docs] async def send_text_preencoded(self, data: bytes) -> None: """Send pre-encoded UTF-8 bytes as a text WebSocket frame. Avoids the ``str → encode → bytes`` roundtrip when the caller already holds a UTF-8 byte string (e.g. from :func:`orjson.dumps`). :param data: UTF-8 encoded payload bytes. """ self._protocol.send_text(data) if _write_pending(self._protocol, self._writer): await self._writer.drain()
[docs] async def send_bytes(self, data: bytes) -> None: """Send a binary WebSocket frame. :param data: Binary payload to send. """ self._protocol.send_binary(data) if _write_pending(self._protocol, self._writer): await self._writer.drain()
[docs] async def ping(self, data: bytes = b"") -> None: """Send a WebSocket ping frame.""" self._protocol.send_ping(data) if _write_pending(self._protocol, self._writer): await self._writer.drain()
[docs] async def recv(self) -> str | bytes: """Receive the next WebSocket message payload. :returns: ``str`` for text frames, ``bytes`` for binary frames. :raises ConnectionClosedOK: On graceful close. :raises ConnectionClosedError: On abnormal close. """ while True: # Drain buffered frames first while self._pending_frames: frame = self._pending_frames.popleft() result = self._handle_frame(frame) if result is not None: await self._flush_outgoing() return result # Check for already-received events events = self._protocol.events_received() for i, raw_event in enumerate(events): if isinstance(raw_event, Frame): result = self._handle_frame(raw_event) if result is not None: # Stash remaining frames for remaining in events[i + 1 :]: if isinstance(remaining, Frame): self._pending_frames.append(remaining) await self._flush_outgoing() return result # Need more data from the network data = await self._reader.read(_READ_SIZE) if not data: raise ConnectionClosedError(None, None, rcvd_then_sent=None) self._protocol.receive_data(data) if self._protocol.handshake_exc is not None: raise self._protocol.handshake_exc await self._flush_outgoing()
def _handle_frame(self, frame: Frame) -> str | bytes | None: """Process a single frame. Return payload for data frames, None for control.""" if frame.opcode == Opcode.BINARY: return bytes(frame.data) if frame.opcode == Opcode.TEXT: return bytes(frame.data).decode() if frame.opcode == Opcode.CLOSE: rcvd = Close.parse(frame.data) if frame.data else None raise ConnectionClosedOK(rcvd, None, rcvd_then_sent=None) # PING/PONG — protocol auto-replies, just flush return None
[docs] async def close(self, code: int = 1000, reason: str = "") -> None: """Initiate a graceful WebSocket close. :param code: WebSocket close code (default 1000 = normal). :param reason: Optional close reason string. """ try: self._protocol.send_close(code, reason) outgoing = _drain_to_send(self._protocol) if outgoing: self._writer.write(outgoing) await self._writer.drain() try: async with asyncio.timeout(5): data = await self._reader.read(_READ_SIZE) if data: self._protocol.receive_data(data) except (TimeoutError, OSError, ConnectionError): pass except (OSError, ConnectionError, InvalidState): pass finally: await self._close_transport()
async def _flush_outgoing(self) -> None: """Write any pending protocol output (e.g. PONG replies).""" if _write_pending(self._protocol, self._writer): with contextlib.suppress(OSError, ConnectionError): await self._writer.drain() async def _close_transport(self) -> None: """Close the underlying TCP connection and wait for it to complete.""" try: if not self._writer.is_closing(): self._writer.close() with contextlib.suppress(OSError, ConnectionError): await self._writer.wait_closed() except (OSError, RuntimeError): pass @property def is_open(self) -> bool: """``True`` if the WebSocket connection appears open.""" return self._protocol.state is _WSState.OPEN @property def subprotocol(self) -> str | None: """Return the negotiated WebSocket subprotocol.""" return self._protocol.subprotocol