mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
catchup: rogue scripts
This commit is contained in:
7
sim/requirements.txt
Normal file
7
sim/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
gymnasium>=0.29.0
|
||||
numpy>=1.24.0
|
||||
pandas>=2.0.0
|
||||
stable-baselines3>=2.2.0
|
||||
tensorboard>=2.15.0
|
||||
jax>=0.4.20
|
||||
jaxlib>=0.4.20
|
||||
117
sim/rl/behavior_loader/visualize_kl.py
Normal file
117
sim/rl/behavior_loader/visualize_kl.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import defaultdict
|
||||
from models import BehaviorModel, AgentBehaviorModel, aggregate_event_transitions, kl_divergence
|
||||
|
||||
def event_frequency_distribution(mdp):
|
||||
evt_cnt, total = defaultdict(int), 0
|
||||
for s, trans in mdp['transitions'].items():
|
||||
evt = s.split('|')[2]
|
||||
for cnt in mdp['trans_counts'][s].values():
|
||||
evt_cnt[evt] += cnt
|
||||
total += cnt
|
||||
return {evt: cnt/total for evt, cnt in evt_cnt.items()} if total > 0 else {}
|
||||
|
||||
def transition_distribution(mdp):
|
||||
trans_cnt, total = defaultdict(int), 0
|
||||
for s, trans in mdp['trans_counts'].items():
|
||||
src = s.split('|')[2]
|
||||
for s_next, cnt in trans.items():
|
||||
dst = s_next.split('|')[2]
|
||||
trans_cnt[f"{src}->{dst}"] += cnt
|
||||
total += cnt
|
||||
return {t: cnt/total for t, cnt in trans_cnt.items()} if total > 0 else {}
|
||||
|
||||
def kl_color(kl):
|
||||
return '#d62828' if kl > 2.0 else '#f77f00' if kl > 0.5 else '#2a9d8f'
|
||||
|
||||
def plot_comparison(ax, human_vals, agent_vals, labels, title, ylabel, kl_val=None):
|
||||
x, w = np.arange(len(labels)), 0.35
|
||||
ax.bar(x - w/2, human_vals, w, label='Human', alpha=0.8, color='#2E86AB')
|
||||
ax.bar(x + w/2, agent_vals, w, label='Agent', alpha=0.8, color='#A23B72')
|
||||
ax.set_ylabel(ylabel, fontsize=9 if len(labels) > 10 else 11, fontweight='bold')
|
||||
ax.set_title(title if not kl_val else f"{title}\nKL={kl_val:.4f}",
|
||||
fontsize=10 if len(labels) > 10 else 12, fontweight='bold')
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
|
||||
ax.legend(fontsize=8)
|
||||
ax.grid(axis='y', alpha=0.3, linestyle='--')
|
||||
return ax
|
||||
|
||||
if __name__ == "__main__":
|
||||
base_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments"
|
||||
human_dir, agent_dir = f"{base_dir}/collected_data/", f"{base_dir}/agents/collected_data/"
|
||||
|
||||
human_model, agent_model = BehaviorModel(human_dir), AgentBehaviorModel(agent_dir)
|
||||
human_mdp, agent_mdp = human_model.build_MDP(), agent_model.build_MDP()
|
||||
|
||||
human_evt, agent_evt = aggregate_event_transitions(human_mdp), aggregate_event_transitions(agent_mdp)
|
||||
common = set(human_evt.keys()) & set(agent_evt.keys())
|
||||
kl_results = sorted([(e, kl_divergence(human_evt[e], agent_evt[e])) for e in common],
|
||||
key=lambda x: x[1], reverse=True)
|
||||
|
||||
fig = plt.figure(figsize=(16, 10))
|
||||
n_rows, n_cols = (len(kl_results) + 1) // 2, 2
|
||||
|
||||
for idx, (evt, kl) in enumerate(kl_results):
|
||||
ax = plt.subplot(n_rows, n_cols, idx + 1)
|
||||
h_dist, a_dist = human_evt.get(evt, {}), agent_evt.get(evt, {})
|
||||
dests = sorted(set(h_dist.keys()) | set(a_dist.keys()))
|
||||
if not dests: continue
|
||||
|
||||
h_probs, a_probs = [h_dist.get(d, 0) for d in dests], [a_dist.get(d, 0) for d in dests]
|
||||
plot_comparison(ax, h_probs, a_probs, dests, f'From: {evt}', 'Probability')
|
||||
ax.set_ylim([0, max(max(h_probs + a_probs, default=0) * 1.1, 0.1)])
|
||||
ax.text(0.95, 0.95, f'KL={kl:.2f}', transform=ax.transAxes, fontsize=11,
|
||||
fontweight='bold', va='top', ha='right',
|
||||
bbox=dict(boxstyle='round', facecolor=kl_color(kl), alpha=0.3))
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('kl_divergence_comparison.png', dpi=300, bbox_inches='tight')
|
||||
print("Saved visualization to kl_divergence_comparison.png")
|
||||
|
||||
fig2, ax2 = plt.subplots(figsize=(10, 6))
|
||||
evts, kls = zip(*kl_results) if kl_results else ([], [])
|
||||
colors = [kl_color(kl) for kl in kls]
|
||||
bars = ax2.barh(evts, kls, color=colors, alpha=0.8)
|
||||
ax2.set_xlabel('KL Divergence D(Human || Agent)', fontsize=12, fontweight='bold')
|
||||
ax2.set_ylabel('Event Type', fontsize=12, fontweight='bold')
|
||||
ax2.set_title('Behavioral Divergence Between Human and Agent Traffic', fontsize=14, fontweight='bold')
|
||||
if kls:
|
||||
ax2.axvline(x=np.mean(kls), color='black', linestyle='--', linewidth=2,
|
||||
alpha=0.5, label=f'Mean={np.mean(kls):.2f}')
|
||||
for bar, kl in zip(bars, kls):
|
||||
ax2.text(bar.get_width() + 0.1, bar.get_y() + bar.get_height()/2,
|
||||
f'{kl:.2f}', ha='left', va='center', fontsize=10, fontweight='bold')
|
||||
ax2.legend()
|
||||
ax2.grid(axis='x', alpha=0.3, linestyle='--')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('kl_summary.png', dpi=300, bbox_inches='tight')
|
||||
print("Saved KL summary to kl_summary.png")
|
||||
|
||||
h_freq, a_freq = event_frequency_distribution(human_mdp), event_frequency_distribution(agent_mdp)
|
||||
h_trans, a_trans = transition_distribution(human_mdp), transition_distribution(agent_mdp)
|
||||
freq_kl, trans_kl = kl_divergence(h_freq, a_freq), kl_divergence(h_trans, a_trans)
|
||||
|
||||
print(f"\n=== Global Distribution KL Divergence ===")
|
||||
print(f"Event frequency KL: {freq_kl:.4f}")
|
||||
print(f"Transition pair KL: {trans_kl:.4f}")
|
||||
|
||||
fig3, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
|
||||
|
||||
all_evts = sorted(set(h_freq.keys()) | set(a_freq.keys()))
|
||||
h_freqs, a_freqs = [h_freq.get(e, 0) for e in all_evts], [a_freq.get(e, 0) for e in all_evts]
|
||||
plot_comparison(ax1, h_freqs, a_freqs, all_evts, 'Event Frequency Distribution',
|
||||
'Frequency', freq_kl)
|
||||
|
||||
all_trans = sorted(set(h_trans.keys()) | set(a_trans.keys()))
|
||||
top_trans = [t for t, _ in sorted([(t, h_trans.get(t, 0) + a_trans.get(t, 0))
|
||||
for t in all_trans], key=lambda x: x[1], reverse=True)[:15]]
|
||||
h_tprobs, a_tprobs = [h_trans.get(t, 0) for t in top_trans], [a_trans.get(t, 0) for t in top_trans]
|
||||
plot_comparison(ax2, h_tprobs, a_tprobs, top_trans, 'Top Transition Pairs Distribution',
|
||||
'Probability', trans_kl)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('global_distributions.png', dpi=300, bbox_inches='tight')
|
||||
print("Saved global distributions to global_distributions.png")
|
||||
86
sim/rl/thesis_core.py
Normal file
86
sim/rl/thesis_core.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sim.case.thesis_simplified.simplified import Session
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PricingStep:
|
||||
sessions: list[Session]
|
||||
demand_by_session: Dict[str, float]
|
||||
demand_by_product: np.ndarray
|
||||
purchases_by_product: np.ndarray
|
||||
revenue: float
|
||||
cost: float
|
||||
n_agents: int
|
||||
|
||||
|
||||
def clip_prices(prices: np.ndarray, min_price: float, max_price: float) -> np.ndarray:
|
||||
return np.clip(prices, min_price, max_price).astype(np.float32)
|
||||
|
||||
|
||||
def constrain_prices(
|
||||
prev_prices: Optional[np.ndarray],
|
||||
proposed: np.ndarray,
|
||||
*,
|
||||
costs: np.ndarray,
|
||||
min_price: float,
|
||||
max_price: float,
|
||||
max_adjustment: float,
|
||||
min_margin_pct: float,
|
||||
) -> np.ndarray:
|
||||
prices = clip_prices(proposed, min_price, max_price)
|
||||
floor = (costs * (1.0 + float(min_margin_pct))).astype(np.float32)
|
||||
prices = np.maximum(prices, floor)
|
||||
if prev_prices is None:
|
||||
return prices
|
||||
prev_prices = prev_prices.astype(np.float32)
|
||||
ratio = np.clip(prices / (prev_prices + 1e-6), 1.0 - max_adjustment, 1.0 + max_adjustment)
|
||||
return (prev_prices * ratio).astype(np.float32)
|
||||
|
||||
|
||||
def aggregate_demand_by_product(
|
||||
sessions: list[Session],
|
||||
demand_by_session: Dict[str, float],
|
||||
n_products: int,
|
||||
) -> np.ndarray:
|
||||
demand = np.zeros(n_products, dtype=np.float32)
|
||||
sessions_by_id = {s.sid: s for s in sessions}
|
||||
for sid, q in demand_by_session.items():
|
||||
sess = sessions_by_id.get(sid)
|
||||
if not sess or not sess.events:
|
||||
continue
|
||||
pidx = int(sess.events[0].product_idx)
|
||||
if 0 <= pidx < n_products:
|
||||
demand[pidx] += float(q)
|
||||
return demand
|
||||
|
||||
|
||||
def aggregate_purchases(
|
||||
sessions: list[Session],
|
||||
costs: np.ndarray,
|
||||
n_products: int,
|
||||
) -> tuple[np.ndarray, float, float, int]:
|
||||
purchases = np.zeros(n_products, dtype=np.float32)
|
||||
revenue = 0.0
|
||||
cost = 0.0
|
||||
n_agents = 0
|
||||
|
||||
for sess in sessions:
|
||||
if sess.actor == "A":
|
||||
n_agents += 1
|
||||
for e in sess.events:
|
||||
if e.action != "purchase":
|
||||
continue
|
||||
pidx = int(e.product_idx)
|
||||
if 0 <= pidx < n_products:
|
||||
purchases[pidx] += 1.0
|
||||
revenue += float(e.price_seen)
|
||||
cost += float(costs[pidx])
|
||||
|
||||
return purchases, revenue, cost, n_agents
|
||||
|
||||
Reference in New Issue
Block a user