"""Sans-I/O WebSocket wrapper for BACnet/SC (AB.7).
Uses the ``websockets`` library's sans-I/O protocol objects together with
``asyncio`` TCP/TLS streams. Each :class:`SCWebSocket` instance owns one
WebSocket connection backed by a ``(StreamReader, StreamWriter)`` pair.
"""
from __future__ import annotations
import asyncio
import collections
import contextlib
import logging
import socket
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from websockets.client import ClientProtocol
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK, ProtocolError
from websockets.frames import Close, Frame, Opcode
from websockets.http11 import Request
from websockets.protocol import State as _WSState
from websockets.server import ServerProtocol
from websockets.typing import Subprotocol
from websockets.uri import parse_uri
if TYPE_CHECKING:
import ssl
from asyncio import StreamReader, StreamWriter
logger = logging.getLogger(__name__)
# Read buffer size for asyncio streams
_READ_SIZE = 65536
# Write buffer high/low water marks for StreamWriter. Lower than the
# asyncio default (64 KiB) so that backpressure triggers earlier for
# slow or dead peers — BACnet/SC frames are typically < 1600 bytes.
_WRITE_HIGH_WATER = 32768
_WRITE_LOW_WATER = 8192
def _set_nodelay(writer: StreamWriter) -> None:
"""Enable TCP_NODELAY and tune write buffer limits on the socket.
Disables Nagle's algorithm so small BACnet/SC frames are sent immediately
rather than being buffered. Critical for low-latency request-response
patterns where Nagle + delayed-ACK interaction can add 40-200ms stalls.
Also sets write buffer water marks lower than the asyncio default so
backpressure triggers earlier for slow peers.
"""
sock = writer.get_extra_info("socket")
if sock is not None:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
transport = writer.transport
if transport is not None:
transport.set_write_buffer_limits(high=_WRITE_HIGH_WATER, low=_WRITE_LOW_WATER)
def _drain_to_send(protocol: ClientProtocol | ServerProtocol) -> bytes:
"""Collect all pending outgoing data from the protocol."""
chunks = protocol.data_to_send()
return b"".join(chunks)
def _write_pending(protocol: ClientProtocol | ServerProtocol, writer: StreamWriter) -> bool:
"""Write all pending protocol data directly to the writer.
Writes each chunk individually to avoid the ``b"".join()`` copy.
Returns True if any data was written.
"""
wrote = False
for chunk in protocol.data_to_send():
if chunk:
writer.write(chunk)
wrote = True
return wrote
[docs]
class SCWebSocket:
"""Async WebSocket connection using the websockets sans-I/O protocol.
Each instance represents one open WebSocket connection and is backed
by asyncio ``StreamReader``/``StreamWriter`` objects. All framing is
handled by the ``websockets`` sans-I/O ``ClientProtocol`` or
``ServerProtocol``.
"""
def __init__(
self,
reader: StreamReader,
writer: StreamWriter,
protocol: ClientProtocol | ServerProtocol,
*,
max_frame_size: int = 0,
) -> None:
self._reader = reader
self._writer = writer
self._protocol = protocol
self._max_frame_size = max_frame_size
# Buffer for events not yet consumed by recv(). When multiple
# WebSocket frames arrive in one TCP segment, events_received()
# returns all of them but recv() processes only one at a time.
self._pending_events: collections.deque[Frame] = collections.deque(maxlen=64)
self._oversize_count = 0 # consecutive oversized frame counter
# -- Client factory --
[docs]
@classmethod
async def connect(
cls,
uri: str,
ssl_ctx: ssl.SSLContext | None,
subprotocol: str,
*,
handshake_timeout: float = 10.0,
max_size: int | None = None,
) -> SCWebSocket:
"""Initiate a WebSocket client connection.
:param uri: WebSocket URI (``wss://host:port/path``).
:param ssl_ctx: TLS context, or None for plaintext ``ws://``.
:param subprotocol: WebSocket subprotocol to negotiate.
:param handshake_timeout: Maximum seconds for the WebSocket handshake.
:param max_size: Maximum WebSocket message size. Passed to the
protocol layer so oversized frames are rejected early.
"""
parsed = urlparse(uri)
host = parsed.hostname or "localhost"
default_port = 443 if parsed.scheme == "wss" else 80
port = parsed.port or default_port
use_ssl = ssl_ctx if parsed.scheme == "wss" else None
if use_ssl is None:
logger.warning(
"SC WebSocket connecting WITHOUT TLS to %s:%d — "
"traffic is unencrypted and unauthenticated. "
"BACnet/SC requires TLS 1.3 in production (Annex AB.7.4).",
host,
port,
)
else:
logger.debug("SC WebSocket connecting (TLS) to %s:%d", host, port)
reader, writer = await asyncio.open_connection(host, port, ssl=use_ssl)
_set_nodelay(writer)
ws_uri = parse_uri(uri)
protocol = ClientProtocol(
ws_uri,
subprotocols=[Subprotocol(subprotocol)],
max_size=max_size,
)
request = protocol.connect()
protocol.send_request(request)
outgoing = _drain_to_send(protocol)
if outgoing:
writer.write(outgoing)
await writer.drain()
# Read HTTP response (with timeout to prevent indefinite hangs)
try:
async with asyncio.timeout(handshake_timeout):
while True:
data = await reader.read(_READ_SIZE)
if not data:
msg = "Connection closed during WebSocket handshake"
raise ConnectionError(msg)
protocol.receive_data(data)
if protocol.handshake_exc is not None:
raise protocol.handshake_exc
# Check for events that indicate handshake completion
events = protocol.events_received()
if events:
break
# Also check if there's data to send (e.g. during upgrade)
outgoing = _drain_to_send(protocol)
if outgoing:
writer.write(outgoing)
await writer.drain()
except BaseException:
writer.close()
raise
logger.debug("SC WebSocket client connected to %s:%d", host, port)
return cls(reader, writer, protocol)
# -- Server factory --
[docs]
@classmethod
async def accept(
cls,
reader: StreamReader,
writer: StreamWriter,
subprotocol: str,
*,
handshake_timeout: float = 10.0,
max_size: int | None = None,
) -> SCWebSocket:
"""Accept an inbound WebSocket connection on existing streams.
:param reader: asyncio StreamReader from accepted connection.
:param writer: asyncio StreamWriter from accepted connection.
:param subprotocol: WebSocket subprotocol to accept.
:param handshake_timeout: Maximum seconds to wait for the handshake.
:param max_size: Maximum WebSocket message size. Passed to the
protocol layer so oversized frames are rejected early.
"""
protocol = ServerProtocol(
subprotocols=[Subprotocol(subprotocol)],
max_size=max_size,
)
_set_nodelay(writer)
# Read the HTTP upgrade request (with timeout to prevent slow clients)
async with asyncio.timeout(handshake_timeout):
while True:
data = await reader.read(_READ_SIZE)
if not data:
msg = "Connection closed before WebSocket handshake"
raise ConnectionError(msg)
protocol.receive_data(data)
events = protocol.events_received()
if events:
request = events[0]
if not isinstance(request, Request):
msg = f"Expected HTTP request, got {type(request)}"
raise ProtocolError(msg)
break
# Accept the connection with the desired subprotocol
response = protocol.accept(request)
protocol.send_response(response)
outgoing = _drain_to_send(protocol)
if outgoing:
writer.write(outgoing)
await writer.drain()
if protocol.handshake_exc is not None:
raise protocol.handshake_exc
logger.debug("SC WebSocket server accepted connection")
return cls(reader, writer, protocol)
# -- I/O operations --
[docs]
async def send(self, data: bytes) -> None:
"""Send a binary WebSocket frame."""
logger.debug("SC WebSocket send: %d bytes", len(data))
self._protocol.send_binary(data)
if _write_pending(self._protocol, self._writer):
await self._writer.drain()
send_raw = send # Alias — same operation, named for semantic clarity
[docs]
def write_no_drain(self, data: bytes) -> bool:
"""Buffer a binary WebSocket frame without draining.
Returns True if data was written to the transport buffer.
Call :meth:`drain` afterwards to flush. Used by hub broadcast
to batch writes across connections before draining concurrently.
"""
self._protocol.send_binary(data)
return _write_pending(self._protocol, self._writer)
[docs]
async def drain(self) -> None:
"""Drain the write buffer. Pair with :meth:`write_no_drain`."""
await self._writer.drain()
[docs]
async def recv(self) -> bytes:
"""Receive the next binary WebSocket message.
:raises ConnectionClosedOK: On graceful close.
:raises ConnectionClosedError: On abnormal close.
"""
while True:
# Drain pending events first (from a previous multi-frame read)
while self._pending_events:
event = self._pending_events.popleft()
result = await self._process_frame(event)
if result is not None:
return result
# Fetch new events from the protocol
events = self._protocol.events_received()
for i, raw_event in enumerate(events):
if isinstance(raw_event, Frame):
result = await self._process_frame(raw_event)
if result is not None:
# Stash remaining events for subsequent recv() calls
for remaining in events[i + 1 :]:
if isinstance(remaining, Frame):
self._pending_events.append(remaining)
return result
# Need more data from the network
data = await self._reader.read(_READ_SIZE)
if not data:
logger.warning("SC WebSocket connection closed unexpectedly")
raise ConnectionClosedError(None, None, rcvd_then_sent=None)
self._protocol.receive_data(data)
if self._protocol.handshake_exc is not None:
raise self._protocol.handshake_exc
await self._flush_outgoing()
async def _process_frame(self, event: Frame) -> bytes | None:
"""Process a single WebSocket frame.
Returns the payload bytes for BINARY frames, None for control frames.
Raises on CLOSE frames.
"""
if event.opcode == Opcode.BINARY:
if self._max_frame_size and len(event.data) > self._max_frame_size:
self._oversize_count += 1
logger.warning(
"SC WebSocket frame too large: %d bytes (max %d), dropping (%d consecutive)",
len(event.data),
self._max_frame_size,
self._oversize_count,
)
if self._oversize_count >= 3:
raise ConnectionClosedError(None, None, rcvd_then_sent=None)
return None
self._oversize_count = 0
logger.debug("SC WebSocket recv: %d bytes", len(event.data))
return bytes(event.data)
if event.opcode == Opcode.CLOSE:
rcvd = Close.parse(event.data) if event.data else None
await self._flush_outgoing()
raise ConnectionClosedOK(rcvd, None, rcvd_then_sent=None)
if event.opcode in (Opcode.PING, Opcode.PONG):
await self._flush_outgoing()
return None
# Non-binary data frames (e.g. TEXT): close with 1003 per AB.7.5.3
if event.opcode == Opcode.TEXT:
logger.warning(
"SC WebSocket received TEXT frame, closing with status 1003 per AB.7.5.3"
)
try:
self._protocol.send_close(1003, "Non-binary data not accepted")
_write_pending(self._protocol, self._writer)
except (OSError, ConnectionError):
pass
raise ConnectionClosedError(None, None, rcvd_then_sent=None)
return None
[docs]
async def close(self, code: int = 1000, reason: str = "") -> None:
"""Initiate graceful WebSocket close."""
try:
self._protocol.send_close(code, reason)
outgoing = _drain_to_send(self._protocol)
if outgoing:
self._writer.write(outgoing)
await self._writer.drain()
# Wait briefly for close acknowledgement
try:
async with asyncio.timeout(5):
data = await self._reader.read(_READ_SIZE)
if data:
self._protocol.receive_data(data)
except (TimeoutError, OSError, ConnectionError):
pass
except (OSError, ConnectionError):
pass
finally:
self._close_transport()
async def _flush_outgoing(self) -> None:
"""Write any pending protocol output to the transport."""
if _write_pending(self._protocol, self._writer):
with contextlib.suppress(OSError, ConnectionError):
await self._writer.drain()
def _close_transport(self) -> None:
"""Close the underlying TCP connection."""
try:
if not self._writer.is_closing():
self._writer.close()
except (OSError, RuntimeError):
pass
@property
def is_open(self) -> bool:
"""Return True if the WebSocket connection appears open."""
return self._protocol.state is _WSState.OPEN
@property
def subprotocol(self) -> str | None:
"""Return the negotiated WebSocket subprotocol."""
return self._protocol.subprotocol