"""Haystack authentication handshake.
Implements the Project Haystack authentication protocol including
SCRAM-SHA-256 and PLAINTEXT mechanisms.
Uses the ``cryptography`` library for all key derivation, HMAC, and hashing.
See: https://project-haystack.org/doc/docHaystack/Auth
"""
from __future__ import annotations
import base64
import hmac as hmac_mod
import logging
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.hmac import HMAC
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from hs_py.errors import AuthError
if TYPE_CHECKING:
import aiohttp
__all__ = [
"authenticate",
"scram_client_final",
"scram_client_first",
]
_log = logging.getLogger(__name__)
# Maximum PBKDF2 iterations to accept from a server (prevents CPU DoS).
_MAX_SCRAM_ITERATIONS = 1_000_000
# Minimum PBKDF2 iterations to accept from a server (prevents weak keys).
_MIN_SCRAM_ITERATIONS = 4096
# Minimum salt length in bytes per NIST SP 800-132 (128 bits).
_MIN_SALT_BYTES = 16
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
[docs]
async def authenticate(
session: aiohttp.ClientSession,
base_url: str,
username: str,
password: str,
) -> str:
"""Run the Haystack auth handshake and return a bearer auth token.
Attempts SCRAM-SHA-256 first, falling back to PLAINTEXT if offered.
:param session: An open aiohttp client session.
:param base_url: Haystack server base URL (e.g. ``http://host/api/``).
:param username: Username for authentication.
:param password: Password for authentication.
:returns: Bearer auth token string for subsequent requests.
:raises AuthError: On authentication failure.
"""
about_url = base_url.rstrip("/") + "/about"
# Step 1: HELLO
_log.debug("HELLO for user '%s' at %s", username, about_url)
hello_header = f"HELLO username={_b64url_encode(username.encode())}"
async with session.get(about_url, headers={"Authorization": hello_header}) as resp:
if resp.status == 200:
# Server does not require auth — extract token if present
auth_info = resp.headers.get("Authentication-Info", "")
token = _parse_param(auth_info, "authToken")
if token:
return token
return ""
if resp.status != 401:
raise AuthError(f"Unexpected status {resp.status} during HELLO")
www_auth = resp.headers.get("WWW-Authenticate", "")
# Determine mechanism
params = _parse_header_params(www_auth)
scheme = www_auth.split()[0].upper() if www_auth else ""
handshake_token = params.get("handshakeToken", "")
if scheme == "SCRAM":
_log.debug("SCRAM-SHA-256 handshake starting for '%s'", username)
return await _scram_auth(session, about_url, username, password, handshake_token, params)
if "PLAINTEXT" in www_auth.upper():
_log.debug("PLAINTEXT auth for '%s'", username)
return await _plaintext_auth(session, about_url, username, password)
_log.warning("No supported auth mechanism in: %s", www_auth)
raise AuthError(f"No supported auth mechanism in: {www_auth}")
# ---------------------------------------------------------------------------
# SCRAM-SHA-256 — transport-independent helpers
# ---------------------------------------------------------------------------
@dataclass(frozen=True, slots=True)
class ScramClientFirst:
"""Result of the client-first SCRAM step."""
client_first_msg: str
"""Full client-first message (gs2-header + bare)."""
client_first_bare: str
"""Bare client-first message (without gs2 header)."""
c_nonce: str
"""Client nonce."""
@dataclass(frozen=True, slots=True)
class ScramClientFinal:
"""Result of the client-final SCRAM step."""
client_final_msg: str
"""Full client-final message with proof."""
auth_message: str
"""Concatenated auth message for server-sig verification."""
salted_password: bytes
"""Salted password (needed for server signature verification)."""
[docs]
def scram_client_first(username: str) -> ScramClientFirst:
"""Build the SCRAM client-first message.
:param username: The username to authenticate as.
:returns: :class:`ScramClientFirst` with the message and nonce.
"""
c_nonce = base64.urlsafe_b64encode(os.urandom(24)).decode().rstrip("=")
# Escape per RFC 5802 §5.1: '=' → '=3D', ',' → '=2C'
safe_user = username.replace("=", "=3D").replace(",", "=2C")
client_first_bare = f"n={safe_user},r={c_nonce}"
client_first_msg = "n,," + client_first_bare
return ScramClientFirst(
client_first_msg=client_first_msg,
client_first_bare=client_first_bare,
c_nonce=c_nonce,
)
[docs]
def scram_client_final(
password: str,
first: ScramClientFirst,
server_first_msg: str,
hash_name: str = "SHA-256",
) -> ScramClientFinal:
"""Build the SCRAM client-final message from the server-first response.
:param password: The plaintext password.
:param first: The :class:`ScramClientFirst` from :func:`scram_client_first`.
:param server_first_msg: Raw server-first message (``r=...,s=...,i=...``).
:param hash_name: Hash algorithm name (default ``SHA-256``).
:returns: :class:`ScramClientFinal` with the proof message.
:raises AuthError: On nonce mismatch or unsupported hash.
"""
sf_params = _parse_scram_msg(server_first_msg)
s_nonce = sf_params.get("r", "")
salt_b64 = sf_params.get("s", "")
try:
iterations = int(sf_params.get("i", "4096"))
except (ValueError, TypeError) as exc:
raise AuthError(f"Invalid SCRAM iteration count: {sf_params.get('i')!r}") from exc
if not s_nonce.startswith(first.c_nonce):
raise AuthError("Server nonce does not start with client nonce")
if iterations > _MAX_SCRAM_ITERATIONS:
raise AuthError(f"Server requested excessive PBKDF2 iterations: {iterations}")
if iterations < _MIN_SCRAM_ITERATIONS:
raise AuthError(f"Server requested insufficient PBKDF2 iterations: {iterations}")
salt = base64.b64decode(salt_b64)
if len(salt) < _MIN_SALT_BYTES:
raise AuthError(f"Server provided insufficient salt length: {len(salt)} bytes")
algo = _hash_algo(hash_name)
salted_password = _derive_key(password.encode(), salt, iterations, algo)
client_key = _hmac(algo, salted_password, b"Client Key")
stored_key = _hash_digest(algo, client_key)
channel_binding = _b64url_encode(b"n,,")
client_final_no_proof = f"c={channel_binding},r={s_nonce}"
auth_message = f"{first.client_first_bare},{server_first_msg},{client_final_no_proof}"
client_signature = _hmac(algo, stored_key, auth_message.encode())
client_proof = bytes(a ^ b for a, b in zip(client_key, client_signature, strict=True))
proof_b64 = base64.b64encode(client_proof).decode()
client_final_msg = f"{client_final_no_proof},p={proof_b64}"
return ScramClientFinal(
client_final_msg=client_final_msg,
auth_message=auth_message,
salted_password=salted_password,
)
def verify_server_signature(
final: ScramClientFinal,
server_final_msg: str,
) -> None:
"""Verify the server's signature from the server-final message.
:param final: The :class:`ScramClientFinal` from :func:`scram_client_final`.
:param server_final_msg: Raw server-final message (``v=...``).
:raises AuthError: If the server signature does not match.
"""
sf_params = _parse_scram_msg(server_final_msg)
server_sig_b64 = sf_params.get("v", "")
if not server_sig_b64:
raise AuthError("Server signature missing — server authentication failed")
server_sig = base64.b64decode(server_sig_b64)
server_key = _hmac("sha256", final.salted_password, b"Server Key")
expected = _hmac("sha256", server_key, final.auth_message.encode())
if not hmac_mod.compare_digest(server_sig, expected):
raise AuthError("Server signature verification failed")
# ---------------------------------------------------------------------------
# SCRAM-SHA-256 — HTTP transport
# ---------------------------------------------------------------------------
async def _scram_auth(
session: aiohttp.ClientSession,
url: str,
username: str,
password: str,
handshake_token: str,
hello_params: dict[str, str],
) -> str:
"""Perform SCRAM-SHA-256 authentication over HTTP."""
first = scram_client_first(username)
auth_header = (
f"SCRAM handshakeToken={handshake_token}, "
f"data={_b64url_encode(first.client_first_msg.encode())}"
)
_log.debug("SCRAM step 1: sending client-first message")
async with session.get(url, headers={"Authorization": auth_header}) as resp:
if resp.status != 401:
_log.warning("SCRAM step 1 failed: expected 401, got %d", resp.status)
raise AuthError(f"Expected 401 during SCRAM step 1, got {resp.status}")
www_auth = resp.headers.get("WWW-Authenticate", "")
step2_params = _parse_header_params(www_auth)
handshake_token = step2_params.get("handshakeToken", handshake_token)
hash_name = step2_params.get("hash", "SHA-256")
server_first_data = step2_params.get("data", "")
server_first_msg = _b64url_decode(server_first_data).decode()
final = scram_client_final(password, first, server_first_msg, hash_name)
auth_header = (
f"SCRAM handshakeToken={handshake_token}, "
f"data={_b64url_encode(final.client_final_msg.encode())}"
)
_log.debug("SCRAM step 2: sending client-final message")
async with session.get(url, headers={"Authorization": auth_header}) as resp:
if resp.status != 200:
_log.warning("SCRAM auth failed with status %d", resp.status)
raise AuthError(f"SCRAM auth failed with status {resp.status}")
auth_info = resp.headers.get("Authentication-Info", "")
token = _parse_param(auth_info, "authToken")
if not token:
_log.warning("No authToken in final SCRAM response")
raise AuthError("No authToken in final SCRAM response")
# Verify server signature (mandatory — RFC 5802 mutual authentication)
server_data = _parse_param(auth_info, "data")
if not server_data:
raise AuthError("Server signature missing from final response")
verify_server_signature(final, _b64url_decode(server_data).decode())
return token
# ---------------------------------------------------------------------------
# PLAINTEXT
# ---------------------------------------------------------------------------
async def _plaintext_auth(
session: aiohttp.ClientSession,
url: str,
username: str,
password: str,
) -> str:
"""Perform PLAINTEXT authentication (TLS-only)."""
auth_header = (
f"PLAINTEXT username={_b64url_encode(username.encode())}, "
f"password={_b64url_encode(password.encode())}"
)
async with session.get(url, headers={"Authorization": auth_header}) as resp:
if resp.status != 200:
raise AuthError(f"PLAINTEXT auth failed with status {resp.status}")
auth_info = resp.headers.get("Authentication-Info", "")
token = _parse_param(auth_info, "authToken")
if not token:
raise AuthError("No authToken in PLAINTEXT response")
return token
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _b64url_encode(data: bytes) -> str:
"""Base64url encode without padding (per Haystack spec)."""
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
def _b64url_decode(s: str) -> bytes:
"""Base64url decode, adding padding as needed."""
padding = 4 - len(s) % 4
if padding != 4:
s += "=" * padding
return base64.urlsafe_b64decode(s)
def _parse_header_params(header: str) -> dict[str, str]:
"""Parse ``key=value`` pairs from an HTTP auth header.
Handles both scheme-prefixed headers (``SCRAM key=val, ...``) and
bare param headers (``Authentication-Info: key=val, ...``).
:param header: Raw header value string.
:returns: Dict mapping parameter names to values.
"""
result: dict[str, str] = {}
if not header:
return result
# Determine if first token is a scheme name (no '=' sign)
parts = header.split(None, 1)
param_str = header
if len(parts) >= 2 and "=" not in parts[0]:
param_str = parts[1]
for pair in param_str.split(","):
pair = pair.strip()
if "=" in pair:
k, _, v = pair.partition("=")
result[k.strip()] = v.strip()
return result
def _parse_param(header: str, key: str) -> str | None:
"""Extract a single parameter value from a header string."""
return _parse_header_params(header).get(key)
def _parse_scram_msg(msg: str) -> dict[str, str]:
"""Parse a SCRAM message into its ``key=value`` parts."""
result: dict[str, str] = {}
for part in msg.split(","):
if "=" in part:
k, _, v = part.partition("=")
result[k] = v
return result
_HASH_ALGORITHMS: dict[str, hashes.HashAlgorithm] = {
"sha256": hashes.SHA256(),
"sha512": hashes.SHA512(),
}
def _get_hash_algorithm(algo: str) -> hashes.HashAlgorithm:
"""Map algorithm name to a cryptography hash instance."""
result = _HASH_ALGORITHMS.get(algo)
if result is None:
raise AuthError(f"Unsupported hash algorithm: {algo}")
return result
def _derive_key(password: bytes, salt: bytes, iterations: int, algo: str) -> bytes:
"""Derive a key using PBKDF2-HMAC via the cryptography library."""
hash_algo = _get_hash_algorithm(algo)
kdf = PBKDF2HMAC(
algorithm=hash_algo,
length=hash_algo.digest_size,
salt=salt,
iterations=iterations,
)
return kdf.derive(password)
def _hmac(algo: str, key: bytes, data: bytes) -> bytes:
"""Compute HMAC with the given hash algorithm via the cryptography library."""
h = HMAC(key, _get_hash_algorithm(algo))
h.update(data)
return h.finalize()
def _hash_digest(algo: str, data: bytes) -> bytes:
"""Compute a hash digest via the cryptography library."""
h = hashes.Hash(_get_hash_algorithm(algo))
h.update(data)
return h.finalize()
_HASH_ALGO_NAMES: dict[str, str] = {
"SHA-256": "sha256",
"SHA-512": "sha512",
}
def _hash_algo(name: str) -> str:
"""Map Haystack hash name to internal algorithm name."""
result = _HASH_ALGO_NAMES.get(name)
if result is None:
raise AuthError(f"Unsupported hash algorithm: {name}")
return result