updating bins

This commit is contained in:
2026-03-28 13:18:08 +01:00
parent 473342f103
commit 59b2b46f6e
3 changed files with 19 additions and 7 deletions

Binary file not shown.

Binary file not shown.

View File

@@ -310,12 +310,24 @@ def visualize_mdp(
layout_radius: float = 10.0, layout_radius: float = 10.0,
node_diameter: float = 1.8, node_diameter: float = 1.8,
label_threshold: float = 0.08, label_threshold: float = 0.08,
drop_isolated_nodes: bool = False,
): ):
if not model.mdp: if not model.mdp:
raise ValueError("build MDP first") raise ValueError("build MDP first")
evt_trans = aggregate_event_transitions(model.mdp) evt_trans = aggregate_event_transitions(model.mdp)
ordered_events = _resolve_event_order(evt_trans, event_order=event_order) ordered_events = _resolve_event_order(evt_trans, event_order=event_order)
edges = [
(src, dst, prob)
for src, dsts in evt_trans.items()
for dst, prob in dsts.items()
if prob > threshold
]
if drop_isolated_nodes:
connected = {src for src, _, _ in edges} | {dst for _, dst, _ in edges}
ordered_events = [evt for evt in ordered_events if evt in connected]
positions = _compute_flow_positions(ordered_events, layout_radius=layout_radius) positions = _compute_flow_positions(ordered_events, layout_radius=layout_radius)
g = graphviz.Digraph(format=fmt, engine="neato") g = graphviz.Digraph(format=fmt, engine="neato")
@@ -353,15 +365,14 @@ def visualize_mdp(
x, y = positions[evt] x, y = positions[evt]
g.node(evt, label=_format_node_label(evt), pos=f"{x:.2f},{y:.2f}!", pin="true") g.node(evt, label=_format_node_label(evt), pos=f"{x:.2f},{y:.2f}!", pin="true")
edges = [ edge_set = {
(src, dst, prob) (src, dst) for src, dst, _ in edges if src in positions and dst in positions
for src, dsts in evt_trans.items() }
for dst, prob in dsts.items()
if prob > threshold
]
edge_set = {(src, dst) for src, dst, _ in edges}
for src, dst, prob in sorted(edges, key=lambda row: row[2]): for src, dst, prob in sorted(edges, key=lambda row: row[2]):
if src not in positions or dst not in positions:
continue
edge_attrs: Dict[str, str] = _edge_style(prob) edge_attrs: Dict[str, str] = _edge_style(prob)
if src == dst: if src == dst:
@@ -537,6 +548,7 @@ if __name__ == "__main__":
fmt="pdf", fmt="pdf",
export_dot=True, export_dot=True,
event_order=canonical_events, event_order=canonical_events,
drop_isolated_nodes=True,
) )
common = set(human_evt.keys()) & set(agent_evt.keys()) common = set(human_evt.keys()) & set(agent_evt.keys())