diff --git a/paper/src/chapters/mdp_agent.pdf b/paper/src/chapters/mdp_agent.pdf index 0566be9..6845eb5 100644 Binary files a/paper/src/chapters/mdp_agent.pdf and b/paper/src/chapters/mdp_agent.pdf differ diff --git a/paper/src/chapters/mdp_human.pdf b/paper/src/chapters/mdp_human.pdf index 7cef37a..69bc8d3 100644 Binary files a/paper/src/chapters/mdp_human.pdf and b/paper/src/chapters/mdp_human.pdf differ diff --git a/sim/rl/behavior_loader/models.py b/sim/rl/behavior_loader/models.py index cb67cbf..0b1a285 100644 --- a/sim/rl/behavior_loader/models.py +++ b/sim/rl/behavior_loader/models.py @@ -3,7 +3,7 @@ try: except ImportError: from sim.rl.behavior_loader.loader import Loader, AgentLoader, JointLoader from collections import defaultdict -from typing import Dict, List, Tuple, Set +from typing import Dict, List, Optional, Set, Tuple import numpy as np import graphviz import sys @@ -195,6 +195,35 @@ def aggregate_event_transitions(mdp: Dict) -> Dict[str, Dict[str, float]]: return dict(evt_trans) +def _resolve_event_order( + evt_trans: Dict[str, Dict[str, float]], + event_order: Optional[List[str]] = None, +) -> List[str]: + observed = set(evt_trans.keys()) | { + dst for transitions in evt_trans.values() for dst in transitions + } + if event_order: + ordered = list(dict.fromkeys(event_order)) + missing = sorted(observed - set(ordered)) + return ordered + missing + return sorted(observed) + + +def _fixed_circle_positions( + events: List[str], radius: float +) -> Dict[str, Tuple[float, float]]: + if not events: + return {} + step = (2 * np.pi) / len(events) + return { + evt: ( + float(radius * np.cos(idx * step)), + float(radius * np.sin(idx * step)), + ) + for idx, evt in enumerate(events) + } + + def visualize_mdp( model: BehaviorModel, threshold: float = 0.05, @@ -202,20 +231,31 @@ def visualize_mdp( fmt: str = "svg", view: bool = False, export_dot: bool = False, + event_order: Optional[List[str]] = None, + layout_radius: float = 6.0, + node_diameter: float = 2.4, ): if not model.mdp: raise ValueError("build MDP first") evt_trans = aggregate_event_transitions(model.mdp) - g = graphviz.Digraph(format=fmt) - g.attr(rankdir="LR", size="30") - g.attr("node", shape="circle", width="1", height="1") + ordered_events = _resolve_event_order(evt_trans, event_order=event_order) + positions = _fixed_circle_positions(ordered_events, radius=layout_radius) - events = set(evt_trans.keys()) | { - e for trans in evt_trans.values() for e in trans.keys() - } - for evt in events: - g.node(evt) + g = graphviz.Digraph(format=fmt, engine="neato") + g.attr(overlap="false", splines="true", outputorder="edgesfirst") + g.attr( + "node", + shape="circle", + width=f"{node_diameter:.2f}", + height=f"{node_diameter:.2f}", + fixedsize="true", + fontsize="10", + ) + + for evt in ordered_events: + x_pos, y_pos = positions[evt] + g.node(evt, pos=f"{x_pos:.3f},{y_pos:.3f}!", pin="true") for src, dsts in evt_trans.items(): for dst, prob in dsts.items(): @@ -342,11 +382,6 @@ if __name__ == "__main__": f"Built MDP: {human_mdp['num_states']} states, " f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions" ) - if not human_mdp["states"]: - exit("No states found") - visualize_mdp( - human_model, threshold=0.05, output="human_mdp_viz", fmt="pdf", export_dot=True - ) agent_model = AgentBehaviorModel(agent_dir) agent_mdp = agent_model.build_MDP() @@ -355,14 +390,35 @@ if __name__ == "__main__": f"AGENT... Built MDP: {agent_mdp['num_states']} states, " f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions" ) - if not agent_mdp["states"]: - exit("No states found") - visualize_mdp( - agent_model, threshold=0.05, output="agent_mdp_viz", fmt="pdf", export_dot=True - ) human_evt = aggregate_event_transitions(human_mdp) agent_evt = aggregate_event_transitions(agent_mdp) + canonical_events = sorted( + (set(human_evt.keys()) | {e for tr in human_evt.values() for e in tr.keys()}) + | (set(agent_evt.keys()) | {e for tr in agent_evt.values() for e in tr.keys()}) + ) + + if not human_mdp["states"]: + exit("No states found") + visualize_mdp( + human_model, + threshold=0.05, + output="human_mdp_viz", + fmt="pdf", + export_dot=True, + event_order=canonical_events, + ) + + if not agent_mdp["states"]: + exit("No states found") + visualize_mdp( + agent_model, + threshold=0.05, + output="agent_mdp_viz", + fmt="pdf", + export_dot=True, + event_order=canonical_events, + ) common = set(human_evt.keys()) & set(agent_evt.keys()) @@ -394,6 +450,7 @@ if __name__ == "__main__": output="joint_mdp_viz", fmt="pdf", export_dot=True, + event_order=canonical_events, ) inter_class_avg = float(np.mean([kl for _, kl in kl_divs]))