mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
"""Vectorized KL divergence for separability scoring."""
|
|
|
|
import numpy as np
|
|
from typing import Tuple
|
|
|
|
from lib.agent_probability import (
|
|
DEFAULT_AGENT_PRIOR,
|
|
estimate_agent_probability_batch,
|
|
)
|
|
|
|
try:
|
|
import jax.numpy as jnp
|
|
from jax import jit
|
|
|
|
JAX_AVAILABLE = True
|
|
except ImportError:
|
|
jnp, JAX_AVAILABLE = np, False
|
|
|
|
def jit(f):
|
|
return f
|
|
|
|
|
|
@jit
|
|
def batch_kl(P, Q_human, Q_agent, eps=1e-10):
|
|
"""Compute KL(P||Q) for batched P. P:(n,s,s), Q:(s,s). Returns (delta_h, delta_a) each (n,)."""
|
|
p = P + eps
|
|
p = p / p.sum(axis=-1, keepdims=True)
|
|
qh, qa = Q_human[None] + eps, Q_agent[None] + eps
|
|
delta_h = jnp.sum(p * jnp.log(p / qh), axis=(1, 2))
|
|
delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2))
|
|
return delta_h, delta_a
|
|
|
|
|
|
def compute_divergences(
|
|
session_trans: np.ndarray, ref_human: np.ndarray, ref_agent: np.ndarray
|
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""Compute KL divergence of each session from human/agent prototypes."""
|
|
if JAX_AVAILABLE:
|
|
dh, da = batch_kl(
|
|
jnp.array(session_trans), jnp.array(ref_human), jnp.array(ref_agent)
|
|
)
|
|
return np.asarray(dh), np.asarray(da)
|
|
# numpy fallback
|
|
eps = 1e-10
|
|
p = session_trans + eps
|
|
p = p / p.sum(axis=-1, keepdims=True)
|
|
qh, qa = ref_human[None] + eps, ref_agent[None] + eps
|
|
delta_h = np.sum(p * np.log(p / qh), axis=(1, 2))
|
|
delta_a = np.sum(p * np.log(p / qa), axis=(1, 2))
|
|
return delta_h, delta_a
|
|
|
|
|
|
def estimate_alpha_batch(
|
|
prob_agent: np.ndarray,
|
|
delta_h: np.ndarray,
|
|
delta_a: np.ndarray,
|
|
temp: float = 1.0,
|
|
prior_agent: float = DEFAULT_AGENT_PRIOR,
|
|
) -> np.ndarray:
|
|
"""Vectorized alpha estimation using divergence gap mapping."""
|
|
_ = prob_agent
|
|
return estimate_agent_probability_batch(
|
|
delta_h=np.asarray(delta_h, dtype=float),
|
|
delta_a=np.asarray(delta_a, dtype=float),
|
|
temperature=temp,
|
|
prior_agent=prior_agent,
|
|
)
|