mirror of
https://github.com/velocitatem/cvfs.git
synced 2026-05-31 08:43:37 +00:00
Finish MVP and dockerize
This commit is contained in:
92
dlib/auth/oidc.py
Normal file
92
dlib/auth/oidc.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AuthenticatedUser(BaseModel):
|
||||
sub: str
|
||||
email: str | None = None
|
||||
name: str | None = None
|
||||
picture: str | None = None
|
||||
roles: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TokenValidationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class OidcTokenValidator:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
issuer: str | None,
|
||||
audience: str | None,
|
||||
jwks_url: str | None = None,
|
||||
disable: bool = False,
|
||||
) -> None:
|
||||
self.issuer = issuer
|
||||
self.audience = audience
|
||||
self.jwks_url = jwks_url or (
|
||||
f"{issuer.rstrip('/')}/.well-known/jwks.json" if issuer else None
|
||||
)
|
||||
self.disable = disable or not issuer
|
||||
self._jwks: dict[str, Any] | None = None
|
||||
self._jwks_expiry: float = 0
|
||||
|
||||
async def validate(self, token: str) -> AuthenticatedUser:
|
||||
if self.disable or not token:
|
||||
return AuthenticatedUser(
|
||||
sub="dev-user", email="dev@example.com", name="Developer"
|
||||
)
|
||||
header = jwt.get_unverified_header(token)
|
||||
key = await self._get_key(header.get("kid"))
|
||||
if not key:
|
||||
raise TokenValidationError("Unable to resolve signing key")
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
key,
|
||||
algorithms=[key.get("alg", "RS256")],
|
||||
audience=self.audience,
|
||||
issuer=self.issuer,
|
||||
)
|
||||
except JWTError as exc:
|
||||
raise TokenValidationError(str(exc)) from exc
|
||||
roles = claims.get("roles") or claims.get("app_metadata", {}).get("roles") or []
|
||||
if isinstance(roles, str):
|
||||
roles = [roles]
|
||||
return AuthenticatedUser(
|
||||
sub=str(claims.get("sub")),
|
||||
email=claims.get("email"),
|
||||
name=claims.get("name"),
|
||||
picture=claims.get("picture"),
|
||||
roles=roles,
|
||||
)
|
||||
|
||||
async def _get_key(self, kid: str | None) -> dict[str, Any] | None:
|
||||
if not self.jwks_url:
|
||||
return None
|
||||
if not self._jwks or time.time() > self._jwks_expiry:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(self.jwks_url)
|
||||
response.raise_for_status()
|
||||
self._jwks = response.json()
|
||||
self._jwks_expiry = time.time() + 3600
|
||||
keys = self._jwks.get("keys", []) if isinstance(self._jwks, dict) else []
|
||||
if kid:
|
||||
for key in keys:
|
||||
if key.get("kid") == kid:
|
||||
return key
|
||||
return keys[0] if keys else None
|
||||
|
||||
|
||||
def build_validator(
|
||||
*, issuer: str | None, audience: str | None, disable: bool
|
||||
) -> OidcTokenValidator:
|
||||
return OidcTokenValidator(issuer=issuer, audience=audience, disable=disable)
|
||||
Reference in New Issue
Block a user