chore: refactor wrapper

This commit is contained in:
2026-01-30 13:17:12 +01:00
parent 10e8397eec
commit 28d3f6853e
4 changed files with 193 additions and 146 deletions

3
engine/lib/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .demand import generate_demand, estimate_demand
from .behavior import sample_behavior
from .render import DashboardRenderer, style_axis

126
engine/lib/render.py Normal file
View File

@@ -0,0 +1,126 @@
"""rendering logic for PHANTOM environment dashboard"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
def style_axis(ax, title: str = None, xlabel: str = None, ylabel: str = None):
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
if title: ax.set_title(title, fontsize=11, fontweight='bold', pad=8)
if xlabel: ax.set_xlabel(xlabel, fontsize=9)
if ylabel: ax.set_ylabel(ylabel, fontsize=9)
class DashboardRenderer:
"""stateful renderer for PHANTOM market dynamics visualization"""
def __init__(self):
self.fig = None
self.gs = None
def render(self, env) -> None:
if self.fig is None:
plt.ion()
self.fig = plt.figure(figsize=(14, 10))
self.gs = GridSpec(3, 3, figure=self.fig, hspace=0.35, wspace=0.3,
left=0.07, right=0.95, top=0.92, bottom=0.08)
plt.show(block=False)
self.fig.clear()
self.fig.suptitle(f'PHANTOM Market Dynamics [t={env._step_count}, a={env.alpha:.2f}]',
fontsize=14, fontweight='bold')
demand_mat = np.array(env._demand_history).T
price_mat = np.array(env._price_history).T
elasticity = env._compute_elasticity()
self._render_scatter(env)
self._render_elasticity_bar(env, elasticity)
self._render_session_pie(env)
self._render_price_heatmap(price_mat)
self._render_demand_heatmap(demand_mat)
self._render_correlation(env.n_products, price_mat, demand_mat)
self._render_revenue(env)
self.fig.canvas.draw_idle()
self.fig.canvas.flush_events()
def _render_scatter(self, env):
ax = self.fig.add_subplot(self.gs[0, 0])
prices_flat = np.array(env._price_history).flatten()
demands_flat = np.array(env._demand_history).flatten()
product_ids = np.tile(np.arange(env.n_products), len(env._price_history))
ax.scatter(prices_flat, demands_flat, c=product_ids, cmap='plasma', alpha=0.6, s=15, edgecolors='none')
if len(prices_flat) > 1:
z = np.polyfit(prices_flat, demands_flat, 1)
p_line = np.linspace(prices_flat.min(), prices_flat.max(), 50)
ax.plot(p_line, np.polyval(z, p_line), '--', lw=1.5, alpha=0.8)
style_axis(ax, "Price-Demand Relationship", "Price ($)", "Demand")
def _render_elasticity_bar(self, env, elasticity):
ax = self.fig.add_subplot(self.gs[0, 1])
ax.barh(range(env.n_products), elasticity, alpha=0.8)
ax.axvline(0, lw=0.8, alpha=0.5)
ax.axvline(-1, lw=1, ls='--', alpha=0.5)
ax.set_yticks(range(env.n_products))
ax.set_yticklabels([f'P{i}' for i in range(env.n_products)], fontsize=7)
style_axis(ax, "Price Elasticity", "(dQ/dP)(P/Q)", None)
def _render_session_pie(self, env):
ax = self.fig.add_subplot(self.gs[0, 2])
n_h, n_a = env.market.Nhumans, env.market.Nagents
wedges, _ = ax.pie([n_h, n_a], startangle=90, wedgeprops={'linewidth': 2, 'edgecolor': 'white'})
ax.legend(wedges, [f'H ({n_h})', f'A ({n_a})'], loc='lower center', fontsize=8,
frameon=False, bbox_to_anchor=(0.5, -0.05))
ax.set_title("Session Mix", fontsize=11, fontweight='bold')
def _render_price_heatmap(self, price_mat):
ax = self.fig.add_subplot(self.gs[1, :2])
im = ax.imshow(price_mat, aspect='auto', cmap='viridis', origin='lower')
style_axis(ax, "Price Heatmap P(product, t)", "Step", "Product")
cbar = self.fig.colorbar(im, ax=ax, fraction=0.03, pad=0.02)
cbar.set_label('$', fontsize=8)
def _render_demand_heatmap(self, demand_mat):
ax = self.fig.add_subplot(self.gs[1, 2])
im = ax.imshow(demand_mat, aspect='auto', cmap='Blues', origin='lower')
style_axis(ax, "Demand Q(product, t)", "Step", None)
self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
def _render_correlation(self, n_products, price_mat, demand_mat):
ax = self.fig.add_subplot(self.gs[2, 0])
if price_mat.shape[1] > 2:
corr = np.corrcoef(price_mat, demand_mat)[:n_products, n_products:]
im = ax.imshow(corr, cmap='RdBu', vmin=-1, vmax=1, aspect='auto')
ax.set_xticks(range(n_products))
ax.set_yticks(range(n_products))
ax.set_xticklabels([f'Q{i}' for i in range(n_products)], fontsize=6)
ax.set_yticklabels([f'P{i}' for i in range(n_products)], fontsize=6)
self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
style_axis(ax, "Price-Demand Correlation", None, None)
def _render_revenue(self, env):
ax = self.fig.add_subplot(self.gs[2, 1:])
n_steps = len(env._revenue_history)
demand_std = [np.std(d) for d in env._demand_history]
ax.fill_between(range(n_steps), env._revenue_history, alpha=0.3)
ax.plot(env._revenue_history, linewidth=2, label='Revenue')
ax.set_xlim(0, max(n_steps, 1))
ax.set_ylim(0, max(env._revenue_history) * 1.1 if env._revenue_history else 1)
ax2 = ax.twinx()
ax2.plot(range(n_steps), demand_std, linewidth=2, ls='-', alpha=0.9, label='sigma(Demand)')
d_min, d_max = min(demand_std), max(demand_std)
margin = (d_max - d_min) * 0.2 if d_max > d_min else 0.5
ax2.set_ylim(max(0, d_min - margin), d_max + margin)
ax2.set_ylabel('Demand sigma', fontsize=9)
style_axis(ax, "Revenue & Demand Dispersion", "Step", "Revenue ($)")
ax.legend(loc='upper left', fontsize=7, frameon=False)
ax2.legend(loc='upper right', fontsize=7, frameon=False)
def close(self):
if self.fig:
plt.close(self.fig)
self.fig = None

45
engine/train.py Normal file
View File

@@ -0,0 +1,45 @@
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
from .wrapper import PHANTOM
class RenderCallback(BaseCallback):
"""Renders environment on every step for live visualization."""
def __init__(self, env: PHANTOM):
super().__init__()
self.env = env
def _on_step(self) -> bool:
self.env.render()
return True
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")
eval_env = PHANTOM(n_products=10, alpha=0.3, render_mode=None)
model = SAC(
"MultiInputPolicy",
env,
verbose=1,
learning_rate=3e-4,
buffer_size=50000,
batch_size=256,
tau=0.005,
gamma=0.99,
)
render_cb = RenderCallback(env)
eval_cb = EvalCallback(eval_env, eval_freq=1000, n_eval_episodes=5, verbose=1)
model.learn(total_timesteps=50000, callback=[render_cb, eval_cb])
model.save("phantom_sac")
# test trained policy
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")
obs, _ = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, term, trunc, _ = env.step(action)
env.render()
if term or trunc: break
env.close()

View File

@@ -1,10 +1,8 @@
import gymnasium as gym import gymnasium as gym
from gymnasium import spaces from gymnasium import spaces
import numpy as np import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.colors as mcolors
from .engine import Limbo, MarketEngine, PricingEngine from .engine import Limbo, MarketEngine, PricingEngine
from .lib.render import DashboardRenderer
class PHANTOM(gym.Env): class PHANTOM(gym.Env):
@@ -16,7 +14,7 @@ class PHANTOM(gym.Env):
alpha: float = 0.3, alpha: float = 0.3,
N: int = 100, N: int = 100,
price_bounds: tuple = (10.0, 150.0), price_bounds: tuple = (10.0, 150.0),
lambda_coi: float = 0.1, # coi leakage penalty weight lambda_coi: float = 0.1,
render_mode: str = None): render_mode: str = None):
super().__init__() super().__init__()
self.n_products = n_products self.n_products = n_products
@@ -30,12 +28,10 @@ class PHANTOM(gym.Env):
self._platform_stub = PricingEngine() self._platform_stub = PricingEngine()
self._limbo = Limbo(self._platform_stub, self.market) self._limbo = Limbo(self._platform_stub, self.market)
# action: continuous prices for each product
self.action_space = spaces.Box( self.action_space = spaces.Box(
low=price_bounds[0], high=price_bounds[1], low=price_bounds[0], high=price_bounds[1],
shape=(n_products,), dtype=np.float32 shape=(n_products,), dtype=np.float32
) )
# observation: demand estimate + previous prices
self.observation_space = spaces.Dict({ self.observation_space = spaces.Dict({
"demand": spaces.Box(low=0.0, high=100.0, shape=(n_products,), dtype=np.float32), "demand": spaces.Box(low=0.0, high=100.0, shape=(n_products,), dtype=np.float32),
"prices": spaces.Box(low=price_bounds[0], high=price_bounds[1], shape=(n_products,), dtype=np.float32), "prices": spaces.Box(low=price_bounds[0], high=price_bounds[1], shape=(n_products,), dtype=np.float32),
@@ -47,30 +43,22 @@ class PHANTOM(gym.Env):
self._demand_history = [] self._demand_history = []
self._price_history = [] self._price_history = []
self._revenue_history = [] self._revenue_history = []
self._fig = None self._renderer = None
self._gs = None
self._dashboard_colors = {
'bg': '#f5f0e8', 'panel': '#ebe3d5', 'accent': '#c9b99a',
'text': '#3d3229', 'green': '#5c7a5c', 'red': '#8b4049',
'blue': '#5a7384', 'orange': '#b87333', 'purple': '#7d6b7d'
}
def _get_obs(self) -> dict: def _get_obs(self) -> dict:
demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)], dtype=np.float32) demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)], dtype=np.float32)
return {"demand": demand_arr, "prices": self._prices.astype(np.float32)} return {"demand": demand_arr, "prices": self._prices.astype(np.float32)}
def _compute_reward(self, prices: np.ndarray, demand: dict) -> float: def _compute_reward(self, prices: np.ndarray, demand: dict) -> float:
demand_arr = np.array([demand.get(i, 0.0) for i in range(self.n_products)]) revenue = np.sum(prices * np.array([demand.get(i, 0.0) for i in range(self.n_products)]))
revenue = np.sum(prices * demand_arr) # revenue = price * quantity proxy # TODO: implement supra-competitive price punishment
base_price = self.price_bounds[0] return float(revenue)
return float(revenue)# - self.lambda_coi * coi_leak)
def _record_history(self): def _record_history(self):
demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)]) demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)])
self._demand_history.append(demand_arr) self._demand_history.append(demand_arr)
self._price_history.append(self._prices.copy()) self._price_history.append(self._prices.copy())
revenue = np.sum(self._prices * demand_arr) self._revenue_history.append(np.sum(self._prices * demand_arr))
self._revenue_history.append(revenue)
def reset(self, seed=None, options=None): def reset(self, seed=None, options=None):
super().reset(seed=seed) super().reset(seed=seed)
@@ -89,149 +77,34 @@ class PHANTOM(gym.Env):
reward = self._compute_reward(self._prices, self._demand) reward = self._compute_reward(self._prices, self._demand)
terminated = self._step_count >= 100 terminated = self._step_count >= 100
truncated = False
return self._get_obs(), reward, terminated, truncated, {"step": self._step_count} return self._get_obs(), reward, terminated, False, {"step": self._step_count}
def _compute_elasticity(self) -> np.ndarray: def _compute_elasticity(self) -> np.ndarray:
"""point elasticity: e = (dQ/dP) * (P/Q) estimated via finite differences, clipped to [-5, 5]""" """point elasticity: e = (dQ/dP) * (P/Q) via finite differences, clipped to [-5, 5]"""
if len(self._price_history) < 2: if len(self._price_history) < 2:
return np.zeros(self.n_products) return np.zeros(self.n_products)
p = np.array(self._price_history) p, q = np.array(self._price_history), np.array(self._demand_history)
q = np.array(self._demand_history) dp, dq = np.diff(p, axis=0), np.diff(q, axis=0)
dp = np.diff(p, axis=0) valid = np.abs(dp) > 0.5
dq = np.diff(q, axis=0)
min_dp = 0.5 # ignore tiny price changes to avoid explosions
valid = np.abs(dp) > min_dp
with np.errstate(divide='ignore', invalid='ignore'): with np.errstate(divide='ignore', invalid='ignore'):
elasticity = np.where(valid, (dq / dp) * (p[:-1] / np.maximum(q[:-1], 1.0)), 0.0) elasticity = np.where(valid, (dq / dp) * (p[:-1] / np.maximum(q[:-1], 1.0)), 0.0)
elasticity = np.clip(elasticity, -5.0, 5.0) elasticity = np.nan_to_num(np.clip(elasticity, -5.0, 5.0), nan=0.0)
elasticity = np.nan_to_num(elasticity, nan=0.0)
return np.mean(elasticity, axis=0) if len(elasticity) > 0 else np.zeros(self.n_products) return np.mean(elasticity, axis=0) if len(elasticity) > 0 else np.zeros(self.n_products)
def _style_axis(self, ax, title: str = None, xlabel: str = None, ylabel: str = None):
c = self._dashboard_colors
ax.set_facecolor(c['panel'])
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_color(c['accent']); ax.spines['left'].set_color(c['accent'])
ax.tick_params(colors=c['text'], labelsize=8)
if title: ax.set_title(title, color=c['text'], fontsize=11, fontweight='bold', pad=8)
if xlabel: ax.set_xlabel(xlabel, color=c['text'], fontsize=9)
if ylabel: ax.set_ylabel(ylabel, color=c['text'], fontsize=9)
def render(self): def render(self):
if self.render_mode == "human": if self.render_mode == "human":
c = self._dashboard_colors if self._renderer is None:
if self._fig is None: self._renderer = DashboardRenderer()
plt.ion() self._renderer.render(self)
self._fig = plt.figure(figsize=(14, 10), facecolor=c['bg'])
self._gs = GridSpec(3, 3, figure=self._fig, hspace=0.35, wspace=0.3,
left=0.07, right=0.95, top=0.92, bottom=0.08)
plt.show(block=False)
self._fig.clear()
self._fig.suptitle(f'PHANTOM Market Dynamics [t={self._step_count}, α={self.alpha:.2f}]',
color=c['text'], fontsize=14, fontweight='bold')
demand_mat = np.array(self._demand_history).T
price_mat = np.array(self._price_history).T
elasticity = self._compute_elasticity()
cmap = mcolors.LinearSegmentedColormap.from_list('phantom', [c['bg'], c['blue'], c['green']])
cmap_div = mcolors.LinearSegmentedColormap.from_list('elast', [c['red'], c['bg'], c['blue']])
# price-demand elasticity scatter (all historical data points)
ax_elast = self._fig.add_subplot(self._gs[0, 0])
prices_flat = np.array(self._price_history).flatten()
demands_flat = np.array(self._demand_history).flatten()
product_ids = np.tile(np.arange(self.n_products), len(self._price_history))
scatter = ax_elast.scatter(prices_flat, demands_flat, c=product_ids, cmap='plasma',
alpha=0.6, s=15, edgecolors='none')
if len(prices_flat) > 1: # fit regression line
z = np.polyfit(prices_flat, demands_flat, 1)
p_line = np.linspace(prices_flat.min(), prices_flat.max(), 50)
ax_elast.plot(p_line, np.polyval(z, p_line), '--', color=c['red'], lw=1.5, alpha=0.8)
self._style_axis(ax_elast, "Price-Demand Relationship", "Price ($)", "Demand")
# elasticity coefficients bar
ax_ebar = self._fig.add_subplot(self._gs[0, 1])
colors_e = [c['red'] if e < -0.5 else c['blue'] if e > 0.5 else c['accent'] for e in elasticity]
ax_ebar.barh(range(self.n_products), elasticity, color=colors_e, alpha=0.8, edgecolor=c['bg'])
ax_ebar.axvline(0, color=c['text'], lw=0.8, alpha=0.5)
ax_ebar.axvline(-1, color=c['red'], lw=1, ls='--', alpha=0.5) # unit elastic reference
ax_ebar.set_yticks(range(self.n_products))
ax_ebar.set_yticklabels([f'P{i}' for i in range(self.n_products)], fontsize=7)
self._style_axis(ax_ebar, "Price Elasticity ε", "ε = (ΔQ/ΔP)·(P/Q)", None)
# session composition pie
ax_pie = self._fig.add_subplot(self._gs[0, 2])
n_humans, n_agents = self.market.Nhumans, self.market.Nagents
ax_pie.set_facecolor(c['panel'])
wedges, _ = ax_pie.pie([n_humans, n_agents], colors=[c['blue'], c['red']],
startangle=90, wedgeprops={'linewidth': 2, 'edgecolor': c['bg']})
ax_pie.legend(wedges, [f'H ({n_humans})', f'A ({n_agents})'],
loc='lower center', fontsize=8, frameon=False,
labelcolor=c['text'], bbox_to_anchor=(0.5, -0.05))
ax_pie.set_title("Session Mix", color=c['text'], fontsize=11, fontweight='bold')
# price heatmap over time
ax_pheat = self._fig.add_subplot(self._gs[1, :2])
im_p = ax_pheat.imshow(price_mat, aspect='auto', cmap='viridis', origin='lower')
self._style_axis(ax_pheat, "Price Heatmap P(product, t)", "Step", "Product")
cbar_p = self._fig.colorbar(im_p, ax=ax_pheat, fraction=0.03, pad=0.02)
cbar_p.ax.tick_params(colors=c['text'], labelsize=7)
cbar_p.set_label('$', color=c['text'], fontsize=8)
# demand heatmap over time
ax_dheat = self._fig.add_subplot(self._gs[1, 2])
im_d = ax_dheat.imshow(demand_mat, aspect='auto', cmap=cmap, origin='lower')
self._style_axis(ax_dheat, "Demand Q(product, t)", "Step", None)
cbar_d = self._fig.colorbar(im_d, ax=ax_dheat, fraction=0.046, pad=0.02)
cbar_d.ax.tick_params(colors=c['text'], labelsize=7)
# cross-correlation matrix (price-demand covariance per product)
ax_corr = self._fig.add_subplot(self._gs[2, 0])
if len(self._price_history) > 2:
corr_mat = np.corrcoef(price_mat, demand_mat)[:self.n_products, self.n_products:]
im_corr = ax_corr.imshow(corr_mat, cmap=cmap_div, vmin=-1, vmax=1, aspect='auto')
ax_corr.set_xticks(range(self.n_products))
ax_corr.set_yticks(range(self.n_products))
ax_corr.set_xticklabels([f'Q{i}' for i in range(self.n_products)], fontsize=6)
ax_corr.set_yticklabels([f'P{i}' for i in range(self.n_products)], fontsize=6)
cbar_c = self._fig.colorbar(im_corr, ax=ax_corr, fraction=0.046, pad=0.02)
cbar_c.ax.tick_params(colors=c['text'], labelsize=7)
self._style_axis(ax_corr, "Price-Demand Correlation", None, None)
# revenue curve with demand dispersion (std dev shows concentration)
ax_rev = self._fig.add_subplot(self._gs[2, 1:])
n_steps = len(self._revenue_history)
demand_std = [np.std(d) for d in self._demand_history]
ax_rev.fill_between(range(n_steps), self._revenue_history, alpha=0.3, color=c['green'])
ax_rev.plot(self._revenue_history, color=c['green'], linewidth=2, label='Revenue')
ax_rev.set_xlim(0, max(n_steps, 1))
ax_rev.set_ylim(0, max(self._revenue_history) * 1.1 if self._revenue_history else 1)
ax2 = ax_rev.twinx()
ax2.plot(range(n_steps), demand_std, color=c['blue'], linewidth=2, ls='-', alpha=0.9, label='σ(Demand)')
d_min, d_max = min(demand_std), max(demand_std)
margin = (d_max - d_min) * 0.2 if d_max > d_min else 0.5
ax2.set_ylim(max(0, d_min - margin), d_max + margin)
ax2.tick_params(axis='y', colors=c['blue'], labelsize=8)
ax2.spines['right'].set_color(c['blue'])
ax2.set_ylabel('Demand σ', color=c['blue'], fontsize=9)
self._style_axis(ax_rev, "Revenue & Demand Dispersion", "Step", "Revenue ($)")
ax_rev.legend(loc='upper left', fontsize=7, frameon=False, labelcolor=c['text'])
ax2.legend(loc='upper right', fontsize=7, frameon=False, labelcolor=c['text'])
self._fig.canvas.draw_idle()
self._fig.canvas.flush_events()
plt.pause(0.05)
elif self.render_mode == "ansi": elif self.render_mode == "ansi":
return f"step={self._step_count}, prices={self._prices}, demand={self._demand}" return f"step={self._step_count}, prices={self._prices}, demand={self._demand}"
return None return None
def close(self): def close(self):
if self._fig: plt.close(self._fig) if self._renderer:
self._fig = None self._renderer.close()
self._renderer = None
if __name__ == "__main__": if __name__ == "__main__":