"""Haystack WebSocket server.
Accepts WebSocket connections using the ``websockets`` sans-I/O layer
(via :class:`~hs_py.ws.HaystackWebSocket`) and dispatches operations to a
:class:`~hs_py.server.HaystackOps` implementation.
Uses ``asyncio.start_server`` for full control over TLS context, consistent
with bac-py patterns and independent of aiohttp.
"""
from __future__ import annotations
import asyncio
import contextlib
import hmac as hmac_mod
import logging
import time
import weakref
from typing import TYPE_CHECKING, Any
import orjson
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from hs_py._scram_core import (
HandshakeState,
TokenEntry,
handle_scram,
scram_hello,
)
from hs_py.encoding.json import encode_grid as encode_grid_json
from hs_py.errors import HaystackError
from hs_py.grid import Grid
from hs_py.metrics import MetricsHooks, _fire
from hs_py.ops import HaystackOps, dispatch_op
from hs_py.tls import TLSConfig, build_server_ssl_context
from hs_py.ws import HaystackWebSocket, cancel_task, heartbeat_loop
from hs_py.ws_codec import (
CHUNK_SIZE,
CHUNK_THRESHOLD,
FLAG_CHUNKED,
FLAG_ERROR,
FLAG_PUSH,
FLAG_RESPONSE,
ChunkAssembler,
decode_binary_frame,
encode_binary_push,
encode_binary_response,
encode_chunked_frames,
)
if TYPE_CHECKING:
from asyncio import StreamReader, StreamWriter
from hs_py.auth_types import Authenticator, CertAuthenticator
__all__ = [
"WebSocketServer",
]
_log = logging.getLogger(__name__)
# Maximum concurrent connections to prevent resource exhaustion.
_MAX_CONNECTIONS = 1000
# Maximum WebSocket message size (10 MB) to prevent memory exhaustion.
_MAX_WS_MESSAGE_SIZE = 10 * 1024 * 1024
# Maximum number of items in a batch request.
_MAX_BATCH_SIZE = 1000
# Maximum entries in the response cache.
_MAX_CACHE_SIZE = 2048
# Ops that mutate state and should invalidate read caches.
_MUTATION_OPS = frozenset({"hisWrite", "pointWrite", "invokeAction"})
def _ws_envelope(grid_bytes: bytes, req_id: Any = None, ch: Any = None) -> bytes:
"""Build a JSON envelope around pre-encoded grid bytes."""
parts = [b'{"grid":', grid_bytes]
if req_id is not None:
parts.append(b',"id":')
parts.append(orjson.dumps(req_id))
if ch is not None:
parts.append(b',"ch":')
parts.append(orjson.dumps(ch))
parts.append(b"}")
return b"".join(parts)
[docs]
class WebSocketServer:
"""Haystack WebSocket server using websockets sans-I/O.
Usage::
ops = MyHaystackOps()
server = WebSocketServer(ops, host="0.0.0.0", port=8080)
await server.start()
# ... server is running ...
await server.stop()
"""
def __init__(
self,
ops: HaystackOps,
*,
auth_token: str = "",
authenticator: Authenticator | None = None,
tls: TLSConfig | None = None,
host: str = "0.0.0.0",
port: int = 8080,
heartbeat: float = 30.0,
metrics: MetricsHooks | None = None,
cert_auth: CertAuthenticator | None = None,
compression: bool = False,
binary: bool = False,
user_store: Any = None,
binary_compression: int | None = None,
chunked: bool = False,
) -> None:
"""Initialise the WebSocket server.
:param ops: :class:`~hs_py.server.HaystackOps` implementation to dispatch to.
:param auth_token: Expected bearer token from clients (empty to skip auth).
:param authenticator: Optional :class:`~hs_py.auth_types.Authenticator` for
SCRAM-SHA-256 authentication over WebSocket messages. When provided,
clients perform a SCRAM handshake after connecting. Takes precedence
over *auth_token* when both are given.
:param tls: Optional :class:`~hs_py.tls.TLSConfig` for TLS 1.3.
:param host: Bind address.
:param port: Bind port (0 for OS-assigned).
:param heartbeat: Ping interval in seconds (0 to disable).
:param metrics: Optional :class:`~hs_py.metrics.MetricsHooks` callbacks.
:param cert_auth: Optional :class:`~hs_py.server.CertAuthenticator` for mTLS.
:param compression: Enable per-message deflate compression.
:param binary: Use binary frame encoding instead of JSON envelopes.
:param user_store: Optional user store for role-based authorization checks.
:param binary_compression: Codec-level compression algorithm for binary frames
(e.g. ``COMP_ZLIB``). ``None`` disables codec compression.
:param chunked: Enable chunked transfer for large binary payloads.
"""
self._ops = ops
self._auth_token = auth_token
self._authenticator = authenticator
self._tls = tls
self._cert_auth = cert_auth
self._binary = binary
self._host = host
self._port = port
self._heartbeat = heartbeat
self._metrics = metrics or MetricsHooks()
self._compression = compression
self._binary_compression = binary_compression
self._chunked = chunked
self._user_store = user_store
self._server: asyncio.Server | None = None
self._connections: weakref.WeakSet[HaystackWebSocket] = weakref.WeakSet()
self._connection_count = 0
# SCRAM state shared across connections
self._handshakes: dict[str, HandshakeState] = {}
self._tokens: dict[str, TokenEntry] = {}
# Response cache for read ops (grid bytes keyed by filter+limit)
self._ws_grid_cache: dict[str, bytes] = {}
# Wire push handler so ops can trigger watch pushes
self._ops.set_push_handler(self.push_watch)
[docs]
async def start(self) -> None:
"""Start the WebSocket server."""
ssl_ctx = build_server_ssl_context(self._tls) if self._tls else None
self._server = await asyncio.start_server(
self._handle_client,
self._host,
self._port,
ssl=ssl_ctx,
)
addr = self._server.sockets[0].getsockname() if self._server.sockets else ("?", "?")
_log.info("Haystack WebSocket server listening on %s:%s", addr[0], addr[1])
[docs]
async def stop(self) -> None:
"""Stop the server and close all connections."""
if self._server is not None:
self._server.close()
await self._server.wait_closed()
self._server = None
# Close all tracked connections
for ws in set(self._connections):
with contextlib.suppress(Exception):
await ws.close()
self._connection_count = 0
@property
def port(self) -> int:
"""Return the bound port (useful when bound to port 0)."""
if self._server and self._server.sockets:
return int(self._server.sockets[0].getsockname()[1])
return self._port
# ---- Watch push --------------------------------------------------------
[docs]
async def push_watch(self, watch_id: str, grid: Grid) -> None:
"""Push a watch change notification to all connected clients.
:param watch_id: The watch identifier.
:param grid: Grid of changed entities.
"""
connections = set(self._connections)
if self._binary:
# Check if chunking is needed for large push payloads
if self._chunked:
payload = encode_grid_json(grid)
if len(payload) > CHUNK_THRESHOLD:
frames = encode_chunked_frames(
FLAG_PUSH,
0,
"watchPoll",
payload,
compression=self._binary_compression,
chunk_size=CHUNK_SIZE,
)
for ws in connections:
with contextlib.suppress(Exception):
for frame in frames:
await ws.send_bytes(frame)
return
frame = encode_binary_push("watchPoll", grid, compression=self._binary_compression)
for ws in connections:
with contextlib.suppress(Exception):
await ws.send_bytes(frame)
else:
grid_bytes = encode_grid_json(grid)
wid_bytes = orjson.dumps(watch_id)
payload = b'{"type":"watch","watchId":' + wid_bytes + b',"grid":' + grid_bytes + b"}"
for ws in connections:
with contextlib.suppress(Exception):
await ws.send_text_preencoded(payload)
def _cached_grid_bytes(
self,
op: str,
msg: dict[str, Any],
grid: Grid,
) -> bytes:
"""Return cached grid bytes for read ops, encode otherwise."""
# Invalidate cache on mutation ops
if op in _MUTATION_OPS and self._ws_grid_cache:
self._ws_grid_cache.clear()
if op == "read":
grid_data = msg.get("grid")
if isinstance(grid_data, dict):
rows = grid_data.get("rows", [])
if rows and isinstance(rows[0], dict):
filt = rows[0].get("filter", "")
limit = rows[0].get("limit", "")
key = f"ws_read:{filt}:{limit}"
cached = self._ws_grid_cache.get(key)
if cached is not None:
return cached
grid_bytes = encode_grid_json(grid)
if len(self._ws_grid_cache) < _MAX_CACHE_SIZE:
self._ws_grid_cache[key] = grid_bytes
return grid_bytes
return encode_grid_json(grid)
# ---- Connection handling -----------------------------------------------
async def _handle_client(self, reader: StreamReader, writer: StreamWriter) -> None:
"""Per-connection handler: accept WS, authenticate, dispatch ops."""
if self._connection_count >= _MAX_CONNECTIONS:
_log.warning("Max connections reached (%d), rejecting", _MAX_CONNECTIONS)
writer.close()
with contextlib.suppress(OSError, ConnectionError):
await writer.wait_closed()
return
self._connection_count += 1
ws: HaystackWebSocket | None = None
hb_task: asyncio.Task[None] | None = None
remote: str = "?"
try:
ws = await HaystackWebSocket.accept(
reader, writer, compression=self._compression, max_size=_MAX_WS_MESSAGE_SIZE
)
self._connections.add(ws)
remote = writer.get_extra_info("peername", ("?",))[0]
_fire(self._metrics.on_ws_connect, str(remote))
# Authenticate: cert-based first, then SCRAM, then token-based
ws_username: str | None = None
if self._cert_auth is not None:
peercert = writer.get_extra_info("peercert")
username = self._cert_auth.authorize(peercert)
if username is not None:
_log.debug("Client authenticated via certificate CN=%s", username)
ws_username = username
elif self._authenticator is not None:
ws_username = await self._scram_authenticate(ws)
if ws_username is None:
return
elif self._auth_token:
authenticated = await self._token_authenticate(ws)
if not authenticated:
return
else:
_log.warning("Client certificate not authorized")
return
elif self._authenticator is not None:
ws_username = await self._scram_authenticate(ws)
if ws_username is None:
return
elif self._auth_token:
authenticated = await self._token_authenticate(ws)
if not authenticated:
return
# Start heartbeat and message dispatch
if self._heartbeat > 0:
hb_task = asyncio.create_task(
heartbeat_loop(ws, self._heartbeat), name="hs-ws-srv-heartbeat"
)
await self._message_loop(ws, ws_username)
except (ConnectionClosedOK, ConnectionClosedError):
pass
except (TimeoutError, ConnectionError, OSError):
_log.debug("WebSocket connection error")
except Exception:
_log.exception("Unexpected error in WebSocket connection handler")
finally:
await cancel_task(hb_task)
self._connection_count -= 1
if ws is not None:
self._connections.discard(ws)
_fire(self._metrics.on_ws_disconnect, str(remote))
with contextlib.suppress(Exception):
await ws.close()
async def _token_authenticate(self, ws: HaystackWebSocket) -> bool:
"""Read and validate a bearer token message. Return ``True`` if valid."""
try:
async with asyncio.timeout(10.0):
data = await ws.recv()
msg = orjson.loads(data)
token = msg.get("authToken", "")
if token and hmac_mod.compare_digest(token, self._auth_token):
return True
_log.warning("WebSocket token auth failed")
return False
except Exception:
_log.warning("WebSocket token auth error")
return False
async def _scram_authenticate(self, ws: HaystackWebSocket) -> str | None:
"""Perform SCRAM-SHA-256 handshake over WebSocket messages.
Message flow::
Client → {"type":"hello","username":"<b64>"}
Server → {"type":"hello","handshakeToken":"...","hash":"SHA-256"}
Client → {"type":"scram","handshakeToken":"...","data":"<b64>"}
Server → {"type":"scram","handshakeToken":"...","data":"<b64>"}
Client → {"type":"scram","handshakeToken":"...","data":"<b64>"}
Server → {"type":"authOk","authToken":"...","data":"<b64>"}
Returns the authenticated username on success, or ``None`` on failure.
"""
assert self._authenticator is not None
try:
async with asyncio.timeout(30.0):
# Step 1: HELLO
data = await ws.recv()
msg = orjson.loads(data)
# Also accept legacy token auth during SCRAM mode
if "authToken" in msg:
token = msg["authToken"]
if self._auth_token and hmac_mod.compare_digest(token, self._auth_token):
return "" # token auth has no username
_log.warning("WebSocket token auth failed (SCRAM mode)")
await ws.send_text(orjson.dumps({"type": "authErr"}).decode())
return None
if msg.get("type") != "hello":
_log.warning("Expected hello message, got: %s", msg.get("type"))
await ws.send_text(orjson.dumps({"type": "authErr"}).decode())
return None
auth_header = f"HELLO username={msg.get('username', '')}"
result = await scram_hello(self._authenticator, self._handshakes, auth_header)
if result.status != 401 or "handshakeToken" not in result.headers.get(
"WWW-Authenticate", ""
):
await ws.send_text(orjson.dumps({"type": "authErr"}).decode())
return None
from hs_py.auth import _parse_header_params
hello_resp_params = _parse_header_params(result.headers["WWW-Authenticate"])
await ws.send_text(
orjson.dumps(
{
"type": "hello",
"handshakeToken": hello_resp_params.get("handshakeToken", ""),
"hash": hello_resp_params.get("hash", "SHA-256"),
}
).decode()
)
# Step 2: client-first → server-first
data = await ws.recv()
msg = orjson.loads(data)
if msg.get("type") != "scram":
await ws.send_text(orjson.dumps({"type": "authErr"}).decode())
return None
ht = msg.get("handshakeToken", "")
scram_data = msg.get("data", "")
auth_header = f"SCRAM handshakeToken={ht}, data={scram_data}"
result = handle_scram(self._handshakes, self._tokens, auth_header)
if result.status == 401:
www_auth = result.headers.get("WWW-Authenticate", "")
resp_params = _parse_header_params(www_auth)
await ws.send_text(
orjson.dumps(
{
"type": "scram",
"handshakeToken": resp_params.get("handshakeToken", ""),
"hash": resp_params.get("hash", "SHA-256"),
"data": resp_params.get("data", ""),
}
).decode()
)
else:
await ws.send_text(orjson.dumps({"type": "authErr"}).decode())
return None
# Step 3: client-final → server-final + authToken
data = await ws.recv()
msg = orjson.loads(data)
if msg.get("type") != "scram":
await ws.send_text(orjson.dumps({"type": "authErr"}).decode())
return None
ht = msg.get("handshakeToken", "")
scram_data = msg.get("data", "")
auth_header = f"SCRAM handshakeToken={ht}, data={scram_data}"
result = handle_scram(self._handshakes, self._tokens, auth_header)
if result.status == 200:
auth_info = result.headers.get("Authentication-Info", "")
resp_params = _parse_header_params(auth_info)
auth_token = resp_params.get("authToken", "")
await ws.send_text(
orjson.dumps(
{
"type": "authOk",
"authToken": auth_token,
"data": resp_params.get("data", ""),
}
).decode()
)
_log.debug("WebSocket SCRAM auth succeeded")
# Look up username from the issued token
token_entry = self._tokens.get(auth_token)
return token_entry.username if token_entry else ""
await ws.send_text(orjson.dumps({"type": "authErr"}).decode())
return None
except Exception:
_log.warning("WebSocket SCRAM auth error")
return None
async def _check_permission(self, username: str | None, op_name: str) -> None:
"""Check role permissions for a WebSocket op.
:raises HaystackError: If the user lacks the required role.
"""
if self._user_store is None:
return
from hs_py.user import WRITE_OPS, Role
if username is None:
raise HaystackError("Authentication required")
user = await self._user_store.get_user(username)
if user is None or not user.enabled:
raise HaystackError("Authentication required")
if op_name in WRITE_OPS and user.role < Role.OPERATOR:
raise HaystackError(
f"Insufficient permissions: {op_name} requires operator or admin role"
)
async def _message_loop(self, ws: HaystackWebSocket, username: str | None = None) -> None:
"""Read request messages and dispatch to HaystackOps."""
chunk_assembler = ChunkAssembler() if self._chunked else None
last_cleanup = time.monotonic()
while True:
data = await ws.recv()
_fire(self._metrics.on_ws_message_recv, "", len(data))
# Periodic cleanup of orphaned chunk buffers
if chunk_assembler is not None:
now = time.monotonic()
if now - last_cleanup > 30.0:
chunk_assembler.cleanup(now)
last_cleanup = now
# Binary frame handling — only bytes from binary WS frames
if isinstance(data, bytes) and self._binary and len(data) >= 4:
await self._handle_binary_message(ws, data, username, chunk_assembler)
continue
# JSON text handling
text = data if isinstance(data, str) else data.decode()
try:
msg = orjson.loads(text)
except orjson.JSONDecodeError:
_log.warning("Non-JSON WebSocket message, ignoring")
continue
# Batch: JSON array of envelopes
if isinstance(msg, list):
await self._handle_batch(ws, msg, username)
continue
await self._handle_json_message(ws, msg, username)
async def _handle_capabilities(self, ws: HaystackWebSocket, msg: dict[str, Any]) -> None:
"""Respond to a capabilities negotiation message.
The client sends its supported features and the server responds
with the intersection of what it supports.
"""
client_compression = msg.get("compression", [])
client_chunked = msg.get("chunked", False)
# Determine agreed compression algorithm
server_algos = []
if self._binary_compression is not None:
from hs_py.ws_codec import COMP_LZMA, COMP_ZLIB
algo_names = {COMP_ZLIB: "zlib", COMP_LZMA: "lzma"}
name = algo_names.get(self._binary_compression)
if name:
server_algos.append(name)
agreed_compression: str | None = None
for algo in client_compression:
if algo in server_algos:
agreed_compression = algo
break
agreed_chunked = client_chunked and self._chunked
response = {
"type": "capabilities",
"compression": agreed_compression,
"chunked": agreed_chunked,
"version": 2,
}
await ws.send_text_preencoded(orjson.dumps(response))
_log.debug(
"Capabilities negotiated: compression=%s chunked=%s",
agreed_compression,
agreed_chunked,
)
async def _handle_json_message(
self, ws: HaystackWebSocket, msg: dict[str, Any], username: str | None = None
) -> None:
"""Dispatch a single JSON request envelope."""
# Handle capabilities negotiation
if msg.get("type") == "capabilities":
await self._handle_capabilities(ws, msg)
return
req_id = msg.get("id")
op = msg.get("op", "")
ch = msg.get("ch")
_fire(self._metrics.on_ws_message_recv, op, 0)
try:
await self._check_permission(username, op)
result_grid = await dispatch_op(self._ops, op, msg)
except HaystackError as exc:
_fire(self._metrics.on_error, op, type(exc).__name__)
result_grid = Grid.make_error(str(exc))
except Exception as exc:
_log.exception("Unhandled error in op '%s'", op)
_fire(self._metrics.on_error, op, type(exc).__name__)
result_grid = Grid.make_error("Internal server error")
grid_bytes = self._cached_grid_bytes(op, msg, result_grid)
payload = _ws_envelope(grid_bytes, req_id, ch)
await ws.send_text_preencoded(payload)
_fire(self._metrics.on_ws_message_sent, op, len(payload))
async def _handle_binary_message(
self,
ws: HaystackWebSocket,
data: bytes,
username: str | None = None,
chunk_assembler: ChunkAssembler | None = None,
) -> None:
"""Dispatch a binary frame request."""
try:
flags, req_id, op, grid_bytes = decode_binary_frame(data)
except ValueError:
_log.warning("Invalid binary frame, ignoring")
return
if flags & FLAG_PUSH:
return # Server doesn't process inbound pushes
# Chunked frame — accumulate until complete
if flags & FLAG_CHUNKED:
if chunk_assembler is None:
_log.warning("Received chunked frame but chunking is not enabled, ignoring")
return
assembled = chunk_assembler.feed(flags, req_id, op, grid_bytes)
if assembled is None:
return # waiting for more chunks
grid_bytes = assembled
_fire(self._metrics.on_ws_message_recv, op, len(data))
try:
msg: dict[str, Any] = {"op": op}
if grid_bytes:
msg["grid"] = orjson.loads(grid_bytes)
await self._check_permission(username, op)
result_grid = await dispatch_op(self._ops, op, msg)
except HaystackError as exc:
_fire(self._metrics.on_error, op, type(exc).__name__)
result_grid = Grid.make_error(str(exc))
except Exception as exc:
_log.exception("Unhandled error in binary op '%s'", op)
_fire(self._metrics.on_error, op, type(exc).__name__)
result_grid = Grid.make_error("Internal server error")
# Encode grid payload once, then decide: chunk or single frame
payload = encode_grid_json(result_grid)
resp_flags = FLAG_RESPONSE
if result_grid.is_error:
resp_flags |= FLAG_ERROR
# Chunk based on raw payload size (before compression)
if self._chunked and len(payload) > CHUNK_THRESHOLD:
frames = encode_chunked_frames(
resp_flags,
req_id,
op,
payload,
compression=self._binary_compression,
chunk_size=CHUNK_SIZE,
)
for frame in frames:
await ws.send_bytes(frame)
_fire(self._metrics.on_ws_message_sent, op, sum(len(f) for f in frames))
else:
response = encode_binary_response(
req_id,
op,
result_grid,
is_error=result_grid.is_error,
compression=self._binary_compression,
)
await ws.send_bytes(response)
_fire(self._metrics.on_ws_message_sent, op, len(response))
async def _handle_batch(
self, ws: HaystackWebSocket, batch: list[Any], username: str | None = None
) -> None:
"""Dispatch a batch of JSON request envelopes concurrently."""
items = [item for item in batch if isinstance(item, dict)]
if not items:
return
# Cap batch size to prevent resource exhaustion
if len(items) > _MAX_BATCH_SIZE:
items = items[:_MAX_BATCH_SIZE]
async def _dispatch_item(item: dict[str, Any]) -> bytes:
r_id = item.get("id")
r_op = item.get("op", "")
try:
await self._check_permission(username, r_op)
r_grid = await dispatch_op(self._ops, r_op, item)
except HaystackError as exc:
_fire(self._metrics.on_error, r_op, type(exc).__name__)
r_grid = Grid.make_error(str(exc))
except Exception:
_log.exception("Unhandled error in batch op '%s'", r_op)
_fire(self._metrics.on_error, r_op, "InternalError")
r_grid = Grid.make_error("Internal server error")
grid_bytes = self._cached_grid_bytes(r_op, item, r_grid)
return _ws_envelope(grid_bytes, r_id)
item_bytes = await asyncio.gather(*[_dispatch_item(item) for item in items])
payload = b"[" + b",".join(item_bytes) + b"]"
await ws.send_text_preencoded(payload)
_fire(self._metrics.on_ws_message_sent, "batch", len(payload))