chore: refactor wrapper

This commit is contained in:
2026-01-30 13:17:12 +01:00
parent 10e8397eec
commit 28d3f6853e
4 changed files with 193 additions and 146 deletions

3
engine/lib/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .demand import generate_demand, estimate_demand
from .behavior import sample_behavior
from .render import DashboardRenderer, style_axis

126
engine/lib/render.py Normal file
View File

@@ -0,0 +1,126 @@
"""rendering logic for PHANTOM environment dashboard"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
def style_axis(ax, title: str = None, xlabel: str = None, ylabel: str = None):
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
if title: ax.set_title(title, fontsize=11, fontweight='bold', pad=8)
if xlabel: ax.set_xlabel(xlabel, fontsize=9)
if ylabel: ax.set_ylabel(ylabel, fontsize=9)
class DashboardRenderer:
"""stateful renderer for PHANTOM market dynamics visualization"""
def __init__(self):
self.fig = None
self.gs = None
def render(self, env) -> None:
if self.fig is None:
plt.ion()
self.fig = plt.figure(figsize=(14, 10))
self.gs = GridSpec(3, 3, figure=self.fig, hspace=0.35, wspace=0.3,
left=0.07, right=0.95, top=0.92, bottom=0.08)
plt.show(block=False)
self.fig.clear()
self.fig.suptitle(f'PHANTOM Market Dynamics [t={env._step_count}, a={env.alpha:.2f}]',
fontsize=14, fontweight='bold')
demand_mat = np.array(env._demand_history).T
price_mat = np.array(env._price_history).T
elasticity = env._compute_elasticity()
self._render_scatter(env)
self._render_elasticity_bar(env, elasticity)
self._render_session_pie(env)
self._render_price_heatmap(price_mat)
self._render_demand_heatmap(demand_mat)
self._render_correlation(env.n_products, price_mat, demand_mat)
self._render_revenue(env)
self.fig.canvas.draw_idle()
self.fig.canvas.flush_events()
def _render_scatter(self, env):
ax = self.fig.add_subplot(self.gs[0, 0])
prices_flat = np.array(env._price_history).flatten()
demands_flat = np.array(env._demand_history).flatten()
product_ids = np.tile(np.arange(env.n_products), len(env._price_history))
ax.scatter(prices_flat, demands_flat, c=product_ids, cmap='plasma', alpha=0.6, s=15, edgecolors='none')
if len(prices_flat) > 1:
z = np.polyfit(prices_flat, demands_flat, 1)
p_line = np.linspace(prices_flat.min(), prices_flat.max(), 50)
ax.plot(p_line, np.polyval(z, p_line), '--', lw=1.5, alpha=0.8)
style_axis(ax, "Price-Demand Relationship", "Price ($)", "Demand")
def _render_elasticity_bar(self, env, elasticity):
ax = self.fig.add_subplot(self.gs[0, 1])
ax.barh(range(env.n_products), elasticity, alpha=0.8)
ax.axvline(0, lw=0.8, alpha=0.5)
ax.axvline(-1, lw=1, ls='--', alpha=0.5)
ax.set_yticks(range(env.n_products))
ax.set_yticklabels([f'P{i}' for i in range(env.n_products)], fontsize=7)
style_axis(ax, "Price Elasticity", "(dQ/dP)(P/Q)", None)
def _render_session_pie(self, env):
ax = self.fig.add_subplot(self.gs[0, 2])
n_h, n_a = env.market.Nhumans, env.market.Nagents
wedges, _ = ax.pie([n_h, n_a], startangle=90, wedgeprops={'linewidth': 2, 'edgecolor': 'white'})
ax.legend(wedges, [f'H ({n_h})', f'A ({n_a})'], loc='lower center', fontsize=8,
frameon=False, bbox_to_anchor=(0.5, -0.05))
ax.set_title("Session Mix", fontsize=11, fontweight='bold')
def _render_price_heatmap(self, price_mat):
ax = self.fig.add_subplot(self.gs[1, :2])
im = ax.imshow(price_mat, aspect='auto', cmap='viridis', origin='lower')
style_axis(ax, "Price Heatmap P(product, t)", "Step", "Product")
cbar = self.fig.colorbar(im, ax=ax, fraction=0.03, pad=0.02)
cbar.set_label('$', fontsize=8)
def _render_demand_heatmap(self, demand_mat):
ax = self.fig.add_subplot(self.gs[1, 2])
im = ax.imshow(demand_mat, aspect='auto', cmap='Blues', origin='lower')
style_axis(ax, "Demand Q(product, t)", "Step", None)
self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
def _render_correlation(self, n_products, price_mat, demand_mat):
ax = self.fig.add_subplot(self.gs[2, 0])
if price_mat.shape[1] > 2:
corr = np.corrcoef(price_mat, demand_mat)[:n_products, n_products:]
im = ax.imshow(corr, cmap='RdBu', vmin=-1, vmax=1, aspect='auto')
ax.set_xticks(range(n_products))
ax.set_yticks(range(n_products))
ax.set_xticklabels([f'Q{i}' for i in range(n_products)], fontsize=6)
ax.set_yticklabels([f'P{i}' for i in range(n_products)], fontsize=6)
self.fig.colorbar(im, ax=ax, fraction=0.046, pad=0.02)
style_axis(ax, "Price-Demand Correlation", None, None)
def _render_revenue(self, env):
ax = self.fig.add_subplot(self.gs[2, 1:])
n_steps = len(env._revenue_history)
demand_std = [np.std(d) for d in env._demand_history]
ax.fill_between(range(n_steps), env._revenue_history, alpha=0.3)
ax.plot(env._revenue_history, linewidth=2, label='Revenue')
ax.set_xlim(0, max(n_steps, 1))
ax.set_ylim(0, max(env._revenue_history) * 1.1 if env._revenue_history else 1)
ax2 = ax.twinx()
ax2.plot(range(n_steps), demand_std, linewidth=2, ls='-', alpha=0.9, label='sigma(Demand)')
d_min, d_max = min(demand_std), max(demand_std)
margin = (d_max - d_min) * 0.2 if d_max > d_min else 0.5
ax2.set_ylim(max(0, d_min - margin), d_max + margin)
ax2.set_ylabel('Demand sigma', fontsize=9)
style_axis(ax, "Revenue & Demand Dispersion", "Step", "Revenue ($)")
ax.legend(loc='upper left', fontsize=7, frameon=False)
ax2.legend(loc='upper right', fontsize=7, frameon=False)
def close(self):
if self.fig:
plt.close(self.fig)
self.fig = None

45
engine/train.py Normal file
View File

@@ -0,0 +1,45 @@
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
from .wrapper import PHANTOM
class RenderCallback(BaseCallback):
"""Renders environment on every step for live visualization."""
def __init__(self, env: PHANTOM):
super().__init__()
self.env = env
def _on_step(self) -> bool:
self.env.render()
return True
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")
eval_env = PHANTOM(n_products=10, alpha=0.3, render_mode=None)
model = SAC(
"MultiInputPolicy",
env,
verbose=1,
learning_rate=3e-4,
buffer_size=50000,
batch_size=256,
tau=0.005,
gamma=0.99,
)
render_cb = RenderCallback(env)
eval_cb = EvalCallback(eval_env, eval_freq=1000, n_eval_episodes=5, verbose=1)
model.learn(total_timesteps=50000, callback=[render_cb, eval_cb])
model.save("phantom_sac")
# test trained policy
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")
obs, _ = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, term, trunc, _ = env.step(action)
env.render()
if term or trunc: break
env.close()

View File

@@ -1,10 +1,8 @@
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.colors as mcolors
from .engine import Limbo, MarketEngine, PricingEngine
from .lib.render import DashboardRenderer
class PHANTOM(gym.Env):
@@ -16,7 +14,7 @@ class PHANTOM(gym.Env):
alpha: float = 0.3,
N: int = 100,
price_bounds: tuple = (10.0, 150.0),
lambda_coi: float = 0.1, # coi leakage penalty weight
lambda_coi: float = 0.1,
render_mode: str = None):
super().__init__()
self.n_products = n_products
@@ -30,12 +28,10 @@ class PHANTOM(gym.Env):
self._platform_stub = PricingEngine()
self._limbo = Limbo(self._platform_stub, self.market)
# action: continuous prices for each product
self.action_space = spaces.Box(
low=price_bounds[0], high=price_bounds[1],
shape=(n_products,), dtype=np.float32
)
# observation: demand estimate + previous prices
self.observation_space = spaces.Dict({
"demand": spaces.Box(low=0.0, high=100.0, shape=(n_products,), dtype=np.float32),
"prices": spaces.Box(low=price_bounds[0], high=price_bounds[1], shape=(n_products,), dtype=np.float32),
@@ -47,30 +43,22 @@ class PHANTOM(gym.Env):
self._demand_history = []
self._price_history = []
self._revenue_history = []
self._fig = None
self._gs = None
self._dashboard_colors = {
'bg': '#f5f0e8', 'panel': '#ebe3d5', 'accent': '#c9b99a',
'text': '#3d3229', 'green': '#5c7a5c', 'red': '#8b4049',
'blue': '#5a7384', 'orange': '#b87333', 'purple': '#7d6b7d'
}
self._renderer = None
def _get_obs(self) -> dict:
demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)], dtype=np.float32)
return {"demand": demand_arr, "prices": self._prices.astype(np.float32)}
def _compute_reward(self, prices: np.ndarray, demand: dict) -> float:
demand_arr = np.array([demand.get(i, 0.0) for i in range(self.n_products)])
revenue = np.sum(prices * demand_arr) # revenue = price * quantity proxy
base_price = self.price_bounds[0]
return float(revenue)# - self.lambda_coi * coi_leak)
revenue = np.sum(prices * np.array([demand.get(i, 0.0) for i in range(self.n_products)]))
# TODO: implement supra-competitive price punishment
return float(revenue)
def _record_history(self):
demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)])
self._demand_history.append(demand_arr)
self._price_history.append(self._prices.copy())
revenue = np.sum(self._prices * demand_arr)
self._revenue_history.append(revenue)
self._revenue_history.append(np.sum(self._prices * demand_arr))
def reset(self, seed=None, options=None):
super().reset(seed=seed)
@@ -89,149 +77,34 @@ class PHANTOM(gym.Env):
reward = self._compute_reward(self._prices, self._demand)
terminated = self._step_count >= 100
truncated = False
return self._get_obs(), reward, terminated, truncated, {"step": self._step_count}
return self._get_obs(), reward, terminated, False, {"step": self._step_count}
def _compute_elasticity(self) -> np.ndarray:
"""point elasticity: e = (dQ/dP) * (P/Q) estimated via finite differences, clipped to [-5, 5]"""
"""point elasticity: e = (dQ/dP) * (P/Q) via finite differences, clipped to [-5, 5]"""
if len(self._price_history) < 2:
return np.zeros(self.n_products)
p = np.array(self._price_history)
q = np.array(self._demand_history)
dp = np.diff(p, axis=0)
dq = np.diff(q, axis=0)
min_dp = 0.5 # ignore tiny price changes to avoid explosions
valid = np.abs(dp) > min_dp
p, q = np.array(self._price_history), np.array(self._demand_history)
dp, dq = np.diff(p, axis=0), np.diff(q, axis=0)
valid = np.abs(dp) > 0.5
with np.errstate(divide='ignore', invalid='ignore'):
elasticity = np.where(valid, (dq / dp) * (p[:-1] / np.maximum(q[:-1], 1.0)), 0.0)
elasticity = np.clip(elasticity, -5.0, 5.0)
elasticity = np.nan_to_num(elasticity, nan=0.0)
elasticity = np.nan_to_num(np.clip(elasticity, -5.0, 5.0), nan=0.0)
return np.mean(elasticity, axis=0) if len(elasticity) > 0 else np.zeros(self.n_products)
def _style_axis(self, ax, title: str = None, xlabel: str = None, ylabel: str = None):
c = self._dashboard_colors
ax.set_facecolor(c['panel'])
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_color(c['accent']); ax.spines['left'].set_color(c['accent'])
ax.tick_params(colors=c['text'], labelsize=8)
if title: ax.set_title(title, color=c['text'], fontsize=11, fontweight='bold', pad=8)
if xlabel: ax.set_xlabel(xlabel, color=c['text'], fontsize=9)
if ylabel: ax.set_ylabel(ylabel, color=c['text'], fontsize=9)
def render(self):
if self.render_mode == "human":
c = self._dashboard_colors
if self._fig is None:
plt.ion()
self._fig = plt.figure(figsize=(14, 10), facecolor=c['bg'])
self._gs = GridSpec(3, 3, figure=self._fig, hspace=0.35, wspace=0.3,
left=0.07, right=0.95, top=0.92, bottom=0.08)
plt.show(block=False)
self._fig.clear()
self._fig.suptitle(f'PHANTOM Market Dynamics [t={self._step_count}, α={self.alpha:.2f}]',
color=c['text'], fontsize=14, fontweight='bold')
demand_mat = np.array(self._demand_history).T
price_mat = np.array(self._price_history).T
elasticity = self._compute_elasticity()
cmap = mcolors.LinearSegmentedColormap.from_list('phantom', [c['bg'], c['blue'], c['green']])
cmap_div = mcolors.LinearSegmentedColormap.from_list('elast', [c['red'], c['bg'], c['blue']])
# price-demand elasticity scatter (all historical data points)
ax_elast = self._fig.add_subplot(self._gs[0, 0])
prices_flat = np.array(self._price_history).flatten()
demands_flat = np.array(self._demand_history).flatten()
product_ids = np.tile(np.arange(self.n_products), len(self._price_history))
scatter = ax_elast.scatter(prices_flat, demands_flat, c=product_ids, cmap='plasma',
alpha=0.6, s=15, edgecolors='none')
if len(prices_flat) > 1: # fit regression line
z = np.polyfit(prices_flat, demands_flat, 1)
p_line = np.linspace(prices_flat.min(), prices_flat.max(), 50)
ax_elast.plot(p_line, np.polyval(z, p_line), '--', color=c['red'], lw=1.5, alpha=0.8)
self._style_axis(ax_elast, "Price-Demand Relationship", "Price ($)", "Demand")
# elasticity coefficients bar
ax_ebar = self._fig.add_subplot(self._gs[0, 1])
colors_e = [c['red'] if e < -0.5 else c['blue'] if e > 0.5 else c['accent'] for e in elasticity]
ax_ebar.barh(range(self.n_products), elasticity, color=colors_e, alpha=0.8, edgecolor=c['bg'])
ax_ebar.axvline(0, color=c['text'], lw=0.8, alpha=0.5)
ax_ebar.axvline(-1, color=c['red'], lw=1, ls='--', alpha=0.5) # unit elastic reference
ax_ebar.set_yticks(range(self.n_products))
ax_ebar.set_yticklabels([f'P{i}' for i in range(self.n_products)], fontsize=7)
self._style_axis(ax_ebar, "Price Elasticity ε", "ε = (ΔQ/ΔP)·(P/Q)", None)
# session composition pie
ax_pie = self._fig.add_subplot(self._gs[0, 2])
n_humans, n_agents = self.market.Nhumans, self.market.Nagents
ax_pie.set_facecolor(c['panel'])
wedges, _ = ax_pie.pie([n_humans, n_agents], colors=[c['blue'], c['red']],
startangle=90, wedgeprops={'linewidth': 2, 'edgecolor': c['bg']})
ax_pie.legend(wedges, [f'H ({n_humans})', f'A ({n_agents})'],
loc='lower center', fontsize=8, frameon=False,
labelcolor=c['text'], bbox_to_anchor=(0.5, -0.05))
ax_pie.set_title("Session Mix", color=c['text'], fontsize=11, fontweight='bold')
# price heatmap over time
ax_pheat = self._fig.add_subplot(self._gs[1, :2])
im_p = ax_pheat.imshow(price_mat, aspect='auto', cmap='viridis', origin='lower')
self._style_axis(ax_pheat, "Price Heatmap P(product, t)", "Step", "Product")
cbar_p = self._fig.colorbar(im_p, ax=ax_pheat, fraction=0.03, pad=0.02)
cbar_p.ax.tick_params(colors=c['text'], labelsize=7)
cbar_p.set_label('$', color=c['text'], fontsize=8)
# demand heatmap over time
ax_dheat = self._fig.add_subplot(self._gs[1, 2])
im_d = ax_dheat.imshow(demand_mat, aspect='auto', cmap=cmap, origin='lower')
self._style_axis(ax_dheat, "Demand Q(product, t)", "Step", None)
cbar_d = self._fig.colorbar(im_d, ax=ax_dheat, fraction=0.046, pad=0.02)
cbar_d.ax.tick_params(colors=c['text'], labelsize=7)
# cross-correlation matrix (price-demand covariance per product)
ax_corr = self._fig.add_subplot(self._gs[2, 0])
if len(self._price_history) > 2:
corr_mat = np.corrcoef(price_mat, demand_mat)[:self.n_products, self.n_products:]
im_corr = ax_corr.imshow(corr_mat, cmap=cmap_div, vmin=-1, vmax=1, aspect='auto')
ax_corr.set_xticks(range(self.n_products))
ax_corr.set_yticks(range(self.n_products))
ax_corr.set_xticklabels([f'Q{i}' for i in range(self.n_products)], fontsize=6)
ax_corr.set_yticklabels([f'P{i}' for i in range(self.n_products)], fontsize=6)
cbar_c = self._fig.colorbar(im_corr, ax=ax_corr, fraction=0.046, pad=0.02)
cbar_c.ax.tick_params(colors=c['text'], labelsize=7)
self._style_axis(ax_corr, "Price-Demand Correlation", None, None)
# revenue curve with demand dispersion (std dev shows concentration)
ax_rev = self._fig.add_subplot(self._gs[2, 1:])
n_steps = len(self._revenue_history)
demand_std = [np.std(d) for d in self._demand_history]
ax_rev.fill_between(range(n_steps), self._revenue_history, alpha=0.3, color=c['green'])
ax_rev.plot(self._revenue_history, color=c['green'], linewidth=2, label='Revenue')
ax_rev.set_xlim(0, max(n_steps, 1))
ax_rev.set_ylim(0, max(self._revenue_history) * 1.1 if self._revenue_history else 1)
ax2 = ax_rev.twinx()
ax2.plot(range(n_steps), demand_std, color=c['blue'], linewidth=2, ls='-', alpha=0.9, label='σ(Demand)')
d_min, d_max = min(demand_std), max(demand_std)
margin = (d_max - d_min) * 0.2 if d_max > d_min else 0.5
ax2.set_ylim(max(0, d_min - margin), d_max + margin)
ax2.tick_params(axis='y', colors=c['blue'], labelsize=8)
ax2.spines['right'].set_color(c['blue'])
ax2.set_ylabel('Demand σ', color=c['blue'], fontsize=9)
self._style_axis(ax_rev, "Revenue & Demand Dispersion", "Step", "Revenue ($)")
ax_rev.legend(loc='upper left', fontsize=7, frameon=False, labelcolor=c['text'])
ax2.legend(loc='upper right', fontsize=7, frameon=False, labelcolor=c['text'])
self._fig.canvas.draw_idle()
self._fig.canvas.flush_events()
plt.pause(0.05)
if self._renderer is None:
self._renderer = DashboardRenderer()
self._renderer.render(self)
elif self.render_mode == "ansi":
return f"step={self._step_count}, prices={self._prices}, demand={self._demand}"
return None
def close(self):
if self._fig: plt.close(self._fig)
self._fig = None
if self._renderer:
self._renderer.close()
self._renderer = None
if __name__ == "__main__":