chore: cleaning some code

This commit is contained in:
2026-02-28 23:30:16 +01:00
parent 233ce3be34
commit 803e3a2972
6 changed files with 81 additions and 30 deletions

View File

@@ -3,7 +3,10 @@ from typing import Dict
def compute_agent_probability(
trajectory: list, human_transitions: Dict, agent_transitions: Dict
trajectory: list,
human_transitions: Dict,
agent_transitions: Dict,
temperature: float = 1.0,
) -> float:
"""estimate agent probability via KL divergence between trajectory transitions and reference models
@@ -52,9 +55,9 @@ def compute_agent_probability(
kl_agent = kl_div(empirical, agent_transitions)
# convert to probability via softmax (lower KL = higher prob)
# agent_prob = exp(-kl_agent) / (exp(-kl_human) + exp(-kl_agent))
exp_h = np.exp(-kl_human)
exp_a = np.exp(-kl_agent)
t = float(max(temperature, 1e-6))
exp_h = np.exp(-kl_human / t)
exp_a = np.exp(-kl_agent / t)
return float(exp_a / (exp_h + exp_a + 1e-10))