mirror of
https://github.com/velocitatem/cvfs.git
synced 2026-05-31 08:43:37 +00:00
120 lines
4.0 KiB
Python
120 lines
4.0 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from jose import JWTError, jwt
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class AuthenticatedUser(BaseModel):
|
|
sub: str
|
|
email: str | None = None
|
|
name: str | None = None
|
|
picture: str | None = None
|
|
roles: list[str] = Field(default_factory=list)
|
|
|
|
|
|
class TokenValidationError(Exception):
|
|
pass
|
|
|
|
|
|
def _normalize_issuer(value: str | None) -> tuple[str | None, str | None]:
|
|
if not value:
|
|
return None, None
|
|
raw = value.strip().rstrip("/")
|
|
normalized = raw.replace("/application/o/authorize/", "/application/o/")
|
|
normalized = normalized.replace("/application/o/authorize", "/application/o")
|
|
normalized = normalized.rstrip("/")
|
|
return raw, normalized if normalized != raw else raw
|
|
|
|
|
|
class OidcTokenValidator:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
issuer: str | None,
|
|
audience: str | None,
|
|
jwks_url: str | None = None,
|
|
disable: bool = False,
|
|
) -> None:
|
|
raw_issuer, discovery_issuer = _normalize_issuer(issuer)
|
|
self.issuer = raw_issuer
|
|
self.audience = audience
|
|
self.jwks_url = jwks_url
|
|
self.discovery_url = (
|
|
f"{(discovery_issuer or raw_issuer).rstrip('/')}/.well-known/openid-configuration"
|
|
if (discovery_issuer or raw_issuer)
|
|
else None
|
|
)
|
|
self.disable = disable or not raw_issuer
|
|
self._jwks: dict[str, Any] | None = None
|
|
self._jwks_expiry: float = 0
|
|
|
|
async def validate(self, token: str) -> AuthenticatedUser:
|
|
if self.disable or not token:
|
|
return AuthenticatedUser(
|
|
sub="dev-user", email="dev@example.com", name="Developer"
|
|
)
|
|
header = jwt.get_unverified_header(token)
|
|
kid = header.get("kid")
|
|
alg = header.get("alg") or "RS256"
|
|
key = await self._get_key(kid)
|
|
if not key:
|
|
raise TokenValidationError("Unable to resolve signing key")
|
|
try:
|
|
claims = jwt.decode(
|
|
token,
|
|
key,
|
|
algorithms=[alg],
|
|
audience=self.audience,
|
|
issuer=self.issuer,
|
|
)
|
|
except JWTError as exc:
|
|
raise TokenValidationError(str(exc)) from exc
|
|
roles = claims.get("roles") or claims.get("app_metadata", {}).get("roles") or []
|
|
if isinstance(roles, str):
|
|
roles = [roles]
|
|
return AuthenticatedUser(
|
|
sub=str(claims.get("sub")),
|
|
email=claims.get("email"),
|
|
name=claims.get("name"),
|
|
picture=claims.get("picture"),
|
|
roles=roles,
|
|
)
|
|
|
|
async def _ensure_jwks_url(self) -> None:
|
|
if self.jwks_url or not self.discovery_url:
|
|
return
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
response = await client.get(self.discovery_url)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
jwks_uri = data.get("jwks_uri")
|
|
if isinstance(jwks_uri, str):
|
|
self.jwks_url = jwks_uri
|
|
|
|
async def _get_key(self, kid: str | None) -> dict[str, Any] | None:
|
|
await self._ensure_jwks_url()
|
|
if not self.jwks_url:
|
|
return None
|
|
if not self._jwks or time.time() > self._jwks_expiry:
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
response = await client.get(self.jwks_url)
|
|
response.raise_for_status()
|
|
self._jwks = response.json()
|
|
self._jwks_expiry = time.time() + 3600
|
|
keys = self._jwks.get("keys", []) if isinstance(self._jwks, dict) else []
|
|
if kid:
|
|
for key in keys:
|
|
if key.get("kid") == kid:
|
|
return key
|
|
return keys[0] if keys else None
|
|
|
|
|
|
def build_validator(
|
|
*, issuer: str | None, audience: str | None, disable: bool
|
|
) -> OidcTokenValidator:
|
|
return OidcTokenValidator(issuer=issuer, audience=audience, disable=disable)
|