mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: cleaning the code
This commit is contained in:
@@ -36,7 +36,7 @@ try:
|
||||
except ImportError:
|
||||
HAS_WANDB = False
|
||||
|
||||
from ..wandb_checkpoint import (
|
||||
from ..wandb_checkpoint import ( # noqa: E402
|
||||
checkpoint_artifact_name,
|
||||
download_latest_checkpoint,
|
||||
log_checkpoint_bytes,
|
||||
@@ -83,7 +83,7 @@ except ImportError:
|
||||
|
||||
HAS_JAX_STACK = False
|
||||
|
||||
from .env import PHANTOMJAXEnv, make_env_params
|
||||
from .env import PHANTOMJAXEnv, make_env_params # noqa: E402
|
||||
|
||||
|
||||
class ActorCritic(nn.Module):
|
||||
|
||||
@@ -12,3 +12,27 @@ from .providers import (
|
||||
)
|
||||
from .coi import compute_uplift_coi, extract_purchases, compute_agent_probability
|
||||
from .discrete import EventQTable
|
||||
|
||||
__all__ = [
|
||||
"estimate_demand",
|
||||
"estimate_weighted_demand",
|
||||
"generate_demand_for_actor",
|
||||
"sample_behavior",
|
||||
"get_transition_models",
|
||||
"trajectory_to_events",
|
||||
"DashboardRenderer",
|
||||
"style_axis",
|
||||
"EconomicMetricsWrapper",
|
||||
"MetricsCallback",
|
||||
"EvalMetricsCallback",
|
||||
"CheckpointArtifactCallback",
|
||||
"ProviderBenchmark",
|
||||
"ProviderResult",
|
||||
"BenchmarkConfig",
|
||||
"RandomBaseline",
|
||||
"SurgeBaseline",
|
||||
"compute_uplift_coi",
|
||||
"extract_purchases",
|
||||
"compute_agent_probability",
|
||||
"EventQTable",
|
||||
]
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
"""rendering logic for PHANTOM environment dashboard"""
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.gridspec import GridSpec
|
||||
|
||||
|
||||
def style_axis(ax, title: str = None, xlabel: str = None, ylabel: str = None):
|
||||
ax.spines['top'].set_visible(False)
|
||||
ax.spines['right'].set_visible(False)
|
||||
if title: ax.set_title(title, fontsize=11, fontweight='bold', pad=8)
|
||||
if xlabel: ax.set_xlabel(xlabel, fontsize=9)
|
||||
if ylabel: ax.set_ylabel(ylabel, fontsize=9)
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
if title:
|
||||
ax.set_title(title, fontsize=11, fontweight="bold", pad=8)
|
||||
if xlabel:
|
||||
ax.set_xlabel(xlabel, fontsize=9)
|
||||
if ylabel:
|
||||
ax.set_ylabel(ylabel, fontsize=9)
|
||||
|
||||
|
||||
class DashboardRenderer:
|
||||
@@ -23,13 +27,25 @@ class DashboardRenderer:
|
||||
if self.fig is None:
|
||||
plt.ion()
|
||||
self.fig = plt.figure(figsize=(14, 10))
|
||||
self.gs = GridSpec(3, 3, figure=self.fig, hspace=0.35, wspace=0.3,
|
||||
left=0.07, right=0.95, top=0.92, bottom=0.08)
|
||||
self.gs = GridSpec(
|
||||
3,
|
||||
3,
|
||||
figure=self.fig,
|
||||
hspace=0.35,
|
||||
wspace=0.3,
|
||||
left=0.07,
|
||||
right=0.95,
|
||||
top=0.92,
|
||||
bottom=0.08,
|
||||
)
|
||||
plt.show(block=False)
|
||||
|
||||
self.fig.clear()
|
||||
self.fig.suptitle(f'PHANTOM Market Dynamics [t={env._step_count}, a={env.alpha:.2f}]',
|
||||
fontsize=14, fontweight='bold')
|
||||
self.fig.suptitle(
|
||||
f"PHANTOM Market Dynamics [t={env._step_count}, a={env.alpha:.2f}]",
|
||||
fontsize=14,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
demand_mat = np.array(env._demand_history).T
|
||||
price_mat = np.array(env._price_history).T
|
||||
@@ -51,40 +67,56 @@ class DashboardRenderer:
|
||||
prices_flat = np.array(env._price_history).flatten()
|
||||
demands_flat = np.array(env._demand_history).flatten()
|
||||
product_ids = np.tile(np.arange(env.n_products), len(env._price_history))
|
||||
ax.scatter(prices_flat, demands_flat, c=product_ids, cmap='plasma', alpha=0.6, s=15, edgecolors='none')
|
||||
ax.scatter(
|
||||
prices_flat,
|
||||
demands_flat,
|
||||
c=product_ids,
|
||||
cmap="plasma",
|
||||
alpha=0.6,
|
||||
s=15,
|
||||
edgecolors="none",
|
||||
)
|
||||
if len(prices_flat) > 1:
|
||||
z = np.polyfit(prices_flat, demands_flat, 1)
|
||||
p_line = np.linspace(prices_flat.min(), prices_flat.max(), 50)
|
||||
ax.plot(p_line, np.polyval(z, p_line), '--', lw=1.5, alpha=0.8)
|
||||
ax.plot(p_line, np.polyval(z, p_line), "--", lw=1.5, alpha=0.8)
|
||||
style_axis(ax, "Price-Demand Relationship", "Price ($)", "Demand")
|
||||
|
||||
def _render_elasticity_bar(self, env, elasticity):
|
||||
ax = self.fig.add_subplot(self.gs[0, 1])
|
||||
ax.barh(range(env.n_products), elasticity, alpha=0.8)
|
||||
ax.axvline(0, lw=0.8, alpha=0.5)
|
||||
ax.axvline(-1, lw=1, ls='--', alpha=0.5)
|
||||
ax.axvline(-1, lw=1, ls="--", alpha=0.5)
|
||||
ax.set_yticks(range(env.n_products))
|
||||
ax.set_yticklabels([f'P{i}' for i in range(env.n_products)], fontsize=7)
|
||||
ax.set_yticklabels([f"P{i}" for i in range(env.n_products)], fontsize=7)
|
||||
style_axis(ax, "Price Elasticity", "(dQ/dP)(P/Q)", None)
|
||||
|
||||
def _render_session_pie(self, env):
|
||||
ax = self.fig.add_subplot(self.gs[0, 2])
|
||||
n_h, n_a = env.market.Nhumans, env.market.Nagents
|
||||
wedges, _ = ax.pie([n_h, n_a], startangle=90, wedgeprops={'linewidth': 2, 'edgecolor': 'white'})
|
||||
ax.legend(wedges, [f'H ({n_h})', f'A ({n_a})'], loc='lower center', fontsize=8,
|
||||
frameon=False, bbox_to_anchor=(0.5, -0.05))
|
||||
ax.set_title("Session Mix", fontsize=11, fontweight='bold')
|
||||
wedges, _ = ax.pie(
|
||||
[n_h, n_a], startangle=90, wedgeprops={"linewidth": 2, "edgecolor": "white"}
|
||||
)
|
||||
ax.legend(
|
||||
wedges,
|
||||
[f"H ({n_h})", f"A ({n_a})"],
|
||||
loc="lower center",
|
||||
fontsize=8,
|
||||
frameon=False,
|
||||
bbox_to_anchor=(0.5, -0.05),
|
||||
)
|
||||
ax.set_title("Session Mix", fontsize=11, fontweight="bold")
|
||||
|
||||
def _render_price_heatmap(self, price_mat):
|
||||
ax = self.fig.add_subplot(self.gs[1, :2])
|
||||
im = ax.imshow(price_mat, aspect='auto', cmap='viridis', origin='lower')
|
||||
im = ax.imshow(price_mat, aspect="auto", cmap="viridis", origin="lower")
|
||||
style_axis(ax, "Price Heatmap P(product, t)", "Step", "Product")
|
||||
cbar = self.fig.colorbar(im, ax=ax, fraction=0.03, pad=0.02)
|
||||
cbar.set_label('$', fontsize=8)
|
||||
cbar.set_label("$", fontsize=8)
|
||||
|
||||
def _render_demand_heatmap(self, demand_mat):
|
||||
ax = self.fig.add_subplot(self.gs[1, 2])
|
||||
im = ax.imshow(demand_mat, aspect='auto', cmap='Blues', origin='lower')
|
||||
im = ax.imshow(demand_mat, aspect="auto", cmap="Blues", origin="lower")
|
||||
style_axis(ax, "Demand Q(product, t)", "Step", None)
|
||||
self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
|
||||
|
||||
@@ -92,11 +124,11 @@ class DashboardRenderer:
|
||||
ax = self.fig.add_subplot(self.gs[2, 0])
|
||||
if price_mat.shape[1] > 2:
|
||||
corr = np.corrcoef(price_mat, demand_mat)[:n_products, n_products:]
|
||||
im = ax.imshow(corr, cmap='RdBu', vmin=-1, vmax=1, aspect='auto')
|
||||
im = ax.imshow(corr, cmap="RdBu", vmin=-1, vmax=1, aspect="auto")
|
||||
ax.set_xticks(range(n_products))
|
||||
ax.set_yticks(range(n_products))
|
||||
ax.set_xticklabels([f'Q{i}' for i in range(n_products)], fontsize=6)
|
||||
ax.set_yticklabels([f'P{i}' for i in range(n_products)], fontsize=6)
|
||||
ax.set_xticklabels([f"Q{i}" for i in range(n_products)], fontsize=6)
|
||||
ax.set_yticklabels([f"P{i}" for i in range(n_products)], fontsize=6)
|
||||
self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
|
||||
style_axis(ax, "Price-Demand Correlation", None, None)
|
||||
|
||||
@@ -105,20 +137,27 @@ class DashboardRenderer:
|
||||
n_steps = len(env._revenue_history)
|
||||
demand_std = [np.std(d) for d in env._demand_history]
|
||||
ax.fill_between(range(n_steps), env._revenue_history, alpha=0.3)
|
||||
ax.plot(env._revenue_history, linewidth=2, label='Revenue')
|
||||
ax.plot(env._revenue_history, linewidth=2, label="Revenue")
|
||||
ax.set_xlim(0, max(n_steps, 1))
|
||||
ax.set_ylim(0, max(env._revenue_history) * 1.1 if env._revenue_history else 1)
|
||||
|
||||
ax2 = ax.twinx()
|
||||
ax2.plot(range(n_steps), demand_std, linewidth=2, ls='-', alpha=0.9, label='sigma(Demand)')
|
||||
ax2.plot(
|
||||
range(n_steps),
|
||||
demand_std,
|
||||
linewidth=2,
|
||||
ls="-",
|
||||
alpha=0.9,
|
||||
label="sigma(Demand)",
|
||||
)
|
||||
d_min, d_max = min(demand_std), max(demand_std)
|
||||
margin = (d_max - d_min) * 0.2 if d_max > d_min else 0.5
|
||||
ax2.set_ylim(max(0, d_min - margin), d_max + margin)
|
||||
ax2.set_ylabel('Demand sigma', fontsize=9)
|
||||
ax2.set_ylabel("Demand sigma", fontsize=9)
|
||||
|
||||
style_axis(ax, "Revenue & Demand Dispersion", "Step", "Revenue ($)")
|
||||
ax.legend(loc='upper left', fontsize=7, frameon=False)
|
||||
ax2.legend(loc='upper right', fontsize=7, frameon=False)
|
||||
ax.legend(loc="upper left", fontsize=7, frameon=False)
|
||||
ax2.legend(loc="upper right", fontsize=7, frameon=False)
|
||||
|
||||
def close(self):
|
||||
if self.fig:
|
||||
|
||||
@@ -35,7 +35,6 @@ class EconomicMetricsWrapper(gym.Wrapper):
|
||||
prices = self.env.unwrapped._prices
|
||||
demand_dict = self.env.unwrapped._demand
|
||||
demand = np.array([demand_dict.get(i, 0.0) for i in range(len(prices))])
|
||||
alpha = self.env.unwrapped.alpha
|
||||
|
||||
# core calculations
|
||||
revenue = float(np.sum(prices * demand))
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""full factorial design - all factor combinations"""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "..")
|
||||
import logging
|
||||
from itertools import product
|
||||
@@ -12,6 +14,7 @@ from .factors import FACTORS, DEMAND_FUNCTIONS, SEEDS_PER_CONFIG
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_configs():
|
||||
"""generate all factor combinations with seeds"""
|
||||
all_levels = [f.levels for f in FACTORS]
|
||||
@@ -22,10 +25,13 @@ def generate_configs():
|
||||
base = {names[i]: combo[i] for i in range(len(names))}
|
||||
for seed in range(SEEDS_PER_CONFIG):
|
||||
cfg = {**base, "seed": seed}
|
||||
cfg["id"] = hashlib.md5(json.dumps(cfg, sort_keys=True).encode()).hexdigest()[:8]
|
||||
cfg["id"] = hashlib.md5(
|
||||
json.dumps(cfg, sort_keys=True).encode()
|
||||
).hexdigest()[:8]
|
||||
configs.append(cfg)
|
||||
return configs
|
||||
|
||||
|
||||
def run_single(cfg: dict) -> dict:
|
||||
"""execute one experiment config, return metrics"""
|
||||
from engine.wrapper import PHANTOM
|
||||
@@ -49,7 +55,8 @@ def run_single(cfg: dict) -> dict:
|
||||
obs, reward, term, trunc, _ = env.step(action)
|
||||
total_reward += reward
|
||||
steps += 1
|
||||
if term: break
|
||||
if term:
|
||||
break
|
||||
|
||||
env.close()
|
||||
return {
|
||||
@@ -60,22 +67,28 @@ def run_single(cfg: dict) -> dict:
|
||||
"steps": steps,
|
||||
}
|
||||
|
||||
|
||||
def run_study(max_workers: int = None, output: str = "results_full.jsonl"):
|
||||
configs = generate_configs()
|
||||
log.info(f"full factorial: {len(configs)} configs ({len(configs)//SEEDS_PER_CONFIG} unique × {SEEDS_PER_CONFIG} seeds)")
|
||||
log.info(
|
||||
f"full factorial: {len(configs)} configs ({len(configs) // SEEDS_PER_CONFIG} unique × {SEEDS_PER_CONFIG} seeds)"
|
||||
)
|
||||
|
||||
results = []
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as ex:
|
||||
for i, result in enumerate(ex.map(run_single, configs)):
|
||||
results.append(result)
|
||||
if (i+1) % 100 == 0: log.info(f"progress: {i+1}/{len(configs)}")
|
||||
if (i + 1) % 100 == 0:
|
||||
log.info(f"progress: {i + 1}/{len(configs)}")
|
||||
|
||||
Path(output).write_text("\n".join(json.dumps(r) for r in results))
|
||||
log.info(f"wrote {len(results)} results to {output}")
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--workers", type=int, default=None)
|
||||
p.add_argument("--output", default="results_full.jsonl")
|
||||
@@ -83,7 +96,9 @@ if __name__ == "__main__":
|
||||
args = p.parse_args()
|
||||
|
||||
configs = generate_configs()
|
||||
log.info(f"design: {len(configs)} runs | factors: {[f.name for f in FACTORS]} | levels: {[len(f.levels) for f in FACTORS]}")
|
||||
log.info(
|
||||
f"design: {len(configs)} runs | factors: {[f.name for f in FACTORS]} | levels: {[len(f.levels) for f in FACTORS]}"
|
||||
)
|
||||
|
||||
if not args.dry_run:
|
||||
run_study(args.workers, args.output)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""mixed design: full factorial on primary factors, latin hypercube on secondary"""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "..")
|
||||
import logging
|
||||
from itertools import product
|
||||
@@ -16,6 +18,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
LH_SAMPLES = 10
|
||||
|
||||
|
||||
def generate_configs(lh_samples: int = LH_SAMPLES):
|
||||
primary = [f for f in FACTORS if f.primary]
|
||||
secondary = [f for f in FACTORS if not f.primary]
|
||||
@@ -28,7 +31,9 @@ def generate_configs(lh_samples: int = LH_SAMPLES):
|
||||
samples = lhs.random(n=lh_samples)
|
||||
for s in samples:
|
||||
sec_vals = {
|
||||
secondary[i].name: secondary[i].levels[int(s[i] * len(secondary[i].levels))]
|
||||
secondary[i].name: secondary[i].levels[
|
||||
int(s[i] * len(secondary[i].levels))
|
||||
]
|
||||
for i in range(len(secondary))
|
||||
}
|
||||
base = {primary[i].name: p_combo[i] for i in range(len(primary))}
|
||||
@@ -36,10 +41,13 @@ def generate_configs(lh_samples: int = LH_SAMPLES):
|
||||
|
||||
for seed in range(SEEDS_PER_CONFIG):
|
||||
cfg = {**base, "seed": seed}
|
||||
cfg["id"] = hashlib.md5(json.dumps(cfg, sort_keys=True).encode()).hexdigest()[:8]
|
||||
cfg["id"] = hashlib.md5(
|
||||
json.dumps(cfg, sort_keys=True).encode()
|
||||
).hexdigest()[:8]
|
||||
configs.append(cfg)
|
||||
return configs
|
||||
|
||||
|
||||
def run_single(cfg: dict) -> dict:
|
||||
from engine.wrapper import PHANTOM
|
||||
import numpy as np
|
||||
@@ -62,7 +70,8 @@ def run_single(cfg: dict) -> dict:
|
||||
obs, reward, term, trunc, _ = env.step(action)
|
||||
total_reward += reward
|
||||
steps += 1
|
||||
if term: break
|
||||
if term:
|
||||
break
|
||||
|
||||
env.close()
|
||||
return {
|
||||
@@ -73,23 +82,33 @@ def run_single(cfg: dict) -> dict:
|
||||
"steps": steps,
|
||||
}
|
||||
|
||||
def run_study(max_workers: int = None, output: str = "results_mixed.jsonl", lh_samples: int = LH_SAMPLES):
|
||||
|
||||
def run_study(
|
||||
max_workers: int = None,
|
||||
output: str = "results_mixed.jsonl",
|
||||
lh_samples: int = LH_SAMPLES,
|
||||
):
|
||||
configs = generate_configs(lh_samples)
|
||||
n_primary_cells = int(np.prod([len(f.levels) for f in FACTORS if f.primary]))
|
||||
log.info(f"mixed LH: {len(configs)} configs ({n_primary_cells} primary × {lh_samples} LH × {SEEDS_PER_CONFIG} seeds)")
|
||||
log.info(
|
||||
f"mixed LH: {len(configs)} configs ({n_primary_cells} primary × {lh_samples} LH × {SEEDS_PER_CONFIG} seeds)"
|
||||
)
|
||||
|
||||
results = []
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as ex:
|
||||
for i, result in enumerate(ex.map(run_single, configs)):
|
||||
results.append(result)
|
||||
if (i+1) % 100 == 0: log.info(f"progress: {i+1}/{len(configs)}")
|
||||
if (i + 1) % 100 == 0:
|
||||
log.info(f"progress: {i + 1}/{len(configs)}")
|
||||
|
||||
Path(output).write_text("\n".join(json.dumps(r) for r in results))
|
||||
log.info(f"wrote {len(results)} results to {output}")
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--workers", type=int, default=None)
|
||||
p.add_argument("--output", default="results_mixed.jsonl")
|
||||
@@ -100,7 +119,9 @@ if __name__ == "__main__":
|
||||
primary = [f for f in FACTORS if f.primary]
|
||||
secondary = [f for f in FACTORS if not f.primary]
|
||||
configs = generate_configs(args.lh_samples)
|
||||
log.info(f"design: {len(configs)} runs | primary: {[f.name for f in primary]} | secondary (LH): {[f.name for f in secondary]}")
|
||||
log.info(
|
||||
f"design: {len(configs)} runs | primary: {[f.name for f in primary]} | secondary (LH): {[f.name for f in secondary]}"
|
||||
)
|
||||
|
||||
if not args.dry_run:
|
||||
run_study(args.workers, args.output, args.lh_samples)
|
||||
|
||||
@@ -4,8 +4,12 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lib.discrete import EventQTable
|
||||
|
||||
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user