mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
Merge pull request #56 from velocitatem/refactor-transition-graphs
updating node positinoing
This commit is contained in:
@@ -23,7 +23,7 @@ where:
|
|||||||
The platform does not directly observe the true underlying demand function $d(p)$. Instead, it observes a behavioral proxy $\hat{q}_t$, which is a composite signal derived from the mixture of actor types. We define the demand proxy for product $i$ at epoch $t$ as a weighted aggregation of events:
|
The platform does not directly observe the true underlying demand function $d(p)$. Instead, it observes a behavioral proxy $\hat{q}_t$, which is a composite signal derived from the mixture of actor types. We define the demand proxy for product $i$ at epoch $t$ as a weighted aggregation of events:
|
||||||
\begin{equation}
|
\begin{equation}
|
||||||
\label{eq:qhat}
|
\label{eq:qhat}
|
||||||
\hat{q}_{t,i} = \sum_{s \in \mathcal{S}_t} \sum_{k=1}^{L_s} \omega(a_{s,k}) \cdot \mathds{1}[i_{s,k} = i]
|
\hat{q}_{t,i} = \sum_{s \in \mathcal{S}_t} \sum_{k=1}^{L_s} \omega(a_{s,k}) \cdot \mathbf{1}[i_{s,k} = i]
|
||||||
\end{equation}
|
\end{equation}
|
||||||
where $\omega: \mathcal{A} \to \mathbb{R}_+$ assigns weights to actions based on their signal strength regarding willingness to pay.
|
where $\omega: \mathcal{A} \to \mathbb{R}_+$ assigns weights to actions based on their signal strength regarding willingness to pay.
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
alpha,n_products,baseline_runs,defended_runs,baseline_coi_level_mean,defended_coi_level_mean,coi_preserved,coi_preserved_pct
|
||||||
|
0.0,5.0,9,10,137.060822623968,136.18680853180368,-0.874014092164316,-0.6376833842316922
|
||||||
|
0.0,25.0,9,2,137.114858903596,136.13793579187393,-0.9769231117220727,-0.7124852255501622
|
||||||
|
0.0,50.0,9,11,137.16224858153575,136.92415566181484,-0.23809291972091273,-0.17358487643878118
|
||||||
|
0.0,100.0,9,12,135.86629045322655,137.3609873086303,1.4946968554037596,1.1001234010420895
|
||||||
|
0.1,5.0,3,6,136.59581715538818,135.6308466787041,-0.9649704766840728,-0.7064421859904723
|
||||||
|
0.1,25.0,11,8,135.9860669350444,136.43616365263273,0.45009671758833747,0.33098737814318313
|
||||||
|
0.1,50.0,10,11,136.28362874897243,136.92880179422633,0.6451730452538982,0.4734046570203046
|
||||||
|
0.1,100.0,8,8,137.35578496752095,137.53394777402949,0.17816280650853855,0.12970899372797937
|
||||||
|
0.2,5.0,8,9,135.55116314329388,137.30311388107864,1.7519507377847674,1.2924645551973204
|
||||||
|
0.2,25.0,10,9,137.01587649612287,137.22137163685403,0.20549514073115915,0.1499790724887083
|
||||||
|
0.2,50.0,4,8,137.45096138958434,137.1307018163465,-0.32025957323784837,-0.2329991511155169
|
||||||
|
0.2,100.0,9,9,137.50780776750915,137.43195025898902,-0.07585750852013007,-0.0551659645744523
|
||||||
|
0.3,5.0,6,6,134.95569459599133,134.21855668602896,-0.7371379099623709,-0.5462073402453271
|
||||||
|
0.3,25.0,9,16,136.38346021911525,136.32131251342705,-0.06214770568820427,-0.04556835967378819
|
||||||
|
0.3,50.0,8,6,136.97414077213367,136.88041560990786,-0.09372516222580884,-0.06842544271310845
|
||||||
|
0.3,100.0,7,16,137.19706520314455,137.31020460277784,0.11313939963329744,0.08246488324351146
|
||||||
|
0.4,5.0,8,11,135.6494813257779,136.5487738152141,0.899292489436192,0.6629531352769695
|
||||||
|
0.4,25.0,7,9,136.38451372914378,136.10614648175604,-0.27836724738773455,-0.20410473284420322
|
||||||
|
0.4,50.0,7,10,137.12976275807247,136.98838321468799,-0.14137954338448822,-0.10309909427460566
|
||||||
|
0.4,100.0,11,8,137.4158065068933,137.4849148270489,0.06910832015560686,0.050291390715769026
|
||||||
|
0.5,5.0,7,19,135.91101413475477,136.145621134976,0.2346070002212457,0.1726180925915501
|
||||||
|
0.5,25.0,8,7,137.0972914279529,137.35620682163616,0.2589153936832531,0.18885522170896996
|
||||||
|
0.5,50.0,8,1,137.0714841014652,135.66696334266234,-1.404520758802846,-1.0246629837050352
|
||||||
|
0.5,100.0,10,8,137.4717672869487,137.35366167964338,-0.11810560730532416,-0.08591262746975456
|
||||||
|
0.6,5.0,8,13,133.13626070539635,136.09936023073067,2.9630995253343144,2.225614201296411
|
||||||
|
0.6,25.0,5,10,136.0741624588533,136.26219778039936,0.18803532154606728,0.13818591137970535
|
||||||
|
0.6,50.0,8,10,135.09036188289087,136.05846380616936,0.968101923278482,0.7166328595060871
|
||||||
|
0.6,100.0,7,8,137.29304001584052,137.07512338179083,-0.2179166340496863,-0.15872372993164377
|
||||||
|
0.7,5.0,7,7,136.0533783988379,135.14350016006424,-0.9098782387736719,-0.6687656341075052
|
||||||
|
0.7,25.0,8,11,137.12781750399415,136.8176582131797,-0.3101592908144539,-0.2261826203172962
|
||||||
|
0.7,50.0,14,11,137.06965735909125,136.7028634119364,-0.3667939471548607,-0.26759674914335285
|
||||||
|
0.7,100.0,11,11,137.48279078937205,137.09121810549402,-0.39157268387802446,-0.28481578067317975
|
||||||
|
0.8,5.0,4,7,135.3095773096514,136.59715728802078,1.2875799783693935,0.9515808148766959
|
||||||
|
0.8,25.0,12,13,136.93488398652164,135.73319876476054,-1.201685221761096,-0.8775596011600497
|
||||||
|
0.8,50.0,6,8,136.4704324290659,136.86568018140107,0.39524775233516607,0.289621528487943
|
||||||
|
0.8,100.0,4,11,137.519864039095,137.4763376137669,-0.04352642532811046,-0.03165100957032396
|
||||||
|
0.9,5.0,5,5,134.77024204025943,136.6651608019597,1.8949187617002679,1.4060364758669837
|
||||||
|
0.9,25.0,9,13,136.7554042236364,136.06108143100832,-0.6943227926280713,-0.507711411164888
|
||||||
|
0.9,50.0,10,12,136.08715955450202,137.07569864767092,0.988539093168896,0.7264014447836223
|
||||||
|
0.9,100.0,11,9,137.57053132642514,137.30115968842037,-0.2693716380047704,-0.19580620602940735
|
||||||
|
1.0,5.0,5,7,136.43177888041947,135.92674388998284,-0.5050349904366271,-0.37017401266847305
|
||||||
|
1.0,25.0,11,9,136.7037183889911,136.22617845471228,-0.47753993427880914,-0.34932475861407586
|
||||||
|
1.0,50.0,11,5,136.93074105866745,137.05826644845806,0.12752538979060546,0.09313130769953819
|
||||||
|
1.0,100.0,8,9,136.4880191421812,137.41913068956546,0.9311115473842619,0.682192879079234
|
||||||
|
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
|||||||
|
\includegraphics[width=0.98\linewidth]{chapters/figures/results/generated/final/plots/final_focus_coi_by_alpha.pdf}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
\includegraphics[width=0.98\linewidth]{chapters/figures/results/generated/final/plots/final_focus_coi_preservation_grid.pdf}
|
||||||
Binary file not shown.
Binary file not shown.
@@ -4,15 +4,34 @@ set -euo pipefail
|
|||||||
|
|
||||||
cmd="${1:-}"
|
cmd="${1:-}"
|
||||||
|
|
||||||
|
sync_mdp_figures() {
|
||||||
|
local script_dir project_root sim_dir chapters_dir
|
||||||
|
script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
project_root="$(cd "$script_dir/.." && pwd)"
|
||||||
|
sim_dir="$project_root/sim/rl/behavior_loader"
|
||||||
|
chapters_dir="$project_root/paper/src/chapters"
|
||||||
|
|
||||||
|
printf '%s\n' 'Refreshing MDP figures for paper...'
|
||||||
|
(
|
||||||
|
cd "$sim_dir"
|
||||||
|
python models.py
|
||||||
|
)
|
||||||
|
|
||||||
|
cp "$sim_dir/human_mdp_viz.pdf" "$chapters_dir/mdp_human.pdf"
|
||||||
|
cp "$sim_dir/agent_mdp_viz.pdf" "$chapters_dir/mdp_agent.pdf"
|
||||||
|
}
|
||||||
|
|
||||||
case "$cmd" in
|
case "$cmd" in
|
||||||
build)
|
build)
|
||||||
mkdir -p paper/build
|
mkdir -p paper/build
|
||||||
|
sync_mdp_figures
|
||||||
bash paper/concat_code.sh
|
bash paper/concat_code.sh
|
||||||
cd paper/src
|
cd paper/src
|
||||||
latexmk -pdf -jobname=main -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main.tex
|
latexmk -pdf -jobname=main -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main.tex
|
||||||
;;
|
;;
|
||||||
watch)
|
watch)
|
||||||
mkdir -p paper/build
|
mkdir -p paper/build
|
||||||
|
sync_mdp_figures
|
||||||
cd paper/src
|
cd paper/src
|
||||||
latexmk -pvc -pdf -jobname=main -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main.tex
|
latexmk -pvc -pdf -jobname=main -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main.tex
|
||||||
;;
|
;;
|
||||||
@@ -33,11 +52,13 @@ case "$cmd" in
|
|||||||
;;
|
;;
|
||||||
build-genpop)
|
build-genpop)
|
||||||
mkdir -p paper/build
|
mkdir -p paper/build
|
||||||
|
sync_mdp_figures
|
||||||
cd paper/src
|
cd paper/src
|
||||||
latexmk -pdf -jobname=main-genpop -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main-genpop.tex
|
latexmk -pdf -jobname=main-genpop -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main-genpop.tex
|
||||||
;;
|
;;
|
||||||
watch-genpop)
|
watch-genpop)
|
||||||
mkdir -p paper/build
|
mkdir -p paper/build
|
||||||
|
sync_mdp_figures
|
||||||
cd paper/src
|
cd paper/src
|
||||||
latexmk -pvc -pdf -jobname=main-genpop -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main-genpop.tex
|
latexmk -pvc -pdf -jobname=main-genpop -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main-genpop.tex
|
||||||
;;
|
;;
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from sim.rl.behavior_loader.loader import Loader, AgentLoader, JointLoader
|
from sim.rl.behavior_loader.loader import Loader, AgentLoader, JointLoader
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Tuple, Set
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import graphviz
|
import graphviz
|
||||||
import sys
|
import sys
|
||||||
@@ -195,6 +195,110 @@ def aggregate_event_transitions(mdp: Dict) -> Dict[str, Dict[str, float]]:
|
|||||||
return dict(evt_trans)
|
return dict(evt_trans)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_event_order(
|
||||||
|
evt_trans: Dict[str, Dict[str, float]],
|
||||||
|
event_order: Optional[List[str]] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
observed = set(evt_trans.keys()) | {
|
||||||
|
dst for transitions in evt_trans.values() for dst in transitions
|
||||||
|
}
|
||||||
|
if event_order:
|
||||||
|
ordered = list(dict.fromkeys(event_order))
|
||||||
|
missing = sorted(observed - set(ordered))
|
||||||
|
return ordered + missing
|
||||||
|
return sorted(observed)
|
||||||
|
|
||||||
|
|
||||||
|
def _compass_from_angle(angle_rad: float) -> str:
|
||||||
|
ports = ("e", "ne", "n", "nw", "w", "sw", "s", "se")
|
||||||
|
normalized = (angle_rad + (2 * np.pi)) % (2 * np.pi)
|
||||||
|
step = np.pi / 4
|
||||||
|
idx = int(np.round(normalized / step)) % len(ports)
|
||||||
|
return ports[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def _edge_ports(
|
||||||
|
src: str,
|
||||||
|
dst: str,
|
||||||
|
positions: Dict[str, Tuple[float, float]],
|
||||||
|
has_reverse: bool,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
src_x, src_y = positions[src]
|
||||||
|
dst_x, dst_y = positions[dst]
|
||||||
|
angle = float(np.arctan2(dst_y - src_y, dst_x - src_x))
|
||||||
|
|
||||||
|
if has_reverse:
|
||||||
|
bend = np.pi / 10
|
||||||
|
angle += bend if src < dst else -bend
|
||||||
|
|
||||||
|
tail_port = _compass_from_angle(angle)
|
||||||
|
head_port = _compass_from_angle(angle + np.pi)
|
||||||
|
return tail_port, head_port
|
||||||
|
|
||||||
|
|
||||||
|
def _edge_style(prob: float) -> Dict[str, str]:
|
||||||
|
if prob >= 0.75:
|
||||||
|
edge_color = "#111827"
|
||||||
|
elif prob >= 0.50:
|
||||||
|
edge_color = "#374151"
|
||||||
|
elif prob >= 0.25:
|
||||||
|
edge_color = "#6b7280"
|
||||||
|
else:
|
||||||
|
edge_color = "#9ca3af"
|
||||||
|
return {
|
||||||
|
"color": edge_color,
|
||||||
|
"fontcolor": "#111827",
|
||||||
|
"fontsize": "10",
|
||||||
|
"penwidth": f"{0.9 + 3.6 * prob:.2f}",
|
||||||
|
"arrowsize": f"{0.55 + 0.55 * prob:.2f}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_node_label(evt: str) -> str:
|
||||||
|
max_line_len = 16
|
||||||
|
tokens = evt.split("_")
|
||||||
|
if len(tokens) == 1:
|
||||||
|
return evt
|
||||||
|
|
||||||
|
lines: List[str] = []
|
||||||
|
curr = ""
|
||||||
|
for token in tokens:
|
||||||
|
piece = token if not curr else f"_{token}"
|
||||||
|
if curr and len(curr) + len(piece) > max_line_len:
|
||||||
|
lines.append(curr)
|
||||||
|
curr = token
|
||||||
|
else:
|
||||||
|
curr = f"{curr}{piece}" if curr else token
|
||||||
|
if curr:
|
||||||
|
lines.append(curr)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_flow_positions(
|
||||||
|
events: List[str],
|
||||||
|
layout_radius: float,
|
||||||
|
) -> Dict[str, Tuple[float, float]]:
|
||||||
|
"""Balanced grid layout for paper-friendly diagrams."""
|
||||||
|
if not events:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
num_events = len(events)
|
||||||
|
cols = int(np.ceil(np.sqrt(num_events)))
|
||||||
|
rows = int(np.ceil(num_events / cols))
|
||||||
|
x_step = max(layout_radius * 1.10, 3.6)
|
||||||
|
y_step = max(layout_radius * 0.95, 3.2)
|
||||||
|
|
||||||
|
positions: Dict[str, Tuple[float, float]] = {}
|
||||||
|
for idx, evt in enumerate(events):
|
||||||
|
row = idx // cols
|
||||||
|
col = idx % cols
|
||||||
|
x = (col - (cols - 1) / 2.0) * x_step
|
||||||
|
y = ((rows - 1) / 2.0 - row) * y_step
|
||||||
|
positions[evt] = (float(x), float(y))
|
||||||
|
|
||||||
|
return positions
|
||||||
|
|
||||||
|
|
||||||
def visualize_mdp(
|
def visualize_mdp(
|
||||||
model: BehaviorModel,
|
model: BehaviorModel,
|
||||||
threshold: float = 0.05,
|
threshold: float = 0.05,
|
||||||
@@ -202,25 +306,80 @@ def visualize_mdp(
|
|||||||
fmt: str = "svg",
|
fmt: str = "svg",
|
||||||
view: bool = False,
|
view: bool = False,
|
||||||
export_dot: bool = False,
|
export_dot: bool = False,
|
||||||
|
event_order: Optional[List[str]] = None,
|
||||||
|
layout_radius: float = 10.0,
|
||||||
|
node_diameter: float = 1.8,
|
||||||
|
label_threshold: float = 0.08,
|
||||||
):
|
):
|
||||||
if not model.mdp:
|
if not model.mdp:
|
||||||
raise ValueError("build MDP first")
|
raise ValueError("build MDP first")
|
||||||
|
|
||||||
evt_trans = aggregate_event_transitions(model.mdp)
|
evt_trans = aggregate_event_transitions(model.mdp)
|
||||||
g = graphviz.Digraph(format=fmt)
|
ordered_events = _resolve_event_order(evt_trans, event_order=event_order)
|
||||||
g.attr(rankdir="LR", size="30")
|
positions = _compute_flow_positions(ordered_events, layout_radius=layout_radius)
|
||||||
g.attr("node", shape="circle", width="1", height="1")
|
|
||||||
|
|
||||||
events = set(evt_trans.keys()) | {
|
g = graphviz.Digraph(format=fmt, engine="neato")
|
||||||
e for trans in evt_trans.values() for e in trans.keys()
|
g.attr(
|
||||||
}
|
overlap="false",
|
||||||
for evt in events:
|
splines="true",
|
||||||
g.node(evt)
|
outputorder="edgesfirst",
|
||||||
|
pad="0.5",
|
||||||
|
sep="+9",
|
||||||
|
esep="+4",
|
||||||
|
bgcolor="white",
|
||||||
|
dpi="180",
|
||||||
|
)
|
||||||
|
g.attr(
|
||||||
|
"node",
|
||||||
|
shape="circle",
|
||||||
|
fixedsize="true",
|
||||||
|
width=f"{node_diameter:.2f}",
|
||||||
|
height=f"{node_diameter:.2f}",
|
||||||
|
fontsize="11",
|
||||||
|
fontname="Helvetica",
|
||||||
|
style="filled",
|
||||||
|
fillcolor="white",
|
||||||
|
color="#374151",
|
||||||
|
fontcolor="#111827",
|
||||||
|
penwidth="1.8",
|
||||||
|
peripheries="1",
|
||||||
|
)
|
||||||
|
g.attr(
|
||||||
|
"edge",
|
||||||
|
fontname="Helvetica",
|
||||||
|
)
|
||||||
|
|
||||||
for src, dsts in evt_trans.items():
|
for evt in ordered_events:
|
||||||
for dst, prob in dsts.items():
|
x, y = positions[evt]
|
||||||
if prob > threshold:
|
g.node(evt, label=_format_node_label(evt), pos=f"{x:.2f},{y:.2f}!", pin="true")
|
||||||
g.edge(src, dst, label=f"{prob:.2f}")
|
|
||||||
|
edges = [
|
||||||
|
(src, dst, prob)
|
||||||
|
for src, dsts in evt_trans.items()
|
||||||
|
for dst, prob in dsts.items()
|
||||||
|
if prob > threshold
|
||||||
|
]
|
||||||
|
edge_set = {(src, dst) for src, dst, _ in edges}
|
||||||
|
|
||||||
|
for src, dst, prob in sorted(edges, key=lambda row: row[2]):
|
||||||
|
edge_attrs: Dict[str, str] = _edge_style(prob)
|
||||||
|
|
||||||
|
if src == dst:
|
||||||
|
# pick a loop port away from the main flow
|
||||||
|
sx, sy = positions[src]
|
||||||
|
loop_port = "n" if sy <= 0 else "s"
|
||||||
|
edge_attrs.update({"tailport": loop_port, "headport": loop_port})
|
||||||
|
else:
|
||||||
|
has_reverse = (dst, src) in edge_set
|
||||||
|
tail_port, head_port = _edge_ports(src, dst, positions, has_reverse)
|
||||||
|
edge_attrs.update({"tailport": tail_port, "headport": head_port})
|
||||||
|
if has_reverse:
|
||||||
|
edge_attrs["constraint"] = "false"
|
||||||
|
|
||||||
|
if prob >= label_threshold or src == dst:
|
||||||
|
edge_attrs["label"] = f" {prob:.2f} "
|
||||||
|
|
||||||
|
g.edge(src, dst, **edge_attrs)
|
||||||
|
|
||||||
g.render(output, view=view, cleanup=True)
|
g.render(output, view=view, cleanup=True)
|
||||||
print(f"Saved MDP graph to {output}.{fmt}")
|
print(f"Saved MDP graph to {output}.{fmt}")
|
||||||
@@ -342,11 +501,6 @@ if __name__ == "__main__":
|
|||||||
f"Built MDP: {human_mdp['num_states']} states, "
|
f"Built MDP: {human_mdp['num_states']} states, "
|
||||||
f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions"
|
f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions"
|
||||||
)
|
)
|
||||||
if not human_mdp["states"]:
|
|
||||||
exit("No states found")
|
|
||||||
visualize_mdp(
|
|
||||||
human_model, threshold=0.05, output="human_mdp_viz", fmt="pdf", export_dot=True
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_model = AgentBehaviorModel(agent_dir)
|
agent_model = AgentBehaviorModel(agent_dir)
|
||||||
agent_mdp = agent_model.build_MDP()
|
agent_mdp = agent_model.build_MDP()
|
||||||
@@ -355,14 +509,35 @@ if __name__ == "__main__":
|
|||||||
f"AGENT... Built MDP: {agent_mdp['num_states']} states, "
|
f"AGENT... Built MDP: {agent_mdp['num_states']} states, "
|
||||||
f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions"
|
f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions"
|
||||||
)
|
)
|
||||||
if not agent_mdp["states"]:
|
|
||||||
exit("No states found")
|
|
||||||
visualize_mdp(
|
|
||||||
agent_model, threshold=0.05, output="agent_mdp_viz", fmt="pdf", export_dot=True
|
|
||||||
)
|
|
||||||
|
|
||||||
human_evt = aggregate_event_transitions(human_mdp)
|
human_evt = aggregate_event_transitions(human_mdp)
|
||||||
agent_evt = aggregate_event_transitions(agent_mdp)
|
agent_evt = aggregate_event_transitions(agent_mdp)
|
||||||
|
canonical_events = sorted(
|
||||||
|
(set(human_evt.keys()) | {e for tr in human_evt.values() for e in tr.keys()})
|
||||||
|
| (set(agent_evt.keys()) | {e for tr in agent_evt.values() for e in tr.keys()})
|
||||||
|
)
|
||||||
|
|
||||||
|
if not human_mdp["states"]:
|
||||||
|
exit("No states found")
|
||||||
|
visualize_mdp(
|
||||||
|
human_model,
|
||||||
|
threshold=0.05,
|
||||||
|
output="human_mdp_viz",
|
||||||
|
fmt="pdf",
|
||||||
|
export_dot=True,
|
||||||
|
event_order=canonical_events,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not agent_mdp["states"]:
|
||||||
|
exit("No states found")
|
||||||
|
visualize_mdp(
|
||||||
|
agent_model,
|
||||||
|
threshold=0.05,
|
||||||
|
output="agent_mdp_viz",
|
||||||
|
fmt="pdf",
|
||||||
|
export_dot=True,
|
||||||
|
event_order=canonical_events,
|
||||||
|
)
|
||||||
|
|
||||||
common = set(human_evt.keys()) & set(agent_evt.keys())
|
common = set(human_evt.keys()) & set(agent_evt.keys())
|
||||||
|
|
||||||
@@ -394,6 +569,7 @@ if __name__ == "__main__":
|
|||||||
output="joint_mdp_viz",
|
output="joint_mdp_viz",
|
||||||
fmt="pdf",
|
fmt="pdf",
|
||||||
export_dot=True,
|
export_dot=True,
|
||||||
|
event_order=canonical_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
inter_class_avg = float(np.mean([kl for _, kl in kl_divs]))
|
inter_class_avg = float(np.mean([kl for _, kl in kl_divs]))
|
||||||
|
|||||||
Reference in New Issue
Block a user