mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
initial progress
This commit is contained in:
@@ -209,19 +209,94 @@ def _resolve_event_order(
|
||||
return sorted(observed)
|
||||
|
||||
|
||||
def _fixed_circle_positions(
|
||||
events: List[str], radius: float
|
||||
def _compass_from_angle(angle_rad: float) -> str:
|
||||
ports = ("e", "ne", "n", "nw", "w", "sw", "s", "se")
|
||||
normalized = (angle_rad + (2 * np.pi)) % (2 * np.pi)
|
||||
step = np.pi / 4
|
||||
idx = int(np.round(normalized / step)) % len(ports)
|
||||
return ports[idx]
|
||||
|
||||
|
||||
def _edge_ports(
|
||||
src: str,
|
||||
dst: str,
|
||||
positions: Dict[str, Tuple[float, float]],
|
||||
has_reverse: bool,
|
||||
) -> Tuple[str, str]:
|
||||
src_x, src_y = positions[src]
|
||||
dst_x, dst_y = positions[dst]
|
||||
angle = float(np.arctan2(dst_y - src_y, dst_x - src_x))
|
||||
|
||||
if has_reverse:
|
||||
bend = np.pi / 10
|
||||
angle += bend if src < dst else -bend
|
||||
|
||||
tail_port = _compass_from_angle(angle)
|
||||
head_port = _compass_from_angle(angle + np.pi)
|
||||
return tail_port, head_port
|
||||
|
||||
|
||||
def _edge_style(prob: float) -> Dict[str, str]:
|
||||
if prob >= 0.75:
|
||||
edge_color = "#111827"
|
||||
elif prob >= 0.50:
|
||||
edge_color = "#374151"
|
||||
elif prob >= 0.25:
|
||||
edge_color = "#6b7280"
|
||||
else:
|
||||
edge_color = "#9ca3af"
|
||||
return {
|
||||
"color": edge_color,
|
||||
"fontcolor": "#111827",
|
||||
"fontsize": "10",
|
||||
"penwidth": f"{0.9 + 3.6 * prob:.2f}",
|
||||
"arrowsize": f"{0.55 + 0.55 * prob:.2f}",
|
||||
}
|
||||
|
||||
|
||||
def _format_node_label(evt: str) -> str:
|
||||
max_line_len = 16
|
||||
tokens = evt.split("_")
|
||||
if len(tokens) == 1:
|
||||
return evt
|
||||
|
||||
lines: List[str] = []
|
||||
curr = ""
|
||||
for token in tokens:
|
||||
piece = token if not curr else f"_{token}"
|
||||
if curr and len(curr) + len(piece) > max_line_len:
|
||||
lines.append(curr)
|
||||
curr = token
|
||||
else:
|
||||
curr = f"{curr}{piece}" if curr else token
|
||||
if curr:
|
||||
lines.append(curr)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _compute_flow_positions(
|
||||
events: List[str],
|
||||
layout_radius: float,
|
||||
) -> Dict[str, Tuple[float, float]]:
|
||||
"""Balanced grid layout for paper-friendly diagrams."""
|
||||
if not events:
|
||||
return {}
|
||||
step = (2 * np.pi) / len(events)
|
||||
return {
|
||||
evt: (
|
||||
float(radius * np.cos(idx * step)),
|
||||
float(radius * np.sin(idx * step)),
|
||||
)
|
||||
for idx, evt in enumerate(events)
|
||||
}
|
||||
|
||||
num_events = len(events)
|
||||
cols = int(np.ceil(np.sqrt(num_events)))
|
||||
rows = int(np.ceil(num_events / cols))
|
||||
x_step = max(layout_radius * 1.10, 3.6)
|
||||
y_step = max(layout_radius * 0.95, 3.2)
|
||||
|
||||
positions: Dict[str, Tuple[float, float]] = {}
|
||||
for idx, evt in enumerate(events):
|
||||
row = idx // cols
|
||||
col = idx % cols
|
||||
x = (col - (cols - 1) / 2.0) * x_step
|
||||
y = ((rows - 1) / 2.0 - row) * y_step
|
||||
positions[evt] = (float(x), float(y))
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
def visualize_mdp(
|
||||
@@ -232,35 +307,79 @@ def visualize_mdp(
|
||||
view: bool = False,
|
||||
export_dot: bool = False,
|
||||
event_order: Optional[List[str]] = None,
|
||||
layout_radius: float = 6.0,
|
||||
node_diameter: float = 2.4,
|
||||
layout_radius: float = 10.0,
|
||||
node_diameter: float = 1.8,
|
||||
label_threshold: float = 0.08,
|
||||
):
|
||||
if not model.mdp:
|
||||
raise ValueError("build MDP first")
|
||||
|
||||
evt_trans = aggregate_event_transitions(model.mdp)
|
||||
ordered_events = _resolve_event_order(evt_trans, event_order=event_order)
|
||||
positions = _fixed_circle_positions(ordered_events, radius=layout_radius)
|
||||
positions = _compute_flow_positions(ordered_events, layout_radius=layout_radius)
|
||||
|
||||
g = graphviz.Digraph(format=fmt, engine="neato")
|
||||
g.attr(overlap="false", splines="true", outputorder="edgesfirst")
|
||||
g.attr(
|
||||
overlap="false",
|
||||
splines="true",
|
||||
outputorder="edgesfirst",
|
||||
pad="0.5",
|
||||
sep="+9",
|
||||
esep="+4",
|
||||
bgcolor="white",
|
||||
dpi="180",
|
||||
)
|
||||
g.attr(
|
||||
"node",
|
||||
shape="circle",
|
||||
fixedsize="true",
|
||||
width=f"{node_diameter:.2f}",
|
||||
height=f"{node_diameter:.2f}",
|
||||
fixedsize="true",
|
||||
fontsize="10",
|
||||
fontsize="11",
|
||||
fontname="Helvetica",
|
||||
style="filled",
|
||||
fillcolor="white",
|
||||
color="#374151",
|
||||
fontcolor="#111827",
|
||||
penwidth="1.8",
|
||||
peripheries="1",
|
||||
)
|
||||
g.attr(
|
||||
"edge",
|
||||
fontname="Helvetica",
|
||||
)
|
||||
|
||||
for evt in ordered_events:
|
||||
x_pos, y_pos = positions[evt]
|
||||
g.node(evt, pos=f"{x_pos:.3f},{y_pos:.3f}!", pin="true")
|
||||
x, y = positions[evt]
|
||||
g.node(evt, label=_format_node_label(evt), pos=f"{x:.2f},{y:.2f}!", pin="true")
|
||||
|
||||
for src, dsts in evt_trans.items():
|
||||
for dst, prob in dsts.items():
|
||||
if prob > threshold:
|
||||
g.edge(src, dst, label=f"{prob:.2f}")
|
||||
edges = [
|
||||
(src, dst, prob)
|
||||
for src, dsts in evt_trans.items()
|
||||
for dst, prob in dsts.items()
|
||||
if prob > threshold
|
||||
]
|
||||
edge_set = {(src, dst) for src, dst, _ in edges}
|
||||
|
||||
for src, dst, prob in sorted(edges, key=lambda row: row[2]):
|
||||
edge_attrs: Dict[str, str] = _edge_style(prob)
|
||||
|
||||
if src == dst:
|
||||
# pick a loop port away from the main flow
|
||||
sx, sy = positions[src]
|
||||
loop_port = "n" if sy <= 0 else "s"
|
||||
edge_attrs.update({"tailport": loop_port, "headport": loop_port})
|
||||
else:
|
||||
has_reverse = (dst, src) in edge_set
|
||||
tail_port, head_port = _edge_ports(src, dst, positions, has_reverse)
|
||||
edge_attrs.update({"tailport": tail_port, "headport": head_port})
|
||||
if has_reverse:
|
||||
edge_attrs["constraint"] = "false"
|
||||
|
||||
if prob >= label_threshold or src == dst:
|
||||
edge_attrs["label"] = f" {prob:.2f} "
|
||||
|
||||
g.edge(src, dst, **edge_attrs)
|
||||
|
||||
g.render(output, view=view, cleanup=True)
|
||||
print(f"Saved MDP graph to {output}.{fmt}")
|
||||
|
||||
Reference in New Issue
Block a user