Files
cvfs/dlib/auth/oidc.py

118 lines
4.0 KiB
Python

from __future__ import annotations
import time
from typing import Any
import httpx
from jose import JWTError, jwt
from pydantic import BaseModel, Field
class AuthenticatedUser(BaseModel):
sub: str
email: str | None = None
name: str | None = None
picture: str | None = None
roles: list[str] = Field(default_factory=list)
class TokenValidationError(Exception):
pass
def _normalize_issuer(value: str | None) -> tuple[str | None, str | None]:
if not value:
return None, None
raw = value.strip().rstrip("/")
normalized = raw.replace("/application/o/authorize/", "/application/o/")
normalized = normalized.replace("/application/o/authorize", "/application/o")
normalized = normalized.rstrip("/")
return raw, normalized if normalized != raw else raw
class OidcTokenValidator:
def __init__(
self,
*,
issuer: str | None,
audience: str | None,
jwks_url: str | None = None,
disable: bool = False,
) -> None:
raw_issuer, discovery_issuer = _normalize_issuer(issuer)
self.issuer = raw_issuer
self.audience = audience
self.jwks_url = jwks_url
self.discovery_url = (
f"{(discovery_issuer or raw_issuer).rstrip('/')}/.well-known/openid-configuration"
if (discovery_issuer or raw_issuer)
else None
)
self.disable = disable or not raw_issuer
self._jwks: dict[str, Any] | None = None
self._jwks_expiry: float = 0
async def validate(self, token: str) -> AuthenticatedUser:
if self.disable or not token:
return AuthenticatedUser(
sub="dev-user", email="dev@example.com", name="Developer"
)
header = jwt.get_unverified_header(token)
alg = header.get("alg") or "RS256"
jwks = await self._get_jwks()
if not jwks:
raise TokenValidationError("Unable to resolve signing keys")
try:
claims = jwt.decode(
token,
jwks,
algorithms=[alg],
options={"verify_aud": False, "verify_iss": False},
)
iss = claims.get("iss")
if self.issuer and iss not in (self.issuer, self.issuer + "/"):
# fallback: check if it matches discovery host
if not (iss and iss.startswith(self.issuer.split("/application/")[0])):
raise TokenValidationError(f"Invalid issuer: {iss}")
except JWTError as exc:
raise TokenValidationError(str(exc)) from exc
roles = claims.get("roles") or claims.get("app_metadata", {}).get("roles") or []
if isinstance(roles, str):
roles = [roles]
return AuthenticatedUser(
sub=str(claims.get("sub")),
email=claims.get("email"),
name=claims.get("name"),
picture=claims.get("picture"),
roles=roles,
)
async def _ensure_jwks_url(self) -> None:
if self.jwks_url or not self.discovery_url:
return
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(self.discovery_url)
response.raise_for_status()
data = response.json()
jwks_uri = data.get("jwks_uri")
if isinstance(jwks_uri, str):
self.jwks_url = jwks_uri
async def _get_jwks(self) -> dict[str, Any] | None:
await self._ensure_jwks_url()
if not self.jwks_url:
return None
if not self._jwks or time.time() > self._jwks_expiry:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(self.jwks_url)
response.raise_for_status()
self._jwks = response.json()
self._jwks_expiry = time.time() + 3600
return self._jwks
def build_validator(
*, issuer: str | None, audience: str | None, disable: bool
) -> OidcTokenValidator:
return OidcTokenValidator(issuer=issuer, audience=audience, disable=disable)