chore: cleaning the code

This commit is contained in:
2026-02-28 23:38:38 +01:00
parent 803e3a2972
commit ec880db444
7 changed files with 145 additions and 43 deletions

View File

@@ -36,7 +36,7 @@ try:
except ImportError: except ImportError:
HAS_WANDB = False HAS_WANDB = False
from ..wandb_checkpoint import ( from ..wandb_checkpoint import ( # noqa: E402
checkpoint_artifact_name, checkpoint_artifact_name,
download_latest_checkpoint, download_latest_checkpoint,
log_checkpoint_bytes, log_checkpoint_bytes,
@@ -83,7 +83,7 @@ except ImportError:
HAS_JAX_STACK = False HAS_JAX_STACK = False
from .env import PHANTOMJAXEnv, make_env_params from .env import PHANTOMJAXEnv, make_env_params # noqa: E402
class ActorCritic(nn.Module): class ActorCritic(nn.Module):

View File

@@ -12,3 +12,27 @@ from .providers import (
) )
from .coi import compute_uplift_coi, extract_purchases, compute_agent_probability from .coi import compute_uplift_coi, extract_purchases, compute_agent_probability
from .discrete import EventQTable from .discrete import EventQTable
__all__ = [
"estimate_demand",
"estimate_weighted_demand",
"generate_demand_for_actor",
"sample_behavior",
"get_transition_models",
"trajectory_to_events",
"DashboardRenderer",
"style_axis",
"EconomicMetricsWrapper",
"MetricsCallback",
"EvalMetricsCallback",
"CheckpointArtifactCallback",
"ProviderBenchmark",
"ProviderResult",
"BenchmarkConfig",
"RandomBaseline",
"SurgeBaseline",
"compute_uplift_coi",
"extract_purchases",
"compute_agent_probability",
"EventQTable",
]

View File

@@ -1,15 +1,19 @@
"""rendering logic for PHANTOM environment dashboard""" """rendering logic for PHANTOM environment dashboard"""
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec from matplotlib.gridspec import GridSpec
def style_axis(ax, title: str = None, xlabel: str = None, ylabel: str = None): def style_axis(ax, title: str = None, xlabel: str = None, ylabel: str = None):
ax.spines['top'].set_visible(False) ax.spines["top"].set_visible(False)
ax.spines['right'].set_visible(False) ax.spines["right"].set_visible(False)
if title: ax.set_title(title, fontsize=11, fontweight='bold', pad=8) if title:
if xlabel: ax.set_xlabel(xlabel, fontsize=9) ax.set_title(title, fontsize=11, fontweight="bold", pad=8)
if ylabel: ax.set_ylabel(ylabel, fontsize=9) if xlabel:
ax.set_xlabel(xlabel, fontsize=9)
if ylabel:
ax.set_ylabel(ylabel, fontsize=9)
class DashboardRenderer: class DashboardRenderer:
@@ -23,13 +27,25 @@ class DashboardRenderer:
if self.fig is None: if self.fig is None:
plt.ion() plt.ion()
self.fig = plt.figure(figsize=(14, 10)) self.fig = plt.figure(figsize=(14, 10))
self.gs = GridSpec(3, 3, figure=self.fig, hspace=0.35, wspace=0.3, self.gs = GridSpec(
left=0.07, right=0.95, top=0.92, bottom=0.08) 3,
3,
figure=self.fig,
hspace=0.35,
wspace=0.3,
left=0.07,
right=0.95,
top=0.92,
bottom=0.08,
)
plt.show(block=False) plt.show(block=False)
self.fig.clear() self.fig.clear()
self.fig.suptitle(f'PHANTOM Market Dynamics [t={env._step_count}, a={env.alpha:.2f}]', self.fig.suptitle(
fontsize=14, fontweight='bold') f"PHANTOM Market Dynamics [t={env._step_count}, a={env.alpha:.2f}]",
fontsize=14,
fontweight="bold",
)
demand_mat = np.array(env._demand_history).T demand_mat = np.array(env._demand_history).T
price_mat = np.array(env._price_history).T price_mat = np.array(env._price_history).T
@@ -51,40 +67,56 @@ class DashboardRenderer:
prices_flat = np.array(env._price_history).flatten() prices_flat = np.array(env._price_history).flatten()
demands_flat = np.array(env._demand_history).flatten() demands_flat = np.array(env._demand_history).flatten()
product_ids = np.tile(np.arange(env.n_products), len(env._price_history)) product_ids = np.tile(np.arange(env.n_products), len(env._price_history))
ax.scatter(prices_flat, demands_flat, c=product_ids, cmap='plasma', alpha=0.6, s=15, edgecolors='none') ax.scatter(
prices_flat,
demands_flat,
c=product_ids,
cmap="plasma",
alpha=0.6,
s=15,
edgecolors="none",
)
if len(prices_flat) > 1: if len(prices_flat) > 1:
z = np.polyfit(prices_flat, demands_flat, 1) z = np.polyfit(prices_flat, demands_flat, 1)
p_line = np.linspace(prices_flat.min(), prices_flat.max(), 50) p_line = np.linspace(prices_flat.min(), prices_flat.max(), 50)
ax.plot(p_line, np.polyval(z, p_line), '--', lw=1.5, alpha=0.8) ax.plot(p_line, np.polyval(z, p_line), "--", lw=1.5, alpha=0.8)
style_axis(ax, "Price-Demand Relationship", "Price ($)", "Demand") style_axis(ax, "Price-Demand Relationship", "Price ($)", "Demand")
def _render_elasticity_bar(self, env, elasticity): def _render_elasticity_bar(self, env, elasticity):
ax = self.fig.add_subplot(self.gs[0, 1]) ax = self.fig.add_subplot(self.gs[0, 1])
ax.barh(range(env.n_products), elasticity, alpha=0.8) ax.barh(range(env.n_products), elasticity, alpha=0.8)
ax.axvline(0, lw=0.8, alpha=0.5) ax.axvline(0, lw=0.8, alpha=0.5)
ax.axvline(-1, lw=1, ls='--', alpha=0.5) ax.axvline(-1, lw=1, ls="--", alpha=0.5)
ax.set_yticks(range(env.n_products)) ax.set_yticks(range(env.n_products))
ax.set_yticklabels([f'P{i}' for i in range(env.n_products)], fontsize=7) ax.set_yticklabels([f"P{i}" for i in range(env.n_products)], fontsize=7)
style_axis(ax, "Price Elasticity", "(dQ/dP)(P/Q)", None) style_axis(ax, "Price Elasticity", "(dQ/dP)(P/Q)", None)
def _render_session_pie(self, env): def _render_session_pie(self, env):
ax = self.fig.add_subplot(self.gs[0, 2]) ax = self.fig.add_subplot(self.gs[0, 2])
n_h, n_a = env.market.Nhumans, env.market.Nagents n_h, n_a = env.market.Nhumans, env.market.Nagents
wedges, _ = ax.pie([n_h, n_a], startangle=90, wedgeprops={'linewidth': 2, 'edgecolor': 'white'}) wedges, _ = ax.pie(
ax.legend(wedges, [f'H ({n_h})', f'A ({n_a})'], loc='lower center', fontsize=8, [n_h, n_a], startangle=90, wedgeprops={"linewidth": 2, "edgecolor": "white"}
frameon=False, bbox_to_anchor=(0.5, -0.05)) )
ax.set_title("Session Mix", fontsize=11, fontweight='bold') ax.legend(
wedges,
[f"H ({n_h})", f"A ({n_a})"],
loc="lower center",
fontsize=8,
frameon=False,
bbox_to_anchor=(0.5, -0.05),
)
ax.set_title("Session Mix", fontsize=11, fontweight="bold")
def _render_price_heatmap(self, price_mat): def _render_price_heatmap(self, price_mat):
ax = self.fig.add_subplot(self.gs[1, :2]) ax = self.fig.add_subplot(self.gs[1, :2])
im = ax.imshow(price_mat, aspect='auto', cmap='viridis', origin='lower') im = ax.imshow(price_mat, aspect="auto", cmap="viridis", origin="lower")
style_axis(ax, "Price Heatmap P(product, t)", "Step", "Product") style_axis(ax, "Price Heatmap P(product, t)", "Step", "Product")
cbar = self.fig.colorbar(im, ax=ax, fraction=0.03, pad=0.02) cbar = self.fig.colorbar(im, ax=ax, fraction=0.03, pad=0.02)
cbar.set_label('$', fontsize=8) cbar.set_label("$", fontsize=8)
def _render_demand_heatmap(self, demand_mat): def _render_demand_heatmap(self, demand_mat):
ax = self.fig.add_subplot(self.gs[1, 2]) ax = self.fig.add_subplot(self.gs[1, 2])
im = ax.imshow(demand_mat, aspect='auto', cmap='Blues', origin='lower') im = ax.imshow(demand_mat, aspect="auto", cmap="Blues", origin="lower")
style_axis(ax, "Demand Q(product, t)", "Step", None) style_axis(ax, "Demand Q(product, t)", "Step", None)
self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02) self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
@@ -92,11 +124,11 @@ class DashboardRenderer:
ax = self.fig.add_subplot(self.gs[2, 0]) ax = self.fig.add_subplot(self.gs[2, 0])
if price_mat.shape[1] > 2: if price_mat.shape[1] > 2:
corr = np.corrcoef(price_mat, demand_mat)[:n_products, n_products:] corr = np.corrcoef(price_mat, demand_mat)[:n_products, n_products:]
im = ax.imshow(corr, cmap='RdBu', vmin=-1, vmax=1, aspect='auto') im = ax.imshow(corr, cmap="RdBu", vmin=-1, vmax=1, aspect="auto")
ax.set_xticks(range(n_products)) ax.set_xticks(range(n_products))
ax.set_yticks(range(n_products)) ax.set_yticks(range(n_products))
ax.set_xticklabels([f'Q{i}' for i in range(n_products)], fontsize=6) ax.set_xticklabels([f"Q{i}" for i in range(n_products)], fontsize=6)
ax.set_yticklabels([f'P{i}' for i in range(n_products)], fontsize=6) ax.set_yticklabels([f"P{i}" for i in range(n_products)], fontsize=6)
self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02) self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
style_axis(ax, "Price-Demand Correlation", None, None) style_axis(ax, "Price-Demand Correlation", None, None)
@@ -105,20 +137,27 @@ class DashboardRenderer:
n_steps = len(env._revenue_history) n_steps = len(env._revenue_history)
demand_std = [np.std(d) for d in env._demand_history] demand_std = [np.std(d) for d in env._demand_history]
ax.fill_between(range(n_steps), env._revenue_history, alpha=0.3) ax.fill_between(range(n_steps), env._revenue_history, alpha=0.3)
ax.plot(env._revenue_history, linewidth=2, label='Revenue') ax.plot(env._revenue_history, linewidth=2, label="Revenue")
ax.set_xlim(0, max(n_steps, 1)) ax.set_xlim(0, max(n_steps, 1))
ax.set_ylim(0, max(env._revenue_history) * 1.1 if env._revenue_history else 1) ax.set_ylim(0, max(env._revenue_history) * 1.1 if env._revenue_history else 1)
ax2 = ax.twinx() ax2 = ax.twinx()
ax2.plot(range(n_steps), demand_std, linewidth=2, ls='-', alpha=0.9, label='sigma(Demand)') ax2.plot(
range(n_steps),
demand_std,
linewidth=2,
ls="-",
alpha=0.9,
label="sigma(Demand)",
)
d_min, d_max = min(demand_std), max(demand_std) d_min, d_max = min(demand_std), max(demand_std)
margin = (d_max - d_min) * 0.2 if d_max > d_min else 0.5 margin = (d_max - d_min) * 0.2 if d_max > d_min else 0.5
ax2.set_ylim(max(0, d_min - margin), d_max + margin) ax2.set_ylim(max(0, d_min - margin), d_max + margin)
ax2.set_ylabel('Demand sigma', fontsize=9) ax2.set_ylabel("Demand sigma", fontsize=9)
style_axis(ax, "Revenue & Demand Dispersion", "Step", "Revenue ($)") style_axis(ax, "Revenue & Demand Dispersion", "Step", "Revenue ($)")
ax.legend(loc='upper left', fontsize=7, frameon=False) ax.legend(loc="upper left", fontsize=7, frameon=False)
ax2.legend(loc='upper right', fontsize=7, frameon=False) ax2.legend(loc="upper right", fontsize=7, frameon=False)
def close(self): def close(self):
if self.fig: if self.fig:

View File

@@ -35,7 +35,6 @@ class EconomicMetricsWrapper(gym.Wrapper):
prices = self.env.unwrapped._prices prices = self.env.unwrapped._prices
demand_dict = self.env.unwrapped._demand demand_dict = self.env.unwrapped._demand
demand = np.array([demand_dict.get(i, 0.0) for i in range(len(prices))]) demand = np.array([demand_dict.get(i, 0.0) for i in range(len(prices))])
alpha = self.env.unwrapped.alpha
# core calculations # core calculations
revenue = float(np.sum(prices * demand)) revenue = float(np.sum(prices * demand))

View File

@@ -1,5 +1,7 @@
"""full factorial design - all factor combinations""" """full factorial design - all factor combinations"""
import sys import sys
sys.path.insert(0, "..") sys.path.insert(0, "..")
import logging import logging
from itertools import product from itertools import product
@@ -12,6 +14,7 @@ from .factors import FACTORS, DEMAND_FUNCTIONS, SEEDS_PER_CONFIG
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def generate_configs(): def generate_configs():
"""generate all factor combinations with seeds""" """generate all factor combinations with seeds"""
all_levels = [f.levels for f in FACTORS] all_levels = [f.levels for f in FACTORS]
@@ -22,10 +25,13 @@ def generate_configs():
base = {names[i]: combo[i] for i in range(len(names))} base = {names[i]: combo[i] for i in range(len(names))}
for seed in range(SEEDS_PER_CONFIG): for seed in range(SEEDS_PER_CONFIG):
cfg = {**base, "seed": seed} cfg = {**base, "seed": seed}
cfg["id"] = hashlib.md5(json.dumps(cfg, sort_keys=True).encode()).hexdigest()[:8] cfg["id"] = hashlib.md5(
json.dumps(cfg, sort_keys=True).encode()
).hexdigest()[:8]
configs.append(cfg) configs.append(cfg)
return configs return configs
def run_single(cfg: dict) -> dict: def run_single(cfg: dict) -> dict:
"""execute one experiment config, return metrics""" """execute one experiment config, return metrics"""
from engine.wrapper import PHANTOM from engine.wrapper import PHANTOM
@@ -49,7 +55,8 @@ def run_single(cfg: dict) -> dict:
obs, reward, term, trunc, _ = env.step(action) obs, reward, term, trunc, _ = env.step(action)
total_reward += reward total_reward += reward
steps += 1 steps += 1
if term: break if term:
break
env.close() env.close()
return { return {
@@ -60,22 +67,28 @@ def run_single(cfg: dict) -> dict:
"steps": steps, "steps": steps,
} }
def run_study(max_workers: int = None, output: str = "results_full.jsonl"): def run_study(max_workers: int = None, output: str = "results_full.jsonl"):
configs = generate_configs() configs = generate_configs()
log.info(f"full factorial: {len(configs)} configs ({len(configs)//SEEDS_PER_CONFIG} unique × {SEEDS_PER_CONFIG} seeds)") log.info(
f"full factorial: {len(configs)} configs ({len(configs) // SEEDS_PER_CONFIG} unique × {SEEDS_PER_CONFIG} seeds)"
)
results = [] results = []
with ProcessPoolExecutor(max_workers=max_workers) as ex: with ProcessPoolExecutor(max_workers=max_workers) as ex:
for i, result in enumerate(ex.map(run_single, configs)): for i, result in enumerate(ex.map(run_single, configs)):
results.append(result) results.append(result)
if (i+1) % 100 == 0: log.info(f"progress: {i+1}/{len(configs)}") if (i + 1) % 100 == 0:
log.info(f"progress: {i + 1}/{len(configs)}")
Path(output).write_text("\n".join(json.dumps(r) for r in results)) Path(output).write_text("\n".join(json.dumps(r) for r in results))
log.info(f"wrote {len(results)} results to {output}") log.info(f"wrote {len(results)} results to {output}")
return results return results
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
p = argparse.ArgumentParser() p = argparse.ArgumentParser()
p.add_argument("--workers", type=int, default=None) p.add_argument("--workers", type=int, default=None)
p.add_argument("--output", default="results_full.jsonl") p.add_argument("--output", default="results_full.jsonl")
@@ -83,7 +96,9 @@ if __name__ == "__main__":
args = p.parse_args() args = p.parse_args()
configs = generate_configs() configs = generate_configs()
log.info(f"design: {len(configs)} runs | factors: {[f.name for f in FACTORS]} | levels: {[len(f.levels) for f in FACTORS]}") log.info(
f"design: {len(configs)} runs | factors: {[f.name for f in FACTORS]} | levels: {[len(f.levels) for f in FACTORS]}"
)
if not args.dry_run: if not args.dry_run:
run_study(args.workers, args.output) run_study(args.workers, args.output)

View File

@@ -1,5 +1,7 @@
"""mixed design: full factorial on primary factors, latin hypercube on secondary""" """mixed design: full factorial on primary factors, latin hypercube on secondary"""
import sys import sys
sys.path.insert(0, "..") sys.path.insert(0, "..")
import logging import logging
from itertools import product from itertools import product
@@ -16,6 +18,7 @@ log = logging.getLogger(__name__)
LH_SAMPLES = 10 LH_SAMPLES = 10
def generate_configs(lh_samples: int = LH_SAMPLES): def generate_configs(lh_samples: int = LH_SAMPLES):
primary = [f for f in FACTORS if f.primary] primary = [f for f in FACTORS if f.primary]
secondary = [f for f in FACTORS if not f.primary] secondary = [f for f in FACTORS if not f.primary]
@@ -28,7 +31,9 @@ def generate_configs(lh_samples: int = LH_SAMPLES):
samples = lhs.random(n=lh_samples) samples = lhs.random(n=lh_samples)
for s in samples: for s in samples:
sec_vals = { sec_vals = {
secondary[i].name: secondary[i].levels[int(s[i] * len(secondary[i].levels))] secondary[i].name: secondary[i].levels[
int(s[i] * len(secondary[i].levels))
]
for i in range(len(secondary)) for i in range(len(secondary))
} }
base = {primary[i].name: p_combo[i] for i in range(len(primary))} base = {primary[i].name: p_combo[i] for i in range(len(primary))}
@@ -36,10 +41,13 @@ def generate_configs(lh_samples: int = LH_SAMPLES):
for seed in range(SEEDS_PER_CONFIG): for seed in range(SEEDS_PER_CONFIG):
cfg = {**base, "seed": seed} cfg = {**base, "seed": seed}
cfg["id"] = hashlib.md5(json.dumps(cfg, sort_keys=True).encode()).hexdigest()[:8] cfg["id"] = hashlib.md5(
json.dumps(cfg, sort_keys=True).encode()
).hexdigest()[:8]
configs.append(cfg) configs.append(cfg)
return configs return configs
def run_single(cfg: dict) -> dict: def run_single(cfg: dict) -> dict:
from engine.wrapper import PHANTOM from engine.wrapper import PHANTOM
import numpy as np import numpy as np
@@ -62,7 +70,8 @@ def run_single(cfg: dict) -> dict:
obs, reward, term, trunc, _ = env.step(action) obs, reward, term, trunc, _ = env.step(action)
total_reward += reward total_reward += reward
steps += 1 steps += 1
if term: break if term:
break
env.close() env.close()
return { return {
@@ -73,23 +82,33 @@ def run_single(cfg: dict) -> dict:
"steps": steps, "steps": steps,
} }
def run_study(max_workers: int = None, output: str = "results_mixed.jsonl", lh_samples: int = LH_SAMPLES):
def run_study(
max_workers: int = None,
output: str = "results_mixed.jsonl",
lh_samples: int = LH_SAMPLES,
):
configs = generate_configs(lh_samples) configs = generate_configs(lh_samples)
n_primary_cells = int(np.prod([len(f.levels) for f in FACTORS if f.primary])) n_primary_cells = int(np.prod([len(f.levels) for f in FACTORS if f.primary]))
log.info(f"mixed LH: {len(configs)} configs ({n_primary_cells} primary × {lh_samples} LH × {SEEDS_PER_CONFIG} seeds)") log.info(
f"mixed LH: {len(configs)} configs ({n_primary_cells} primary × {lh_samples} LH × {SEEDS_PER_CONFIG} seeds)"
)
results = [] results = []
with ProcessPoolExecutor(max_workers=max_workers) as ex: with ProcessPoolExecutor(max_workers=max_workers) as ex:
for i, result in enumerate(ex.map(run_single, configs)): for i, result in enumerate(ex.map(run_single, configs)):
results.append(result) results.append(result)
if (i+1) % 100 == 0: log.info(f"progress: {i+1}/{len(configs)}") if (i + 1) % 100 == 0:
log.info(f"progress: {i + 1}/{len(configs)}")
Path(output).write_text("\n".join(json.dumps(r) for r in results)) Path(output).write_text("\n".join(json.dumps(r) for r in results))
log.info(f"wrote {len(results)} results to {output}") log.info(f"wrote {len(results)} results to {output}")
return results return results
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
p = argparse.ArgumentParser() p = argparse.ArgumentParser()
p.add_argument("--workers", type=int, default=None) p.add_argument("--workers", type=int, default=None)
p.add_argument("--output", default="results_mixed.jsonl") p.add_argument("--output", default="results_mixed.jsonl")
@@ -100,7 +119,9 @@ if __name__ == "__main__":
primary = [f for f in FACTORS if f.primary] primary = [f for f in FACTORS if f.primary]
secondary = [f for f in FACTORS if not f.primary] secondary = [f for f in FACTORS if not f.primary]
configs = generate_configs(args.lh_samples) configs = generate_configs(args.lh_samples)
log.info(f"design: {len(configs)} runs | primary: {[f.name for f in primary]} | secondary (LH): {[f.name for f in secondary]}") log.info(
f"design: {len(configs)} runs | primary: {[f.name for f in primary]} | secondary (LH): {[f.name for f in secondary]}"
)
if not args.dry_run: if not args.dry_run:
run_study(args.workers, args.output, args.lh_samples) run_study(args.workers, args.output, args.lh_samples)

View File

@@ -4,8 +4,12 @@ import argparse
import json import json
import os import os
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np import numpy as np
if TYPE_CHECKING:
from .lib.discrete import EventQTable
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
try: try: