updating node positinoing

This commit is contained in:
2026-03-27 17:19:27 +01:00
parent 18b41ff802
commit 58042ba4f2
3 changed files with 76 additions and 19 deletions

Binary file not shown.

Binary file not shown.

View File

@@ -3,7 +3,7 @@ try:
except ImportError:
from sim.rl.behavior_loader.loader import Loader, AgentLoader, JointLoader
from collections import defaultdict
from typing import Dict, List, Tuple, Set
from typing import Dict, List, Optional, Set, Tuple
import numpy as np
import graphviz
import sys
@@ -195,6 +195,35 @@ def aggregate_event_transitions(mdp: Dict) -> Dict[str, Dict[str, float]]:
return dict(evt_trans)
def _resolve_event_order(
evt_trans: Dict[str, Dict[str, float]],
event_order: Optional[List[str]] = None,
) -> List[str]:
observed = set(evt_trans.keys()) | {
dst for transitions in evt_trans.values() for dst in transitions
}
if event_order:
ordered = list(dict.fromkeys(event_order))
missing = sorted(observed - set(ordered))
return ordered + missing
return sorted(observed)
def _fixed_circle_positions(
events: List[str], radius: float
) -> Dict[str, Tuple[float, float]]:
if not events:
return {}
step = (2 * np.pi) / len(events)
return {
evt: (
float(radius * np.cos(idx * step)),
float(radius * np.sin(idx * step)),
)
for idx, evt in enumerate(events)
}
def visualize_mdp(
model: BehaviorModel,
threshold: float = 0.05,
@@ -202,20 +231,31 @@ def visualize_mdp(
fmt: str = "svg",
view: bool = False,
export_dot: bool = False,
event_order: Optional[List[str]] = None,
layout_radius: float = 6.0,
node_diameter: float = 2.4,
):
if not model.mdp:
raise ValueError("build MDP first")
evt_trans = aggregate_event_transitions(model.mdp)
g = graphviz.Digraph(format=fmt)
g.attr(rankdir="LR", size="30")
g.attr("node", shape="circle", width="1", height="1")
ordered_events = _resolve_event_order(evt_trans, event_order=event_order)
positions = _fixed_circle_positions(ordered_events, radius=layout_radius)
events = set(evt_trans.keys()) | {
e for trans in evt_trans.values() for e in trans.keys()
}
for evt in events:
g.node(evt)
g = graphviz.Digraph(format=fmt, engine="neato")
g.attr(overlap="false", splines="true", outputorder="edgesfirst")
g.attr(
"node",
shape="circle",
width=f"{node_diameter:.2f}",
height=f"{node_diameter:.2f}",
fixedsize="true",
fontsize="10",
)
for evt in ordered_events:
x_pos, y_pos = positions[evt]
g.node(evt, pos=f"{x_pos:.3f},{y_pos:.3f}!", pin="true")
for src, dsts in evt_trans.items():
for dst, prob in dsts.items():
@@ -342,11 +382,6 @@ if __name__ == "__main__":
f"Built MDP: {human_mdp['num_states']} states, "
f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions"
)
if not human_mdp["states"]:
exit("No states found")
visualize_mdp(
human_model, threshold=0.05, output="human_mdp_viz", fmt="pdf", export_dot=True
)
agent_model = AgentBehaviorModel(agent_dir)
agent_mdp = agent_model.build_MDP()
@@ -355,14 +390,35 @@ if __name__ == "__main__":
f"AGENT... Built MDP: {agent_mdp['num_states']} states, "
f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions"
)
if not agent_mdp["states"]:
exit("No states found")
visualize_mdp(
agent_model, threshold=0.05, output="agent_mdp_viz", fmt="pdf", export_dot=True
)
human_evt = aggregate_event_transitions(human_mdp)
agent_evt = aggregate_event_transitions(agent_mdp)
canonical_events = sorted(
(set(human_evt.keys()) | {e for tr in human_evt.values() for e in tr.keys()})
| (set(agent_evt.keys()) | {e for tr in agent_evt.values() for e in tr.keys()})
)
if not human_mdp["states"]:
exit("No states found")
visualize_mdp(
human_model,
threshold=0.05,
output="human_mdp_viz",
fmt="pdf",
export_dot=True,
event_order=canonical_events,
)
if not agent_mdp["states"]:
exit("No states found")
visualize_mdp(
agent_model,
threshold=0.05,
output="agent_mdp_viz",
fmt="pdf",
export_dot=True,
event_order=canonical_events,
)
common = set(human_evt.keys()) & set(agent_evt.keys())
@@ -394,6 +450,7 @@ if __name__ == "__main__":
output="joint_mdp_viz",
fmt="pdf",
export_dot=True,
event_order=canonical_events,
)
inter_class_avg = float(np.mean([kl for _, kl in kl_divs]))