unified separability writing

This commit is contained in:
2026-03-23 21:47:31 +01:00
parent 910dba0a7d
commit 220b6ce8c1
4 changed files with 129 additions and 66 deletions

View File

@@ -3,10 +3,13 @@
Computes divergence signals delta_H, delta_A from session trajectories using
transition kernel estimation and KL divergence to prototype behavioral profiles.
"""
from __future__ import annotations
from typing import Dict, List, Tuple, TYPE_CHECKING
import numpy as np
from lib.agent_probability import DEFAULT_AGENT_PRIOR, estimate_agent_probability
if TYPE_CHECKING:
from .simplified import Event, Session
@@ -32,7 +35,10 @@ TRANS_A = {
def kl_div(p: Dict[str, float], q: Dict[str, float], eps: float = 1e-10) -> float:
"""KL divergence D_KL(p || q) for discrete distributions."""
keys = set(p.keys()) | set(q.keys())
return sum(p.get(k, eps) * np.log((p.get(k, eps) + eps) / (q.get(k, eps) + eps)) for k in keys)
return sum(
p.get(k, eps) * np.log((p.get(k, eps) + eps) / (q.get(k, eps) + eps))
for k in keys
)
def build_kernel(events: List["Event"]) -> Dict[str, Dict[str, float]]:
@@ -44,7 +50,11 @@ def build_kernel(events: List["Event"]) -> Dict[str, Dict[str, float]]:
trans.setdefault(prev, {})
trans[prev][curr] = trans[prev].get(curr, 0) + 1
prev = curr
return {s: {d: c / sum(dsts.values()) for d, c in dsts.items()} for s, dsts in trans.items() if sum(dsts.values()) > 0}
return {
s: {d: c / sum(dsts.values()) for d, c in dsts.items()}
for s, dsts in trans.items()
if sum(dsts.values()) > 0
}
def compute_divergence(session: "Session") -> Tuple[float, float]:
@@ -55,18 +65,35 @@ def compute_divergence(session: "Session") -> Tuple[float, float]:
"""
kernel = build_kernel(session.events)
if not kernel:
return 0.5, 0.5
delta_h = sum(kl_div(kernel.get(s, {}), TRANS_H.get(s, {})) for s in kernel) / len(kernel)
delta_a = sum(kl_div(kernel.get(s, {}), TRANS_A.get(s, {})) for s in kernel) / len(kernel)
return 0.0, 0.0
delta_h = sum(kl_div(kernel.get(s, {}), TRANS_H.get(s, {})) for s in kernel) / len(
kernel
)
delta_a = sum(kl_div(kernel.get(s, {}), TRANS_A.get(s, {})) for s in kernel) / len(
kernel
)
return delta_h, delta_a
def estimate_alpha(session: "Session", beta: float = 2.0) -> float:
"""Per-session contamination estimate alpha_hat = sigma(beta*(delta_H - delta_A)).
def estimate_alpha(
session: "Session",
beta: float = 2.0,
prior_agent: float = DEFAULT_AGENT_PRIOR,
) -> float:
"""Per-session contamination estimate alpha_hat = sigma((delta_H - delta_A) / T).
Returns probability session is agent-generated based on behavioral divergence.
"""
dh, da = compute_divergence(session)
if (dh + da) <= 0:
return 0.5
return 1.0 / (1.0 + np.exp(-beta * (dh - da)))
return float(prior_agent)
if beta <= 0:
return estimate_agent_probability(
dh, da, temperature=1.0, prior_agent=prior_agent
)
return estimate_agent_probability(
delta_h=dh,
delta_a=da,
temperature=1.0 / beta,
prior_agent=prior_agent,
)

View File

@@ -1,14 +1,24 @@
"""Vectorized KL divergence for separability scoring."""
import numpy as np
from typing import Tuple
from lib.agent_probability import (
DEFAULT_AGENT_PRIOR,
estimate_agent_probability_batch,
)
try:
import jax.numpy as jnp
from jax import jit
JAX_AVAILABLE = True
except ImportError:
jnp, JAX_AVAILABLE = np, False
def jit(f): return f
def jit(f):
return f
@jit
def batch_kl(P, Q_human, Q_agent, eps=1e-10):
@@ -20,10 +30,15 @@ def batch_kl(P, Q_human, Q_agent, eps=1e-10):
delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2))
return delta_h, delta_a
def compute_divergences(session_trans: np.ndarray, ref_human: np.ndarray, ref_agent: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
def compute_divergences(
session_trans: np.ndarray, ref_human: np.ndarray, ref_agent: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute KL divergence of each session from human/agent prototypes."""
if JAX_AVAILABLE:
dh, da = batch_kl(jnp.array(session_trans), jnp.array(ref_human), jnp.array(ref_agent))
dh, da = batch_kl(
jnp.array(session_trans), jnp.array(ref_human), jnp.array(ref_agent)
)
return np.asarray(dh), np.asarray(da)
# numpy fallback
eps = 1e-10
@@ -34,10 +49,19 @@ def compute_divergences(session_trans: np.ndarray, ref_human: np.ndarray, ref_ag
delta_a = np.sum(p * np.log(p / qa), axis=(1, 2))
return delta_h, delta_a
def estimate_alpha_batch(prob_agent: np.ndarray, delta_h: np.ndarray, delta_a: np.ndarray, temp: float = 1.0) -> np.ndarray:
"""Vectorized alpha estimation from classifier probs and divergences."""
mass = delta_h + delta_a
ratio = np.where(mass > 1e-8, delta_a / mass, 0.5)
blended = 0.5 * prob_agent + 0.5 * ratio
if temp <= 0: return np.clip(blended, 0.0, 1.0)
return np.clip(1.0 / (1.0 + np.exp(-temp * (blended - 0.5))), 0.0, 1.0)
def estimate_alpha_batch(
prob_agent: np.ndarray,
delta_h: np.ndarray,
delta_a: np.ndarray,
temp: float = 1.0,
prior_agent: float = DEFAULT_AGENT_PRIOR,
) -> np.ndarray:
"""Vectorized alpha estimation using divergence gap mapping."""
_ = prob_agent
return estimate_agent_probability_batch(
delta_h=np.asarray(delta_h, dtype=float),
delta_a=np.asarray(delta_a, dtype=float),
temperature=temp,
prior_agent=prior_agent,
)