From 803e3a29727c25a4f5551bf44dc8442902abddba Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Sat, 28 Feb 2026 23:30:16 +0100 Subject: [PATCH] chore: cleaning some code --- engine/jax/env.py | 21 +++++++++++-- engine/jax/primitives.py | 12 ++++++-- engine/lib/coi.py | 11 ++++--- engine/studies/factors.py | 3 +- engine/train.py | 2 +- engine/wrapper.py | 62 +++++++++++++++++++++++++++------------ 6 files changed, 81 insertions(+), 30 deletions(-) diff --git a/engine/jax/env.py b/engine/jax/env.py index 06542b1..8ecafd1 100644 --- a/engine/jax/env.py +++ b/engine/jax/env.py @@ -32,6 +32,7 @@ class EnvParams(NamedTuple): price_high: float lambda_coi: float info_value: float + eta_ux: float robust_radius: float margin_floor: float margin_floor_patience: int @@ -63,6 +64,7 @@ class CandidateEval(NamedTuple): agent_prob: jax.Array leakage: jax.Array discount: jax.Array + ux_penalty: jax.Array n_purchases: jax.Array n_agents: jax.Array @@ -76,6 +78,7 @@ def make_env_params( robust_radius: float, robust_points: int, info_value: float, + eta_ux: float = 0.5, action_levels: int, action_scale_low: float, action_scale_high: float, @@ -110,6 +113,7 @@ def make_env_params( price_high=float(price_high), lambda_coi=float(lambda_coi), info_value=float(info_value), + eta_ux=float(eta_ux), robust_radius=float(robust_radius), margin_floor=float(margin_floor), margin_floor_patience=int(margin_floor_patience), @@ -143,6 +147,7 @@ def _evaluate_candidate( key: jax.Array, alpha_candidate: jax.Array, prices: jax.Array, + ux_volatility: jax.Array, params: EnvParams, ) -> CandidateEval: states, products, actors, lengths = _sample_sessions_jax( @@ -167,11 +172,13 @@ def _evaluate_candidate( demand = weighted_demand(states, products, params.n_products, params.event_weights) revenue = revenue_from_demand(prices, demand) - reward, leakage, discount = reward_with_coi_penalty( + reward, leakage, discount, ux_penalty = reward_with_coi_penalty( revenue, agent_prob, params.lambda_coi, params.info_value, + params.eta_ux, + ux_volatility, ) purchases = purchase_flags(states, params.purchase_mask) return CandidateEval( @@ -181,6 +188,7 @@ def _evaluate_candidate( agent_prob=agent_prob, leakage=leakage, discount=discount, + ux_penalty=ux_penalty, n_purchases=jnp.sum(purchases.astype(jnp.float32)), n_agents=jnp.sum(actors.astype(jnp.float32)), ) @@ -212,10 +220,16 @@ def step_env( params: EnvParams, ) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]: prices = _decode_action(state.prices, action, params) + + baseline = jnp.maximum(state.prices, 1.0) + ux_volatility = jnp.where( + state.step_count > 0, jnp.mean(jnp.abs(prices - state.prices) / baseline), 0.0 + ) + n_candidates = params.alpha_candidates.shape[0] cand_keys = jax.random.split(key, n_candidates) evals = jax.vmap( - lambda k, a: _evaluate_candidate(k, a, prices, params), + lambda k, a: _evaluate_candidate(k, a, prices, ux_volatility, params), in_axes=(0, 0), )(cand_keys, params.alpha_candidates) idx = jnp.argmin(evals.reward) @@ -226,6 +240,7 @@ def step_env( agent_prob = evals.agent_prob[idx] leakage = evals.leakage[idx] discount = evals.discount[idx] + ux_penalty = evals.ux_penalty[idx] n_purchases = evals.n_purchases[idx] n_agents = evals.n_agents[idx] alpha_adv = params.alpha_candidates[idx] @@ -255,6 +270,8 @@ def step_env( "alpha_adv": alpha_adv, "coi_leakage": leakage, "coi_discount": discount, + "ux_penalty": ux_penalty, + "volatility": ux_volatility, "n_purchases": n_purchases, "n_agents": n_agents, "avg_margin": avg_margin, diff --git a/engine/jax/primitives.py b/engine/jax/primitives.py index 37bf326..e638b32 100644 --- a/engine/jax/primitives.py +++ b/engine/jax/primitives.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass from functools import partial -from typing import Mapping, Sequence +from typing import Mapping import numpy as np @@ -484,11 +484,17 @@ if JAX_AVAILABLE: def reward_with_coi_penalty( - revenue, agent_prob: float, lambda_coi: float, info_value: float + revenue, + agent_prob: float, + lambda_coi: float, + info_value: float, + eta_ux: float = 0.0, + ux_volatility: float = 0.0, ): leakage = agent_prob * info_value discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0) - return revenue * discount, leakage, discount + ux_penalty = eta_ux * revenue * ux_volatility + return revenue * discount - ux_penalty, leakage, discount, ux_penalty if JAX_AVAILABLE: diff --git a/engine/lib/coi.py b/engine/lib/coi.py index 33267b5..ed18672 100644 --- a/engine/lib/coi.py +++ b/engine/lib/coi.py @@ -3,7 +3,10 @@ from typing import Dict def compute_agent_probability( - trajectory: list, human_transitions: Dict, agent_transitions: Dict + trajectory: list, + human_transitions: Dict, + agent_transitions: Dict, + temperature: float = 1.0, ) -> float: """estimate agent probability via KL divergence between trajectory transitions and reference models @@ -52,9 +55,9 @@ def compute_agent_probability( kl_agent = kl_div(empirical, agent_transitions) # convert to probability via softmax (lower KL = higher prob) - # agent_prob = exp(-kl_agent) / (exp(-kl_human) + exp(-kl_agent)) - exp_h = np.exp(-kl_human) - exp_a = np.exp(-kl_agent) + t = float(max(temperature, 1e-6)) + exp_h = np.exp(-kl_human / t) + exp_a = np.exp(-kl_agent / t) return float(exp_a / (exp_h + exp_a + 1e-10)) diff --git a/engine/studies/factors.py b/engine/studies/factors.py index 1fbfbe1..c9e4cec 100644 --- a/engine/studies/factors.py +++ b/engine/studies/factors.py @@ -1,7 +1,6 @@ """shared factor definitions for experimental designs""" import numpy as np -from dataclasses import dataclass, field -from typing import Callable, Any +from dataclasses import dataclass @dataclass class Factor: diff --git a/engine/train.py b/engine/train.py index a77ca94..063f4ae 100644 --- a/engine/train.py +++ b/engine/train.py @@ -287,7 +287,7 @@ def _sb3_model_cls(algo: str): raise ValueError(f"unsupported algo '{algo}'") -def train_qtable(cfg: dict) -> tuple[EventQTable, dict]: +def train_qtable(cfg: dict) -> tuple["EventQTable", dict]: from .lib.discrete import EventQTable np.random.seed(int(cfg["seed"])) diff --git a/engine/wrapper.py b/engine/wrapper.py index 3e37d9a..751c104 100644 --- a/engine/wrapper.py +++ b/engine/wrapper.py @@ -48,6 +48,7 @@ class PHANTOM(gym.Env): robust_radius: float = 0.0, robust_points: int = 5, info_value: float = 1.0, + eta_ux: float = 0.5, action_levels: int = 9, action_scale_low: float = 0.9, action_scale_high: float = 1.1, @@ -75,6 +76,7 @@ class PHANTOM(gym.Env): self.robust_radius = max(0.0, float(robust_radius)) self.robust_points = max(1, int(robust_points)) self.info_value = float(info_value) + self.eta_ux = float(eta_ux) self.action_levels = max(2, int(action_levels)) self._action_scales = np.linspace( float(action_scale_low), float(action_scale_high), self.action_levels @@ -179,11 +181,26 @@ class PHANTOM(gym.Env): revenue = float(np.dot(prices, demand_arr)) purchases = extract_purchases(trajectories) coi_mix = compute_uplift_coi(prices, purchases, self.baseline_prices) + # multiplicative penalty so COI term scales with revenue magnitude coi_leakage = float(agent_prob * self.info_value) discount = float(np.clip(1.0 - self.lambda_coi * coi_leakage, 0.0, 1.0)) coi_penalty = revenue * (1.0 - discount) # absolute penalty in revenue units - reward = revenue * discount + + # calculate UX penalty based on price volatility + if len(self._price_history) > 0: + volatility = float( + np.mean( + np.abs(prices - self._price_history[-1]) + / np.maximum(self.baseline_prices, 1.0) + ) + ) + else: + volatility = 0.0 + ux_penalty = self.eta_ux * revenue * volatility + + reward = revenue * discount - ux_penalty + return reward, { "revenue": revenue, "coi_mix": float(coi_mix), @@ -191,6 +208,8 @@ class PHANTOM(gym.Env): "coi_leakage": coi_leakage, "coi_penalty": coi_penalty, "coi_discount": discount, + "ux_penalty": ux_penalty, + "volatility": volatility, } def _alpha_candidates(self) -> np.ndarray: @@ -200,27 +219,34 @@ class PHANTOM(gym.Env): hi = min(1.0, self.nominal_alpha + self.robust_radius) return np.linspace(lo, hi, self.robust_points) + def _evaluate_candidate( + self, alpha: float, prices: np.ndarray + ) -> tuple[float, dict, list, float]: + self._set_market_mix(alpha) + demand = self.market.act(prices) + trajectories = list(self.market.last_trajectories) + agent_prob = self._compute_agent_prob(trajectories) + reward, _ = self._compute_reward(prices, demand, agent_prob, trajectories) + return reward, demand, trajectories, agent_prob + def _select_adversarial_alpha( self, prices: np.ndarray ) -> tuple[float, dict, list, float]: - """inner robust step: pick worst-case alpha and return its outcome directly to avoid double-sampling""" + """inner robust step: evaluate candidates and pick worst-case alpha""" candidates = self._alpha_candidates() - best_alpha, worst_reward = float(candidates[0]), np.inf - best_demand, best_trajectories, best_agent_prob = None, [], 0.0 - for alpha in candidates: - self._set_market_mix(float(alpha)) - demand = self.market.act(prices) - trajectories = list(self.market.last_trajectories) - agent_prob = self._compute_agent_prob(trajectories) - reward, _ = self._compute_reward(prices, demand, agent_prob, trajectories) - if reward < worst_reward: - worst_reward = reward - best_alpha, best_demand, best_trajectories, best_agent_prob = ( - float(alpha), - demand, - trajectories, - agent_prob, - ) + evaluations = [ + (alpha, *self._evaluate_candidate(float(alpha), prices)) + for alpha in candidates + ] + + # min over alpha in Wasserstein interval + best_eval = min(evaluations, key=lambda x: x[1]) # index 1 is reward + + best_alpha = best_eval[0] + best_demand = best_eval[2] + best_trajectories = best_eval[3] + best_agent_prob = best_eval[4] + return best_alpha, best_demand, best_trajectories, best_agent_prob def _record_history(self):