mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
updating node positinoing
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -3,7 +3,7 @@ try:
|
||||
except ImportError:
|
||||
from sim.rl.behavior_loader.loader import Loader, AgentLoader, JointLoader
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple, Set
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
import numpy as np
|
||||
import graphviz
|
||||
import sys
|
||||
@@ -195,6 +195,35 @@ def aggregate_event_transitions(mdp: Dict) -> Dict[str, Dict[str, float]]:
|
||||
return dict(evt_trans)
|
||||
|
||||
|
||||
def _resolve_event_order(
|
||||
evt_trans: Dict[str, Dict[str, float]],
|
||||
event_order: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
observed = set(evt_trans.keys()) | {
|
||||
dst for transitions in evt_trans.values() for dst in transitions
|
||||
}
|
||||
if event_order:
|
||||
ordered = list(dict.fromkeys(event_order))
|
||||
missing = sorted(observed - set(ordered))
|
||||
return ordered + missing
|
||||
return sorted(observed)
|
||||
|
||||
|
||||
def _fixed_circle_positions(
|
||||
events: List[str], radius: float
|
||||
) -> Dict[str, Tuple[float, float]]:
|
||||
if not events:
|
||||
return {}
|
||||
step = (2 * np.pi) / len(events)
|
||||
return {
|
||||
evt: (
|
||||
float(radius * np.cos(idx * step)),
|
||||
float(radius * np.sin(idx * step)),
|
||||
)
|
||||
for idx, evt in enumerate(events)
|
||||
}
|
||||
|
||||
|
||||
def visualize_mdp(
|
||||
model: BehaviorModel,
|
||||
threshold: float = 0.05,
|
||||
@@ -202,20 +231,31 @@ def visualize_mdp(
|
||||
fmt: str = "svg",
|
||||
view: bool = False,
|
||||
export_dot: bool = False,
|
||||
event_order: Optional[List[str]] = None,
|
||||
layout_radius: float = 6.0,
|
||||
node_diameter: float = 2.4,
|
||||
):
|
||||
if not model.mdp:
|
||||
raise ValueError("build MDP first")
|
||||
|
||||
evt_trans = aggregate_event_transitions(model.mdp)
|
||||
g = graphviz.Digraph(format=fmt)
|
||||
g.attr(rankdir="LR", size="30")
|
||||
g.attr("node", shape="circle", width="1", height="1")
|
||||
ordered_events = _resolve_event_order(evt_trans, event_order=event_order)
|
||||
positions = _fixed_circle_positions(ordered_events, radius=layout_radius)
|
||||
|
||||
events = set(evt_trans.keys()) | {
|
||||
e for trans in evt_trans.values() for e in trans.keys()
|
||||
}
|
||||
for evt in events:
|
||||
g.node(evt)
|
||||
g = graphviz.Digraph(format=fmt, engine="neato")
|
||||
g.attr(overlap="false", splines="true", outputorder="edgesfirst")
|
||||
g.attr(
|
||||
"node",
|
||||
shape="circle",
|
||||
width=f"{node_diameter:.2f}",
|
||||
height=f"{node_diameter:.2f}",
|
||||
fixedsize="true",
|
||||
fontsize="10",
|
||||
)
|
||||
|
||||
for evt in ordered_events:
|
||||
x_pos, y_pos = positions[evt]
|
||||
g.node(evt, pos=f"{x_pos:.3f},{y_pos:.3f}!", pin="true")
|
||||
|
||||
for src, dsts in evt_trans.items():
|
||||
for dst, prob in dsts.items():
|
||||
@@ -342,11 +382,6 @@ if __name__ == "__main__":
|
||||
f"Built MDP: {human_mdp['num_states']} states, "
|
||||
f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions"
|
||||
)
|
||||
if not human_mdp["states"]:
|
||||
exit("No states found")
|
||||
visualize_mdp(
|
||||
human_model, threshold=0.05, output="human_mdp_viz", fmt="pdf", export_dot=True
|
||||
)
|
||||
|
||||
agent_model = AgentBehaviorModel(agent_dir)
|
||||
agent_mdp = agent_model.build_MDP()
|
||||
@@ -355,14 +390,35 @@ if __name__ == "__main__":
|
||||
f"AGENT... Built MDP: {agent_mdp['num_states']} states, "
|
||||
f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions"
|
||||
)
|
||||
if not agent_mdp["states"]:
|
||||
exit("No states found")
|
||||
visualize_mdp(
|
||||
agent_model, threshold=0.05, output="agent_mdp_viz", fmt="pdf", export_dot=True
|
||||
)
|
||||
|
||||
human_evt = aggregate_event_transitions(human_mdp)
|
||||
agent_evt = aggregate_event_transitions(agent_mdp)
|
||||
canonical_events = sorted(
|
||||
(set(human_evt.keys()) | {e for tr in human_evt.values() for e in tr.keys()})
|
||||
| (set(agent_evt.keys()) | {e for tr in agent_evt.values() for e in tr.keys()})
|
||||
)
|
||||
|
||||
if not human_mdp["states"]:
|
||||
exit("No states found")
|
||||
visualize_mdp(
|
||||
human_model,
|
||||
threshold=0.05,
|
||||
output="human_mdp_viz",
|
||||
fmt="pdf",
|
||||
export_dot=True,
|
||||
event_order=canonical_events,
|
||||
)
|
||||
|
||||
if not agent_mdp["states"]:
|
||||
exit("No states found")
|
||||
visualize_mdp(
|
||||
agent_model,
|
||||
threshold=0.05,
|
||||
output="agent_mdp_viz",
|
||||
fmt="pdf",
|
||||
export_dot=True,
|
||||
event_order=canonical_events,
|
||||
)
|
||||
|
||||
common = set(human_evt.keys()) & set(agent_evt.keys())
|
||||
|
||||
@@ -394,6 +450,7 @@ if __name__ == "__main__":
|
||||
output="joint_mdp_viz",
|
||||
fmt="pdf",
|
||||
export_dot=True,
|
||||
event_order=canonical_events,
|
||||
)
|
||||
|
||||
inter_class_avg = float(np.mean([kl for _, kl in kl_divs]))
|
||||
|
||||
Reference in New Issue
Block a user