mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
updating node positinoing
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -3,7 +3,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from sim.rl.behavior_loader.loader import Loader, AgentLoader, JointLoader
|
from sim.rl.behavior_loader.loader import Loader, AgentLoader, JointLoader
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Tuple, Set
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import graphviz
|
import graphviz
|
||||||
import sys
|
import sys
|
||||||
@@ -195,6 +195,35 @@ def aggregate_event_transitions(mdp: Dict) -> Dict[str, Dict[str, float]]:
|
|||||||
return dict(evt_trans)
|
return dict(evt_trans)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_event_order(
|
||||||
|
evt_trans: Dict[str, Dict[str, float]],
|
||||||
|
event_order: Optional[List[str]] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
observed = set(evt_trans.keys()) | {
|
||||||
|
dst for transitions in evt_trans.values() for dst in transitions
|
||||||
|
}
|
||||||
|
if event_order:
|
||||||
|
ordered = list(dict.fromkeys(event_order))
|
||||||
|
missing = sorted(observed - set(ordered))
|
||||||
|
return ordered + missing
|
||||||
|
return sorted(observed)
|
||||||
|
|
||||||
|
|
||||||
|
def _fixed_circle_positions(
|
||||||
|
events: List[str], radius: float
|
||||||
|
) -> Dict[str, Tuple[float, float]]:
|
||||||
|
if not events:
|
||||||
|
return {}
|
||||||
|
step = (2 * np.pi) / len(events)
|
||||||
|
return {
|
||||||
|
evt: (
|
||||||
|
float(radius * np.cos(idx * step)),
|
||||||
|
float(radius * np.sin(idx * step)),
|
||||||
|
)
|
||||||
|
for idx, evt in enumerate(events)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def visualize_mdp(
|
def visualize_mdp(
|
||||||
model: BehaviorModel,
|
model: BehaviorModel,
|
||||||
threshold: float = 0.05,
|
threshold: float = 0.05,
|
||||||
@@ -202,20 +231,31 @@ def visualize_mdp(
|
|||||||
fmt: str = "svg",
|
fmt: str = "svg",
|
||||||
view: bool = False,
|
view: bool = False,
|
||||||
export_dot: bool = False,
|
export_dot: bool = False,
|
||||||
|
event_order: Optional[List[str]] = None,
|
||||||
|
layout_radius: float = 6.0,
|
||||||
|
node_diameter: float = 2.4,
|
||||||
):
|
):
|
||||||
if not model.mdp:
|
if not model.mdp:
|
||||||
raise ValueError("build MDP first")
|
raise ValueError("build MDP first")
|
||||||
|
|
||||||
evt_trans = aggregate_event_transitions(model.mdp)
|
evt_trans = aggregate_event_transitions(model.mdp)
|
||||||
g = graphviz.Digraph(format=fmt)
|
ordered_events = _resolve_event_order(evt_trans, event_order=event_order)
|
||||||
g.attr(rankdir="LR", size="30")
|
positions = _fixed_circle_positions(ordered_events, radius=layout_radius)
|
||||||
g.attr("node", shape="circle", width="1", height="1")
|
|
||||||
|
|
||||||
events = set(evt_trans.keys()) | {
|
g = graphviz.Digraph(format=fmt, engine="neato")
|
||||||
e for trans in evt_trans.values() for e in trans.keys()
|
g.attr(overlap="false", splines="true", outputorder="edgesfirst")
|
||||||
}
|
g.attr(
|
||||||
for evt in events:
|
"node",
|
||||||
g.node(evt)
|
shape="circle",
|
||||||
|
width=f"{node_diameter:.2f}",
|
||||||
|
height=f"{node_diameter:.2f}",
|
||||||
|
fixedsize="true",
|
||||||
|
fontsize="10",
|
||||||
|
)
|
||||||
|
|
||||||
|
for evt in ordered_events:
|
||||||
|
x_pos, y_pos = positions[evt]
|
||||||
|
g.node(evt, pos=f"{x_pos:.3f},{y_pos:.3f}!", pin="true")
|
||||||
|
|
||||||
for src, dsts in evt_trans.items():
|
for src, dsts in evt_trans.items():
|
||||||
for dst, prob in dsts.items():
|
for dst, prob in dsts.items():
|
||||||
@@ -342,11 +382,6 @@ if __name__ == "__main__":
|
|||||||
f"Built MDP: {human_mdp['num_states']} states, "
|
f"Built MDP: {human_mdp['num_states']} states, "
|
||||||
f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions"
|
f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions"
|
||||||
)
|
)
|
||||||
if not human_mdp["states"]:
|
|
||||||
exit("No states found")
|
|
||||||
visualize_mdp(
|
|
||||||
human_model, threshold=0.05, output="human_mdp_viz", fmt="pdf", export_dot=True
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_model = AgentBehaviorModel(agent_dir)
|
agent_model = AgentBehaviorModel(agent_dir)
|
||||||
agent_mdp = agent_model.build_MDP()
|
agent_mdp = agent_model.build_MDP()
|
||||||
@@ -355,14 +390,35 @@ if __name__ == "__main__":
|
|||||||
f"AGENT... Built MDP: {agent_mdp['num_states']} states, "
|
f"AGENT... Built MDP: {agent_mdp['num_states']} states, "
|
||||||
f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions"
|
f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions"
|
||||||
)
|
)
|
||||||
if not agent_mdp["states"]:
|
|
||||||
exit("No states found")
|
|
||||||
visualize_mdp(
|
|
||||||
agent_model, threshold=0.05, output="agent_mdp_viz", fmt="pdf", export_dot=True
|
|
||||||
)
|
|
||||||
|
|
||||||
human_evt = aggregate_event_transitions(human_mdp)
|
human_evt = aggregate_event_transitions(human_mdp)
|
||||||
agent_evt = aggregate_event_transitions(agent_mdp)
|
agent_evt = aggregate_event_transitions(agent_mdp)
|
||||||
|
canonical_events = sorted(
|
||||||
|
(set(human_evt.keys()) | {e for tr in human_evt.values() for e in tr.keys()})
|
||||||
|
| (set(agent_evt.keys()) | {e for tr in agent_evt.values() for e in tr.keys()})
|
||||||
|
)
|
||||||
|
|
||||||
|
if not human_mdp["states"]:
|
||||||
|
exit("No states found")
|
||||||
|
visualize_mdp(
|
||||||
|
human_model,
|
||||||
|
threshold=0.05,
|
||||||
|
output="human_mdp_viz",
|
||||||
|
fmt="pdf",
|
||||||
|
export_dot=True,
|
||||||
|
event_order=canonical_events,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not agent_mdp["states"]:
|
||||||
|
exit("No states found")
|
||||||
|
visualize_mdp(
|
||||||
|
agent_model,
|
||||||
|
threshold=0.05,
|
||||||
|
output="agent_mdp_viz",
|
||||||
|
fmt="pdf",
|
||||||
|
export_dot=True,
|
||||||
|
event_order=canonical_events,
|
||||||
|
)
|
||||||
|
|
||||||
common = set(human_evt.keys()) & set(agent_evt.keys())
|
common = set(human_evt.keys()) & set(agent_evt.keys())
|
||||||
|
|
||||||
@@ -394,6 +450,7 @@ if __name__ == "__main__":
|
|||||||
output="joint_mdp_viz",
|
output="joint_mdp_viz",
|
||||||
fmt="pdf",
|
fmt="pdf",
|
||||||
export_dot=True,
|
export_dot=True,
|
||||||
|
event_order=canonical_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
inter_class_avg = float(np.mean([kl for _, kl in kl_divs]))
|
inter_class_avg = float(np.mean([kl for _, kl in kl_divs]))
|
||||||
|
|||||||
Reference in New Issue
Block a user