Merge pull request #56 from velocitatem/refactor-transition-graphs

updating node positinoing
This commit is contained in:
Daniel Alves Rösel
2026-03-28 13:10:29 +01:00
committed by GitHub
10 changed files with 268 additions and 24 deletions

View File

@@ -23,7 +23,7 @@ where:
The platform does not directly observe the true underlying demand function $d(p)$. Instead, it observes a behavioral proxy $\hat{q}_t$, which is a composite signal derived from the mixture of actor types. We define the demand proxy for product $i$ at epoch $t$ as a weighted aggregation of events: The platform does not directly observe the true underlying demand function $d(p)$. Instead, it observes a behavioral proxy $\hat{q}_t$, which is a composite signal derived from the mixture of actor types. We define the demand proxy for product $i$ at epoch $t$ as a weighted aggregation of events:
\begin{equation} \begin{equation}
\label{eq:qhat} \label{eq:qhat}
\hat{q}_{t,i} = \sum_{s \in \mathcal{S}_t} \sum_{k=1}^{L_s} \omega(a_{s,k}) \cdot \mathds{1}[i_{s,k} = i] \hat{q}_{t,i} = \sum_{s \in \mathcal{S}_t} \sum_{k=1}^{L_s} \omega(a_{s,k}) \cdot \mathbf{1}[i_{s,k} = i]
\end{equation} \end{equation}
where $\omega: \mathcal{A} \to \mathbb{R}_+$ assigns weights to actions based on their signal strength regarding willingness to pay. where $\omega: \mathcal{A} \to \mathbb{R}_+$ assigns weights to actions based on their signal strength regarding willingness to pay.

View File

@@ -0,0 +1,45 @@
alpha,n_products,baseline_runs,defended_runs,baseline_coi_level_mean,defended_coi_level_mean,coi_preserved,coi_preserved_pct
0.0,5.0,9,10,137.060822623968,136.18680853180368,-0.874014092164316,-0.6376833842316922
0.0,25.0,9,2,137.114858903596,136.13793579187393,-0.9769231117220727,-0.7124852255501622
0.0,50.0,9,11,137.16224858153575,136.92415566181484,-0.23809291972091273,-0.17358487643878118
0.0,100.0,9,12,135.86629045322655,137.3609873086303,1.4946968554037596,1.1001234010420895
0.1,5.0,3,6,136.59581715538818,135.6308466787041,-0.9649704766840728,-0.7064421859904723
0.1,25.0,11,8,135.9860669350444,136.43616365263273,0.45009671758833747,0.33098737814318313
0.1,50.0,10,11,136.28362874897243,136.92880179422633,0.6451730452538982,0.4734046570203046
0.1,100.0,8,8,137.35578496752095,137.53394777402949,0.17816280650853855,0.12970899372797937
0.2,5.0,8,9,135.55116314329388,137.30311388107864,1.7519507377847674,1.2924645551973204
0.2,25.0,10,9,137.01587649612287,137.22137163685403,0.20549514073115915,0.1499790724887083
0.2,50.0,4,8,137.45096138958434,137.1307018163465,-0.32025957323784837,-0.2329991511155169
0.2,100.0,9,9,137.50780776750915,137.43195025898902,-0.07585750852013007,-0.0551659645744523
0.3,5.0,6,6,134.95569459599133,134.21855668602896,-0.7371379099623709,-0.5462073402453271
0.3,25.0,9,16,136.38346021911525,136.32131251342705,-0.06214770568820427,-0.04556835967378819
0.3,50.0,8,6,136.97414077213367,136.88041560990786,-0.09372516222580884,-0.06842544271310845
0.3,100.0,7,16,137.19706520314455,137.31020460277784,0.11313939963329744,0.08246488324351146
0.4,5.0,8,11,135.6494813257779,136.5487738152141,0.899292489436192,0.6629531352769695
0.4,25.0,7,9,136.38451372914378,136.10614648175604,-0.27836724738773455,-0.20410473284420322
0.4,50.0,7,10,137.12976275807247,136.98838321468799,-0.14137954338448822,-0.10309909427460566
0.4,100.0,11,8,137.4158065068933,137.4849148270489,0.06910832015560686,0.050291390715769026
0.5,5.0,7,19,135.91101413475477,136.145621134976,0.2346070002212457,0.1726180925915501
0.5,25.0,8,7,137.0972914279529,137.35620682163616,0.2589153936832531,0.18885522170896996
0.5,50.0,8,1,137.0714841014652,135.66696334266234,-1.404520758802846,-1.0246629837050352
0.5,100.0,10,8,137.4717672869487,137.35366167964338,-0.11810560730532416,-0.08591262746975456
0.6,5.0,8,13,133.13626070539635,136.09936023073067,2.9630995253343144,2.225614201296411
0.6,25.0,5,10,136.0741624588533,136.26219778039936,0.18803532154606728,0.13818591137970535
0.6,50.0,8,10,135.09036188289087,136.05846380616936,0.968101923278482,0.7166328595060871
0.6,100.0,7,8,137.29304001584052,137.07512338179083,-0.2179166340496863,-0.15872372993164377
0.7,5.0,7,7,136.0533783988379,135.14350016006424,-0.9098782387736719,-0.6687656341075052
0.7,25.0,8,11,137.12781750399415,136.8176582131797,-0.3101592908144539,-0.2261826203172962
0.7,50.0,14,11,137.06965735909125,136.7028634119364,-0.3667939471548607,-0.26759674914335285
0.7,100.0,11,11,137.48279078937205,137.09121810549402,-0.39157268387802446,-0.28481578067317975
0.8,5.0,4,7,135.3095773096514,136.59715728802078,1.2875799783693935,0.9515808148766959
0.8,25.0,12,13,136.93488398652164,135.73319876476054,-1.201685221761096,-0.8775596011600497
0.8,50.0,6,8,136.4704324290659,136.86568018140107,0.39524775233516607,0.289621528487943
0.8,100.0,4,11,137.519864039095,137.4763376137669,-0.04352642532811046,-0.03165100957032396
0.9,5.0,5,5,134.77024204025943,136.6651608019597,1.8949187617002679,1.4060364758669837
0.9,25.0,9,13,136.7554042236364,136.06108143100832,-0.6943227926280713,-0.507711411164888
0.9,50.0,10,12,136.08715955450202,137.07569864767092,0.988539093168896,0.7264014447836223
0.9,100.0,11,9,137.57053132642514,137.30115968842037,-0.2693716380047704,-0.19580620602940735
1.0,5.0,5,7,136.43177888041947,135.92674388998284,-0.5050349904366271,-0.37017401266847305
1.0,25.0,11,9,136.7037183889911,136.22617845471228,-0.47753993427880914,-0.34932475861407586
1.0,50.0,11,5,136.93074105866745,137.05826644845806,0.12752538979060546,0.09313130769953819
1.0,100.0,8,9,136.4880191421812,137.41913068956546,0.9311115473842619,0.682192879079234
1 alpha n_products baseline_runs defended_runs baseline_coi_level_mean defended_coi_level_mean coi_preserved coi_preserved_pct
2 0.0 5.0 9 10 137.060822623968 136.18680853180368 -0.874014092164316 -0.6376833842316922
3 0.0 25.0 9 2 137.114858903596 136.13793579187393 -0.9769231117220727 -0.7124852255501622
4 0.0 50.0 9 11 137.16224858153575 136.92415566181484 -0.23809291972091273 -0.17358487643878118
5 0.0 100.0 9 12 135.86629045322655 137.3609873086303 1.4946968554037596 1.1001234010420895
6 0.1 5.0 3 6 136.59581715538818 135.6308466787041 -0.9649704766840728 -0.7064421859904723
7 0.1 25.0 11 8 135.9860669350444 136.43616365263273 0.45009671758833747 0.33098737814318313
8 0.1 50.0 10 11 136.28362874897243 136.92880179422633 0.6451730452538982 0.4734046570203046
9 0.1 100.0 8 8 137.35578496752095 137.53394777402949 0.17816280650853855 0.12970899372797937
10 0.2 5.0 8 9 135.55116314329388 137.30311388107864 1.7519507377847674 1.2924645551973204
11 0.2 25.0 10 9 137.01587649612287 137.22137163685403 0.20549514073115915 0.1499790724887083
12 0.2 50.0 4 8 137.45096138958434 137.1307018163465 -0.32025957323784837 -0.2329991511155169
13 0.2 100.0 9 9 137.50780776750915 137.43195025898902 -0.07585750852013007 -0.0551659645744523
14 0.3 5.0 6 6 134.95569459599133 134.21855668602896 -0.7371379099623709 -0.5462073402453271
15 0.3 25.0 9 16 136.38346021911525 136.32131251342705 -0.06214770568820427 -0.04556835967378819
16 0.3 50.0 8 6 136.97414077213367 136.88041560990786 -0.09372516222580884 -0.06842544271310845
17 0.3 100.0 7 16 137.19706520314455 137.31020460277784 0.11313939963329744 0.08246488324351146
18 0.4 5.0 8 11 135.6494813257779 136.5487738152141 0.899292489436192 0.6629531352769695
19 0.4 25.0 7 9 136.38451372914378 136.10614648175604 -0.27836724738773455 -0.20410473284420322
20 0.4 50.0 7 10 137.12976275807247 136.98838321468799 -0.14137954338448822 -0.10309909427460566
21 0.4 100.0 11 8 137.4158065068933 137.4849148270489 0.06910832015560686 0.050291390715769026
22 0.5 5.0 7 19 135.91101413475477 136.145621134976 0.2346070002212457 0.1726180925915501
23 0.5 25.0 8 7 137.0972914279529 137.35620682163616 0.2589153936832531 0.18885522170896996
24 0.5 50.0 8 1 137.0714841014652 135.66696334266234 -1.404520758802846 -1.0246629837050352
25 0.5 100.0 10 8 137.4717672869487 137.35366167964338 -0.11810560730532416 -0.08591262746975456
26 0.6 5.0 8 13 133.13626070539635 136.09936023073067 2.9630995253343144 2.225614201296411
27 0.6 25.0 5 10 136.0741624588533 136.26219778039936 0.18803532154606728 0.13818591137970535
28 0.6 50.0 8 10 135.09036188289087 136.05846380616936 0.968101923278482 0.7166328595060871
29 0.6 100.0 7 8 137.29304001584052 137.07512338179083 -0.2179166340496863 -0.15872372993164377
30 0.7 5.0 7 7 136.0533783988379 135.14350016006424 -0.9098782387736719 -0.6687656341075052
31 0.7 25.0 8 11 137.12781750399415 136.8176582131797 -0.3101592908144539 -0.2261826203172962
32 0.7 50.0 14 11 137.06965735909125 136.7028634119364 -0.3667939471548607 -0.26759674914335285
33 0.7 100.0 11 11 137.48279078937205 137.09121810549402 -0.39157268387802446 -0.28481578067317975
34 0.8 5.0 4 7 135.3095773096514 136.59715728802078 1.2875799783693935 0.9515808148766959
35 0.8 25.0 12 13 136.93488398652164 135.73319876476054 -1.201685221761096 -0.8775596011600497
36 0.8 50.0 6 8 136.4704324290659 136.86568018140107 0.39524775233516607 0.289621528487943
37 0.8 100.0 4 11 137.519864039095 137.4763376137669 -0.04352642532811046 -0.03165100957032396
38 0.9 5.0 5 5 134.77024204025943 136.6651608019597 1.8949187617002679 1.4060364758669837
39 0.9 25.0 9 13 136.7554042236364 136.06108143100832 -0.6943227926280713 -0.507711411164888
40 0.9 50.0 10 12 136.08715955450202 137.07569864767092 0.988539093168896 0.7264014447836223
41 0.9 100.0 11 9 137.57053132642514 137.30115968842037 -0.2693716380047704 -0.19580620602940735
42 1.0 5.0 5 7 136.43177888041947 135.92674388998284 -0.5050349904366271 -0.37017401266847305
43 1.0 25.0 11 9 136.7037183889911 136.22617845471228 -0.47753993427880914 -0.34932475861407586
44 1.0 50.0 11 5 136.93074105866745 137.05826644845806 0.12752538979060546 0.09313130769953819
45 1.0 100.0 8 9 136.4880191421812 137.41913068956546 0.9311115473842619 0.682192879079234

View File

@@ -0,0 +1 @@
\includegraphics[width=0.98\linewidth]{chapters/figures/results/generated/final/plots/final_focus_coi_by_alpha.pdf}

View File

@@ -0,0 +1 @@
\includegraphics[width=0.98\linewidth]{chapters/figures/results/generated/final/plots/final_focus_coi_preservation_grid.pdf}

Binary file not shown.

Binary file not shown.

View File

@@ -4,15 +4,34 @@ set -euo pipefail
cmd="${1:-}" cmd="${1:-}"
sync_mdp_figures() {
local script_dir project_root sim_dir chapters_dir
script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
project_root="$(cd "$script_dir/.." && pwd)"
sim_dir="$project_root/sim/rl/behavior_loader"
chapters_dir="$project_root/paper/src/chapters"
printf '%s\n' 'Refreshing MDP figures for paper...'
(
cd "$sim_dir"
python models.py
)
cp "$sim_dir/human_mdp_viz.pdf" "$chapters_dir/mdp_human.pdf"
cp "$sim_dir/agent_mdp_viz.pdf" "$chapters_dir/mdp_agent.pdf"
}
case "$cmd" in case "$cmd" in
build) build)
mkdir -p paper/build mkdir -p paper/build
sync_mdp_figures
bash paper/concat_code.sh bash paper/concat_code.sh
cd paper/src cd paper/src
latexmk -pdf -jobname=main -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main.tex latexmk -pdf -jobname=main -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main.tex
;; ;;
watch) watch)
mkdir -p paper/build mkdir -p paper/build
sync_mdp_figures
cd paper/src cd paper/src
latexmk -pvc -pdf -jobname=main -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main.tex latexmk -pvc -pdf -jobname=main -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main.tex
;; ;;
@@ -33,11 +52,13 @@ case "$cmd" in
;; ;;
build-genpop) build-genpop)
mkdir -p paper/build mkdir -p paper/build
sync_mdp_figures
cd paper/src cd paper/src
latexmk -pdf -jobname=main-genpop -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main-genpop.tex latexmk -pdf -jobname=main-genpop -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main-genpop.tex
;; ;;
watch-genpop) watch-genpop)
mkdir -p paper/build mkdir -p paper/build
sync_mdp_figures
cd paper/src cd paper/src
latexmk -pvc -pdf -jobname=main-genpop -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main-genpop.tex latexmk -pvc -pdf -jobname=main-genpop -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main-genpop.tex
;; ;;

View File

@@ -3,7 +3,7 @@ try:
except ImportError: except ImportError:
from sim.rl.behavior_loader.loader import Loader, AgentLoader, JointLoader from sim.rl.behavior_loader.loader import Loader, AgentLoader, JointLoader
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Tuple, Set from typing import Dict, List, Optional, Set, Tuple
import numpy as np import numpy as np
import graphviz import graphviz
import sys import sys
@@ -195,6 +195,110 @@ def aggregate_event_transitions(mdp: Dict) -> Dict[str, Dict[str, float]]:
return dict(evt_trans) return dict(evt_trans)
def _resolve_event_order(
evt_trans: Dict[str, Dict[str, float]],
event_order: Optional[List[str]] = None,
) -> List[str]:
observed = set(evt_trans.keys()) | {
dst for transitions in evt_trans.values() for dst in transitions
}
if event_order:
ordered = list(dict.fromkeys(event_order))
missing = sorted(observed - set(ordered))
return ordered + missing
return sorted(observed)
def _compass_from_angle(angle_rad: float) -> str:
ports = ("e", "ne", "n", "nw", "w", "sw", "s", "se")
normalized = (angle_rad + (2 * np.pi)) % (2 * np.pi)
step = np.pi / 4
idx = int(np.round(normalized / step)) % len(ports)
return ports[idx]
def _edge_ports(
src: str,
dst: str,
positions: Dict[str, Tuple[float, float]],
has_reverse: bool,
) -> Tuple[str, str]:
src_x, src_y = positions[src]
dst_x, dst_y = positions[dst]
angle = float(np.arctan2(dst_y - src_y, dst_x - src_x))
if has_reverse:
bend = np.pi / 10
angle += bend if src < dst else -bend
tail_port = _compass_from_angle(angle)
head_port = _compass_from_angle(angle + np.pi)
return tail_port, head_port
def _edge_style(prob: float) -> Dict[str, str]:
if prob >= 0.75:
edge_color = "#111827"
elif prob >= 0.50:
edge_color = "#374151"
elif prob >= 0.25:
edge_color = "#6b7280"
else:
edge_color = "#9ca3af"
return {
"color": edge_color,
"fontcolor": "#111827",
"fontsize": "10",
"penwidth": f"{0.9 + 3.6 * prob:.2f}",
"arrowsize": f"{0.55 + 0.55 * prob:.2f}",
}
def _format_node_label(evt: str) -> str:
max_line_len = 16
tokens = evt.split("_")
if len(tokens) == 1:
return evt
lines: List[str] = []
curr = ""
for token in tokens:
piece = token if not curr else f"_{token}"
if curr and len(curr) + len(piece) > max_line_len:
lines.append(curr)
curr = token
else:
curr = f"{curr}{piece}" if curr else token
if curr:
lines.append(curr)
return "\n".join(lines)
def _compute_flow_positions(
events: List[str],
layout_radius: float,
) -> Dict[str, Tuple[float, float]]:
"""Balanced grid layout for paper-friendly diagrams."""
if not events:
return {}
num_events = len(events)
cols = int(np.ceil(np.sqrt(num_events)))
rows = int(np.ceil(num_events / cols))
x_step = max(layout_radius * 1.10, 3.6)
y_step = max(layout_radius * 0.95, 3.2)
positions: Dict[str, Tuple[float, float]] = {}
for idx, evt in enumerate(events):
row = idx // cols
col = idx % cols
x = (col - (cols - 1) / 2.0) * x_step
y = ((rows - 1) / 2.0 - row) * y_step
positions[evt] = (float(x), float(y))
return positions
def visualize_mdp( def visualize_mdp(
model: BehaviorModel, model: BehaviorModel,
threshold: float = 0.05, threshold: float = 0.05,
@@ -202,25 +306,80 @@ def visualize_mdp(
fmt: str = "svg", fmt: str = "svg",
view: bool = False, view: bool = False,
export_dot: bool = False, export_dot: bool = False,
event_order: Optional[List[str]] = None,
layout_radius: float = 10.0,
node_diameter: float = 1.8,
label_threshold: float = 0.08,
): ):
if not model.mdp: if not model.mdp:
raise ValueError("build MDP first") raise ValueError("build MDP first")
evt_trans = aggregate_event_transitions(model.mdp) evt_trans = aggregate_event_transitions(model.mdp)
g = graphviz.Digraph(format=fmt) ordered_events = _resolve_event_order(evt_trans, event_order=event_order)
g.attr(rankdir="LR", size="30") positions = _compute_flow_positions(ordered_events, layout_radius=layout_radius)
g.attr("node", shape="circle", width="1", height="1")
events = set(evt_trans.keys()) | { g = graphviz.Digraph(format=fmt, engine="neato")
e for trans in evt_trans.values() for e in trans.keys() g.attr(
} overlap="false",
for evt in events: splines="true",
g.node(evt) outputorder="edgesfirst",
pad="0.5",
sep="+9",
esep="+4",
bgcolor="white",
dpi="180",
)
g.attr(
"node",
shape="circle",
fixedsize="true",
width=f"{node_diameter:.2f}",
height=f"{node_diameter:.2f}",
fontsize="11",
fontname="Helvetica",
style="filled",
fillcolor="white",
color="#374151",
fontcolor="#111827",
penwidth="1.8",
peripheries="1",
)
g.attr(
"edge",
fontname="Helvetica",
)
for src, dsts in evt_trans.items(): for evt in ordered_events:
for dst, prob in dsts.items(): x, y = positions[evt]
if prob > threshold: g.node(evt, label=_format_node_label(evt), pos=f"{x:.2f},{y:.2f}!", pin="true")
g.edge(src, dst, label=f"{prob:.2f}")
edges = [
(src, dst, prob)
for src, dsts in evt_trans.items()
for dst, prob in dsts.items()
if prob > threshold
]
edge_set = {(src, dst) for src, dst, _ in edges}
for src, dst, prob in sorted(edges, key=lambda row: row[2]):
edge_attrs: Dict[str, str] = _edge_style(prob)
if src == dst:
# pick a loop port away from the main flow
sx, sy = positions[src]
loop_port = "n" if sy <= 0 else "s"
edge_attrs.update({"tailport": loop_port, "headport": loop_port})
else:
has_reverse = (dst, src) in edge_set
tail_port, head_port = _edge_ports(src, dst, positions, has_reverse)
edge_attrs.update({"tailport": tail_port, "headport": head_port})
if has_reverse:
edge_attrs["constraint"] = "false"
if prob >= label_threshold or src == dst:
edge_attrs["label"] = f" {prob:.2f} "
g.edge(src, dst, **edge_attrs)
g.render(output, view=view, cleanup=True) g.render(output, view=view, cleanup=True)
print(f"Saved MDP graph to {output}.{fmt}") print(f"Saved MDP graph to {output}.{fmt}")
@@ -342,11 +501,6 @@ if __name__ == "__main__":
f"Built MDP: {human_mdp['num_states']} states, " f"Built MDP: {human_mdp['num_states']} states, "
f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions" f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions"
) )
if not human_mdp["states"]:
exit("No states found")
visualize_mdp(
human_model, threshold=0.05, output="human_mdp_viz", fmt="pdf", export_dot=True
)
agent_model = AgentBehaviorModel(agent_dir) agent_model = AgentBehaviorModel(agent_dir)
agent_mdp = agent_model.build_MDP() agent_mdp = agent_model.build_MDP()
@@ -355,14 +509,35 @@ if __name__ == "__main__":
f"AGENT... Built MDP: {agent_mdp['num_states']} states, " f"AGENT... Built MDP: {agent_mdp['num_states']} states, "
f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions" f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions"
) )
if not agent_mdp["states"]:
exit("No states found")
visualize_mdp(
agent_model, threshold=0.05, output="agent_mdp_viz", fmt="pdf", export_dot=True
)
human_evt = aggregate_event_transitions(human_mdp) human_evt = aggregate_event_transitions(human_mdp)
agent_evt = aggregate_event_transitions(agent_mdp) agent_evt = aggregate_event_transitions(agent_mdp)
canonical_events = sorted(
(set(human_evt.keys()) | {e for tr in human_evt.values() for e in tr.keys()})
| (set(agent_evt.keys()) | {e for tr in agent_evt.values() for e in tr.keys()})
)
if not human_mdp["states"]:
exit("No states found")
visualize_mdp(
human_model,
threshold=0.05,
output="human_mdp_viz",
fmt="pdf",
export_dot=True,
event_order=canonical_events,
)
if not agent_mdp["states"]:
exit("No states found")
visualize_mdp(
agent_model,
threshold=0.05,
output="agent_mdp_viz",
fmt="pdf",
export_dot=True,
event_order=canonical_events,
)
common = set(human_evt.keys()) & set(agent_evt.keys()) common = set(human_evt.keys()) & set(agent_evt.keys())
@@ -394,6 +569,7 @@ if __name__ == "__main__":
output="joint_mdp_viz", output="joint_mdp_viz",
fmt="pdf", fmt="pdf",
export_dot=True, export_dot=True,
event_order=canonical_events,
) )
inter_class_avg = float(np.mean([kl for _, kl in kl_divs])) inter_class_avg = float(np.mean([kl for _, kl in kl_divs]))