"""Async Haystack WebSocket client.
Provides the same operation API as :class:`~hs_py.client.Client` over a
persistent WebSocket connection. Uses JSON-encoded request/response
envelopes with correlation IDs for concurrent request support.
Message format (client → server)::
{"id": "1", "op": "read", "grid": {...}}
Message format (server → client)::
{"id": "1", "grid": {...}}
Server-initiated push::
{"type": "watch", "watchId": "w-1", "grid": {...}}
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
import time
from typing import TYPE_CHECKING, Any, Literal, overload
import orjson
from hs_py.convert import grid_to_pythonic
from hs_py.encoding.json import decode_grid, decode_grid_dict, encode_grid_dict
from hs_py.errors import AuthError, CallError, NetworkError
from hs_py.grid import Grid, GridBuilder
from hs_py.kinds import MARKER, Number, Ref
from hs_py.metrics import MetricsHooks, _fire
from hs_py.tls import TLSConfig, build_client_ssl_context
from hs_py.ws import HaystackWebSocket, cancel_task, heartbeat_loop
from hs_py.ws_codec import (
FLAG_CHUNKED,
FLAG_ERROR,
FLAG_PUSH,
FLAG_RESPONSE,
ChunkAssembler,
decode_binary_frame,
encode_binary_request,
)
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
__all__ = [
"ChannelClient",
"ReconnectingWebSocketClient",
"WebSocketClient",
"WebSocketPool",
]
_log = logging.getLogger(__name__)
def _resolve_grid_response(pending: dict[str, asyncio.Future[Grid]], msg: dict[str, Any]) -> bool:
"""Resolve a pending future from a JSON response envelope. Return True if matched."""
req_id = msg.get("id")
if req_id is None or req_id not in pending:
return False
fut = pending.pop(req_id)
if fut.done():
return True
try:
grid = decode_grid_dict(msg["grid"])
if grid.is_error:
fut.set_exception(CallError(grid.meta.get("dis", "Unknown error"), grid))
else:
fut.set_result(grid)
except Exception as exc:
fut.set_exception(NetworkError(f"Failed to decode response: {exc}"))
return True
[docs]
class WebSocketClient:
"""Async Haystack WebSocket client.
Mirrors the :class:`~hs_py.client.Client` API over a persistent
WebSocket connection.
Usage::
async with WebSocketClient("ws://host:8080/api/ws") as c:
about = await c.about() # returns list[dict] by default
points = await c.read("point and sensor")
raw_grid = await c.about(raw=True) # returns Grid
"""
def __init__(
self,
url: str,
*,
username: str = "",
password: str = "",
auth_token: str = "",
tls: TLSConfig | None = None,
timeout: float = 30.0,
heartbeat: float = 30.0,
metrics: MetricsHooks | None = None,
compression: bool = False,
binary: bool = False,
pythonic: bool = True,
binary_compression: int | None = None,
chunked: bool = False,
) -> None:
"""Initialise the WebSocket client.
:param url: WebSocket URI (e.g. ``ws://host:8080/api/ws``).
:param username: Username for SCRAM-SHA-256 authentication.
:param password: Password for SCRAM-SHA-256 authentication.
:param auth_token: Bearer token sent on connect (used when *username*
is not provided).
:param tls: Optional :class:`~hs_py.tls.TLSConfig` for ``wss://``.
:param timeout: Per-request timeout in seconds.
:param heartbeat: Ping interval in seconds (0 to disable).
:param metrics: Optional :class:`~hs_py.metrics.MetricsHooks` callbacks.
:param compression: Enable per-message deflate compression.
:param binary: Use binary frame encoding instead of JSON envelopes.
:param pythonic: When ``True`` (default) Grid-returning methods return
``list[dict[str, Any]]`` with Haystack kinds converted to plain Python
values. Pass ``False`` to always return raw :class:`~hs_py.grid.Grid`.
:param binary_compression: Codec-level compression algorithm for binary
frames (e.g. ``COMP_ZLIB``). ``None`` disables codec compression.
:param chunked: Enable chunked transfer for large binary payloads.
"""
self._url = url
self._username = username
self._password = password
self._auth_token = auth_token
self._tls = tls
self._timeout = timeout
self._heartbeat = heartbeat
self._metrics = metrics or MetricsHooks()
self._compression = compression
self._binary = binary
self._binary_compression = binary_compression
self._chunked = chunked
self._pythonic = pythonic
self._ws: HaystackWebSocket | None = None
self._next_id = 0
self._pending: dict[str, asyncio.Future[Grid]] = {}
self._recv_task: asyncio.Task[None] | None = None
self._heartbeat_task: asyncio.Task[None] | None = None
self._watch_callback: Callable[[str, Grid], Any] | None = None
async def __aenter__(self) -> WebSocketClient:
ssl_ctx = build_client_ssl_context(self._tls) if self._tls else None
self._ws = await HaystackWebSocket.connect(
self._url, ssl_ctx, compression=self._compression
)
_fire(self._metrics.on_ws_connect, self._url)
# Authenticate before starting recv loop
if self._username:
await self._scram_authenticate()
elif self._auth_token:
await self._ws.send_text(orjson.dumps({"authToken": self._auth_token}).decode())
# Negotiate capabilities before starting recv loop
if self._binary_compression is not None or self._chunked:
await self._negotiate_capabilities()
self._recv_task = asyncio.create_task(self._recv_loop(), name="hs-ws-recv")
if self._heartbeat > 0:
self._heartbeat_task = asyncio.create_task(
heartbeat_loop(self._ws, self._heartbeat), name="hs-ws-heartbeat"
)
return self
async def __aexit__(self, *exc: object) -> None:
await self.close()
[docs]
async def close(self) -> None:
"""Close the WebSocket connection."""
await cancel_task(self._heartbeat_task)
self._heartbeat_task = None
await cancel_task(self._recv_task)
self._recv_task = None
if self._ws is not None:
await self._ws.close()
_fire(self._metrics.on_ws_disconnect, self._url)
self._ws = None
# Cancel any pending requests
for fut in self._pending.values():
if not fut.done():
fut.cancel()
self._pending.clear()
async def _scram_authenticate(self) -> None:
"""Perform SCRAM-SHA-256 handshake over WebSocket messages."""
from hs_py.auth import (
_b64url_decode,
_b64url_encode,
scram_client_final,
scram_client_first,
verify_server_signature,
)
assert self._ws is not None
# Step 1: HELLO
first = scram_client_first(self._username)
await self._ws.send_text(
orjson.dumps(
{"type": "hello", "username": _b64url_encode(self._username.encode())}
).decode()
)
data = await self._ws.recv()
msg = orjson.loads(data)
if msg.get("type") == "authErr":
raise AuthError("Server rejected HELLO")
if msg.get("type") != "hello":
raise AuthError(f"Expected hello response, got: {msg.get('type')}")
handshake_token = msg.get("handshakeToken", "")
hash_name = msg.get("hash", "SHA-256")
# Step 2: client-first → server-first
await self._ws.send_text(
orjson.dumps(
{
"type": "scram",
"handshakeToken": handshake_token,
"data": _b64url_encode(first.client_first_msg.encode()),
}
).decode()
)
data = await self._ws.recv()
msg = orjson.loads(data)
if msg.get("type") == "authErr":
raise AuthError("SCRAM step 1 failed")
if msg.get("type") != "scram":
raise AuthError(f"Expected scram response, got: {msg.get('type')}")
server_first_data = msg.get("data", "")
server_first_msg = _b64url_decode(server_first_data).decode()
new_ht = msg.get("handshakeToken", "")
final = scram_client_final(self._password, first, server_first_msg, hash_name)
# Step 3: client-final → authOk
await self._ws.send_text(
orjson.dumps(
{
"type": "scram",
"handshakeToken": new_ht,
"data": _b64url_encode(final.client_final_msg.encode()),
}
).decode()
)
data = await self._ws.recv()
msg = orjson.loads(data)
if msg.get("type") == "authErr":
raise AuthError("SCRAM authentication failed")
if msg.get("type") != "authOk":
raise AuthError(f"Expected authOk, got: {msg.get('type')}")
# Verify server signature
server_data = msg.get("data", "")
if server_data:
verify_server_signature(final, _b64url_decode(server_data).decode())
_log.debug("WebSocket SCRAM auth succeeded")
async def _negotiate_capabilities(self) -> None:
"""Exchange capabilities with the server for v2 features."""
assert self._ws is not None
from hs_py.ws_codec import COMP_LZMA, COMP_ZLIB
algo_names = {COMP_ZLIB: "zlib", COMP_LZMA: "lzma"}
compression_list: list[str] = []
if self._binary_compression is not None:
name = algo_names.get(self._binary_compression)
if name:
compression_list.append(name)
msg = {
"type": "capabilities",
"compression": compression_list,
"chunked": self._chunked,
"version": 2,
}
await self._ws.send_text(orjson.dumps(msg).decode())
try:
async with asyncio.timeout(5.0):
data = await self._ws.recv()
resp = orjson.loads(data)
if resp.get("type") == "capabilities":
agreed_comp = resp.get("compression")
if agreed_comp is None:
self._binary_compression = None
agreed_chunked = resp.get("chunked", False)
if not agreed_chunked:
self._chunked = False
_log.debug(
"Capabilities negotiated: compression=%s chunked=%s",
agreed_comp,
agreed_chunked,
)
else:
# Server didn't understand capabilities — fall back to v1
self._binary_compression = None
self._chunked = False
_log.debug("Server does not support capabilities negotiation, using v1")
except (TimeoutError, Exception):
# Timeout or error — fall back to v1
self._binary_compression = None
self._chunked = False
_log.debug("Capabilities negotiation failed, falling back to v1")
# ---- Standard ops (mirrors Client) -------------------------------------
@overload
async def about(self, *, raw: Literal[True]) -> Grid: ...
@overload
async def about(self, *, raw: Literal[False] = ...) -> list[dict[str, Any]]: ...
@overload
async def about(self, *, raw: bool = ...) -> Grid | list[dict[str, Any]]: ...
[docs]
async def about(self, *, raw: bool = False) -> Grid | list[dict[str, Any]]:
"""Query server information."""
grid = await self._call("about", Grid.make_empty())
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
@overload
async def ops(self, *, raw: Literal[True]) -> Grid: ...
@overload
async def ops(self, *, raw: Literal[False] = ...) -> list[dict[str, Any]]: ...
@overload
async def ops(self, *, raw: bool = ...) -> Grid | list[dict[str, Any]]: ...
[docs]
async def ops(self, *, raw: bool = False) -> Grid | list[dict[str, Any]]:
"""Query available operations."""
grid = await self._call("ops", Grid.make_empty())
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
@overload
async def formats(self, *, raw: Literal[True]) -> Grid: ...
@overload
async def formats(self, *, raw: Literal[False] = ...) -> list[dict[str, Any]]: ...
@overload
async def formats(self, *, raw: bool = ...) -> Grid | list[dict[str, Any]]: ...
@overload
async def read(self, filter: str, limit: int | None = ..., *, raw: Literal[True]) -> Grid: ...
@overload
async def read(
self, filter: str, limit: int | None = ..., *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def read(
self, filter: str, limit: int | None = ..., *, raw: bool
) -> Grid | list[dict[str, Any]]: ...
[docs]
async def read(
self, filter: str, limit: int | None = None, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read entities matching a filter expression."""
row: dict[str, Any] = {"filter": filter}
if limit is not None:
row["limit"] = Number(float(limit))
grid_req = GridBuilder().add_col("filter").add_col("limit").add_row(row).to_grid()
grid = await self._call("read", grid_req)
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
@overload
async def read_by_ids(self, ids: list[Ref], *, raw: Literal[True]) -> Grid: ...
@overload
async def read_by_ids(
self, ids: list[Ref], *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def read_by_ids(self, ids: list[Ref], *, raw: bool) -> Grid | list[dict[str, Any]]: ...
[docs]
async def read_by_ids(
self, ids: list[Ref], *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read entities by their identifiers."""
builder = GridBuilder().add_col("id")
for ref in ids:
builder.add_row({"id": ref})
grid = await self._call("read", builder.to_grid())
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
@overload
async def nav(self, nav_id: str | None = ..., *, raw: Literal[True]) -> Grid: ...
@overload
async def nav(
self, nav_id: str | None = ..., *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def nav(self, nav_id: str | None = ..., *, raw: bool) -> Grid | list[dict[str, Any]]: ...
[docs]
async def nav(
self, nav_id: str | None = None, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Navigate the entity tree."""
row: dict[str, Any] = {"navId": nav_id}
grid_req = GridBuilder().add_col("navId").add_row(row).to_grid()
grid = await self._call("nav", grid_req)
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
# ---- History ops -------------------------------------------------------
@overload
async def his_read(self, id: Ref, range: str, *, raw: Literal[True]) -> Grid: ...
@overload
async def his_read(
self, id: Ref, range: str, *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def his_read(self, id: Ref, range: str, *, raw: bool) -> Grid | list[dict[str, Any]]: ...
[docs]
async def his_read(
self, id: Ref, range: str, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read time-series data for a single point."""
grid_req = (
GridBuilder()
.add_col("id")
.add_col("range")
.add_row({"id": id, "range": range})
.to_grid()
)
grid = await self._call("hisRead", grid_req)
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
@overload
async def his_read_batch(self, ids: list[Ref], range: str, *, raw: Literal[True]) -> Grid: ...
@overload
async def his_read_batch(
self, ids: list[Ref], range: str, *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def his_read_batch(
self, ids: list[Ref], range: str, *, raw: bool
) -> Grid | list[dict[str, Any]]: ...
[docs]
async def his_read_batch(
self, ids: list[Ref], range: str, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read time-series data for multiple points."""
builder = GridBuilder().set_meta({"range": range}).add_col("id")
for ref in ids:
builder.add_row({"id": ref})
grid = await self._call("hisRead", builder.to_grid())
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
[docs]
async def his_write(self, id: Ref, items: list[dict[str, Any]]) -> None:
"""Write time-series data to a single point."""
builder = GridBuilder().set_meta({"id": id}).add_col("ts").add_col("val")
for item in items:
builder.add_row(item)
await self._call("hisWrite", builder.to_grid())
[docs]
async def his_write_batch(self, grid: Grid) -> None:
"""Write time-series data for multiple points."""
await self._call("hisWrite", grid)
# ---- Point write ops ---------------------------------------------------
@overload
async def point_write_array(self, id: Ref, *, raw: Literal[True]) -> Grid: ...
@overload
async def point_write_array(self, id: Ref, *, raw: Literal[False]) -> list[dict[str, Any]]: ...
@overload
async def point_write_array(self, id: Ref, *, raw: bool) -> Grid | list[dict[str, Any]]: ...
[docs]
async def point_write_array(
self, id: Ref, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read the priority array of a writable point."""
grid_req = GridBuilder().add_col("id").add_row({"id": id}).to_grid()
grid = await self._call("pointWrite", grid_req)
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
[docs]
async def point_write(
self,
id: Ref,
level: int,
val: Any,
who: str = "",
duration: Number | None = None,
) -> None:
"""Write to a priority array level."""
row: dict[str, Any] = {
"id": id,
"level": Number(float(level)),
"val": val,
"who": who,
}
if duration is not None:
row["duration"] = duration
cols = ["id", "level", "val", "who", "duration"]
builder = GridBuilder()
for col in cols:
builder.add_col(col)
builder.add_row(row)
await self._call("pointWrite", builder.to_grid())
# ---- Watch ops ---------------------------------------------------------
@overload
async def watch_sub(
self,
ids: list[Ref],
watch_dis: str,
lease: Number | None = ...,
*,
filter: str | None = ...,
raw: Literal[True],
) -> Grid: ...
@overload
async def watch_sub(
self,
ids: list[Ref],
watch_dis: str,
lease: Number | None = ...,
*,
filter: str | None = ...,
raw: Literal[False],
) -> list[dict[str, Any]]: ...
@overload
async def watch_sub(
self,
ids: list[Ref],
watch_dis: str,
lease: Number | None = ...,
*,
filter: str | None = ...,
raw: bool,
) -> Grid | list[dict[str, Any]]: ...
[docs]
async def watch_sub(
self,
ids: list[Ref],
watch_dis: str,
lease: Number | None = None,
*,
filter: str | None = None,
raw: bool = False,
) -> Grid | list[dict[str, Any]]:
"""Create a new watch or add entities to an existing one.
:param filter: Optional Haystack filter for server-side filtering.
:param raw: If ``True``, return the raw :class:`~hs_py.grid.Grid`.
"""
meta: dict[str, Any] = {"watchDis": watch_dis}
if lease is not None:
meta["lease"] = lease
if filter is not None:
meta["filter"] = filter
builder = GridBuilder().set_meta(meta).add_col("id")
for ref in ids:
builder.add_row({"id": ref})
grid = await self._call("watchSub", builder.to_grid())
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
[docs]
async def watch_unsub(self, watch_id: str, ids: list[Ref]) -> None:
"""Remove entities from a watch."""
builder = GridBuilder().set_meta({"watchId": watch_id}).add_col("id")
for ref in ids:
builder.add_row({"id": ref})
await self._call("watchUnsub", builder.to_grid())
[docs]
async def watch_close(self, watch_id: str) -> None:
"""Close a watch entirely."""
builder = GridBuilder().set_meta({"watchId": watch_id, "close": MARKER}).add_col("id")
await self._call("watchUnsub", builder.to_grid())
@overload
async def watch_poll(
self, watch_id: str, refresh: bool = ..., *, raw: Literal[True]
) -> Grid: ...
@overload
async def watch_poll(
self, watch_id: str, refresh: bool = ..., *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def watch_poll(
self, watch_id: str, refresh: bool = ..., *, raw: bool
) -> Grid | list[dict[str, Any]]: ...
[docs]
async def watch_poll(
self, watch_id: str, refresh: bool = False, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Poll a watch for changes."""
meta: dict[str, Any] = {"watchId": watch_id}
if refresh:
meta["refresh"] = MARKER
grid_req = GridBuilder().set_meta(meta).to_grid()
grid = await self._call("watchPoll", grid_req)
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
# ---- Action ops --------------------------------------------------------
@overload
async def invoke_action(
self,
id: Ref,
action: str,
args: dict[str, Any] | None = ...,
*,
raw: Literal[True],
) -> Grid: ...
@overload
async def invoke_action(
self,
id: Ref,
action: str,
args: dict[str, Any] | None = ...,
*,
raw: Literal[False],
) -> list[dict[str, Any]]: ...
@overload
async def invoke_action(
self,
id: Ref,
action: str,
args: dict[str, Any] | None = ...,
*,
raw: bool,
) -> Grid | list[dict[str, Any]]: ...
[docs]
async def invoke_action(
self,
id: Ref,
action: str,
args: dict[str, Any] | None = None,
*,
raw: bool = False,
) -> Grid | list[dict[str, Any]]:
"""Invoke an action on an entity."""
meta: dict[str, Any] = {"id": id, "action": action}
builder = GridBuilder().set_meta(meta)
if args:
for key in args:
builder.add_col(key)
builder.add_row(args)
grid = await self._call("invokeAction", builder.to_grid())
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
# ---- Batch ops ---------------------------------------------------------
[docs]
async def batch(self, *calls: tuple[str, Grid]) -> list[Grid]:
"""Send multiple operations in a single WebSocket frame.
:param calls: Tuples of ``(op_name, grid)``.
:returns: List of response grids in the same order as *calls*.
:raises NetworkError: If any call times out.
"""
ws = self._require_ws()
loop = asyncio.get_running_loop()
futs: list[tuple[str, asyncio.Future[Grid]]] = []
envelopes: list[dict[str, Any]] = []
for op, grid in calls:
req_id = str(self._next_id)
self._next_id = (self._next_id + 1) & 0xFFFF_FFFF
envelopes.append({"id": req_id, "op": op, "grid": encode_grid_dict(grid)})
fut: asyncio.Future[Grid] = loop.create_future()
self._pending[req_id] = fut
futs.append((req_id, fut))
msg = orjson.dumps(envelopes).decode()
await ws.send_text(msg)
_fire(self._metrics.on_ws_message_sent, "batch", len(msg))
results: list[Grid] = []
try:
for _req_id, fut in futs:
async with asyncio.timeout(self._timeout):
result = await fut
results.append(result)
except TimeoutError as exc:
# Cancel remaining futures
for req_id, fut in futs:
self._pending.pop(req_id, None)
if not fut.done():
fut.cancel()
raise NetworkError("Batch request timed out") from exc
return results
# ---- Watch push --------------------------------------------------------
[docs]
def on_watch_push(self, callback: Callable[[str, Grid], Any]) -> None:
"""Register a callback for server-initiated watch push messages.
:param callback: Called with ``(watch_id, grid)`` for each push.
"""
self._watch_callback = callback
# ---- Internal ----------------------------------------------------------
async def _call(self, op: str, grid: Grid) -> Grid:
"""Send a request and await the correlated response."""
ws = self._require_ws()
req_id_int = self._next_id
req_id = str(req_id_int)
self._next_id = (self._next_id + 1) & 0xFFFF_FFFF
# Create future for response
loop = asyncio.get_running_loop()
fut: asyncio.Future[Grid] = loop.create_future()
self._pending[req_id] = fut
try:
if self._binary:
frame = encode_binary_request(
req_id_int,
op,
grid,
compression=self._binary_compression,
)
await ws.send_bytes(frame)
_fire(self._metrics.on_ws_message_sent, op, len(frame))
else:
envelope = orjson.dumps(
{"id": req_id, "op": op, "grid": encode_grid_dict(grid)}
).decode()
await ws.send_text(envelope)
_fire(self._metrics.on_ws_message_sent, op, len(envelope))
async with asyncio.timeout(self._timeout):
result = await fut
_fire(self._metrics.on_request, op, 0.0)
return result
except TimeoutError as exc:
self._pending.pop(req_id, None)
_fire(self._metrics.on_error, op, "timeout")
raise NetworkError(f"Request '{op}' timed out") from exc
except Exception:
self._pending.pop(req_id, None)
raise
async def _recv_loop(self) -> None:
"""Background task: read messages and dispatch responses/pushes."""
ws = self._require_ws()
chunk_assembler = ChunkAssembler() if self._chunked else None
last_cleanup = time.monotonic()
try:
while True:
data = await ws.recv()
# Periodic cleanup of orphaned chunk buffers
if chunk_assembler is not None:
now = time.monotonic()
if now - last_cleanup > 30.0:
chunk_assembler.cleanup(now)
last_cleanup = now
# Binary frame handling — only bytes from binary WS frames
if isinstance(data, bytes) and self._binary:
self._handle_binary_frame(data, chunk_assembler)
continue
text = data if isinstance(data, str) else data.decode()
try:
msg = orjson.loads(text)
except orjson.JSONDecodeError:
_log.warning("Received non-JSON WebSocket message, ignoring")
continue
# Batch response (array of envelopes)
if isinstance(msg, list):
for item in msg:
self._handle_json_response(item)
continue
# Server-initiated push
msg_type = msg.get("type")
if msg_type == "watch":
self._dispatch_watch_push(msg)
continue
self._handle_json_response(msg)
except asyncio.CancelledError:
return
except Exception as exc:
_log.debug("WebSocket recv loop ended: %s", exc)
# Fail all pending requests
for fut in self._pending.values():
if not fut.done():
fut.set_exception(NetworkError(f"Connection lost: {exc}"))
self._pending.clear()
def _handle_json_response(self, msg: dict[str, Any]) -> None:
"""Process a single JSON response envelope."""
if not _resolve_grid_response(self._pending, msg):
_log.debug("Received unmatched WebSocket message: %s", msg.get("id"))
def _handle_binary_frame(
self,
data: bytes,
chunk_assembler: ChunkAssembler | None = None,
) -> None:
"""Process a binary frame response or push."""
try:
flags, req_id, op, grid_bytes = decode_binary_frame(data)
except ValueError:
_log.warning("Invalid binary frame, ignoring")
return
# Chunked frame — accumulate until complete
if flags & FLAG_CHUNKED:
if chunk_assembler is None:
_log.warning("Received chunked frame but chunking is not enabled, ignoring")
return
assembled = chunk_assembler.feed(flags, req_id, op, grid_bytes)
if assembled is None:
return # waiting for more chunks
grid_bytes = assembled
# Clear the chunked flag for downstream processing
flags &= ~FLAG_CHUNKED
if flags & FLAG_PUSH:
if self._watch_callback is not None:
try:
grid = decode_grid(grid_bytes)
self._watch_callback(op, grid)
except Exception:
_log.exception("Error in binary watch push callback")
return
if flags & FLAG_RESPONSE:
req_key = str(req_id)
fut = self._pending.pop(req_key, None)
if fut is not None and not fut.done():
try:
grid = decode_grid(grid_bytes)
if flags & FLAG_ERROR:
fut.set_exception(CallError(grid.meta.get("dis", "Unknown error"), grid))
else:
fut.set_result(grid)
except Exception as exc:
fut.set_exception(NetworkError(f"Failed to decode binary response: {exc}"))
return
_log.debug("Received unmatched binary frame: req_id=%d op=%s", req_id, op)
def _dispatch_watch_push(self, msg: dict[str, Any]) -> None:
"""Dispatch a watch push message to the registered callback."""
if self._watch_callback is None:
return
watch_id = msg.get("watchId", "")
try:
grid = decode_grid_dict(msg["grid"])
self._watch_callback(watch_id, grid)
except Exception:
_log.exception("Error in watch push callback")
def _require_ws(self) -> HaystackWebSocket:
"""Return the active WebSocket or raise."""
if self._ws is None:
msg = "WebSocketClient is not open. Use 'async with WebSocketClient(...) as c:'"
raise RuntimeError(msg)
return self._ws
[docs]
class ReconnectingWebSocketClient:
"""WebSocket client with automatic reconnection and exponential backoff.
Wraps :class:`WebSocketClient` and manages the connection lifecycle.
On disconnect, reconnects with exponential backoff and re-registers
watch callbacks.
Usage::
client = ReconnectingWebSocketClient("ws://host:8080/api/ws")
await client.start()
try:
rows = await client.about()
finally:
await client.stop()
"""
def __init__(
self,
url: str,
*,
username: str = "",
password: str = "",
auth_token: str = "",
tls: TLSConfig | None = None,
timeout: float = 30.0,
heartbeat: float = 30.0,
min_reconnect_delay: float = 1.0,
max_reconnect_delay: float = 60.0,
on_connect: Callable[[], Awaitable[None]] | None = None,
on_disconnect: Callable[[], Awaitable[None]] | None = None,
metrics: MetricsHooks | None = None,
compression: bool = False,
binary: bool = False,
pythonic: bool = True,
binary_compression: int | None = None,
chunked: bool = False,
) -> None:
"""Initialise the reconnecting client.
:param url: WebSocket URI (e.g. ``ws://host:8080/api/ws``).
:param username: Username for SCRAM-SHA-256 authentication.
:param password: Password for SCRAM-SHA-256 authentication.
:param auth_token: Bearer token sent on connect.
:param tls: Optional :class:`~hs_py.tls.TLSConfig` for ``wss://``.
:param timeout: Per-request timeout in seconds.
:param heartbeat: Ping interval in seconds (0 to disable).
:param min_reconnect_delay: Initial backoff delay in seconds.
:param max_reconnect_delay: Maximum backoff delay in seconds.
:param on_connect: Async callback invoked after each successful connect.
:param on_disconnect: Async callback invoked after each disconnect.
:param metrics: Optional :class:`~hs_py.metrics.MetricsHooks` callbacks.
:param compression: Enable per-message deflate compression.
:param binary: Use binary frame encoding instead of JSON envelopes.
:param pythonic: When ``True`` (default) Grid-returning methods return
``list[dict[str, Any]]``. Pass ``False`` to always return raw Grid.
:param binary_compression: Codec-level compression for binary frames.
:param chunked: Enable chunked transfer for large binary payloads.
"""
self._url = url
self._username = username
self._password = password
self._auth_token = auth_token
self._tls = tls
self._timeout = timeout
self._heartbeat = heartbeat
self._min_delay = min_reconnect_delay
self._max_delay = max_reconnect_delay
self._on_connect = on_connect
self._on_disconnect = on_disconnect
self._metrics = metrics
self._compression = compression
self._binary = binary
self._binary_compression = binary_compression
self._chunked = chunked
self._pythonic = pythonic
self._inner: WebSocketClient | None = None
self._loop_task: asyncio.Task[None] | None = None
self._connected = asyncio.Event()
self._stopping = False
self._watch_callback: Callable[[str, Grid], Any] | None = None
[docs]
async def start(self) -> None:
"""Start the connection loop in the background."""
self._stopping = False
self._loop_task = asyncio.create_task(self._connect_loop(), name="hs-ws-reconnect")
# Wait for initial connection
await self._connected.wait()
[docs]
async def stop(self) -> None:
"""Stop reconnection and close the connection."""
self._stopping = True
self._connected.clear()
if self._loop_task is not None:
self._loop_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._loop_task
self._loop_task = None
if self._inner is not None:
await self._inner.close()
self._inner = None
[docs]
def on_watch_push(self, callback: Callable[[str, Grid], Any]) -> None:
"""Register a watch push callback, preserved across reconnections."""
self._watch_callback = callback
if self._inner is not None:
self._inner.on_watch_push(callback)
# ---- Delegated ops -----------------------------------------------------
@overload
async def about(self, *, raw: Literal[True]) -> Grid: ...
@overload
async def about(self, *, raw: Literal[False] = ...) -> list[dict[str, Any]]: ...
@overload
async def about(self, *, raw: bool = ...) -> Grid | list[dict[str, Any]]: ...
[docs]
async def about(self, *, raw: bool = False) -> Grid | list[dict[str, Any]]:
"""Query server information."""
return await self._require_inner().about(raw=raw)
@overload
async def ops(self, *, raw: Literal[True]) -> Grid: ...
@overload
async def ops(self, *, raw: Literal[False] = ...) -> list[dict[str, Any]]: ...
@overload
async def ops(self, *, raw: bool = ...) -> Grid | list[dict[str, Any]]: ...
[docs]
async def ops(self, *, raw: bool = False) -> Grid | list[dict[str, Any]]:
"""Query available operations."""
return await self._require_inner().ops(raw=raw)
@overload
async def formats(self, *, raw: Literal[True]) -> Grid: ...
@overload
async def formats(self, *, raw: Literal[False] = ...) -> list[dict[str, Any]]: ...
@overload
async def formats(self, *, raw: bool = ...) -> Grid | list[dict[str, Any]]: ...
@overload
async def read(self, filter: str, limit: int | None = ..., *, raw: Literal[True]) -> Grid: ...
@overload
async def read(
self, filter: str, limit: int | None = ..., *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def read(
self, filter: str, limit: int | None = ..., *, raw: bool
) -> Grid | list[dict[str, Any]]: ...
[docs]
async def read(
self, filter: str, limit: int | None = None, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read entities matching a filter expression."""
return await self._require_inner().read(filter, limit, raw=raw)
@overload
async def read_by_ids(self, ids: list[Ref], *, raw: Literal[True]) -> Grid: ...
@overload
async def read_by_ids(
self, ids: list[Ref], *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def read_by_ids(self, ids: list[Ref], *, raw: bool) -> Grid | list[dict[str, Any]]: ...
[docs]
async def read_by_ids(
self, ids: list[Ref], *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read entities by their identifiers."""
return await self._require_inner().read_by_ids(ids, raw=raw)
@overload
async def nav(self, nav_id: str | None = ..., *, raw: Literal[True]) -> Grid: ...
@overload
async def nav(
self, nav_id: str | None = ..., *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def nav(self, nav_id: str | None = ..., *, raw: bool) -> Grid | list[dict[str, Any]]: ...
[docs]
async def nav(
self, nav_id: str | None = None, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Navigate the entity tree."""
return await self._require_inner().nav(nav_id, raw=raw)
@overload
async def his_read(self, id: Ref, range: str, *, raw: Literal[True]) -> Grid: ...
@overload
async def his_read(
self, id: Ref, range: str, *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def his_read(self, id: Ref, range: str, *, raw: bool) -> Grid | list[dict[str, Any]]: ...
[docs]
async def his_read(
self, id: Ref, range: str, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read time-series data for a single point."""
return await self._require_inner().his_read(id, range, raw=raw)
[docs]
async def his_write(self, id: Ref, items: list[dict[str, Any]]) -> None:
"""Write time-series data to a single point."""
await self._require_inner().his_write(id, items)
@overload
async def watch_sub(
self,
ids: list[Ref],
watch_dis: str,
lease: Number | None = ...,
*,
filter: str | None = ...,
raw: Literal[True],
) -> Grid: ...
@overload
async def watch_sub(
self,
ids: list[Ref],
watch_dis: str,
lease: Number | None = ...,
*,
filter: str | None = ...,
raw: Literal[False],
) -> list[dict[str, Any]]: ...
@overload
async def watch_sub(
self,
ids: list[Ref],
watch_dis: str,
lease: Number | None = ...,
*,
filter: str | None = ...,
raw: bool,
) -> Grid | list[dict[str, Any]]: ...
[docs]
async def watch_sub(
self,
ids: list[Ref],
watch_dis: str,
lease: Number | None = None,
*,
filter: str | None = None,
raw: bool = False,
) -> Grid | list[dict[str, Any]]:
"""Create a new watch or add entities to an existing one."""
return await self._require_inner().watch_sub(ids, watch_dis, lease, filter=filter, raw=raw)
[docs]
async def watch_unsub(self, watch_id: str, ids: list[Ref]) -> None:
"""Remove entities from a watch."""
await self._require_inner().watch_unsub(watch_id, ids)
@overload
async def watch_poll(
self, watch_id: str, refresh: bool = ..., *, raw: Literal[True]
) -> Grid: ...
@overload
async def watch_poll(
self, watch_id: str, refresh: bool = ..., *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def watch_poll(
self, watch_id: str, refresh: bool = ..., *, raw: bool
) -> Grid | list[dict[str, Any]]: ...
[docs]
async def watch_poll(
self, watch_id: str, refresh: bool = False, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Poll a watch for changes."""
return await self._require_inner().watch_poll(watch_id, refresh, raw=raw)
@overload
async def invoke_action(
self,
id: Ref,
action: str,
args: dict[str, Any] | None = ...,
*,
raw: Literal[True],
) -> Grid: ...
@overload
async def invoke_action(
self,
id: Ref,
action: str,
args: dict[str, Any] | None = ...,
*,
raw: Literal[False],
) -> list[dict[str, Any]]: ...
@overload
async def invoke_action(
self,
id: Ref,
action: str,
args: dict[str, Any] | None = ...,
*,
raw: bool,
) -> Grid | list[dict[str, Any]]: ...
[docs]
async def invoke_action(
self,
id: Ref,
action: str,
args: dict[str, Any] | None = None,
*,
raw: bool = False,
) -> Grid | list[dict[str, Any]]:
"""Invoke an action on an entity."""
return await self._require_inner().invoke_action(id, action, args, raw=raw)
[docs]
async def batch(self, *calls: tuple[str, Grid]) -> list[Grid]:
"""Send multiple operations in a single WebSocket frame."""
return await self._require_inner().batch(*calls)
@overload
async def his_read_batch(self, ids: list[Ref], range: str, *, raw: Literal[True]) -> Grid: ...
@overload
async def his_read_batch(
self, ids: list[Ref], range: str, *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def his_read_batch(
self, ids: list[Ref], range: str, *, raw: bool
) -> Grid | list[dict[str, Any]]: ...
[docs]
async def his_read_batch(
self, ids: list[Ref], range: str, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read time-series data for multiple points."""
return await self._require_inner().his_read_batch(ids, range, raw=raw)
[docs]
async def his_write_batch(self, grid: Grid) -> None:
"""Write time-series data for multiple points."""
await self._require_inner().his_write_batch(grid)
@overload
async def point_write_array(self, id: Ref, *, raw: Literal[True]) -> Grid: ...
@overload
async def point_write_array(self, id: Ref, *, raw: Literal[False]) -> list[dict[str, Any]]: ...
@overload
async def point_write_array(self, id: Ref, *, raw: bool) -> Grid | list[dict[str, Any]]: ...
[docs]
async def point_write_array(
self, id: Ref, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read the priority array of a writable point."""
return await self._require_inner().point_write_array(id, raw=raw)
[docs]
async def point_write(
self,
id: Ref,
level: int,
val: Any,
who: str = "",
duration: Number | None = None,
) -> None:
"""Write to a priority array level."""
await self._require_inner().point_write(id, level, val, who, duration)
[docs]
async def watch_close(self, watch_id: str) -> None:
"""Close a watch entirely."""
await self._require_inner().watch_close(watch_id)
[docs]
async def close(self) -> None:
"""Close the connection (alias for :meth:`stop`)."""
await self.stop()
# ---- Internal ----------------------------------------------------------
async def _connect_loop(self) -> None:
"""Background loop that maintains the connection."""
delay = self._min_delay
client: WebSocketClient | None = None
while not self._stopping:
try:
client = WebSocketClient(
self._url,
username=self._username,
password=self._password,
auth_token=self._auth_token,
tls=self._tls,
timeout=self._timeout,
heartbeat=self._heartbeat,
metrics=self._metrics,
compression=self._compression,
binary=self._binary,
pythonic=self._pythonic,
binary_compression=self._binary_compression,
chunked=self._chunked,
)
await client.__aenter__()
self._inner = client
if self._watch_callback is not None:
client.on_watch_push(self._watch_callback)
delay = self._min_delay # Reset on successful connect
self._connected.set()
if self._on_connect is not None:
await self._on_connect()
_log.info("ReconnectingWebSocketClient connected to %s", self._url)
# Wait for the recv loop to end (connection lost)
if client._recv_task is not None:
await client._recv_task
except asyncio.CancelledError:
if client is not None:
with contextlib.suppress(Exception):
await client.close()
return
except Exception:
_log.debug("Connection attempt to %s failed", self._url, exc_info=True)
finally:
if client is not None and self._stopping:
with contextlib.suppress(Exception):
await client.close()
# Connection lost or failed
self._connected.clear()
self._inner = None
if self._on_disconnect is not None:
with contextlib.suppress(Exception):
await self._on_disconnect()
if self._stopping:
return
_log.info("Reconnecting to %s in %.1fs", self._url, delay)
await asyncio.sleep(delay)
delay = min(delay * 2, self._max_delay)
def _require_inner(self) -> WebSocketClient:
"""Return the connected inner client or raise."""
if self._inner is None or not self._connected.is_set():
raise NetworkError("Not connected (reconnecting)")
return self._inner
[docs]
class ChannelClient:
"""Virtual client scoped to a single channel within a :class:`WebSocketPool`.
Each channel has its own request ID space and pending futures.
The channel name is included in every JSON envelope as the ``ch`` field.
"""
def __init__(self, pool: WebSocketPool, channel: str, *, pythonic: bool = True) -> None:
"""Initialise a channel client.
:param pool: Parent :class:`WebSocketPool`.
:param channel: Channel name included in every message.
:param pythonic: When ``True`` (default) Grid-returning methods return
``list[dict[str, Any]]``. Pass ``False`` to return raw Grid.
"""
self._pool = pool
self._channel = channel
self._pythonic = pythonic
@overload
async def about(self, *, raw: Literal[True]) -> Grid: ...
@overload
async def about(self, *, raw: Literal[False] = ...) -> list[dict[str, Any]]: ...
@overload
async def about(self, *, raw: bool = ...) -> Grid | list[dict[str, Any]]: ...
[docs]
async def about(self, *, raw: bool = False) -> Grid | list[dict[str, Any]]:
"""Query server information."""
grid = await self._pool._call(self._channel, "about", Grid.make_empty())
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
@overload
async def ops(self, *, raw: Literal[True]) -> Grid: ...
@overload
async def ops(self, *, raw: Literal[False] = ...) -> list[dict[str, Any]]: ...
@overload
async def ops(self, *, raw: bool = ...) -> Grid | list[dict[str, Any]]: ...
[docs]
async def ops(self, *, raw: bool = False) -> Grid | list[dict[str, Any]]:
"""Query available operations."""
grid = await self._pool._call(self._channel, "ops", Grid.make_empty())
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
@overload
async def read(self, filter: str, limit: int | None = ..., *, raw: Literal[True]) -> Grid: ...
@overload
async def read(
self, filter: str, limit: int | None = ..., *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def read(
self, filter: str, limit: int | None = ..., *, raw: bool
) -> Grid | list[dict[str, Any]]: ...
[docs]
async def read(
self, filter: str, limit: int | None = None, *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read entities matching a filter expression."""
row: dict[str, Any] = {"filter": filter}
if limit is not None:
row["limit"] = Number(float(limit))
grid_req = GridBuilder().add_col("filter").add_col("limit").add_row(row).to_grid()
grid = await self._pool._call(self._channel, "read", grid_req)
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
@overload
async def read_by_ids(self, ids: list[Ref], *, raw: Literal[True]) -> Grid: ...
@overload
async def read_by_ids(
self, ids: list[Ref], *, raw: Literal[False]
) -> list[dict[str, Any]]: ...
@overload
async def read_by_ids(self, ids: list[Ref], *, raw: bool) -> Grid | list[dict[str, Any]]: ...
[docs]
async def read_by_ids(
self, ids: list[Ref], *, raw: bool = False
) -> Grid | list[dict[str, Any]]:
"""Read entities by their identifiers."""
builder = GridBuilder().add_col("id")
for ref in ids:
builder.add_row({"id": ref})
grid = await self._pool._call(self._channel, "read", builder.to_grid())
return grid if raw or not self._pythonic else grid_to_pythonic(grid)
[docs]
class WebSocketPool:
"""Multiplexes multiple logical channels over a single WebSocket.
Each channel is identified by a string name included in the JSON
envelope as the ``ch`` field.
Usage::
async with WebSocketPool("ws://host:8080/api/ws") as pool:
ch1 = pool.channel("tenant-1")
ch2 = pool.channel("tenant-2")
about1 = await ch1.about()
about2 = await ch2.about()
"""
def __init__(
self,
url: str,
*,
username: str = "",
password: str = "",
auth_token: str = "",
tls: TLSConfig | None = None,
timeout: float = 30.0,
heartbeat: float = 30.0,
compression: bool = False,
pythonic: bool = True,
) -> None:
"""Initialise the connection pool.
:param url: WebSocket URI (e.g. ``ws://host:8080/api/ws``).
:param username: Username for SCRAM-SHA-256 authentication.
:param password: Password for SCRAM-SHA-256 authentication.
:param auth_token: Bearer token sent on connect.
:param tls: Optional :class:`~hs_py.tls.TLSConfig` for ``wss://``.
:param timeout: Per-request timeout in seconds.
:param heartbeat: Ping interval in seconds (0 to disable).
:param compression: Enable per-message deflate compression.
:param pythonic: When ``True`` (default) channels return
``list[dict[str, Any]]``. Pass ``False`` to return raw Grid.
"""
self._url = url
self._username = username
self._password = password
self._auth_token = auth_token
self._tls = tls
self._timeout = timeout
self._heartbeat = heartbeat
self._compression = compression
self._pythonic = pythonic
self._ws: HaystackWebSocket | None = None
self._next_id = 0
self._pending: dict[str, asyncio.Future[Grid]] = {}
self._recv_task: asyncio.Task[None] | None = None
self._heartbeat_task: asyncio.Task[None] | None = None
async def __aenter__(self) -> WebSocketPool:
ssl_ctx = build_client_ssl_context(self._tls) if self._tls else None
self._ws = await HaystackWebSocket.connect(
self._url, ssl_ctx, compression=self._compression
)
# Authenticate before starting recv loop
if self._username:
await self._pool_scram_authenticate()
elif self._auth_token:
await self._ws.send_text(
orjson.dumps({"type": "auth", "token": self._auth_token}).decode()
)
self._recv_task = asyncio.create_task(self._recv_loop(), name="hs-pool-recv")
if self._heartbeat > 0:
self._heartbeat_task = asyncio.create_task(
heartbeat_loop(self._ws, self._heartbeat), name="hs-pool-heartbeat"
)
return self
async def __aexit__(self, *exc: object) -> None:
await self.close()
[docs]
async def close(self) -> None:
"""Close the pooled WebSocket connection."""
await cancel_task(self._heartbeat_task)
self._heartbeat_task = None
await cancel_task(self._recv_task)
self._recv_task = None
if self._ws is not None:
await self._ws.close()
self._ws = None
for fut in self._pending.values():
if not fut.done():
fut.cancel()
self._pending.clear()
async def _pool_scram_authenticate(self) -> None:
"""Perform SCRAM-SHA-256 handshake (same protocol as WebSocketClient)."""
from hs_py.auth import (
_b64url_decode,
_b64url_encode,
scram_client_final,
scram_client_first,
verify_server_signature,
)
assert self._ws is not None
first = scram_client_first(self._username)
await self._ws.send_text(
orjson.dumps(
{"type": "hello", "username": _b64url_encode(self._username.encode())}
).decode()
)
data = await self._ws.recv()
msg = orjson.loads(data)
if msg.get("type") != "hello":
raise AuthError(f"Expected hello response, got: {msg.get('type')}")
ht = msg.get("handshakeToken", "")
hash_name = msg.get("hash", "SHA-256")
await self._ws.send_text(
orjson.dumps(
{
"type": "scram",
"handshakeToken": ht,
"data": _b64url_encode(first.client_first_msg.encode()),
}
).decode()
)
data = await self._ws.recv()
msg = orjson.loads(data)
if msg.get("type") != "scram":
raise AuthError(f"Expected scram response, got: {msg.get('type')}")
server_first_msg = _b64url_decode(msg.get("data", "")).decode()
new_ht = msg.get("handshakeToken", "")
final = scram_client_final(self._password, first, server_first_msg, hash_name)
await self._ws.send_text(
orjson.dumps(
{
"type": "scram",
"handshakeToken": new_ht,
"data": _b64url_encode(final.client_final_msg.encode()),
}
).decode()
)
data = await self._ws.recv()
msg = orjson.loads(data)
if msg.get("type") != "authOk":
raise AuthError(f"SCRAM auth failed, got: {msg.get('type')}")
server_data = msg.get("data", "")
if server_data:
verify_server_signature(final, _b64url_decode(server_data).decode())
[docs]
def channel(self, name: str) -> ChannelClient:
"""Return a virtual client scoped to the given channel name.
:param name: Channel identifier (included as ``ch`` in each message).
:returns: :class:`ChannelClient` bound to this pool and channel.
"""
return ChannelClient(self, name, pythonic=self._pythonic)
# ---- Internal ----------------------------------------------------------
async def _call(self, ch: str, op: str, grid: Grid) -> Grid:
"""Send a channel-scoped request and await the response."""
ws = self._require_ws()
req_id = str(self._next_id)
self._next_id = (self._next_id + 1) & 0xFFFF_FFFF
envelope = orjson.dumps(
{"id": req_id, "op": op, "grid": encode_grid_dict(grid), "ch": ch}
).decode()
loop = asyncio.get_running_loop()
fut: asyncio.Future[Grid] = loop.create_future()
self._pending[req_id] = fut
try:
await ws.send_text(envelope)
async with asyncio.timeout(self._timeout):
return await fut
except TimeoutError as exc:
self._pending.pop(req_id, None)
raise NetworkError(f"Channel '{ch}' request '{op}' timed out") from exc
except Exception:
self._pending.pop(req_id, None)
raise
async def _recv_loop(self) -> None:
"""Read messages and dispatch responses."""
ws = self._require_ws()
try:
while True:
data = await ws.recv()
try:
msg = orjson.loads(data)
except orjson.JSONDecodeError:
continue
if isinstance(msg, list):
for item in msg:
self._dispatch_response(item)
else:
self._dispatch_response(msg)
except asyncio.CancelledError:
return
except Exception as exc:
_log.debug("WebSocketPool recv loop ended: %s", exc)
for fut in self._pending.values():
if not fut.done():
fut.set_exception(NetworkError(f"Connection lost: {exc}"))
self._pending.clear()
def _dispatch_response(self, msg: dict[str, Any]) -> None:
"""Route a response to the correct pending future."""
_resolve_grid_response(self._pending, msg)
def _require_ws(self) -> HaystackWebSocket:
"""Return the active WebSocket or raise."""
if self._ws is None:
msg = "WebSocketPool is not open. Use 'async with WebSocketPool(...) as p:'"
raise RuntimeError(msg)
return self._ws