diff --git a/engine/jax/train.py b/engine/jax/train.py index e5c4c03..3860d8b 100644 --- a/engine/jax/train.py +++ b/engine/jax/train.py @@ -36,7 +36,7 @@ try: except ImportError: HAS_WANDB = False -from ..wandb_checkpoint import ( +from ..wandb_checkpoint import ( # noqa: E402 checkpoint_artifact_name, download_latest_checkpoint, log_checkpoint_bytes, @@ -83,7 +83,7 @@ except ImportError: HAS_JAX_STACK = False -from .env import PHANTOMJAXEnv, make_env_params +from .env import PHANTOMJAXEnv, make_env_params # noqa: E402 class ActorCritic(nn.Module): diff --git a/engine/lib/__init__.py b/engine/lib/__init__.py index c2fafc9..4bfb923 100644 --- a/engine/lib/__init__.py +++ b/engine/lib/__init__.py @@ -12,3 +12,27 @@ from .providers import ( ) from .coi import compute_uplift_coi, extract_purchases, compute_agent_probability from .discrete import EventQTable + +__all__ = [ + "estimate_demand", + "estimate_weighted_demand", + "generate_demand_for_actor", + "sample_behavior", + "get_transition_models", + "trajectory_to_events", + "DashboardRenderer", + "style_axis", + "EconomicMetricsWrapper", + "MetricsCallback", + "EvalMetricsCallback", + "CheckpointArtifactCallback", + "ProviderBenchmark", + "ProviderResult", + "BenchmarkConfig", + "RandomBaseline", + "SurgeBaseline", + "compute_uplift_coi", + "extract_purchases", + "compute_agent_probability", + "EventQTable", +] diff --git a/engine/lib/render.py b/engine/lib/render.py index a16f215..bb70ba5 100644 --- a/engine/lib/render.py +++ b/engine/lib/render.py @@ -1,15 +1,19 @@ """rendering logic for PHANTOM environment dashboard""" + import numpy as np import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec def style_axis(ax, title: str = None, xlabel: str = None, ylabel: str = None): - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - if title: ax.set_title(title, fontsize=11, fontweight='bold', pad=8) - if xlabel: ax.set_xlabel(xlabel, fontsize=9) - if ylabel: ax.set_ylabel(ylabel, fontsize=9) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + if title: + ax.set_title(title, fontsize=11, fontweight="bold", pad=8) + if xlabel: + ax.set_xlabel(xlabel, fontsize=9) + if ylabel: + ax.set_ylabel(ylabel, fontsize=9) class DashboardRenderer: @@ -23,13 +27,25 @@ class DashboardRenderer: if self.fig is None: plt.ion() self.fig = plt.figure(figsize=(14, 10)) - self.gs = GridSpec(3, 3, figure=self.fig, hspace=0.35, wspace=0.3, - left=0.07, right=0.95, top=0.92, bottom=0.08) + self.gs = GridSpec( + 3, + 3, + figure=self.fig, + hspace=0.35, + wspace=0.3, + left=0.07, + right=0.95, + top=0.92, + bottom=0.08, + ) plt.show(block=False) self.fig.clear() - self.fig.suptitle(f'PHANTOM Market Dynamics [t={env._step_count}, a={env.alpha:.2f}]', - fontsize=14, fontweight='bold') + self.fig.suptitle( + f"PHANTOM Market Dynamics [t={env._step_count}, a={env.alpha:.2f}]", + fontsize=14, + fontweight="bold", + ) demand_mat = np.array(env._demand_history).T price_mat = np.array(env._price_history).T @@ -51,40 +67,56 @@ class DashboardRenderer: prices_flat = np.array(env._price_history).flatten() demands_flat = np.array(env._demand_history).flatten() product_ids = np.tile(np.arange(env.n_products), len(env._price_history)) - ax.scatter(prices_flat, demands_flat, c=product_ids, cmap='plasma', alpha=0.6, s=15, edgecolors='none') + ax.scatter( + prices_flat, + demands_flat, + c=product_ids, + cmap="plasma", + alpha=0.6, + s=15, + edgecolors="none", + ) if len(prices_flat) > 1: z = np.polyfit(prices_flat, demands_flat, 1) p_line = np.linspace(prices_flat.min(), prices_flat.max(), 50) - ax.plot(p_line, np.polyval(z, p_line), '--', lw=1.5, alpha=0.8) + ax.plot(p_line, np.polyval(z, p_line), "--", lw=1.5, alpha=0.8) style_axis(ax, "Price-Demand Relationship", "Price ($)", "Demand") def _render_elasticity_bar(self, env, elasticity): ax = self.fig.add_subplot(self.gs[0, 1]) ax.barh(range(env.n_products), elasticity, alpha=0.8) ax.axvline(0, lw=0.8, alpha=0.5) - ax.axvline(-1, lw=1, ls='--', alpha=0.5) + ax.axvline(-1, lw=1, ls="--", alpha=0.5) ax.set_yticks(range(env.n_products)) - ax.set_yticklabels([f'P{i}' for i in range(env.n_products)], fontsize=7) + ax.set_yticklabels([f"P{i}" for i in range(env.n_products)], fontsize=7) style_axis(ax, "Price Elasticity", "(dQ/dP)(P/Q)", None) def _render_session_pie(self, env): ax = self.fig.add_subplot(self.gs[0, 2]) n_h, n_a = env.market.Nhumans, env.market.Nagents - wedges, _ = ax.pie([n_h, n_a], startangle=90, wedgeprops={'linewidth': 2, 'edgecolor': 'white'}) - ax.legend(wedges, [f'H ({n_h})', f'A ({n_a})'], loc='lower center', fontsize=8, - frameon=False, bbox_to_anchor=(0.5, -0.05)) - ax.set_title("Session Mix", fontsize=11, fontweight='bold') + wedges, _ = ax.pie( + [n_h, n_a], startangle=90, wedgeprops={"linewidth": 2, "edgecolor": "white"} + ) + ax.legend( + wedges, + [f"H ({n_h})", f"A ({n_a})"], + loc="lower center", + fontsize=8, + frameon=False, + bbox_to_anchor=(0.5, -0.05), + ) + ax.set_title("Session Mix", fontsize=11, fontweight="bold") def _render_price_heatmap(self, price_mat): ax = self.fig.add_subplot(self.gs[1, :2]) - im = ax.imshow(price_mat, aspect='auto', cmap='viridis', origin='lower') + im = ax.imshow(price_mat, aspect="auto", cmap="viridis", origin="lower") style_axis(ax, "Price Heatmap P(product, t)", "Step", "Product") cbar = self.fig.colorbar(im, ax=ax, fraction=0.03, pad=0.02) - cbar.set_label('$', fontsize=8) + cbar.set_label("$", fontsize=8) def _render_demand_heatmap(self, demand_mat): ax = self.fig.add_subplot(self.gs[1, 2]) - im = ax.imshow(demand_mat, aspect='auto', cmap='Blues', origin='lower') + im = ax.imshow(demand_mat, aspect="auto", cmap="Blues", origin="lower") style_axis(ax, "Demand Q(product, t)", "Step", None) self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02) @@ -92,11 +124,11 @@ class DashboardRenderer: ax = self.fig.add_subplot(self.gs[2, 0]) if price_mat.shape[1] > 2: corr = np.corrcoef(price_mat, demand_mat)[:n_products, n_products:] - im = ax.imshow(corr, cmap='RdBu', vmin=-1, vmax=1, aspect='auto') + im = ax.imshow(corr, cmap="RdBu", vmin=-1, vmax=1, aspect="auto") ax.set_xticks(range(n_products)) ax.set_yticks(range(n_products)) - ax.set_xticklabels([f'Q{i}' for i in range(n_products)], fontsize=6) - ax.set_yticklabels([f'P{i}' for i in range(n_products)], fontsize=6) + ax.set_xticklabels([f"Q{i}" for i in range(n_products)], fontsize=6) + ax.set_yticklabels([f"P{i}" for i in range(n_products)], fontsize=6) self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02) style_axis(ax, "Price-Demand Correlation", None, None) @@ -105,20 +137,27 @@ class DashboardRenderer: n_steps = len(env._revenue_history) demand_std = [np.std(d) for d in env._demand_history] ax.fill_between(range(n_steps), env._revenue_history, alpha=0.3) - ax.plot(env._revenue_history, linewidth=2, label='Revenue') + ax.plot(env._revenue_history, linewidth=2, label="Revenue") ax.set_xlim(0, max(n_steps, 1)) ax.set_ylim(0, max(env._revenue_history) * 1.1 if env._revenue_history else 1) ax2 = ax.twinx() - ax2.plot(range(n_steps), demand_std, linewidth=2, ls='-', alpha=0.9, label='sigma(Demand)') + ax2.plot( + range(n_steps), + demand_std, + linewidth=2, + ls="-", + alpha=0.9, + label="sigma(Demand)", + ) d_min, d_max = min(demand_std), max(demand_std) margin = (d_max - d_min) * 0.2 if d_max > d_min else 0.5 ax2.set_ylim(max(0, d_min - margin), d_max + margin) - ax2.set_ylabel('Demand sigma', fontsize=9) + ax2.set_ylabel("Demand sigma", fontsize=9) style_axis(ax, "Revenue & Demand Dispersion", "Step", "Revenue ($)") - ax.legend(loc='upper left', fontsize=7, frameon=False) - ax2.legend(loc='upper right', fontsize=7, frameon=False) + ax.legend(loc="upper left", fontsize=7, frameon=False) + ax2.legend(loc="upper right", fontsize=7, frameon=False) def close(self): if self.fig: diff --git a/engine/lib/wrappers.py b/engine/lib/wrappers.py index 3d74b79..a1b464b 100644 --- a/engine/lib/wrappers.py +++ b/engine/lib/wrappers.py @@ -35,7 +35,6 @@ class EconomicMetricsWrapper(gym.Wrapper): prices = self.env.unwrapped._prices demand_dict = self.env.unwrapped._demand demand = np.array([demand_dict.get(i, 0.0) for i in range(len(prices))]) - alpha = self.env.unwrapped.alpha # core calculations revenue = float(np.sum(prices * demand)) diff --git a/engine/studies/full_factorial.py b/engine/studies/full_factorial.py index 92210b2..11947a9 100644 --- a/engine/studies/full_factorial.py +++ b/engine/studies/full_factorial.py @@ -1,5 +1,7 @@ """full factorial design - all factor combinations""" + import sys + sys.path.insert(0, "..") import logging from itertools import product @@ -12,6 +14,7 @@ from .factors import FACTORS, DEMAND_FUNCTIONS, SEEDS_PER_CONFIG logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) + def generate_configs(): """generate all factor combinations with seeds""" all_levels = [f.levels for f in FACTORS] @@ -22,10 +25,13 @@ def generate_configs(): base = {names[i]: combo[i] for i in range(len(names))} for seed in range(SEEDS_PER_CONFIG): cfg = {**base, "seed": seed} - cfg["id"] = hashlib.md5(json.dumps(cfg, sort_keys=True).encode()).hexdigest()[:8] + cfg["id"] = hashlib.md5( + json.dumps(cfg, sort_keys=True).encode() + ).hexdigest()[:8] configs.append(cfg) return configs + def run_single(cfg: dict) -> dict: """execute one experiment config, return metrics""" from engine.wrapper import PHANTOM @@ -49,7 +55,8 @@ def run_single(cfg: dict) -> dict: obs, reward, term, trunc, _ = env.step(action) total_reward += reward steps += 1 - if term: break + if term: + break env.close() return { @@ -60,22 +67,28 @@ def run_single(cfg: dict) -> dict: "steps": steps, } + def run_study(max_workers: int = None, output: str = "results_full.jsonl"): configs = generate_configs() - log.info(f"full factorial: {len(configs)} configs ({len(configs)//SEEDS_PER_CONFIG} unique × {SEEDS_PER_CONFIG} seeds)") + log.info( + f"full factorial: {len(configs)} configs ({len(configs) // SEEDS_PER_CONFIG} unique × {SEEDS_PER_CONFIG} seeds)" + ) results = [] with ProcessPoolExecutor(max_workers=max_workers) as ex: for i, result in enumerate(ex.map(run_single, configs)): results.append(result) - if (i+1) % 100 == 0: log.info(f"progress: {i+1}/{len(configs)}") + if (i + 1) % 100 == 0: + log.info(f"progress: {i + 1}/{len(configs)}") Path(output).write_text("\n".join(json.dumps(r) for r in results)) log.info(f"wrote {len(results)} results to {output}") return results + if __name__ == "__main__": import argparse + p = argparse.ArgumentParser() p.add_argument("--workers", type=int, default=None) p.add_argument("--output", default="results_full.jsonl") @@ -83,7 +96,9 @@ if __name__ == "__main__": args = p.parse_args() configs = generate_configs() - log.info(f"design: {len(configs)} runs | factors: {[f.name for f in FACTORS]} | levels: {[len(f.levels) for f in FACTORS]}") + log.info( + f"design: {len(configs)} runs | factors: {[f.name for f in FACTORS]} | levels: {[len(f.levels) for f in FACTORS]}" + ) if not args.dry_run: run_study(args.workers, args.output) diff --git a/engine/studies/mixed_lh.py b/engine/studies/mixed_lh.py index 33ea2ee..3b7d7e8 100644 --- a/engine/studies/mixed_lh.py +++ b/engine/studies/mixed_lh.py @@ -1,5 +1,7 @@ """mixed design: full factorial on primary factors, latin hypercube on secondary""" + import sys + sys.path.insert(0, "..") import logging from itertools import product @@ -16,6 +18,7 @@ log = logging.getLogger(__name__) LH_SAMPLES = 10 + def generate_configs(lh_samples: int = LH_SAMPLES): primary = [f for f in FACTORS if f.primary] secondary = [f for f in FACTORS if not f.primary] @@ -28,7 +31,9 @@ def generate_configs(lh_samples: int = LH_SAMPLES): samples = lhs.random(n=lh_samples) for s in samples: sec_vals = { - secondary[i].name: secondary[i].levels[int(s[i] * len(secondary[i].levels))] + secondary[i].name: secondary[i].levels[ + int(s[i] * len(secondary[i].levels)) + ] for i in range(len(secondary)) } base = {primary[i].name: p_combo[i] for i in range(len(primary))} @@ -36,10 +41,13 @@ def generate_configs(lh_samples: int = LH_SAMPLES): for seed in range(SEEDS_PER_CONFIG): cfg = {**base, "seed": seed} - cfg["id"] = hashlib.md5(json.dumps(cfg, sort_keys=True).encode()).hexdigest()[:8] + cfg["id"] = hashlib.md5( + json.dumps(cfg, sort_keys=True).encode() + ).hexdigest()[:8] configs.append(cfg) return configs + def run_single(cfg: dict) -> dict: from engine.wrapper import PHANTOM import numpy as np @@ -62,7 +70,8 @@ def run_single(cfg: dict) -> dict: obs, reward, term, trunc, _ = env.step(action) total_reward += reward steps += 1 - if term: break + if term: + break env.close() return { @@ -73,23 +82,33 @@ def run_single(cfg: dict) -> dict: "steps": steps, } -def run_study(max_workers: int = None, output: str = "results_mixed.jsonl", lh_samples: int = LH_SAMPLES): + +def run_study( + max_workers: int = None, + output: str = "results_mixed.jsonl", + lh_samples: int = LH_SAMPLES, +): configs = generate_configs(lh_samples) n_primary_cells = int(np.prod([len(f.levels) for f in FACTORS if f.primary])) - log.info(f"mixed LH: {len(configs)} configs ({n_primary_cells} primary × {lh_samples} LH × {SEEDS_PER_CONFIG} seeds)") + log.info( + f"mixed LH: {len(configs)} configs ({n_primary_cells} primary × {lh_samples} LH × {SEEDS_PER_CONFIG} seeds)" + ) results = [] with ProcessPoolExecutor(max_workers=max_workers) as ex: for i, result in enumerate(ex.map(run_single, configs)): results.append(result) - if (i+1) % 100 == 0: log.info(f"progress: {i+1}/{len(configs)}") + if (i + 1) % 100 == 0: + log.info(f"progress: {i + 1}/{len(configs)}") Path(output).write_text("\n".join(json.dumps(r) for r in results)) log.info(f"wrote {len(results)} results to {output}") return results + if __name__ == "__main__": import argparse + p = argparse.ArgumentParser() p.add_argument("--workers", type=int, default=None) p.add_argument("--output", default="results_mixed.jsonl") @@ -100,7 +119,9 @@ if __name__ == "__main__": primary = [f for f in FACTORS if f.primary] secondary = [f for f in FACTORS if not f.primary] configs = generate_configs(args.lh_samples) - log.info(f"design: {len(configs)} runs | primary: {[f.name for f in primary]} | secondary (LH): {[f.name for f in secondary]}") + log.info( + f"design: {len(configs)} runs | primary: {[f.name for f in primary]} | secondary (LH): {[f.name for f in secondary]}" + ) if not args.dry_run: run_study(args.workers, args.output, args.lh_samples) diff --git a/engine/train.py b/engine/train.py index 063f4ae..4d52a50 100644 --- a/engine/train.py +++ b/engine/train.py @@ -4,8 +4,12 @@ import argparse import json import os from pathlib import Path +from typing import TYPE_CHECKING import numpy as np +if TYPE_CHECKING: + from .lib.discrete import EventQTable + from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint try: