Source code for bac_py.transport.sc.hub_function

"""BACnet/SC Hub Function (AB.5.3).

The Hub Function accepts BACnet/SC connections as hub connections and
routes unicast/broadcast BVLC-SC messages between connected nodes.
"""

from __future__ import annotations

import asyncio
import contextlib
import logging
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from bac_py.transport.sc.connection import (
    SCConnection,
    SCConnectionConfig,
    SCConnectionState,
)
from bac_py.transport.sc.tls import SCTLSConfig, build_server_ssl_context
from bac_py.transport.sc.types import SC_HUB_SUBPROTOCOL, VMAC_BROADCAST
from bac_py.transport.sc.websocket import SCWebSocket

if TYPE_CHECKING:
    from bac_py.transport.sc.bvlc import SCMessage
    from bac_py.transport.sc.vmac import SCVMAC, DeviceUUID

logger = logging.getLogger(__name__)
_DEBUG = logging.DEBUG


[docs] @dataclass class SCHubConfig: """Configuration for an SC Hub Function.""" bind_address: str = "0.0.0.0" bind_port: int = 4443 tls_config: SCTLSConfig = field(default_factory=SCTLSConfig) connection_config: SCConnectionConfig = field(default_factory=SCConnectionConfig) max_connections: int = 1000 max_bvlc_length: int = 6000 max_npdu_length: int = 1497
[docs] class SCHubFunction: """BACnet/SC Hub Function (AB.5.3). WebSocket server that accepts hub connections from SC nodes. Routes unicast messages to destination VMAC, replicates broadcasts to all connected nodes except the source. """ def __init__( self, hub_vmac: SCVMAC, hub_uuid: DeviceUUID, config: SCHubConfig | None = None, ) -> None: self._config = config or SCHubConfig() self._hub_vmac = hub_vmac self._hub_uuid = hub_uuid self._connections: dict[SCVMAC, SCConnection] = {} self._uuid_map: dict[DeviceUUID, SCVMAC] = {} self._pending_vmacs: dict[SCVMAC, float] = {} self._server: asyncio.Server | None = None self._client_tasks: set[asyncio.Task[None]] = set() @property def connections(self) -> dict[SCVMAC, SCConnection]: """Active connections indexed by peer VMAC.""" return dict(self._connections) @property def connection_count(self) -> int: """Number of active hub connections.""" return len(self._connections)
[docs] async def start(self) -> None: """Start the hub function WebSocket server.""" ssl_ctx = build_server_ssl_context(self._config.tls_config) if ssl_ctx is None: logger.warning( "SC Hub Function starting WITHOUT TLS on %s:%d — " "client connections will be unencrypted and unauthenticated", self._config.bind_address, self._config.bind_port, ) self._server = await asyncio.start_server( self._handle_client, self._config.bind_address, self._config.bind_port, ssl=ssl_ctx, ) logger.info( "SC Hub Function listening on %s:%d", self._config.bind_address, self._config.bind_port, )
[docs] async def stop(self) -> None: """Stop the hub function, force-close all connections.""" logger.info("SC hub function stopping") # Stop accepting new connections first to prevent new clients # arriving during shutdown cleanup. if self._server: self._server.close() # Force-close all connections (before wait_closed(), since # wait_closed() blocks until active connection handlers finish) for conn in list(self._connections.values()): with contextlib.suppress(Exception): await conn._go_idle() # Cancel outstanding client tasks for task in self._client_tasks: if not task.done(): task.cancel() if self._client_tasks: await asyncio.gather(*self._client_tasks, return_exceptions=True) self._client_tasks.clear() self._connections.clear() self._uuid_map.clear() self._pending_vmacs.clear() if self._server: await self._server.wait_closed() self._server = None logger.info("SC hub function stopped")
async def _handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: """Handle an inbound WebSocket connection.""" if len(self._connections) >= self._config.max_connections: writer.close() return try: ws = await SCWebSocket.accept( reader, writer, SC_HUB_SUBPROTOCOL, max_size=self._config.max_bvlc_length ) except Exception: logger.debug("WebSocket accept failed", exc_info=True) writer.close() return conn = SCConnection( self._hub_vmac, self._hub_uuid, config=self._config.connection_config, max_bvlc_length=self._config.max_bvlc_length, max_npdu_length=self._config.max_npdu_length, hub_mode=True, ) conn.on_connected = lambda: self._on_node_connected(conn) conn.on_disconnected = lambda: self._on_node_disconnected(conn) async def _route(msg: SCMessage, raw: bytes | None = None) -> None: await self._on_node_message(conn, msg, raw) conn.on_message = _route task = asyncio.create_task(self._accept_with_cleanup(conn, ws)) self._client_tasks.add(task) task.add_done_callback(self._on_client_task_done) async def _accept_with_cleanup(self, conn: SCConnection, ws: SCWebSocket) -> None: """Run the accept handshake, releasing any pending VMAC on failure.""" try: await conn.accept(ws, vmac_checker=self._check_vmac) finally: # If the handshake failed before on_connected, the VMAC is still # in _pending_vmacs. discard() is idempotent so this is safe # even when on_connected already cleaned it up. if conn.peer_vmac: self._pending_vmacs.pop(conn.peer_vmac, None) def _on_client_task_done(self, task: asyncio.Task[None]) -> None: """Clean up a finished client task and log unexpected errors.""" self._client_tasks.discard(task) if not task.cancelled() and task.exception() is not None: logger.debug("SC hub client task failed: %s", task.exception()) def _check_vmac(self, vmac: SCVMAC, uuid: DeviceUUID) -> bool: """Check if VMAC/UUID pair is acceptable (no collision). Atomically reserves the VMAC in ``_pending_vmacs`` to prevent a TOCTOU race between the check and the registration in ``_on_node_connected``. The reservation is released by ``_on_node_connected`` on success or ``_accept_with_cleanup`` on failure. Stale pending entries (older than 30s) are purged on each call. The pending set is capped at ``max_connections``. """ # Purge stale pending entries (30-second TTL) now = time.monotonic() stale = [k for k, ts in self._pending_vmacs.items() if now - ts > 30.0] for k in stale: self._pending_vmacs.pop(k, None) if vmac in self._connections: existing = self._connections[vmac] if existing.peer_uuid != uuid: return False if vmac in self._pending_vmacs: return False if len(self._pending_vmacs) >= self._config.max_connections: return False self._pending_vmacs[vmac] = now return True def _on_node_connected(self, conn: SCConnection) -> None: """Register a newly connected node.""" if conn.peer_vmac is None or conn.peer_uuid is None: return self._pending_vmacs.pop(conn.peer_vmac, None) # If this UUID was previously connected with a different VMAC, clean up if conn.peer_uuid in self._uuid_map: old_vmac = self._uuid_map[conn.peer_uuid] if old_vmac != conn.peer_vmac and old_vmac in self._connections: old_conn = self._connections.pop(old_vmac) task = asyncio.create_task(old_conn.disconnect()) self._client_tasks.add(task) task.add_done_callback(self._on_client_task_done) self._connections[conn.peer_vmac] = conn self._uuid_map[conn.peer_uuid] = conn.peer_vmac logger.info("SC node connected: VMAC=%s", conn.peer_vmac) def _on_node_disconnected(self, conn: SCConnection) -> None: """Remove a disconnected node from the connection table.""" if conn.peer_vmac: self._pending_vmacs.pop(conn.peer_vmac, None) if ( conn.peer_vmac and conn.peer_vmac in self._connections and self._connections[conn.peer_vmac] is conn ): del self._connections[conn.peer_vmac] if ( conn.peer_uuid and conn.peer_uuid in self._uuid_map and self._uuid_map.get(conn.peer_uuid) == conn.peer_vmac ): del self._uuid_map[conn.peer_uuid] logger.info("SC node disconnected: VMAC=%s", conn.peer_vmac) async def _on_node_message( self, source: SCConnection, msg: SCMessage, raw: bytes | None = None ) -> None: """Route a message received from a connected node. - Unicast (destination set): forward to matching VMAC connection. - Broadcast (no destination): replicate to all except source. Per AB.5.3, the hub rewrites VMAC headers before forwarding: - **Unicast (AB.5.3.2):** Sets Originating VMAC to the sender's VMAC and removes the Destination VMAC. - **Broadcast (AB.5.3.3):** Sets Originating VMAC to the sender's VMAC and keeps the Destination VMAC as the broadcast address. """ if source.peer_vmac is None: return # Validate originating VMAC matches the authenticated peer identity # to prevent VMAC spoofing attacks (Annex AB.6.2). if msg.originating and msg.originating != source.peer_vmac: logger.warning( "SC hub dropping message: originating VMAC %s does not match " "authenticated peer %s", msg.originating, source.peer_vmac, ) return # Rewrite raw bytes for hub forwarding (AB.5.3) source_mac = source.peer_vmac.address if msg.destination and not msg.destination.is_broadcast: if __debug__ and logger.isEnabledFor(_DEBUG): logger.debug( "SC hub routing unicast from %s to %s", source.peer_vmac, msg.destination ) forwarded = _rewrite_for_hub_unicast(raw, source_mac) if raw else None await self._unicast(msg, source.peer_vmac, forwarded) else: if __debug__ and logger.isEnabledFor(_DEBUG): logger.debug("SC hub routing broadcast from %s", source.peer_vmac) forwarded = _rewrite_for_hub_broadcast(raw, source_mac) if raw else None await self._broadcast(msg, source.peer_vmac, forwarded) async def _unicast(self, msg: SCMessage, exclude: SCVMAC, raw: bytes | None = None) -> None: """Forward unicast message to destination VMAC. Uses pre-encoded *raw* bytes when available to skip re-encoding. """ if msg.destination is None: return dest_conn = self._connections.get(msg.destination) if dest_conn and dest_conn.state == SCConnectionState.CONNECTED: try: if raw is not None: await dest_conn.send_raw(raw) else: await dest_conn.send_message(msg) except Exception: logger.debug("Hub unicast to %s failed", msg.destination, exc_info=True) async def _broadcast(self, msg: SCMessage, exclude: SCVMAC, raw: bytes | None = None) -> None: """Send message to all connected nodes except the source. Uses pre-encoded *raw* bytes when available to skip re-encoding. Batches writes to all connections first (synchronous buffer), then drains them concurrently for lower latency. """ encoded = raw if raw is not None else msg.encode() targets = [ conn for vmac, conn in self._connections.items() if vmac != exclude and conn.state == SCConnectionState.CONNECTED ] if len(targets) == 1: try: await targets[0].send_raw(encoded) except Exception: logger.debug("Hub broadcast to single target failed", exc_info=True) elif targets: # Phase 1: buffer writes synchronously (no await) needs_drain = [] for conn in targets: try: if conn.write_raw_no_drain(encoded): needs_drain.append(conn) except Exception: logger.debug("Hub broadcast buffer write failed", exc_info=True) # Phase 2: drain all concurrently if needs_drain: await asyncio.gather( *(conn.drain() for conn in needs_drain), return_exceptions=True, )
# --------------------------------------------------------------------------- # Hub forwarding byte rewriters (AB.5.3) # --------------------------------------------------------------------------- # Control flag bit positions (AB.2.2) _ORIGINATING_VMAC = 0x08 _DESTINATION_VMAC = 0x04 _OPTIONS_MASK = 0x03 # bits 0-1: data_options, dest_options def _rewrite_for_hub_unicast(raw: bytes | None, source_vmac: bytes) -> bytes | None: """Rewrite raw BVLC-SC bytes for hub unicast forwarding (AB.5.3.2). Sets Originating VMAC to the source peer's VMAC and removes the Destination VMAC. Preserves message function, message ID, header options, and payload. """ if raw is None or len(raw) < 4: return raw flags = raw[1] & 0x0F offset = 4 # skip function(1) + flags(1) + message_id(2) # Skip existing VMAC fields to find where options+payload start if flags & _ORIGINATING_VMAC: offset += 6 if flags & _DESTINATION_VMAC: offset += 6 # New flags: Originating VMAC only + preserve options flag bits new_flags = (flags & _OPTIONS_MASK) | _ORIGINATING_VMAC # Build: header(4) + originating_vmac(6) + rest rest_len = len(raw) - offset result = bytearray(10 + rest_len) result[0] = raw[0] result[1] = new_flags result[2] = raw[2] result[3] = raw[3] result[4:10] = source_vmac if rest_len: result[10:] = raw[offset:] return bytes(result) def _rewrite_for_hub_broadcast(raw: bytes | None, source_vmac: bytes) -> bytes | None: """Rewrite raw BVLC-SC bytes for hub broadcast forwarding (AB.5.3.3). Sets Originating VMAC to the source peer's VMAC and sets the Destination VMAC to the broadcast address (``FF:FF:FF:FF:FF:FF``). Preserves message function, message ID, header options, and payload. """ if raw is None or len(raw) < 4: return raw flags = raw[1] & 0x0F offset = 4 if flags & _ORIGINATING_VMAC: offset += 6 if flags & _DESTINATION_VMAC: offset += 6 # New flags: Originating + Destination VMAC + preserve options flags new_flags = (flags & _OPTIONS_MASK) | _ORIGINATING_VMAC | _DESTINATION_VMAC # Build: header(4) + originating_vmac(6) + broadcast_vmac(6) + rest rest_len = len(raw) - offset result = bytearray(16 + rest_len) result[0] = raw[0] result[1] = new_flags result[2] = raw[2] result[3] = raw[3] result[4:10] = source_vmac result[10:16] = VMAC_BROADCAST if rest_len: result[16:] = raw[offset:] return bytes(result)