diff --git a/paper/src/chapters/03-methodology.tex b/paper/src/chapters/03-methodology.tex index 8c58717..799486e 100644 --- a/paper/src/chapters/03-methodology.tex +++ b/paper/src/chapters/03-methodology.tex @@ -23,7 +23,7 @@ where: The platform does not directly observe the true underlying demand function $d(p)$. Instead, it observes a behavioral proxy $\hat{q}_t$, which is a composite signal derived from the mixture of actor types. We define the demand proxy for product $i$ at epoch $t$ as a weighted aggregation of events: \begin{equation} \label{eq:qhat} -\hat{q}_{t,i} = \sum_{s \in \mathcal{S}_t} \sum_{k=1}^{L_s} \omega(a_{s,k}) \cdot \mathds{1}[i_{s,k} = i] +\hat{q}_{t,i} = \sum_{s \in \mathcal{S}_t} \sum_{k=1}^{L_s} \omega(a_{s,k}) \cdot \mathbf{1}[i_{s,k} = i] \end{equation} where $\omega: \mathcal{A} \to \mathbb{R}_+$ assigns weights to actions based on their signal strength regarding willingness to pay. diff --git a/paper/src/chapters/figures/results/generated/final/final_focus_coi_preservation_grid.csv b/paper/src/chapters/figures/results/generated/final/final_focus_coi_preservation_grid.csv new file mode 100644 index 0000000..a2e5115 --- /dev/null +++ b/paper/src/chapters/figures/results/generated/final/final_focus_coi_preservation_grid.csv @@ -0,0 +1,45 @@ +alpha,n_products,baseline_runs,defended_runs,baseline_coi_level_mean,defended_coi_level_mean,coi_preserved,coi_preserved_pct +0.0,5.0,9,10,137.060822623968,136.18680853180368,-0.874014092164316,-0.6376833842316922 +0.0,25.0,9,2,137.114858903596,136.13793579187393,-0.9769231117220727,-0.7124852255501622 +0.0,50.0,9,11,137.16224858153575,136.92415566181484,-0.23809291972091273,-0.17358487643878118 +0.0,100.0,9,12,135.86629045322655,137.3609873086303,1.4946968554037596,1.1001234010420895 +0.1,5.0,3,6,136.59581715538818,135.6308466787041,-0.9649704766840728,-0.7064421859904723 +0.1,25.0,11,8,135.9860669350444,136.43616365263273,0.45009671758833747,0.33098737814318313 +0.1,50.0,10,11,136.28362874897243,136.92880179422633,0.6451730452538982,0.4734046570203046 +0.1,100.0,8,8,137.35578496752095,137.53394777402949,0.17816280650853855,0.12970899372797937 +0.2,5.0,8,9,135.55116314329388,137.30311388107864,1.7519507377847674,1.2924645551973204 +0.2,25.0,10,9,137.01587649612287,137.22137163685403,0.20549514073115915,0.1499790724887083 +0.2,50.0,4,8,137.45096138958434,137.1307018163465,-0.32025957323784837,-0.2329991511155169 +0.2,100.0,9,9,137.50780776750915,137.43195025898902,-0.07585750852013007,-0.0551659645744523 +0.3,5.0,6,6,134.95569459599133,134.21855668602896,-0.7371379099623709,-0.5462073402453271 +0.3,25.0,9,16,136.38346021911525,136.32131251342705,-0.06214770568820427,-0.04556835967378819 +0.3,50.0,8,6,136.97414077213367,136.88041560990786,-0.09372516222580884,-0.06842544271310845 +0.3,100.0,7,16,137.19706520314455,137.31020460277784,0.11313939963329744,0.08246488324351146 +0.4,5.0,8,11,135.6494813257779,136.5487738152141,0.899292489436192,0.6629531352769695 +0.4,25.0,7,9,136.38451372914378,136.10614648175604,-0.27836724738773455,-0.20410473284420322 +0.4,50.0,7,10,137.12976275807247,136.98838321468799,-0.14137954338448822,-0.10309909427460566 +0.4,100.0,11,8,137.4158065068933,137.4849148270489,0.06910832015560686,0.050291390715769026 +0.5,5.0,7,19,135.91101413475477,136.145621134976,0.2346070002212457,0.1726180925915501 +0.5,25.0,8,7,137.0972914279529,137.35620682163616,0.2589153936832531,0.18885522170896996 +0.5,50.0,8,1,137.0714841014652,135.66696334266234,-1.404520758802846,-1.0246629837050352 +0.5,100.0,10,8,137.4717672869487,137.35366167964338,-0.11810560730532416,-0.08591262746975456 +0.6,5.0,8,13,133.13626070539635,136.09936023073067,2.9630995253343144,2.225614201296411 +0.6,25.0,5,10,136.0741624588533,136.26219778039936,0.18803532154606728,0.13818591137970535 +0.6,50.0,8,10,135.09036188289087,136.05846380616936,0.968101923278482,0.7166328595060871 +0.6,100.0,7,8,137.29304001584052,137.07512338179083,-0.2179166340496863,-0.15872372993164377 +0.7,5.0,7,7,136.0533783988379,135.14350016006424,-0.9098782387736719,-0.6687656341075052 +0.7,25.0,8,11,137.12781750399415,136.8176582131797,-0.3101592908144539,-0.2261826203172962 +0.7,50.0,14,11,137.06965735909125,136.7028634119364,-0.3667939471548607,-0.26759674914335285 +0.7,100.0,11,11,137.48279078937205,137.09121810549402,-0.39157268387802446,-0.28481578067317975 +0.8,5.0,4,7,135.3095773096514,136.59715728802078,1.2875799783693935,0.9515808148766959 +0.8,25.0,12,13,136.93488398652164,135.73319876476054,-1.201685221761096,-0.8775596011600497 +0.8,50.0,6,8,136.4704324290659,136.86568018140107,0.39524775233516607,0.289621528487943 +0.8,100.0,4,11,137.519864039095,137.4763376137669,-0.04352642532811046,-0.03165100957032396 +0.9,5.0,5,5,134.77024204025943,136.6651608019597,1.8949187617002679,1.4060364758669837 +0.9,25.0,9,13,136.7554042236364,136.06108143100832,-0.6943227926280713,-0.507711411164888 +0.9,50.0,10,12,136.08715955450202,137.07569864767092,0.988539093168896,0.7264014447836223 +0.9,100.0,11,9,137.57053132642514,137.30115968842037,-0.2693716380047704,-0.19580620602940735 +1.0,5.0,5,7,136.43177888041947,135.92674388998284,-0.5050349904366271,-0.37017401266847305 +1.0,25.0,11,9,136.7037183889911,136.22617845471228,-0.47753993427880914,-0.34932475861407586 +1.0,50.0,11,5,136.93074105866745,137.05826644845806,0.12752538979060546,0.09313130769953819 +1.0,100.0,8,9,136.4880191421812,137.41913068956546,0.9311115473842619,0.682192879079234 diff --git a/paper/src/chapters/figures/results/generated/final/plots/final_focus_coi_by_alpha.pdf b/paper/src/chapters/figures/results/generated/final/plots/final_focus_coi_by_alpha.pdf new file mode 100644 index 0000000..b47a4fd Binary files /dev/null and b/paper/src/chapters/figures/results/generated/final/plots/final_focus_coi_by_alpha.pdf differ diff --git a/paper/src/chapters/figures/results/generated/final/plots/final_focus_coi_preservation_grid.pdf b/paper/src/chapters/figures/results/generated/final/plots/final_focus_coi_preservation_grid.pdf new file mode 100644 index 0000000..c02bdf9 Binary files /dev/null and b/paper/src/chapters/figures/results/generated/final/plots/final_focus_coi_preservation_grid.pdf differ diff --git a/paper/src/chapters/figures/results/includes/final/final_focus_coi_by_alpha.tex b/paper/src/chapters/figures/results/includes/final/final_focus_coi_by_alpha.tex new file mode 100644 index 0000000..6eafa3f --- /dev/null +++ b/paper/src/chapters/figures/results/includes/final/final_focus_coi_by_alpha.tex @@ -0,0 +1 @@ +\includegraphics[width=0.98\linewidth]{chapters/figures/results/generated/final/plots/final_focus_coi_by_alpha.pdf} diff --git a/paper/src/chapters/figures/results/includes/final/final_focus_coi_preservation_grid.tex b/paper/src/chapters/figures/results/includes/final/final_focus_coi_preservation_grid.tex new file mode 100644 index 0000000..1ca04c8 --- /dev/null +++ b/paper/src/chapters/figures/results/includes/final/final_focus_coi_preservation_grid.tex @@ -0,0 +1 @@ +\includegraphics[width=0.98\linewidth]{chapters/figures/results/generated/final/plots/final_focus_coi_preservation_grid.pdf} diff --git a/paper/src/chapters/mdp_agent.pdf b/paper/src/chapters/mdp_agent.pdf index 0566be9..aeab1b7 100644 Binary files a/paper/src/chapters/mdp_agent.pdf and b/paper/src/chapters/mdp_agent.pdf differ diff --git a/paper/src/chapters/mdp_human.pdf b/paper/src/chapters/mdp_human.pdf index 7cef37a..b753b4e 100644 Binary files a/paper/src/chapters/mdp_human.pdf and b/paper/src/chapters/mdp_human.pdf differ diff --git a/scripts/nx_paper.sh b/scripts/nx_paper.sh index 375db5a..036a3c4 100644 --- a/scripts/nx_paper.sh +++ b/scripts/nx_paper.sh @@ -4,15 +4,34 @@ set -euo pipefail cmd="${1:-}" +sync_mdp_figures() { + local script_dir project_root sim_dir chapters_dir + script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + project_root="$(cd "$script_dir/.." && pwd)" + sim_dir="$project_root/sim/rl/behavior_loader" + chapters_dir="$project_root/paper/src/chapters" + + printf '%s\n' 'Refreshing MDP figures for paper...' + ( + cd "$sim_dir" + python models.py + ) + + cp "$sim_dir/human_mdp_viz.pdf" "$chapters_dir/mdp_human.pdf" + cp "$sim_dir/agent_mdp_viz.pdf" "$chapters_dir/mdp_agent.pdf" +} + case "$cmd" in build) mkdir -p paper/build + sync_mdp_figures bash paper/concat_code.sh cd paper/src latexmk -pdf -jobname=main -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main.tex ;; watch) mkdir -p paper/build + sync_mdp_figures cd paper/src latexmk -pvc -pdf -jobname=main -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main.tex ;; @@ -33,11 +52,13 @@ case "$cmd" in ;; build-genpop) mkdir -p paper/build + sync_mdp_figures cd paper/src latexmk -pdf -jobname=main-genpop -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main-genpop.tex ;; watch-genpop) mkdir -p paper/build + sync_mdp_figures cd paper/src latexmk -pvc -pdf -jobname=main-genpop -f -interaction=nonstopmode -file-line-error -r ../.latexmkrc -outdir=../build main-genpop.tex ;; diff --git a/sim/rl/behavior_loader/models.py b/sim/rl/behavior_loader/models.py index cb67cbf..25c8c15 100644 --- a/sim/rl/behavior_loader/models.py +++ b/sim/rl/behavior_loader/models.py @@ -3,7 +3,7 @@ try: except ImportError: from sim.rl.behavior_loader.loader import Loader, AgentLoader, JointLoader from collections import defaultdict -from typing import Dict, List, Tuple, Set +from typing import Dict, List, Optional, Set, Tuple import numpy as np import graphviz import sys @@ -195,6 +195,110 @@ def aggregate_event_transitions(mdp: Dict) -> Dict[str, Dict[str, float]]: return dict(evt_trans) +def _resolve_event_order( + evt_trans: Dict[str, Dict[str, float]], + event_order: Optional[List[str]] = None, +) -> List[str]: + observed = set(evt_trans.keys()) | { + dst for transitions in evt_trans.values() for dst in transitions + } + if event_order: + ordered = list(dict.fromkeys(event_order)) + missing = sorted(observed - set(ordered)) + return ordered + missing + return sorted(observed) + + +def _compass_from_angle(angle_rad: float) -> str: + ports = ("e", "ne", "n", "nw", "w", "sw", "s", "se") + normalized = (angle_rad + (2 * np.pi)) % (2 * np.pi) + step = np.pi / 4 + idx = int(np.round(normalized / step)) % len(ports) + return ports[idx] + + +def _edge_ports( + src: str, + dst: str, + positions: Dict[str, Tuple[float, float]], + has_reverse: bool, +) -> Tuple[str, str]: + src_x, src_y = positions[src] + dst_x, dst_y = positions[dst] + angle = float(np.arctan2(dst_y - src_y, dst_x - src_x)) + + if has_reverse: + bend = np.pi / 10 + angle += bend if src < dst else -bend + + tail_port = _compass_from_angle(angle) + head_port = _compass_from_angle(angle + np.pi) + return tail_port, head_port + + +def _edge_style(prob: float) -> Dict[str, str]: + if prob >= 0.75: + edge_color = "#111827" + elif prob >= 0.50: + edge_color = "#374151" + elif prob >= 0.25: + edge_color = "#6b7280" + else: + edge_color = "#9ca3af" + return { + "color": edge_color, + "fontcolor": "#111827", + "fontsize": "10", + "penwidth": f"{0.9 + 3.6 * prob:.2f}", + "arrowsize": f"{0.55 + 0.55 * prob:.2f}", + } + + +def _format_node_label(evt: str) -> str: + max_line_len = 16 + tokens = evt.split("_") + if len(tokens) == 1: + return evt + + lines: List[str] = [] + curr = "" + for token in tokens: + piece = token if not curr else f"_{token}" + if curr and len(curr) + len(piece) > max_line_len: + lines.append(curr) + curr = token + else: + curr = f"{curr}{piece}" if curr else token + if curr: + lines.append(curr) + return "\n".join(lines) + + +def _compute_flow_positions( + events: List[str], + layout_radius: float, +) -> Dict[str, Tuple[float, float]]: + """Balanced grid layout for paper-friendly diagrams.""" + if not events: + return {} + + num_events = len(events) + cols = int(np.ceil(np.sqrt(num_events))) + rows = int(np.ceil(num_events / cols)) + x_step = max(layout_radius * 1.10, 3.6) + y_step = max(layout_radius * 0.95, 3.2) + + positions: Dict[str, Tuple[float, float]] = {} + for idx, evt in enumerate(events): + row = idx // cols + col = idx % cols + x = (col - (cols - 1) / 2.0) * x_step + y = ((rows - 1) / 2.0 - row) * y_step + positions[evt] = (float(x), float(y)) + + return positions + + def visualize_mdp( model: BehaviorModel, threshold: float = 0.05, @@ -202,25 +306,80 @@ def visualize_mdp( fmt: str = "svg", view: bool = False, export_dot: bool = False, + event_order: Optional[List[str]] = None, + layout_radius: float = 10.0, + node_diameter: float = 1.8, + label_threshold: float = 0.08, ): if not model.mdp: raise ValueError("build MDP first") evt_trans = aggregate_event_transitions(model.mdp) - g = graphviz.Digraph(format=fmt) - g.attr(rankdir="LR", size="30") - g.attr("node", shape="circle", width="1", height="1") + ordered_events = _resolve_event_order(evt_trans, event_order=event_order) + positions = _compute_flow_positions(ordered_events, layout_radius=layout_radius) - events = set(evt_trans.keys()) | { - e for trans in evt_trans.values() for e in trans.keys() - } - for evt in events: - g.node(evt) + g = graphviz.Digraph(format=fmt, engine="neato") + g.attr( + overlap="false", + splines="true", + outputorder="edgesfirst", + pad="0.5", + sep="+9", + esep="+4", + bgcolor="white", + dpi="180", + ) + g.attr( + "node", + shape="circle", + fixedsize="true", + width=f"{node_diameter:.2f}", + height=f"{node_diameter:.2f}", + fontsize="11", + fontname="Helvetica", + style="filled", + fillcolor="white", + color="#374151", + fontcolor="#111827", + penwidth="1.8", + peripheries="1", + ) + g.attr( + "edge", + fontname="Helvetica", + ) - for src, dsts in evt_trans.items(): - for dst, prob in dsts.items(): - if prob > threshold: - g.edge(src, dst, label=f"{prob:.2f}") + for evt in ordered_events: + x, y = positions[evt] + g.node(evt, label=_format_node_label(evt), pos=f"{x:.2f},{y:.2f}!", pin="true") + + edges = [ + (src, dst, prob) + for src, dsts in evt_trans.items() + for dst, prob in dsts.items() + if prob > threshold + ] + edge_set = {(src, dst) for src, dst, _ in edges} + + for src, dst, prob in sorted(edges, key=lambda row: row[2]): + edge_attrs: Dict[str, str] = _edge_style(prob) + + if src == dst: + # pick a loop port away from the main flow + sx, sy = positions[src] + loop_port = "n" if sy <= 0 else "s" + edge_attrs.update({"tailport": loop_port, "headport": loop_port}) + else: + has_reverse = (dst, src) in edge_set + tail_port, head_port = _edge_ports(src, dst, positions, has_reverse) + edge_attrs.update({"tailport": tail_port, "headport": head_port}) + if has_reverse: + edge_attrs["constraint"] = "false" + + if prob >= label_threshold or src == dst: + edge_attrs["label"] = f" {prob:.2f} " + + g.edge(src, dst, **edge_attrs) g.render(output, view=view, cleanup=True) print(f"Saved MDP graph to {output}.{fmt}") @@ -342,11 +501,6 @@ if __name__ == "__main__": f"Built MDP: {human_mdp['num_states']} states, " f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions" ) - if not human_mdp["states"]: - exit("No states found") - visualize_mdp( - human_model, threshold=0.05, output="human_mdp_viz", fmt="pdf", export_dot=True - ) agent_model = AgentBehaviorModel(agent_dir) agent_mdp = agent_model.build_MDP() @@ -355,14 +509,35 @@ if __name__ == "__main__": f"AGENT... Built MDP: {agent_mdp['num_states']} states, " f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions" ) - if not agent_mdp["states"]: - exit("No states found") - visualize_mdp( - agent_model, threshold=0.05, output="agent_mdp_viz", fmt="pdf", export_dot=True - ) human_evt = aggregate_event_transitions(human_mdp) agent_evt = aggregate_event_transitions(agent_mdp) + canonical_events = sorted( + (set(human_evt.keys()) | {e for tr in human_evt.values() for e in tr.keys()}) + | (set(agent_evt.keys()) | {e for tr in agent_evt.values() for e in tr.keys()}) + ) + + if not human_mdp["states"]: + exit("No states found") + visualize_mdp( + human_model, + threshold=0.05, + output="human_mdp_viz", + fmt="pdf", + export_dot=True, + event_order=canonical_events, + ) + + if not agent_mdp["states"]: + exit("No states found") + visualize_mdp( + agent_model, + threshold=0.05, + output="agent_mdp_viz", + fmt="pdf", + export_dot=True, + event_order=canonical_events, + ) common = set(human_evt.keys()) & set(agent_evt.keys()) @@ -394,6 +569,7 @@ if __name__ == "__main__": output="joint_mdp_viz", fmt="pdf", export_dot=True, + event_order=canonical_events, ) inter_class_avg = float(np.mean([kl for _, kl in kl_divs]))