chore: bette rplotting

This commit is contained in:
2026-01-29 13:11:52 +01:00
parent 772772b5b9
commit 10e8397eec

View File

@@ -2,6 +2,8 @@ import gymnasium as gym
from gymnasium import spaces from gymnasium import spaces
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.colors as mcolors
from .engine import Limbo, MarketEngine, PricingEngine from .engine import Limbo, MarketEngine, PricingEngine
@@ -21,6 +23,8 @@ class PHANTOM(gym.Env):
self.price_bounds = price_bounds self.price_bounds = price_bounds
self.lambda_coi = lambda_coi self.lambda_coi = lambda_coi
self.render_mode = render_mode self.render_mode = render_mode
self.alpha = alpha
self.N = N
self.market = MarketEngine(alpha=alpha, N=N) self.market = MarketEngine(alpha=alpha, N=N)
self._platform_stub = PricingEngine() self._platform_stub = PricingEngine()
@@ -40,9 +44,16 @@ class PHANTOM(gym.Env):
self._prices = None self._prices = None
self._demand = None self._demand = None
self._step_count = 0 self._step_count = 0
self._demand_history = [] # list of demand arrays over time self._demand_history = []
self._price_history = [] # list of price arrays over time self._price_history = []
self._fig, self._axes = None, None self._revenue_history = []
self._fig = None
self._gs = None
self._dashboard_colors = {
'bg': '#f5f0e8', 'panel': '#ebe3d5', 'accent': '#c9b99a',
'text': '#3d3229', 'green': '#5c7a5c', 'red': '#8b4049',
'blue': '#5a7384', 'orange': '#b87333', 'purple': '#7d6b7d'
}
def _get_obs(self) -> dict: def _get_obs(self) -> dict:
demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)], dtype=np.float32) demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)], dtype=np.float32)
@@ -58,16 +69,16 @@ class PHANTOM(gym.Env):
demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)]) demand_arr = np.array([self._demand.get(i, 0.0) for i in range(self.n_products)])
self._demand_history.append(demand_arr) self._demand_history.append(demand_arr)
self._price_history.append(self._prices.copy()) self._price_history.append(self._prices.copy())
revenue = np.sum(self._prices * demand_arr)
self._revenue_history.append(revenue)
def reset(self, seed=None, options=None): def reset(self, seed=None, options=None):
super().reset(seed=seed) super().reset(seed=seed)
self._prices = np.random.uniform(*self.price_bounds, size=self.n_products) self._prices = np.random.uniform(*self.price_bounds, size=self.n_products)
self._demand = self.market.act(self._prices) self._demand = self.market.act(self._prices)
self._step_count = 0 self._step_count = 0
self._demand_history, self._price_history = [], [] self._demand_history, self._price_history, self._revenue_history = [], [], []
self._record_history() self._record_history()
if self._fig: plt.close(self._fig)
self._fig, self._axes = None, None
return self._get_obs(), {} return self._get_obs(), {}
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
@@ -82,45 +93,137 @@ class PHANTOM(gym.Env):
return self._get_obs(), reward, terminated, truncated, {"step": self._step_count} return self._get_obs(), reward, terminated, truncated, {"step": self._step_count}
def _compute_elasticity(self) -> np.ndarray:
"""point elasticity: e = (dQ/dP) * (P/Q) estimated via finite differences, clipped to [-5, 5]"""
if len(self._price_history) < 2:
return np.zeros(self.n_products)
p = np.array(self._price_history)
q = np.array(self._demand_history)
dp = np.diff(p, axis=0)
dq = np.diff(q, axis=0)
min_dp = 0.5 # ignore tiny price changes to avoid explosions
valid = np.abs(dp) > min_dp
with np.errstate(divide='ignore', invalid='ignore'):
elasticity = np.where(valid, (dq / dp) * (p[:-1] / np.maximum(q[:-1], 1.0)), 0.0)
elasticity = np.clip(elasticity, -5.0, 5.0)
elasticity = np.nan_to_num(elasticity, nan=0.0)
return np.mean(elasticity, axis=0) if len(elasticity) > 0 else np.zeros(self.n_products)
def _style_axis(self, ax, title: str = None, xlabel: str = None, ylabel: str = None):
c = self._dashboard_colors
ax.set_facecolor(c['panel'])
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_color(c['accent']); ax.spines['left'].set_color(c['accent'])
ax.tick_params(colors=c['text'], labelsize=8)
if title: ax.set_title(title, color=c['text'], fontsize=11, fontweight='bold', pad=8)
if xlabel: ax.set_xlabel(xlabel, color=c['text'], fontsize=9)
if ylabel: ax.set_ylabel(ylabel, color=c['text'], fontsize=9)
def render(self): def render(self):
if self.render_mode == "human": if self.render_mode == "human":
c = self._dashboard_colors
if self._fig is None: if self._fig is None:
plt.ion() plt.ion()
self._fig, self._axes = plt.subplots(2, 2, figsize=(12, 8)) self._fig = plt.figure(figsize=(14, 10), facecolor=c['bg'])
self._fig.suptitle("PHANTOM Environment") self._gs = GridSpec(3, 3, figure=self._fig, hspace=0.35, wspace=0.3,
left=0.07, right=0.95, top=0.92, bottom=0.08)
plt.show(block=False)
demand_mat = np.array(self._demand_history).T # shape: (n_products, timesteps) self._fig.clear()
self._fig.suptitle(f'PHANTOM Market Dynamics [t={self._step_count}, α={self.alpha:.2f}]',
color=c['text'], fontsize=14, fontweight='bold')
demand_mat = np.array(self._demand_history).T
price_mat = np.array(self._price_history).T price_mat = np.array(self._price_history).T
revenue_per_step = np.sum(demand_mat * price_mat, axis=0) # revenue = demand * price elasticity = self._compute_elasticity()
demand_variance = np.var(demand_mat, axis=0) # how spread demand is across products cmap = mcolors.LinearSegmentedColormap.from_list('phantom', [c['bg'], c['blue'], c['green']])
cmap_div = mcolors.LinearSegmentedColormap.from_list('elast', [c['red'], c['bg'], c['blue']])
for row in self._axes: # price-demand elasticity scatter (all historical data points)
for ax in row: ax.clear() ax_elast = self._fig.add_subplot(self._gs[0, 0])
prices_flat = np.array(self._price_history).flatten()
demands_flat = np.array(self._demand_history).flatten()
product_ids = np.tile(np.arange(self.n_products), len(self._price_history))
scatter = ax_elast.scatter(prices_flat, demands_flat, c=product_ids, cmap='plasma',
alpha=0.6, s=15, edgecolors='none')
if len(prices_flat) > 1: # fit regression line
z = np.polyfit(prices_flat, demands_flat, 1)
p_line = np.linspace(prices_flat.min(), prices_flat.max(), 50)
ax_elast.plot(p_line, np.polyval(z, p_line), '--', color=c['red'], lw=1.5, alpha=0.8)
self._style_axis(ax_elast, "Price-Demand Relationship", "Price ($)", "Demand")
self._axes[0, 0].imshow(demand_mat, aspect='auto', cmap='viridis', origin='lower') # elasticity coefficients bar
self._axes[0, 0].set_xlabel("Time Step") ax_ebar = self._fig.add_subplot(self._gs[0, 1])
self._axes[0, 0].set_ylabel("Product") colors_e = [c['red'] if e < -0.5 else c['blue'] if e > 0.5 else c['accent'] for e in elasticity]
self._axes[0, 0].set_title("Demand Over Time") ax_ebar.barh(range(self.n_products), elasticity, color=colors_e, alpha=0.8, edgecolor=c['bg'])
ax_ebar.axvline(0, color=c['text'], lw=0.8, alpha=0.5)
ax_ebar.axvline(-1, color=c['red'], lw=1, ls='--', alpha=0.5) # unit elastic reference
ax_ebar.set_yticks(range(self.n_products))
ax_ebar.set_yticklabels([f'P{i}' for i in range(self.n_products)], fontsize=7)
self._style_axis(ax_ebar, "Price Elasticity ε", "ε = (ΔQ/ΔP)·(P/Q)", None)
self._axes[0, 1].imshow(price_mat, aspect='auto', cmap='plasma', origin='lower') # session composition pie
self._axes[0, 1].set_xlabel("Time Step") ax_pie = self._fig.add_subplot(self._gs[0, 2])
self._axes[0, 1].set_ylabel("Product") n_humans, n_agents = self.market.Nhumans, self.market.Nagents
self._axes[0, 1].set_title("Price Over Time") ax_pie.set_facecolor(c['panel'])
wedges, _ = ax_pie.pie([n_humans, n_agents], colors=[c['blue'], c['red']],
startangle=90, wedgeprops={'linewidth': 2, 'edgecolor': c['bg']})
ax_pie.legend(wedges, [f'H ({n_humans})', f'A ({n_agents})'],
loc='lower center', fontsize=8, frameon=False,
labelcolor=c['text'], bbox_to_anchor=(0.5, -0.05))
ax_pie.set_title("Session Mix", color=c['text'], fontsize=11, fontweight='bold')
self._axes[1, 0].plot(revenue_per_step, color='teal', linewidth=1.5) # price heatmap over time
self._axes[1, 0].set_xlabel("Time Step") ax_pheat = self._fig.add_subplot(self._gs[1, :2])
self._axes[1, 0].set_ylabel("Revenue") im_p = ax_pheat.imshow(price_mat, aspect='auto', cmap='viridis', origin='lower')
self._axes[1, 0].set_title("Revenue per Step") self._style_axis(ax_pheat, "Price Heatmap P(product, t)", "Step", "Product")
cbar_p = self._fig.colorbar(im_p, ax=ax_pheat, fraction=0.03, pad=0.02)
cbar_p.ax.tick_params(colors=c['text'], labelsize=7)
cbar_p.set_label('$', color=c['text'], fontsize=8)
self._axes[1, 1].plot(demand_variance, color='orangered', linewidth=1.5) # demand heatmap over time
self._axes[1, 1].set_xlabel("Time Step") ax_dheat = self._fig.add_subplot(self._gs[1, 2])
self._axes[1, 1].set_ylabel("Variance") im_d = ax_dheat.imshow(demand_mat, aspect='auto', cmap=cmap, origin='lower')
self._axes[1, 1].set_title("Demand Variance") self._style_axis(ax_dheat, "Demand Q(product, t)", "Step", None)
cbar_d = self._fig.colorbar(im_d, ax=ax_dheat, fraction=0.046, pad=0.02)
cbar_d.ax.tick_params(colors=c['text'], labelsize=7)
self._fig.tight_layout() # cross-correlation matrix (price-demand covariance per product)
self._fig.canvas.draw() ax_corr = self._fig.add_subplot(self._gs[2, 0])
if len(self._price_history) > 2:
corr_mat = np.corrcoef(price_mat, demand_mat)[:self.n_products, self.n_products:]
im_corr = ax_corr.imshow(corr_mat, cmap=cmap_div, vmin=-1, vmax=1, aspect='auto')
ax_corr.set_xticks(range(self.n_products))
ax_corr.set_yticks(range(self.n_products))
ax_corr.set_xticklabels([f'Q{i}' for i in range(self.n_products)], fontsize=6)
ax_corr.set_yticklabels([f'P{i}' for i in range(self.n_products)], fontsize=6)
cbar_c = self._fig.colorbar(im_corr, ax=ax_corr, fraction=0.046, pad=0.02)
cbar_c.ax.tick_params(colors=c['text'], labelsize=7)
self._style_axis(ax_corr, "Price-Demand Correlation", None, None)
# revenue curve with demand dispersion (std dev shows concentration)
ax_rev = self._fig.add_subplot(self._gs[2, 1:])
n_steps = len(self._revenue_history)
demand_std = [np.std(d) for d in self._demand_history]
ax_rev.fill_between(range(n_steps), self._revenue_history, alpha=0.3, color=c['green'])
ax_rev.plot(self._revenue_history, color=c['green'], linewidth=2, label='Revenue')
ax_rev.set_xlim(0, max(n_steps, 1))
ax_rev.set_ylim(0, max(self._revenue_history) * 1.1 if self._revenue_history else 1)
ax2 = ax_rev.twinx()
ax2.plot(range(n_steps), demand_std, color=c['blue'], linewidth=2, ls='-', alpha=0.9, label='σ(Demand)')
d_min, d_max = min(demand_std), max(demand_std)
margin = (d_max - d_min) * 0.2 if d_max > d_min else 0.5
ax2.set_ylim(max(0, d_min - margin), d_max + margin)
ax2.tick_params(axis='y', colors=c['blue'], labelsize=8)
ax2.spines['right'].set_color(c['blue'])
ax2.set_ylabel('Demand σ', color=c['blue'], fontsize=9)
self._style_axis(ax_rev, "Revenue & Demand Dispersion", "Step", "Revenue ($)")
ax_rev.legend(loc='upper left', fontsize=7, frameon=False, labelcolor=c['text'])
ax2.legend(loc='upper right', fontsize=7, frameon=False, labelcolor=c['text'])
self._fig.canvas.draw_idle()
self._fig.canvas.flush_events() self._fig.canvas.flush_events()
plt.pause(0.01) plt.pause(0.05)
elif self.render_mode == "ansi": elif self.render_mode == "ansi":
return f"step={self._step_count}, prices={self._prices}, demand={self._demand}" return f"step={self._step_count}, prices={self._prices}, demand={self._demand}"
@@ -128,15 +231,15 @@ class PHANTOM(gym.Env):
def close(self): def close(self):
if self._fig: plt.close(self._fig) if self._fig: plt.close(self._fig)
self._fig, self._axes = None, None self._fig = None
if __name__ == "__main__": if __name__ == "__main__":
env = PHANTOM(n_products=100, render_mode="human") env = PHANTOM(n_products=15, alpha=0.3, N=100, render_mode="human")
obs, _ = env.reset() obs, _ = env.reset()
for _ in range(100): for step in range(100):
action = env.action_space.sample() action = env.action_space.sample()
obs, reward, term, trunc, info = env.step(action) obs, reward, term, trunc, info = env.step(action)
env.render() env.render()
print(f"Reward: {reward:.2f}")
if term: break if term: break
env.close()