Source code for hs_py.ws_server

"""Haystack WebSocket server.

Accepts WebSocket connections using the ``websockets`` sans-I/O layer
(via :class:`~hs_py.ws.HaystackWebSocket`) and dispatches operations to a
:class:`~hs_py.server.HaystackOps` implementation.

Uses ``asyncio.start_server`` for full control over TLS context, consistent
with bac-py patterns and independent of aiohttp.
"""

from __future__ import annotations

import asyncio
import contextlib
import hmac as hmac_mod
import logging
import time
import weakref
from typing import TYPE_CHECKING, Any

import orjson
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK

from hs_py._scram_core import (
    HandshakeState,
    TokenEntry,
    handle_scram,
    scram_hello,
)
from hs_py.encoding.json import encode_grid as encode_grid_json
from hs_py.errors import HaystackError
from hs_py.grid import Grid
from hs_py.metrics import MetricsHooks, _fire
from hs_py.ops import HaystackOps, dispatch_op
from hs_py.tls import TLSConfig, build_server_ssl_context
from hs_py.ws import HaystackWebSocket, cancel_task, heartbeat_loop
from hs_py.ws_codec import (
    CHUNK_SIZE,
    CHUNK_THRESHOLD,
    FLAG_CHUNKED,
    FLAG_ERROR,
    FLAG_PUSH,
    FLAG_RESPONSE,
    ChunkAssembler,
    decode_binary_frame,
    encode_binary_push,
    encode_binary_response,
    encode_chunked_frames,
)

if TYPE_CHECKING:
    from asyncio import StreamReader, StreamWriter

    from hs_py.auth_types import Authenticator, CertAuthenticator

__all__ = [
    "WebSocketServer",
]

_log = logging.getLogger(__name__)

# Maximum concurrent connections to prevent resource exhaustion.
_MAX_CONNECTIONS = 1000

# Maximum WebSocket message size (10 MB) to prevent memory exhaustion.
_MAX_WS_MESSAGE_SIZE = 10 * 1024 * 1024

# Maximum number of items in a batch request.
_MAX_BATCH_SIZE = 1000

# Maximum entries in the response cache.
_MAX_CACHE_SIZE = 2048

# Ops that mutate state and should invalidate read caches.
_MUTATION_OPS = frozenset({"hisWrite", "pointWrite", "invokeAction"})


def _ws_envelope(grid_bytes: bytes, req_id: Any = None, ch: Any = None) -> bytes:
    """Build a JSON envelope around pre-encoded grid bytes."""
    parts = [b'{"grid":', grid_bytes]
    if req_id is not None:
        parts.append(b',"id":')
        parts.append(orjson.dumps(req_id))
    if ch is not None:
        parts.append(b',"ch":')
        parts.append(orjson.dumps(ch))
    parts.append(b"}")
    return b"".join(parts)


[docs] class WebSocketServer: """Haystack WebSocket server using websockets sans-I/O. Usage:: ops = MyHaystackOps() server = WebSocketServer(ops, host="0.0.0.0", port=8080) await server.start() # ... server is running ... await server.stop() """ def __init__( self, ops: HaystackOps, *, auth_token: str = "", authenticator: Authenticator | None = None, tls: TLSConfig | None = None, host: str = "0.0.0.0", port: int = 8080, heartbeat: float = 30.0, metrics: MetricsHooks | None = None, cert_auth: CertAuthenticator | None = None, compression: bool = False, binary: bool = False, user_store: Any = None, binary_compression: int | None = None, chunked: bool = False, ) -> None: """Initialise the WebSocket server. :param ops: :class:`~hs_py.server.HaystackOps` implementation to dispatch to. :param auth_token: Expected bearer token from clients (empty to skip auth). :param authenticator: Optional :class:`~hs_py.auth_types.Authenticator` for SCRAM-SHA-256 authentication over WebSocket messages. When provided, clients perform a SCRAM handshake after connecting. Takes precedence over *auth_token* when both are given. :param tls: Optional :class:`~hs_py.tls.TLSConfig` for TLS 1.3. :param host: Bind address. :param port: Bind port (0 for OS-assigned). :param heartbeat: Ping interval in seconds (0 to disable). :param metrics: Optional :class:`~hs_py.metrics.MetricsHooks` callbacks. :param cert_auth: Optional :class:`~hs_py.server.CertAuthenticator` for mTLS. :param compression: Enable per-message deflate compression. :param binary: Use binary frame encoding instead of JSON envelopes. :param user_store: Optional user store for role-based authorization checks. :param binary_compression: Codec-level compression algorithm for binary frames (e.g. ``COMP_ZLIB``). ``None`` disables codec compression. :param chunked: Enable chunked transfer for large binary payloads. """ self._ops = ops self._auth_token = auth_token self._authenticator = authenticator self._tls = tls self._cert_auth = cert_auth self._binary = binary self._host = host self._port = port self._heartbeat = heartbeat self._metrics = metrics or MetricsHooks() self._compression = compression self._binary_compression = binary_compression self._chunked = chunked self._user_store = user_store self._server: asyncio.Server | None = None self._connections: weakref.WeakSet[HaystackWebSocket] = weakref.WeakSet() self._connection_count = 0 # SCRAM state shared across connections self._handshakes: dict[str, HandshakeState] = {} self._tokens: dict[str, TokenEntry] = {} # Response cache for read ops (grid bytes keyed by filter+limit) self._ws_grid_cache: dict[str, bytes] = {} # Wire push handler so ops can trigger watch pushes self._ops.set_push_handler(self.push_watch)
[docs] async def start(self) -> None: """Start the WebSocket server.""" ssl_ctx = build_server_ssl_context(self._tls) if self._tls else None self._server = await asyncio.start_server( self._handle_client, self._host, self._port, ssl=ssl_ctx, ) addr = self._server.sockets[0].getsockname() if self._server.sockets else ("?", "?") _log.info("Haystack WebSocket server listening on %s:%s", addr[0], addr[1])
[docs] async def stop(self) -> None: """Stop the server and close all connections.""" if self._server is not None: self._server.close() await self._server.wait_closed() self._server = None # Close all tracked connections for ws in set(self._connections): with contextlib.suppress(Exception): await ws.close() self._connection_count = 0
@property def port(self) -> int: """Return the bound port (useful when bound to port 0).""" if self._server and self._server.sockets: return int(self._server.sockets[0].getsockname()[1]) return self._port # ---- Watch push --------------------------------------------------------
[docs] async def push_watch(self, watch_id: str, grid: Grid) -> None: """Push a watch change notification to all connected clients. :param watch_id: The watch identifier. :param grid: Grid of changed entities. """ connections = set(self._connections) if self._binary: # Check if chunking is needed for large push payloads if self._chunked: payload = encode_grid_json(grid) if len(payload) > CHUNK_THRESHOLD: frames = encode_chunked_frames( FLAG_PUSH, 0, "watchPoll", payload, compression=self._binary_compression, chunk_size=CHUNK_SIZE, ) for ws in connections: with contextlib.suppress(Exception): for frame in frames: await ws.send_bytes(frame) return frame = encode_binary_push("watchPoll", grid, compression=self._binary_compression) for ws in connections: with contextlib.suppress(Exception): await ws.send_bytes(frame) else: grid_bytes = encode_grid_json(grid) wid_bytes = orjson.dumps(watch_id) payload = b'{"type":"watch","watchId":' + wid_bytes + b',"grid":' + grid_bytes + b"}" for ws in connections: with contextlib.suppress(Exception): await ws.send_text_preencoded(payload)
def _cached_grid_bytes( self, op: str, msg: dict[str, Any], grid: Grid, ) -> bytes: """Return cached grid bytes for read ops, encode otherwise.""" # Invalidate cache on mutation ops if op in _MUTATION_OPS and self._ws_grid_cache: self._ws_grid_cache.clear() if op == "read": grid_data = msg.get("grid") if isinstance(grid_data, dict): rows = grid_data.get("rows", []) if rows and isinstance(rows[0], dict): filt = rows[0].get("filter", "") limit = rows[0].get("limit", "") key = f"ws_read:{filt}:{limit}" cached = self._ws_grid_cache.get(key) if cached is not None: return cached grid_bytes = encode_grid_json(grid) if len(self._ws_grid_cache) < _MAX_CACHE_SIZE: self._ws_grid_cache[key] = grid_bytes return grid_bytes return encode_grid_json(grid) # ---- Connection handling ----------------------------------------------- async def _handle_client(self, reader: StreamReader, writer: StreamWriter) -> None: """Per-connection handler: accept WS, authenticate, dispatch ops.""" if self._connection_count >= _MAX_CONNECTIONS: _log.warning("Max connections reached (%d), rejecting", _MAX_CONNECTIONS) writer.close() with contextlib.suppress(OSError, ConnectionError): await writer.wait_closed() return self._connection_count += 1 ws: HaystackWebSocket | None = None hb_task: asyncio.Task[None] | None = None remote: str = "?" try: ws = await HaystackWebSocket.accept( reader, writer, compression=self._compression, max_size=_MAX_WS_MESSAGE_SIZE ) self._connections.add(ws) remote = writer.get_extra_info("peername", ("?",))[0] _fire(self._metrics.on_ws_connect, str(remote)) # Authenticate: cert-based first, then SCRAM, then token-based ws_username: str | None = None if self._cert_auth is not None: peercert = writer.get_extra_info("peercert") username = self._cert_auth.authorize(peercert) if username is not None: _log.debug("Client authenticated via certificate CN=%s", username) ws_username = username elif self._authenticator is not None: ws_username = await self._scram_authenticate(ws) if ws_username is None: return elif self._auth_token: authenticated = await self._token_authenticate(ws) if not authenticated: return else: _log.warning("Client certificate not authorized") return elif self._authenticator is not None: ws_username = await self._scram_authenticate(ws) if ws_username is None: return elif self._auth_token: authenticated = await self._token_authenticate(ws) if not authenticated: return # Start heartbeat and message dispatch if self._heartbeat > 0: hb_task = asyncio.create_task( heartbeat_loop(ws, self._heartbeat), name="hs-ws-srv-heartbeat" ) await self._message_loop(ws, ws_username) except (ConnectionClosedOK, ConnectionClosedError): pass except (TimeoutError, ConnectionError, OSError): _log.debug("WebSocket connection error") except Exception: _log.exception("Unexpected error in WebSocket connection handler") finally: await cancel_task(hb_task) self._connection_count -= 1 if ws is not None: self._connections.discard(ws) _fire(self._metrics.on_ws_disconnect, str(remote)) with contextlib.suppress(Exception): await ws.close() async def _token_authenticate(self, ws: HaystackWebSocket) -> bool: """Read and validate a bearer token message. Return ``True`` if valid.""" try: async with asyncio.timeout(10.0): data = await ws.recv() msg = orjson.loads(data) token = msg.get("authToken", "") if token and hmac_mod.compare_digest(token, self._auth_token): return True _log.warning("WebSocket token auth failed") return False except Exception: _log.warning("WebSocket token auth error") return False async def _scram_authenticate(self, ws: HaystackWebSocket) -> str | None: """Perform SCRAM-SHA-256 handshake over WebSocket messages. Message flow:: Client → {"type":"hello","username":"<b64>"} Server → {"type":"hello","handshakeToken":"...","hash":"SHA-256"} Client → {"type":"scram","handshakeToken":"...","data":"<b64>"} Server → {"type":"scram","handshakeToken":"...","data":"<b64>"} Client → {"type":"scram","handshakeToken":"...","data":"<b64>"} Server → {"type":"authOk","authToken":"...","data":"<b64>"} Returns the authenticated username on success, or ``None`` on failure. """ assert self._authenticator is not None try: async with asyncio.timeout(30.0): # Step 1: HELLO data = await ws.recv() msg = orjson.loads(data) # Also accept legacy token auth during SCRAM mode if "authToken" in msg: token = msg["authToken"] if self._auth_token and hmac_mod.compare_digest(token, self._auth_token): return "" # token auth has no username _log.warning("WebSocket token auth failed (SCRAM mode)") await ws.send_text(orjson.dumps({"type": "authErr"}).decode()) return None if msg.get("type") != "hello": _log.warning("Expected hello message, got: %s", msg.get("type")) await ws.send_text(orjson.dumps({"type": "authErr"}).decode()) return None auth_header = f"HELLO username={msg.get('username', '')}" result = await scram_hello(self._authenticator, self._handshakes, auth_header) if result.status != 401 or "handshakeToken" not in result.headers.get( "WWW-Authenticate", "" ): await ws.send_text(orjson.dumps({"type": "authErr"}).decode()) return None from hs_py.auth import _parse_header_params hello_resp_params = _parse_header_params(result.headers["WWW-Authenticate"]) await ws.send_text( orjson.dumps( { "type": "hello", "handshakeToken": hello_resp_params.get("handshakeToken", ""), "hash": hello_resp_params.get("hash", "SHA-256"), } ).decode() ) # Step 2: client-first → server-first data = await ws.recv() msg = orjson.loads(data) if msg.get("type") != "scram": await ws.send_text(orjson.dumps({"type": "authErr"}).decode()) return None ht = msg.get("handshakeToken", "") scram_data = msg.get("data", "") auth_header = f"SCRAM handshakeToken={ht}, data={scram_data}" result = handle_scram(self._handshakes, self._tokens, auth_header) if result.status == 401: www_auth = result.headers.get("WWW-Authenticate", "") resp_params = _parse_header_params(www_auth) await ws.send_text( orjson.dumps( { "type": "scram", "handshakeToken": resp_params.get("handshakeToken", ""), "hash": resp_params.get("hash", "SHA-256"), "data": resp_params.get("data", ""), } ).decode() ) else: await ws.send_text(orjson.dumps({"type": "authErr"}).decode()) return None # Step 3: client-final → server-final + authToken data = await ws.recv() msg = orjson.loads(data) if msg.get("type") != "scram": await ws.send_text(orjson.dumps({"type": "authErr"}).decode()) return None ht = msg.get("handshakeToken", "") scram_data = msg.get("data", "") auth_header = f"SCRAM handshakeToken={ht}, data={scram_data}" result = handle_scram(self._handshakes, self._tokens, auth_header) if result.status == 200: auth_info = result.headers.get("Authentication-Info", "") resp_params = _parse_header_params(auth_info) auth_token = resp_params.get("authToken", "") await ws.send_text( orjson.dumps( { "type": "authOk", "authToken": auth_token, "data": resp_params.get("data", ""), } ).decode() ) _log.debug("WebSocket SCRAM auth succeeded") # Look up username from the issued token token_entry = self._tokens.get(auth_token) return token_entry.username if token_entry else "" await ws.send_text(orjson.dumps({"type": "authErr"}).decode()) return None except Exception: _log.warning("WebSocket SCRAM auth error") return None async def _check_permission(self, username: str | None, op_name: str) -> None: """Check role permissions for a WebSocket op. :raises HaystackError: If the user lacks the required role. """ if self._user_store is None: return from hs_py.user import WRITE_OPS, Role if username is None: raise HaystackError("Authentication required") user = await self._user_store.get_user(username) if user is None or not user.enabled: raise HaystackError("Authentication required") if op_name in WRITE_OPS and user.role < Role.OPERATOR: raise HaystackError( f"Insufficient permissions: {op_name} requires operator or admin role" ) async def _message_loop(self, ws: HaystackWebSocket, username: str | None = None) -> None: """Read request messages and dispatch to HaystackOps.""" chunk_assembler = ChunkAssembler() if self._chunked else None last_cleanup = time.monotonic() while True: data = await ws.recv() _fire(self._metrics.on_ws_message_recv, "", len(data)) # Periodic cleanup of orphaned chunk buffers if chunk_assembler is not None: now = time.monotonic() if now - last_cleanup > 30.0: chunk_assembler.cleanup(now) last_cleanup = now # Binary frame handling — only bytes from binary WS frames if isinstance(data, bytes) and self._binary and len(data) >= 4: await self._handle_binary_message(ws, data, username, chunk_assembler) continue # JSON text handling text = data if isinstance(data, str) else data.decode() try: msg = orjson.loads(text) except orjson.JSONDecodeError: _log.warning("Non-JSON WebSocket message, ignoring") continue # Batch: JSON array of envelopes if isinstance(msg, list): await self._handle_batch(ws, msg, username) continue await self._handle_json_message(ws, msg, username) async def _handle_capabilities(self, ws: HaystackWebSocket, msg: dict[str, Any]) -> None: """Respond to a capabilities negotiation message. The client sends its supported features and the server responds with the intersection of what it supports. """ client_compression = msg.get("compression", []) client_chunked = msg.get("chunked", False) # Determine agreed compression algorithm server_algos = [] if self._binary_compression is not None: from hs_py.ws_codec import COMP_LZMA, COMP_ZLIB algo_names = {COMP_ZLIB: "zlib", COMP_LZMA: "lzma"} name = algo_names.get(self._binary_compression) if name: server_algos.append(name) agreed_compression: str | None = None for algo in client_compression: if algo in server_algos: agreed_compression = algo break agreed_chunked = client_chunked and self._chunked response = { "type": "capabilities", "compression": agreed_compression, "chunked": agreed_chunked, "version": 2, } await ws.send_text_preencoded(orjson.dumps(response)) _log.debug( "Capabilities negotiated: compression=%s chunked=%s", agreed_compression, agreed_chunked, ) async def _handle_json_message( self, ws: HaystackWebSocket, msg: dict[str, Any], username: str | None = None ) -> None: """Dispatch a single JSON request envelope.""" # Handle capabilities negotiation if msg.get("type") == "capabilities": await self._handle_capabilities(ws, msg) return req_id = msg.get("id") op = msg.get("op", "") ch = msg.get("ch") _fire(self._metrics.on_ws_message_recv, op, 0) try: await self._check_permission(username, op) result_grid = await dispatch_op(self._ops, op, msg) except HaystackError as exc: _fire(self._metrics.on_error, op, type(exc).__name__) result_grid = Grid.make_error(str(exc)) except Exception as exc: _log.exception("Unhandled error in op '%s'", op) _fire(self._metrics.on_error, op, type(exc).__name__) result_grid = Grid.make_error("Internal server error") grid_bytes = self._cached_grid_bytes(op, msg, result_grid) payload = _ws_envelope(grid_bytes, req_id, ch) await ws.send_text_preencoded(payload) _fire(self._metrics.on_ws_message_sent, op, len(payload)) async def _handle_binary_message( self, ws: HaystackWebSocket, data: bytes, username: str | None = None, chunk_assembler: ChunkAssembler | None = None, ) -> None: """Dispatch a binary frame request.""" try: flags, req_id, op, grid_bytes = decode_binary_frame(data) except ValueError: _log.warning("Invalid binary frame, ignoring") return if flags & FLAG_PUSH: return # Server doesn't process inbound pushes # Chunked frame — accumulate until complete if flags & FLAG_CHUNKED: if chunk_assembler is None: _log.warning("Received chunked frame but chunking is not enabled, ignoring") return assembled = chunk_assembler.feed(flags, req_id, op, grid_bytes) if assembled is None: return # waiting for more chunks grid_bytes = assembled _fire(self._metrics.on_ws_message_recv, op, len(data)) try: msg: dict[str, Any] = {"op": op} if grid_bytes: msg["grid"] = orjson.loads(grid_bytes) await self._check_permission(username, op) result_grid = await dispatch_op(self._ops, op, msg) except HaystackError as exc: _fire(self._metrics.on_error, op, type(exc).__name__) result_grid = Grid.make_error(str(exc)) except Exception as exc: _log.exception("Unhandled error in binary op '%s'", op) _fire(self._metrics.on_error, op, type(exc).__name__) result_grid = Grid.make_error("Internal server error") # Encode grid payload once, then decide: chunk or single frame payload = encode_grid_json(result_grid) resp_flags = FLAG_RESPONSE if result_grid.is_error: resp_flags |= FLAG_ERROR # Chunk based on raw payload size (before compression) if self._chunked and len(payload) > CHUNK_THRESHOLD: frames = encode_chunked_frames( resp_flags, req_id, op, payload, compression=self._binary_compression, chunk_size=CHUNK_SIZE, ) for frame in frames: await ws.send_bytes(frame) _fire(self._metrics.on_ws_message_sent, op, sum(len(f) for f in frames)) else: response = encode_binary_response( req_id, op, result_grid, is_error=result_grid.is_error, compression=self._binary_compression, ) await ws.send_bytes(response) _fire(self._metrics.on_ws_message_sent, op, len(response)) async def _handle_batch( self, ws: HaystackWebSocket, batch: list[Any], username: str | None = None ) -> None: """Dispatch a batch of JSON request envelopes concurrently.""" items = [item for item in batch if isinstance(item, dict)] if not items: return # Cap batch size to prevent resource exhaustion if len(items) > _MAX_BATCH_SIZE: items = items[:_MAX_BATCH_SIZE] async def _dispatch_item(item: dict[str, Any]) -> bytes: r_id = item.get("id") r_op = item.get("op", "") try: await self._check_permission(username, r_op) r_grid = await dispatch_op(self._ops, r_op, item) except HaystackError as exc: _fire(self._metrics.on_error, r_op, type(exc).__name__) r_grid = Grid.make_error(str(exc)) except Exception: _log.exception("Unhandled error in batch op '%s'", r_op) _fire(self._metrics.on_error, r_op, "InternalError") r_grid = Grid.make_error("Internal server error") grid_bytes = self._cached_grid_bytes(r_op, item, r_grid) return _ws_envelope(grid_bytes, r_id) item_bytes = await asyncio.gather(*[_dispatch_item(item) for item in items]) payload = b"[" + b",".join(item_bytes) + b"]" await ws.send_text_preencoded(payload) _fire(self._metrics.on_ws_message_sent, "batch", len(payload))