diff --git a/paper/src/chapters/mdp_agent.pdf b/paper/src/chapters/mdp_agent.pdf index aeab1b7..24d141e 100644 Binary files a/paper/src/chapters/mdp_agent.pdf and b/paper/src/chapters/mdp_agent.pdf differ diff --git a/paper/src/chapters/mdp_human.pdf b/paper/src/chapters/mdp_human.pdf index b753b4e..6ae3aa3 100644 Binary files a/paper/src/chapters/mdp_human.pdf and b/paper/src/chapters/mdp_human.pdf differ diff --git a/sim/rl/behavior_loader/models.py b/sim/rl/behavior_loader/models.py index 25c8c15..0b1d95e 100644 --- a/sim/rl/behavior_loader/models.py +++ b/sim/rl/behavior_loader/models.py @@ -310,12 +310,24 @@ def visualize_mdp( layout_radius: float = 10.0, node_diameter: float = 1.8, label_threshold: float = 0.08, + drop_isolated_nodes: bool = False, ): if not model.mdp: raise ValueError("build MDP first") evt_trans = aggregate_event_transitions(model.mdp) ordered_events = _resolve_event_order(evt_trans, event_order=event_order) + + edges = [ + (src, dst, prob) + for src, dsts in evt_trans.items() + for dst, prob in dsts.items() + if prob > threshold + ] + if drop_isolated_nodes: + connected = {src for src, _, _ in edges} | {dst for _, dst, _ in edges} + ordered_events = [evt for evt in ordered_events if evt in connected] + positions = _compute_flow_positions(ordered_events, layout_radius=layout_radius) g = graphviz.Digraph(format=fmt, engine="neato") @@ -353,15 +365,14 @@ def visualize_mdp( x, y = positions[evt] g.node(evt, label=_format_node_label(evt), pos=f"{x:.2f},{y:.2f}!", pin="true") - edges = [ - (src, dst, prob) - for src, dsts in evt_trans.items() - for dst, prob in dsts.items() - if prob > threshold - ] - edge_set = {(src, dst) for src, dst, _ in edges} + edge_set = { + (src, dst) for src, dst, _ in edges if src in positions and dst in positions + } for src, dst, prob in sorted(edges, key=lambda row: row[2]): + if src not in positions or dst not in positions: + continue + edge_attrs: Dict[str, str] = _edge_style(prob) if src == dst: @@ -537,6 +548,7 @@ if __name__ == "__main__": fmt="pdf", export_dot=True, event_order=canonical_events, + drop_isolated_nodes=True, ) common = set(human_evt.keys()) & set(agent_evt.keys())