Source code for hs_py.storage.timescale

"""TimescaleDB StorageAdapter for Haystack server backends.

Provides ``TimescaleAdapter`` implementing the ``StorageAdapter`` Protocol
using asyncpg for async PostgreSQL/TimescaleDB access.

Schema is created automatically on ``start()``. Entities are stored as JSONB
in ``hs_entities``. Time-series history uses ``hs_history`` (optionally a
TimescaleDB hypertable). Priority arrays are stored in ``hs_priority``.
Watches are tracked in ``hs_watches`` and ``hs_watch_entities``.

Usage::

    pool = await create_timescale_pool("postgresql://localhost/haystack")
    adapter = TimescaleAdapter(pool)
    await adapter.start()
    await adapter.load_entities([{"id": Ref("site1"), "site": MARKER, "dis": "My Site"}])
"""

from __future__ import annotations

import datetime
import logging
import re
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    import asyncpg  # type: ignore[import-untyped]

    from hs_py.filter.ast import Node
    from hs_py.kinds import Ref
    from hs_py.tls import TLSConfig
    from hs_py.user import User

__all__ = [
    "TimescaleAdapter",
    "create_timescale_pool",
]

_log = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Schema SQL
# ---------------------------------------------------------------------------

_SCHEMA_SQL = """
CREATE TABLE IF NOT EXISTS hs_entities (
    id TEXT PRIMARY KEY,
    tags JSONB NOT NULL DEFAULT '{}',
    created_at TIMESTAMPTZ DEFAULT now(),
    updated_at TIMESTAMPTZ DEFAULT now()
);
CREATE INDEX IF NOT EXISTS idx_entities_tags ON hs_entities USING GIN (tags);

CREATE TABLE IF NOT EXISTS hs_history (
    entity_id TEXT NOT NULL,
    ts TIMESTAMPTZ NOT NULL,
    val DOUBLE PRECISION NOT NULL,
    unit TEXT,
    PRIMARY KEY (entity_id, ts)
);

CREATE TABLE IF NOT EXISTS hs_priority (
    entity_id TEXT NOT NULL,
    level SMALLINT NOT NULL,
    val JSONB,
    who TEXT DEFAULT '',
    PRIMARY KEY (entity_id, level)
);

CREATE TABLE IF NOT EXISTS hs_watches (
    watch_id TEXT PRIMARY KEY,
    dis TEXT NOT NULL DEFAULT '',
    lease_secs INTEGER DEFAULT 300,
    created_at TIMESTAMPTZ DEFAULT now()
);

CREATE TABLE IF NOT EXISTS hs_watch_entities (
    watch_id TEXT REFERENCES hs_watches(watch_id) ON DELETE CASCADE,
    entity_id TEXT NOT NULL,
    dirty BOOLEAN DEFAULT FALSE,
    PRIMARY KEY (watch_id, entity_id)
);

CREATE TABLE IF NOT EXISTS hs_users (
    username TEXT PRIMARY KEY,
    credentials JSONB NOT NULL,
    first_name TEXT NOT NULL DEFAULT '',
    last_name TEXT NOT NULL DEFAULT '',
    email TEXT NOT NULL DEFAULT '',
    role TEXT NOT NULL DEFAULT 'viewer',
    enabled BOOLEAN NOT NULL DEFAULT TRUE,
    created_at DOUBLE PRECISION NOT NULL DEFAULT 0,
    updated_at DOUBLE PRECISION NOT NULL DEFAULT 0
);
"""

_HYPERTABLE_SQL = """
SELECT create_hypertable('hs_history', 'ts', if_not_exists => true);
"""


# ---------------------------------------------------------------------------
# Encoding helpers
# ---------------------------------------------------------------------------


def _encode_tags(entity: dict[str, Any]) -> dict[str, Any]:
    """Encode entity tags to a JSONB-compatible dict using Haystack JSON v4."""
    from hs_py.encoding.json import encode_val

    return {k: encode_val(v) for k, v in entity.items()}


def _decode_tags(tags: dict[str, Any]) -> dict[str, Any]:
    """Decode JSONB tags dict back to Haystack kinds using JSON v4."""
    from hs_py.encoding.json import _decode_val_v4

    return {k: _decode_val_v4(v) for k, v in tags.items()}


# ---------------------------------------------------------------------------
# Filter AST → PostgreSQL JSONB translation
# ---------------------------------------------------------------------------


def _is_simple_path(path: Any) -> bool:
    """Return True if the path has exactly one segment (simple tag name)."""
    return len(path.names) == 1


def _ast_to_sql(
    node: Node,
    params: list[Any],
) -> str | None:
    """Translate a filter AST node to a PostgreSQL JSONB expression.

    Returns a SQL fragment string, or ``None`` if the node cannot be
    translated (falls back to Python evaluation).

    :param node: Filter AST node to translate.
    :param params: Mutable list to append parameter values to.
    :returns: SQL fragment or ``None``.
    """
    from hs_py.filter.ast import And, Cmp, CmpOp, Has, Missing, Or

    if isinstance(node, Has):
        if not _is_simple_path(node.path):
            return None
        name = node.path.names[0]
        return f"tags ? {_pg_literal(name)}"

    if isinstance(node, Missing):
        if not _is_simple_path(node.path):
            return None
        name = node.path.names[0]
        return f"NOT (tags ? {_pg_literal(name)})"

    if isinstance(node, Cmp):
        if not _is_simple_path(node.path):
            return None
        name = node.path.names[0]
        sql_val = _encode_cmp_val(node.val)
        if sql_val is None:
            return None

        idx = len(params) + 1

        if node.op == CmpOp.EQ:
            params.append(sql_val)
            return f"tags->>{_pg_literal(name)} = ${idx}"
        if node.op == CmpOp.NE:
            params.append(sql_val)
            return f"tags->>{_pg_literal(name)} != ${idx}"

        # Numeric comparisons — pass float so asyncpg binds correctly
        try:
            float_val = float(sql_val)
        except (ValueError, TypeError):
            return None
        params.append(float_val)

        if node.op == CmpOp.GT:
            return f"(tags->>{_pg_literal(name)})::float > ${idx}::float"
        if node.op == CmpOp.GE:
            return f"(tags->>{_pg_literal(name)})::float >= ${idx}::float"
        if node.op == CmpOp.LT:
            return f"(tags->>{_pg_literal(name)})::float < ${idx}::float"
        if node.op == CmpOp.LE:
            return f"(tags->>{_pg_literal(name)})::float <= ${idx}::float"
        return None

    if isinstance(node, And):
        left = _ast_to_sql(node.left, params)
        right = _ast_to_sql(node.right, params)
        if left is None or right is None:
            return None
        return f"({left}) AND ({right})"

    if isinstance(node, Or):
        left = _ast_to_sql(node.left, params)
        right = _ast_to_sql(node.right, params)
        if left is None or right is None:
            return None
        return f"({left}) OR ({right})"

    return None


def _encode_cmp_val(val: Any) -> str | None:
    """Convert a Haystack comparison value to a string for SQL parameter binding.

    Returns ``None`` for unsupported types.
    """
    from hs_py.kinds import Marker, Number, Ref

    if isinstance(val, str):
        return val
    if isinstance(val, bool):
        return str(val).lower()
    if isinstance(val, int | float):
        return str(val)
    if isinstance(val, Number):
        return str(val.val)
    if isinstance(val, Ref):
        return val.val
    if isinstance(val, Marker):
        return None  # Markers can't be meaningfully compared as strings
    return None


def _pg_literal(name: str) -> str:
    """Return a PostgreSQL string literal for a column/key name.

    Uses single-quoting with embedded single quotes doubled.
    Validates that the name matches Haystack tag name rules.

    :raises ValueError: If *name* contains disallowed characters.
    """
    if not _TAG_NAME_RE.match(name):
        msg = f"Invalid tag name for SQL: {name!r}"
        raise ValueError(msg)
    escaped = name.replace("'", "''")
    return f"'{escaped}'"


# Strict Haystack tag name pattern: starts with lowercase letter,
# then alphanumeric or underscore.
_TAG_NAME_RE = re.compile(r"^[a-zA-Z][a-zA-Z0-9_]*$")

# Maximum rows to scan in Python fallback when SQL translation fails.
_MAX_FALLBACK_ROWS = 50_000


# ---------------------------------------------------------------------------
# History range parsing
# ---------------------------------------------------------------------------


def _parse_his_range(
    range_str: str,
) -> tuple[datetime.datetime, datetime.datetime]:
    """Parse a Haystack history range string to a (start, end) UTC pair.

    Supported formats:
    - ``"today"`` — today 00:00 to tomorrow 00:00 UTC
    - ``"yesterday"`` — yesterday 00:00 to today 00:00 UTC
    - ``"YYYY-MM-DD"`` — single date (00:00 to 24:00)
    - ``"YYYY-MM-DD,YYYY-MM-DD"`` — date range (inclusive)
    - ``"YYYY-MM-DDTHH:MM:SS,YYYY-MM-DDTHH:MM:SS"`` — datetime range
    """
    utc = datetime.UTC
    today = datetime.datetime.now(utc).date()

    range_str = range_str.strip()

    if range_str == "today":
        start = datetime.datetime(today.year, today.month, today.day, tzinfo=utc)
        end = start + datetime.timedelta(days=1)
        return start, end

    if range_str == "yesterday":
        yesterday = today - datetime.timedelta(days=1)
        start = datetime.datetime(yesterday.year, yesterday.month, yesterday.day, tzinfo=utc)
        end = start + datetime.timedelta(days=1)
        return start, end

    if "," in range_str:
        parts = [p.strip() for p in range_str.split(",", 1)]
        start = _parse_datetime_str(parts[0], utc)
        end_raw = _parse_datetime_str(parts[1], utc)
        # If end is a date (no time component), extend to end of day
        if "T" not in parts[1] and len(parts[1]) == 10:
            end_raw = end_raw + datetime.timedelta(days=1)
        return start, end_raw

    # Single date
    start = _parse_datetime_str(range_str, utc)
    end = start + datetime.timedelta(days=1)
    return start, end


def _parse_datetime_str(s: str, utc: datetime.timezone) -> datetime.datetime:
    """Parse a date or datetime string to a UTC datetime."""
    if "T" in s or " " in s:
        dt = datetime.datetime.fromisoformat(s.replace(" ", "T"))
        if dt.tzinfo is None:
            dt = dt.replace(tzinfo=utc)
        return dt.astimezone(utc)
    # date-only
    d = datetime.date.fromisoformat(s)
    return datetime.datetime(d.year, d.month, d.day, tzinfo=utc)


# ---------------------------------------------------------------------------
# TimescaleAdapter
# ---------------------------------------------------------------------------


[docs] class TimescaleAdapter: """StorageAdapter backed by PostgreSQL/TimescaleDB via asyncpg. :param pool: asyncpg connection pool to use. """ def __init__(self, pool: asyncpg.Pool[Any]) -> None: """Initialise the adapter with an existing asyncpg pool. :param pool: Open asyncpg connection pool. """ self._pool = pool self._read_cache: dict[tuple[str, int | None], list[dict[str, Any]]] = {} self._read_cache_max = 64 self._all_col_names: tuple[str, ...] | None = None # ----------------------------------------------------------------------- # Lifecycle # -----------------------------------------------------------------------
[docs] async def start(self) -> None: """Create database schema (idempotent). Runs DDL to create tables and indexes if they do not already exist. Attempts to create a TimescaleDB hypertable for ``hs_history``; the attempt is silently skipped if TimescaleDB is not available. """ import asyncpg as _asyncpg async with self._pool.acquire() as conn: await conn.execute(_SCHEMA_SQL) try: await conn.execute(_HYPERTABLE_SQL) except _asyncpg.UndefinedFunctionError: _log.debug("TimescaleDB hypertable creation skipped (not available)")
[docs] async def close(self) -> None: """Close the underlying connection pool.""" await self._pool.close()
# ----------------------------------------------------------------------- # Entity read operations # ----------------------------------------------------------------------- @property def all_col_names(self) -> tuple[str, ...] | None: """Cached column names across all entities, or ``None`` if unknown.""" return self._all_col_names
[docs] async def read_by_filter( self, ast: Node, limit: int | None = None, ) -> list[dict[str, Any]]: """Return entities matching the filter AST. Translates simple (single-segment) filter nodes to JSONB SQL. Falls back to Python-side evaluation via :func:`hs_py.filter.evaluate` for multi-segment paths or unsupported node types. Results are cached by ``(sql_clause, limit)`` to avoid re-decoding on repeated reads. :param ast: Parsed filter AST. :param limit: Maximum number of entities to return. ``None`` means no limit. :returns: List of entity tag dicts. """ from hs_py.filter.eval import evaluate params: list[Any] = [] sql_clause = _ast_to_sql(ast, params) if sql_clause is not None: cache_key = (sql_clause, limit) cached = self._read_cache.get(cache_key) if cached is not None: return cached base = "SELECT tags FROM hs_entities WHERE " + sql_clause if limit is not None: base += f" LIMIT {limit}" async with self._pool.acquire() as conn: rows = await conn.fetch(base, *params) results = [_decode_tags(dict(row["tags"])) for row in rows] if len(self._read_cache) < self._read_cache_max: self._read_cache[cache_key] = results else: base = f"SELECT tags FROM hs_entities LIMIT {_MAX_FALLBACK_ROWS}" async with self._pool.acquire() as conn: rows = await conn.fetch(base) results = [] for row in rows: entity = _decode_tags(dict(row["tags"])) if evaluate(ast, entity): results.append(entity) if limit is not None and len(results) >= limit: break return results
[docs] async def read_by_ids(self, ids: list[Ref]) -> list[dict[str, Any] | None]: """Return entities for a list of Refs, preserving input order. :param ids: Ordered list of entity Refs to fetch. :returns: List the same length as *ids*. Each entry is the entity dict if found, or ``None`` if the Ref does not exist. """ if not ids: return [] ref_vals = [ref.val for ref in ids] async with self._pool.acquire() as conn: rows = await conn.fetch( "SELECT id, tags FROM hs_entities WHERE id = ANY($1::text[])", ref_vals, ) by_id: dict[str, dict[str, Any]] = {} for row in rows: by_id[row["id"]] = _decode_tags(dict(row["tags"])) return [by_id.get(rv) for rv in ref_vals]
[docs] async def nav(self, nav_id: str | None = None) -> list[dict[str, Any]]: """Navigate the entity tree. - ``nav_id=None`` returns all site entities. - ``nav_id`` set to a site id returns equip entities for that site. - ``nav_id`` set to an equip id returns point entities for that equip. :param nav_id: ``None`` for root, or an entity id to navigate into. :returns: List of entity tag dicts. """ async with self._pool.acquire() as conn: if nav_id is None: # Root: return sites rows = await conn.fetch("SELECT tags FROM hs_entities WHERE tags ? 'site'") else: # Determine whether nav_id is a site or equip parent = await conn.fetchrow( "SELECT tags FROM hs_entities WHERE id = $1", nav_id, ) if parent is None: return [] parent_tags = _decode_tags(dict(parent["tags"])) if "site" in parent_tags: # Site → return equips referencing this site rows = await conn.fetch( """ SELECT tags FROM hs_entities WHERE tags ? 'equip' AND tags->'siteRef'->>'val' = $1 """, nav_id, ) else: # Equip (or anything else) → return points referencing this equip rows = await conn.fetch( """ SELECT tags FROM hs_entities WHERE tags ? 'point' AND tags->'equipRef'->>'val' = $1 """, nav_id, ) return [_decode_tags(dict(row["tags"])) for row in rows]
# ----------------------------------------------------------------------- # History operations # -----------------------------------------------------------------------
[docs] async def his_read( self, ref: Ref, range_str: str | None = None, ) -> list[dict[str, Any]]: """Return time-series history for a point. :param ref: Entity Ref of the point. :param range_str: Optional Haystack range string (e.g. ``"today"``, ``"2024-01-01,2024-01-31"``). If ``None``, all data is returned. :returns: List of dicts with ``"ts"`` (datetime) and ``"val"`` keys. """ ref_val = ref.val if range_str is not None: start, end = _parse_his_range(range_str) else: start = datetime.datetime.min.replace(tzinfo=datetime.UTC) end = datetime.datetime.max.replace(tzinfo=datetime.UTC) async with self._pool.acquire() as conn: rows = await conn.fetch( """ SELECT ts, val, unit FROM hs_history WHERE entity_id = $1 AND ts >= $2 AND ts < $3 ORDER BY ts """, ref_val, start, end, ) result = [] for row in rows: item: dict[str, Any] = {"ts": row["ts"], "val": row["val"]} if row["unit"] is not None: item["unit"] = row["unit"] result.append(item) return result
[docs] async def his_write(self, ref: Ref, items: list[dict[str, Any]]) -> None: """Append time-series data for a point. :param ref: Entity Ref of the point. :param items: List of dicts with ``"ts"`` and ``"val"`` keys. """ from hs_py.kinds import Number if not items: return ref_val = ref.val records = [] for item in items: ts = item["ts"] raw_val = item.get("val", 0) unit: str | None = None if isinstance(raw_val, Number): unit = raw_val.unit val = raw_val.val else: val = float(raw_val) if isinstance(ts, datetime.datetime) and ts.tzinfo is None: ts = ts.replace(tzinfo=datetime.UTC) records.append((ref_val, ts, val, unit)) async with self._pool.acquire() as conn: await conn.executemany( """ INSERT INTO hs_history (entity_id, ts, val, unit) VALUES ($1, $2, $3, $4) ON CONFLICT (entity_id, ts) DO UPDATE SET val = EXCLUDED.val, unit = EXCLUDED.unit """, records, )
# ----------------------------------------------------------------------- # Priority array operations # -----------------------------------------------------------------------
[docs] async def point_write( self, ref: Ref, level: int, val: Any, who: str = "", duration: Any = None, ) -> None: """Write a value to a writable point's priority array. :param ref: Entity Ref of the writable point. :param level: Priority level (1-17). :param val: Value to write. Pass ``None`` to clear the level. :param who: Optional identifier of who is writing. :param duration: Ignored by this backend. """ import orjson from hs_py.encoding.json import encode_val ref_val = ref.val if val is None: async with self._pool.acquire() as conn: await conn.execute( "DELETE FROM hs_priority WHERE entity_id = $1 AND level = $2", ref_val, level, ) else: encoded = encode_val(val) val_json = orjson.dumps(encoded).decode() async with self._pool.acquire() as conn: await conn.execute( """ INSERT INTO hs_priority (entity_id, level, val, who) VALUES ($1, $2, $3::jsonb, $4) ON CONFLICT (entity_id, level) DO UPDATE SET val = EXCLUDED.val, who = EXCLUDED.who """, ref_val, level, val_json, who or "", )
[docs] async def point_read_array(self, ref: Ref) -> list[dict[str, Any]]: """Return the 17-level priority array for a writable point. :param ref: Entity Ref of the writable point. :returns: List of 17 dicts, each with a ``"level"`` key and an optional ``"val"`` key (absent when the level is unset). """ from hs_py.encoding.json import decode_val from hs_py.kinds import Number async with self._pool.acquire() as conn: rows = await conn.fetch( "SELECT level, val, who FROM hs_priority WHERE entity_id = $1", ref.val, ) by_level: dict[int, dict[str, Any]] = {} for row in rows: raw_val = row["val"] decoded_val: Any = None if raw_val is not None: import orjson obj = orjson.loads(raw_val) if isinstance(raw_val, str) else raw_val decoded_val = decode_val(obj) by_level[row["level"]] = {"level": Number(float(row["level"])), "val": decoded_val} rows_out: list[dict[str, Any]] = [] for lvl in range(1, 18): if lvl in by_level: rows_out.append(by_level[lvl]) else: row_d: dict[str, Any] = {"level": Number(float(lvl)), "val": None} rows_out.append(row_d) return rows_out
# ----------------------------------------------------------------------- # Watch operations # -----------------------------------------------------------------------
[docs] async def watch_sub( self, watch_id: str | None, ids: list[Ref], dis: str = "watch", ) -> tuple[str, list[dict[str, Any]]]: """Create or extend a watch subscription. :param watch_id: Existing watch ID to extend, or ``None`` to create a new watch. :param ids: Entity Refs to add to the watch. :param dis: Human-readable display name for a new watch. :returns: ``(watch_id, entities)`` where *entities* is the current state of all newly subscribed entities. """ import secrets if not watch_id: watch_id = f"w-{secrets.token_hex(4)}" ref_vals = [ref.val for ref in ids] async with self._pool.acquire() as conn, conn.transaction(): await conn.execute( """ INSERT INTO hs_watches (watch_id, dis) VALUES ($1, $2) ON CONFLICT (watch_id) DO NOTHING """, watch_id, dis, ) if ref_vals: records = [(watch_id, rv) for rv in ref_vals] await conn.executemany( """ INSERT INTO hs_watch_entities (watch_id, entity_id, dirty) VALUES ($1, $2, true) ON CONFLICT (watch_id, entity_id) DO NOTHING """, records, ) entities = await self._fetch_watch_entities_with_conn(conn, watch_id) return watch_id, entities
[docs] async def watch_unsub( self, watch_id: str, ids: list[Ref], *, close: bool = False, ) -> None: """Remove entities from a watch, or close the watch entirely. :param watch_id: Watch to modify. :param ids: Entity Refs to remove. Ignored when *close* is ``True``. :param close: If ``True``, the entire watch is torn down. """ async with self._pool.acquire() as conn: if close: await conn.execute( "DELETE FROM hs_watches WHERE watch_id = $1", watch_id, ) else: ref_vals = [ref.val for ref in ids] if ref_vals: await conn.execute( """ DELETE FROM hs_watch_entities WHERE watch_id = $1 AND entity_id = ANY($2::text[]) """, watch_id, ref_vals, )
[docs] async def watch_poll( self, watch_id: str, *, refresh: bool = False, ) -> list[dict[str, Any]]: """Poll a watch for changed entities. Uses a transaction to atomically fetch and clear dirty flags, preventing lost notifications from concurrent writers. :param watch_id: Watch to poll. :param refresh: If ``True``, return all watched entities regardless of dirty state. :returns: List of entity dicts that have changed since the last poll. """ async with self._pool.acquire() as conn, conn.transaction(): if refresh: entities = await self._fetch_watch_entities_with_conn(conn, watch_id) else: entities = await self._fetch_dirty_watch_entities_with_conn(conn, watch_id) await conn.execute( "UPDATE hs_watch_entities SET dirty = false WHERE watch_id = $1", watch_id, ) return entities
# ----------------------------------------------------------------------- # Helper: load_entities # -----------------------------------------------------------------------
[docs] async def load_entities(self, entities: list[dict[str, Any]]) -> int: """Bulk-upsert a list of entity dicts into the store. Uses a staging table with COPY for fast bulk loading, then upserts into the main table. The ``id`` tag (a :class:`~hs_py.kinds.Ref`) is extracted and used as the primary key. Entities without an ``id`` are skipped. :param entities: List of entity tag dicts. :returns: Number of entities written. """ import orjson from hs_py.kinds import Ref records = [] for entity in entities: id_tag = entity.get("id") if not isinstance(id_tag, Ref): continue tags_json = orjson.dumps(_encode_tags(entity)).decode() records.append((id_tag.val, tags_json)) if not records: return 0 async with self._pool.acquire() as conn, conn.transaction(): await conn.execute("CREATE TEMP TABLE _staging (id TEXT, tags TEXT) ON COMMIT DROP") await conn.copy_records_to_table("_staging", records=records, columns=["id", "tags"]) await conn.execute( """ INSERT INTO hs_entities (id, tags, updated_at) SELECT id, tags::jsonb, now() FROM _staging ON CONFLICT (id) DO UPDATE SET tags = EXCLUDED.tags, updated_at = now() """ ) # Compute column names for Grid construction fast path. seen: dict[str, None] = {} for entity in entities: for key in entity: if key not in seen: seen[key] = None self._all_col_names = tuple(seen) return len(records)
# ----------------------------------------------------------------------- # Private helpers # ----------------------------------------------------------------------- async def _fetch_watch_entities_with_conn( self, conn: asyncpg.Connection[Any], watch_id: str ) -> list[dict[str, Any]]: """Fetch all entities subscribed to a watch using an existing connection.""" rows = await conn.fetch( """ SELECT e.tags FROM hs_watch_entities we JOIN hs_entities e ON e.id = we.entity_id WHERE we.watch_id = $1 """, watch_id, ) return [_decode_tags(dict(row["tags"])) for row in rows] async def _fetch_dirty_watch_entities_with_conn( self, conn: asyncpg.Connection[Any], watch_id: str ) -> list[dict[str, Any]]: """Fetch only dirty (changed) entities subscribed to a watch.""" rows = await conn.fetch( """ SELECT e.tags FROM hs_watch_entities we JOIN hs_entities e ON e.id = we.entity_id WHERE we.watch_id = $1 AND we.dirty = true """, watch_id, ) return [_decode_tags(dict(row["tags"])) for row in rows] # ---- UserStore implementation --------------------------------------------
[docs] async def get_user(self, username: str) -> User | None: """Return a user by username, or ``None`` if not found.""" async with self._pool.acquire() as conn: row = await conn.fetchrow("SELECT * FROM hs_users WHERE username = $1", username) if row is None: return None return self._row_to_user(row)
[docs] async def list_users(self) -> list[User]: """Return all users.""" async with self._pool.acquire() as conn: rows = await conn.fetch("SELECT * FROM hs_users ORDER BY username") return [self._row_to_user(row) for row in rows]
[docs] async def create_user(self, user: User) -> None: """Persist a new user. :raises ValueError: If a user with the same username already exists. """ import base64 creds = { "salt": base64.b64encode(user.credentials.salt).decode(), "iterations": user.credentials.iterations, "stored_key": base64.b64encode(user.credentials.stored_key).decode(), "server_key": base64.b64encode(user.credentials.server_key).decode(), } try: async with self._pool.acquire() as conn: await conn.execute( """ INSERT INTO hs_users (username, credentials, first_name, last_name, email, role, enabled, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) """, user.username, creds, user.first_name, user.last_name, user.email, user.role.value, user.enabled, user.created_at, user.updated_at, ) except Exception as exc: if "unique" in str(exc).lower() or "duplicate" in str(exc).lower(): msg = f"User already exists: {user.username!r}" raise ValueError(msg) from exc raise
[docs] async def update_user(self, username: str, **fields: Any) -> User: """Update fields on an existing user. :raises KeyError: If the user does not exist. """ import time from hs_py.user import User as _User from hs_py.user import derive_scram_credentials as _derive existing = await self.get_user(username) if existing is None: msg = f"User not found: {username!r}" raise KeyError(msg) updates: dict[str, Any] = {"updated_at": time.time()} if "password" in fields: updates["credentials"] = _derive(fields.pop("password")) allowed = {"first_name", "last_name", "email", "role", "enabled", "credentials"} for k, v in fields.items(): if k in allowed: updates[k] = v from dataclasses import asdict merged = {**asdict(existing), **updates} merged["credentials"] = updates.get("credentials", existing.credentials) # asdict() converts Role enum to its value — restore the enum instance if isinstance(merged.get("role"), str): from hs_py.user import Role as _Role merged["role"] = _Role(merged["role"]) new_user = _User(**merged) import base64 creds_dict = { "salt": base64.b64encode(new_user.credentials.salt).decode(), "iterations": new_user.credentials.iterations, "stored_key": base64.b64encode(new_user.credentials.stored_key).decode(), "server_key": base64.b64encode(new_user.credentials.server_key).decode(), } async with self._pool.acquire() as conn: await conn.execute( """ UPDATE hs_users SET credentials = $1, first_name = $2, last_name = $3, email = $4, role = $5, enabled = $6, updated_at = $7 WHERE username = $8 """, creds_dict, new_user.first_name, new_user.last_name, new_user.email, new_user.role.value, new_user.enabled, new_user.updated_at, username, ) return new_user
[docs] async def delete_user(self, username: str) -> bool: """Delete a user by username.""" async with self._pool.acquire() as conn: result = await conn.execute("DELETE FROM hs_users WHERE username = $1", username) return str(result) == "DELETE 1"
@staticmethod def _row_to_user(row: Any) -> User: """Convert a database row to a User object.""" import base64 from hs_py.auth_types import ScramCredentials from hs_py.user import Role as _Role from hs_py.user import User as _User creds_data = row["credentials"] credentials = ScramCredentials( salt=base64.b64decode(creds_data["salt"]), iterations=creds_data["iterations"], stored_key=base64.b64decode(creds_data["stored_key"]), server_key=base64.b64decode(creds_data["server_key"]), ) return _User( username=row["username"], credentials=credentials, first_name=row["first_name"], last_name=row["last_name"], email=row["email"], role=_Role(row["role"]), enabled=row["enabled"], created_at=row["created_at"], updated_at=row["updated_at"], )
async def _init_connection(conn: asyncpg.Connection[Any]) -> None: """Register type codecs on each new physical connection. Uses orjson for JSONB encoding/decoding for consistency with the rest of the codebase and better performance. """ import orjson await conn.set_type_codec( "jsonb", encoder=lambda v: orjson.dumps(v).decode(), decoder=orjson.loads, schema="pg_catalog", ) await conn.set_type_codec( "json", encoder=lambda v: orjson.dumps(v).decode(), decoder=orjson.loads, schema="pg_catalog", )
[docs] async def create_timescale_pool( dsn: str = "postgresql://localhost:5432/haystack", *, min_size: int = 2, max_size: int = 10, command_timeout: float = 60.0, tls: TLSConfig | None = None, ) -> asyncpg.Pool[Any]: """Create an asyncpg connection pool for TimescaleDB. Configures connection recycling, idle connection cleanup, and registers orjson-based JSONB codecs on each new connection. :param dsn: PostgreSQL DSN string. :param min_size: Minimum number of pooled connections. :param max_size: Maximum number of pooled connections. :param command_timeout: Per-query timeout in seconds (default 60). :param tls: Optional TLS configuration for SSL connections. :returns: Open :class:`asyncpg.Pool` ready for use. """ import asyncpg as _asyncpg ssl_ctx: Any = None if tls is not None: from hs_py.tls import build_client_ssl_context ssl_ctx = build_client_ssl_context(tls) pool: asyncpg.Pool[Any] = await _asyncpg.create_pool( dsn, min_size=min_size, max_size=max_size, max_queries=50_000, max_inactive_connection_lifetime=300.0, command_timeout=command_timeout, init=_init_connection, ssl=ssl_ctx, ) return pool