Source code for hs_py.storage.memory

"""In-memory Haystack storage adapter.

Provides :class:`InMemoryAdapter`, a pure-Python implementation of
:class:`~hs_py.storage.protocol.StorageAdapter` backed by plain dicts.
Suitable for testing, demos, and small deployments that do not require
persistence.
"""

from __future__ import annotations

import secrets
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from hs_py.filter import evaluate
from hs_py.filter.ast import Has
from hs_py.kinds import Number, Ref
from hs_py.user import User, derive_scram_credentials

if TYPE_CHECKING:
    from hs_py.filter.ast import Node

__all__ = ["InMemoryAdapter"]

# Maximum number of history items per point (prevents unbounded memory growth).
_MAX_HISTORY_PER_POINT = 100_000


# ---------------------------------------------------------------------------
# Internal watch state
# ---------------------------------------------------------------------------


@dataclass
class _WatchState:
    """Mutable state for a single active watch subscription."""

    dis: str
    ids: set[str] = field(default_factory=set)
    dirty: set[str] = field(default_factory=set)


# ---------------------------------------------------------------------------
# InMemoryAdapter
# ---------------------------------------------------------------------------


[docs] class InMemoryAdapter: """In-memory implementation of :class:`~hs_py.storage.protocol.StorageAdapter`. All state is held in plain Python dicts; nothing is persisted to disk. Thread-safety is not guaranteed — use within a single asyncio event loop. :param entities: Optional initial list of entity dicts (each must have an ``id`` :class:`~hs_py.kinds.Ref`). """ def __init__(self, entities: list[dict[str, Any]] | None = None) -> None: # ref_val -> entity dict self._entities: dict[str, dict[str, Any]] = {} # tag_name -> set of ref_vals (tag presence index for fast Has queries) self._tag_index: dict[str, set[str]] = {} # ref_val -> list of {ts, val} dicts self._timeseries: dict[str, list[dict[str, Any]]] = {} # ref_val -> {level (1-17) -> val} self._priority: dict[str, dict[int, Any]] = {} # watch_id -> _WatchState self._watches: dict[str, _WatchState] = {} # username -> User self._users: dict[str, User] = {} # Cached column names across all entities (invalidated on load) self._all_col_names: tuple[str, ...] | None = None if entities: self.load_entities(entities) # ---- Lifecycle -----------------------------------------------------------
[docs] async def start(self) -> None: """No-op initializer (in-memory adapter needs no setup)."""
[docs] async def close(self) -> None: """No-op teardown (no resources to release)."""
# ---- Bulk load -----------------------------------------------------------
[docs] def load_entities(self, entities: list[dict[str, Any]]) -> int: """Bulk-load a list of entity dicts. Each entity must have an ``id`` :class:`~hs_py.kinds.Ref`. Entities without an ``id`` are silently skipped. :param entities: List of entity dicts to load. :returns: Number of entities actually stored. """ count = 0 tag_index = self._tag_index for entity in entities: ref = entity.get("id") if isinstance(ref, Ref): ent = dict(entity) ref_val = ref.val self._entities[ref_val] = ent # Update tag presence index for tag_name in ent: idx = tag_index.get(tag_name) if idx is None: idx = set() tag_index[tag_name] = idx idx.add(ref_val) count += 1 # Invalidate column cache self._all_col_names = None return count
# ---- Internal helpers ---------------------------------------------------- def _resolver(self, ref: Ref) -> dict[str, Any] | None: """Resolve a Ref to an entity dict for multi-segment filter paths.""" return self._entities.get(ref.val) def _ref_val(self, ref_or_str: Any) -> str | None: """Extract a ref string value from a Ref or str tag value.""" if isinstance(ref_or_str, Ref): return ref_or_str.val if isinstance(ref_or_str, str): return ref_or_str return None @property def all_col_names(self) -> tuple[str, ...]: """Return all unique tag names across all entities (cached).""" if self._all_col_names is None: seen: dict[str, None] = {} for entity in self._entities.values(): for key in entity: if key not in seen: seen[key] = None self._all_col_names = tuple(seen) return self._all_col_names # ---- Read ops ------------------------------------------------------------
[docs] async def read_by_filter( self, ast: Node, limit: int | None = None, ) -> list[dict[str, Any]]: """Return entities matching a filter AST. :param ast: Compiled filter AST from :func:`~hs_py.filter.parse`. :param limit: Maximum number of results to return. ``None`` means no limit. :returns: List of matching entity dicts. """ # Fast path: simple Has("tagName") uses the tag index directly. if type(ast) is Has and len(ast.path.names) == 1: tag = ast.path.names[0] ref_vals = self._tag_index.get(tag) if ref_vals is None: return [] entities = self._entities if limit is not None: rows: list[dict[str, Any]] = [] for rv in ref_vals: ent = entities.get(rv) if ent is not None: rows.append(ent) if len(rows) >= limit: break return rows return [entities[rv] for rv in ref_vals if rv in entities] # General path: evaluate filter against every entity. rows = [] for entity in self._entities.values(): if evaluate(ast, entity, self._resolver): rows.append(entity) if limit is not None and len(rows) >= limit: break return rows
[docs] async def read_by_ids(self, ids: list[Ref]) -> list[dict[str, Any] | None]: """Return entities for a list of Refs, preserving input order. :param ids: Ordered list of entity Refs to fetch. :returns: List the same length as *ids*. Each entry is the entity dict if found, or ``None`` if the Ref does not exist. """ return [self._entities.get(ref.val) for ref in ids]
# ---- Navigation ----------------------------------------------------------
[docs] async def nav(self, nav_id: str | None = None) -> list[dict[str, Any]]: """Navigate the site/equip/point hierarchy. - ``nav_id=None`` — return all entities with the ``site`` tag. - ``nav_id`` of a site — return equips whose ``siteRef`` matches. - ``nav_id`` of an equip — return points whose ``equipRef`` matches. :param nav_id: Ref val of the parent entity, or ``None`` for roots. :returns: List of child entity dicts. """ if nav_id is None: return [e for e in self._entities.values() if "site" in e] target = self._entities.get(nav_id) if target is None: return [] if "site" in target: # Return equips with a matching siteRef result: list[dict[str, Any]] = [] for entity in self._entities.values(): if "equip" not in entity: continue site_ref_val = self._ref_val(entity.get("siteRef")) if site_ref_val == nav_id: result.append(entity) return result if "equip" in target: # Return points with a matching equipRef result = [] for entity in self._entities.values(): if "point" not in entity: continue equip_ref_val = self._ref_val(entity.get("equipRef")) if equip_ref_val == nav_id: result.append(entity) return result return []
# ---- History ops ---------------------------------------------------------
[docs] async def his_read( self, ref: Ref, range_str: str | None = None, ) -> list[dict[str, Any]]: """Return time-series history for a point. :param ref: Entity Ref of the point. :param range_str: Optional range string (currently ignored; all data is returned). :returns: List of dicts with ``"ts"`` and ``"val"`` keys. """ return list(self._timeseries.get(ref.val, []))
[docs] async def his_write(self, ref: Ref, items: list[dict[str, Any]]) -> None: """Append time-series data for a point. :param ref: Entity Ref of the point. :param items: List of dicts with ``"ts"`` and ``"val"`` keys to append. """ bucket = self._timeseries.setdefault(ref.val, []) bucket.extend(items) # Cap per-point history to prevent unbounded memory growth if len(bucket) > _MAX_HISTORY_PER_POINT: del bucket[: len(bucket) - _MAX_HISTORY_PER_POINT]
# ---- Priority array ops --------------------------------------------------
[docs] async def point_write( self, ref: Ref, level: int, val: Any, who: str = "", duration: Any = None, ) -> None: """Write a value to a writable point's priority array. :param ref: Entity Ref of the writable point. :param level: Priority level (1-17). Level 17 is the default. :param val: Value to write. Pass ``None`` to clear the level. :param who: Optional identifier of who is writing (stored for reference but not currently used). :param duration: Ignored by this backend. """ pri = self._priority.setdefault(ref.val, {}) if val is None: pri.pop(level, None) else: pri[level] = val
[docs] async def point_read_array(self, ref: Ref) -> list[dict[str, Any]]: """Return the 17-level priority array for a writable point. :param ref: Entity Ref of the writable point. :returns: List of 17 dicts, each containing ``"level"`` (:class:`Number`) and optionally ``"val"`` (absent when the level is unset). """ pri = self._priority.get(ref.val, {}) rows: list[dict[str, Any]] = [] for level in range(1, 18): row: dict[str, Any] = {"level": Number(float(level))} if level in pri: row["val"] = pri[level] rows.append(row) return rows
# ---- Watch ops -----------------------------------------------------------
[docs] async def watch_sub( self, watch_id: str | None, ids: list[Ref], dis: str = "watch", ) -> tuple[str, list[dict[str, Any]]]: """Create or extend a watch subscription. :param watch_id: Existing watch ID to extend, or ``None`` to create a new watch. If the provided ID does not exist it is treated as ``None`` and a new watch is created. :param ids: Entity Refs to subscribe to. :param dis: Human-readable name for a newly created watch. :returns: ``(watch_id, entities)`` — the (possibly new) watch ID and the current state of all subscribed entities. """ # Resolve or create the watch if watch_id is None or watch_id not in self._watches: watch_id = f"w-{secrets.token_hex(4)}" self._watches[watch_id] = _WatchState(dis=dis) state = self._watches[watch_id] for ref in ids: state.ids.add(ref.val) # Return current state of subscribed entities entities = [self._entities[rv] for rv in state.ids if rv in self._entities] return watch_id, entities
[docs] async def watch_unsub( self, watch_id: str, ids: list[Ref], *, close: bool = False, ) -> None: """Remove entities from a watch, or close the watch entirely. :param watch_id: Watch to modify. :param ids: Entity Refs to remove. Ignored when *close* is ``True``. :param close: If ``True``, tear down the entire watch. """ if watch_id not in self._watches: return if close: del self._watches[watch_id] return state = self._watches[watch_id] for ref in ids: state.ids.discard(ref.val) state.dirty.discard(ref.val)
[docs] async def watch_poll( self, watch_id: str, *, refresh: bool = False, ) -> list[dict[str, Any]]: """Poll a watch for changed entities. :param watch_id: Watch to poll. :param refresh: If ``True``, return all watched entities regardless of dirty state. :returns: List of entity dicts that have changed since the last poll, or all watched entities when *refresh* is ``True``. The dirty set is cleared after polling. """ state = self._watches.get(watch_id) if state is None: return [] ref_vals = set(state.ids) if refresh else state.dirty & state.ids state.dirty.clear() return [self._entities[rv] for rv in ref_vals if rv in self._entities]
# ---- Mutation helpers (for testing / server push) -----------------------
[docs] def mark_dirty(self, ref_val: str) -> None: """Mark an entity as changed in all watches that subscribe to it. :param ref_val: The ``Ref.val`` of the entity that changed. """ for state in self._watches.values(): if ref_val in state.ids: state.dirty.add(ref_val)
# ---- UserStore implementation --------------------------------------------
[docs] async def get_user(self, username: str) -> User | None: """Return a user by username, or ``None`` if not found.""" return self._users.get(username)
[docs] async def list_users(self) -> list[User]: """Return all users.""" return list(self._users.values())
[docs] async def create_user(self, user: User) -> None: """Persist a new user. :raises ValueError: If a user with the same username already exists. """ if user.username in self._users: msg = f"User already exists: {user.username!r}" raise ValueError(msg) self._users[user.username] = user
[docs] async def update_user(self, username: str, **fields: Any) -> User: """Update fields on an existing user. :raises KeyError: If the user does not exist. """ existing = self._users.get(username) if existing is None: msg = f"User not found: {username!r}" raise KeyError(msg) import time updates: dict[str, Any] = {"updated_at": time.time()} if "password" in fields: updates["credentials"] = derive_scram_credentials(fields.pop("password")) allowed = {"first_name", "last_name", "email", "role", "enabled", "credentials"} for key, val in fields.items(): if key in allowed: updates[key] = val from dataclasses import asdict merged = {**asdict(existing), **updates} merged["credentials"] = updates.get("credentials", existing.credentials) # asdict() converts Role enum to its value — restore the enum instance if isinstance(merged.get("role"), str): from hs_py.user import Role merged["role"] = Role(merged["role"]) new_user = User(**merged) self._users[username] = new_user return new_user
[docs] async def delete_user(self, username: str) -> bool: """Delete a user by username.""" return self._users.pop(username, None) is not None