Source code for hs_py.storage.redis

"""Redis storage adapter for Haystack servers.

Implements :class:`~hs_py.storage.protocol.StorageAdapter` using Redis 8
with RedisJSON, RedisTimeSeries, and RediSearch.

Requires Redis 8+ (ships with JSON, TimeSeries, and Search modules) and the
``redis[hiredis]`` Python package (installed via ``pip install hs-py[server]``).

Key schema::

    hs:e:{ref_val}          RedisJSON document (entity + _tags index field)
    hs:ids                  Set of all entity ref vals
    hs:tag:{tagname}        Set of ref vals that have this tag
    hs:ts:{ref_val}         TimeSeries key for history data
    hs:pri:{ref_val}        Hash mapping level -> JSON-encoded value
    hs:w:{watch_id}         Hash with watch metadata (dis, lease)
    hs:w:{watch_id}:ids     Set of watched ref vals
    hs:w:{watch_id}:dirty   Set of dirty ref vals

RediSearch index ``hs_idx`` is created on ``hs:e:*`` JSON documents with:

- ``_tags`` TAG field (comma-separated tag names) for Has/Missing queries
- ``siteRef`` TAG field (``$.siteRef.val``) for site navigation and filters
- ``equipRef`` TAG field (``$.equipRef.val``) for equip navigation and filters
"""

from __future__ import annotations

import contextlib
import datetime
import logging
import re
import secrets
from typing import TYPE_CHECKING, Any

import orjson

from hs_py.encoding.json import _decode_val_v4, encode_val
from hs_py.filter import evaluate
from hs_py.filter.ast import And, Cmp, CmpOp, Has, Missing, Node, Or
from hs_py.kinds import Number, Ref
from hs_py.user import User, derive_scram_credentials, user_from_dict, user_to_dict

if TYPE_CHECKING:
    from redis.asyncio import Redis

    from hs_py.tls import TLSConfig

__all__ = ["RedisAdapter", "create_redis_client"]

_log = logging.getLogger(__name__)

# Default connection settings (production best practices).
_SOCKET_TIMEOUT = 10.0
_SOCKET_CONNECT_TIMEOUT = 5.0
_HEALTH_CHECK_INTERVAL = 30
_RETRY_ATTEMPTS = 3
_MAX_CONNECTIONS = 50
_PIPELINE_BATCH_SIZE = 500

# Key prefixes
_E = "hs:e:"
_IDS = "hs:ids"
_TAG = "hs:tag:"
_TS = "hs:ts:"
_PRI = "hs:pri:"
_W = "hs:w:"
_USER = "hs:user:"

# RediSearch index name and schema
_FT_INDEX = "hs_idx"

# Maximum results from a single RediSearch query.
_MAX_FT_RESULTS = 10_000

# Maximum entities to scan in Python fallback when no tag index narrows candidates.
_MAX_FALLBACK_SCAN = 50_000

# Allowed pattern for Redis key components (Haystack identifiers).
_SAFE_KEY_RE = re.compile(r"^[a-zA-Z0-9_:\-.~]+$")

# Ref-valued fields indexed as TAG in RediSearch for efficient querying.
_FT_REF_FIELDS = frozenset({"siteRef", "equipRef"})

# Expected field names in the RediSearch index.
_FT_EXPECTED_FIELDS = frozenset({"_tags", "siteRef", "equipRef"})


def _validate_key_part(value: str, label: str = "key") -> None:
    """Validate that a value is safe for use in a Redis key."""
    if not _SAFE_KEY_RE.match(value):
        msg = f"Invalid characters in {label}: {value!r}"
        raise ValueError(msg)


def _entity_key(ref_val: str) -> str:
    _validate_key_part(ref_val, "ref_val")
    return f"{_E}{ref_val}"


def _tag_key(tag: str) -> str:
    return f"{_TAG}{tag}"


def _ts_key(ref_val: str) -> str:
    return f"{_TS}{ref_val}"


def _pri_key(ref_val: str) -> str:
    return f"{_PRI}{ref_val}"


def _watch_key(watch_id: str) -> str:
    return f"{_W}{watch_id}"


def _watch_ids_key(watch_id: str) -> str:
    return f"{_W}{watch_id}:ids"


def _watch_dirty_key(watch_id: str) -> str:
    return f"{_W}{watch_id}:dirty"


def _encode_entity(entity: dict[str, Any]) -> dict[str, Any]:
    """Encode an entity dict to JSON-serializable form (v4 format).

    Adds a ``_tags`` field with comma-separated tag names for RediSearch
    TAG indexing.  The dict is suitable for passing directly to
    ``r.json().set()``.
    """
    encoded = {k: encode_val(v) for k, v in entity.items()}
    encoded["_tags"] = ",".join(entity.keys())
    return encoded


def _decode_entity(raw: dict[str, Any]) -> dict[str, Any]:
    """Decode a JSON-serialized entity dict back to Haystack kinds.

    Strips the internal ``_tags`` index field.
    """
    raw.pop("_tags", None)
    return {k: _decode_val_v4(v) for k, v in raw.items()}


def _extract_has_tags(node: Node) -> set[str]:
    """Extract simple tag names from Has nodes in a filter AST.

    Used for candidate narrowing via tag index Sets before full evaluation.
    Only extracts tags from Has nodes connected by And -- Or and Cmp nodes
    are not useful for narrowing since they don't guarantee tag presence.
    """
    if isinstance(node, Has) and len(node.path.names) == 1:
        return {node.path.names[0]}
    if isinstance(node, And):
        return _extract_has_tags(node.left) | _extract_has_tags(node.right)
    return set()


def _build_ft_query(node: Node) -> str | None:
    """Try to build a RediSearch query from a filter AST.

    Supports:

    - Has/Missing on single-segment tag paths
    - Cmp ``==`` on Ref-valued indexed fields (siteRef, equipRef)
    - And/Or combinations of the above

    Returns ``None`` if the filter contains unsupported nodes (multi-segment
    paths, non-EQ comparisons, unindexed fields).
    """
    if isinstance(node, Has) and len(node.path.names) == 1:
        tag = node.path.names[0]
        return f"@_tags:{{{_ft_escape(tag)}}}"
    if isinstance(node, Missing) and len(node.path.names) == 1:
        tag = node.path.names[0]
        return f"-@_tags:{{{_ft_escape(tag)}}}"
    if isinstance(node, Cmp) and node.op == CmpOp.EQ and len(node.path.names) == 1:
        field = node.path.names[0]
        if field in _FT_REF_FIELDS:
            if isinstance(node.val, Ref):
                return f"@{field}:{{{_ft_escape(node.val.val)}}}"
            if isinstance(node.val, str):
                return f"@{field}:{{{_ft_escape(node.val)}}}"
    if isinstance(node, And):
        left = _build_ft_query(node.left)
        right = _build_ft_query(node.right)
        if left is not None and right is not None:
            return f"({left} {right})"
        return None
    if isinstance(node, Or):
        left = _build_ft_query(node.left)
        right = _build_ft_query(node.right)
        if left is not None and right is not None:
            return f"({left})|({right})"
        return None
    return None


# Characters that need escaping in RediSearch TAG values
_FT_SPECIAL = frozenset(r",.<>{}[]\"':;!@#$%^&*()-+=~ |/?`")


def _ft_escape(tag: str) -> str:
    """Escape special characters in a RediSearch TAG value."""
    return "".join(f"\\{c}" if c in _FT_SPECIAL else c for c in tag)


def _parse_ft_fields(info: dict[str, Any]) -> set[str]:
    """Extract indexed field attribute names from an ``ft().info()`` response.

    With RESP3, ``attributes`` is a list of dicts each containing an
    ``'attribute'`` key.
    """
    fields: set[str] = set()
    for attr in info.get("attributes", []):
        if isinstance(attr, dict):
            name = attr.get("attribute")
            if isinstance(name, str):
                fields.add(name)
    return fields


[docs] def create_redis_client( url: str = "redis://localhost:6379", *, tls: TLSConfig | None = None, max_connections: int = _MAX_CONNECTIONS, ) -> Redis[str]: """Create an async Redis client with optional TLS 1.3. Uses RESP3 protocol, automatic string decoding, connection health checks, and jittered exponential-backoff retries. When *tls* is provided, the connection enforces TLS 1.3 minimum, loads the configured client certificate for mutual authentication, and verifies the server certificate against the configured CA. :param url: Redis connection URL (``redis://`` or ``rediss://``). :param tls: Optional TLS configuration for encrypted connections. :param max_connections: Maximum connections in the pool (default 50). :returns: An async ``redis.asyncio.Redis`` client. Example:: from hs_py.tls import TLSConfig from hs_py.storage.redis import RedisAdapter, create_redis_client tls = TLSConfig( certificate_path="client.pem", private_key_path="client.key", ca_certificates_path="ca.pem", ) redis = create_redis_client("rediss://redis:6379", tls=tls) adapter = RedisAdapter(redis) """ import ssl as _ssl from redis.asyncio import Redis from redis.backoff import EqualJitterBackoff from redis.exceptions import BusyLoadingError from redis.retry import Retry retry = Retry(EqualJitterBackoff(), _RETRY_ATTEMPTS) common: dict[str, Any] = { "protocol": 3, "decode_responses": True, "max_connections": max_connections, "socket_timeout": _SOCKET_TIMEOUT, "socket_connect_timeout": _SOCKET_CONNECT_TIMEOUT, "socket_keepalive": True, "health_check_interval": _HEALTH_CHECK_INTERVAL, "retry": retry, "retry_on_timeout": True, "retry_on_error": [BusyLoadingError], } if tls is None: return Redis.from_url(url, **common) # Ensure rediss:// scheme for TLS connections tls_url = url.replace("redis://", "rediss://") if url.startswith("redis://") else url return Redis.from_url( # type: ignore[call-overload,no-any-return] tls_url, **common, ssl_certfile=tls.certificate_path, ssl_keyfile=tls.private_key_path, ssl_ca_certs=tls.ca_certificates_path, ssl_password=tls.key_password, ssl_min_version=_ssl.TLSVersion.TLSv1_3, ssl_cert_reqs="required", ssl_check_hostname=True, )
[docs] class RedisAdapter: """StorageAdapter backed by Redis 8 (JSON + TimeSeries + Search). Implements :class:`~hs_py.storage.protocol.StorageAdapter` using RedisJSON for entity storage, RedisTimeSeries for history data, and RediSearch for efficient filter queries. :param redis: A ``redis.asyncio.Redis`` client instance created with ``protocol=3`` and ``decode_responses=True``. """ def __init__(self, redis: Redis[str]) -> None: self._r = redis self._read_cache: dict[tuple[str, int | None], list[dict[str, Any]]] = {} self._read_cache_max = 64 self._all_col_names: tuple[str, ...] | None = None # ---- Lifecycle -----------------------------------------------------------
[docs] async def start(self) -> None: """Verify Redis connection and create RediSearch index.""" await self._r.ping() _log.info("RedisAdapter connected to Redis") await self._ensure_search_index()
[docs] async def close(self) -> None: """Close the Redis connection.""" await self._r.aclose() # type: ignore[attr-defined] _log.info("RedisAdapter disconnected from Redis")
# ---- Search index -------------------------------------------------------- async def _ensure_search_index(self) -> None: """Create or rebuild the RediSearch index with the expected schema. If the index exists but is missing expected fields (e.g. after a schema upgrade), it is dropped and recreated so RediSearch re-indexes all existing JSON documents. """ from redis.commands.search.field import TagField from redis.commands.search.index_definition import IndexDefinition, IndexType from redis.exceptions import ResponseError ft = self._r.ft(_FT_INDEX) try: info = await ft.info() # type: ignore[no-untyped-call] existing = _parse_ft_fields(info) if existing >= _FT_EXPECTED_FIELDS: _log.info("RediSearch index '%s' schema is current", _FT_INDEX) return _log.info("RediSearch index '%s' schema outdated, rebuilding", _FT_INDEX) await ft.dropindex() except ResponseError: pass # Index does not exist yet schema = ( TagField("$._tags", as_name="_tags", separator=","), TagField("$.siteRef.val", as_name="siteRef"), TagField("$.equipRef.val", as_name="equipRef"), ) definition = IndexDefinition( # type: ignore[no-untyped-call] prefix=[_E], index_type=IndexType.JSON ) await ft.create_index(schema, definition=definition) _log.info("Created RediSearch index '%s'", _FT_INDEX) # ---- Internal helpers ---------------------------------------------------- async def _store_entity(self, ref_val: str, entity: dict[str, Any]) -> None: """Store a single entity with tag indexes. When updating an existing entity, stale tag index entries are removed. """ old_entity = await self._load_entity(ref_val) old_tags = set(old_entity) if old_entity else set() new_tags = set(entity) removed_tags = old_tags - new_tags encoded = _encode_entity(entity) pipe = self._r.pipeline() pipe.json().set(_entity_key(ref_val), "$", encoded) pipe.sadd(_IDS, ref_val) for tag in new_tags: pipe.sadd(_tag_key(tag), ref_val) for tag in removed_tags: pipe.srem(_tag_key(tag), ref_val) await pipe.execute() async def _load_entity(self, ref_val: str) -> dict[str, Any] | None: """Load a single entity by ref val.""" raw: dict[str, Any] | None = await self._r.json().get(_entity_key(ref_val)) if raw is None: return None return _decode_entity(raw) async def _load_entities(self, ref_vals: list[str]) -> list[dict[str, Any] | None]: """Load multiple entities by ref val via ``json().mget()``.""" if not ref_vals: return [] keys = [_entity_key(rv) for rv in ref_vals] # json().mget with "$" returns [[doc], [doc], None, ...] per key results: list[Any] = await self._r.json().mget(keys, "$") # type: ignore[no-untyped-call] out: list[dict[str, Any] | None] = [] for r in results: if r is None: out.append(None) elif isinstance(r, list) and r: out.append(_decode_entity(r[0])) else: out.append(None) return out async def _ft_search(self, query_str: str, limit: int | None = None) -> list[dict[str, Any]]: """Execute a RediSearch query and return decoded entities. Returns document content inline from FT.SEARCH (no separate load step). """ from redis.commands.search.query import Query max_results = min(limit, _MAX_FT_RESULTS) if limit is not None else _MAX_FT_RESULTS q = Query(query_str).paging(0, max_results) # type: ignore[no-untyped-call] result: Any = await self._r.ft(_FT_INDEX).search(q) # type: ignore[misc] # RESP3 returns a dict with inline document content total: int = result.get("total_results", 0) if not total: return [] rows: list[dict[str, Any]] = [] for doc in result["results"]: attrs = doc.get("extra_attributes", {}) raw = attrs.get("$", None) if raw is not None: if isinstance(raw, str): raw = orjson.loads(raw) rows.append(_decode_entity(raw)) else: # Fallback: load by ID if content not inline ref_val = doc["id"][len(_E) :] entity = await self._load_entity(ref_val) if entity is not None: rows.append(entity) return rows # ---- StorageAdapter methods ---------------------------------------------- @property def all_col_names(self) -> tuple[str, ...] | None: """Cached column names across all entities, or ``None`` if unknown.""" return self._all_col_names
[docs] async def read_by_filter( self, ast: Node, limit: int | None = None, ) -> list[dict[str, Any]]: """Return entities matching a filter AST. Attempts to delegate fully to RediSearch; falls back to Python evaluation with tag-index candidate narrowing for unsupported filter constructs. :param ast: Compiled filter AST from :func:`~hs_py.filter.parse`. :param limit: Maximum number of results to return. ``None`` means no limit. :returns: List of matching entity dicts. """ # Try to fully delegate to RediSearch ft_query = _build_ft_query(ast) if ft_query is not None: cache_key = (ft_query, limit) cached = self._read_cache.get(cache_key) if cached is not None: return cached results = await self._ft_search(ft_query, limit) if len(self._read_cache) < self._read_cache_max: self._read_cache[cache_key] = results return results # Fallback: use tag index Sets for candidate narrowing, then Python eval has_tags = _extract_has_tags(ast) if has_tags: tag_keys = [_tag_key(t) for t in has_tags] if len(tag_keys) == 1: candidate_ids: list[str] = [str(v) for v in await self._r.smembers(tag_keys[0])] else: candidate_ids = [str(v) for v in await self._r.sinter(*tag_keys)] else: # No tag narrowing — cap full scan to prevent loading entire dataset all_ids = await self._r.srandmember(_IDS, _MAX_FALLBACK_SCAN) candidate_ids = [str(v) for v in (all_ids or [])] if not candidate_ids: return [] entities = await self._load_entities(candidate_ids) # Build a resolver from loaded entities for multi-segment paths entity_map: dict[str, dict[str, Any]] = {} for ref_val, entity in zip(candidate_ids, entities, strict=True): if entity is not None: entity_map[ref_val] = entity def resolver(ref: Ref) -> dict[str, Any] | None: return entity_map.get(ref.val) rows: list[dict[str, Any]] = [] for entity in entity_map.values(): if evaluate(ast, entity, resolver): rows.append(entity) if limit is not None and len(rows) >= limit: break return rows
[docs] async def read_by_ids(self, ids: list[Ref]) -> list[dict[str, Any] | None]: """Return entities for a list of Refs, preserving order. :param ids: Ordered list of entity Refs to fetch. :returns: List the same length as *ids*. Each entry is the entity dict if found, or ``None`` if the Ref does not exist. """ ref_vals = [ref.val for ref in ids] return await self._load_entities(ref_vals)
[docs] async def nav(self, nav_id: str | None = None) -> list[dict[str, Any]]: """Navigate the site/equip/point hierarchy. Uses RediSearch indexed ``siteRef`` and ``equipRef`` fields for efficient lookups instead of loading all entities into memory. :param nav_id: The ``Ref.val`` of the entity whose children should be returned. Pass ``None`` to get root-level sites. :returns: List of child entity dicts. """ if nav_id is None: # Root: return sites via RediSearch return await self._ft_search("@_tags:{site}") # Load the target entity to determine its type target = await self._load_entity(nav_id) if target is None: return [] escaped_id = _ft_escape(nav_id) if "site" in target: # Site -> equips with matching siteRef via RediSearch return await self._ft_search(f"@_tags:{{equip}} @siteRef:{{{escaped_id}}}") if "equip" in target: # Equip -> points with matching equipRef via RediSearch return await self._ft_search(f"@_tags:{{point}} @equipRef:{{{escaped_id}}}") return []
[docs] async def his_read( self, ref: Ref, range_str: str | None = None, ) -> list[dict[str, Any]]: """Return time-series history for a point. :param ref: Entity Ref of the point. :param range_str: Optional range string (currently ignored; all data is returned). :returns: List of dicts with ``"ts"`` (datetime) and ``"val"`` keys. """ from redis.exceptions import ResponseError ts_key = _ts_key(ref.val) try: samples: list[tuple[int, float]] = await self._r.ts().range(ts_key, "-", "+") except ResponseError: return [] # Look up the entity to get the unit entity = await self._load_entity(ref.val) unit = entity.get("unit") if entity is not None else None rows: list[dict[str, Any]] = [] for ts_ms, val_float in samples: dt = datetime.datetime.fromtimestamp(ts_ms / 1000, tz=datetime.UTC) val = Number(val_float, unit) if unit else Number(val_float) rows.append({"ts": dt, "val": val}) return rows
[docs] async def his_write(self, ref: Ref, items: list[dict[str, Any]]) -> None: """Append time-series data for a point. :param ref: Entity Ref of the point. :param items: List of dicts with ``"ts"`` and ``"val"`` keys. """ from redis.exceptions import ResponseError ts_key = _ts_key(ref.val) # Ensure the TS key exists with labels and duplicate policy. # Use try/except to avoid TOCTOU race with concurrent writers. entity = await self._load_entity(ref.val) ts_labels: dict[str, str] = {"entity": ref.val} if entity is not None: unit = entity.get("unit") if isinstance(unit, str): ts_labels["unit"] = unit with contextlib.suppress(ResponseError): await self._r.ts().create(ts_key, duplicate_policy="last", labels=ts_labels) pipe = self._r.pipeline() for row in items: val = row.get("val") ts = row.get("ts") # Extract numeric value if isinstance(val, Number): float_val = val.val elif isinstance(val, (int, float)): float_val = float(val) else: continue # Convert timestamp: datetime → ms epoch, int/str passthrough ts_arg: int | str if isinstance(ts, datetime.datetime): ts_arg = int(ts.timestamp() * 1000) elif isinstance(ts, (int, float)) or (isinstance(ts, str) and ts.isdigit()): ts_arg = int(ts) else: ts_arg = "*" pipe.ts().add(ts_key, ts_arg, float_val) await pipe.execute()
[docs] async def point_write( self, ref: Ref, level: int, val: Any, who: str = "", duration: Any = None, ) -> None: """Write a value to a writable point's priority array. :param ref: Entity Ref of the writable point. :param level: Priority level (1-17). :param val: Value to write. Pass ``None`` to clear the level. :param who: Optional identifier of who is writing (ignored). :param duration: Optional duration override (ignored). """ pri_key = _pri_key(ref.val) if val is None: await self._r.hdel(pri_key, str(level)) else: encoded = orjson.dumps(encode_val(val)) await self._r.hset(pri_key, str(level), encoded)
[docs] async def point_read_array(self, ref: Ref) -> list[dict[str, Any]]: """Return the 17-level priority array for a writable point. :param ref: Entity Ref of the writable point. :returns: List of 17 dicts, each with a ``"level"`` key and an optional ``"val"`` key (absent when the level is unset). """ pri_key = _pri_key(ref.val) raw = await self._r.hgetall(pri_key) rows: list[dict[str, Any]] = [] for level in range(1, 18): row: dict[str, Any] = {"level": Number(float(level)), "val": None} level_str = str(level) if level_str in raw: val_json = orjson.loads(raw[level_str]) row["val"] = _decode_val_v4(val_json) rows.append(row) return rows
[docs] async def watch_sub( self, watch_id: str | None, ids: list[Ref], dis: str = "watch", ) -> tuple[str, list[dict[str, Any]]]: """Create or extend a watch subscription. :param watch_id: Existing watch ID to extend, or ``None`` to create a new watch. :param ids: Entity Refs to add to the watch. :param dis: Human-readable display name for a new watch. :returns: ``(watch_id, entities)`` where *entities* is the current state of all newly subscribed entities. """ if not watch_id or not await self._r.exists(_watch_key(watch_id)): watch_id = f"w-{secrets.token_hex(4)}" await self._r.hset( _watch_key(watch_id), mapping={"dis": str(dis), "lease": "300"}, ) # Collect entity state for subscribed IDs ids_key = _watch_ids_key(watch_id) ref_vals: list[str] = [ref.val for ref in ids] if ref_vals: await self._r.sadd(ids_key, *ref_vals) # Load current state of watched entities entities = await self._load_entities(ref_vals) return watch_id, [e for e in entities if e is not None]
[docs] async def watch_unsub( self, watch_id: str, ids: list[Ref], *, close: bool = False, ) -> None: """Remove entities from a watch, or close the watch entirely. :param watch_id: Watch to modify. :param ids: Entity Refs to remove. Ignored when *close* is ``True``. :param close: If ``True``, the entire watch is torn down. :raises ValueError: If *watch_id* is not found. """ watch_key = _watch_key(watch_id) if not await self._r.exists(watch_key): msg = f"Unknown watch: {watch_id}" raise ValueError(msg) if close: pipe = self._r.pipeline() pipe.delete(watch_key) pipe.delete(_watch_ids_key(watch_id)) pipe.delete(_watch_dirty_key(watch_id)) await pipe.execute() return # Remove specific IDs ids_key = _watch_ids_key(watch_id) dirty_key = _watch_dirty_key(watch_id) ref_vals = [ref.val for ref in ids] if ref_vals: pipe = self._r.pipeline() pipe.srem(ids_key, *ref_vals) pipe.srem(dirty_key, *ref_vals) await pipe.execute()
[docs] async def watch_poll( self, watch_id: str, *, refresh: bool = False, ) -> list[dict[str, Any]]: """Poll for changed entities. :param watch_id: Watch to poll. :param refresh: If ``True``, return all watched entities (full refresh) regardless of dirty state. :returns: List of entity dicts that have changed since the last poll (or all entities if *refresh* is ``True``). :raises ValueError: If *watch_id* is not found. """ watch_key = _watch_key(watch_id) if not await self._r.exists(watch_key): msg = f"Unknown watch: {watch_id}" raise ValueError(msg) ids_key = _watch_ids_key(watch_id) dirty_key = _watch_dirty_key(watch_id) if refresh: # Atomically get watched IDs and clear dirty set pipe = self._r.pipeline() pipe.smembers(ids_key) pipe.delete(dirty_key) results = await pipe.execute() ref_vals: list[str] = [str(v) for v in results[0]] else: # Atomically read dirty + watched sets and clear dirty pipe = self._r.pipeline() pipe.smembers(dirty_key) pipe.smembers(ids_key) pipe.delete(dirty_key) results = await pipe.execute() dirty_members = results[0] watched_members = results[1] watched_set = {str(v) for v in watched_members} ref_vals = [str(v) for v in dirty_members if str(v) in watched_set] if not ref_vals: return [] entities = await self._load_entities(ref_vals) return [e for e in entities if e is not None]
# ---- Non-protocol helpers ------------------------------------------------
[docs] async def load_entities(self, entities: list[dict[str, Any]]) -> int: """Bulk-load a list of entity dicts into Redis. Each entity must have an ``id`` :class:`~hs_py.kinds.Ref`. Entities without an ``id`` are silently skipped. Large batches are chunked to avoid unbounded pipeline memory usage. :param entities: List of entity dicts to load. :returns: Number of entities actually stored. """ count = 0 skipped = 0 cmds_in_pipe = 0 pipe = self._r.pipeline() for entity in entities: ref = entity.get("id") if not isinstance(ref, Ref): skipped += 1 continue encoded = _encode_entity(entity) pipe.json().set(_entity_key(ref.val), "$", encoded) pipe.sadd(_IDS, ref.val) for tag in entity: pipe.sadd(_tag_key(tag), ref.val) count += 1 cmds_in_pipe += 2 + len(entity) if cmds_in_pipe >= _PIPELINE_BATCH_SIZE: await pipe.execute() pipe = self._r.pipeline() cmds_in_pipe = 0 if cmds_in_pipe: await pipe.execute() if skipped: _log.warning("Skipped %d rows without 'id' Ref during load", skipped) _log.info("Loaded %d entities into Redis", count) # Compute column names for Grid construction fast path. seen: dict[str, None] = {} for entity in entities: for key in entity: if key not in seen: seen[key] = None self._all_col_names = tuple(seen) return count
# ---- UserStore implementation -------------------------------------------- def _user_key(self, username: str) -> str: """Return the Redis key for a user.""" if not _SAFE_KEY_RE.match(username): msg = f"Invalid username for Redis key: {username!r}" raise ValueError(msg) return f"{_USER}{username}"
[docs] async def get_user(self, username: str) -> User | None: """Return a user by username, or ``None`` if not found.""" data = await self._r.json().get(self._user_key(username)) if data is None: return None return user_from_dict(data)
[docs] async def list_users(self) -> list[User]: """Return all users.""" keys: list[str] = [] async for key in self._r.scan_iter(match=f"{_USER}*", count=1000): keys.append(str(key)) if not keys: return [] users: list[User] = [] for key in keys: data = await self._r.json().get(key) if data is not None: users.append(user_from_dict(data)) return users
[docs] async def create_user(self, user: User) -> None: """Persist a new user. :raises ValueError: If a user with the same username already exists. """ key = self._user_key(user.username) existing = await self._r.json().get(key) if existing is not None: msg = f"User already exists: {user.username!r}" raise ValueError(msg) await self._r.json().set(key, "$", user_to_dict(user))
[docs] async def update_user(self, username: str, **fields: Any) -> User: """Update fields on an existing user. :raises KeyError: If the user does not exist. """ import time key = self._user_key(username) data = await self._r.json().get(key) if data is None: msg = f"User not found: {username!r}" raise KeyError(msg) existing = user_from_dict(data) updates: dict[str, Any] = {"updated_at": time.time()} if "password" in fields: updates["credentials"] = derive_scram_credentials(fields.pop("password")) allowed = {"first_name", "last_name", "email", "role", "enabled", "credentials"} for k, v in fields.items(): if k in allowed: updates[k] = v from dataclasses import asdict merged = {**asdict(existing), **updates} merged["credentials"] = updates.get("credentials", existing.credentials) # asdict() converts Role enum to its value — restore the enum instance if isinstance(merged.get("role"), str): from hs_py.user import Role merged["role"] = Role(merged["role"]) new_user = User(**merged) await self._r.json().set(key, "$", user_to_dict(new_user)) return new_user
[docs] async def delete_user(self, username: str) -> bool: """Delete a user by username.""" key = self._user_key(username) return bool(await self._r.delete(key))