diff --git a/dlib/auth/oidc.py b/dlib/auth/oidc.py index 50f3de8..55d236c 100644 --- a/dlib/auth/oidc.py +++ b/dlib/auth/oidc.py @@ -55,8 +55,9 @@ class OidcTokenValidator: normalized_issuer = _normalize_issuer(issuer) self.issuer = normalized_issuer self.audience = audience - self.jwks_url = jwks_url or ( - f"{normalized_issuer.rstrip('/')}/.well-known/jwks.json" + self.jwks_url = jwks_url + self.discovery_url = ( + f"{normalized_issuer.rstrip('/')}/.well-known/openid-configuration" if normalized_issuer else None ) @@ -94,7 +95,19 @@ class OidcTokenValidator: roles=roles, ) + async def _ensure_jwks_url(self) -> None: + if self.jwks_url or not self.discovery_url: + return + async with httpx.AsyncClient(timeout=10) as client: + response = await client.get(self.discovery_url) + response.raise_for_status() + data = response.json() + jwks_uri = data.get("jwks_uri") + if isinstance(jwks_uri, str): + self.jwks_url = jwks_uri + async def _get_key(self, kid: str | None) -> dict[str, Any] | None: + await self._ensure_jwks_url() if not self.jwks_url: return None if not self._jwks or time.time() > self._jwks_expiry: