Source code for bac_py.segmentation.manager

"""APDU segmentation and reassembly logic per ASHRAE 135-2016 Clause 5.2/5.4.

This module contains pure segmentation logic with no I/O dependencies.
TSMs drive instances of SegmentSender/SegmentReceiver via method calls.
"""

from __future__ import annotations

import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Literal

from bac_py.types.enums import AbortReason

_MAX_REASSEMBLY_SIZE = 1_048_576  # 1 MiB cap on total reassembly size

logger = logging.getLogger(__name__)

# Segment header overhead for each PDU type (when segmented=True).
# ConfirmedRequest: byte0 + byte1(max-seg/max-apdu) + invoke_id + seq_num + window_size + service_choice
CONFIRMED_REQUEST_SEGMENT_OVERHEAD = 6
# ComplexACK: byte0 + invoke_id + seq_num + window_size + service_choice
COMPLEX_ACK_SEGMENT_OVERHEAD = 5

DEFAULT_PROPOSED_WINDOW_SIZE = 16


[docs] class SegmentationError(Exception): """Raised when segmentation fails (e.g., APDU too long for peer).""" def __init__(self, abort_reason: AbortReason, message: str = "") -> None: self.abort_reason = abort_reason super().__init__(message)
[docs] class SegmentAction(Enum): """Action the receiver should take after processing a segment.""" CONTINUE = "continue" SEND_ACK = "send_ack" RESEND_LAST_ACK = "resend" COMPLETE = "complete" ABORT = "abort"
# --- Pure functions per Clause 5.4 ---
[docs] def in_window(seq_a: int, seq_b: int, actual_window_size: int) -> bool: """Determine if segment seq_a is within the window starting at seq_b. Per Clause 5.4: ``(seqA - seqB) mod 256 < ActualWindowSize``. """ return ((seq_a - seq_b) % 256) < actual_window_size
[docs] def duplicate_in_window( seq_a: int, seq_b: int, actual_window_size: int, proposed_window_size: int, ) -> bool: """Determine if segment seq_a is a duplicate the receiver has already seen. Per Clause 5.4: ``Wm < (seqA - seqB) mod 256 <= 255`` where ``Wm = max(ActualWindowSize, ProposedWindowSize)``. """ wm = max(actual_window_size, proposed_window_size) diff = (seq_a - seq_b) % 256 return wm < diff <= 255
[docs] def compute_max_segment_payload( max_apdu_length: int, pdu_type: Literal["confirmed_request", "complex_ack"], ) -> int: """Return the maximum service data bytes that fit in one segment. :param max_apdu_length: Max APDU size for the link (e.g. 480, 1476). :param pdu_type: Which PDU type determines the header overhead. :returns: Max payload bytes per segment. """ overhead = ( CONFIRMED_REQUEST_SEGMENT_OVERHEAD if pdu_type == "confirmed_request" else COMPLEX_ACK_SEGMENT_OVERHEAD ) return max_apdu_length - overhead
[docs] def split_payload(payload: bytes, max_segment_size: int) -> list[bytes]: """Split a byte payload into segments of at most max_segment_size bytes. :param payload: The raw service data to split. :param max_segment_size: Maximum bytes per segment. :returns: List of byte segments. At least one segment is always returned (which may be empty if the payload is empty). :raises ValueError: If max_segment_size is not positive. """ if max_segment_size <= 0: msg = f"max_segment_size must be positive, got {max_segment_size}" raise ValueError(msg) if len(payload) == 0: return [b""] segments: list[bytes] = [] for i in range(0, len(payload), max_segment_size): segments.append(payload[i : i + max_segment_size]) return segments
[docs] def check_segment_count(num_segments: int, max_segments: int | None) -> bool: """Check that the segment count does not exceed the peer's limit. :param num_segments: Total number of segments. :param max_segments: Peer's max-segments-accepted (``None`` = unlimited). :returns: ``True`` if within limits. """ if max_segments is None: return True return num_segments <= max_segments
# --- SegmentSender ---
[docs] @dataclass class SegmentSender: """Manages the send side of a segmented transaction. Tracks segments by absolute 0-based index and converts to 8-bit sequence numbers (``index & 0xFF``) for the wire protocol. """ segments: list[bytes] invoke_id: int service_choice: int proposed_window_size: int actual_window_size: int _window_start_idx: int = 0
[docs] @classmethod def create( cls, payload: bytes, invoke_id: int, service_choice: int, max_apdu_length: int, pdu_type: Literal["confirmed_request", "complex_ack"], proposed_window_size: int = DEFAULT_PROPOSED_WINDOW_SIZE, peer_max_segments: int | None = None, ) -> SegmentSender: """Create a SegmentSender by splitting the payload. :raises SegmentationError: If the segment count exceeds peer_max_segments. """ max_payload = compute_max_segment_payload(max_apdu_length, pdu_type) segments = split_payload(payload, max_payload) if not check_segment_count(len(segments), peer_max_segments): msg = ( f"Payload requires {len(segments)} segments but peer accepts " f"at most {peer_max_segments}" ) logger.warning( "segment count exceeded: %d segments, peer max=%s, invoke_id=%d", len(segments), peer_max_segments, invoke_id, ) raise SegmentationError(AbortReason.APDU_TOO_LONG, msg) logger.debug( "segmented send created: invoke_id=%d segments=%d window=%d", invoke_id, len(segments), proposed_window_size, ) return cls( segments=segments, invoke_id=invoke_id, service_choice=service_choice, proposed_window_size=proposed_window_size, actual_window_size=proposed_window_size, )
[docs] def fill_window(self) -> list[tuple[int, bytes, bool]]: """Return segments for the current window. :returns: List of ``(sequence_number, segment_data, more_follows)`` tuples. """ result: list[tuple[int, bytes, bool]] = [] end_idx = min(len(self.segments), self._window_start_idx + self.actual_window_size) last_idx = len(self.segments) - 1 for idx in range(self._window_start_idx, end_idx): seq_num = idx & 0xFF more_follows = idx < last_idx result.append((seq_num, self.segments[idx], more_follows)) logger.debug( "fill_window: invoke_id=%d segments=%d/%d window_size=%d", self.invoke_id, len(result), self.total_segments, self.actual_window_size, ) return result
[docs] def handle_segment_ack( self, ack_seq: int, actual_window_size: int, negative: bool, ) -> bool: """Process a SegmentACK. :param ack_seq: The sequence number in the SegmentACK. :param actual_window_size: The receiver's advertised window size. :param negative: Whether this is a negative ACK (retransmit request). :returns: ``True`` if all segments have been acknowledged. """ acked_idx = self._seq_to_idx(ack_seq) self.actual_window_size = actual_window_size if negative: logger.debug( "negative ack: invoke_id=%d seq=%d, resending from idx=%d", self.invoke_id, ack_seq, acked_idx + 1, ) self._window_start_idx = acked_idx + 1 if self.is_complete: logger.info( "segmented send complete: invoke_id=%d segments=%d", self.invoke_id, self.total_segments, ) return self.is_complete
@property def is_complete(self) -> bool: """True when all segments have been acknowledged.""" return self._window_start_idx >= len(self.segments) @property def total_segments(self) -> int: """Total number of segments.""" return len(self.segments) def _seq_to_idx(self, seq: int) -> int: """Map an 8-bit sequence number to the nearest absolute index. Sequence numbers wrap at 256 (``seq = idx & 0xFF``), so when the total segment count exceeds 255 multiple indices share the same sequence number. This method resolves the ambiguity by searching forward from a lower bound based on the current window position. Finds the index at or above ``_window_start_idx - actual_window_size`` (to handle the case where the ACK references the last segment in the previous window) where ``idx & 0xFF == seq``. """ # Start searching from a reasonable lower bound search_start = max(0, self._window_start_idx - self.actual_window_size) for idx in range(search_start, len(self.segments)): if (idx & 0xFF) == seq: return idx # Fallback: indicates a protocol state mismatch logger.warning( "Could not map sequence number %d to segment index; falling back to window start %d", seq, self._window_start_idx, ) return self._window_start_idx
# --- SegmentReceiver ---
[docs] @dataclass class SegmentReceiver: """Manages reassembly of received segments. Tracks segments by absolute 0-based index. """ _segments: dict[int, bytes] = field(default_factory=dict) _expected_idx: int = 0 actual_window_size: int = DEFAULT_PROPOSED_WINDOW_SIZE proposed_window_size: int = 1 service_choice: int = 0 _final_idx: int | None = None _last_ack_seq: int = 0 _window_start_idx: int = 0 _total_bytes: int = 0 created_at: float = field(default_factory=time.monotonic)
[docs] @classmethod def create( cls, first_segment_data: bytes, service_choice: int, proposed_window_size: int, more_follows: bool = True, our_window_size: int = DEFAULT_PROPOSED_WINDOW_SIZE, ) -> SegmentReceiver: """Create a receiver from the first segment (sequence number 0). :param first_segment_data: Payload of the first segment. :param service_choice: Service choice from the PDU header. :param proposed_window_size: Window size proposed by the sender. :param more_follows: The more-follows flag from the first segment. :param our_window_size: Our preferred window size. :returns: New :class:`SegmentReceiver` with the first segment stored. """ actual = min(our_window_size, proposed_window_size) receiver = cls( _segments={0: first_segment_data}, _expected_idx=1, actual_window_size=actual, proposed_window_size=proposed_window_size, service_choice=service_choice, _final_idx=0 if not more_follows else None, _last_ack_seq=0, _window_start_idx=1, _total_bytes=len(first_segment_data), ) return receiver
[docs] def receive_segment( self, seq_num: int, data: bytes, more_follows: bool, ) -> tuple[SegmentAction, int]: """Process a received segment. Per Clause 5.4, SegmentACKs are sent at window boundaries (when the window is full) or when the transfer is complete, not for every individual segment. :param seq_num: 8-bit sequence number from the PDU. :param data: Segment payload data. :param more_follows: The more-follows flag from the PDU. :returns: ``(action, ack_sequence_number)`` indicating what the caller should do next. For ABORT, ack_sequence_number is -1. CONTINUE means the segment was stored but no ACK is needed yet. """ expected_seq = self._expected_idx & 0xFF if in_window(seq_num, expected_seq, self.actual_window_size): # Map sequence to absolute index abs_idx = self._seq_to_abs_idx(seq_num) old = self._segments.get(abs_idx) if old is not None: self._total_bytes -= len(old) self._segments[abs_idx] = data self._total_bytes += len(data) if self._total_bytes > _MAX_REASSEMBLY_SIZE: logger.warning("Reassembly size exceeds %d bytes, aborting", _MAX_REASSEMBLY_SIZE) return (SegmentAction.ABORT, -1) logger.debug( "segment %d/%d received, idx=%d, more_follows=%s", seq_num, self.actual_window_size, abs_idx, more_follows, ) if not more_follows: self._final_idx = abs_idx # Advance expected past contiguously received segments while self._expected_idx in self._segments: self._expected_idx += 1 self._last_ack_seq = seq_num if self.is_complete: self._window_start_idx = self._expected_idx logger.info("segmented transfer complete: segments=%s", len(self._segments)) return (SegmentAction.COMPLETE, seq_num) # ACK at window boundary: when contiguous reception fills the window window_end = self._window_start_idx + self.actual_window_size if self._expected_idx >= window_end: self._window_start_idx = self._expected_idx return (SegmentAction.SEND_ACK, seq_num) return (SegmentAction.CONTINUE, seq_num) if duplicate_in_window( seq_num, expected_seq, self.actual_window_size, self.proposed_window_size ): logger.warning("duplicate segment: seq=%s", seq_num) return (SegmentAction.RESEND_LAST_ACK, self._last_ack_seq) logger.warning("out-of-window segment: seq=%s, expected=%s", seq_num, expected_seq) return (SegmentAction.ABORT, -1)
@property def last_ack_seq(self) -> int: """The sequence number of the last acknowledged segment.""" return self._last_ack_seq @property def is_complete(self) -> bool: """True when all segments have been received.""" if self._final_idx is None: return False return self._expected_idx > self._final_idx
[docs] def reassemble(self) -> bytes: """Concatenate all segments in order. :raises ValueError: If not all segments have been received. """ if not self.is_complete: msg = "Cannot reassemble: not all segments received" raise ValueError(msg) assert self._final_idx is not None parts: list[bytes] = [] for i in range(self._final_idx + 1): parts.append(self._segments[i]) return b"".join(parts)
def _seq_to_abs_idx(self, seq_num: int) -> int: """Map an 8-bit sequence number to an absolute index near _expected_idx.""" # The sequence number corresponds to an index in the vicinity of _expected_idx. # Since sequence numbers wrap at 256, compute the offset from expected. expected_seq = self._expected_idx & 0xFF offset = (seq_num - expected_seq) % 256 return self._expected_idx + offset