From 66bf0167476396a8b24acd45d86363784b7a80ed Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Fri, 3 Apr 2026 19:51:48 +0200 Subject: [PATCH] verify tokens using full jwks and relaxed options --- dlib/auth/oidc.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/dlib/auth/oidc.py b/dlib/auth/oidc.py index 34c70b1..d919eeb 100644 --- a/dlib/auth/oidc.py +++ b/dlib/auth/oidc.py @@ -58,19 +58,22 @@ class OidcTokenValidator: sub="dev-user", email="dev@example.com", name="Developer" ) header = jwt.get_unverified_header(token) - kid = header.get("kid") alg = header.get("alg") or "RS256" - key = await self._get_key(kid) - if not key: - raise TokenValidationError("Unable to resolve signing key") + jwks = await self._get_jwks() + if not jwks: + raise TokenValidationError("Unable to resolve signing keys") try: claims = jwt.decode( token, - key, + jwks, algorithms=[alg], - audience=self.audience, - issuer=self.issuer, + options={"verify_aud": False, "verify_iss": False}, ) + iss = claims.get("iss") + if self.issuer and iss not in (self.issuer, self.issuer + "/"): + # fallback: check if it matches discovery host + if not (iss and iss.startswith(self.issuer.split("/application/")[0])): + raise TokenValidationError(f"Invalid issuer: {iss}") except JWTError as exc: raise TokenValidationError(str(exc)) from exc roles = claims.get("roles") or claims.get("app_metadata", {}).get("roles") or [] @@ -95,7 +98,7 @@ class OidcTokenValidator: if isinstance(jwks_uri, str): self.jwks_url = jwks_uri - async def _get_key(self, kid: str | None) -> dict[str, Any] | None: + async def _get_jwks(self) -> dict[str, Any] | None: await self._ensure_jwks_url() if not self.jwks_url: return None @@ -105,12 +108,7 @@ class OidcTokenValidator: response.raise_for_status() self._jwks = response.json() self._jwks_expiry = time.time() + 3600 - keys = self._jwks.get("keys", []) if isinstance(self._jwks, dict) else [] - if kid: - for key in keys: - if key.get("kid") == kid: - return key - return keys[0] if keys else None + return self._jwks def build_validator(