Source code for bac_py.transport.sc.node_switch

"""BACnet/SC Node Switch (AB.4).

Manages direct peer-to-peer connections between SC nodes, bypassing
the hub for unicast traffic.  Listens for inbound direct connections
and initiates outbound connections via address resolution through
the hub.
"""

from __future__ import annotations

import asyncio
import contextlib
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from bac_py.transport.sc.bvlc import AddressResolutionAckPayload, SCMessage
from bac_py.transport.sc.connection import (
    SCConnection,
    SCConnectionConfig,
    SCConnectionState,
)
from bac_py.transport.sc.tls import SCTLSConfig, build_client_ssl_context, build_server_ssl_context
from bac_py.transport.sc.types import (
    SC_DIRECT_SUBPROTOCOL,
    BvlcSCFunction,
)
from bac_py.transport.sc.websocket import SCWebSocket

if TYPE_CHECKING:
    from collections.abc import Awaitable, Callable

    from bac_py.transport.sc.vmac import SCVMAC, DeviceUUID

logger = logging.getLogger(__name__)


[docs] @dataclass class SCNodeSwitchConfig: """Configuration for an SC Node Switch.""" enable: bool = False bind_address: str = "0.0.0.0" bind_port: int = 0 tls_config: SCTLSConfig = field(default_factory=SCTLSConfig) connection_config: SCConnectionConfig = field(default_factory=SCConnectionConfig) address_resolution_timeout: float = 5.0 max_connections: int = 100 max_bvlc_length: int = 1600 max_npdu_length: int = 1497
[docs] class SCNodeSwitch: """BACnet/SC Node Switch (AB.4). Manages direct connections between SC nodes. Listens for inbound direct connections and initiates outbound connections via address resolution through the hub. """ def __init__( self, local_vmac: SCVMAC, local_uuid: DeviceUUID, config: SCNodeSwitchConfig | None = None, ) -> None: self._config = config or SCNodeSwitchConfig() self._local_vmac = local_vmac self._local_uuid = local_uuid self._direct_connections: dict[SCVMAC, SCConnection] = {} self._server: asyncio.Server | None = None self._client_tasks: set[asyncio.Task[None]] = set() self._pending_resolutions: dict[SCVMAC, asyncio.Future[list[str]]] = {} # Cache TLS contexts (immutable after init) self._client_ssl_ctx = build_client_ssl_context(self._config.tls_config) self._server_ssl_ctx = build_server_ssl_context(self._config.tls_config) # Callbacks self.on_message: Callable[[SCMessage, bytes | None], Awaitable[None] | None] | None = None @property def connections(self) -> dict[SCVMAC, SCConnection]: """Active direct connections indexed by peer VMAC.""" return dict(self._direct_connections) @property def connection_count(self) -> int: """Number of active direct connections.""" return len(self._direct_connections) @property def local_vmac(self) -> SCVMAC: """Local VMAC address.""" return self._local_vmac @local_vmac.setter def local_vmac(self, value: SCVMAC) -> None: self._local_vmac = value
[docs] async def start(self) -> None: """Start the node switch, listening for inbound direct connections.""" if not self._config.enable: return if self._server_ssl_ctx is None: logger.warning( "SC Node Switch starting WITHOUT TLS on %s:%d — " "direct connections will be unencrypted and unauthenticated", self._config.bind_address, self._config.bind_port, ) self._server = await asyncio.start_server( self._handle_inbound, self._config.bind_address, self._config.bind_port, ssl=self._server_ssl_ctx, ) logger.info( "SC Node Switch listening on %s:%d", self._config.bind_address, self._config.bind_port, )
[docs] async def stop(self) -> None: """Stop the node switch and close all direct connections.""" logger.info("SC node switch stopping") # Cancel pending resolutions for fut in self._pending_resolutions.values(): if not fut.done(): fut.cancel() self._pending_resolutions.clear() # Close all direct connections (must happen BEFORE server.close() # because Python 3.13 wait_closed() blocks until handlers finish) for conn in list(self._direct_connections.values()): with contextlib.suppress(Exception): await conn._go_idle() self._direct_connections.clear() # Cancel pending tasks for task in self._client_tasks: if not task.done(): task.cancel() if self._client_tasks: await asyncio.gather(*self._client_tasks, return_exceptions=True) self._client_tasks.clear() # Stop server (after connections are closed) if self._server: self._server.close() await self._server.wait_closed() self._server = None logger.info("SC node switch stopped")
[docs] def has_direct(self, dest: SCVMAC) -> bool: """Check if a direct connection exists to the given VMAC.""" conn = self._direct_connections.get(dest) return conn is not None and conn.state == SCConnectionState.CONNECTED
[docs] async def send_direct(self, dest: SCVMAC, msg: SCMessage) -> bool: """Send via direct connection if available. :returns: True if sent successfully, False if no direct connection. """ conn = self._direct_connections.get(dest) if conn is None or conn.state != SCConnectionState.CONNECTED: return False try: await conn.send_message(msg) logger.debug("SC message sent via direct connection to %s", dest) return True except (ConnectionError, OSError): logger.debug("SC direct send failed to %s", dest) return False
[docs] async def resolve_address( self, dest: SCVMAC, hub_send: Callable[[SCMessage], Awaitable[None]], ) -> list[str]: """Request WebSocket URIs for a peer via Address-Resolution through the hub. :param dest: Target VMAC to resolve. :param hub_send: Callback to send a message through the hub connection. :returns: List of WebSocket URIs for the target, empty if not resolved. """ if dest in self._pending_resolutions: return await self._pending_resolutions[dest] if len(self._pending_resolutions) >= self._config.max_connections: logger.warning("SC address resolution cache full, dropping request") return [] fut: asyncio.Future[list[str]] = asyncio.get_running_loop().create_future() self._pending_resolutions[dest] = fut msg = SCMessage( BvlcSCFunction.ADDRESS_RESOLUTION, message_id=0, originating=self._local_vmac, destination=dest, ) try: await hub_send(msg) async with asyncio.timeout(self._config.address_resolution_timeout): return await fut except (TimeoutError, OSError, ConnectionError, asyncio.CancelledError): return [] finally: self._pending_resolutions.pop(dest, None)
[docs] def handle_address_resolution_ack(self, msg: SCMessage) -> None: """Handle an Address-Resolution-ACK received from the hub. Resolves the pending future for the originating VMAC. """ if msg.originating and msg.originating in self._pending_resolutions: fut = self._pending_resolutions[msg.originating] if not fut.done(): ack = AddressResolutionAckPayload.decode(msg.payload) fut.set_result(list(ack.websocket_uris))
[docs] async def establish_direct(self, dest: SCVMAC, uris: list[str]) -> bool: """Try connecting to a peer using resolved URIs. :returns: True if connected successfully. """ if len(self._direct_connections) >= self._config.max_connections: logger.debug("SC direct connection limit reached (%d)", self._config.max_connections) return False logger.debug("SC establishing direct connection to %s via %s", dest, uris) for uri in uris: if not uri.startswith(("ws://", "wss://")): logger.warning( "SC ignoring non-WebSocket URI in address resolution: %s", uri, ) continue try: ws = await SCWebSocket.connect( uri, self._client_ssl_ctx, SC_DIRECT_SUBPROTOCOL, max_size=self._config.max_bvlc_length, ) except (OSError, ConnectionError): continue conn = SCConnection( self._local_vmac, self._local_uuid, config=self._config.connection_config, max_bvlc_length=self._config.max_bvlc_length, max_npdu_length=self._config.max_npdu_length, ) conn.on_message = self.on_message conn.on_disconnected = self._make_disconnect_cb(conn) await conn.initiate(ws) if conn.state == SCConnectionState.CONNECTED and conn.peer_vmac == dest: self._direct_connections[dest] = conn logger.info("Direct connection established to VMAC=%s via %s", dest, uri) return True # Wrong peer or failed — clean up with contextlib.suppress(Exception): await conn._go_idle() return False
# ------------------------------------------------------------------ # Inbound connection handling # ------------------------------------------------------------------ async def _handle_inbound( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, ) -> None: """Handle an inbound direct connection from a peer.""" if len(self._direct_connections) >= self._config.max_connections: writer.close() return try: ws = await SCWebSocket.accept( reader, writer, SC_DIRECT_SUBPROTOCOL, max_size=self._config.max_bvlc_length ) except Exception: logger.debug("Direct WebSocket accept failed", exc_info=True) writer.close() return conn = SCConnection( self._local_vmac, self._local_uuid, config=self._config.connection_config, max_bvlc_length=self._config.max_bvlc_length, max_npdu_length=self._config.max_npdu_length, ) conn.on_message = self.on_message conn.on_connected = lambda: self._on_inbound_connected(conn) conn.on_disconnected = lambda: self._on_direct_disconnected(conn) task = asyncio.create_task(conn.accept(ws)) self._client_tasks.add(task) task.add_done_callback(self._on_client_task_done) def _on_client_task_done(self, task: asyncio.Task[None]) -> None: """Clean up a finished client task and log unexpected errors.""" self._client_tasks.discard(task) if not task.cancelled() and task.exception() is not None: logger.debug("SC direct connection task failed: %s", task.exception()) def _make_disconnect_cb(self, conn: SCConnection) -> Callable[[], None]: """Create a disconnect callback bound to a specific connection.""" def cb() -> None: self._on_direct_disconnected(conn) return cb def _on_inbound_connected(self, conn: SCConnection) -> None: """Register an inbound direct connection after handshake.""" if conn.peer_vmac is None: return self._direct_connections[conn.peer_vmac] = conn logger.info("Inbound direct connection from VMAC=%s", conn.peer_vmac) def _on_direct_disconnected(self, conn: SCConnection) -> None: """Remove a disconnected direct connection.""" if ( conn.peer_vmac and conn.peer_vmac in self._direct_connections and self._direct_connections[conn.peer_vmac] is conn ): del self._direct_connections[conn.peer_vmac] logger.info("Direct connection disconnected: VMAC=%s", conn.peer_vmac)