unified separability writing

This commit is contained in:
2026-03-23 21:47:31 +01:00
parent 910dba0a7d
commit 220b6ce8c1
4 changed files with 129 additions and 66 deletions

View File

@@ -1,12 +1,15 @@
import numpy as np
from typing import Dict
from lib.agent_probability import DEFAULT_AGENT_PRIOR, estimate_agent_probability
def compute_agent_probability(
trajectory: list,
human_transitions: Dict,
agent_transitions: Dict,
temperature: float = 1.0,
prior_agent: float = DEFAULT_AGENT_PRIOR,
) -> float:
"""estimate agent probability via KL divergence between trajectory transitions and reference models
@@ -18,10 +21,10 @@ def compute_agent_probability(
agent_transitions: reference transition dict from agent MDP (event->event->prob)
returns:
agent probability in [0, 1] via softmax over KL divergences
agent probability in [0, 1] via sigma((delta_h - delta_a) / T)
"""
if len(trajectory) < 2:
return 0.0 # insufficient data, assume human
return float(prior_agent)
# build empirical transition distribution from trajectory
trans_counts = {}
@@ -54,11 +57,12 @@ def compute_agent_probability(
kl_human = kl_div(empirical, human_transitions)
kl_agent = kl_div(empirical, agent_transitions)
# convert to probability via softmax (lower KL = higher prob)
t = float(max(temperature, 1e-6))
exp_h = np.exp(-kl_human / t)
exp_a = np.exp(-kl_agent / t)
return float(exp_a / (exp_h + exp_a + 1e-10))
return estimate_agent_probability(
delta_h=kl_human,
delta_a=kl_agent,
temperature=temperature,
prior_agent=prior_agent,
)
def extract_purchases(trajectories: list) -> Dict[int, int]: