From 28d3f6853e197954ebe43872763c2dde6419f8c4 Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Fri, 30 Jan 2026 13:17:12 +0100 Subject: [PATCH] chore: refactor wrapper --- engine/lib/__init__.py | 3 + engine/lib/render.py | 126 +++++++++++++++++++++++++++++++ engine/train.py | 45 +++++++++++ engine/wrapper.py | 165 +++++------------------------------------ 4 files changed, 193 insertions(+), 146 deletions(-) create mode 100644 engine/lib/__init__.py create mode 100644 engine/lib/render.py create mode 100644 engine/train.py diff --git a/engine/lib/__init__.py b/engine/lib/__init__.py new file mode 100644 index 0000000..8e17835 --- /dev/null +++ b/engine/lib/__init__.py @@ -0,0 +1,3 @@ +from .demand import generate_demand, estimate_demand +from .behavior import sample_behavior +from .render import DashboardRenderer, style_axis diff --git a/engine/lib/render.py b/engine/lib/render.py new file mode 100644 index 0000000..a16f215 --- /dev/null +++ b/engine/lib/render.py @@ -0,0 +1,126 @@ +"""rendering logic for PHANTOM environment dashboard""" +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.gridspec import GridSpec + + +def style_axis(ax, title: str = None, xlabel: str = None, ylabel: str = None): + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + if title: ax.set_title(title, fontsize=11, fontweight='bold', pad=8) + if xlabel: ax.set_xlabel(xlabel, fontsize=9) + if ylabel: ax.set_ylabel(ylabel, fontsize=9) + + +class DashboardRenderer: + """stateful renderer for PHANTOM market dynamics visualization""" + + def __init__(self): + self.fig = None + self.gs = None + + def render(self, env) -> None: + if self.fig is None: + plt.ion() + self.fig = plt.figure(figsize=(14, 10)) + self.gs = GridSpec(3, 3, figure=self.fig, hspace=0.35, wspace=0.3, + left=0.07, right=0.95, top=0.92, bottom=0.08) + plt.show(block=False) + + self.fig.clear() + self.fig.suptitle(f'PHANTOM Market Dynamics [t={env._step_count}, a={env.alpha:.2f}]', + fontsize=14, fontweight='bold') + + demand_mat = np.array(env._demand_history).T + price_mat = np.array(env._price_history).T + elasticity = env._compute_elasticity() + + self._render_scatter(env) + self._render_elasticity_bar(env, elasticity) + self._render_session_pie(env) + self._render_price_heatmap(price_mat) + self._render_demand_heatmap(demand_mat) + self._render_correlation(env.n_products, price_mat, demand_mat) + self._render_revenue(env) + + self.fig.canvas.draw_idle() + self.fig.canvas.flush_events() + + def _render_scatter(self, env): + ax = self.fig.add_subplot(self.gs[0, 0]) + prices_flat = np.array(env._price_history).flatten() + demands_flat = np.array(env._demand_history).flatten() + product_ids = np.tile(np.arange(env.n_products), len(env._price_history)) + ax.scatter(prices_flat, demands_flat, c=product_ids, cmap='plasma', alpha=0.6, s=15, edgecolors='none') + if len(prices_flat) > 1: + z = np.polyfit(prices_flat, demands_flat, 1) + p_line = np.linspace(prices_flat.min(), prices_flat.max(), 50) + ax.plot(p_line, np.polyval(z, p_line), '--', lw=1.5, alpha=0.8) + style_axis(ax, "Price-Demand Relationship", "Price ($)", "Demand") + + def _render_elasticity_bar(self, env, elasticity): + ax = self.fig.add_subplot(self.gs[0, 1]) + ax.barh(range(env.n_products), elasticity, alpha=0.8) + ax.axvline(0, lw=0.8, alpha=0.5) + ax.axvline(-1, lw=1, ls='--', alpha=0.5) + ax.set_yticks(range(env.n_products)) + ax.set_yticklabels([f'P{i}' for i in range(env.n_products)], fontsize=7) + style_axis(ax, "Price Elasticity", "(dQ/dP)(P/Q)", None) + + def _render_session_pie(self, env): + ax = self.fig.add_subplot(self.gs[0, 2]) + n_h, n_a = env.market.Nhumans, env.market.Nagents + wedges, _ = ax.pie([n_h, n_a], startangle=90, wedgeprops={'linewidth': 2, 'edgecolor': 'white'}) + ax.legend(wedges, [f'H ({n_h})', f'A ({n_a})'], loc='lower center', fontsize=8, + frameon=False, bbox_to_anchor=(0.5, -0.05)) + ax.set_title("Session Mix", fontsize=11, fontweight='bold') + + def _render_price_heatmap(self, price_mat): + ax = self.fig.add_subplot(self.gs[1, :2]) + im = ax.imshow(price_mat, aspect='auto', cmap='viridis', origin='lower') + style_axis(ax, "Price Heatmap P(product, t)", "Step", "Product") + cbar = self.fig.colorbar(im, ax=ax, fraction=0.03, pad=0.02) + cbar.set_label('$', fontsize=8) + + def _render_demand_heatmap(self, demand_mat): + ax = self.fig.add_subplot(self.gs[1, 2]) + im = ax.imshow(demand_mat, aspect='auto', cmap='Blues', origin='lower') + style_axis(ax, "Demand Q(product, t)", "Step", None) + self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02) + + def _render_correlation(self, n_products, price_mat, demand_mat): + ax = self.fig.add_subplot(self.gs[2, 0]) + if price_mat.shape[1] > 2: + corr = np.corrcoef(price_mat, demand_mat)[:n_products, n_products:] + im = ax.imshow(corr, cmap='RdBu', vmin=-1, vmax=1, aspect='auto') + ax.set_xticks(range(n_products)) + ax.set_yticks(range(n_products)) + ax.set_xticklabels([f'Q{i}' for i in range(n_products)], fontsize=6) + ax.set_yticklabels([f'P{i}' for i in range(n_products)], fontsize=6) + self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02) + style_axis(ax, "Price-Demand Correlation", None, None) + + def _render_revenue(self, env): + ax = self.fig.add_subplot(self.gs[2, 1:]) + n_steps = len(env._revenue_history) + demand_std = [np.std(d) for d in env._demand_history] + ax.fill_between(range(n_steps), env._revenue_history, alpha=0.3) + ax.plot(env._revenue_history, linewidth=2, label='Revenue') + ax.set_xlim(0, max(n_steps, 1)) + ax.set_ylim(0, max(env._revenue_history) * 1.1 if env._revenue_history else 1) + + ax2 = ax.twinx() + ax2.plot(range(n_steps), demand_std, linewidth=2, ls='-', alpha=0.9, label='sigma(Demand)') + d_min, d_max = min(demand_std), max(demand_std) + margin = (d_max - d_min) * 0.2 if d_max > d_min else 0.5 + ax2.set_ylim(max(0, d_min - margin), d_max + margin) + ax2.set_ylabel('Demand sigma', fontsize=9) + + style_axis(ax, "Revenue & Demand Dispersion", "Step", "Revenue ($)") + ax.legend(loc='upper left', fontsize=7, frameon=False) + ax2.legend(loc='upper right', fontsize=7, frameon=False) + + def close(self): + if self.fig: + plt.close(self.fig) + self.fig = None diff --git a/engine/train.py b/engine/train.py new file mode 100644 index 0000000..496ecfd --- /dev/null +++ b/engine/train.py @@ -0,0 +1,45 @@ +from stable_baselines3 import SAC +from stable_baselines3.common.callbacks import EvalCallback, BaseCallback +from .wrapper import PHANTOM + + +class RenderCallback(BaseCallback): + """Renders environment on every step for live visualization.""" + def __init__(self, env: PHANTOM): + super().__init__() + self.env = env + + def _on_step(self) -> bool: + self.env.render() + return True + + +env = PHANTOM(n_products=10, alpha=0.3, render_mode="human") +eval_env = PHANTOM(n_products=10, alpha=0.3, render_mode=None) + +model = SAC( + "MultiInputPolicy", + env, + verbose=1, + learning_rate=3e-4, + buffer_size=50000, + batch_size=256, + tau=0.005, + gamma=0.99, +) + +render_cb = RenderCallback(env) +eval_cb = EvalCallback(eval_env, eval_freq=1000, n_eval_episodes=5, verbose=1) + +model.learn(total_timesteps=50000, callback=[render_cb, eval_cb]) +model.save("phantom_sac") + +# test trained policy +env = PHANTOM(n_products=10, alpha=0.3, render_mode="human") +obs, _ = env.reset() +for _ in range(100): + action, _ = model.predict(obs, deterministic=True) + obs, reward, term, trunc, _ = env.step(action) + env.render() + if term or trunc: break +env.close() diff --git a/engine/wrapper.py b/engine/wrapper.py index 7637998..0301082 100644 --- a/engine/wrapper.py +++ b/engine/wrapper.py @@ -1,10 +1,8 @@ import gymnasium as gym from gymnasium import spaces import numpy as np -import matplotlib.pyplot as plt -from matplotlib.gridspec import GridSpec -import matplotlib.colors as mcolors from .engine import Limbo, MarketEngine, PricingEngine +from .lib.render import DashboardRenderer class PHANTOM(gym.Env): @@ -16,7 +14,7 @@ class PHANTOM(gym.Env): alpha: float = 0.3, N: int = 100, price_bounds: tuple = (10.0, 150.0), - lambda_coi: float = 0.1, # coi leakage penalty weight + lambda_coi: float = 0.1, render_mode: str = None): super().__init__() self.n_products = n_products @@ -30,12 +28,10 @@ class PHANTOM(gym.Env): self._platform_stub = PricingEngine() self._limbo = Limbo(self._platform_stub, self.market) - # action: continuous prices for each product self.action_space = spaces.Box( low=price_bounds[0], high=price_bounds[1], shape=(n_products,), dtype=np.float32 ) - # observation: demand estimate + previous prices self.observation_space = spaces.Dict({ "demand": spaces.Box(low=0.0, high=100.0, shape=(n_products,), dtype=np.float32), "prices": spaces.Box(low=price_bounds[0], high=price_bounds[1], shape=(n_products,), dtype=np.float32), @@ -47,30 +43,22 @@ class PHANTOM(gym.Env): self._demand_history = [] self._price_history = [] self._revenue_history = [] - self._fig = None - self._gs = None - self._dashboard_colors = { - 'bg': '#f5f0e8', 'panel': '#ebe3d5', 'accent': '#c9b99a', - 'text': '#3d3229', 'green': '#5c7a5c', 'red': '#8b4049', - 'blue': '#5a7384', 'orange': '#b87333', 'purple': '#7d6b7d' - } + self._renderer = None def _get_obs(self) -> dict: demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)], dtype=np.float32) return {"demand": demand_arr, "prices": self._prices.astype(np.float32)} def _compute_reward(self, prices: np.ndarray, demand: dict) -> float: - demand_arr = np.array([demand.get(i, 0.0) for i in range(self.n_products)]) - revenue = np.sum(prices * demand_arr) # revenue = price * quantity proxy - base_price = self.price_bounds[0] - return float(revenue)# - self.lambda_coi * coi_leak) + revenue = np.sum(prices * np.array([demand.get(i, 0.0) for i in range(self.n_products)])) + # TODO: implement supra-competitive price punishment + return float(revenue) def _record_history(self): demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)]) self._demand_history.append(demand_arr) self._price_history.append(self._prices.copy()) - revenue = np.sum(self._prices * demand_arr) - self._revenue_history.append(revenue) + self._revenue_history.append(np.sum(self._prices * demand_arr)) def reset(self, seed=None, options=None): super().reset(seed=seed) @@ -89,149 +77,34 @@ class PHANTOM(gym.Env): reward = self._compute_reward(self._prices, self._demand) terminated = self._step_count >= 100 - truncated = False - return self._get_obs(), reward, terminated, truncated, {"step": self._step_count} + return self._get_obs(), reward, terminated, False, {"step": self._step_count} def _compute_elasticity(self) -> np.ndarray: - """point elasticity: e = (dQ/dP) * (P/Q) estimated via finite differences, clipped to [-5, 5]""" + """point elasticity: e = (dQ/dP) * (P/Q) via finite differences, clipped to [-5, 5]""" if len(self._price_history) < 2: return np.zeros(self.n_products) - p = np.array(self._price_history) - q = np.array(self._demand_history) - dp = np.diff(p, axis=0) - dq = np.diff(q, axis=0) - min_dp = 0.5 # ignore tiny price changes to avoid explosions - valid = np.abs(dp) > min_dp + p, q = np.array(self._price_history), np.array(self._demand_history) + dp, dq = np.diff(p, axis=0), np.diff(q, axis=0) + valid = np.abs(dp) > 0.5 with np.errstate(divide='ignore', invalid='ignore'): elasticity = np.where(valid, (dq / dp) * (p[:-1] / np.maximum(q[:-1], 1.0)), 0.0) - elasticity = np.clip(elasticity, -5.0, 5.0) - elasticity = np.nan_to_num(elasticity, nan=0.0) + elasticity = np.nan_to_num(np.clip(elasticity, -5.0, 5.0), nan=0.0) return np.mean(elasticity, axis=0) if len(elasticity) > 0 else np.zeros(self.n_products) - def _style_axis(self, ax, title: str = None, xlabel: str = None, ylabel: str = None): - c = self._dashboard_colors - ax.set_facecolor(c['panel']) - ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False) - ax.spines['bottom'].set_color(c['accent']); ax.spines['left'].set_color(c['accent']) - ax.tick_params(colors=c['text'], labelsize=8) - if title: ax.set_title(title, color=c['text'], fontsize=11, fontweight='bold', pad=8) - if xlabel: ax.set_xlabel(xlabel, color=c['text'], fontsize=9) - if ylabel: ax.set_ylabel(ylabel, color=c['text'], fontsize=9) - def render(self): if self.render_mode == "human": - c = self._dashboard_colors - if self._fig is None: - plt.ion() - self._fig = plt.figure(figsize=(14, 10), facecolor=c['bg']) - self._gs = GridSpec(3, 3, figure=self._fig, hspace=0.35, wspace=0.3, - left=0.07, right=0.95, top=0.92, bottom=0.08) - plt.show(block=False) - - self._fig.clear() - self._fig.suptitle(f'PHANTOM Market Dynamics [t={self._step_count}, α={self.alpha:.2f}]', - color=c['text'], fontsize=14, fontweight='bold') - - demand_mat = np.array(self._demand_history).T - price_mat = np.array(self._price_history).T - elasticity = self._compute_elasticity() - cmap = mcolors.LinearSegmentedColormap.from_list('phantom', [c['bg'], c['blue'], c['green']]) - cmap_div = mcolors.LinearSegmentedColormap.from_list('elast', [c['red'], c['bg'], c['blue']]) - - # price-demand elasticity scatter (all historical data points) - ax_elast = self._fig.add_subplot(self._gs[0, 0]) - prices_flat = np.array(self._price_history).flatten() - demands_flat = np.array(self._demand_history).flatten() - product_ids = np.tile(np.arange(self.n_products), len(self._price_history)) - scatter = ax_elast.scatter(prices_flat, demands_flat, c=product_ids, cmap='plasma', - alpha=0.6, s=15, edgecolors='none') - if len(prices_flat) > 1: # fit regression line - z = np.polyfit(prices_flat, demands_flat, 1) - p_line = np.linspace(prices_flat.min(), prices_flat.max(), 50) - ax_elast.plot(p_line, np.polyval(z, p_line), '--', color=c['red'], lw=1.5, alpha=0.8) - self._style_axis(ax_elast, "Price-Demand Relationship", "Price ($)", "Demand") - - # elasticity coefficients bar - ax_ebar = self._fig.add_subplot(self._gs[0, 1]) - colors_e = [c['red'] if e < -0.5 else c['blue'] if e > 0.5 else c['accent'] for e in elasticity] - ax_ebar.barh(range(self.n_products), elasticity, color=colors_e, alpha=0.8, edgecolor=c['bg']) - ax_ebar.axvline(0, color=c['text'], lw=0.8, alpha=0.5) - ax_ebar.axvline(-1, color=c['red'], lw=1, ls='--', alpha=0.5) # unit elastic reference - ax_ebar.set_yticks(range(self.n_products)) - ax_ebar.set_yticklabels([f'P{i}' for i in range(self.n_products)], fontsize=7) - self._style_axis(ax_ebar, "Price Elasticity ε", "ε = (ΔQ/ΔP)·(P/Q)", None) - - # session composition pie - ax_pie = self._fig.add_subplot(self._gs[0, 2]) - n_humans, n_agents = self.market.Nhumans, self.market.Nagents - ax_pie.set_facecolor(c['panel']) - wedges, _ = ax_pie.pie([n_humans, n_agents], colors=[c['blue'], c['red']], - startangle=90, wedgeprops={'linewidth': 2, 'edgecolor': c['bg']}) - ax_pie.legend(wedges, [f'H ({n_humans})', f'A ({n_agents})'], - loc='lower center', fontsize=8, frameon=False, - labelcolor=c['text'], bbox_to_anchor=(0.5, -0.05)) - ax_pie.set_title("Session Mix", color=c['text'], fontsize=11, fontweight='bold') - - # price heatmap over time - ax_pheat = self._fig.add_subplot(self._gs[1, :2]) - im_p = ax_pheat.imshow(price_mat, aspect='auto', cmap='viridis', origin='lower') - self._style_axis(ax_pheat, "Price Heatmap P(product, t)", "Step", "Product") - cbar_p = self._fig.colorbar(im_p, ax=ax_pheat, fraction=0.03, pad=0.02) - cbar_p.ax.tick_params(colors=c['text'], labelsize=7) - cbar_p.set_label('$', color=c['text'], fontsize=8) - - # demand heatmap over time - ax_dheat = self._fig.add_subplot(self._gs[1, 2]) - im_d = ax_dheat.imshow(demand_mat, aspect='auto', cmap=cmap, origin='lower') - self._style_axis(ax_dheat, "Demand Q(product, t)", "Step", None) - cbar_d = self._fig.colorbar(im_d, ax=ax_dheat, fraction=0.046, pad=0.02) - cbar_d.ax.tick_params(colors=c['text'], labelsize=7) - - # cross-correlation matrix (price-demand covariance per product) - ax_corr = self._fig.add_subplot(self._gs[2, 0]) - if len(self._price_history) > 2: - corr_mat = np.corrcoef(price_mat, demand_mat)[:self.n_products, self.n_products:] - im_corr = ax_corr.imshow(corr_mat, cmap=cmap_div, vmin=-1, vmax=1, aspect='auto') - ax_corr.set_xticks(range(self.n_products)) - ax_corr.set_yticks(range(self.n_products)) - ax_corr.set_xticklabels([f'Q{i}' for i in range(self.n_products)], fontsize=6) - ax_corr.set_yticklabels([f'P{i}' for i in range(self.n_products)], fontsize=6) - cbar_c = self._fig.colorbar(im_corr, ax=ax_corr, fraction=0.046, pad=0.02) - cbar_c.ax.tick_params(colors=c['text'], labelsize=7) - self._style_axis(ax_corr, "Price-Demand Correlation", None, None) - - # revenue curve with demand dispersion (std dev shows concentration) - ax_rev = self._fig.add_subplot(self._gs[2, 1:]) - n_steps = len(self._revenue_history) - demand_std = [np.std(d) for d in self._demand_history] - ax_rev.fill_between(range(n_steps), self._revenue_history, alpha=0.3, color=c['green']) - ax_rev.plot(self._revenue_history, color=c['green'], linewidth=2, label='Revenue') - ax_rev.set_xlim(0, max(n_steps, 1)) - ax_rev.set_ylim(0, max(self._revenue_history) * 1.1 if self._revenue_history else 1) - ax2 = ax_rev.twinx() - ax2.plot(range(n_steps), demand_std, color=c['blue'], linewidth=2, ls='-', alpha=0.9, label='σ(Demand)') - d_min, d_max = min(demand_std), max(demand_std) - margin = (d_max - d_min) * 0.2 if d_max > d_min else 0.5 - ax2.set_ylim(max(0, d_min - margin), d_max + margin) - ax2.tick_params(axis='y', colors=c['blue'], labelsize=8) - ax2.spines['right'].set_color(c['blue']) - ax2.set_ylabel('Demand σ', color=c['blue'], fontsize=9) - self._style_axis(ax_rev, "Revenue & Demand Dispersion", "Step", "Revenue ($)") - ax_rev.legend(loc='upper left', fontsize=7, frameon=False, labelcolor=c['text']) - ax2.legend(loc='upper right', fontsize=7, frameon=False, labelcolor=c['text']) - - self._fig.canvas.draw_idle() - self._fig.canvas.flush_events() - plt.pause(0.05) - + if self._renderer is None: + self._renderer = DashboardRenderer() + self._renderer.render(self) elif self.render_mode == "ansi": return f"step={self._step_count}, prices={self._prices}, demand={self._demand}" return None def close(self): - if self._fig: plt.close(self._fig) - self._fig = None + if self._renderer: + self._renderer.close() + self._renderer = None if __name__ == "__main__":