Files
cvfs/dlib/auth/oidc.py

105 lines
3.3 KiB
Python

from __future__ import annotations
import time
from functools import cached_property
from typing import Any
import httpx
from jose import JWTError, jwt
from pydantic import BaseModel, Field
class AuthenticatedUser(BaseModel):
sub: str
email: str | None = None
name: str | None = None
picture: str | None = None
roles: list[str] = Field(default_factory=list)
class TokenValidationError(Exception):
pass
def _normalize_issuer(value: str | None) -> str | None:
if not value:
return None
normalized = value.strip().rstrip("/")
normalized = normalized.replace("/application/o/authorize", "")
normalized = normalized.rstrip("/")
return f"{normalized}/application/o"
class OidcTokenValidator:
def __init__(
self,
*,
issuer: str | None,
audience: str | None,
jwks_url: str | None = None,
disable: bool = False,
) -> None:
normalized_issuer = _normalize_issuer(issuer)
self.issuer = normalized_issuer
self.audience = audience
self.jwks_url = jwks_url or (
f"{normalized_issuer.rstrip('/')}/.well-known/jwks.json"
if normalized_issuer
else None
)
self.disable = disable or not normalized_issuer
self._jwks: dict[str, Any] | None = None
self._jwks_expiry: float = 0
async def validate(self, token: str) -> AuthenticatedUser:
if self.disable or not token:
return AuthenticatedUser(
sub="dev-user", email="dev@example.com", name="Developer"
)
header = jwt.get_unverified_header(token)
key = await self._get_key(header.get("kid"))
if not key:
raise TokenValidationError("Unable to resolve signing key")
try:
claims = jwt.decode(
token,
key,
algorithms=[key.get("alg", "RS256")],
audience=self.audience,
issuer=self.issuer,
)
except JWTError as exc:
raise TokenValidationError(str(exc)) from exc
roles = claims.get("roles") or claims.get("app_metadata", {}).get("roles") or []
if isinstance(roles, str):
roles = [roles]
return AuthenticatedUser(
sub=str(claims.get("sub")),
email=claims.get("email"),
name=claims.get("name"),
picture=claims.get("picture"),
roles=roles,
)
async def _get_key(self, kid: str | None) -> dict[str, Any] | None:
if not self.jwks_url:
return None
if not self._jwks or time.time() > self._jwks_expiry:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(self.jwks_url)
response.raise_for_status()
self._jwks = response.json()
self._jwks_expiry = time.time() + 3600
keys = self._jwks.get("keys", []) if isinstance(self._jwks, dict) else []
if kid:
for key in keys:
if key.get("kid") == kid:
return key
return keys[0] if keys else None
def build_validator(
*, issuer: str | None, audience: str | None, disable: bool
) -> OidcTokenValidator:
return OidcTokenValidator(issuer=issuer, audience=audience, disable=disable)