Source code for hs_py.filter.eval

"""Evaluate Haystack filter AST nodes against entity dicts.

Supports single-segment and multi-segment (dereference) path expressions.
Multi-segment paths require a resolver callback to look up Refs.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

from hs_py.filter.ast import And, Cmp, CmpOp, Has, Missing, Node, Or
from hs_py.grid import Grid
from hs_py.kinds import Number, Ref

__all__ = [
    "Resolver",
    "evaluate",
    "evaluate_grid",
]

#: Callback type for resolving a Ref to an entity dict.
#: Returns None if the Ref cannot be resolved.
Resolver = Callable[[Ref], dict[str, Any] | None]


[docs] def evaluate( node: Node, entity: dict[str, Any], resolver: Resolver | None = None, ) -> bool: """Evaluate a filter AST against an entity dict. :param node: Root of the filter AST. :param entity: Tag dict to test. :param resolver: Optional callback to resolve Refs for multi-segment paths. :returns: ``True`` if the entity matches the filter. """ return _eval(node, entity, resolver)
[docs] def evaluate_grid( node: Node, grid: Grid, resolver: Resolver | None = None, ) -> Grid: """Filter a grid, returning only rows that match the filter. :param node: Root of the filter AST. :param grid: Grid to filter. :param resolver: Optional callback to resolve Refs for multi-segment paths. If not provided and the grid has an ``id`` column, an auto-resolver is created from the grid's rows. :returns: New :class:`~hs_py.grid.Grid` containing only matching rows. """ if resolver is None: resolver = _grid_resolver(grid) matching = tuple(row for row in grid if _eval(node, row, resolver)) if not matching: return Grid.make_empty() return Grid(meta=grid.meta, cols=grid.cols, rows=matching)
# ---- Internal evaluation ---------------------------------------------------- #: Sentinel for missing tag values (distinct from None). _MISSING = object() # Pre-compute CmpOp members for identity comparison (avoids Enum __eq__). _OP_EQ = CmpOp.EQ _OP_NE = CmpOp.NE _OP_LT = CmpOp.LT _OP_LE = CmpOp.LE _OP_GT = CmpOp.GT def _eval_has(node: Has, entity: dict[str, Any], resolver: Resolver | None) -> bool: names = node.path.names if len(names) == 1: return names[0] in entity return _resolve_path_multi(names, entity, resolver) is not _MISSING def _eval_missing(node: Missing, entity: dict[str, Any], resolver: Resolver | None) -> bool: names = node.path.names if len(names) == 1: return names[0] not in entity return _resolve_path_multi(names, entity, resolver) is _MISSING def _eval_cmp(node: Cmp, entity: dict[str, Any], resolver: Resolver | None) -> bool: names = node.path.names if len(names) == 1: val = entity.get(names[0], _MISSING) else: val = _resolve_path_multi(names, entity, resolver) if val is _MISSING: return False return _compare(val, node.op, node.val) def _eval_and(node: And, entity: dict[str, Any], resolver: Resolver | None) -> bool: return _eval(node.left, entity, resolver) and _eval(node.right, entity, resolver) def _eval_or(node: Or, entity: dict[str, Any], resolver: Resolver | None) -> bool: return _eval(node.left, entity, resolver) or _eval(node.right, entity, resolver) # Type → handler dispatch table (O(1) lookup instead of isinstance chain). _EVAL_DISPATCH: dict[type, Callable[..., bool]] = { Has: _eval_has, Missing: _eval_missing, Cmp: _eval_cmp, And: _eval_and, Or: _eval_or, } def _eval(node: Node, entity: dict[str, Any], resolver: Resolver | None) -> bool: handler = _EVAL_DISPATCH.get(type(node)) if handler is not None: return handler(node, entity, resolver) msg = f"Unknown node type: {type(node).__name__}" raise TypeError(msg) def _resolve_path_multi( names: tuple[str, ...], entity: dict[str, Any], resolver: Resolver | None ) -> Any: """Walk a multi-segment path, returning _MISSING if any segment fails.""" current: Any = entity last = len(names) - 1 for i in range(last): if not isinstance(current, dict): return _MISSING val = current.get(names[i], _MISSING) if val is _MISSING: return _MISSING if not isinstance(val, Ref): return _MISSING if resolver is None: return _MISSING resolved = resolver(val) if resolved is None: return _MISSING current = resolved # Last segment — just a dict lookup. if not isinstance(current, dict): return _MISSING return current.get(names[last], _MISSING) def _compare(left: Any, op: CmpOp, right: Any) -> bool: """Compare two values using a comparison operator. Number comparison uses numeric value only (unit-agnostic). """ if op is _OP_EQ: return _eq(left, right) if op is _OP_NE: return not _eq(left, right) return _ordered_cmp(left, op, right) def _eq(left: Any, right: Any) -> bool: """Equality check with Number-aware comparison.""" if type(left) is Number and type(right) is Number: return left.val == right.val and left.unit == right.unit if type(left) is Ref and type(right) is Ref: return left.val == right.val return left == right # type: ignore[no-any-return] def _ordered_cmp(left: Any, op: CmpOp, right: Any) -> bool: """Ordered comparison (<, <=, >, >=).""" lv = left.val if type(left) is Number else left rv = right.val if type(right) is Number else right try: if op is _OP_LT: return lv < rv if op is _OP_LE: return lv <= rv if op is _OP_GT: return lv > rv return lv >= rv except TypeError: return False def _grid_resolver(grid: Grid) -> Resolver | None: """Build a resolver from grid rows that have an ``id`` column.""" if "id" not in grid.col_names: return None index: dict[str, dict[str, Any]] = {} for row in grid: ref = row.get("id") if isinstance(ref, Ref): index[ref.val] = row if not index: return None def _resolve(ref: Ref) -> dict[str, Any] | None: return index.get(ref.val) return _resolve