mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
unified separability writing
This commit is contained in:
@@ -1,14 +1,24 @@
|
||||
"""Vectorized KL divergence for separability scoring."""
|
||||
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
from lib.agent_probability import (
|
||||
DEFAULT_AGENT_PRIOR,
|
||||
estimate_agent_probability_batch,
|
||||
)
|
||||
|
||||
try:
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
|
||||
JAX_AVAILABLE = True
|
||||
except ImportError:
|
||||
jnp, JAX_AVAILABLE = np, False
|
||||
def jit(f): return f
|
||||
|
||||
def jit(f):
|
||||
return f
|
||||
|
||||
|
||||
@jit
|
||||
def batch_kl(P, Q_human, Q_agent, eps=1e-10):
|
||||
@@ -20,10 +30,15 @@ def batch_kl(P, Q_human, Q_agent, eps=1e-10):
|
||||
delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2))
|
||||
return delta_h, delta_a
|
||||
|
||||
def compute_divergences(session_trans: np.ndarray, ref_human: np.ndarray, ref_agent: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
|
||||
def compute_divergences(
|
||||
session_trans: np.ndarray, ref_human: np.ndarray, ref_agent: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Compute KL divergence of each session from human/agent prototypes."""
|
||||
if JAX_AVAILABLE:
|
||||
dh, da = batch_kl(jnp.array(session_trans), jnp.array(ref_human), jnp.array(ref_agent))
|
||||
dh, da = batch_kl(
|
||||
jnp.array(session_trans), jnp.array(ref_human), jnp.array(ref_agent)
|
||||
)
|
||||
return np.asarray(dh), np.asarray(da)
|
||||
# numpy fallback
|
||||
eps = 1e-10
|
||||
@@ -34,10 +49,19 @@ def compute_divergences(session_trans: np.ndarray, ref_human: np.ndarray, ref_ag
|
||||
delta_a = np.sum(p * np.log(p / qa), axis=(1, 2))
|
||||
return delta_h, delta_a
|
||||
|
||||
def estimate_alpha_batch(prob_agent: np.ndarray, delta_h: np.ndarray, delta_a: np.ndarray, temp: float = 1.0) -> np.ndarray:
|
||||
"""Vectorized alpha estimation from classifier probs and divergences."""
|
||||
mass = delta_h + delta_a
|
||||
ratio = np.where(mass > 1e-8, delta_a / mass, 0.5)
|
||||
blended = 0.5 * prob_agent + 0.5 * ratio
|
||||
if temp <= 0: return np.clip(blended, 0.0, 1.0)
|
||||
return np.clip(1.0 / (1.0 + np.exp(-temp * (blended - 0.5))), 0.0, 1.0)
|
||||
|
||||
def estimate_alpha_batch(
|
||||
prob_agent: np.ndarray,
|
||||
delta_h: np.ndarray,
|
||||
delta_a: np.ndarray,
|
||||
temp: float = 1.0,
|
||||
prior_agent: float = DEFAULT_AGENT_PRIOR,
|
||||
) -> np.ndarray:
|
||||
"""Vectorized alpha estimation using divergence gap mapping."""
|
||||
_ = prob_agent
|
||||
return estimate_agent_probability_batch(
|
||||
delta_h=np.asarray(delta_h, dtype=float),
|
||||
delta_a=np.asarray(delta_a, dtype=float),
|
||||
temperature=temp,
|
||||
prior_agent=prior_agent,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user