"""FastAPI-based Haystack HTTP server.
Provides a FastAPI application factory with content-negotiated Haystack routes,
SCRAM-SHA-256 authentication middleware, and WebSocket support.
See: https://project-haystack.org/doc/docHaystack/HttpApi
"""
from __future__ import annotations
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any
import orjson
from fastapi import APIRouter, FastAPI, Request, Response, WebSocket, WebSocketDisconnect
from starlette.responses import Response as StarletteResponse
from hs_py._scram_core import (
TOKEN_LIFETIME,
HandshakeState,
TokenEntry,
handle_scram,
scram_hello,
validate_bearer,
)
from hs_py.content_negotiation import (
UnsupportedContentTypeError,
decode_request,
encode_response,
negotiate_format,
)
from hs_py.encoding.json import encode_grid as encode_grid_json
from hs_py.encoding.zinc import decode_val as zinc_decode_val
from hs_py.errors import HaystackError
from hs_py.grid import Grid
from hs_py.ops import _POST_OP_METHODS, HaystackOps, dispatch_op
if TYPE_CHECKING:
from collections.abc import AsyncIterator
from starlette.types import ASGIApp, Receive, Scope, Send
from hs_py.auth_types import Authenticator
from hs_py.ontology.namespace import Namespace
from hs_py.storage.protocol import StorageAdapter, UserStore
__all__ = [
"ScramAuthMiddleware",
"create_fastapi_app",
]
_log = logging.getLogger(__name__)
_GET_OPS = ("about", "ops", "formats", "close")
_POST_OPS = tuple(_POST_OP_METHODS.keys())
# Ops tagged `noSideEffects` that accept GET with Zinc-encoded query params.
_NO_SIDE_EFFECTS_OPS: dict[str, str] = {
"read": "read",
"nav": "nav",
"hisRead": "his_read",
"defs": "defs",
"libs": "libs",
"filetypes": "filetypes",
}
# Maximum request body size (16 MiB)
_MAX_BODY_SIZE = 16 * 1024 * 1024
# Maximum entries in response/grid caches.
_MAX_CACHE_SIZE = 2048
# Maximum number of items in a WebSocket batch request.
_MAX_BATCH_SIZE = 1000
# Ops that mutate state and should invalidate read caches.
_MUTATION_OPS = frozenset({"hisWrite", "pointWrite", "invokeAction"})
_401_HEADERS = {"WWW-Authenticate": "SCRAM hash=SHA-256"}
# ---------------------------------------------------------------------------
# SCRAM auth middleware (pure ASGI)
# ---------------------------------------------------------------------------
[docs]
class ScramAuthMiddleware:
"""SCRAM-SHA-256 authentication middleware for FastAPI.
Pure ASGI middleware — avoids ``BaseHTTPMiddleware`` overhead and streaming
issues. Implements the Haystack SCRAM handshake:
1. Client sends ``Authorization: HELLO username=<b64>``
2. Server returns 401 with ``WWW-Authenticate: SCRAM handshakeToken=...``
3. Client sends ``Authorization: SCRAM handshakeToken=..., data=<client-first>``
4. Server returns 401 with server-first message
5. Client sends ``Authorization: SCRAM handshakeToken=..., data=<client-final>``
6. Server returns 200 with ``Authentication-Info: authToken=..., data=<server-final>``
7. Subsequent requests use ``Authorization: BEARER authToken=...``
:param app: The ASGI application to wrap.
:param authenticator: Server-side credential store.
:param auth_tokens: Shared token dict (also used by the WS endpoint).
"""
def __init__(
self,
app: ASGIApp,
authenticator: Authenticator,
auth_tokens: dict[str, TokenEntry] | None = None,
) -> None:
self.app = app
self._authenticator = authenticator
self._handshakes: dict[str, HandshakeState] = {}
self._tokens: dict[str, TokenEntry] = auth_tokens if auth_tokens is not None else {}
@property
def tokens(self) -> dict[str, TokenEntry]:
"""Expose the token store (read-only access for tests)."""
return self._tokens
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI entry point — intercept HTTP requests for auth."""
if scope["type"] != "http":
# Pass WebSocket and lifespan through unchanged
await self.app(scope, receive, send)
return
# Extract Authorization header from raw ASGI scope
auth_header = ""
for key, value in scope.get("headers", []):
if key == b"authorization":
auth_header = value.decode("latin-1")
break
scheme = auth_header.split()[0].upper() if auth_header else ""
if scheme == "HELLO":
result = await scram_hello(self._authenticator, self._handshakes, auth_header)
response = StarletteResponse(
status_code=result.status, headers=result.headers, content=result.body
)
await response(scope, receive, send)
return
if scheme == "SCRAM":
result = handle_scram(self._handshakes, self._tokens, auth_header)
response = StarletteResponse(
status_code=result.status, headers=result.headers, content=result.body
)
await response(scope, receive, send)
return
if scheme == "BEARER":
bearer_result = validate_bearer(self._tokens, auth_header)
if bearer_result is not None:
response = StarletteResponse(
status_code=bearer_result.status,
headers=bearer_result.headers,
content=bearer_result.body,
)
await response(scope, receive, send)
return
# Attach authenticated username to scope state for downstream
params = dict(
p.split("=", 1) for p in auth_header.split(None, 1)[1].split(",") if "=" in p
)
token = params.get("authToken", "").strip()
entry = self._tokens.get(token)
if entry is not None:
scope.setdefault("state", {})["username"] = entry.username
await self.app(scope, receive, send)
return
# No recognized auth scheme → 401
response = StarletteResponse(status_code=401, headers=_401_HEADERS)
await response(scope, receive, send)
# ---------------------------------------------------------------------------
# Request / response helpers
# ---------------------------------------------------------------------------
def _get_ops(request: Request) -> HaystackOps:
"""Extract the HaystackOps instance from app state."""
return request.app.state.ops # type: ignore[no-any-return]
def _get_default_format(request: Request) -> str:
"""Return the server's configured default format."""
return getattr(request.app.state, "default_format", "json")
async def _parse_grid(request: Request) -> Grid:
"""Decode the request body into a Grid, using content-negotiation.
:raises UnsupportedContentTypeError: Propagated to the error middleware
which converts it to HTTP 415.
"""
body = await request.body()
if not body:
return Grid.make_empty()
if len(body) > _MAX_BODY_SIZE:
raise HaystackError("Request body too large")
ct = request.headers.get("content-type", "application/json")
return decode_request(body, ct)
def _negotiate_or_406(request: Request) -> str:
"""Negotiate the response format or raise for HTTP 406."""
accept = request.headers.get("accept", "")
default = _get_default_format(request)
fmt = negotiate_format(accept, default=default)
if fmt is None:
raise _NotAcceptableError(accept)
return fmt
class _NotAcceptableError(Exception):
"""Raised when Accept header contains no supported MIME types."""
def __init__(self, accept: str) -> None:
self.accept = accept
super().__init__(f"Not Acceptable: {accept}")
def _grid_response(grid: Grid, request: Request) -> Response:
"""Encode a Grid into an HTTP response, honouring the Accept header."""
fmt = _negotiate_or_406(request)
body, ct = encode_response(grid, fmt)
return Response(content=body, media_type=ct)
def _cached_grid_response(grid: Grid, request: Request, cache_key: str) -> Response:
"""Encode a Grid with per-format caching on app.state._response_cache."""
fmt = _negotiate_or_406(request)
key = (cache_key, fmt)
cache: dict[tuple[str, str], tuple[bytes, str]] = request.app.state._response_cache
cached = cache.get(key)
if cached is not None:
return Response(content=cached[0], media_type=cached[1])
body, ct = encode_response(grid, fmt)
if len(cache) < _MAX_CACHE_SIZE:
cache[key] = (body, ct)
return Response(content=body, media_type=ct)
# ---------------------------------------------------------------------------
# WebSocket encoding helpers
# ---------------------------------------------------------------------------
def _ws_encode_grid(grid: Grid) -> bytes:
"""Encode a grid to JSON bytes using the fast orjson path."""
return encode_grid_json(grid)
def _ws_cached_grid_bytes(
app: Any,
op: str,
msg: dict[str, Any],
grid: Grid,
) -> bytes:
"""Return cached grid bytes for read ops, encode otherwise."""
cache: dict[str, bytes] = app.state._ws_grid_cache
# Invalidate cache on mutation ops
if op in _MUTATION_OPS and cache:
cache.clear()
if op == "read":
grid_data = msg.get("grid")
if isinstance(grid_data, dict):
rows = grid_data.get("rows", [])
if rows and isinstance(rows[0], dict):
filt = rows[0].get("filter", "")
limit = rows[0].get("limit", "")
key = f"ws_read:{filt}:{limit}"
cached = cache.get(key)
if cached is not None:
return cached
grid_bytes = _ws_encode_grid(grid)
if len(cache) < _MAX_CACHE_SIZE:
cache[key] = grid_bytes
return grid_bytes
return _ws_encode_grid(grid)
def _ws_envelope(grid_bytes: bytes, req_id: Any = None) -> bytes:
"""Build a JSON envelope around pre-encoded grid bytes."""
if req_id is not None:
id_bytes = orjson.dumps(req_id)
return b'{"grid":' + grid_bytes + b',"id":' + id_bytes + b"}"
return b'{"grid":' + grid_bytes + b"}"
# ---------------------------------------------------------------------------
# Error handler
# ---------------------------------------------------------------------------
async def _haystack_error_handler(request: Request, exc: Exception) -> Response:
"""Convert a :class:`~hs_py.errors.HaystackError` into an error Grid response."""
grid = Grid.make_error(str(exc))
return _grid_response(grid, request)
async def _generic_error_handler(request: Request, exc: Exception) -> Response:
"""Catch-all for unhandled exceptions — return an error Grid response.
Mirrors the aiohttp ``_error_middleware`` behaviour so that unexpected
exceptions never leak a raw 500 to Haystack clients.
"""
_log.exception("Unhandled exception in request handler")
grid = Grid.make_error("Internal server error")
return _grid_response(grid, request)
# ---------------------------------------------------------------------------
# Lifespan context manager
# ---------------------------------------------------------------------------
@asynccontextmanager
async def _lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Manage storage adapter lifecycle — start on startup, close on shutdown."""
storage = getattr(app.state, "storage", None)
if storage is not None:
await storage.start()
# Bootstrap superuser if a user store is available
user_store: UserStore | None = getattr(app.state, "user_store", None)
if user_store is not None:
from hs_py.bootstrap import ensure_superuser
await ensure_superuser(user_store)
yield
if storage is not None:
await storage.close()
# ---------------------------------------------------------------------------
# Role-based access control helpers
# ---------------------------------------------------------------------------
async def _check_op_permission(request: Request, op_name: str) -> None:
"""Verify the authenticated user has permission for the given op.
Raises :class:`~hs_py.errors.HaystackError` if the user lacks the
required role. When no ``user_store`` is configured on the app,
permission checks are skipped (open access).
:param request: The incoming HTTP request.
:param op_name: The Haystack op name (e.g. ``"hisWrite"``).
"""
user_store: UserStore | None = getattr(request.app.state, "user_store", None)
if user_store is None:
return # No user store → no role enforcement
from hs_py.user import WRITE_OPS, Role
username: str | None = getattr(request.state, "username", None)
if username is None:
raise HaystackError("Authentication required")
user = await user_store.get_user(username)
if user is None or not user.enabled:
raise HaystackError("Authentication required")
if op_name in WRITE_OPS and user.role < Role.OPERATOR:
raise HaystackError(f"Insufficient permissions: {op_name} requires operator or admin role")
# ---------------------------------------------------------------------------
# Route handler factories
# ---------------------------------------------------------------------------
def _make_get_handler(op_name: str) -> Any:
"""Create a GET handler for a named Haystack op."""
async def handler(request: Request) -> Response:
ops = _get_ops(request)
if op_name == "about":
grid = await ops.about()
return _cached_grid_response(grid, request, "about")
elif op_name == "ops":
grid = await ops.ops()
return _cached_grid_response(grid, request, "ops")
elif op_name == "formats":
grid = await ops.formats()
return _cached_grid_response(grid, request, "formats")
elif op_name == "close":
await _check_op_permission(request, "close")
await ops.on_close()
grid = Grid.make_empty()
else:
grid = Grid.make_error(f"Unknown GET operation: {op_name}")
return _grid_response(grid, request)
handler.__name__ = f"{op_name}_get_handler"
return handler
def _make_post_handler(op_name: str, method_name: str) -> Any:
"""Create a POST handler for a named Haystack op.
:param op_name: The URL-level op name (e.g. ``"hisRead"``).
:param method_name: The :class:`~hs_py.ops.HaystackOps` method name (e.g.
``"his_read"``).
"""
async def handler(request: Request) -> Response:
await _check_op_permission(request, op_name)
ops = _get_ops(request)
req_grid = await _parse_grid(request)
method = getattr(ops, method_name)
result_grid: Grid = await method(req_grid)
# Invalidate read caches on mutation ops
if op_name in _MUTATION_OPS:
cache: dict[Any, Any] | None = getattr(request.app.state, "_response_cache", None)
if cache:
cache.clear()
ws_cache: dict[str, bytes] | None = getattr(request.app.state, "_ws_grid_cache", None)
if ws_cache:
ws_cache.clear()
return _grid_response(result_grid, request)
handler.__name__ = f"{op_name}_post_handler"
return handler
def _make_read_handler() -> Any:
"""Create a cached POST handler for the read op.
Read responses are cached by (filter, limit, format) since entity data
is typically stable between mutations. The cache is stored on
``app.state._response_cache`` and is cleared on entity mutations.
"""
async def handler(request: Request) -> Response:
await _check_op_permission(request, "read")
ops = _get_ops(request)
req_grid = await _parse_grid(request)
result_grid: Grid = await ops.read(req_grid)
# Build a cache key from the filter/limit in the request grid.
cache: dict[tuple[str, str], tuple[bytes, str]] | None = getattr(
request.app.state, "_response_cache", None
)
if cache is not None and req_grid.rows:
first = req_grid[0]
filter_str = first.get("filter", "")
limit_val = first.get("limit", "")
cache_key = f"read:{filter_str}:{limit_val}"
return _cached_grid_response(result_grid, request, cache_key)
return _grid_response(result_grid, request)
handler.__name__ = "read_post_handler"
return handler
def _make_query_get_handler(op_name: str, method_name: str) -> Any:
"""Create a GET handler for a ``noSideEffects`` op.
Parses query parameters into a single-row Grid. Values are Zinc-decoded
when possible, otherwise treated as strings per the Haystack spec.
"""
async def handler(request: Request) -> Response:
await _check_op_permission(request, op_name)
ops = _get_ops(request)
params = dict(request.query_params)
if params:
row: dict[str, Any] = {}
for key, val in params.items():
try:
row[key] = zinc_decode_val(val)
except Exception:
row[key] = val
req_grid = Grid.make_rows([row])
else:
req_grid = Grid.make_empty()
method = getattr(ops, method_name)
result_grid: Grid = await method(req_grid)
return _grid_response(result_grid, request)
handler.__name__ = f"{op_name}_query_get_handler"
return handler
# ---------------------------------------------------------------------------
# WebSocket dispatch helpers
# ---------------------------------------------------------------------------
async def _check_ws_op_permission(app: FastAPI, username: str | None, op_name: str) -> None:
"""Check role permissions for a WebSocket op.
:raises HaystackError: If the user lacks the required role.
"""
user_store: UserStore | None = getattr(app.state, "user_store", None)
if user_store is None:
return
from hs_py.user import WRITE_OPS, Role
if username is None:
raise HaystackError("Authentication required")
user = await user_store.get_user(username)
if user is None or not user.enabled:
raise HaystackError("Authentication required")
if op_name in WRITE_OPS and user.role < Role.OPERATOR:
raise HaystackError(f"Insufficient permissions: {op_name} requires operator or admin role")
async def _handle_ws_single(
websocket: WebSocket, ops: HaystackOps, msg: dict[str, Any], username: str | None = None
) -> None:
"""Dispatch a single WS message and send the response."""
req_id = msg.get("id")
op = msg.get("op", "")
try:
await _check_ws_op_permission(websocket.app, username, op)
result_grid = await dispatch_op(ops, op, msg)
except HaystackError as exc:
result_grid = Grid.make_error(str(exc))
except Exception:
_log.exception("Unhandled error in WS op '%s'", op)
result_grid = Grid.make_error("Internal server error")
grid_bytes = _ws_cached_grid_bytes(websocket.app, op, msg, result_grid)
payload = _ws_envelope(grid_bytes, req_id)
await websocket.send_text(payload.decode())
async def _handle_ws_batch(
websocket: WebSocket, ops: HaystackOps, batch: list[Any], username: str | None = None
) -> None:
"""Dispatch all ops in a batch concurrently, then send array response."""
items = [item for item in batch if isinstance(item, dict)]
if not items:
return
# Cap batch size to prevent resource exhaustion
if len(items) > _MAX_BATCH_SIZE:
items = items[:_MAX_BATCH_SIZE]
async def _dispatch_item(item: dict[str, Any]) -> bytes:
r_id = item.get("id")
r_op = item.get("op", "")
try:
await _check_ws_op_permission(websocket.app, username, r_op)
r_grid = await dispatch_op(ops, r_op, item)
except HaystackError as exc:
r_grid = Grid.make_error(str(exc))
except Exception:
_log.exception("Unhandled error in batch WS op '%s'", r_op)
r_grid = Grid.make_error("Internal server error")
grid_bytes = _ws_cached_grid_bytes(websocket.app, r_op, item, r_grid)
return _ws_envelope(grid_bytes, r_id)
item_bytes = await asyncio.gather(*[_dispatch_item(item) for item in items])
payload = b"[" + b",".join(item_bytes) + b"]"
await websocket.send_text(payload.decode())
# ---------------------------------------------------------------------------
# Router construction
# ---------------------------------------------------------------------------
def _build_router() -> APIRouter:
"""Build and return an APIRouter with all Haystack op endpoints."""
router = APIRouter()
# GET ops
for op_name in _GET_OPS:
router.add_api_route(
f"/{op_name}",
_make_get_handler(op_name),
methods=["GET"],
)
# POST ops
for op_name, method_name in _POST_OP_METHODS.items():
if op_name == "read":
handler = _make_read_handler()
else:
handler = _make_post_handler(op_name, method_name)
router.add_api_route(
f"/{op_name}",
handler,
methods=["POST"],
)
# GET with query params for noSideEffects ops (e.g. GET /read?filter=site)
for op_name, method_name in _NO_SIDE_EFFECTS_OPS.items():
if op_name not in _GET_OPS:
router.add_api_route(
f"/{op_name}",
_make_query_get_handler(op_name, method_name),
methods=["GET"],
)
# WebSocket endpoint
@router.websocket("/ws")
async def ws_endpoint(websocket: WebSocket) -> None:
"""Handle a WebSocket upgrade and dispatch Haystack operations.
Uses the same JSON envelope protocol as the aiohttp WebSocket handler.
Messages are dispatched concurrently via ``asyncio.create_task`` —
response ordering is maintained by correlation IDs.
When SCRAM auth is enabled, the first message must contain an
``authToken`` field matching a valid bearer token. Connections
without a valid token are closed with code 4003.
"""
await websocket.accept(subprotocol="haystack")
ops: HaystackOps = websocket.app.state.ops
tasks: set[asyncio.Task[None]] = set()
# Check if SCRAM auth is enabled — require token on WS too
auth_tokens: dict[str, TokenEntry] | None = getattr(
websocket.app.state, "auth_tokens", None
)
if auth_tokens is not None:
try:
raw = await websocket.receive_text()
msg = orjson.loads(raw)
token = msg.get("authToken", "")
entry = auth_tokens.get(token)
if entry is None or (time.monotonic() - entry.created) > TOKEN_LIFETIME:
await websocket.close(code=4003, reason="Authentication required")
return
ws_username: str | None = entry.username
except Exception:
await websocket.close(code=4003, reason="Authentication required")
return
else:
ws_username = None
try:
while True:
raw = await websocket.receive_text()
try:
msg = orjson.loads(raw)
except orjson.JSONDecodeError:
_log.warning("Non-JSON WebSocket message, ignoring")
continue
# Batch: JSON array of envelopes — dispatch all ops concurrently
if isinstance(msg, list):
task = asyncio.create_task(_handle_ws_batch(websocket, ops, msg, ws_username))
tasks.add(task)
task.add_done_callback(tasks.discard)
continue
# Single message — fire-and-forget via create_task
if isinstance(msg, dict):
task = asyncio.create_task(_handle_ws_single(websocket, ops, msg, ws_username))
tasks.add(task)
task.add_done_callback(tasks.discard)
except WebSocketDisconnect:
pass
finally:
# Wait for in-flight tasks before returning
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
@router.get("/ontology/export")
async def export_ontology(request: Request, format: str = "turtle") -> Response:
"""Export the loaded ontology namespace as RDF (Turtle or JSON-LD)."""
ops = _get_ops(request)
ns = getattr(ops, "_namespace", None)
if ns is None:
return _grid_response(Grid.make_error("No namespace loaded"), request)
from hs_py.ontology.rdf import export_jsonld, export_turtle
if format == "jsonld":
return Response(content=export_jsonld(ns), media_type="application/ld+json")
return Response(content=export_turtle(ns), media_type="text/turtle")
return router
# ---------------------------------------------------------------------------
# User management endpoints
# ---------------------------------------------------------------------------
def _build_user_router(user_store: UserStore) -> APIRouter:
"""Build an APIRouter with user management CRUD endpoints."""
from pydantic import BaseModel
from starlette.responses import JSONResponse
from hs_py.user import Role
router = APIRouter(prefix="/users", tags=["users"])
# -- Pydantic request/response models ------------------------------------
class CreateUserRequest(BaseModel):
"""Request body for creating a new user."""
username: str
password: str
first_name: str = ""
last_name: str = ""
email: str = ""
role: str = "viewer"
enabled: bool = True
class UpdateUserRequest(BaseModel):
"""Request body for updating a user (all fields optional)."""
password: str | None = None
first_name: str | None = None
last_name: str | None = None
email: str | None = None
role: str | None = None
enabled: bool | None = None
class UserResponse(BaseModel):
"""Public user representation (no credentials)."""
username: str
first_name: str
last_name: str
email: str
role: str
enabled: bool
created_at: float
updated_at: float
# -- Dependencies --------------------------------------------------------
async def _get_current_user(request: Request) -> Any:
"""FastAPI dependency: resolve authenticated user from request state.
:returns: The :class:`~hs_py.user.User` object.
:raises HaystackError: If not authenticated or user not found.
"""
username: str | None = getattr(request.state, "username", None)
if username is None:
raise HaystackError("Authentication required")
user = await user_store.get_user(username)
if user is None or not user.enabled:
raise HaystackError("Authentication required")
return user
async def _require_admin(request: Request) -> str:
"""Extract authenticated username and verify admin role.
:returns: The authenticated username.
:raises HaystackError: If not authenticated or not an admin.
"""
user = await _get_current_user(request)
if user.role != Role.ADMIN:
raise HaystackError("Admin access required")
return str(user.username)
def _user_response(user: Any) -> UserResponse:
"""Convert a User to a Pydantic response model (no credentials)."""
return UserResponse(
username=user.username,
first_name=user.first_name,
last_name=user.last_name,
email=user.email,
role=user.role.value,
enabled=user.enabled,
created_at=user.created_at,
updated_at=user.updated_at,
)
@router.get("/")
async def list_users(request: Request) -> JSONResponse:
"""List all users (admin-only)."""
await _require_admin(request)
users = await user_store.list_users()
return JSONResponse([_user_response(u).model_dump() for u in users])
@router.get("/{username}")
async def get_user(request: Request, username: str) -> JSONResponse:
"""Get a single user by username (admin-only)."""
await _require_admin(request)
user = await user_store.get_user(username)
if user is None:
return JSONResponse({"error": f"User not found: {username!r}"}, status_code=404)
return JSONResponse(_user_response(user).model_dump())
@router.post("/")
async def create_user_endpoint(request: Request) -> JSONResponse:
"""Create a new user (admin-only)."""
await _require_admin(request)
raw = await request.json()
try:
body = CreateUserRequest(**raw)
except Exception:
return JSONResponse({"error": "Invalid request body"}, status_code=400)
username = body.username.strip()
password = body.password.strip()
if not username or not password:
return JSONResponse({"error": "username and password are required"}, status_code=400)
from hs_py.user import create_user as _create_user
try:
role = Role(body.role)
except ValueError:
return JSONResponse(
{"error": f"Invalid role: {body.role!r}. Must be admin, operator, or viewer"},
status_code=400,
)
try:
user = _create_user(
username=username,
password=password,
first_name=body.first_name,
last_name=body.last_name,
email=body.email,
role=role,
enabled=body.enabled,
)
await user_store.create_user(user)
except ValueError as exc:
return JSONResponse({"error": str(exc)}, status_code=409)
return JSONResponse(_user_response(user).model_dump(), status_code=201)
@router.put("/{username}")
async def update_user_endpoint(request: Request, username: str) -> JSONResponse:
"""Update an existing user (admin-only)."""
await _require_admin(request)
raw = await request.json()
try:
body = UpdateUserRequest(**raw)
except Exception:
return JSONResponse({"error": "Invalid request body"}, status_code=400)
fields: dict[str, Any] = {}
for field_name in ("password", "first_name", "last_name", "email", "enabled"):
val = getattr(body, field_name)
if val is not None:
fields[field_name] = val
if body.role is not None:
try:
fields["role"] = Role(body.role)
except ValueError:
return JSONResponse(
{"error": f"Invalid role: {body.role!r}. Must be admin, operator, or viewer"},
status_code=400,
)
if not fields:
return JSONResponse({"error": "No valid fields to update"}, status_code=400)
try:
updated = await user_store.update_user(username, **fields)
except KeyError:
return JSONResponse({"error": f"User not found: {username!r}"}, status_code=404)
return JSONResponse(_user_response(updated).model_dump())
@router.delete("/{username}")
async def delete_user_endpoint(request: Request, username: str) -> JSONResponse:
"""Delete a user (admin-only)."""
caller = await _require_admin(request)
if username == caller:
return JSONResponse({"error": "Cannot delete your own account"}, status_code=400)
deleted = await user_store.delete_user(username)
if not deleted:
return JSONResponse({"error": f"User not found: {username!r}"}, status_code=404)
return JSONResponse({"deleted": username})
return router
# ---------------------------------------------------------------------------
# App factory
# ---------------------------------------------------------------------------
[docs]
def create_fastapi_app(
ops: HaystackOps | None = None,
*,
storage: StorageAdapter | None = None,
authenticator: Authenticator | None = None,
namespace: Namespace | None = None,
user_store: UserStore | None = None,
prefix: str = "/api",
cors_origins: list[str] | None = None,
default_format: str = "json",
) -> FastAPI:
"""Create a FastAPI application with Haystack HTTP routes.
The returned app supports content-negotiated responses (JSON, Zinc, Trio,
CSV, Turtle, JSON-LD), SCRAM-SHA-256 authentication (when an
*authenticator* is provided), and WebSocket connections at ``{prefix}/ws``.
:param ops: :class:`~hs_py.ops.HaystackOps` implementation to dispatch to.
When *None* and *storage* is provided, a default :class:`~hs_py.ops.HaystackOps`
is constructed automatically.
:param storage: Optional :class:`~hs_py.storage.protocol.StorageAdapter`
backend. Its ``start()`` / ``close()`` lifecycle methods are called
automatically via the FastAPI lifespan.
:param authenticator: Optional :class:`~hs_py.auth_types.Authenticator` for
SCRAM-SHA-256 auth. When *None*, all requests are accepted without auth.
:param namespace: Optional :class:`~hs_py.ontology.namespace.Namespace` for
``defs`` and ``libs`` operations.
:param user_store: Optional :class:`~hs_py.storage.protocol.UserStore` for
user management. When provided, user CRUD endpoints are mounted under
``{prefix}/users/`` and superuser bootstrapping runs at startup.
:param prefix: URL path prefix for all Haystack routes (default ``"/api"``).
:param cors_origins: Optional list of allowed CORS origins. When provided,
``CORSMiddleware`` is added with credentials support. Example:
``["http://localhost:3000", "https://app.example.com"]``.
:param default_format: Default response format when the client sends ``*/*``
or no ``Accept`` header. The Haystack spec defines ``"zinc"`` as the
default; this library defaults to ``"json"`` for ecosystem compatibility.
Supported values: ``"json"``, ``"zinc"``.
:returns: Configured :class:`fastapi.FastAPI` application.
Example::
from hs_py.fastapi_server import create_fastapi_app
from hs_py.storage.memory import InMemoryAdapter
from hs_py.auth_types import StorageAuthenticator
storage = InMemoryAdapter()
auth = StorageAuthenticator(storage)
app = create_fastapi_app(storage=storage, authenticator=auth, user_store=storage)
# uvicorn.run(app, host="0.0.0.0", port=8080)
"""
if ops is None:
ops = HaystackOps(storage=storage, namespace=namespace)
app = FastAPI(title="Haystack Server", lifespan=_lifespan)
app.state.ops = ops
app.state.storage = storage
app.state.user_store = user_store
app.state.default_format = default_format
app.state._response_cache = {}
app.state._ws_grid_cache = {}
# Shared token store — set on app.state before middleware so both the
# SCRAM middleware and the WS endpoint reference the same dict.
auth_tokens: dict[str, TokenEntry] | None = None
if authenticator is not None:
auth_tokens = {}
app.state.auth_tokens = auth_tokens
app.add_middleware(
ScramAuthMiddleware, authenticator=authenticator, auth_tokens=auth_tokens
)
if cors_origins:
from starlette.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def error_and_security_headers(request: Request, call_next: Any) -> StarletteResponse:
"""Catch unhandled exceptions and add security headers.
This outermost middleware mirrors the aiohttp ``_error_middleware``:
Haystack-specific errors become error-grid responses, unexpected
exceptions produce a generic error grid, and every response gets
standard security headers.
"""
try:
response: StarletteResponse = await call_next(request)
except _NotAcceptableError:
response = Response(
content="Not Acceptable: no supported media type in Accept header",
status_code=406,
media_type="text/plain",
)
except UnsupportedContentTypeError as exc:
response = Response(
content=f"Unsupported Media Type: {exc.mime}",
status_code=415,
media_type="text/plain",
)
except HaystackError as exc:
response = _grid_response(Grid.make_error(str(exc)), request)
except Exception:
_log.exception("Unhandled exception in request handler")
response = _grid_response(Grid.make_error("Internal server error"), request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
if request.url.scheme == "https":
response.headers["Strict-Transport-Security"] = "max-age=63072000; includeSubDomains"
return response
app.add_exception_handler(HaystackError, _haystack_error_handler)
prefix = prefix.rstrip("/")
app.include_router(_build_router(), prefix=prefix)
if user_store is not None:
app.include_router(_build_user_router(user_store), prefix=prefix)
return app