chore: cleaning some code

This commit is contained in:
2026-02-28 23:30:16 +01:00
parent 233ce3be34
commit 803e3a2972
6 changed files with 81 additions and 30 deletions

View File

@@ -32,6 +32,7 @@ class EnvParams(NamedTuple):
price_high: float price_high: float
lambda_coi: float lambda_coi: float
info_value: float info_value: float
eta_ux: float
robust_radius: float robust_radius: float
margin_floor: float margin_floor: float
margin_floor_patience: int margin_floor_patience: int
@@ -63,6 +64,7 @@ class CandidateEval(NamedTuple):
agent_prob: jax.Array agent_prob: jax.Array
leakage: jax.Array leakage: jax.Array
discount: jax.Array discount: jax.Array
ux_penalty: jax.Array
n_purchases: jax.Array n_purchases: jax.Array
n_agents: jax.Array n_agents: jax.Array
@@ -76,6 +78,7 @@ def make_env_params(
robust_radius: float, robust_radius: float,
robust_points: int, robust_points: int,
info_value: float, info_value: float,
eta_ux: float = 0.5,
action_levels: int, action_levels: int,
action_scale_low: float, action_scale_low: float,
action_scale_high: float, action_scale_high: float,
@@ -110,6 +113,7 @@ def make_env_params(
price_high=float(price_high), price_high=float(price_high),
lambda_coi=float(lambda_coi), lambda_coi=float(lambda_coi),
info_value=float(info_value), info_value=float(info_value),
eta_ux=float(eta_ux),
robust_radius=float(robust_radius), robust_radius=float(robust_radius),
margin_floor=float(margin_floor), margin_floor=float(margin_floor),
margin_floor_patience=int(margin_floor_patience), margin_floor_patience=int(margin_floor_patience),
@@ -143,6 +147,7 @@ def _evaluate_candidate(
key: jax.Array, key: jax.Array,
alpha_candidate: jax.Array, alpha_candidate: jax.Array,
prices: jax.Array, prices: jax.Array,
ux_volatility: jax.Array,
params: EnvParams, params: EnvParams,
) -> CandidateEval: ) -> CandidateEval:
states, products, actors, lengths = _sample_sessions_jax( states, products, actors, lengths = _sample_sessions_jax(
@@ -167,11 +172,13 @@ def _evaluate_candidate(
demand = weighted_demand(states, products, params.n_products, params.event_weights) demand = weighted_demand(states, products, params.n_products, params.event_weights)
revenue = revenue_from_demand(prices, demand) revenue = revenue_from_demand(prices, demand)
reward, leakage, discount = reward_with_coi_penalty( reward, leakage, discount, ux_penalty = reward_with_coi_penalty(
revenue, revenue,
agent_prob, agent_prob,
params.lambda_coi, params.lambda_coi,
params.info_value, params.info_value,
params.eta_ux,
ux_volatility,
) )
purchases = purchase_flags(states, params.purchase_mask) purchases = purchase_flags(states, params.purchase_mask)
return CandidateEval( return CandidateEval(
@@ -181,6 +188,7 @@ def _evaluate_candidate(
agent_prob=agent_prob, agent_prob=agent_prob,
leakage=leakage, leakage=leakage,
discount=discount, discount=discount,
ux_penalty=ux_penalty,
n_purchases=jnp.sum(purchases.astype(jnp.float32)), n_purchases=jnp.sum(purchases.astype(jnp.float32)),
n_agents=jnp.sum(actors.astype(jnp.float32)), n_agents=jnp.sum(actors.astype(jnp.float32)),
) )
@@ -212,10 +220,16 @@ def step_env(
params: EnvParams, params: EnvParams,
) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]: ) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]:
prices = _decode_action(state.prices, action, params) prices = _decode_action(state.prices, action, params)
baseline = jnp.maximum(state.prices, 1.0)
ux_volatility = jnp.where(
state.step_count > 0, jnp.mean(jnp.abs(prices - state.prices) / baseline), 0.0
)
n_candidates = params.alpha_candidates.shape[0] n_candidates = params.alpha_candidates.shape[0]
cand_keys = jax.random.split(key, n_candidates) cand_keys = jax.random.split(key, n_candidates)
evals = jax.vmap( evals = jax.vmap(
lambda k, a: _evaluate_candidate(k, a, prices, params), lambda k, a: _evaluate_candidate(k, a, prices, ux_volatility, params),
in_axes=(0, 0), in_axes=(0, 0),
)(cand_keys, params.alpha_candidates) )(cand_keys, params.alpha_candidates)
idx = jnp.argmin(evals.reward) idx = jnp.argmin(evals.reward)
@@ -226,6 +240,7 @@ def step_env(
agent_prob = evals.agent_prob[idx] agent_prob = evals.agent_prob[idx]
leakage = evals.leakage[idx] leakage = evals.leakage[idx]
discount = evals.discount[idx] discount = evals.discount[idx]
ux_penalty = evals.ux_penalty[idx]
n_purchases = evals.n_purchases[idx] n_purchases = evals.n_purchases[idx]
n_agents = evals.n_agents[idx] n_agents = evals.n_agents[idx]
alpha_adv = params.alpha_candidates[idx] alpha_adv = params.alpha_candidates[idx]
@@ -255,6 +270,8 @@ def step_env(
"alpha_adv": alpha_adv, "alpha_adv": alpha_adv,
"coi_leakage": leakage, "coi_leakage": leakage,
"coi_discount": discount, "coi_discount": discount,
"ux_penalty": ux_penalty,
"volatility": ux_volatility,
"n_purchases": n_purchases, "n_purchases": n_purchases,
"n_agents": n_agents, "n_agents": n_agents,
"avg_margin": avg_margin, "avg_margin": avg_margin,

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Mapping, Sequence from typing import Mapping
import numpy as np import numpy as np
@@ -484,11 +484,17 @@ if JAX_AVAILABLE:
def reward_with_coi_penalty( def reward_with_coi_penalty(
revenue, agent_prob: float, lambda_coi: float, info_value: float revenue,
agent_prob: float,
lambda_coi: float,
info_value: float,
eta_ux: float = 0.0,
ux_volatility: float = 0.0,
): ):
leakage = agent_prob * info_value leakage = agent_prob * info_value
discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0) discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0)
return revenue * discount, leakage, discount ux_penalty = eta_ux * revenue * ux_volatility
return revenue * discount - ux_penalty, leakage, discount, ux_penalty
if JAX_AVAILABLE: if JAX_AVAILABLE:

View File

@@ -3,7 +3,10 @@ from typing import Dict
def compute_agent_probability( def compute_agent_probability(
trajectory: list, human_transitions: Dict, agent_transitions: Dict trajectory: list,
human_transitions: Dict,
agent_transitions: Dict,
temperature: float = 1.0,
) -> float: ) -> float:
"""estimate agent probability via KL divergence between trajectory transitions and reference models """estimate agent probability via KL divergence between trajectory transitions and reference models
@@ -52,9 +55,9 @@ def compute_agent_probability(
kl_agent = kl_div(empirical, agent_transitions) kl_agent = kl_div(empirical, agent_transitions)
# convert to probability via softmax (lower KL = higher prob) # convert to probability via softmax (lower KL = higher prob)
# agent_prob = exp(-kl_agent) / (exp(-kl_human) + exp(-kl_agent)) t = float(max(temperature, 1e-6))
exp_h = np.exp(-kl_human) exp_h = np.exp(-kl_human / t)
exp_a = np.exp(-kl_agent) exp_a = np.exp(-kl_agent / t)
return float(exp_a / (exp_h + exp_a + 1e-10)) return float(exp_a / (exp_h + exp_a + 1e-10))

View File

@@ -1,7 +1,6 @@
"""shared factor definitions for experimental designs""" """shared factor definitions for experimental designs"""
import numpy as np import numpy as np
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Callable, Any
@dataclass @dataclass
class Factor: class Factor:

View File

@@ -287,7 +287,7 @@ def _sb3_model_cls(algo: str):
raise ValueError(f"unsupported algo '{algo}'") raise ValueError(f"unsupported algo '{algo}'")
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]: def train_qtable(cfg: dict) -> tuple["EventQTable", dict]:
from .lib.discrete import EventQTable from .lib.discrete import EventQTable
np.random.seed(int(cfg["seed"])) np.random.seed(int(cfg["seed"]))

View File

@@ -48,6 +48,7 @@ class PHANTOM(gym.Env):
robust_radius: float = 0.0, robust_radius: float = 0.0,
robust_points: int = 5, robust_points: int = 5,
info_value: float = 1.0, info_value: float = 1.0,
eta_ux: float = 0.5,
action_levels: int = 9, action_levels: int = 9,
action_scale_low: float = 0.9, action_scale_low: float = 0.9,
action_scale_high: float = 1.1, action_scale_high: float = 1.1,
@@ -75,6 +76,7 @@ class PHANTOM(gym.Env):
self.robust_radius = max(0.0, float(robust_radius)) self.robust_radius = max(0.0, float(robust_radius))
self.robust_points = max(1, int(robust_points)) self.robust_points = max(1, int(robust_points))
self.info_value = float(info_value) self.info_value = float(info_value)
self.eta_ux = float(eta_ux)
self.action_levels = max(2, int(action_levels)) self.action_levels = max(2, int(action_levels))
self._action_scales = np.linspace( self._action_scales = np.linspace(
float(action_scale_low), float(action_scale_high), self.action_levels float(action_scale_low), float(action_scale_high), self.action_levels
@@ -179,11 +181,26 @@ class PHANTOM(gym.Env):
revenue = float(np.dot(prices, demand_arr)) revenue = float(np.dot(prices, demand_arr))
purchases = extract_purchases(trajectories) purchases = extract_purchases(trajectories)
coi_mix = compute_uplift_coi(prices, purchases, self.baseline_prices) coi_mix = compute_uplift_coi(prices, purchases, self.baseline_prices)
# multiplicative penalty so COI term scales with revenue magnitude # multiplicative penalty so COI term scales with revenue magnitude
coi_leakage = float(agent_prob * self.info_value) coi_leakage = float(agent_prob * self.info_value)
discount = float(np.clip(1.0 - self.lambda_coi * coi_leakage, 0.0, 1.0)) discount = float(np.clip(1.0 - self.lambda_coi * coi_leakage, 0.0, 1.0))
coi_penalty = revenue * (1.0 - discount) # absolute penalty in revenue units coi_penalty = revenue * (1.0 - discount) # absolute penalty in revenue units
reward = revenue * discount
# calculate UX penalty based on price volatility
if len(self._price_history) > 0:
volatility = float(
np.mean(
np.abs(prices - self._price_history[-1])
/ np.maximum(self.baseline_prices, 1.0)
)
)
else:
volatility = 0.0
ux_penalty = self.eta_ux * revenue * volatility
reward = revenue * discount - ux_penalty
return reward, { return reward, {
"revenue": revenue, "revenue": revenue,
"coi_mix": float(coi_mix), "coi_mix": float(coi_mix),
@@ -191,6 +208,8 @@ class PHANTOM(gym.Env):
"coi_leakage": coi_leakage, "coi_leakage": coi_leakage,
"coi_penalty": coi_penalty, "coi_penalty": coi_penalty,
"coi_discount": discount, "coi_discount": discount,
"ux_penalty": ux_penalty,
"volatility": volatility,
} }
def _alpha_candidates(self) -> np.ndarray: def _alpha_candidates(self) -> np.ndarray:
@@ -200,27 +219,34 @@ class PHANTOM(gym.Env):
hi = min(1.0, self.nominal_alpha + self.robust_radius) hi = min(1.0, self.nominal_alpha + self.robust_radius)
return np.linspace(lo, hi, self.robust_points) return np.linspace(lo, hi, self.robust_points)
def _evaluate_candidate(
self, alpha: float, prices: np.ndarray
) -> tuple[float, dict, list, float]:
self._set_market_mix(alpha)
demand = self.market.act(prices)
trajectories = list(self.market.last_trajectories)
agent_prob = self._compute_agent_prob(trajectories)
reward, _ = self._compute_reward(prices, demand, agent_prob, trajectories)
return reward, demand, trajectories, agent_prob
def _select_adversarial_alpha( def _select_adversarial_alpha(
self, prices: np.ndarray self, prices: np.ndarray
) -> tuple[float, dict, list, float]: ) -> tuple[float, dict, list, float]:
"""inner robust step: pick worst-case alpha and return its outcome directly to avoid double-sampling""" """inner robust step: evaluate candidates and pick worst-case alpha"""
candidates = self._alpha_candidates() candidates = self._alpha_candidates()
best_alpha, worst_reward = float(candidates[0]), np.inf evaluations = [
best_demand, best_trajectories, best_agent_prob = None, [], 0.0 (alpha, *self._evaluate_candidate(float(alpha), prices))
for alpha in candidates: for alpha in candidates
self._set_market_mix(float(alpha)) ]
demand = self.market.act(prices)
trajectories = list(self.market.last_trajectories) # min over alpha in Wasserstein interval
agent_prob = self._compute_agent_prob(trajectories) best_eval = min(evaluations, key=lambda x: x[1]) # index 1 is reward
reward, _ = self._compute_reward(prices, demand, agent_prob, trajectories)
if reward < worst_reward: best_alpha = best_eval[0]
worst_reward = reward best_demand = best_eval[2]
best_alpha, best_demand, best_trajectories, best_agent_prob = ( best_trajectories = best_eval[3]
float(alpha), best_agent_prob = best_eval[4]
demand,
trajectories,
agent_prob,
)
return best_alpha, best_demand, best_trajectories, best_agent_prob return best_alpha, best_demand, best_trajectories, best_agent_prob
def _record_history(self): def _record_history(self):