From 8aa4db1c9e5572aa1eeaffee498d85f8fee49326 Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Wed, 18 Mar 2026 11:39:51 +0100 Subject: [PATCH] chor> competitive wrapping --- engine/engine.py | 20 +++++++++++++++++++- engine/wrapper.py | 42 +++++++++++++++++++++++++++--------------- requirements.txt | 1 + 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/engine/engine.py b/engine/engine.py index d548177..0e6f143 100644 --- a/engine/engine.py +++ b/engine/engine.py @@ -60,7 +60,25 @@ class MarketEngine: ] # store trajectories for agent probability calculation self.last_trajectories = human_t + agent_t - return estimate_demand(self.last_trajectories, self.action_weights) + + demand_proxy = estimate_demand( + self.last_trajectories, + self.action_weights, + normalize=True, + per_session=False, + ) + raw_mix = ((1.0 - float(self.alpha)) * demand_h) + ( + float(self.alpha) * demand_a + ) + total_raw_demand = float(np.sum(raw_mix)) + if not demand_proxy: + return {i: float(raw_mix[i]) for i in range(len(prices))} + if total_raw_demand <= 0.0: + return {i: 0.0 for i in range(len(prices))} + return { + i: total_raw_demand * float(demand_proxy.get(i, 0.0)) / 100.0 + for i in range(len(prices)) + } def measure(self): pass diff --git a/engine/wrapper.py b/engine/wrapper.py index 1748617..0ff75d1 100644 --- a/engine/wrapper.py +++ b/engine/wrapper.py @@ -130,6 +130,13 @@ class PHANTOM(gym.Env): self._initial_episode_prices = None self._trajectories = [] # session trajectories for agent prob calculation self.baseline_prices = np.full(self.n_products, self.price_bounds[0]) + self.anchor_prices = np.full( + self.n_products, + float(np.clip(float(self.human_params[0]), *self.price_bounds)), + ) + self.competitive_cap = float( + min(self.price_bounds[1], float(np.mean(self.anchor_prices)) * 1.15) + ) self._low_margin_streak = 0 # consecutive steps below margin_floor self._last_agent_prob = float(self.alpha) self._last_alpha_adv = float(self.alpha) @@ -169,19 +176,28 @@ class PHANTOM(gym.Env): self.market.Nhumans = self.N - n_agents def _decode_action(self, action) -> np.ndarray: - base = ( - self._prices - if self._prices is not None - else np.full(self.n_products, self.price_bounds[0], dtype=float) - ) + prev = self._prices + base = self.anchor_prices + + def _blend(target: np.ndarray) -> np.ndarray: + if prev is None: + lower = float(self.price_bounds[0]) + return np.clip(target, lower, self.competitive_cap) + blended = 0.75 * np.asarray(prev, dtype=float) + 0.25 * target + lower = float(self.price_bounds[0]) + return np.clip(blended, lower, self.competitive_cap) + if np.isscalar(action): idx = int(np.clip(int(action), 0, self.action_levels - 1)) - return np.clip(base * self._action_scales[idx], *self.price_bounds) + target = base * self._action_scales[idx] + return _blend(target) a = np.asarray(action) if a.size == 1: idx = int(np.clip(int(a.reshape(-1)[0]), 0, self.action_levels - 1)) - return np.clip(base * self._action_scales[idx], *self.price_bounds) - return np.clip(a.astype(float), *self.price_bounds) + target = base * self._action_scales[idx] + return _blend(target) + lower = float(self.price_bounds[0]) + return np.clip(a.astype(float), lower, self.competitive_cap) def _compute_agent_prob(self, trajectories=None) -> float: trajectories = ( @@ -225,14 +241,10 @@ class PHANTOM(gym.Env): upward_volatility = 0.0 ux_penalty = self.eta_ux * info_budget * (volatility + 0.5 * upward_volatility) - competitive_anchor = float( - np.clip(float(self.human_params[0]) * 1.2, *self.price_bounds) - ) + competitive_anchor = float(np.mean(self.anchor_prices)) price_ratio = prices / max(competitive_anchor, 1.0) - supra_excess = np.clip(price_ratio - 1.0, 0.0, None) - supra_penalty = ( - 0.5 * self.eta_ux * info_budget * float(np.mean(np.square(supra_excess))) - ) + supra_excess = np.clip(price_ratio - 1.15, 0.0, None) + supra_penalty = 4.0 * info_budget * float(np.mean(np.square(supra_excess))) supra_share = float(np.mean(supra_excess > 0.0)) reward_revenue = self.reward_profit_weight * profit diff --git a/requirements.txt b/requirements.txt index c1a8686..71af617 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ pandas jupyter ipykernel matplotlib +tikzplotlib graphviz browser-use pytest