mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
chore: cleaning some code
This commit is contained in:
@@ -3,7 +3,10 @@ from typing import Dict
|
||||
|
||||
|
||||
def compute_agent_probability(
|
||||
trajectory: list, human_transitions: Dict, agent_transitions: Dict
|
||||
trajectory: list,
|
||||
human_transitions: Dict,
|
||||
agent_transitions: Dict,
|
||||
temperature: float = 1.0,
|
||||
) -> float:
|
||||
"""estimate agent probability via KL divergence between trajectory transitions and reference models
|
||||
|
||||
@@ -52,9 +55,9 @@ def compute_agent_probability(
|
||||
kl_agent = kl_div(empirical, agent_transitions)
|
||||
|
||||
# convert to probability via softmax (lower KL = higher prob)
|
||||
# agent_prob = exp(-kl_agent) / (exp(-kl_human) + exp(-kl_agent))
|
||||
exp_h = np.exp(-kl_human)
|
||||
exp_a = np.exp(-kl_agent)
|
||||
t = float(max(temperature, 1e-6))
|
||||
exp_h = np.exp(-kl_human / t)
|
||||
exp_a = np.exp(-kl_agent / t)
|
||||
return float(exp_a / (exp_h + exp_a + 1e-10))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user